From cbe51e323d9075c359a5277ae96c7351c230d2f5 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 3 May 2023 21:09:26 -0600 Subject: [PATCH 1/4] First draft notebook --- .../2d_controlnet/2d_controlnet.ipynb | 975 ++++++++++++++++++ 1 file changed, 975 insertions(+) create mode 100644 tutorials/generative/2d_controlnet/2d_controlnet.ipynb diff --git a/tutorials/generative/2d_controlnet/2d_controlnet.ipynb b/tutorials/generative/2d_controlnet/2d_controlnet.ipynb new file mode 100644 index 00000000..bbff623f --- /dev/null +++ b/tutorials/generative/2d_controlnet/2d_controlnet.ipynb @@ -0,0 +1,975 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "70eef519", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) MONAI Consortium\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "63d95da6", + "metadata": {}, + "source": [ + "# Weakly Supervised Anomaly Detection with Implicit Guidance\n", + "\n", + "This tutorial illustrates how to use MONAI Generative Models for training a 2D anomaly detection using DDIMs [1]. By leveraging recent advances in generative diffusion probabilistic models, we synthesize counterfactuals of \"How would a patient appear if X pathology was not present?\". The difference image between the observed patient state and the healthy counterfactual can be used for inferring the location of pathology. We generate counterfactuals that correspond to the minimal change of the input such that it is transformed to healthy domain. We create these counterfactual diffusion models by manipulating the generation process with implicit guidance.\n", + "\n", + "In summary, the tutorial will cover the following:\n", + "1. Loading and preprocessing a dataset (we extract the brain MRI dataset 2D slices from 3D volumes from the BraTS dataset)\n", + "2. Training a 2D diffusion model\n", + "3. Anomaly detection with the trained model\n", + "\n", + "This method results in anomaly heatmaps. It is weakly supervised. The information about labels is not fed to the model as segmentation masks but as a scalar signal (is there an anomaly or not), which is used to guide the diffusion process.\n", + "\n", + "During inference, the model generates a counterfactual image, which is then compared to the original image. The difference between the two images is used to generate an anomaly heatmap.\n", + "\n", + "[1] - Sanchez et al. [What is Healthy? Generative Counterfactual Diffusion for Lesion Localization](https://arxiv.org/abs/2207.12268). DGM 4 MICCAI 2022" + ] + }, + { + "cell_type": "markdown", + "id": "6b766027", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "972ed3f3", + "metadata": { + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MONAI version: 1.1.0\n", + "Numpy version: 1.23.5\n", + "Pytorch version: 1.13.1+cu117\n", + "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n", + "MONAI rev id: a2ec3752f54bfc3b40e7952234fbeb5452ed63e3\n", + "MONAI __file__: /remote/rds/users/s2086085/miniconda3/envs/pytorch_monai/lib/python3.10/site-packages/monai/__init__.py\n", + "\n", + "Optional dependencies:\n", + "Pytorch Ignite version: 0.4.10\n", + "Nibabel version: 5.0.1\n", + "scikit-image version: 0.19.3\n", + "Pillow version: 9.4.0\n", + "Tensorboard version: 2.12.0\n", + "gdown version: 4.6.4\n", + "TorchVision version: 0.14.1+cu117\n", + "tqdm version: 4.64.1\n", + "lmdb version: 1.4.0\n", + "psutil version: 5.9.4\n", + "pandas version: 1.5.3\n", + "einops version: 0.6.0\n", + "transformers version: 4.21.3\n", + "mlflow version: 2.1.1\n", + "pynrrd version: 1.0.0\n", + "\n", + "For details about installing the optional dependencies, please visit:\n", + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", + "\n" + ] + } + ], + "source": [ + "import os\n", + "import tempfile\n", + "import time\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import sys\n", + "from monai import transforms\n", + "from monai.apps import DecathlonDataset\n", + "from monai.config import print_config\n", + "from monai.data import DataLoader\n", + "from monai.utils import set_determinism\n", + "from torch.cuda.amp import GradScaler, autocast\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "from generative.inferers import DiffusionInferer\n", + "from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet\n", + "from generative.networks.schedulers.ddim import DDIMScheduler\n", + "\n", + "torch.multiprocessing.set_sharing_strategy(\"file_system\")\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "id": "7d4ff515", + "metadata": {}, + "source": [ + "## Setup data directory" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8b4323e7", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory" + ] + }, + { + "cell_type": "markdown", + "id": "99175d50", + "metadata": {}, + "source": [ + "## Set deterministic training for reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "34ea510f", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "set_determinism(42)" + ] + }, + { + "cell_type": "markdown", + "id": "c3f70dd1-236a-47ff-a244-575729ad92ba", + "metadata": {}, + "source": [ + "## Setup BRATS Dataset - Transforms for extracting 2D slices from 3D volumes\n", + "\n", + "We now download the BraTS dataset and extract the 2D slices from the 3D volumes. The `slice_label` is used to indicate whether the slice contains an anomaly or not." + ] + }, + { + "cell_type": "markdown", + "id": "6986f55c", + "metadata": {}, + "source": [ + "Here we use transforms to augment the training dataset, as usual:\n", + "\n", + "1. `LoadImaged` loads the brain images from files.\n", + "2. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n", + "3. The first `Lambdad` transform chooses the first channel of the image, which is the T1-weighted image.\n", + "4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm to match the original paper.\n", + "5. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n", + "6. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n", + "6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` )." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c68d2d91-9a0b-4ac1-ae49-f4a64edbd82a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ": Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.\n" + ] + } + ], + "source": [ + "channel = 0 # 0 = Flair\n", + "assert channel in [0, 1, 2, 3], \"Choose a valid channel\"\n", + "\n", + "train_transforms = transforms.Compose(\n", + " [\n", + " transforms.LoadImaged(keys=[\"image\", \"label\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", + " transforms.Lambdad(keys=[\"image\"], func=lambda x: x[channel, :, :, :]),\n", + " transforms.AddChanneld(keys=[\"image\"]),\n", + " transforms.EnsureTyped(keys=[\"image\", \"label\"]),\n", + " transforms.Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " transforms.Spacingd(keys=[\"image\", \"label\"], pixdim=(3.0, 3.0, 2.0), mode=(\"bilinear\", \"nearest\")),\n", + " transforms.CenterSpatialCropd(keys=[\"image\", \"label\"], roi_size=(64, 64, 44)),\n", + " transforms.ScaleIntensityRangePercentilesd(keys=\"image\", lower=0, upper=99.5, b_min=0, b_max=1),\n", + " transforms.RandSpatialCropd(keys=[\"image\", \"label\"], roi_size=(64, 64, 1), random_size=False),\n", + " transforms.Lambdad(keys=[\"image\", \"label\"], func=lambda x: x.squeeze(-1)),\n", + " transforms.CopyItemsd(keys=[\"label\"], times=1, names=[\"slice_label\"]),\n", + " transforms.Lambdad(keys=[\"slice_label\"], func=lambda x: 2.0 if x.sum() > 0 else 1.0),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9d378ac6", + "metadata": {}, + "source": [ + "### Load Training and Validation Datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "da1927b0", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|██████████| 388/388 [02:41<00:00, 2.40it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Lenght of training data: 388\n", + "Train image shape torch.Size([1, 64, 64])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|██████████| 96/96 [00:39<00:00, 2.42it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Lenght of training data: 96\n", + "Validation Image shape torch.Size([1, 64, 64])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "train_ds = DecathlonDataset(\n", + " root_dir=root_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"training\",\n", + " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", + " num_workers=4,\n", + " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", + " seed=0,\n", + " transform=train_transforms,\n", + ")\n", + "print(f\"Length of training data: {len(train_ds)}\")\n", + "print(f'Train image shape {train_ds[0][\"image\"].shape}')\n", + "\n", + "val_ds = DecathlonDataset(\n", + " root_dir=root_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"validation\",\n", + " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", + " num_workers=4,\n", + " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", + " seed=0,\n", + " transform=train_transforms,\n", + ")\n", + "print(f\"Length of training data: {len(val_ds)}\")\n", + "print(f'Validation Image shape {val_ds[0][\"image\"].shape}')" + ] + }, + { + "cell_type": "markdown", + "id": "08428bc6", + "metadata": {}, + "source": [ + "## Define network, scheduler, optimizer, and inferer\n", + "\n", + "At this step, we instantiate the MONAI components to create a DDIM, the UNET with conditioning, the noise scheduler, and the inferer used for training and sampling. We are using\n", + "the deterministic DDIM scheduler containing 1000 timesteps, and a 2D UNET with attention mechanisms.\n", + "\n", + "The `attention` mechanism is essential for ensuring good conditioning and images manipulation here.\n", + "\n", + "An `embedding layer`, which is also optimised during training, is used in the original work because it was empirically shown to improve conditioning compared to a single scalar information.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bee5913e", + "metadata": { + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\")\n", + "embedding_dimension = 64\n", + "model = DiffusionModelUNet(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " num_channels=(64, 64, 64),\n", + " attention_levels=(False, True, True),\n", + " num_res_blocks=1,\n", + " num_head_channels=16,\n", + " with_conditioning=True,\n", + " cross_attention_dim=embedding_dimension,\n", + ").to(device)\n", + "embed = torch.nn.Embedding(num_embeddings=3, embedding_dim=embedding_dimension, padding_idx=0).to(device)\n", + "\n", + "scheduler = DDIMScheduler(num_train_timesteps=1000)\n", + "optimizer = torch.optim.Adam(params=list(model.parameters()) + list(embed.parameters()), lr=1e-5)\n", + "\n", + "inferer = DiffusionInferer(scheduler)" + ] + }, + { + "cell_type": "markdown", + "id": "f815ff34", + "metadata": {}, + "source": [ + "## Training a diffusion model with classifier-free guidance" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9a4fc901", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss 0.8078, Interval Loss 0.9115, Interval Loss Val 0.8100\n", + "Train Loss 0.6174, Interval Loss 0.7103, Interval Loss Val 0.6154\n", + "Train Loss 0.4603, Interval Loss 0.5313, Interval Loss Val 0.4571\n", + "Train Loss 0.3323, Interval Loss 0.3887, Interval Loss Val 0.3288\n", + "Train Loss 0.2358, Interval Loss 0.2822, Interval Loss Val 0.2356\n", + "Train Loss 0.1771, Interval Loss 0.2034, Interval Loss Val 0.1660\n", + "Train Loss 0.1237, Interval Loss 0.1467, Interval Loss Val 0.1185\n", + "Train Loss 0.0802, Interval Loss 0.1054, Interval Loss Val 0.0855\n", + "Train Loss 0.0731, Interval Loss 0.0770, Interval Loss Val 0.0699\n", + "Train Loss 0.0528, Interval Loss 0.0570, Interval Loss Val 0.0512\n", + "Train Loss 0.0383, Interval Loss 0.0434, Interval Loss Val 0.0311\n", + "Train Loss 0.0268, Interval Loss 0.0343, Interval Loss Val 0.0483\n", + "Train Loss 0.0330, Interval Loss 0.0292, Interval Loss Val 0.0255\n", + "Train Loss 0.0359, Interval Loss 0.0250, Interval Loss Val 0.0293\n", + "Train Loss 0.0235, Interval Loss 0.0233, Interval Loss Val 0.0318\n", + "Train Loss 0.0241, Interval Loss 0.0224, Interval Loss Val 0.0303\n", + "Train Loss 0.0171, Interval Loss 0.0211, Interval Loss Val 0.0217\n", + "Train Loss 0.0304, Interval Loss 0.0207, Interval Loss Val 0.0098\n", + "Train Loss 0.0269, Interval Loss 0.0211, Interval Loss Val 0.0128\n", + "Train Loss 0.0163, Interval Loss 0.0194, Interval Loss Val 0.0233\n", + "Train Loss 0.0319, Interval Loss 0.0195, Interval Loss Val 0.0389\n", + "Train Loss 0.0200, Interval Loss 0.0194, Interval Loss Val 0.0214\n", + "Train Loss 0.0190, Interval Loss 0.0203, Interval Loss Val 0.0235\n", + "Train Loss 0.0291, Interval Loss 0.0194, Interval Loss Val 0.0178\n", + "Train Loss 0.0263, Interval Loss 0.0196, Interval Loss Val 0.0248\n", + "Train Loss 0.0229, Interval Loss 0.0198, Interval Loss Val 0.0246\n", + "Train Loss 0.0183, Interval Loss 0.0195, Interval Loss Val 0.0239\n", + "Train Loss 0.0147, Interval Loss 0.0188, Interval Loss Val 0.0216\n", + "Train Loss 0.0255, Interval Loss 0.0181, Interval Loss Val 0.0150\n", + "Train Loss 0.0161, Interval Loss 0.0183, Interval Loss Val 0.0211\n", + "Train Loss 0.0239, Interval Loss 0.0182, Interval Loss Val 0.0097\n", + "Train Loss 0.0163, Interval Loss 0.0188, Interval Loss Val 0.0173\n", + "Train Loss 0.0188, Interval Loss 0.0177, Interval Loss Val 0.0121\n", + "Train Loss 0.0120, Interval Loss 0.0189, Interval Loss Val 0.0191\n", + "Train Loss 0.0201, Interval Loss 0.0182, Interval Loss Val 0.0130\n", + "Train Loss 0.0118, Interval Loss 0.0187, Interval Loss Val 0.0309\n", + "Train Loss 0.0114, Interval Loss 0.0181, Interval Loss Val 0.0221\n", + "Train Loss 0.0197, Interval Loss 0.0180, Interval Loss Val 0.0118\n", + "Train Loss 0.0229, Interval Loss 0.0176, Interval Loss Val 0.0278\n", + "Train Loss 0.0242, Interval Loss 0.0188, Interval Loss Val 0.0126\n", + "Train Loss 0.0166, Interval Loss 0.0182, Interval Loss Val 0.0157\n", + "Train Loss 0.0162, Interval Loss 0.0195, Interval Loss Val 0.0170\n", + "Train Loss 0.0124, Interval Loss 0.0179, Interval Loss Val 0.0261\n", + "Train Loss 0.0151, Interval Loss 0.0183, Interval Loss Val 0.0223\n", + "Train Loss 0.0308, Interval Loss 0.0188, Interval Loss Val 0.0151\n", + "Train Loss 0.0210, Interval Loss 0.0177, Interval Loss Val 0.0193\n", + "Train Loss 0.0175, Interval Loss 0.0184, Interval Loss Val 0.0232\n", + "Train Loss 0.0270, Interval Loss 0.0184, Interval Loss Val 0.0125\n", + "Train Loss 0.0128, Interval Loss 0.0181, Interval Loss Val 0.0224\n", + "Train Loss 0.0170, Interval Loss 0.0188, Interval Loss Val 0.0199\n", + "Train Loss 0.0203, Interval Loss 0.0176, Interval Loss Val 0.0145\n", + "Train Loss 0.0248, Interval Loss 0.0176, Interval Loss Val 0.0149\n", + "Train Loss 0.0213, Interval Loss 0.0188, Interval Loss Val 0.0120\n", + "Train Loss 0.0324, Interval Loss 0.0181, Interval Loss Val 0.0310\n", + "Train Loss 0.0302, Interval Loss 0.0183, Interval Loss Val 0.0119\n", + "Train Loss 0.0206, Interval Loss 0.0168, Interval Loss Val 0.0060\n", + "Train Loss 0.0117, Interval Loss 0.0176, Interval Loss Val 0.0208\n", + "Train Loss 0.0213, Interval Loss 0.0172, Interval Loss Val 0.0178\n", + "Train Loss 0.0156, Interval Loss 0.0176, Interval Loss Val 0.0201\n", + "Train Loss 0.0134, Interval Loss 0.0180, Interval Loss Val 0.0116\n", + "Train Loss 0.0305, Interval Loss 0.0176, Interval Loss Val 0.0194\n", + "Train Loss 0.0253, Interval Loss 0.0171, Interval Loss Val 0.0230\n", + "Train Loss 0.0198, Interval Loss 0.0173, Interval Loss Val 0.0246\n", + "Train Loss 0.0139, Interval Loss 0.0175, Interval Loss Val 0.0146\n", + "Train Loss 0.0154, Interval Loss 0.0168, Interval Loss Val 0.0269\n", + "Train Loss 0.0086, Interval Loss 0.0174, Interval Loss Val 0.0132\n", + "Train Loss 0.0156, Interval Loss 0.0171, Interval Loss Val 0.0170\n", + "Train Loss 0.0292, Interval Loss 0.0174, Interval Loss Val 0.0165\n", + "Train Loss 0.0220, Interval Loss 0.0179, Interval Loss Val 0.0171\n", + "Train Loss 0.0132, Interval Loss 0.0175, Interval Loss Val 0.0116\n", + "Train Loss 0.0158, Interval Loss 0.0174, Interval Loss Val 0.0165\n", + "Train Loss 0.0325, Interval Loss 0.0172, Interval Loss Val 0.0235\n", + "Train Loss 0.0105, Interval Loss 0.0169, Interval Loss Val 0.0171\n", + "Train Loss 0.0082, Interval Loss 0.0173, Interval Loss Val 0.0146\n", + "Train Loss 0.0232, Interval Loss 0.0180, Interval Loss Val 0.0132\n", + "Train Loss 0.0120, Interval Loss 0.0170, Interval Loss Val 0.0275\n", + "Train Loss 0.0211, Interval Loss 0.0182, Interval Loss Val 0.0107\n", + "Train Loss 0.0247, Interval Loss 0.0169, Interval Loss Val 0.0162\n", + "Train Loss 0.0196, Interval Loss 0.0178, Interval Loss Val 0.0365\n", + "Train Loss 0.0247, Interval Loss 0.0173, Interval Loss Val 0.0185\n", + "Train Loss 0.0174, Interval Loss 0.0180, Interval Loss Val 0.0153\n", + "Train Loss 0.0203, Interval Loss 0.0185, Interval Loss Val 0.0208\n", + "Train Loss 0.0112, Interval Loss 0.0172, Interval Loss Val 0.0296\n", + "Train Loss 0.0215, Interval Loss 0.0165, Interval Loss Val 0.0155\n", + "Train Loss 0.0144, Interval Loss 0.0172, Interval Loss Val 0.0194\n", + "Train Loss 0.0192, Interval Loss 0.0179, Interval Loss Val 0.0195\n", + "Train Loss 0.0175, Interval Loss 0.0178, Interval Loss Val 0.0092\n", + "Train Loss 0.0082, Interval Loss 0.0180, Interval Loss Val 0.0323\n", + "Train Loss 0.0234, Interval Loss 0.0168, Interval Loss Val 0.0118\n", + "Train Loss 0.0234, Interval Loss 0.0172, Interval Loss Val 0.0192\n", + "Train Loss 0.0088, Interval Loss 0.0172, Interval Loss Val 0.0262\n", + "Train Loss 0.0189, Interval Loss 0.0179, Interval Loss Val 0.0313\n", + "Train Loss 0.0081, Interval Loss 0.0182, Interval Loss Val 0.0181\n", + "Train Loss 0.0195, Interval Loss 0.0168, Interval Loss Val 0.0164\n", + "Train Loss 0.0280, Interval Loss 0.0166, Interval Loss Val 0.0132\n", + "Train Loss 0.0198, Interval Loss 0.0179, Interval Loss Val 0.0125\n", + "Train Loss 0.0182, Interval Loss 0.0167, Interval Loss Val 0.0208\n", + "Train Loss 0.0099, Interval Loss 0.0171, Interval Loss Val 0.0119\n", + "Train Loss 0.0271, Interval Loss 0.0169, Interval Loss Val 0.0156\n", + "Train Loss 0.0119, Interval Loss 0.0164, Interval Loss Val 0.0114\n", + "Train Loss 0.0172, Interval Loss 0.0165, Interval Loss Val 0.0228\n", + "Train Loss 0.0198, Interval Loss 0.0167, Interval Loss Val 0.0179\n", + "Train Loss 0.0165, Interval Loss 0.0164, Interval Loss Val 0.0110\n", + "Train Loss 0.0172, Interval Loss 0.0167, Interval Loss Val 0.0076\n", + "Train Loss 0.0140, Interval Loss 0.0171, Interval Loss Val 0.0172\n", + "Train Loss 0.0090, Interval Loss 0.0176, Interval Loss Val 0.0115\n", + "Train Loss 0.0120, Interval Loss 0.0167, Interval Loss Val 0.0202\n", + "Train Loss 0.0137, Interval Loss 0.0166, Interval Loss Val 0.0169\n", + "Train Loss 0.0113, Interval Loss 0.0171, Interval Loss Val 0.0131\n", + "Train Loss 0.0187, Interval Loss 0.0171, Interval Loss Val 0.0072\n", + "Train Loss 0.0302, Interval Loss 0.0158, Interval Loss Val 0.0169\n", + "Train Loss 0.0220, Interval Loss 0.0172, Interval Loss Val 0.0233\n", + "Train Loss 0.0238, Interval Loss 0.0169, Interval Loss Val 0.0193\n", + "Train Loss 0.0077, Interval Loss 0.0173, Interval Loss Val 0.0176\n", + "Train Loss 0.0221, Interval Loss 0.0172, Interval Loss Val 0.0221\n", + "Train Loss 0.0106, Interval Loss 0.0161, Interval Loss Val 0.0170\n", + "Train Loss 0.0179, Interval Loss 0.0167, Interval Loss Val 0.0281\n", + "Train Loss 0.0147, Interval Loss 0.0164, Interval Loss Val 0.0167\n", + "Train Loss 0.0223, Interval Loss 0.0171, Interval Loss Val 0.0150\n", + "Train Loss 0.0206, Interval Loss 0.0176, Interval Loss Val 0.0159\n", + "Train Loss 0.0138, Interval Loss 0.0168, Interval Loss Val 0.0175\n", + "Train Loss 0.0202, Interval Loss 0.0165, Interval Loss Val 0.0108\n", + "Train Loss 0.0154, Interval Loss 0.0166, Interval Loss Val 0.0123\n", + "Train Loss 0.0225, Interval Loss 0.0175, Interval Loss Val 0.0133\n", + "Train Loss 0.0146, Interval Loss 0.0161, Interval Loss Val 0.0172\n", + "Train Loss 0.0166, Interval Loss 0.0166, Interval Loss Val 0.0120\n", + "Train Loss 0.0140, Interval Loss 0.0169, Interval Loss Val 0.0137\n", + "Train Loss 0.0165, Interval Loss 0.0171, Interval Loss Val 0.0142\n", + "Train Loss 0.0245, Interval Loss 0.0162, Interval Loss Val 0.0301\n", + "Train Loss 0.0207, Interval Loss 0.0162, Interval Loss Val 0.0135\n", + "Train Loss 0.0183, Interval Loss 0.0162, Interval Loss Val 0.0112\n", + "Train Loss 0.0124, Interval Loss 0.0161, Interval Loss Val 0.0128\n", + "Train Loss 0.0168, Interval Loss 0.0169, Interval Loss Val 0.0119\n", + "Train Loss 0.0042, Interval Loss 0.0180, Interval Loss Val 0.0175\n", + "Train Loss 0.0207, Interval Loss 0.0161, Interval Loss Val 0.0143\n", + "Train Loss 0.0195, Interval Loss 0.0157, Interval Loss Val 0.0217\n", + "Train Loss 0.0231, Interval Loss 0.0169, Interval Loss Val 0.0113\n", + "Train Loss 0.0193, Interval Loss 0.0152, Interval Loss Val 0.0221\n", + "Train Loss 0.0206, Interval Loss 0.0166, Interval Loss Val 0.0152\n", + "Train Loss 0.0133, Interval Loss 0.0161, Interval Loss Val 0.0078\n", + "Train Loss 0.0172, Interval Loss 0.0162, Interval Loss Val 0.0109\n", + "Train Loss 0.0204, Interval Loss 0.0160, Interval Loss Val 0.0247\n", + "Train Loss 0.0113, Interval Loss 0.0159, Interval Loss Val 0.0142\n", + "Train Loss 0.0196, Interval Loss 0.0167, Interval Loss Val 0.0183\n", + "Train Loss 0.0088, Interval Loss 0.0156, Interval Loss Val 0.0139\n", + "Train Loss 0.0205, Interval Loss 0.0170, Interval Loss Val 0.0124\n", + "Train Loss 0.0149, Interval Loss 0.0175, Interval Loss Val 0.0309\n", + "Train Loss 0.0115, Interval Loss 0.0165, Interval Loss Val 0.0147\n", + "Train Loss 0.0133, Interval Loss 0.0166, Interval Loss Val 0.0148\n", + "Train Loss 0.0225, Interval Loss 0.0160, Interval Loss Val 0.0166\n", + "Train Loss 0.0147, Interval Loss 0.0163, Interval Loss Val 0.0147\n", + "Train Loss 0.0107, Interval Loss 0.0169, Interval Loss Val 0.0105\n", + "Train Loss 0.0108, Interval Loss 0.0159, Interval Loss Val 0.0139\n", + "Train Loss 0.0079, Interval Loss 0.0161, Interval Loss Val 0.0176\n", + "Train Loss 0.0101, Interval Loss 0.0168, Interval Loss Val 0.0085\n", + "Train Loss 0.0210, Interval Loss 0.0178, Interval Loss Val 0.0146\n", + "Train Loss 0.0146, Interval Loss 0.0174, Interval Loss Val 0.0203\n", + "Train Loss 0.0151, Interval Loss 0.0156, Interval Loss Val 0.0164\n", + "Train Loss 0.0092, Interval Loss 0.0163, Interval Loss Val 0.0126\n", + "Train Loss 0.0140, Interval Loss 0.0160, Interval Loss Val 0.0118\n", + "Train Loss 0.0164, Interval Loss 0.0168, Interval Loss Val 0.0174\n", + "Train Loss 0.0089, Interval Loss 0.0166, Interval Loss Val 0.0143\n", + "Train Loss 0.0165, Interval Loss 0.0177, Interval Loss Val 0.0122\n", + "Train Loss 0.0244, Interval Loss 0.0171, Interval Loss Val 0.0151\n", + "Train Loss 0.0148, Interval Loss 0.0163, Interval Loss Val 0.0209\n", + "Train Loss 0.0162, Interval Loss 0.0160, Interval Loss Val 0.0166\n", + "Train Loss 0.0175, Interval Loss 0.0173, Interval Loss Val 0.0167\n", + "Train Loss 0.0246, Interval Loss 0.0172, Interval Loss Val 0.0165\n", + "Train Loss 0.0150, Interval Loss 0.0168, Interval Loss Val 0.0151\n", + "Train Loss 0.0307, Interval Loss 0.0169, Interval Loss Val 0.0063\n", + "Train Loss 0.0121, Interval Loss 0.0170, Interval Loss Val 0.0179\n", + "Train Loss 0.0113, Interval Loss 0.0156, Interval Loss Val 0.0127\n", + "Train Loss 0.0139, Interval Loss 0.0156, Interval Loss Val 0.0155\n", + "Train Loss 0.0109, Interval Loss 0.0165, Interval Loss Val 0.0148\n", + "Train Loss 0.0090, Interval Loss 0.0168, Interval Loss Val 0.0269\n", + "Train Loss 0.0246, Interval Loss 0.0169, Interval Loss Val 0.0161\n", + "Train Loss 0.0228, Interval Loss 0.0160, Interval Loss Val 0.0166\n", + "Train Loss 0.0124, Interval Loss 0.0156, Interval Loss Val 0.0124\n", + "Train Loss 0.0120, Interval Loss 0.0160, Interval Loss Val 0.0134\n", + "Train Loss 0.0214, Interval Loss 0.0167, Interval Loss Val 0.0192\n", + "Train Loss 0.0194, Interval Loss 0.0165, Interval Loss Val 0.0142\n", + "Train Loss 0.0092, Interval Loss 0.0153, Interval Loss Val 0.0148\n", + "Train Loss 0.0198, Interval Loss 0.0165, Interval Loss Val 0.0115\n", + "Train Loss 0.0139, Interval Loss 0.0172, Interval Loss Val 0.0144\n", + "Train Loss 0.0069, Interval Loss 0.0165, Interval Loss Val 0.0155\n", + "Train Loss 0.0109, Interval Loss 0.0160, Interval Loss Val 0.0197\n", + "Train Loss 0.0243, Interval Loss 0.0165, Interval Loss Val 0.0154\n", + "Train Loss 0.0124, Interval Loss 0.0172, Interval Loss Val 0.0245\n", + "Train Loss 0.0234, Interval Loss 0.0160, Interval Loss Val 0.0198\n", + "Train Loss 0.0129, Interval Loss 0.0163, Interval Loss Val 0.0144\n", + "Train Loss 0.0201, Interval Loss 0.0168, Interval Loss Val 0.0098\n", + "Train Loss 0.0203, Interval Loss 0.0157, Interval Loss Val 0.0135\n", + "Train Loss 0.0161, Interval Loss 0.0159, Interval Loss Val 0.0194\n", + "Train Loss 0.0147, Interval Loss 0.0160, Interval Loss Val 0.0232\n", + "Train Loss 0.0165, Interval Loss 0.0166, Interval Loss Val 0.0133\n", + "Train Loss 0.0175, Interval Loss 0.0165, Interval Loss Val 0.0261\n", + "Train Loss 0.0184, Interval Loss 0.0161, Interval Loss Val 0.0124\n", + "Train Loss 0.0114, Interval Loss 0.0170, Interval Loss Val 0.0103\n", + "Train Loss 0.0224, Interval Loss 0.0160, Interval Loss Val 0.0199\n", + "Train Loss 0.0227, Interval Loss 0.0154, Interval Loss Val 0.0085\n", + "train diffusion completed, total time: 5660.153350830078.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The seaborn styles shipped by Matplotlib are deprecated since 3.6, as they no longer correspond to the styles shipped by seaborn. However, they will remain available as 'seaborn-v0_8-