Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions generative/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@

from __future__ import annotations

from .encoder_modules import SpatialRescaler
from .selfattention import SABlock
from .transformerblock import TransformerBlock
83 changes: 83 additions & 0 deletions generative/networks/blocks/encoder_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
from functools import partial

import torch
import torch.nn as nn
from monai.networks.blocks import Convolution

__all__ = ["SpatialRescaler"]


class SpatialRescaler(nn.Module):
"""
SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py

Args:
spatial_dims: number of spatial dimensions.
n_stages: number of interpolation stages.
size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]).
method: algorithm used for sampling.
multiplier: multiplier for spatial size. If `multiplier` is a sequence,
its length has to match the number of spatial dimensions; `input.dim() - 2`.
in_channels: number of input channels.
out_channels: number of output channels.
bias: whether to have a bias term.
"""

def __init__(
self,
spatial_dims: int = 2,
n_stages: int = 1,
size: Sequence[int] | int | None = None,
method: str = "bilinear",
multiplier: Sequence[float] | float | None = None,
in_channels: int = 3,
out_channels: int = None,
bias: bool = False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ["nearest", "linear", "bilinear", "trilinear", "bicubic", "area"]
if size is not None and n_stages != 1:
raise ValueError("when size is not None, n_stages should be 1.")
if size is not None and multiplier is not None:
raise ValueError("only one of size or multiplier should be defined.")
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method, size=size)
self.remap_output = out_channels is not None
if self.remap_output:
print(f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels before resizing.")
self.channel_mapper = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
conv_only=True,
bias=bias,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.remap_output:
x = self.channel_mapper(x)

for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)

return x

def encode(self, x: torch.Tensor) -> torch.Tensor:
return self(x)
130 changes: 130 additions & 0 deletions tests/test_encoder_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from generative.networks.blocks import SpatialRescaler

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

CASES = [
[
{
"spatial_dims": 2,
"n_stages": 1,
"method": "bilinear",
"multiplier": 0.5,
"in_channels": None,
"out_channels": None,
},
(1, 1, 16, 16),
(1, 1, 8, 8),
],
[
{
"spatial_dims": 2,
"n_stages": 1,
"method": "bilinear",
"multiplier": 0.5,
"in_channels": 3,
"out_channels": 2,
},
(1, 3, 16, 16),
(1, 2, 8, 8),
],
[
{
"spatial_dims": 3,
"n_stages": 1,
"method": "trilinear",
"multiplier": 0.5,
"in_channels": None,
"out_channels": None,
},
(1, 1, 16, 16, 16),
(1, 1, 8, 8, 8),
],
[
{
"spatial_dims": 3,
"n_stages": 1,
"method": "trilinear",
"multiplier": 0.5,
"in_channels": 3,
"out_channels": 2,
},
(1, 3, 16, 16, 16),
(1, 2, 8, 8, 8),
],
[
{
"spatial_dims": 3,
"n_stages": 1,
"method": "trilinear",
"multiplier": (0.25, 0.5, 0.75),
"in_channels": 3,
"out_channels": 2,
},
(1, 3, 20, 20, 20),
(1, 2, 5, 10, 15),
],
[
{"spatial_dims": 2, "n_stages": 1, "size": (8, 8), "method": "bilinear", "in_channels": 3, "out_channels": 2},
(1, 3, 16, 16),
(1, 2, 8, 8),
],
[
{
"spatial_dims": 3,
"n_stages": 1,
"size": (8, 8, 8),
"method": "trilinear",
"in_channels": None,
"out_channels": None,
},
(1, 1, 16, 16, 16),
(1, 1, 8, 8, 8),
],
]


class TestSpatialRescaler(unittest.TestCase):
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape):
module = SpatialRescaler(**input_param).to(device)

result = module(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

def test_method_not_in_available_options(self):
with self.assertRaises(AssertionError):
SpatialRescaler(method="none")

def test_n_stages_is_negative(self):
with self.assertRaises(AssertionError):
SpatialRescaler(n_stages=-1)

def test_use_size_but_n_stages_is_not_one(self):
with self.assertRaises(ValueError):
SpatialRescaler(n_stages=2, size=[8, 8, 8])

def test_both_size_and_multiplier_defined(self):
with self.assertRaises(ValueError):
SpatialRescaler(size=[1, 2, 3], multiplier=0.5)


if __name__ == "__main__":
unittest.main()