Skip to content

m-laprise/bert-network

Repository files navigation

bert-network

Generic bert training

Usage:

export HF_OFFLINE=1
export HF_DATASETS_OFFLINE=1
NGPU=2 bash train_bert.sh --training.steps 1000

Or for a custom config:

WANDB_NAME="my-run-name" CONFIG_FILE=configs/my_config.toml NGPU=2 bash train_bert.sh

Model architecture is defined in configs/bert_arch.json. Training hyperparameters are set in configs/bert_training.toml.

Note that the trainer won't automatically try to load from a checkpoint. If you want to resume from a checkpoint:

  • Set load_step = -1 to load the latest checkpoint
  • Set load_step = 1000 to load a specific step
  • Or pass it via command line: --checkpoint.load_step -1

To debug flame/train.py with specific arguments, run:

python -m debugpy --listen 5678 --wait-for-client -m torch.distributed.run \
  --nproc_per_node=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 \
  -m flame.train --job.config_file configs/bert_training.toml

Then attach using "Debug Training (Attach to Process)" in VSCode.

Matrix completion data generation module

Components:

  1. MatrixTokenizer - Discretizes values into n_bins bins with special tokens [PAD], [MASK], [CLS], [SEP], [ROW_SEP]
  2. generate_matrix_dataset() - Generates low-rank matrices offline with: - Variance normalization (all matrices have var=1) - Stores both input_ids (tokenized) and values (continuous floats) - Logs stats: nuclear norm, spectral gap, mean, variance, trace, range
  3. DataCollatorForMatrixCompletion - MLM-style masking that outputs: - input_ids: with [MASK] tokens at masked positions - labels: continuous float values (NaN for non-masked positions) - attention_mask: standard padding mask
  4. build_dataloader_matrix_completion() - FLAME-compatible dataloader builder

Batch output format:

{
    "input_ids": torch.LongTensor,      # [batch, seq_len]
    "labels": torch.FloatTensor,        # [batch, seq_len], NaN for non-masked
    "attention_mask": torch.LongTensor, # [batch, seq_len]
}

CLI usage:

python -m data.matrix_completion --m 20 --n 20 --r 5 --n_samples 1000 --output_dir data/generated/my_dataset

Expected structure and content of data/generated/:

data/generated/ └── <dataset_name>/ # e.g., matrix_m20_n20_r5_seed42/ ├── config.json # Generation hyperparameters ├── dataset_stats.json # Aggregate statistics over all matrices ├── tokenizer_config.json # Tokenizer settings └── train/ ├── data-00000-of-00001.arrow # HuggingFace Dataset (binary) ├── dataset_info.json # Dataset schema └── state.json # HuggingFace internal state

train/data-*.arrow — The actual dataset in Apache Arrow format. Each sample contains:

  • input_ids: List[int] — Tokenized matrix with [CLS], bin tokens, [SEP]

  • values: List[float] — Continuous float values aligned with input_ids (NaN for special tokens)

  • Example for a 3×3 matrix without row separators:

input_ids: [2, 505, 612, 498, 523, 501, 489, 510, 507, 515, 3]
            ^  ^----- 9 matrix entries (bin indices) -----^  ^
          [CLS]                                            [SEP]

values:    [NaN, 0.12, 0.91, -0.05, 0.23, 0.01, -0.14, 0.08, 0.04, 0.19, NaN]
            ^   ^---------- 9 continuous float values ------------^     ^
          [CLS]                                                       [SEP]

train/dataset_info.json — HuggingFace schema describing the features:

{
  "features": {
    "input_ids": {"feature": {"dtype": "int32"}, "_type": "List"},
    "values": {"feature": {"dtype": "float64"}, "_type": "List"}
  }
}

Training BERT for matrix completion

Train a BERT model (BertForMatrixRegression in models/bert_regression.py) to predict continuous matrix values from masked inputs (with online data generation):

CONFIG_FILE=configs/bert_matcomp.toml bash train_bert.sh --matrix.prec 3

Input (from DataCollatorForMatrixCompletion):

  • input_ids: (B, L) LongTensor — tokenized matrix with some percentage of value tokens replaced by [MASK]
  • labels: (B, L) FloatTensor — continuous float values; NaN only at special tokens ([CLS], [PAD]) and padding
  • attention_mask: (B, L) LongTensor — 1 for real tokens, 0 for padding

Output (from BertForMatrixRegression):

  • logits: (B, L, 1) FloatTensor — predicted continuous value per position
  • loss: scalar — MSE on all non-NaN positions

Training objective: The model learns to output the full continuous matrix. At masked positions it must predict from context; at non-masked positions it learns to refine the quantized input back to continuous values.

With loss_masked_only=True: Labels become NaN at non-masked positions, reverting to standard MLM-style denoising (loss only on masked tokens).

Notes about tokenization and vocabulary size

Matrices M = U @ V.T are normalized post-generation to have variance 1, so changing the initial range of the entries of U and V does not change the vocabulary size. The vocabulary size is controlled by prec, the precision of the entries (the number of digits after the decimal point in the entries of M).

Sequence length:

  • Tokens per matrix: m * n + 1 (entries + [CLS])
m (square) tokens (m² + 2)
10 101
15 226
20 401
22 485

Vocabulary size (with n_sigma=4):

  • n_bins = (val_max - val_min) * 10^prec + 1 = 8 * 10^prec + 1
  • vocab_size = n_bins + 4 (4 special tokens)
  • A precision of 2 means a resolution of 1e-2, and so on.
  • prec = 0 gives integer quantization, with tokens representing integers from -4 to 4.
prec n_bins vocab_size
0 9 13
1 81 85
2 801 805
3 8001 8005
4 80001 80005

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors