diff --git a/generative/networks/nets/__init__.py b/generative/networks/nets/__init__.py index 8f7b51d8..ed15c2f8 100644 --- a/generative/networks/nets/__init__.py +++ b/generative/networks/nets/__init__.py @@ -12,4 +12,5 @@ from .autoencoderkl import AutoencoderKL from .diffusion_model_unet import DiffusionModelUNet from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator +from .transformer import DecoderOnlyTransformer from .vqvae import VQVAE diff --git a/generative/networks/nets/transformer.py b/generative/networks/nets/transformer.py new file mode 100644 index 00000000..84476aef --- /dev/null +++ b/generative/networks/nets/transformer.py @@ -0,0 +1,61 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn +from x_transformers import Decoder, TransformerWrapper + +__all__ = ["DecoderOnlyTransformer"] + + +class DecoderOnlyTransformer(nn.Module): + """Decoder-only (Autoregressive) Transformer model. + + Args: + num_tokens: Number of tokens in the vocabulary. + max_seq_len: Maximum sequence length. + attn_layers_dim: Dimensionality of the attention layers. + attn_layers_depth: Number of attention layers. + attn_layers_heads: Number of attention heads. + with_cross_attention: Whether to use cross attention for conditioning. + """ + + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers_dim: int, + attn_layers_depth: int, + attn_layers_heads: int, + with_cross_attention: bool = False, + ) -> None: + super().__init__() + self.num_tokens = num_tokens + self.max_seq_len = max_seq_len + self.attn_layers_dim = attn_layers_dim + self.attn_layers_depth = attn_layers_depth + self.attn_layers_heads = attn_layers_heads + + self.model = TransformerWrapper( + num_tokens=self.num_tokens, + max_seq_len=self.max_seq_len, + attn_layers=Decoder( + dim=self.attn_layers_dim, + depth=self.attn_layers_depth, + heads=self.attn_layers_heads, + cross_attend=with_cross_attention, + ), + ) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.model(x, context=context) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1a9ddd18..2c8e5a6d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -54,3 +54,4 @@ nni optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded lpips==0.1.4 +x-transformers==1.8.1 diff --git a/tests/min_tests.py b/tests/min_tests.py index 82fb1130..b4373dd8 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -33,6 +33,7 @@ def run_testsuit(): "test_integration_workflows_adversarial", "test_latent_diffusion_inferer", "test_diffusion_inferer", + "test_transformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 00000000..9ddb4ca7 --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,42 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from monai.networks import eval_mode + +from generative.networks.nets import DecoderOnlyTransformer + + +class TestDecoderOnlyTransformer(unittest.TestCase): + def test_unconditioned_models(self): + net = DecoderOnlyTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2 + ) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16))) + + def test_conditioned_models(self): + net = DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + with_cross_attention=True, + ) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8)) + + +if __name__ == "__main__": + unittest.main()