GPU-Accelerated Convolution Operations with Shared Memory Optimization
This tutorial explores efficient implementation of convolutional operations on GPUs using CUDA with a focus on shared memory optimization. Convolutional Neural Networks (CNNs) are a cornerstone of modern deep learning for computer vision, and the convolution operation accounts for the majority of their computational workload. This makes it a prime target for GPU acceleration.
You can find the code in https://github.com/eunomia-bpf/basic-cuda-tutorial
Table of Contents
- Introduction to Convolutional Neural Networks
- Convolution Operation
- Implementation Approaches
- Direct Convolution
- Shared Memory Optimization
- Additional CNN Components
- Activation Functions
- Pooling Layers
- Performance Analysis
- Further Optimization Techniques
Introduction to Convolutional Neural Networks
Convolutional Neural Networks (CNNs) have revolutionized computer vision by capturing spatial hierarchies and patterns in image data. They're engineered to automatically learn spatial hierarchies of features through backpropagation by using multiple building blocks such as:
- Convolutional layers - apply learnable filters to input data
- Activation functions - introduce non-linearity (typically ReLU)
- Pooling layers - reduce spatial dimensions
- Fully-connected layers - perform classification based on extracted features
The convolutional layer is the core building block of a CNN, which is why optimizing its performance is critical for efficient deep learning applications.
Convolution Operation
Mathematical Definition
The 2D convolution operation is defined as:
Where:
- b
is the batch index
- c
is the input channel index
- k
is the output channel index (kernel number)
- x
, y
are spatial coordinates
- kx
, ky
are kernel positions
- s
is the stride
- p
is the padding
Dimensions and Memory Layout
For a typical convolution operation:
- Input shape: [batch_size, in_channels, height, width]
- Kernels shape: [out_channels, in_channels, kernel_height, kernel_width]
- Output shape: [batch_size, out_channels, out_height, out_width]
Where:
out_height = (height + 2*padding - kernel_height) / stride + 1
out_width = (width + 2*padding - kernel_width) / stride + 1
Implementation Approaches
Direct Convolution
The naive implementation of convolution directly maps the mathematical definition to code:
__global__ void convolutionDirectKernel(
float *input, float *kernels, float *output,
int batchSize, int inputChannels, int inputSize,
int kernelSize, int kernelCount, int outputSize,
int padding, int stride)
{
// Calculate output position
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int k = blockIdx.z; // Output channel (kernel number)
int b = threadIdx.z; // Batch index
// Skip out-of-bounds threads
if (x >= outputSize || y >= outputSize || k >= kernelCount || b >= batchSize)
return;
// Compute convolution for this output position
float sum = 0.0f;
// For each input channel
for (int c = 0; c < inputChannels; c++) {
// For each kernel position
for (int ky = 0; ky < kernelSize; ky++) {
for (int kx = 0; kx < kernelSize; kx++) {
// Input position
int in_x = x * stride - padding + kx;
int in_y = y * stride - padding + ky;
// Skip if input position is outside the input
if (in_x >= 0 && in_x < inputSize && in_y >= 0 && in_y < inputSize) {
// Input value
float in_val = input[
b * inputChannels * inputSize * inputSize +
c * inputSize * inputSize +
in_y * inputSize + in_x
];
// Kernel value
float kernel_val = kernels[
k * inputChannels * kernelSize * kernelSize +
c * kernelSize * kernelSize +
ky * kernelSize + kx
];
// Accumulate result
sum += in_val * kernel_val;
}
}
}
}
// Store output
output[
b * kernelCount * outputSize * outputSize +
k * outputSize * outputSize +
y * outputSize + x
] = sum;
}
Characteristics of Direct Convolution: - Simple and straightforward implementation - Each thread computes a single output element - High global memory access redundancy - Low arithmetic intensity
Shared Memory Optimization
The key insight for optimization is that adjacent output elements reuse many of the same input values. By loading input data into shared memory once and reusing it for multiple computations, we can significantly reduce global memory accesses:
__global__ void convolutionSharedKernel(
float *input, float *kernels, float *output,
int batchSize, int inputChannels, int inputSize,
int kernelSize, int kernelCount, int outputSize,
int padding, int stride)
{
// Shared memory for input tile
extern __shared__ float sharedData[];
// Calculate tile dimensions
int tileSize = blockDim.x;
int tileSizeWithPadding = tileSize + kernelSize - 1;
// Thread and block indices
int tx = threadIdx.x;
int ty = threadIdx.y;
int bx = blockIdx.x;
int by = blockIdx.y;
int k = blockIdx.z;
int b = threadIdx.z;
// Load input data to shared memory
// ...
// Compute convolution using shared memory
// ...
}
Advantages of Shared Memory Approach: 1. Reduced Global Memory Access: Each input element is loaded from global memory only once, then reused multiple times from shared memory. 2. Improved Memory Access Patterns: Threads in a block access contiguous memory locations. 3. Increased Arithmetic Intensity: More computation per global memory access.
Additional CNN Components
Activation Functions
Activation functions introduce non-linearity into the network. The ReLU (Rectified Linear Unit) is the most common activation function in CNNs:
__global__ void reluActivationKernel(float *data, int size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
data[idx] = fmaxf(0.0f, data[idx]);
}
}
ReLU is extremely parallel-friendly as each output element depends only on a single input element.
Pooling Layers
Pooling reduces the spatial dimensions of the feature maps, decreasing computational load and providing some translation invariance:
__global__ void maxPoolingKernel(
float *input, float *output,
int batchSize, int channels, int inputSize,
int poolSize, int outputSize, int stride)
{
// ...
// Find maximum value in pooling window
for (int dy = 0; dy < poolSize; dy++) {
for (int dx = 0; dx < poolSize; dx++) {
// ...
maxVal = fmaxf(maxVal, value);
}
}
// Store output
// ...
}
Performance Analysis
Our implementation compares two convolution approaches:
- Direct Convolution: Baseline implementation with each thread computing one output element.
- Shared Memory Convolution: Optimized implementation that loads input tiles to shared memory.
Typical performance improvements from shared memory optimization: - For 5×5 kernels: 2-4× speedup - For larger kernels: 3-7× speedup - For multiple input channels: Even greater speedup
Memory Access Analysis
For a direct convolution with kernel size K×K: - Each output element requires K×K input elements - For an N×N output, that's N×N×K×K global memory accesses
With shared memory optimization: - Each input element is loaded to shared memory once - For an M×M tile (with M threads per dimension), we load (M+K-1)×(M+K-1) elements - Total global memory accesses: (M+K-1)×(M+K-1) per tile
The reduction in global memory accesses can be substantial, especially for larger kernel sizes.
Further Optimization Techniques
Beyond the shared memory optimization shown in this example, several other techniques can further accelerate CNN operations:
-
Kernel Fusion: Combining convolution, bias addition, and activation into a single kernel to reduce kernel launch overhead and memory transactions.
-
Winograd Algorithm: Reduces the number of multiplications needed for small kernel sizes (e.g., 3×3) at the cost of additional additions.
-
FFT-based Convolution: For large kernel sizes, using Fast Fourier Transform can accelerate convolution.
-
Im2Col + GEMM: Reformatting the convolution operation as a matrix multiplication to leverage highly optimized GEMM libraries.
-
Quantization: Using lower precision (INT8, FP16) to increase arithmetic throughput and reduce memory bandwidth requirements.
-
Tensor Cores: On modern NVIDIA GPUs, utilizing Tensor Cores for mixed-precision matrix multiplications.
-
Kernel Decomposition: Decomposing larger kernels into separable 1D filters when possible (e.g., 5×5 → 5×1 followed by 1×5).
Conclusion
Efficient implementation of convolution operations is crucial for CNN performance. By leveraging GPU shared memory, we can significantly reduce global memory accesses and improve throughput. The optimization techniques demonstrated in this example represent foundational approaches that modern deep learning frameworks build upon.
For production applications, it's generally recommended to use optimized libraries like cuDNN, which implement many of these optimizations (and more) with architecture-specific tuning. However, understanding the underlying principles of efficient convolution is valuable for custom implementations and future optimizations.