A record of the code and experiments for the paper:
C. Durkan, A. Bekasov, I. Murray, G. Papamakarios, Neural Spline Flows, NeurIPS 2019. [arXiv] [bibtex]
Work in this repository has now stopped. Please go to nflows for an updated and pip-installable normalizing flows framework for PyTorch.
See environment.yml for required Conda/pip packages, or use this to create a Conda environment with
all dependencies:
conda env create -f environment.ymlTested with Python 3.5 and PyTorch 1.1.
Data for density-estimation experiments is available at https://zenodo.org/record/1161203#.Wmtf_XVl8eN.
Data for VAE and image-modeling experiments is downloaded automatically using either torchvision or custom
data providers (NB: see the update below).
UPDATE (19 Oct 2025): Data for image-modeling experiments has been moved off Google Drive and uploaded to Zenodo (https://zenodo.org/records/17390372). Custom data providers will no longer work out of the box. To fix them, note that filenames on Zenodo start with the original Google Drive IDs as used in the data provider definitions (for example, here).
DATAROOT environment variable needs to be set before running experiments.
Use experiments/face.py or experiments/plane.py.
Use experiments/uci.py.
Use experiments/vae_.py.
Use experiments/images.py.
Sacred is used to organize image experiments. See the documentation for more information.
experiments/image_configs contains .json configurations used for RQ-NSF (C) experiments. For baseline experiments use coupling_layer_type='affine'.
For example, to run RQ-NSF (C) on CIFAR-10 8-bit:
python experiments/images.py with experiments/image_configs/cifar-10-8bit.jsonCorresponding affine baseline run:
python experiments/images.py with experiments/image_configs/cifar-10-8bit.json coupling_layer_type='affine'To evaluate on the test set:
python experiments/images.py eval_on_test with experiments/image_configs/cifar-10-8bit.json flow_checkpoint='<saved_checkpoint>'To sample:
python experiments/images.py sample with experiments/image_configs/cifar-10-8bit.json flow_checkpoint='<saved_checkpoint>'