Skip to content

pixelhero98/LLapDiff

Repository files navigation

LLapDiff

LLapDiff is a Laplace-domain latent diffusion model for irregular, partially observed panel time series. The public mainline contains joint forecast/imputation training, and generic training/evaluation entry points.

What is in this repo

  • Dataset/ Dataset-specific cache builders, loaders, and dataset statistics tools.
  • Latent_Space/ The latent VAE implementation and utilities.
  • Model/ The summarizer, LLapDiff backbone, and diffusion utilities.
  • config.py Generic global defaults only.
  • dataset_defaults.py The per-dataset preset table.
  • train_val_pipeline.py Canonical end-to-end runner for VAE + summarizer + LLapDiff.
  • llapdiff_checkpoint_eval.py Generic checkpoint evaluation for forecast and imputation.
  • run_multidataset_artifact_prep.py VAE/summarizer artifact preparation and health audit.
  • Viz/plot_llapdiff_poles.py Pole-visualization utility for trained checkpoints.

Public defaults

The canonical public recipe is:

  • PREDICT_TYPE="v"
  • PRIMARY_EVAL_METRIC="crps"
  • LOSS_WEIGHT_SCHEME="weighted_min_snr"
  • BASE_LR=1.5e-4
  • DATES_PER_BATCH=1
  • joint target-mask training enabled by default

Dataset-specific values are centralized in dataset_defaults.py. Table-listed values win where they are explicitly represented in the code.

Supported datasets

The preset registry currently supports:

  • bms_air
  • uci_air
  • physionet
  • noaa_us
  • noaa_uk
  • us_equity
  • crypto

Default horizons and context lengths:

Dataset Horizons Context
bms_air 24, 48, 96, 168 336
uci_air 24, 48, 96, 168 336
physionet 4, 8, 10, 12 24
noaa_us 24, 48, 96, 168 336
noaa_uk 24, 48, 96, 168 336
us_equity 5, 20, 60, 100 200
crypto 5, 20, 60, 100 200

Financial ticker lists used by the cache builders live at:

Environment

The training stack expects Python 3.11 plus the standard scientific/PyTorch stack used throughout the repo, including:

  • torch
  • numpy
  • pandas
  • matplotlib
  • pyarrow
  • fastparquet
  • yfinance
  • requests
  • tqdm

Preparing dataset caches

Main cache-builder entrypoints:

Financial loaders and several dataset loaders can also rebuild the window index on demand through run_experiment(..., reindex=True).

Checking dataset statistics

Use the public cache summary tool:

cd /path/to/LLapDiff
python Dataset/dataset_summary.py \
  --data-dir /path/to/LLapDiff/Dataset/fin_dataset/crypto \
  --coverage 0.0 \
  --per-asset

This reads the prepared cache_ratio_index/ tree and reports panel size, split counts, missingness, and coverage-sensitive step counts.

Preparing VAE and summarizer artifacts

To dry-run the multi-dataset artifact plan:

cd /path/to/LLapDiff
python run_multidataset_artifact_prep.py \
  --dry-run \
  --datasets physionet bms_air

To train or reuse VAE and summarizer artifacts and emit a health report:

cd /path/to/LLapDiff
python run_multidataset_artifact_prep.py \
  --datasets bms_air uci_air physionet noaa_us noaa_uk us_equity \
  --summary-json /tmp/multidataset_artifact_prep_summary.json

Artifacts are written under:

  • ldt/vae/saved_model/<dataset>/
  • ldt/summarizer/saved_model/<dataset>/

Canonical training

The canonical training entrypoint is train_val_pipeline.py.

Example: train the full crypto stack for all preset horizons:

cd /path/to/LLapDiff
python train_val_pipeline.py \
  --dataset-key crypto \
  --summary-json /tmp/crypto_pipeline_summary.json

Example: run only one horizon and force artifact recomputation:

cd /path/to/LLapDiff
python train_val_pipeline.py \
  --dataset-key us_equity \
  --preds 100 \
  --recompute-vae \
  --recompute-summarizer \
  --summary-json /tmp/us_equity_pred100.json

Evaluating a checkpoint

Use the generic evaluator:

cd /path/to/LLapDiff
python llapdiff_checkpoint_eval.py \
  --dataset-key crypto \
  --pred 100 \
  --checkpoint /path/to/LLapDiff/ldt/output/crypto/llapdiff_pred-100_best_raw.pt \
  --out-json /tmp/crypto_eval.json

The output includes:

  • forecast_test
  • regular_keep25
  • random_keep50
  • balanced_summary

Pole visualization

Pole plotting is handled by Viz/plot_llapdiff_poles.py.

Example:

cd /path/to/LLapDiff
python Viz/plot_llapdiff_poles.py \
  --dataset-key crypto \
  --pred 100 \
  --checkpoint /path/to/LLapDiff/ldt/output/crypto_cov0_jointmix_vpred_dates1_lr15e4/llapdiff_pred-100_best_raw.pt \
  --output-dir /tmp/pole_plot_smoke

This writes a PDF into the requested output directory.

How to tune the model

Use this order when changing the training recipe:

  1. dataset/cache sanity
  2. normalization
  3. training objective / parameterization
  4. architecture last

For practical tuning in this repo:

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages