diff --git a/generative/networks/blocks/__init__.py b/generative/networks/blocks/__init__.py index 49c4a32c..b7931237 100644 --- a/generative/networks/blocks/__init__.py +++ b/generative/networks/blocks/__init__.py @@ -11,5 +11,6 @@ from __future__ import annotations +from .encoder_modules import SpatialRescaler from .selfattention import SABlock from .transformerblock import TransformerBlock diff --git a/generative/networks/blocks/encoder_modules.py b/generative/networks/blocks/encoder_modules.py new file mode 100644 index 00000000..62eab739 --- /dev/null +++ b/generative/networks/blocks/encoder_modules.py @@ -0,0 +1,83 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from functools import partial + +import torch +import torch.nn as nn +from monai.networks.blocks import Convolution + +__all__ = ["SpatialRescaler"] + + +class SpatialRescaler(nn.Module): + """ + SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py + + Args: + spatial_dims: number of spatial dimensions. + n_stages: number of interpolation stages. + size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]). + method: algorithm used for sampling. + multiplier: multiplier for spatial size. If `multiplier` is a sequence, + its length has to match the number of spatial dimensions; `input.dim() - 2`. + in_channels: number of input channels. + out_channels: number of output channels. + bias: whether to have a bias term. + """ + + def __init__( + self, + spatial_dims: int = 2, + n_stages: int = 1, + size: Sequence[int] | int | None = None, + method: str = "bilinear", + multiplier: Sequence[float] | float | None = None, + in_channels: int = 3, + out_channels: int = None, + bias: bool = False, + ): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ["nearest", "linear", "bilinear", "trilinear", "bicubic", "area"] + if size is not None and n_stages != 1: + raise ValueError("when size is not None, n_stages should be 1.") + if size is not None and multiplier is not None: + raise ValueError("only one of size or multiplier should be defined.") + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method, size=size) + self.remap_output = out_channels is not None + if self.remap_output: + print(f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels before resizing.") + self.channel_mapper = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + conv_only=True, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.remap_output: + x = self.channel_mapper(x) + + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + return x + + def encode(self, x: torch.Tensor) -> torch.Tensor: + return self(x) diff --git a/tests/test_encoder_modules.py b/tests/test_encoder_modules.py new file mode 100644 index 00000000..04639177 --- /dev/null +++ b/tests/test_encoder_modules.py @@ -0,0 +1,130 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from generative.networks.blocks import SpatialRescaler + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +CASES = [ + [ + { + "spatial_dims": 2, + "n_stages": 1, + "method": "bilinear", + "multiplier": 0.5, + "in_channels": None, + "out_channels": None, + }, + (1, 1, 16, 16), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 2, + "n_stages": 1, + "method": "bilinear", + "multiplier": 0.5, + "in_channels": 3, + "out_channels": 2, + }, + (1, 3, 16, 16), + (1, 2, 8, 8), + ], + [ + { + "spatial_dims": 3, + "n_stages": 1, + "method": "trilinear", + "multiplier": 0.5, + "in_channels": None, + "out_channels": None, + }, + (1, 1, 16, 16, 16), + (1, 1, 8, 8, 8), + ], + [ + { + "spatial_dims": 3, + "n_stages": 1, + "method": "trilinear", + "multiplier": 0.5, + "in_channels": 3, + "out_channels": 2, + }, + (1, 3, 16, 16, 16), + (1, 2, 8, 8, 8), + ], + [ + { + "spatial_dims": 3, + "n_stages": 1, + "method": "trilinear", + "multiplier": (0.25, 0.5, 0.75), + "in_channels": 3, + "out_channels": 2, + }, + (1, 3, 20, 20, 20), + (1, 2, 5, 10, 15), + ], + [ + {"spatial_dims": 2, "n_stages": 1, "size": (8, 8), "method": "bilinear", "in_channels": 3, "out_channels": 2}, + (1, 3, 16, 16), + (1, 2, 8, 8), + ], + [ + { + "spatial_dims": 3, + "n_stages": 1, + "size": (8, 8, 8), + "method": "trilinear", + "in_channels": None, + "out_channels": None, + }, + (1, 1, 16, 16, 16), + (1, 1, 8, 8, 8), + ], +] + + +class TestSpatialRescaler(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape): + module = SpatialRescaler(**input_param).to(device) + + result = module(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_method_not_in_available_options(self): + with self.assertRaises(AssertionError): + SpatialRescaler(method="none") + + def test_n_stages_is_negative(self): + with self.assertRaises(AssertionError): + SpatialRescaler(n_stages=-1) + + def test_use_size_but_n_stages_is_not_one(self): + with self.assertRaises(ValueError): + SpatialRescaler(n_stages=2, size=[8, 8, 8]) + + def test_both_size_and_multiplier_defined(self): + with self.assertRaises(ValueError): + SpatialRescaler(size=[1, 2, 3], multiplier=0.5) + + +if __name__ == "__main__": + unittest.main()