Ascendant0
- 176
- 37
- TL;DR
- Need a way to reduce MATLAB code from requiring a 16 GB GPU, down to a 8 GB GPU (our college's supercomputer has 8 GB GPUs, and OOM issue)
We are trying to evaluate data from the crab nebula using MATLAB, but this computation cannot be run on a personal workstation in any reasonable time. It must be split into many parallel jobs and run on our college supercomputer.
However, while our local workstation GPU has 16 GB of memory, each GPU on the supercomputer has only 8 GB. With the current implementation, our 16 GB can handle it, but on the 8 GB, the job repeatedly fails with out-of-memory errors on the supercomputer GPUs, even though the same code structure is required for performance.
We are therefore trying to restructure the code to reduce peak memory usage while keeping the computation GPU-accelerated and mathematically identical.
What the code is doing
The code processes large complex voltage blocks (VBLK) from a radio astronomy pipeline.
Each input file contains approximately:
The problem
On the supercomputer, the code runs out of GPU memory (8 GB limit) due to large intermediate arrays (e.g. repmat + reshape + GPU arrays) inside nested loops.
We have attempted to break the computation into smaller pieces, but current approaches either:
We are looking for a way to tile or restructure the computation so that:
==============================
Here is the code:
However, while our local workstation GPU has 16 GB of memory, each GPU on the supercomputer has only 8 GB. With the current implementation, our 16 GB can handle it, but on the 8 GB, the job repeatedly fails with out-of-memory errors on the supercomputer GPUs, even though the same code structure is required for performance.
We are therefore trying to restructure the code to reduce peak memory usage while keeping the computation GPU-accelerated and mathematically identical.
What the code is doing
The code processes large complex voltage blocks (VBLK) from a radio astronomy pipeline.
Each input file contains approximately:
- 2048 frequency channels
- ~129k time samples
- 2 polarizations
- Normalizes each frequency channel.
- Constructs large time-shifted matrices to compute cross-correlations between 949 upper–lower subband pairs.
- Accumulates absolute cross-correlation amplitudes.
The problem
On the supercomputer, the code runs out of GPU memory (8 GB limit) due to large intermediate arrays (e.g. repmat + reshape + GPU arrays) inside nested loops.
We have attempted to break the computation into smaller pieces, but current approaches either:
- still trigger OOM errors, or
- increase runtime by orders of magnitude (not acceptable).
We are looking for a way to tile or restructure the computation so that:
- GPU memory usage stays below ~8 GB,
- Results are numerically identical,
- Runtime remains practical on an HPC system.
- Chunking the time or subband dimension safely,
- Memory-efficient alternatives to large replicated shift matrices,
- MATLAB GPU best practices for large cross-correlation problems,
==============================
Here is the code:
Matlab:
PGMname = 'PGM_251219_NoRev_NoHH_Abs_XC_R13_F0021_B_1_1'
'LINES WITH ************ '
'TO BE UPDATED '
'AS NEW FILES ARE PROCESSED'
% purpose: correlation of subbands
% 949 subband pairs
% 949 spacing Usb to Lsb
% No Hamming. Use full 390 kHz BW of each subband
% absolute values of XC array
% VBLK COMPLEX VOLTAGES, POLARIZATIONS 1 AND 2,
% LOADED TOGETHER, PROCESSED SEPARATELY
% FORWARD F AND REVERSE R IN SAME PGM
% ShfMat1 BOTH WAYS TO GET REVERSE
% COMPLEX SHIFT ARRAY Usb
% COMPLEX CONJUGATE DM1 Lsb
% NORMALIZE COMPLEX ARRAYS BEFORE XC
% ABS VALUE XC ARRAY SAVED
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% Run 13
% Update DFile and SAVE also !!!
BLOCK1 = 1;
BLOCKN = 1;
DFile = 21:21; % actual last digit file labels to be processed
% 13_0000 : 13_0002 for FF3 : FF5
% 13_0006 : 13_0009 for FF6 : FF9
% 13_0010 : 13_0015 for FF10: FF15
% 13_0016 : 13_0019 for FF16: FF19
% 13_0020 : 13_0023 for FF20: FF23
% 13_0024 : 13_0038 for FF24: FF38
NDFile = length(DFile);
%
% Previously, Files v1f1 and v1f1 were assigned F1,F2
% but I am now ignoring them since they had Noise Diode on
%
% 0000_3 through 0000_5 are missing
%
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
NDsamp = 129024; % original samples per channel
%
% FIRST LAG (SHIFT) = 0
% SM1(1600,NSsamp3) has range of shifts of 0-1599
% DM1 starts at channel jCh1=151+256
% of DatMat0
% gDM1 to be used for an entire Block (i.e. for all jUsb)
%
% keep FULL BLK for use to get ShfMat for each jUsb
% within current block
% so that SEGMENTS can be shifted to left sequentially from top
% lag (shift) increases with segment shift
% to EARLIER time as lag increases;
% all the multiple Lsb signals remain FIXED in time
%
% after first jUsb,
% cut one column from left of DM11
% FOR NEXT xcorr
%
% original BLK (2048,129024)
% original stdBLK(2048,129024)
%
% ===================== CIRCE PATHS (ONLY LOCATION CHANGES) =====================
workDir = "/work_bgfs/e/evansa";
inDir = "/work_bgfs/e/evansa/Data11-20-2025"; % input voltage data
outDir = "/work_bgfs/e/evansa"; % output directory
cd(workDir);
% ==============================================================================
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% NC version 251112
NUsb = 949; % based on jUsb=151:151+948 for NC version
NDchan = 2048; % original data channels (start as rows)
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% max(max(NLag1+NLags-1))=1599 from 250115
NDShf0 = 1601; % rows in repmat used to get ShfMat
NDShf2 = 1600; % rows in gShfMat1
NDsamp2 = NDsamp-1; % samp in intermediate matrix of ShfMat
NDsamp3 = NDsamp-NDShf0-1; % samples per channel in gShfMat1
Ntot0 = NDsamp*NDchan; % total original samples NDsamp*NDrows
Ntot2 = NDsamp*NDShf0; % total samples in initial repmat
% array for calculation of ShfMat
Ng1 = NDShf0; % first and last pts for gBB2 to gBB3
Ng2 = Ng1+NDsamp2*(NDShf0-1) - 1;
% for Reverse
% Rsamp2 = NDsamp+1 = 129025
% Rsamp3 = NDsamp - 3201 = 125823
% Rg2 = Ng1+Rsamp2*(NDShf0-1) - 1 = 206441600, vs
% Ntot2 = 206567424
Rsamp2 = 129025;
Rsamp3 = 125823;
Rg2 = 206441600; % Ng1=1600 is same as for Forward
%*************************************************************************
%*************************************************************************
%*************************************************************************
% FOR FOR FOR FOR
% FOR FOR FOR FOR
for jFile = DFile(1):DFile(NDFile) %FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
% FOR FOR FOR FOR
% FOR FOR FOR FOR
% rezero for each File, type double
Fabs = zeros(1600,949);
%*************************************************************************
%*************************************************************************
%*************************************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
data0 = "VBLK_75753_13_00" + string(jFile)+ "_";
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
%*************************************************************************
%*************************************************************************
%*************************************************************************
% FOR FOR FOR FOR
% FOR FOR FOR FOR
for jBlk = BLOCK1:BLOCKN %BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
% FOR FOR FOR FOR
% FOR FOR FOR FOR
%*************************************************************************
%*************************************************************************
%*************************************************************************
data1 = data0 + string(jBlk) + ".mat"; % (CIRCE: keep same filename pattern)
% ===================== CIRCE INPUT (ONLY LOCATION CHANGE) =====================
data1_full = fullfile(inDir, data1);
% =============================================================================
%*************************************************************************
load(data1_full)
%************************************************************************
%*************************************************************************
%*************************************************************************
%*************************************************************************
% FOR FOR FOR FOR
% FOR FOR FOR FOR
for jPol = 1:2
% FOR FOR FOR FOR
% FOR FOR FOR FOR
%*************************************************************************
%*************************************************************************
%*************************************************************************
cd(workDir);
%GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG
reset(gpuDevice(1)) % for each Polarization
%GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG
if jPol==1
BLK=VBLK(:,:,1);
end
if jPol==2
BLK=VBLK(:,:,2);
end
%
% BLK is single(2048,129024) (Nchan, NDsamp) subband intensities
% highest subband RF (channel) top row
% subband number increases from top to bottom
%
% input BLK (2048,NDsamp) = BLK(2048,129024)
% need full BLK to get ShfMat
%
% complex normalization of each row (channel)
stdBLK = (BLK-mean(BLK,2))./std(BLK,0,2);
% rows dominant
% No HAMMING in NoHH version starting 251214
%HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH
% cDM0 and ccDM0 contain all channels and samples
cDM0 = stdBLK; % time domain (NDchan,NDsamp)
% normalize again after shift array is created
ccDM0 = conj(cDM0); % (NDchan, NDsamp)
% transpose, limit ranges,
% and normalize later
% cc= CONJUGATE COMPLEX array
% cDM0 and ccDM0 contain ALL channels and samples
%HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH
%
% use cDM0
% for construction of FgcSM1 and RgcSM1
%
% use ccDM0 for construction of gccDM1 for
% BOTH Fwd and Rev
%
% gccDM1 starts with channel 151+949
% gccDM1 always complex conjugate
%
% Lsb (gccDM1) arrays for each BLOCK will have
% first column now is original 151+949
% total columns in gccDM1 = 949 =jUsb= # of pairs
%
% Rsamp3 for both F and R
% samples in gccDM1 start at sample #
% Ng1 to maintain time synch with ShfMat
%
NR3 = Ng1+Rsamp3-1; % last sample of gcDM1
gccDM1 = gpuArray(transpose(ccDM0(151+949:2048,Ng1:NR3))); % (Rsamp3,949)
gccDM1 = (gccDM1-mean(gccDM1))./std(gccDM1) ;
%*************************************************************************
%*************************************************************************
%*************************************************************************
% FOR FOR FOR FOR
% FOR FOR FOR FOR
for jUsb = 151:151+948 % NC version
% FOR FOR FOR FOR
% FOR FOR FOR FOR
%*************************************************************************
%*************************************************************************
%*************************************************************************
[318 jFile jBlk jPol jUsb]
datetime
gBB1 = gpuArray(repmat(cDM0(jUsb,:),[NDShf0 1])); % repmat
gBB1 = reshape(transpose(gBB1),[1,Ntot2]); % reshape to linear
gBB2 = transpose((reshape(gBB1(Ng1: Ng2),[NDsamp2,NDShf2])));
FgcSM1 = gBB2(:,1:Rsamp3); % cut samples to Rsamp3
FgcSM1 = (FgcSM1-mean(FgcSM1,2))./std(FgcSM1,0,2);
kUsb = jUsb-150;
Fabs(:,kUsb)=Fabs(:,kUsb)+abs(gather((FgcSM1*gccDM1(:,kUsb))/Rsamp3));
end % jUsb
end % jPol
end % jBLK
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% save results
FFout = string(jFile);
data4 = "_" + string(BLOCK1) + "_" + string(BLOCKN) + ".mat";
data5F = "251219_Fabs_NoHH_13_" + FFout + data4;
% ===================== CIRCE OUTPUT (ONLY LOCATION CHANGE) ====================
save(fullfile(outDir, data5F), "Fabs", "-v7.3");
% ==============================================================================
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
% UPDATE AS NEEDED ********************************************************
cd(workDir);
end % jFile
Last edited by a moderator: