Transformer Implementation from Scratch
Overview
This is a decoder-only transformer built from scratch in JAX. The point was not to wrap a library model, but to write the machinery directly enough that every tensor shape and parameter path was visible.
Components
The implementation includes:
- Multi-head causal self-attention
- Sinusoidal positional encodings
- Layer normalization
- Residual MLP blocks
- Explicit parameter threading
- A JIT-compiled training loop
- Manual gradient computation with
jax.grad
I avoided higher-level frameworks like Flax and Haiku so that the model stayed close to raw JAX arrays and pure functions.
What Clicked
JAX makes the transformer feel like a functional program: parameters are passed explicitly, transformations are composable, and training is a sequence of pure updates once the random keys are handled carefully.
The exercise made attention less mysterious. At the implementation level, the core operation is still structured matrix multiplication; the hard part is making the shapes, masks, and residual paths line up exactly.