diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 3f4ac0c6..7f155dbb 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -225,6 +225,66 @@ def step( return pred_prev_sample, pred_original_sample + def reversed_step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_post_sample -> "x_t+1" + + # 1. get previous step value (=t+1) + prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas at timestep t+1 + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + + if self.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + + # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + return pred_post_sample, pred_original_sample + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: """ Add noise to the original samples. diff --git a/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.ipynb b/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.ipynb new file mode 100644 index 00000000..c092372b --- /dev/null +++ b/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.ipynb @@ -0,0 +1,960 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "63d95da6", + "metadata": {}, + "source": [ + "# Weakly Supervised Anomaly Detection with Classifier Guidance\n", + "\n", + "This tutorial illustrates how to use MONAI Generative Models for training a 2D gradient-guided anomaly detection using DDIMs [1].\n", + "\n", + "In summary, the tutorial will cover the following:\n", + "1. Loading and preprocessing a dataset (we extract the brain MRI dataset 2D slices from 3D volumes from the BraTS dataset)\n", + "2. Training a 2D diffusion model\n", + "3. Anomaly detection with the trained model\n", + "\n", + "This method results in anomaly heatmaps. It is weakly supervised. The information about labels is not fed to the model as segmentation masks but as a scalar signal (is there an anomaly or not), which is used to guide the diffusion process.\n", + "\n", + "During inference, the model generates a counterfactual image, which is then compared to the original image. The difference between the two images is used to generate an anomaly heatmap.\n", + "\n", + "[1] - Sanchez et al. [What is Healthy? Generative Counterfactual Diffusion for Lesion Localization](https://arxiv.org/abs/2207.12268). DGM 4 MICCAI 2022" + ] + }, + { + "cell_type": "markdown", + "id": "6b766027", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "972ed3f3", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MONAI version: 1.1.0\n", + "Numpy version: 1.23.5\n", + "Pytorch version: 1.13.1+cu117\n", + "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n", + "MONAI rev id: a2ec3752f54bfc3b40e7952234fbeb5452ed63e3\n", + "MONAI __file__: /remote/rds/users/s2086085/miniconda3/envs/pytorch_monai/lib/python3.10/site-packages/monai/__init__.py\n", + "\n", + "Optional dependencies:\n", + "Pytorch Ignite version: 0.4.10\n", + "Nibabel version: 5.0.1\n", + "scikit-image version: 0.19.3\n", + "Pillow version: 9.4.0\n", + "Tensorboard version: 2.12.0\n", + "gdown version: 4.6.4\n", + "TorchVision version: 0.14.1+cu117\n", + "tqdm version: 4.64.1\n", + "lmdb version: 1.4.0\n", + "psutil version: 5.9.4\n", + "pandas version: 1.5.3\n", + "einops version: 0.6.0\n", + "transformers version: 4.21.3\n", + "mlflow version: 2.1.1\n", + "pynrrd version: 1.0.0\n", + "\n", + "For details about installing the optional dependencies, please visit:\n", + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", + "\n" + ] + } + ], + "source": [ + "# Copyright 2020 MONAI Consortium\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUtotal_timestepsWARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "import os\n", + "import tempfile\n", + "import time\n", + "from typing import Dict\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import sys\n", + "from monai import transforms\n", + "from monai.apps import DecathlonDataset\n", + "from monai.config import print_config\n", + "from monai.data import DataLoader\n", + "from monai.utils import first, set_determinism\n", + "from torch.cuda.amp import GradScaler, autocast\n", + "from tqdm import tqdm\n", + "torch.multiprocessing.set_sharing_strategy('file_system')\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", + "from generative.inferers import DiffusionInferer\n", + "\n", + "from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet\n", + "from generative.networks.schedulers.ddim import DDIMScheduler\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "id": "7d4ff515", + "metadata": {}, + "source": [ + "## Setup data directory" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8b4323e7", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory" + ] + }, + { + "cell_type": "markdown", + "id": "99175d50", + "metadata": {}, + "source": [ + "## Set deterministic training for reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "34ea510f", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "set_determinism(42)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c3f70dd1-236a-47ff-a244-575729ad92ba", + "metadata": { + "tags": [] + }, + "source": [ + "## Setup BRATS Dataset - Transforms for extracting 2D slices from 3D volumes\n", + "\n", + "We now download the BraTS dataset and extract the 2D slices from the 3D volumes. The `slice_label` is used to indicate whether the slice contains an anomaly or not." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "6986f55c", + "metadata": {}, + "source": [ + "Here we use transforms to augment the training dataset, as usual:\n", + "\n", + "1. `LoadImaged` loads the hands images from files.\n", + "2. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n", + "3. The first `Lambdad` transform chooses the first channel of the image, which is the T1-weighted image.\n", + "4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm to match the original paper.\n", + "5. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n", + "6. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n", + "6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` )." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c68d2d91-9a0b-4ac1-ae49-f4a64edbd82a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ": Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.\n" + ] + } + ], + "source": [ + "channel = 0 # 0 = Flair\n", + "assert channel in [0, 1, 2, 3], \"Choose a valid channel\"\n", + "\n", + "train_transforms = transforms.Compose(\n", + " [\n", + " transforms.LoadImaged(keys=[\"image\", \"label\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", + " transforms.Lambdad(keys=[\"image\"], func=lambda x: x[channel, :, :, :]),\n", + " transforms.AddChanneld(keys=[\"image\"]),\n", + " transforms.EnsureTyped(keys=[\"image\", \"label\"]),\n", + " transforms.Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " transforms.Spacingd(keys=[\"image\", \"label\"], pixdim=(3.0, 3.0, 2.0), mode=(\"bilinear\", \"nearest\")),\n", + " transforms.CenterSpatialCropd(keys=[\"image\", \"label\"], roi_size=(64, 64, 44)),\n", + " transforms.ScaleIntensityRangePercentilesd(keys=\"image\", lower=0, upper=99.5, b_min=0, b_max=1),\n", + " transforms.RandSpatialCropd(keys=[\"image\", \"label\"], roi_size=(64, 64, 1), random_size=False),\n", + " transforms.Lambdad(keys=[\"image\", \"label\"], func=lambda x: x.squeeze(-1)),\n", + " transforms.CopyItemsd(keys=[\"label\"], times=1, names=[\"slice_label\"]),\n", + " transforms.Lambdad(keys=[\"slice_label\"], func=lambda x: 2.0 if x.sum() > 0 else 1.0),\n", + " ]\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "9d378ac6", + "metadata": {}, + "source": [ + "### Load Training and Validation Datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "da1927b0", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|██████████| 388/388 [02:41<00:00, 2.40it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Lenght of training data: 388\n", + "Train image shape torch.Size([1, 64, 64])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|██████████| 96/96 [00:39<00:00, 2.42it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Lenght of training data: 96\n", + "Validation Image shape torch.Size([1, 64, 64])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "\n", + "train_ds = DecathlonDataset(\n", + " root_dir=root_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"training\", # validation\n", + " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", + " num_workers=4,\n", + " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", + " seed=0,\n", + " transform=train_transforms,\n", + ")\n", + "print(f\"Lenght of training data: {len(train_ds)}\")\n", + "print(f'Train image shape {train_ds[0][\"image\"].shape}')\n", + "\n", + "val_ds = DecathlonDataset(\n", + " root_dir=root_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"validation\", # validation\n", + " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", + " num_workers=4,\n", + " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", + " seed=0,\n", + " transform=train_transforms,\n", + ")\n", + "print(f\"Lenght of training data: {len(val_ds)}\")\n", + "print(f'Validation Image shape {val_ds[0][\"image\"].shape}')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "08428bc6", + "metadata": {}, + "source": [ + "## Define network, scheduler, optimizer, and inferer\n", + "\n", + "At this step, we instantiate the MONAI components to create a DDIM, the UNET with conditioning, the noise scheduler, and the inferer used for training and sampling. We are using\n", + "the deterministic DDIM scheduler containing 1000 timesteps, and a 2D UNET with attention mechanisms.\n", + "\n", + "The `attention` mechanism is essential for ensuring good conditioning and images manipulation here. \n", + "\n", + "An `embedding layer`, which is also optimised during training, is used in the original work because it was empirically shown to improve conditioning compared to a single scalar information.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bee5913e", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\")\n", + "embedding_dimension = 64\n", + "model = DiffusionModelUNet(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " num_channels=(64, 64, 64),\n", + " attention_levels=(False, True, True),\n", + " num_res_blocks=1,\n", + " num_head_channels=16,\n", + " with_conditioning=True,\n", + " cross_attention_dim=embedding_dimension,\n", + " ).to(device)\n", + "embed = torch.nn.Embedding(num_embeddings=3, embedding_dim=embedding_dimension, padding_idx=0).to(device)\n", + "\n", + "scheduler = DDIMScheduler(\n", + " num_train_timesteps=1000,\n", + ")\n", + "optimizer = torch.optim.Adam(params=list(model.parameters()) + list(embed.parameters()), lr=1e-5)\n", + "\n", + "inferer = DiffusionInferer(scheduler)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f815ff34", + "metadata": {}, + "source": [ + "## Training a diffusion model with classifier-free guidance" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9a4fc901", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss 0.8078, Interval Loss 0.9115, Interval Loss Val 0.8100\n", + "Train Loss 0.6174, Interval Loss 0.7103, Interval Loss Val 0.6154\n", + "Train Loss 0.4603, Interval Loss 0.5313, Interval Loss Val 0.4571\n", + "Train Loss 0.3323, Interval Loss 0.3887, Interval Loss Val 0.3288\n", + "Train Loss 0.2358, Interval Loss 0.2822, Interval Loss Val 0.2356\n", + "Train Loss 0.1771, Interval Loss 0.2034, Interval Loss Val 0.1660\n", + "Train Loss 0.1237, Interval Loss 0.1467, Interval Loss Val 0.1185\n", + "Train Loss 0.0802, Interval Loss 0.1054, Interval Loss Val 0.0855\n", + "Train Loss 0.0731, Interval Loss 0.0770, Interval Loss Val 0.0699\n", + "Train Loss 0.0528, Interval Loss 0.0570, Interval Loss Val 0.0512\n", + "Train Loss 0.0383, Interval Loss 0.0434, Interval Loss Val 0.0311\n", + "Train Loss 0.0268, Interval Loss 0.0343, Interval Loss Val 0.0483\n", + "Train Loss 0.0330, Interval Loss 0.0292, Interval Loss Val 0.0255\n", + "Train Loss 0.0359, Interval Loss 0.0250, Interval Loss Val 0.0293\n", + "Train Loss 0.0235, Interval Loss 0.0233, Interval Loss Val 0.0318\n", + "Train Loss 0.0241, Interval Loss 0.0224, Interval Loss Val 0.0303\n", + "Train Loss 0.0171, Interval Loss 0.0211, Interval Loss Val 0.0217\n", + "Train Loss 0.0304, Interval Loss 0.0207, Interval Loss Val 0.0098\n", + "Train Loss 0.0269, Interval Loss 0.0211, Interval Loss Val 0.0128\n", + "Train Loss 0.0163, Interval Loss 0.0194, Interval Loss Val 0.0233\n", + "Train Loss 0.0319, Interval Loss 0.0195, Interval Loss Val 0.0389\n", + "Train Loss 0.0200, Interval Loss 0.0194, Interval Loss Val 0.0214\n", + "Train Loss 0.0190, Interval Loss 0.0203, Interval Loss Val 0.0235\n", + "Train Loss 0.0291, Interval Loss 0.0194, Interval Loss Val 0.0178\n", + "Train Loss 0.0263, Interval Loss 0.0196, Interval Loss Val 0.0248\n", + "Train Loss 0.0229, Interval Loss 0.0198, Interval Loss Val 0.0246\n", + "Train Loss 0.0183, Interval Loss 0.0195, Interval Loss Val 0.0239\n", + "Train Loss 0.0147, Interval Loss 0.0188, Interval Loss Val 0.0216\n", + "Train Loss 0.0255, Interval Loss 0.0181, Interval Loss Val 0.0150\n", + "Train Loss 0.0161, Interval Loss 0.0183, Interval Loss Val 0.0211\n", + "Train Loss 0.0239, Interval Loss 0.0182, Interval Loss Val 0.0097\n", + "Train Loss 0.0163, Interval Loss 0.0188, Interval Loss Val 0.0173\n", + "Train Loss 0.0188, Interval Loss 0.0177, Interval Loss Val 0.0121\n", + "Train Loss 0.0120, Interval Loss 0.0189, Interval Loss Val 0.0191\n", + "Train Loss 0.0201, Interval Loss 0.0182, Interval Loss Val 0.0130\n", + "Train Loss 0.0118, Interval Loss 0.0187, Interval Loss Val 0.0309\n", + "Train Loss 0.0114, Interval Loss 0.0181, Interval Loss Val 0.0221\n", + "Train Loss 0.0197, Interval Loss 0.0180, Interval Loss Val 0.0118\n", + "Train Loss 0.0229, Interval Loss 0.0176, Interval Loss Val 0.0278\n", + "Train Loss 0.0242, Interval Loss 0.0188, Interval Loss Val 0.0126\n", + "Train Loss 0.0166, Interval Loss 0.0182, Interval Loss Val 0.0157\n", + "Train Loss 0.0162, Interval Loss 0.0195, Interval Loss Val 0.0170\n", + "Train Loss 0.0124, Interval Loss 0.0179, Interval Loss Val 0.0261\n", + "Train Loss 0.0151, Interval Loss 0.0183, Interval Loss Val 0.0223\n", + "Train Loss 0.0308, Interval Loss 0.0188, Interval Loss Val 0.0151\n", + "Train Loss 0.0210, Interval Loss 0.0177, Interval Loss Val 0.0193\n", + "Train Loss 0.0175, Interval Loss 0.0184, Interval Loss Val 0.0232\n", + "Train Loss 0.0270, Interval Loss 0.0184, Interval Loss Val 0.0125\n", + "Train Loss 0.0128, Interval Loss 0.0181, Interval Loss Val 0.0224\n", + "Train Loss 0.0170, Interval Loss 0.0188, Interval Loss Val 0.0199\n", + "Train Loss 0.0203, Interval Loss 0.0176, Interval Loss Val 0.0145\n", + "Train Loss 0.0248, Interval Loss 0.0176, Interval Loss Val 0.0149\n", + "Train Loss 0.0213, Interval Loss 0.0188, Interval Loss Val 0.0120\n", + "Train Loss 0.0324, Interval Loss 0.0181, Interval Loss Val 0.0310\n", + "Train Loss 0.0302, Interval Loss 0.0183, Interval Loss Val 0.0119\n", + "Train Loss 0.0206, Interval Loss 0.0168, Interval Loss Val 0.0060\n", + "Train Loss 0.0117, Interval Loss 0.0176, Interval Loss Val 0.0208\n", + "Train Loss 0.0213, Interval Loss 0.0172, Interval Loss Val 0.0178\n", + "Train Loss 0.0156, Interval Loss 0.0176, Interval Loss Val 0.0201\n", + "Train Loss 0.0134, Interval Loss 0.0180, Interval Loss Val 0.0116\n", + "Train Loss 0.0305, Interval Loss 0.0176, Interval Loss Val 0.0194\n", + "Train Loss 0.0253, Interval Loss 0.0171, Interval Loss Val 0.0230\n", + "Train Loss 0.0198, Interval Loss 0.0173, Interval Loss Val 0.0246\n", + "Train Loss 0.0139, Interval Loss 0.0175, Interval Loss Val 0.0146\n", + "Train Loss 0.0154, Interval Loss 0.0168, Interval Loss Val 0.0269\n", + "Train Loss 0.0086, Interval Loss 0.0174, Interval Loss Val 0.0132\n", + "Train Loss 0.0156, Interval Loss 0.0171, Interval Loss Val 0.0170\n", + "Train Loss 0.0292, Interval Loss 0.0174, Interval Loss Val 0.0165\n", + "Train Loss 0.0220, Interval Loss 0.0179, Interval Loss Val 0.0171\n", + "Train Loss 0.0132, Interval Loss 0.0175, Interval Loss Val 0.0116\n", + "Train Loss 0.0158, Interval Loss 0.0174, Interval Loss Val 0.0165\n", + "Train Loss 0.0325, Interval Loss 0.0172, Interval Loss Val 0.0235\n", + "Train Loss 0.0105, Interval Loss 0.0169, Interval Loss Val 0.0171\n", + "Train Loss 0.0082, Interval Loss 0.0173, Interval Loss Val 0.0146\n", + "Train Loss 0.0232, Interval Loss 0.0180, Interval Loss Val 0.0132\n", + "Train Loss 0.0120, Interval Loss 0.0170, Interval Loss Val 0.0275\n", + "Train Loss 0.0211, Interval Loss 0.0182, Interval Loss Val 0.0107\n", + "Train Loss 0.0247, Interval Loss 0.0169, Interval Loss Val 0.0162\n", + "Train Loss 0.0196, Interval Loss 0.0178, Interval Loss Val 0.0365\n", + "Train Loss 0.0247, Interval Loss 0.0173, Interval Loss Val 0.0185\n", + "Train Loss 0.0174, Interval Loss 0.0180, Interval Loss Val 0.0153\n", + "Train Loss 0.0203, Interval Loss 0.0185, Interval Loss Val 0.0208\n", + "Train Loss 0.0112, Interval Loss 0.0172, Interval Loss Val 0.0296\n", + "Train Loss 0.0215, Interval Loss 0.0165, Interval Loss Val 0.0155\n", + "Train Loss 0.0144, Interval Loss 0.0172, Interval Loss Val 0.0194\n", + "Train Loss 0.0192, Interval Loss 0.0179, Interval Loss Val 0.0195\n", + "Train Loss 0.0175, Interval Loss 0.0178, Interval Loss Val 0.0092\n", + "Train Loss 0.0082, Interval Loss 0.0180, Interval Loss Val 0.0323\n", + "Train Loss 0.0234, Interval Loss 0.0168, Interval Loss Val 0.0118\n", + "Train Loss 0.0234, Interval Loss 0.0172, Interval Loss Val 0.0192\n", + "Train Loss 0.0088, Interval Loss 0.0172, Interval Loss Val 0.0262\n", + "Train Loss 0.0189, Interval Loss 0.0179, Interval Loss Val 0.0313\n", + "Train Loss 0.0081, Interval Loss 0.0182, Interval Loss Val 0.0181\n", + "Train Loss 0.0195, Interval Loss 0.0168, Interval Loss Val 0.0164\n", + "Train Loss 0.0280, Interval Loss 0.0166, Interval Loss Val 0.0132\n", + "Train Loss 0.0198, Interval Loss 0.0179, Interval Loss Val 0.0125\n", + "Train Loss 0.0182, Interval Loss 0.0167, Interval Loss Val 0.0208\n", + "Train Loss 0.0099, Interval Loss 0.0171, Interval Loss Val 0.0119\n", + "Train Loss 0.0271, Interval Loss 0.0169, Interval Loss Val 0.0156\n", + "Train Loss 0.0119, Interval Loss 0.0164, Interval Loss Val 0.0114\n", + "Train Loss 0.0172, Interval Loss 0.0165, Interval Loss Val 0.0228\n", + "Train Loss 0.0198, Interval Loss 0.0167, Interval Loss Val 0.0179\n", + "Train Loss 0.0165, Interval Loss 0.0164, Interval Loss Val 0.0110\n", + "Train Loss 0.0172, Interval Loss 0.0167, Interval Loss Val 0.0076\n", + "Train Loss 0.0140, Interval Loss 0.0171, Interval Loss Val 0.0172\n", + "Train Loss 0.0090, Interval Loss 0.0176, Interval Loss Val 0.0115\n", + "Train Loss 0.0120, Interval Loss 0.0167, Interval Loss Val 0.0202\n", + "Train Loss 0.0137, Interval Loss 0.0166, Interval Loss Val 0.0169\n", + "Train Loss 0.0113, Interval Loss 0.0171, Interval Loss Val 0.0131\n", + "Train Loss 0.0187, Interval Loss 0.0171, Interval Loss Val 0.0072\n", + "Train Loss 0.0302, Interval Loss 0.0158, Interval Loss Val 0.0169\n", + "Train Loss 0.0220, Interval Loss 0.0172, Interval Loss Val 0.0233\n", + "Train Loss 0.0238, Interval Loss 0.0169, Interval Loss Val 0.0193\n", + "Train Loss 0.0077, Interval Loss 0.0173, Interval Loss Val 0.0176\n", + "Train Loss 0.0221, Interval Loss 0.0172, Interval Loss Val 0.0221\n", + "Train Loss 0.0106, Interval Loss 0.0161, Interval Loss Val 0.0170\n", + "Train Loss 0.0179, Interval Loss 0.0167, Interval Loss Val 0.0281\n", + "Train Loss 0.0147, Interval Loss 0.0164, Interval Loss Val 0.0167\n", + "Train Loss 0.0223, Interval Loss 0.0171, Interval Loss Val 0.0150\n", + "Train Loss 0.0206, Interval Loss 0.0176, Interval Loss Val 0.0159\n", + "Train Loss 0.0138, Interval Loss 0.0168, Interval Loss Val 0.0175\n", + "Train Loss 0.0202, Interval Loss 0.0165, Interval Loss Val 0.0108\n", + "Train Loss 0.0154, Interval Loss 0.0166, Interval Loss Val 0.0123\n", + "Train Loss 0.0225, Interval Loss 0.0175, Interval Loss Val 0.0133\n", + "Train Loss 0.0146, Interval Loss 0.0161, Interval Loss Val 0.0172\n", + "Train Loss 0.0166, Interval Loss 0.0166, Interval Loss Val 0.0120\n", + "Train Loss 0.0140, Interval Loss 0.0169, Interval Loss Val 0.0137\n", + "Train Loss 0.0165, Interval Loss 0.0171, Interval Loss Val 0.0142\n", + "Train Loss 0.0245, Interval Loss 0.0162, Interval Loss Val 0.0301\n", + "Train Loss 0.0207, Interval Loss 0.0162, Interval Loss Val 0.0135\n", + "Train Loss 0.0183, Interval Loss 0.0162, Interval Loss Val 0.0112\n", + "Train Loss 0.0124, Interval Loss 0.0161, Interval Loss Val 0.0128\n", + "Train Loss 0.0168, Interval Loss 0.0169, Interval Loss Val 0.0119\n", + "Train Loss 0.0042, Interval Loss 0.0180, Interval Loss Val 0.0175\n", + "Train Loss 0.0207, Interval Loss 0.0161, Interval Loss Val 0.0143\n", + "Train Loss 0.0195, Interval Loss 0.0157, Interval Loss Val 0.0217\n", + "Train Loss 0.0231, Interval Loss 0.0169, Interval Loss Val 0.0113\n", + "Train Loss 0.0193, Interval Loss 0.0152, Interval Loss Val 0.0221\n", + "Train Loss 0.0206, Interval Loss 0.0166, Interval Loss Val 0.0152\n", + "Train Loss 0.0133, Interval Loss 0.0161, Interval Loss Val 0.0078\n", + "Train Loss 0.0172, Interval Loss 0.0162, Interval Loss Val 0.0109\n", + "Train Loss 0.0204, Interval Loss 0.0160, Interval Loss Val 0.0247\n", + "Train Loss 0.0113, Interval Loss 0.0159, Interval Loss Val 0.0142\n", + "Train Loss 0.0196, Interval Loss 0.0167, Interval Loss Val 0.0183\n", + "Train Loss 0.0088, Interval Loss 0.0156, Interval Loss Val 0.0139\n", + "Train Loss 0.0205, Interval Loss 0.0170, Interval Loss Val 0.0124\n", + "Train Loss 0.0149, Interval Loss 0.0175, Interval Loss Val 0.0309\n", + "Train Loss 0.0115, Interval Loss 0.0165, Interval Loss Val 0.0147\n", + "Train Loss 0.0133, Interval Loss 0.0166, Interval Loss Val 0.0148\n", + "Train Loss 0.0225, Interval Loss 0.0160, Interval Loss Val 0.0166\n", + "Train Loss 0.0147, Interval Loss 0.0163, Interval Loss Val 0.0147\n", + "Train Loss 0.0107, Interval Loss 0.0169, Interval Loss Val 0.0105\n", + "Train Loss 0.0108, Interval Loss 0.0159, Interval Loss Val 0.0139\n", + "Train Loss 0.0079, Interval Loss 0.0161, Interval Loss Val 0.0176\n", + "Train Loss 0.0101, Interval Loss 0.0168, Interval Loss Val 0.0085\n", + "Train Loss 0.0210, Interval Loss 0.0178, Interval Loss Val 0.0146\n", + "Train Loss 0.0146, Interval Loss 0.0174, Interval Loss Val 0.0203\n", + "Train Loss 0.0151, Interval Loss 0.0156, Interval Loss Val 0.0164\n", + "Train Loss 0.0092, Interval Loss 0.0163, Interval Loss Val 0.0126\n", + "Train Loss 0.0140, Interval Loss 0.0160, Interval Loss Val 0.0118\n", + "Train Loss 0.0164, Interval Loss 0.0168, Interval Loss Val 0.0174\n", + "Train Loss 0.0089, Interval Loss 0.0166, Interval Loss Val 0.0143\n", + "Train Loss 0.0165, Interval Loss 0.0177, Interval Loss Val 0.0122\n", + "Train Loss 0.0244, Interval Loss 0.0171, Interval Loss Val 0.0151\n", + "Train Loss 0.0148, Interval Loss 0.0163, Interval Loss Val 0.0209\n", + "Train Loss 0.0162, Interval Loss 0.0160, Interval Loss Val 0.0166\n", + "Train Loss 0.0175, Interval Loss 0.0173, Interval Loss Val 0.0167\n", + "Train Loss 0.0246, Interval Loss 0.0172, Interval Loss Val 0.0165\n", + "Train Loss 0.0150, Interval Loss 0.0168, Interval Loss Val 0.0151\n", + "Train Loss 0.0307, Interval Loss 0.0169, Interval Loss Val 0.0063\n", + "Train Loss 0.0121, Interval Loss 0.0170, Interval Loss Val 0.0179\n", + "Train Loss 0.0113, Interval Loss 0.0156, Interval Loss Val 0.0127\n", + "Train Loss 0.0139, Interval Loss 0.0156, Interval Loss Val 0.0155\n", + "Train Loss 0.0109, Interval Loss 0.0165, Interval Loss Val 0.0148\n", + "Train Loss 0.0090, Interval Loss 0.0168, Interval Loss Val 0.0269\n", + "Train Loss 0.0246, Interval Loss 0.0169, Interval Loss Val 0.0161\n", + "Train Loss 0.0228, Interval Loss 0.0160, Interval Loss Val 0.0166\n", + "Train Loss 0.0124, Interval Loss 0.0156, Interval Loss Val 0.0124\n", + "Train Loss 0.0120, Interval Loss 0.0160, Interval Loss Val 0.0134\n", + "Train Loss 0.0214, Interval Loss 0.0167, Interval Loss Val 0.0192\n", + "Train Loss 0.0194, Interval Loss 0.0165, Interval Loss Val 0.0142\n", + "Train Loss 0.0092, Interval Loss 0.0153, Interval Loss Val 0.0148\n", + "Train Loss 0.0198, Interval Loss 0.0165, Interval Loss Val 0.0115\n", + "Train Loss 0.0139, Interval Loss 0.0172, Interval Loss Val 0.0144\n", + "Train Loss 0.0069, Interval Loss 0.0165, Interval Loss Val 0.0155\n", + "Train Loss 0.0109, Interval Loss 0.0160, Interval Loss Val 0.0197\n", + "Train Loss 0.0243, Interval Loss 0.0165, Interval Loss Val 0.0154\n", + "Train Loss 0.0124, Interval Loss 0.0172, Interval Loss Val 0.0245\n", + "Train Loss 0.0234, Interval Loss 0.0160, Interval Loss Val 0.0198\n", + "Train Loss 0.0129, Interval Loss 0.0163, Interval Loss Val 0.0144\n", + "Train Loss 0.0201, Interval Loss 0.0168, Interval Loss Val 0.0098\n", + "Train Loss 0.0203, Interval Loss 0.0157, Interval Loss Val 0.0135\n", + "Train Loss 0.0161, Interval Loss 0.0159, Interval Loss Val 0.0194\n", + "Train Loss 0.0147, Interval Loss 0.0160, Interval Loss Val 0.0232\n", + "Train Loss 0.0165, Interval Loss 0.0166, Interval Loss Val 0.0133\n", + "Train Loss 0.0175, Interval Loss 0.0165, Interval Loss Val 0.0261\n", + "Train Loss 0.0184, Interval Loss 0.0161, Interval Loss Val 0.0124\n", + "Train Loss 0.0114, Interval Loss 0.0170, Interval Loss Val 0.0103\n", + "Train Loss 0.0224, Interval Loss 0.0160, Interval Loss Val 0.0199\n", + "Train Loss 0.0227, Interval Loss 0.0154, Interval Loss Val 0.0085\n", + "train diffusion completed, total time: 5660.153350830078.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The seaborn styles shipped by Matplotlib are deprecated since 3.6, as they no longer correspond to the styles shipped by seaborn. However, they will remain available as 'seaborn-v0_8-