MATLAB GPU Cross-Correlation Code Fails with OOM on 8 GB Supercomputer

  • Context: MATLAB 
  • Thread starter Thread starter Ascendant0
  • Start date Start date
Click For Summary
SUMMARY

The forum discussion centers on the challenges of running MATLAB code for cross-correlation analysis of radio astronomy data on a supercomputer with 8 GB GPUs, which leads to out-of-memory (OOM) errors. The code processes large voltage blocks (VBLK) containing 2048 frequency channels and approximately 129,000 time samples, requiring memory-efficient restructuring to maintain performance. The users seek solutions to reduce peak memory usage while ensuring numerical accuracy and practical runtime, specifically through chunking and optimizing memory usage in MATLAB GPU implementations.

PREREQUISITES
  • Understanding of MATLAB GPU programming and memory management
  • Familiarity with cross-correlation techniques in signal processing
  • Knowledge of radio astronomy data structures and processing
  • Experience with high-performance computing (HPC) environments
NEXT STEPS
  • Research MATLAB GPU best practices for memory optimization
  • Explore techniques for chunking data in MATLAB to manage memory usage
  • Investigate alternatives to large replicated matrices in GPU computations
  • Learn about efficient data handling in high-performance computing systems
USEFUL FOR

This discussion is beneficial for MATLAB developers, astrophysicists working with large datasets, and researchers involved in high-performance computing and GPU programming for signal processing tasks.

