Kernel Fusion

Kernel Fusion

Kernel Fusion is a GPU optimization technique that combines multiple sequential operations into a single kernel, eliminating intermediate tensor materialization and reducing memory I/O. Instead of writing results to slow HBM memory between operations, fused kernels keep intermediate data in fast SRAM.

Avoiding the Memory Round-Trip

Imagine computing y = log(relu(x + b)) naively: each operation reads from and writes to slow GPU memory. With fusion, we load x and b once, compute everything in fast cache, and write only y. The speedup comes from avoiding the “memory tax” on each operation.

The Problem: Intermediate Tensor Materialization

┌─────────────────────────────────────────────────────────────┐
│                 Naive (Unfused) Execution                    │
│                                                              │
│   Operation 1: z = x + b                                     │
│   ┌─────────┐     ┌─────────┐     ┌─────────┐               │
│   │  HBM    │ ──► │ Compute │ ──► │  HBM    │  (write z)    │
│   │ (read x,b)    │  x + b  │     │ (store z)│               │
│   └─────────┘     └─────────┘     └─────────┘               │
│                                        │                     │
│   Operation 2: a = relu(z)             ▼                     │
│   ┌─────────┐     ┌─────────┐     ┌─────────┐               │
│   │  HBM    │ ──► │ Compute │ ──► │  HBM    │  (write a)    │
│   │ (read z)│     │ relu(z) │     │ (store a)│               │
│   └─────────┘     └─────────┘     └─────────┘               │
│                                        │                     │
│   Operation 3: y = log(a)              ▼                     │
│   ┌─────────┐     ┌─────────┐     ┌─────────┐               │
│   │  HBM    │ ──► │ Compute │ ──► │  HBM    │  (write y)    │
│   │ (read a)│     │ log(a)  │     │ (store y)│               │
│   └─────────┘     └─────────┘     └─────────┘               │
│                                                              │
│   Total HBM transfers: 6 (read x, b, z, a; write z, a, y)   │
└─────────────────────────────────────────────────────────────┘

The Solution: Fused Kernel

┌─────────────────────────────────────────────────────────────┐
│                   Fused Execution                            │
│                                                              │
│   ┌─────────┐     ┌─────────────────────────┐  ┌─────────┐ │
│   │  HBM    │ ──► │        SRAM             │  │  HBM    │ │
│   │(read x,b)     │  z = x + b              │  │(write y)│ │
│   └─────────┘     │  a = relu(z) [in SRAM]  │  └─────────┘ │
│                   │  y = log(a)  [in SRAM]  │──►           │
│                   └─────────────────────────┘               │
│                                                              │
│   Total HBM transfers: 3 (read x, b; write y)               │
│   Speedup: ~2x fewer memory operations!                      │
└─────────────────────────────────────────────────────────────┘

Mathematical Formulation

Memory Savings from Fusion

For a sequence of elementwise operations on tensor of size :

Unfused: bytes transferred (read + write per operation)

Fused: bytes transferred (read input + write output)

Key Properties

  • Eliminates Intermediate Storage: Tensors between operations never hit slow HBM
  • Reduces Kernel Launch Overhead: One launch instead of many
  • Improves Cache Utilization: Data stays in fast SRAM/registers
  • Enables Online Algorithms: Reductions can be computed incrementally

Common Fusion Patterns

PatternOperationsExample
Activation FusionLinear + Activationrelu(Wx + b)
Normalization FusionStats + NormalizeLayerNorm, BatchNorm
Attention FusionQK^T + Softmax + VFlashAttention
Reduction FusionMatmul + Max/SumSparton

Variants

Manual Fusion (Triton/CUDA)

Write custom kernels that combine operations explicitly. Maximum control but requires GPU programming expertise.

Compiler Fusion (XLA, TorchScript)

Automatic fusion by ML compilers. Less effort but may miss optimization opportunities.

Framework-Level Fusion

Libraries like FlashAttention provide pre-fused implementations of common patterns.

Example: SPLADE LM Head

Naive vs. Fused SPLADE

Naive: Separate kernels for each operation Memory: intermediates

Fused (Sparton): Single kernel with online reduction Memory: — no large intermediates

Connections

Appears In