Back to Projects

Transformer Implementation from Scratch

transformersJAXeinopsdeep-learning

Overview

This is a decoder-only transformer built from scratch in JAX. The point was not to wrap a library model, but to write the machinery directly enough that every tensor shape and parameter path was visible.

Components

The implementation includes:

  • Multi-head causal self-attention
  • Sinusoidal positional encodings
  • Layer normalization
  • Residual MLP blocks
  • Explicit parameter threading
  • A JIT-compiled training loop
  • Manual gradient computation with jax.grad

I avoided higher-level frameworks like Flax and Haiku so that the model stayed close to raw JAX arrays and pure functions.

What Clicked

JAX makes the transformer feel like a functional program: parameters are passed explicitly, transformations are composable, and training is a sequence of pure updates once the random keys are handled carefully.

The exercise made attention less mysterious. At the implementation level, the core operation is still structured matrix multiplication; the hard part is making the shapes, masks, and residual paths line up exactly.