MATLAB GPU Cross-Correlation Code Fails with OOM on 8 GB Supercomputer

  • Context: MATLAB 
  • Thread starter Thread starter Ascendant0
  • Start date Start date
Ascendant0
Messages
176
Reaction score
37
TL;DR
Need a way to reduce MATLAB code from requiring a 16 GB GPU, down to a 8 GB GPU (our college's supercomputer has 8 GB GPUs, and OOM issue)
We are trying to evaluate data from the crab nebula using MATLAB, but this computation cannot be run on a personal workstation in any reasonable time. It must be split into many parallel jobs and run on our college supercomputer.

However, while our local workstation GPU has 16 GB of memory, each GPU on the supercomputer has only 8 GB. With the current implementation, our 16 GB can handle it, but on the 8 GB, the job repeatedly fails with out-of-memory errors on the supercomputer GPUs, even though the same code structure is required for performance.

We are therefore trying to restructure the code to reduce peak memory usage while keeping the computation GPU-accelerated and mathematically identical.

What the code is doing

The code processes large complex voltage blocks (VBLK) from a radio astronomy pipeline.
Each input file contains approximately:
  • 2048 frequency channels
  • ~129k time samples
  • 2 polarizations
For each file and polarization, the code:
  1. Normalizes each frequency channel.
  2. Constructs large time-shifted matrices to compute cross-correlations between 949 upper–lower subband pairs.
  3. Accumulates absolute cross-correlation amplitudes.
Full time and frequency resolution must be preserved (no averaging or decimation).

The problem

On the supercomputer, the code runs out of GPU memory (8 GB limit) due to large intermediate arrays (e.g. repmat + reshape + GPU arrays) inside nested loops.

We have attempted to break the computation into smaller pieces, but current approaches either:
  • still trigger OOM errors, or
  • increase runtime by orders of magnitude (not acceptable).
What we need help with

We are looking for a way to tile or restructure the computation so that:
  • GPU memory usage stays below ~8 GB,
  • Results are numerically identical,
  • Runtime remains practical on an HPC system.
In particular, advice on:
  • Chunking the time or subband dimension safely,
  • Memory-efficient alternatives to large replicated shift matrices,
  • MATLAB GPU best practices for large cross-correlation problems,
would be very helpful.

==============================

Here is the code:
Matlab:
PGMname = 'PGM_251219_NoRev_NoHH_Abs_XC_R13_F0021_B_1_1'

'LINES WITH ************ '
'TO BE UPDATED '
'AS NEW FILES ARE PROCESSED'
% purpose: correlation of subbands
% 949 subband pairs
% 949 spacing Usb to Lsb
% No Hamming. Use full 390 kHz BW of each subband
% absolute values of XC array
% VBLK COMPLEX VOLTAGES, POLARIZATIONS 1 AND 2,
%     LOADED TOGETHER, PROCESSED SEPARATELY
% FORWARD F AND REVERSE R IN SAME PGM
% ShfMat1 BOTH WAYS TO GET REVERSE
% COMPLEX SHIFT ARRAY Usb
% COMPLEX CONJUGATE DM1 Lsb
% NORMALIZE COMPLEX ARRAYS BEFORE XC
% ABS VALUE XC ARRAY SAVED

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

% Run 13
% Update DFile and SAVE also !!!

BLOCK1 = 1;
BLOCKN = 1;

DFile   = 21:21; % actual last digit file labels to be processed
               % 13_0000 : 13_0002 for FF3 : FF5
               % 13_0006 : 13_0009 for FF6 : FF9
               % 13_0010 : 13_0015 for FF10: FF15
               % 13_0016 : 13_0019 for FF16: FF19
               % 13_0020 : 13_0023 for FF20: FF23
               % 13_0024 : 13_0038 for FF24: FF38
NDFile  = length(DFile);
%
% Previously, Files v1f1 and v1f1 were assigned F1,F2
% but I am now ignoring them since they had Noise Diode on
%
% 0000_3 through 0000_5 are missing
%

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

NDsamp   = 129024;               % original samples per channel

%
% FIRST LAG (SHIFT) = 0
% SM1(1600,NSsamp3) has range of shifts of 0-1599
% DM1 starts at channel jCh1=151+256
% of DatMat0
% gDM1 to be used for an entire Block (i.e. for all jUsb)
%
% keep FULL BLK for use to get ShfMat for each jUsb
% within current block
% so that SEGMENTS can be shifted to left sequentially from top
% lag (shift) increases with segment shift
% to EARLIER time as lag increases;
% all the multiple Lsb signals remain FIXED in time
%
% after first jUsb,
% cut one column from left of DM11
% FOR NEXT xcorr
%
% original BLK   (2048,129024)
% original stdBLK(2048,129024)
%

% ===================== CIRCE PATHS (ONLY LOCATION CHANGES) =====================
workDir = "/work_bgfs/e/evansa";
inDir   = "/work_bgfs/e/evansa/Data11-20-2025";  % input voltage data
outDir  = "/work_bgfs/e/evansa";                 % output directory

cd(workDir);
% ==============================================================================

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

% NC version 251112

NUsb     = 949;   % based on jUsb=151:151+948  for NC version
NDchan   = 2048;  % original data channels (start as rows)

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

% max(max(NLag1+NLags-1))=1599 from 250115

NDShf0   = 1601;    % rows in repmat used to get ShfMat
NDShf2   = 1600;    % rows in gShfMat1

NDsamp2  = NDsamp-1;         % samp in intermediate matrix of ShfMat
NDsamp3  = NDsamp-NDShf0-1;  % samples per channel in gShfMat1

Ntot0    = NDsamp*NDchan;    % total original samples NDsamp*NDrows
Ntot2    = NDsamp*NDShf0;    % total samples in initial repmat
                             % array for calculation of ShfMat

Ng1      = NDShf0;           % first and last pts for gBB2 to gBB3
Ng2      = Ng1+NDsamp2*(NDShf0-1) - 1;

% for Reverse
% Rsamp2   = NDsamp+1 = 129025
% Rsamp3   = NDsamp - 3201 = 125823
% Rg2      = Ng1+Rsamp2*(NDShf0-1) - 1 = 206441600, vs
% Ntot2    =                             206567424

Rsamp2   = 129025;
Rsamp3   = 125823;
Rg2      = 206441600; % Ng1=1600 is same as for Forward

%*************************************************************************
%*************************************************************************
%*************************************************************************
% FOR FOR FOR FOR
% FOR FOR FOR FOR
for jFile = DFile(1):DFile(NDFile)      %FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
% FOR FOR FOR FOR
% FOR FOR FOR FOR

% rezero for each File, type double
Fabs  = zeros(1600,949);

%*************************************************************************
%*************************************************************************
%*************************************************************************

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
data0 = "VBLK_75753_13_00" + string(jFile)+ "_";
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

 %*************************************************************************
 %*************************************************************************
 %*************************************************************************
 % FOR FOR FOR FOR
 % FOR FOR FOR FOR
 for jBlk = BLOCK1:BLOCKN %BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
 % FOR FOR FOR FOR
 % FOR FOR FOR FOR
 %*************************************************************************
 %*************************************************************************
 %*************************************************************************

 data1 = data0 + string(jBlk) + ".mat";   % (CIRCE: keep same filename pattern)

 % ===================== CIRCE INPUT (ONLY LOCATION CHANGE) =====================
 data1_full = fullfile(inDir, data1);
 % =============================================================================

%*************************************************************************
load(data1_full)
%************************************************************************

%*************************************************************************
%*************************************************************************
%*************************************************************************
% FOR FOR FOR FOR
% FOR FOR FOR FOR
for jPol = 1:2
% FOR FOR FOR FOR
% FOR FOR FOR FOR
%*************************************************************************
%*************************************************************************
%*************************************************************************
 cd(workDir);

%GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG
reset(gpuDevice(1)) % for each Polarization
%GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG

 if jPol==1
     BLK=VBLK(:,:,1);
 end

 if jPol==2
     BLK=VBLK(:,:,2);
 end

%
% BLK is single(2048,129024) (Nchan, NDsamp) subband intensities
% highest subband RF (channel) top row
% subband number increases from top to bottom
%
% input BLK (2048,NDsamp) = BLK(2048,129024)
% need full BLK  to get ShfMat
%

% complex normalization of each row (channel)

stdBLK = (BLK-mean(BLK,2))./std(BLK,0,2);
% rows dominant

% No HAMMING in NoHH version starting 251214
%HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH

% cDM0 and ccDM0 contain all channels and samples

cDM0  = stdBLK;           % time domain (NDchan,NDsamp)
                          % normalize again after shift array is created
ccDM0 = conj(cDM0);       % (NDchan, NDsamp)
                          % transpose, limit ranges,
                          % and normalize later

% cc= CONJUGATE COMPLEX array
% cDM0 and ccDM0 contain ALL channels and samples

%HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH

%
% use cDM0
% for construction of FgcSM1 and RgcSM1
%
% use ccDM0 for construction of gccDM1 for
% BOTH Fwd and Rev
%
% gccDM1 starts with channel 151+949
% gccDM1 always complex conjugate
%
% Lsb (gccDM1) arrays for each BLOCK will have
% first column now is original 151+949
% total columns in gccDM1 = 949 =jUsb= # of pairs
%
% Rsamp3 for both F and R
% samples in gccDM1 start at sample #
% Ng1 to maintain time synch with ShfMat
%

NR3 = Ng1+Rsamp3-1; % last sample of gcDM1

gccDM1     = gpuArray(transpose(ccDM0(151+949:2048,Ng1:NR3))); % (Rsamp3,949)
gccDM1     = (gccDM1-mean(gccDM1))./std(gccDM1) ;

%*************************************************************************
%*************************************************************************
%*************************************************************************
% FOR FOR FOR FOR
% FOR FOR FOR FOR
for jUsb = 151:151+948 % NC version
% FOR FOR FOR FOR
% FOR FOR FOR FOR
%*************************************************************************
%*************************************************************************
%*************************************************************************

[318 jFile jBlk jPol jUsb]
datetime

gBB1 = gpuArray(repmat(cDM0(jUsb,:),[NDShf0 1]));    % repmat

gBB1 = reshape(transpose(gBB1),[1,Ntot2]);    % reshape to linear

gBB2 = transpose((reshape(gBB1(Ng1: Ng2),[NDsamp2,NDShf2])));

FgcSM1 = gBB2(:,1:Rsamp3);       % cut samples to Rsamp3

FgcSM1   = (FgcSM1-mean(FgcSM1,2))./std(FgcSM1,0,2);

kUsb = jUsb-150;

Fabs(:,kUsb)=Fabs(:,kUsb)+abs(gather((FgcSM1*gccDM1(:,kUsb))/Rsamp3));

end % jUsb
end % jPol
end % jBLK

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

% save results
FFout = string(jFile);

data4  = "_" + string(BLOCK1) + "_" + string(BLOCKN) + ".mat";
data5F = "251219_Fabs_NoHH_13_" + FFout + data4;

% ===================== CIRCE OUTPUT (ONLY LOCATION CHANGE) ====================
save(fullfile(outDir, data5F), "Fabs", "-v7.3");
% ==============================================================================

% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************

cd(workDir);

end  % jFile
 
Last edited by a moderator:
Physics news on Phys.org
Ascendant0 said:
TL;DR: Need a way to reduce MATLAB code from requiring a 16 GB GPU, down to a 8 GB GPU (our college's supercomputer has 8 GB GPUs, and OOM issue)

The code processes large complex voltage blocks (VBLK) from a radio astronomy pipeline.
Each input file contains approximately:
  • 2048 frequency channels
  • ~129k time samples
  • 2 polarizations
For each file and polarization, the code:
  1. Normalizes each frequency channel.
  2. Constructs large time-shifted matrices to compute cross-correlations between 949 upper–lower subband pairs.
  3. Accumulates absolute cross-correlation amplitudes.
Can you say more about why you are doing cross-correlations between different frequency channels?