Usage:
export HF_OFFLINE=1
export HF_DATASETS_OFFLINE=1
NGPU=2 bash train_bert.sh --training.steps 1000Or for a custom config:
WANDB_NAME="my-run-name" CONFIG_FILE=configs/my_config.toml NGPU=2 bash train_bert.shModel architecture is defined in configs/bert_arch.json.
Training hyperparameters are set in configs/bert_training.toml.
Note that the trainer won't automatically try to load from a checkpoint. If you want to resume from a checkpoint:
- Set
load_step = -1to load the latest checkpoint - Set
load_step = 1000to load a specific step - Or pass it via command line:
--checkpoint.load_step -1
To debug flame/train.py with specific arguments, run:
python -m debugpy --listen 5678 --wait-for-client -m torch.distributed.run \
--nproc_per_node=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 \
-m flame.train --job.config_file configs/bert_training.tomlThen attach using "Debug Training (Attach to Process)" in VSCode.
Components:
MatrixTokenizer- Discretizes values inton_binsbins with special tokens [PAD], [MASK], [CLS], [SEP], [ROW_SEP]generate_matrix_dataset()- Generates low-rank matrices offline with: - Variance normalization (all matrices have var=1) - Stores bothinput_ids(tokenized) andvalues(continuous floats) - Logs stats: nuclear norm, spectral gap, mean, variance, trace, rangeDataCollatorForMatrixCompletion- MLM-style masking that outputs: -input_ids: with [MASK] tokens at masked positions -labels: continuous float values (NaN for non-masked positions) -attention_mask: standard padding maskbuild_dataloader_matrix_completion()- FLAME-compatible dataloader builder
Batch output format:
{
"input_ids": torch.LongTensor, # [batch, seq_len]
"labels": torch.FloatTensor, # [batch, seq_len], NaN for non-masked
"attention_mask": torch.LongTensor, # [batch, seq_len]
}
CLI usage:
python -m data.matrix_completion --m 20 --n 20 --r 5 --n_samples 1000 --output_dir data/generated/my_dataset
Expected structure and content of data/generated/:
data/generated/ └── <dataset_name>/ # e.g., matrix_m20_n20_r5_seed42/ ├── config.json # Generation hyperparameters ├── dataset_stats.json # Aggregate statistics over all matrices ├── tokenizer_config.json # Tokenizer settings └── train/ ├── data-00000-of-00001.arrow # HuggingFace Dataset (binary) ├── dataset_info.json # Dataset schema └── state.json # HuggingFace internal state
train/data-*.arrow — The actual dataset in Apache Arrow format. Each sample contains:
-
input_ids: List[int] — Tokenized matrix with [CLS], bin tokens, [SEP] -
values: List[float] — Continuous float values aligned with input_ids (NaN for special tokens) -
Example for a 3×3 matrix without row separators:
input_ids: [2, 505, 612, 498, 523, 501, 489, 510, 507, 515, 3]
^ ^----- 9 matrix entries (bin indices) -----^ ^
[CLS] [SEP]
values: [NaN, 0.12, 0.91, -0.05, 0.23, 0.01, -0.14, 0.08, 0.04, 0.19, NaN]
^ ^---------- 9 continuous float values ------------^ ^
[CLS] [SEP]
train/dataset_info.json — HuggingFace schema describing the features:
{
"features": {
"input_ids": {"feature": {"dtype": "int32"}, "_type": "List"},
"values": {"feature": {"dtype": "float64"}, "_type": "List"}
}
}
Train a BERT model (BertForMatrixRegression in models/bert_regression.py) to predict continuous matrix values from masked inputs (with online data generation):
CONFIG_FILE=configs/bert_matcomp.toml bash train_bert.sh --matrix.prec 3Input (from DataCollatorForMatrixCompletion):
input_ids: (B, L) LongTensor — tokenized matrix with some percentage of value tokens replaced by [MASK]labels: (B, L) FloatTensor — continuous float values; NaN only at special tokens ([CLS], [PAD]) and paddingattention_mask: (B, L) LongTensor — 1 for real tokens, 0 for padding
Output (from BertForMatrixRegression):
logits: (B, L, 1) FloatTensor — predicted continuous value per positionloss: scalar — MSE on all non-NaN positions
Training objective: The model learns to output the full continuous matrix. At masked positions it must predict from context; at non-masked positions it learns to refine the quantized input back to continuous values.
With loss_masked_only=True: Labels become NaN at non-masked positions, reverting to standard MLM-style denoising (loss only on masked tokens).
Matrices M = U @ V.T are normalized post-generation to have variance 1, so changing the initial range of the entries of U and V does not change the vocabulary size. The vocabulary size is controlled by prec, the precision of the entries (the number of digits after the decimal point in the entries of M).
Sequence length:
- Tokens per matrix: m * n + 1 (entries + [CLS])
| m (square) | tokens (m² + 2) |
|---|---|
| 10 | 101 |
| 15 | 226 |
| 20 | 401 |
| 22 | 485 |
Vocabulary size (with n_sigma=4):
n_bins = (val_max - val_min) * 10^prec + 1 = 8 * 10^prec + 1vocab_size = n_bins + 4(4 special tokens)- A precision of 2 means a resolution of 1e-2, and so on.
prec = 0gives integer quantization, with tokens representing integers from -4 to 4.
| prec | n_bins | vocab_size |
|---|---|---|
| 0 | 9 | 13 |
| 1 | 81 | 85 |
| 2 | 801 | 805 |
| 3 | 8001 | 8005 |
| 4 | 80001 | 80005 |