Fixing LLM Inference Bottlenecks with Custom CUDA Kernels: Part 2
Today, we are identifying a "simple" bottleneck, writing a custom CUDA kernel, and proving it moves the needle.
---
## 1. Finding the Low-Hanging Fruit
We are looking for **point-wise operations**. These are operations where each output element depends only on its corresponding input element (like addition or multiplication). In vanilla LLM implementations, these are the biggest sources of avoidable waste.
### The GELU Problem
The math for the GELU activation function used in GPT-2 looks like this:
$$GELU(x) = 0.5x(1 + \tanh(\sqrt{2/\pi}(x + 0.044715x^3)))$$
In the HuggingFace `transformers` implementation, it looks like this:
```python
def forward(self, input):
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
The Problem: PyTorch’s dispatcher launches 8 separate CUDA kernels every single time this function is called.
These operations are memory-bound. The GPU spends more time moving data in and out of VRAM for these tiny operations than it does actually calculating them. On an RTX 4070, this unoptimized GELU eats roughly 12.2% of total CUDA time. That is pure overhead.
2. The Solution: Kernel Fusion
We are going to "fuse" these 8 operations into one. Instead of reading and writing to memory 8 times, we read once, perform the math in GPU registers, and write once.
The CUDA Kernel (fused_gelu.cu)
We use __device__ __forceinline__ to ensure the compiler pastes the math directly into our kernel, eliminating function call overhead.
#include <cuda_runtime.h>
#include <math.h>
extern "C" {
__device__ __forceinline__ float gelu(float x) {
const float two_by_pi = 0.7978845608028654f; // √(2/π)
return 0.5f * x * (1.0f + tanhf(two_by_pi * (x + 0.044715f * x * x * x)));
}
__global__ void fused_gelu_kernel(const float* __restrict__ input, float* __restrict__ output, int total) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < total) {
output[idx] = gelu(input[idx]);
}
}
void launch_fused_gelu(const float* input, float* output, int total, cudaStream_t stream) {
const int threads = 256;
const int blocks = (total + threads - 1) / threads;
fused_gelu_kernel<<<blocks, threads, 0, stream>>>(input, output, total);
}
}
The C++ Binding (fused_gelu_binding.cpp)
This is the glue. It performs sanity checks and executes the kernel within the current PyTorch CUDA stream.
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
extern "C" void launch_fused_gelu(const float*, float*, int, cudaStream_t);
torch::Tensor fused_gelu_launcher(torch::Tensor input) {
TORCH_CHECK(input.is_cuda(), "Input must be on GPU");
TORCH_CHECK(input.dtype() == torch::kFloat32, "Only FP32 supported");
auto output = torch::empty_like(input);
launch_fused_gelu(
input.data_ptr<float>(),
output.data_ptr<float>(),
input.numel(),
c10::cuda::getCurrentCUDAStream()
);
return output;
}
TORCH_LIBRARY(fg, m) {
m.def("fused_gelu(Tensor input) -> Tensor");
}
TORCH_LIBRARY_IMPL(fg, CUDA, m) {
m.impl("fused_gelu", TORCH_FN(fused_gelu_launcher));
}
3. Deployment: Monkey Patching GPT-2
We compile the kernel on the fly using cpp_extension.load. To inject this into the model without rewriting the transformers library, we use monkey patching.
import torch
import types
from torch.utils.cpp_extension import load
# Compile and load the custom op
fused = load(
name="fused_gelu",
sources=["fused_gelu.cu", "fused_gelu_binding.cpp"],
extra_cuda_cflags=["--use_fast_math"],
verbose=False
)
def fused_mlp_forward(self, hidden_states):
# Transpose weights because GPT2 uses Conv1D layout
x = torch.nn.functional.linear(hidden_states, self.c_fc.weight.t(), self.c_fc.bias)
x = torch.ops.fg.fused_gelu(x)
return self.c_proj(x)
# Patch all MLP modules in the model
def patch_model(model):
for module in model.modules():
if "GPT2MLP" in str(type(module)):
module.forward = types.MethodType(fused_mlp_forward, module)
patch_model(model)
4. Performance vs. Accuracy
If you run this, you will notice a slight mismatch in logits (<5%).
Why? We used the --use_fast_math flag. This swaps standard math functions for hardware intrinsics like __tanhf, which are significantly faster but less precise. For most LLM inference tasks, this precision hit is negligible compared to the 12% reduction in CUDA execution time. If you need 100% bit-wise parity, remove that flag.
Conclusion
By fusing 8 kernels into 1, we eliminated unnecessary memory round-trips. This is how you optimize inference: stop looking at the high-level Python and start looking at what the GPU is actually doing.
Next Step: In Part 3, we move beyond point-wise ops and tackle Advanced Fused Kernels for Softmax and Attention.
Comments
Post a Comment