@@ -353,6 +353,10 @@ void TCuda<AFloat>::CalculateConvWeightGradients(TCudaMatrix<AFloat> & weightGra
353353 const size_t filterSize = filterHeight * filterWidth;
354354 const size_t nLocalViewPixels = filterDepth * filterSize;
355355 R__ASSERT ( weightGradients.GetNcols () == nLocalViewPixels);
356+ R__ASSERT ( weightGradients.GetNrows () == depth);
357+ R__ASSERT ( df.size () == batchSize);
358+
359+
356360
357361 const size_t tempStrideRows = 1 ;
358362 const size_t tempStrideCols = 1 ;
@@ -361,46 +365,17 @@ void TCuda<AFloat>::CalculateConvWeightGradients(TCudaMatrix<AFloat> & weightGra
361365 const size_t tempZeroPaddingHeight = (height - inputHeight + filterHeight - 1 ) / 2 ;
362366 const size_t tempZeroPaddingWidth = (width - inputWidth + filterWidth - 1 ) / 2 ;
363367
364- std::vector< TCudaMatrix<AFloat> > vres;
365- for (size_t i = 0 ; i < batchSize; i++) {
366- vres.emplace_back (depth, nLocalViewPixels);
367- }
368-
369368 // Convolution.
370369 TCudaMatrix<AFloat> activationsPrime (nLocalViews, nLocalViewPixels);
370+ TCudaMatrix<AFloat> resPrime (depth, nLocalViewPixels);
371371 for (size_t event = 0 ; event < df.size (); event++) {
372372 Im2col (activationsPrime, activationsBackward[event], inputHeight, inputWidth, filterHeight, filterWidth,
373373 tempStrideRows, tempStrideCols, tempZeroPaddingHeight, tempZeroPaddingWidth);
374374
375- Multiply (vres[event], df[event], activationsPrime);
376- }
377-
378- dim3 blockDims = TDevice::BlockDims2D ();
379- dim3 gridDims = TDevice::GridDims2D (weightGradients);
380- cudaStream_t s = weightGradients.GetComputeStream ();
375+ Multiply (resPrime, df[event], activationsPrime);
381376
382- // Get raw pointers from a vector of matrices - this is more challenging than it sounds.
383- //
384- // Attention: While `TCudaMatrix.GetDataPointer() returns a pointer to device memory,
385- // std::vector (and its .data() raw pointer) resides on host memory. Therefore
386- // we need to manually copy these pointers to the device prior to invoking the kernel.
387-
388- const AFloat ** dB; // device pointer to device pointers.
389- const AFloat ** hB = new const AFloat * [batchSize]; // host pointer to device pointers.
390-
391- cudaMalloc (&dB, sizeof (AFloat *) * batchSize);
392- for (size_t i = 0 ; i < batchSize; ++i) {
393- hB[i] = vres[i].GetDataPointer ();
377+ TCuda<AFloat>::ScaleAdd (weightGradients, resPrime, 1.0 );
394378 }
395-
396- cudaMemcpy (dB, hB, sizeof (AFloat *) * batchSize, cudaMemcpyHostToDevice);
397-
398- // Launch the kernel using our device pointers.
399- ::TMVA::DNN::Cuda::UpdateWeights<<<gridDims, blockDims>>> (weightGradients.GetDataPointer (), dB, batchSize,
400- depth, nLocalViewPixels);
401-
402- delete [] hB;
403- cudaFree (dB);
404379}
405380
406381// ____________________________________________________________________________
0 commit comments