Ascendant0
Messages
177
Reaction score
37
TL;DR
Need a way to reduce MATLAB code from requiring a 16 GB GPU, down to a 8 GB GPU (our college's supercomputer has 8 GB GPUs, and OOM issue)
We are trying to evaluate data from the crab nebula using MATLAB, but this computation cannot be run on a personal workstation in any reasonable time. It must be split into many parallel jobs and run on our college supercomputer.

However, while our local workstation GPU has 16 GB of memory, each GPU on the supercomputer has only 8 GB. With the current implementation, our 16 GB can handle it, but on the 8 GB, the job repeatedly fails with out-of-memory errors on the supercomputer GPUs, even though the same code structure is required for performance.

We are therefore trying to restructure the code to reduce peak memory usage while keeping the computation GPU-accelerated and mathematically identical.

What the code is doing

The code processes large complex voltage blocks (VBLK) from a radio astronomy pipeline.
Each input file contains approximately:
  • 2048 frequency channels
  • ~129k time samples
  • 2 polarizations
For each file and polarization, the code:
  1. Normalizes each frequency channel.
  2. Constructs large time-shifted matrices to compute cross-correlations between 949 upper–lower subband pairs.
  3. Accumulates absolute cross-correlation amplitudes.
Full time and frequency resolution must be preserved (no averaging or decimation).

The problem

On the supercomputer, the code runs out of GPU memory (8 GB limit) due to large intermediate arrays (e.g. repmat + reshape + GPU arrays) inside nested loops.

We have attempted to break the computation into smaller pieces, but current approaches either:
  • still trigger OOM errors, or
  • increase runtime by orders of magnitude (not acceptable).
What we need help with

We are looking for a way to tile or restructure the computation so that:
  • GPU memory usage stays below ~8 GB,
  • Results are numerically identical,
  • Runtime remains practical on an HPC system.
In particular, advice on:
  • Chunking the time or subband dimension safely,
  • Memory-efficient alternatives to large replicated shift matrices,
  • MATLAB GPU best practices for large cross-correlation problems,
would be very helpful.

==============================

Here is the code:
Matlab:
PGMname = 'PGM_251219_NoRev_NoHH_Abs_XC_R13_F0021_B_1_1'

'LINES WITH ************ '
'TO BE UPDATED '
'AS NEW FILES ARE PROCESSED'
% purpose: correlation of subbands
% 949 subband pairs
% 949 spacing Usb to Lsb
% No Hamming. Use full 390 kHz BW of each subband
% absolute values of XC array
% VBLK COMPLEX VOLTAGES, POLARIZATIONS 1 AND 2,
%     LOADED TOGETHER, PROCESSED SEPARATELY
% FORWARD F AND REVERSE R IN SAME PGM
% ShfMat1 BOTH WAYS TO GET REVERSE
% COMPLEX SHIFT ARRAY Usb
% COMPLEX CONJUGATE DM1 Lsb
% NORMALIZE COMPLEX ARRAYS BEFORE XC
% ABS VALUE XC ARRAY SAVED

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

% Run 13
% Update DFile and SAVE also !!!

BLOCK1 = 1;
BLOCKN = 1;

DFile   = 21:21; % actual last digit file labels to be processed
               % 13_0000 : 13_0002 for FF3 : FF5
               % 13_0006 : 13_0009 for FF6 : FF9
               % 13_0010 : 13_0015 for FF10: FF15
               % 13_0016 : 13_0019 for FF16: FF19
               % 13_0020 : 13_0023 for FF20: FF23
               % 13_0024 : 13_0038 for FF24: FF38
NDFile  = length(DFile);
%
% Previously, Files v1f1 and v1f1 were assigned F1,F2
% but I am now ignoring them since they had Noise Diode on
%
% 0000_3 through 0000_5 are missing
%

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

NDsamp   = 129024;               % original samples per channel

%
% FIRST LAG (SHIFT) = 0
% SM1(1600,NSsamp3) has range of shifts of 0-1599
% DM1 starts at channel jCh1=151+256
% of DatMat0
% gDM1 to be used for an entire Block (i.e. for all jUsb)
%
% keep FULL BLK for use to get ShfMat for each jUsb
% within current block
% so that SEGMENTS can be shifted to left sequentially from top
% lag (shift) increases with segment shift
% to EARLIER time as lag increases;
% all the multiple Lsb signals remain FIXED in time
%
% after first jUsb,
% cut one column from left of DM11
% FOR NEXT xcorr
%
% original BLK   (2048,129024)
% original stdBLK(2048,129024)
%

% ===================== CIRCE PATHS (ONLY LOCATION CHANGES) =====================
workDir = "/work_bgfs/e/evansa";
inDir   = "/work_bgfs/e/evansa/Data11-20-2025";  % input voltage data
outDir  = "/work_bgfs/e/evansa";                 % output directory

cd(workDir);
% ==============================================================================

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

% NC version 251112

NUsb     = 949;   % based on jUsb=151:151+948  for NC version
NDchan   = 2048;  % original data channels (start as rows)

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

% max(max(NLag1+NLags-1))=1599 from 250115

NDShf0   = 1601;    % rows in repmat used to get ShfMat
NDShf2   = 1600;    % rows in gShfMat1

NDsamp2  = NDsamp-1;         % samp in intermediate matrix of ShfMat
NDsamp3  = NDsamp-NDShf0-1;  % samples per channel in gShfMat1

Ntot0    = NDsamp*NDchan;    % total original samples NDsamp*NDrows
Ntot2    = NDsamp*NDShf0;    % total samples in initial repmat
                             % array for calculation of ShfMat

Ng1      = NDShf0;           % first and last pts for gBB2 to gBB3
Ng2      = Ng1+NDsamp2*(NDShf0-1) - 1;

% for Reverse
% Rsamp2   = NDsamp+1 = 129025
% Rsamp3   = NDsamp - 3201 = 125823
% Rg2      = Ng1+Rsamp2*(NDShf0-1) - 1 = 206441600, vs
% Ntot2    =                             206567424

Rsamp2   = 129025;
Rsamp3   = 125823;
Rg2      = 206441600; % Ng1=1600 is same as for Forward

%*************************************************************************
%*************************************************************************
%*************************************************************************
% FOR FOR FOR FOR
% FOR FOR FOR FOR
for jFile = DFile(1):DFile(NDFile)      %FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
% FOR FOR FOR FOR
% FOR FOR FOR FOR

% rezero for each File, type double
Fabs  = zeros(1600,949);

%*************************************************************************
%*************************************************************************
%*************************************************************************

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
data0 = "VBLK_75753_13_00" + string(jFile)+ "_";
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

 %*************************************************************************
 %*************************************************************************
 %*************************************************************************
 % FOR FOR FOR FOR
 % FOR FOR FOR FOR
 for jBlk = BLOCK1:BLOCKN %BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
 % FOR FOR FOR FOR
 % FOR FOR FOR FOR
 %*************************************************************************
 %*************************************************************************
 %*************************************************************************

 data1 = data0 + string(jBlk) + ".mat";   % (CIRCE: keep same filename pattern)

 % ===================== CIRCE INPUT (ONLY LOCATION CHANGE) =====================
 data1_full = fullfile(inDir, data1);
 % =============================================================================

%*************************************************************************
load(data1_full)
%************************************************************************

%*************************************************************************
%*************************************************************************
%*************************************************************************
% FOR FOR FOR FOR
% FOR FOR FOR FOR
for jPol = 1:2
% FOR FOR FOR FOR
% FOR FOR FOR FOR
%*************************************************************************
%*************************************************************************
%*************************************************************************
 cd(workDir);

%GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG
reset(gpuDevice(1)) % for each Polarization
%GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG

 if jPol==1
     BLK=VBLK(:,:,1);
 end

 if jPol==2
     BLK=VBLK(:,:,2);
 end

%
% BLK is single(2048,129024) (Nchan, NDsamp) subband intensities
% highest subband RF (channel) top row
% subband number increases from top to bottom
%
% input BLK (2048,NDsamp) = BLK(2048,129024)
% need full BLK  to get ShfMat
%

% complex normalization of each row (channel)

stdBLK = (BLK-mean(BLK,2))./std(BLK,0,2);
% rows dominant

% No HAMMING in NoHH version starting 251214
%HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH

% cDM0 and ccDM0 contain all channels and samples

cDM0  = stdBLK;           % time domain (NDchan,NDsamp)
                          % normalize again after shift array is created
ccDM0 = conj(cDM0);       % (NDchan, NDsamp)
                          % transpose, limit ranges,
                          % and normalize later

% cc= CONJUGATE COMPLEX array
% cDM0 and ccDM0 contain ALL channels and samples

%HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH

%
% use cDM0
% for construction of FgcSM1 and RgcSM1
%
% use ccDM0 for construction of gccDM1 for
% BOTH Fwd and Rev
%
% gccDM1 starts with channel 151+949
% gccDM1 always complex conjugate
%
% Lsb (gccDM1) arrays for each BLOCK will have
% first column now is original 151+949
% total columns in gccDM1 = 949 =jUsb= # of pairs
%
% Rsamp3 for both F and R
% samples in gccDM1 start at sample #
% Ng1 to maintain time synch with ShfMat
%

NR3 = Ng1+Rsamp3-1; % last sample of gcDM1

gccDM1     = gpuArray(transpose(ccDM0(151+949:2048,Ng1:NR3))); % (Rsamp3,949)
gccDM1     = (gccDM1-mean(gccDM1))./std(gccDM1) ;

%*************************************************************************
%*************************************************************************
%*************************************************************************
% FOR FOR FOR FOR
% FOR FOR FOR FOR
for jUsb = 151:151+948 % NC version
% FOR FOR FOR FOR
% FOR FOR FOR FOR
%*************************************************************************
%*************************************************************************
%*************************************************************************

[318 jFile jBlk jPol jUsb]
datetime

gBB1 = gpuArray(repmat(cDM0(jUsb,:),[NDShf0 1]));    % repmat

gBB1 = reshape(transpose(gBB1),[1,Ntot2]);    % reshape to linear

gBB2 = transpose((reshape(gBB1(Ng1: Ng2),[NDsamp2,NDShf2])));

FgcSM1 = gBB2(:,1:Rsamp3);       % cut samples to Rsamp3

FgcSM1   = (FgcSM1-mean(FgcSM1,2))./std(FgcSM1,0,2);

kUsb = jUsb-150;

Fabs(:,kUsb)=Fabs(:,kUsb)+abs(gather((FgcSM1*gccDM1(:,kUsb))/Rsamp3));

end % jUsb
end % jPol
end % jBLK

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

% save results
FFout = string(jFile);

data4  = "_" + string(BLOCK1) + "_" + string(BLOCKN) + ".mat";
data5F = "251219_Fabs_NoHH_13_" + FFout + data4;

% ===================== CIRCE OUTPUT (ONLY LOCATION CHANGE) ====================
save(fullfile(outDir, data5F), "Fabs", "-v7.3");
% ==============================================================================

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

cd(workDir);

end  % jFile
 
Last edited by a moderator:
Physics news on Phys.org
Ascendant0 said:
TL;DR: Need a way to reduce MATLAB code from requiring a 16 GB GPU, down to a 8 GB GPU (our college's supercomputer has 8 GB GPUs, and OOM issue)

The code processes large complex voltage blocks (VBLK) from a radio astronomy pipeline.
Each input file contains approximately:
  • 2048 frequency channels
  • ~129k time samples
  • 2 polarizations
For each file and polarization, the code:
  1. Normalizes each frequency channel.
  2. Constructs large time-shifted matrices to compute cross-correlations between 949 upper–lower subband pairs.
  3. Accumulates absolute cross-correlation amplitudes.
Can you say more about why you are doing cross-correlations between different frequency channels?
 
berkeman said:
Can you say more about why you are doing cross-correlations between different frequency channels?
To clarify: the mathematics and the specific cross-correlations are exactly what we intend scientifically and are not something we’re looking to change or simplify.

The reason for mentioning this at all is just to give context for the computational structure of the code. From an implementation standpoint, the task reduces to computing many fixed-offset cross-correlations over large time windows, which currently requires constructing large time-shifted matrices.

The issue we’re trying to solve is purely computational: finding a way to execute the same math on GPUs with an 8 GB memory limit, without altering the calculation or its interpretation.
 
Ascendant0 said:
The issue we’re trying to solve is purely computational: finding a way to execute the same math on GPUs with an 8 GB memory limit, without altering the calculation or its interpretation.
Well, okay I guess. You probably see my point, though. One way to slice up the computations is to do the computations on different sets of frequency bands separately. But since you are doing cross-correlations between different frequency bands/channels that would not be a candidate. The only reasons I know of to do cross-correlations between frequency bands would be 1) because you are decoding encoded information (like spread-spectrum or similar modulation schemes) which makes no sense for radio astronomy, or 2) you are trying to improve your signal-to-noise ratio.
 
