Skip to content

Commit 1a173bd

Browse files
monetalmoneta
authored andcommitted
use simply ScaleAdd when updating weightgradients
1 parent 0881dec commit 1a173bd

File tree

1 file changed

+7
-32
lines changed

1 file changed

+7
-32
lines changed

tmva/tmva/src/DNN/Architectures/Cuda/Propagation.cu

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,10 @@ void TCuda<AFloat>::CalculateConvWeightGradients(TCudaMatrix<AFloat> & weightGra
353353
const size_t filterSize = filterHeight * filterWidth;
354354
const size_t nLocalViewPixels = filterDepth * filterSize;
355355
R__ASSERT( weightGradients.GetNcols() == nLocalViewPixels);
356+
R__ASSERT( weightGradients.GetNrows() == depth);
357+
R__ASSERT( df.size() == batchSize);
358+
359+
356360

357361
const size_t tempStrideRows = 1;
358362
const size_t tempStrideCols = 1;
@@ -361,46 +365,17 @@ void TCuda<AFloat>::CalculateConvWeightGradients(TCudaMatrix<AFloat> & weightGra
361365
const size_t tempZeroPaddingHeight = (height - inputHeight + filterHeight - 1) / 2;
362366
const size_t tempZeroPaddingWidth = (width - inputWidth + filterWidth - 1) / 2;
363367

364-
std::vector< TCudaMatrix<AFloat> > vres;
365-
for (size_t i = 0; i < batchSize; i++) {
366-
vres.emplace_back(depth, nLocalViewPixels);
367-
}
368-
369368
// Convolution.
370369
TCudaMatrix<AFloat> activationsPrime(nLocalViews, nLocalViewPixels);
370+
TCudaMatrix<AFloat> resPrime(depth, nLocalViewPixels);
371371
for(size_t event = 0; event < df.size(); event++) {
372372
Im2col(activationsPrime, activationsBackward[event], inputHeight, inputWidth, filterHeight, filterWidth,
373373
tempStrideRows, tempStrideCols, tempZeroPaddingHeight, tempZeroPaddingWidth);
374374

375-
Multiply(vres[event], df[event], activationsPrime);
376-
}
377-
378-
dim3 blockDims = TDevice::BlockDims2D();
379-
dim3 gridDims = TDevice::GridDims2D(weightGradients);
380-
cudaStream_t s = weightGradients.GetComputeStream();
375+
Multiply(resPrime, df[event], activationsPrime);
381376

382-
// Get raw pointers from a vector of matrices - this is more challenging than it sounds.
383-
//
384-
// Attention: While `TCudaMatrix.GetDataPointer() returns a pointer to device memory,
385-
// std::vector (and its .data() raw pointer) resides on host memory. Therefore
386-
// we need to manually copy these pointers to the device prior to invoking the kernel.
387-
388-
const AFloat ** dB; // device pointer to device pointers.
389-
const AFloat ** hB = new const AFloat * [batchSize]; // host pointer to device pointers.
390-
391-
cudaMalloc(&dB, sizeof(AFloat *) * batchSize);
392-
for(size_t i = 0; i < batchSize; ++i) {
393-
hB[i] = vres[i].GetDataPointer();
377+
TCuda<AFloat>::ScaleAdd(weightGradients, resPrime, 1.0);
394378
}
395-
396-
cudaMemcpy(dB, hB, sizeof(AFloat *) * batchSize, cudaMemcpyHostToDevice);
397-
398-
// Launch the kernel using our device pointers.
399-
::TMVA::DNN::Cuda::UpdateWeights<<<gridDims, blockDims>>>(weightGradients.GetDataPointer(), dB, batchSize,
400-
depth, nLocalViewPixels);
401-
402-
delete [] hB;
403-
cudaFree(dB);
404379
}
405380

406381
//____________________________________________________________________________

0 commit comments

Comments
 (0)