A JAX-based quantized training framework that exports directly to TensorFlow Lite for efficient deployment on microcontrollers and embedded devices.
Quax is a training and deployment framework specifically designed for resource-constrained hardware that ingests quantized flatbuffers directly.
- Layer-level quantization control - Precise control over quantization at individual layer granularity
- Direct TFLite export - export to flatbuffer format without intermediate conversions
- TensorFlow-independent - Pure Jax/Flax implementation with no TensorFlow dependencies
Install Quax via pip:
cd quax
pip install .Run the example model:
python3 quax_e2e_model.py