Skip to content

TypeError: unsupported format string passed to MetaTensor.__format__ #6522

@MathijsdeBoer

Description

@MathijsdeBoer

Describe the bug
Not sure if the bug lies on this side, or on the side of PyTorch Lightning, but here it goes:

I'm using PyTorch Lightning to set up a simple training pipeline. When I use pl.callbacks.EarlyStopping with a CacheDataset and associated transforms, I get:

(... snip for brevity ...)
File "E:\PythonPoetry\virtualenvs\evdkp-3rCk5jn4-py3.10\lib\site-packages\pytorch_lightning\callbacks\early_stopping.py", line 184, in on_train_epoch_end
    self._run_early_stopping_check(trainer)
  File "E:\PythonPoetry\virtualenvs\evdkp-3rCk5jn4-py3.10\lib\site-packages\pytorch_lightning\callbacks\early_stopping.py", line 201, in _run_early_stopping_check
    should_stop, reason = self._evaluate_stopping_criteria(current)
  File "E:\PythonPoetry\virtualenvs\evdkp-3rCk5jn4-py3.10\lib\site-packages\pytorch_lightning\callbacks\early_stopping.py", line 236, in _evaluate_stopping_criteria
    reason = self._improvement_message(current)
  File "E:\PythonPoetry\virtualenvs\evdkp-3rCk5jn4-py3.10\lib\site-packages\pytorch_lightning\callbacks\early_stopping.py", line 258, in _improvement_message
    msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"
  File "E:\PythonPoetry\virtualenvs\evdkp-3rCk5jn4-py3.10\lib\site-packages\torch\_tensor.py", line 870, in __format__
    return handle_torch_function(Tensor.__format__, (self,), self, format_spec)
  File "E:\PythonPoetry\virtualenvs\evdkp-3rCk5jn4-py3.10\lib\site-packages\torch\overrides.py", line 1551, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
  File "E:\PythonPoetry\virtualenvs\evdkp-3rCk5jn4-py3.10\lib\site-packages\monai\data\meta_tensor.py", line 276, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
  File "E:\PythonPoetry\virtualenvs\evdkp-3rCk5jn4-py3.10\lib\site-packages\torch\_tensor.py", line 1295, in __torch_function__
    ret = func(*args, **kwargs)
  File "E:\PythonPoetry\virtualenvs\evdkp-3rCk5jn4-py3.10\lib\site-packages\torch\_tensor.py", line 873, in __format__
    return object.__format__(self, format_spec)
TypeError: unsupported format string passed to MetaTensor.__format__

Where I reckon this line is the issue:

File "E:\PythonPoetry\virtualenvs\evdkp-3rCk5jn4-py3.10\lib\site-packages\pytorch_lightning\callbacks\early_stopping.py", line 258, in _improvement_message
    msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"

To Reproduce

I've tried to extract a minimal example of the cause of the issue.

# main.py

import pytorch_lightning as pl

from ... import MyDataModule, MyModel

trainer = pl.Trainer(
    callbacks=[
        pl.callbacks.EarlyStopping(
            monitor="val_loss",
            patience=3,
            mode="min",
            verbose=False,
        ),
    ],
)

data_module = MyDataModule(path_to_dataset)

model = MyModel()

trainer.fit(model, datamodule=data_module)
# mydatamodule.py

from pathlib import Path

import monai.transforms as mt
from monai.data import CacheDataset
from pytorch_lightning import LightningDataModule
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

class MyDataModule(LightningDataModule):
    def __init__(self, path: Path) -> None:
        super().__init__()
        self._path = path
        
        self.samples = []
        # collect samples
        ...

        self._train_samples = None
        self._val_samples = None

        base_transforms = [
            mt.LoadImaged(keys=["image", "label"], image_only=True),
            mt.EnsureChannelFirstd(keys=["image", "label"]),
            mt.NormalizeIntensityd(keys=["image"]),
            mt.ToTensord(keys=["image", "label"]),
        ]

        self._train_transform = mt.Compose(
            base_transforms
        )
        self._val_transform = mt.Compose(
            base_transforms
        )

    def setup(self, stage: str) -> None:
        self._train_samples, self._val_samples = train_test_split(self._samples, test_size=0.2)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            CacheDataset(
                self._train_samples,
                transform=self._train_transform,
                cache_rate=1.0,
                num_workers=0,
            ),
            batch_size=1,
            num_workers=0,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            CacheDataset(
                self._val_samples,
                transform=self._val_transform,
                cache_rate=1.0,
                num_workers=0,
            ),
            batch_size=1,
            num_workers=0,
        )
# model.py

import math
from typing import Callable

import pytorch_lightning as pl
import torch
from monai.networks.nets import UNet
from torch.nn import MSELoss

class MyModel(pl.LightningModule):
    def __init__(
        self,
        n_channels: int,
        n_classes: int,
        initial_filters: int = 32,
        max_filters: int | None = None,
        depth: int | None = None,
        n_residual_units: int = 0,
        final_activation: torch.nn.Module | Callable = torch.nn.Sigmoid(),
        loss_function: torch.nn.Module | Callable | None = None,
        metrics: list[torch.nn.Module | Callable] | None = None,
        resolution: tuple[int, int] | int = 256,
        learning_rate: float = 1e-3,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore=["final_activation", "loss_function", "metrics"])

        if isinstance(resolution, int):
            self.resolution: tuple[int, int] = (resolution * 2, resolution)
        elif isinstance(resolution, tuple):
            self.resolution: tuple[int, int] = resolution
        else:
            raise ValueError("resolution must be an int or a tuple of ints.")

        self.example_input_array = torch.zeros((1, n_channels, *self.resolution))

        if depth is None:
            depth: int = int(round(math.log2(min(self.resolution))))

        if max_filters is None:
            channels = [initial_filters * 2**i for i in range(depth)]
        else:
            channels = [min(initial_filters * 2**i, max_filters) for i in range(depth)]

        strides = [2] * (depth - 1)

        self.model = UNet(
            spatial_dims=2,
            in_channels=n_channels,
            out_channels=n_classes,
            channels=channels,
            strides=strides,
            num_res_units=n_residual_units,
        )
        self.final_activation = final_activation

        if loss_function is None:
            self.loss_function = MSELoss()
        else:
            self.loss_function = loss_function

        if metrics is None:
            self.metrics = []
        else:
            self.metrics = metrics

        self.lr = learning_rate

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self) -> dict[str, torch.optim.Optimizer | str]:
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.1, patience=50, verbose=True
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_loss",
        }

    def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict:
        x = batch["image"]
        y = batch["label"]

        y_hat = self.final_activation(self(x))

        loss = self.loss_function(y_hat, y)

        output = {
            "loss": loss,
        }

        self.log_dict(output, prog_bar=True)
        return output

    def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict:
        x = batch["image"]
        y = batch["label"]

        y_hat = self.final_activation(self(x))

        loss = self.loss_function(y_hat, y)

        output = {
            "val_loss": loss,
        }

        self.log_dict(output, prog_bar=True)
        return output

    def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict:
        x = batch["image"]
        y = batch["label"]

        y_hat = self.final_activation(self(x))

        loss = self.loss_function(y_hat, y)

        output = {
            "test_loss": loss,
        }

        self.log_dict(output)
        return output

Expected behavior
The EarlyStopping to work.

Environment

Tried this on v1.1.0 and v1.2.0rc7

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions