Sparton

Sparton

Sparton (Sparse Triton) is a highly optimized implementation of the SPLADE LM head using Triton kernels. It achieves ~5x speedup and ~4x memory reduction by combining operator reordering, online reduction, tiled matrix multiplication, and sparse gradient computation to eliminate the massive intermediate tensor.

Never Build What You Don't Need

The key insight is that SPLADE computes a huge intermediate tensor only to immediately reduce it with Max. Sparton reorders operations so the Max happens during computation, tile by tile, never materializing the full tensor. It’s like computing a running maximum while streaming through data instead of storing everything first.

The Problem: SPLADE Memory Bottleneck

Standard SPLADE LM Head

where:

  • — hidden states
  • — vocabulary embeddings
  • The intermediate tensor is massive

Memory explosion example:

  • , ,
  • Intermediate size: bytes 1 GB

Innovation 1: Operator Reordering

Sparton Reordered LM Head

The Max moves inside the monotonic functions (ReLU, Log1p), reducing tensor size from to before applying elementwise operations.

Why This Works

For monotonically non-decreasing functions : Since ReLU and are monotonic for , we can swap the order.

Standard:  MatMul → Mask → ReLU → Log1p → Max
           [B×S×|V|]      [B×S×|V|]      [B×|V|]
                   ↑ All stored in memory

Sparton:   MatMul → Mask → Max → ReLU → Log1p
           [B×S×|V|]     [B×|V|]  (small!)
                   ↑ Never fully materialized

Innovation 2: Online Reduction / Tiled MatMul

Instead of computing the full tensor, Sparton processes vocabulary tiles with a running maximum:

for each vocabulary tile (size T):
    1. Compute partial: P = H × E^T[tile]  → [B × S × T]
    2. Apply mask
    3. Update running max: max_acc = max(max_acc, P.max(dim=S))
    4. Discard P (don't store!)

Result: Only store max_acc [B × |V|], never full [B × S × |V|]

Online Reduction

An online algorithm processes data incrementally without storing all intermediate results. For max: maintain a running maximum, updating it as each tile is processed.

Innovation 3: Sparse Gradient Computation

During backpropagation, gradients are non-zero only where the max was achieved:

Sparse Gradient

\frac{1}{1 + Y_j} & \text{if } i = \text{argmax}_k (HE^T)_{k,j} \text{ and } Y_j > 0 \\ 0 & \text{otherwise} \end{cases}$$ Only $B \times |V|$ positions have non-zero gradients (vs. $B \times S \times |V|$ naively).

The forward pass stores:

  • argmax indices: — which sequence position achieved the max
  • max values: — the actual maximum values

Innovation 4: Fused Triton Kernel

All operations combined into a single kernel:

@triton.jit
def sparton_forward(...):
    # Load H tile into SRAM
    h_tile = tl.load(H_ptr + offsets)
 
    # Initialize accumulators
    max_val = -INF
    max_idx = 0
 
    # Loop over vocabulary tiles
    for v_start in range(0, V, V_TILE):
        e_tile = tl.load(E_ptr + v_offsets)
        partial = tl.dot(h_tile, e_tile) + bias
        partial = tl.where(mask, partial, -INF)
 
        # Online max reduction
        new_max = tl.max(partial, axis=0)
        update_mask = new_max > max_val
        max_val = tl.where(update_mask, new_max, max_val)
        max_idx = tl.where(update_mask, seq_idx, max_idx)
 
    # Apply ReLU and Log1p ONCE
    output = tl.log(1 + tl.maximum(max_val, 0))
    tl.store(out_ptr, output)
    tl.store(idx_ptr, max_idx)

Performance Results (SPLADE V3, B=320, S=512, |V|=30522)

PhaseComponentEager Time (ms)Eager Mem (MiB)
FwdBackbone + LM Head162.128885.1
FwdBackbone + Sparton113.72955.4
Fwd+BwdBackbone + LM Head498.188875.0
Fwd+BwdBackbone + Sparton330.151651.2

Sparton almost completely removes the memory overhead of the LM Head. Micro-benchmarks show up to 4.8x faster and 10x+ peak memory reduction.

Key Properties

  • Memory Efficient: instead of
  • Fast: ~5x speedup from Kernel Fusion and reduced memory traffic
  • Mathematically Equivalent: Produces identical outputs to standard SPLADE
  • Backward Compatible: Drop-in replacement for SPLADE training

Connections

Appears In