Comprehensive pipeline for training and visualizing neural networks on modular arithmetic tasks with reproducibility, device abstraction, and automated result tracking.
from main import Experiment
config = {
'p': 17,
'c': [4, 1, 0],
'd': [1, 2, 3, 0, 0],
'embedding_dim': 128,
'hidden': 500,
'learning_rate': 0.005,
'max_steps': 1000,
'batch_size': 1024,
'weight_decay': 1e-4,
'split': 0.99999,
'negs_per_ex': 20,
}
exp = Experiment(config)
exp.run(animate=False)| Parameter | Type | Default | Description |
|---|---|---|---|
p |
int | - | Prime modulus (>= 2) |
c |
list | - | Polynomial coefficients [c0, c1, c2] |
d |
list | - | Polynomial degrees [d0, d1, d2, d3, d4] |
embedding_dim |
int | - | Embedding dimension |
hidden |
int | - | Hidden layer size |
learning_rate |
float | - | Learning rate |
max_steps |
int | - | Training iterations |
batch_size |
int | - | Batch size |
weight_decay |
float | - | L2 regularization |
split |
float | - | Train/test ratio (0-1) |
negs_per_ex |
int | - | Negative samples per example |
The target function is a modular polynomial defined as:
Components:
-
$c = [c_0, c_1, c_2]$ — Coefficients (3 values) -
$d = [d_0, d_1, d_2, d_3, d_4]$ — Degrees (5 values) -
$p$ — Prime modulus
Examples:
-
Addition:
$f(x, y) = (x + y) \bmod p$ c = [1, 1, 0] d = [1, 1, 1, 0, 0]
-
Quadratic with exponent:
$f(x, y) = (4x + y^2)^3 \bmod p$ c = [4, 1, 0] d = [1, 2, 3, 0, 0]
-
With cross term:
$f(x, y) = (x + y) + xy \bmod p$ c = [1, 1, 1] d = [1, 1, 1, 1, 1]
exp.run(animate=False) # Full pipeline
exp.setup_experiment_config() # Initialize model/dataset
exp.train_model() # Train only
exp.visualize_experiment() # Visualize only
exp.calculate_parameters() # Compute model statsGenerated plots:
- Training and test loss curves
- Overall accuracy curves
- Positive example accuracies
- Negative example accuracies
- (If
animate=Trueand single test example):- Histogram of all predictions
- Single prediction histogram
- Probability distribution evolution animation
Automatically detects and uses:
- MPS (Apple Silicon)
- CUDA (NVIDIA GPU)
- CPU (fallback)
PyTorch >= 1.9
NumPy
Matplotlib
IPython
BadData_AppC
NegSamplingMath_Unlearn_AppC_FullDataSet