Skip to content

Scronkfinkle/tiny-mnist

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tiny MNIST

A ~300 line from-scratch* neural network implementation in Rust for MNIST digit classification, inspired by the 3Blue1Brown series on neural networks.

Overview

This project implements a multi-layer perceptron (MLP) neural network from the ground up, without using any ML frameworks. It implements the core concepts of:

  • Forward propagation
  • Backpropagation with gradient descent
  • Cross-entropy loss with softmax activation
  • Weight initialization and numerical stability techniques

I made it because after watching some videos on how neural networks function at a low level, I wanted to apply it to see what it actually takes to build something practical.

Pre-requisites

You should watch these videos to help explain everything that's going on here:

3Blue1Brown Neural Network Series:

Artom Kirsanov's Cross-Entropy explanation:

(Optional) Andrej Karpathy's Micrograd tutorial:

Quick Start

This code is not very efficient, and cargo run will be very slow. It's recommended you build and run this with

cargo run --release

It takes about 4-5 minutes on my machine to train but you should eventually see something like this:
If it fails with NaN (or if you want to try and get better), you can try tuning the parameters/architecture in main.rs

# Model Performance (Training Data)
# Correct: 7701  Incorrect: 52299 Accuracy: 12.84%
# Model Performance (Test Data)
# Correct: 1238  Incorrect: 8762 Accuracy: 12.38%
Initial Loss 2.3021650570112273
Training...
(Epoch 1/10) Loss: 2.006211042872563
(Epoch 2/10) Loss: 1.6804801431140421
(Epoch 3/10) Loss: 1.5403268631741156
(Epoch 4/10) Loss: 1.4416650354726195
(Epoch 5/10) Loss: 1.421127595082744
(Epoch 6/10) Loss: 1.4114806765025723
(Epoch 7/10) Loss: 1.4064681376955193
(Epoch 8/10) Loss: 1.4039222909552094
(Epoch 9/10) Loss: 1.403353387461272
(Epoch 10/10) Loss: 1.4032988632194991
# Model Performance (Training Data)
# Correct: 56579  Incorrect: 3421 Accuracy: 94.30%
# Model Performance (Test Data)
# Correct: 9436  Incorrect: 564 Accuracy: 94.36%

Dataset

The MNIST dataset files are included in this repository to make it easy to get started. The original dataset docs can be found at Yann LeCun's website, but downloading and extracting can be tricky (currently the links are 403 forbidden), so the binary files are vendored in this repositority as:

  • data/train-images-idx3-ubyte - Training images (60,000 samples)
  • data/train-labels-idx1-ubyte - Training labels
  • data/t10k-images-idx3-ubyte - Test images (10,000 samples)
  • data/t10k-labels-idx1-ubyte - Test labels

Implementation Details

  • Forward Pass: Matrix multiplication with activation functions (src/mlp.rs:73-98)
  • Backward Pass: Chain rule application for gradient computation (src/mlp.rs:100-201)
  • Weight Updates: Gradient descent with learning rate and clipping (src/mlp.rs:176)
  • Data Processing: Pixel normalization and one-hot encoding (src/mnist.rs)

Additional Math Resources

This implementation uses several mathematical concepts beyond what's covered in the 3Blue1Brown videos. Each is documented with links in the code:

Numerical Stability (NUM_STABLE markers)

In the real world, we need to handle the fact that we're doing floating point arithmetic and can very easily break things. You'll see NUM_STABLE comments marking places where numerical stability techniques were needed to prevent NaN errors:

  • Weight Initialization (src/utils.rs:10-18): Uses He initialization with additional 0.1 scaling to prevent weights from becoming too large
  • Gradient Clipping (src/mlp.rs:160-172): Caps gradient magnitudes to prevent exploding gradients
  • Log Epsilon Addition (src/functions.rs:44-48): Adds small epsilon (1e-12) before taking logarithm to avoid ln(0)
  • Leaky ReLU (src/functions.rs:66-75): Prevents completely dead neurons that would stop learning
  • Raw Logits Output (src/mlp.rs:32-40): Uses raw final layer outputs instead of activated values for numerical stability in loss calculation

These modifications were essential to get the network to train successfully without hitting numerical instabilities that would cause the training to fail with NaN values.

*Excluding one dependency to make matrices first-class citizens in the language, and another for generating random numbers

About

A ~300 line from-scratch* neural network implementation in Rust for MNIST digit classification, inspired by the 3Blue1Brown series on neural networks.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors