TiDHy provides two conda environments optimized for different use cases:
The TiDHy environment includes JAX/Flax, RAPIDS for GPU-accelerated operations, and all necessary dependencies.
One-command setup:
bash setup_tidhy_env.shThis script will:
- Create the conda environment with Python 3.13, RAPIDS 25.10, and CUDA 12.x support
- Install all Python packages using UV (fast dependency resolver)
- Install TiDHy as an editable package
- Verify the installation and check GPU/CUDA availability
Activate the environment:
conda activate tidhyFor running SSM baseline comparisons (ARHMM, SLDS, etc.):
bash setup_ssm_env.shActivate the environment:
conda activate ssmIf you prefer manual setup:
-
Create conda environment:
conda env create -f environment.yaml # For TiDHy # OR conda env create -f ssm_environment.yml # For SSM baselines
-
Activate environment:
conda activate tidhy # or 'ssm' -
Install Python packages with UV:
uv pip install --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ 'jax[cuda12]>=0.4.20' 'jaxlib>=0.4.20' 'flax>=0.8.0' \ 'optax>=0.1.7' 'orbax-checkpoint>=0.4.0' 'chex>=0.1.8' \ 'dynamax>=1.0.0' 'scikit-learn>=1.3.0' \ 'hydra-core>=1.3.0' 'omegaconf>=2.3.0' 'wandb' \ 'tqdm>=4.65.0' 'natsort' -e .
- Conda/Miniconda: Required for environment management
- CUDA 12.x: For GPU acceleration (check with
nvidia-smi) - Python 3.13: Installed automatically by conda
- TensorFlow Probability + JAX 0.8+: TFP 0.25.0 requires a compatibility patch for JAX 0.8+. The patch is automatically applied in all entry point scripts (
Run_TiDHy_NNX_vmap.py, etc.) viaTiDHy.utils.tfp_jax_patch.apply_tfp_jax_patch(). If you import TiDHy modules directly, apply the patch before importing. - JAX: Version 0.8+ recommended for Python 3.13 support
- RAPIDS: Version 25.10 for latest features and Python 3.13 support
Check if JAX can detect your GPU:
python -c "import jax; print(jax.devices())"Expected output should show CUDA/GPU devices if properly configured.
Run the main training script with Hydra configuration overrides:
python Run_TiDHy_NNX_vmap.py dataset=SLDS model=sparsityAvailable datasets: SLDS, SSM, Rossler, AnymalTerrain, CalMS21
Available model configs: default_model, sparsity, r2_sparse,
SSM baseline (requires ssm environment):
conda activate ssm
python Run_SSM.py dataset=SLDSTo add a custom dataset you can load data in any way you want. The final formatting should follow the convention of:
- train_data:
(time x features) - val_data:
(time x features) - test_data:
(time x features)
The data can then be stacked with overlapping windows using the stack_data function:
train_inputs = stack_data(train_inputs,cfg.train.sequence_length,overlap=cfg.train.sequence_length//cfg.train.overlap_factor)