Hi,

good thing a friendly moderator (Berkeman) added the [CODE=MATLAB] tags for legibility !

Ascendant0 said:
purely computational:
Still: over 300 lines scares off most potential helpers !

Then: I don't have MATLAB, only OCTAVE so I had to fumble a bit.
Even then, can't run beyond about halfway up to
load(Data11-20-2025\VBLK_75753_13_0021_1.mat)

A file I don't have -- but you do.
Ascendant0 said:
the job repeatedly fails with out-of-memory errors on the supercomputer GPU

Prime suspect is VBLK which is (Nchan, NDsamp,2) complex, so 2048 * 129024 * 2 = 528 482 304 elements of default 16 bytes lets you hit close to 8 GB already (7.88). Single would occupy 4 GB --- not much room left over for other stuff

Obvious questions arise:
  • how big is the file
  • at what point does the 8GB machine fail ? beyond the loading ?
  • if it fails at the loading, can you check sizeof(VBLK) on the 16GB station ?

purely computational:
With the wisdom of ignorance, my approach would be to try 'clear VBLK' after BLK=VBLK(:,:,1); around line 183 and process jPOL = 1 .
Then reload data1_full for the second round. Basically trying to halve memory usage by unwinding the jPOL loop.

If that fails, perhaps try to separately split data1_full in two files and load BLK instead of VBLK for each jPOL ?

function memory can help you check things

Best of luck !

BvU