From 5d973cce22d8e41b3ea762781ed1dacc779ec5b9 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sun, 5 Feb 2023 11:26:26 +0000 Subject: [PATCH 1/2] Add AutoregressiveTransformer Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/__init__.py | 1 + .../nets/autoregressive_transformer.py | 46 +++++++++++++++++++ requirements-dev.txt | 1 + tests/min_tests.py | 1 + tests/test_autoregressive_transformer.py | 30 ++++++++++++ 5 files changed, 79 insertions(+) create mode 100644 generative/networks/nets/autoregressive_transformer.py create mode 100644 tests/test_autoregressive_transformer.py diff --git a/generative/networks/nets/__init__.py b/generative/networks/nets/__init__.py index 8f7b51d8..1c49c5b0 100644 --- a/generative/networks/nets/__init__.py +++ b/generative/networks/nets/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .autoencoderkl import AutoencoderKL +from .autoregressive_transformer import AutoregressiveTransformer from .diffusion_model_unet import DiffusionModelUNet from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator from .vqvae import VQVAE diff --git a/generative/networks/nets/autoregressive_transformer.py b/generative/networks/nets/autoregressive_transformer.py new file mode 100644 index 00000000..9780fa39 --- /dev/null +++ b/generative/networks/nets/autoregressive_transformer.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn +from x_transformers import Decoder, TransformerWrapper + +__all__ = ["AutoregressiveTransformer"] + + +class AutoregressiveTransformer(nn.Module): + """Autoregressive Transformer model. + + Args: + num_tokens: Number of tokens in the vocabulary. + max_seq_len: Maximum sequence length. + attn_layers_dim: Dimensionality of the attention layers. + attn_layers_depth: Number of attention layers. + attn_layers_heads: Number of attention heads. + """ + + def __init__( + self, num_tokens: int, max_seq_len: int, attn_layers_dim: int, attn_layers_depth: int, attn_layers_heads: int + ) -> None: + super().__init__() + self.num_tokens = num_tokens + self.max_seq_len = max_seq_len + self.attn_layers_dim = attn_layers_dim + self.attn_layers_depth = attn_layers_depth + self.attn_layers_heads = attn_layers_heads + + self.model = TransformerWrapper( + num_tokens=self.num_tokens, + max_seq_len=self.max_seq_len, + attn_layers=Decoder(dim=self.attn_layers_dim, depth=self.attn_layers_depth, heads=self.attn_layers_heads), + ) + + def forward(self, x): + return self.model(x) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1a9ddd18..2c8e5a6d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -54,3 +54,4 @@ nni optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded lpips==0.1.4 +x-transformers==1.8.1 diff --git a/tests/min_tests.py b/tests/min_tests.py index 82fb1130..240117f7 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -33,6 +33,7 @@ def run_testsuit(): "test_integration_workflows_adversarial", "test_latent_diffusion_inferer", "test_diffusion_inferer", + "test_autoregressive_transformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_autoregressive_transformer.py b/tests/test_autoregressive_transformer.py new file mode 100644 index 00000000..629c58a6 --- /dev/null +++ b/tests/test_autoregressive_transformer.py @@ -0,0 +1,30 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from monai.networks import eval_mode + +from generative.networks.nets import AutoregressiveTransformer + + +class TestAutoregressiveTransformer(unittest.TestCase): + def test_shape_unconditioned_models(self): + net = AutoregressiveTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2 + ) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16))) + + +if __name__ == "__main__": + unittest.main() From df52e62f13e7d6c13898ce208e8cb3af77768908 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 6 Feb 2023 14:11:27 +0000 Subject: [PATCH 2/2] Add cross-attention and rename model Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/__init__.py | 2 +- ...gressive_transformer.py => transformer.py} | 29 ++++++++++++++----- tests/min_tests.py | 2 +- ...ive_transformer.py => test_transformer.py} | 20 ++++++++++--- 4 files changed, 40 insertions(+), 13 deletions(-) rename generative/networks/nets/{autoregressive_transformer.py => transformer.py} (62%) rename tests/{test_autoregressive_transformer.py => test_transformer.py} (60%) diff --git a/generative/networks/nets/__init__.py b/generative/networks/nets/__init__.py index 1c49c5b0..ed15c2f8 100644 --- a/generative/networks/nets/__init__.py +++ b/generative/networks/nets/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .autoencoderkl import AutoencoderKL -from .autoregressive_transformer import AutoregressiveTransformer from .diffusion_model_unet import DiffusionModelUNet from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator +from .transformer import DecoderOnlyTransformer from .vqvae import VQVAE diff --git a/generative/networks/nets/autoregressive_transformer.py b/generative/networks/nets/transformer.py similarity index 62% rename from generative/networks/nets/autoregressive_transformer.py rename to generative/networks/nets/transformer.py index 9780fa39..84476aef 100644 --- a/generative/networks/nets/autoregressive_transformer.py +++ b/generative/networks/nets/transformer.py @@ -9,14 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +import torch import torch.nn as nn from x_transformers import Decoder, TransformerWrapper -__all__ = ["AutoregressiveTransformer"] +__all__ = ["DecoderOnlyTransformer"] -class AutoregressiveTransformer(nn.Module): - """Autoregressive Transformer model. +class DecoderOnlyTransformer(nn.Module): + """Decoder-only (Autoregressive) Transformer model. Args: num_tokens: Number of tokens in the vocabulary. @@ -24,10 +27,17 @@ class AutoregressiveTransformer(nn.Module): attn_layers_dim: Dimensionality of the attention layers. attn_layers_depth: Number of attention layers. attn_layers_heads: Number of attention heads. + with_cross_attention: Whether to use cross attention for conditioning. """ def __init__( - self, num_tokens: int, max_seq_len: int, attn_layers_dim: int, attn_layers_depth: int, attn_layers_heads: int + self, + num_tokens: int, + max_seq_len: int, + attn_layers_dim: int, + attn_layers_depth: int, + attn_layers_heads: int, + with_cross_attention: bool = False, ) -> None: super().__init__() self.num_tokens = num_tokens @@ -39,8 +49,13 @@ def __init__( self.model = TransformerWrapper( num_tokens=self.num_tokens, max_seq_len=self.max_seq_len, - attn_layers=Decoder(dim=self.attn_layers_dim, depth=self.attn_layers_depth, heads=self.attn_layers_heads), + attn_layers=Decoder( + dim=self.attn_layers_dim, + depth=self.attn_layers_depth, + heads=self.attn_layers_heads, + cross_attend=with_cross_attention, + ), ) - def forward(self, x): - return self.model(x) + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.model(x, context=context) diff --git a/tests/min_tests.py b/tests/min_tests.py index 240117f7..b4373dd8 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -33,7 +33,7 @@ def run_testsuit(): "test_integration_workflows_adversarial", "test_latent_diffusion_inferer", "test_diffusion_inferer", - "test_autoregressive_transformer", + "test_transformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_autoregressive_transformer.py b/tests/test_transformer.py similarity index 60% rename from tests/test_autoregressive_transformer.py rename to tests/test_transformer.py index 629c58a6..9ddb4ca7 100644 --- a/tests/test_autoregressive_transformer.py +++ b/tests/test_transformer.py @@ -14,17 +14,29 @@ import torch from monai.networks import eval_mode -from generative.networks.nets import AutoregressiveTransformer +from generative.networks.nets import DecoderOnlyTransformer -class TestAutoregressiveTransformer(unittest.TestCase): - def test_shape_unconditioned_models(self): - net = AutoregressiveTransformer( +class TestDecoderOnlyTransformer(unittest.TestCase): + def test_unconditioned_models(self): + net = DecoderOnlyTransformer( num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2 ) with eval_mode(net): net.forward(torch.randint(0, 10, (1, 16))) + def test_conditioned_models(self): + net = DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + with_cross_attention=True, + ) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8)) + if __name__ == "__main__": unittest.main()