A Triton-inspired GPU kernel compiler. Write GPU kernels in Python, compile them via MLIR to real hardware instructions.
import tiny_ton as tt
import numpy as np
@tt.jit
def vector_add(a_ptr, b_ptr, c_ptr, N):
pid = tt.program_id(0)
offsets = pid * 64 + tt.arange(0, 64)
mask = offsets < N
a = tt.load(a_ptr + offsets, mask=mask)
b = tt.load(b_ptr + offsets, mask=mask)
tt.store(c_ptr + offsets, a + b, mask=mask)
a = np.array([1, 2, 3, 4], dtype=np.int32)
b = np.array([10, 20, 30, 40], dtype=np.int32)
c = np.zeros(4, dtype=np.int32)
vector_add[(1,)](a, b, c, len(a))
print(c) # [11, 22, 33, 44]Python (@jit) → AST capture → pybind11 → C++ IRBuilder → MLIR (TinyTon dialect)
→ Register Allocation → CodeGen → Runtime/Simulator → Execution
- CMake 3.20+
- C++17 compiler
- LLVM/MLIR 18
- pybind11
- Python 3.10+
# Docker (recommended)
docker build -t tiny-ton .
docker run tiny-ton ttc --emit asm examples/vector_add.tgc
# Native
brew install cmake ninja llvm@18
rm -rf build
cmake -G Ninja -S . -B build \
-DCMAKE_BUILD_TYPE=Debug \
-DMLIR_DIR=/opt/homebrew/opt/llvm@18/lib/cmake/mlir \
-DLLVM_DIR=/opt/homebrew/opt/llvm@18/lib/cmake/llvm \
-DTTN_ENABLE_PYTHON=OFF
cmake --build build
./build/bin/ttc --helpcd python
pip install -e .Goal: run Karpathy's microgpt forward pass on GPU via tiny-ton JIT kernels.
- Element-wise arithmetic:
add,sub,mul,div(i32/f32/f16) - Math intrinsics:
exp,log,sqrt,rsqrt,abs,max(f32/f16) - Masked load/store with
program_idthreading - NVIDIA GPU backend: MLIR → PTX via combined pass + libdevice
- Google Colab CI: build + test on T4 GPU
Each operation is a single kernel, tested independently against NumPy.
-
tt.reduce_sum— warp-shuffle /gpu.all_reducereduction -
tt.reduce_max— same as above with max -
tt.relu— element-wisemax(x, 0) -
tt.gather— embedding lookup by index -
tt.dot/ matvec — dot product viareduce_sum -
softmax— composed:reduce_max→sub→exp→reduce_sum→div(5 launches) -
rmsnorm— composed:square→reduce_sum→rsqrt→scale(4 launches) -
linear— matvec using dot (one output per block) -
cross_entropy— composed:softmax→gather→-log -
attention— composed: linear projections + dot + softmax + weighted sum
Replace microgpt's Python ops one by one with tiny-ton GPU kernels. Each op is still a separate launch — no fusion yet.
- Replace
softmax(),rmsnorm(),linear()with GPU kernels - Replace attention + MLP with composed GPU launches
- Full forward pass end-to-end on GPU
- Benchmark vs Python CPU baseline
Reduce launch overhead, fuse kernels, improve throughput.
Benchmark context: Stage 2 ran 8,800 kernel launches for 20 inference samples at n_embd=16. Overhead dominated (~150µs/launch × 8,800 = ~1,320s). GPU ran 487x slower than CPU. Every item below attacks this.
- Fused softmax — 5 launches → 1 (warp shuffle: reduce max, sub, exp, reduce sum, div all in registers) —
examples/fused_softmax_test.py,docs/10-fused-softmax.md - Fused rmsnorm — 4 launches → 1 (warp shuffle: square, reduce sum, rsqrt, scale in registers) —
examples/fused_rmsnorm_test.py,docs/12-fused-rmsnorm.md - Fused per-head attention — 12 launches → 7 (score scaling + softmax fused into one kernel) —
examples/fused_attention_test.py,docs/13-fused-attention.md - NumPy training — replaced scalar
Valueautograd with vectorized NumPy forward + manual backward (full BPTT through KV cache) + Adam; 1000 steps in ~1s vs ~minutes
Expected: ~3x fewer kernel launches, ~3x speedup.
- Test at n_embd=64, n_embd=128 to find the GPU crossover point
- n_embd=16: GPU 487x slower (4x useful work per launch)
- n_embd=64: estimated ~30x slower
- n_embd=512+: GPU wins
- Make block size a kernel
constexprparameter — today it is hardcoded to 64, so 75% of threads are idle at n_embd=16 —examples/constexpr_test.py,docs/14-constexpr.md - Implemented in
jit.py(parsePARAM: tt.constexprannotation, separate cache key per value, exclude from IR args); no C++ or MLIR changes required
Mirrors Modular's Blackwell series, adapted for Ampere sm_87. Each kernel adds one hardware concept. Target: match cuBLAS FP32 (~2 TFLOPS), then FP16 with tensor cores (~12 TFLOPS).
Hardware context (Jetson Orin Nano): 16 SMs · 48 KB shmem/SM · 68 GB/s · FP32 peak ~2 TFLOPS · FP16 tensor core peak ~12 TFLOPS
| Kernel | Technique | Expected TFLOPS | Compiler change |
|---|---|---|---|
| K0: Naive GEMM | One block per output element, global memory only | ~0.001 | Add //, % to JIT |
| K1: Row GEMM | One block per row, A reused across N cols | ~0.005 | None (rename tiled_matmul_kernel) |
| K2: Shmem GEMM | A + B tiles in shared memory, 2D grid | ~0.1 | program_id(1) + scf.for runtime loop |
| K3: Swizzled GEMM | XOR-swizzle shmem layout, eliminate 8-way bank conflicts | ~0.2 | Swizzle address helper in JIT |
| K4: Vectorized GEMM | LDG.128 — load 4 floats per instruction |
~0.5 | New tt.load_v4 IR op |
| K5: Pipelined GEMM | cp.async — overlap load with compute (Ampere) |
~1.0 | New tt.async_copy IR op |
| K6: Tensor Core GEMM | mma.sync.m16n8k16 — FP16 tensor cores |
~6–12 | New tt.dot tile op |
Progress:
- Correctness:
tiled_gemm_test.py— all loop_sum, tiled_dot, tiled_matmul tests pass on Jetson - Bug fix:
reduce_sumpartial-warp shuffle (blockSize < 32) — passes correctwidthtogpu::ShuffleOpinstead of hardcoded 32 - Bug fix:
emit_mulscalar promotion —_promote_scalarinjit.pypreventsTypeErrorwhen a constexpr int is passed as an IR operand - Benchmark notebook —
examples/gemm_benchmark.ipynbwith cuBLAS reference numbers and gap analysis - K0: Naive GEMM — add
//(FloorDiv) and%(Mod) to JITBinOphandler - K1: Row GEMM — rename/reframe
tiled_matmul_kernelin notebook - K2: Shmem GEMM —
program_id(1)(2D grid) +scf.forruntime K-loop - K3: Swizzled GEMM — 128-byte XOR swizzle to eliminate shmem bank conflicts
- K4: Vectorized GEMM —
LDG.128vectorized loads - K5: Pipelined GEMM —
cp.asyncto overlap load and compute - K6: Tensor Core GEMM —
mma.sync.aligned.m16n8k16viatt.dot
See examples/gemm_benchmark.ipynb for the live benchmark notebook and docs/16-tiled-gemm.md for the current tiling design.
- Flash Attention style — tiles the KV cache into chunks, accumulates softmax numerator/denominator across chunks; needed when seq_len > block_size (64)
- Pattern-matching fusion pass on the
tinytonMLIR dialect — detectsexp→reduce_sum→divetc. and merges them automatically, like XLA/TVM/torch.compile
MIT — see LICENSE.