Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions generative/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
from .autoencoderkl import AutoencoderKL
from .diffusion_model_unet import DiffusionModelUNet
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
from .transformer import DecoderOnlyTransformer
from .vqvae import VQVAE
61 changes: 61 additions & 0 deletions generative/networks/nets/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
import torch.nn as nn
from x_transformers import Decoder, TransformerWrapper

__all__ = ["DecoderOnlyTransformer"]


class DecoderOnlyTransformer(nn.Module):
"""Decoder-only (Autoregressive) Transformer model.

Args:
num_tokens: Number of tokens in the vocabulary.
max_seq_len: Maximum sequence length.
attn_layers_dim: Dimensionality of the attention layers.
attn_layers_depth: Number of attention layers.
attn_layers_heads: Number of attention heads.
with_cross_attention: Whether to use cross attention for conditioning.
"""

def __init__(
self,
num_tokens: int,
max_seq_len: int,
attn_layers_dim: int,
attn_layers_depth: int,
attn_layers_heads: int,
with_cross_attention: bool = False,
) -> None:
super().__init__()
self.num_tokens = num_tokens
self.max_seq_len = max_seq_len
self.attn_layers_dim = attn_layers_dim
self.attn_layers_depth = attn_layers_depth
self.attn_layers_heads = attn_layers_heads

self.model = TransformerWrapper(
num_tokens=self.num_tokens,
max_seq_len=self.max_seq_len,
attn_layers=Decoder(
dim=self.attn_layers_dim,
depth=self.attn_layers_depth,
heads=self.attn_layers_heads,
cross_attend=with_cross_attention,
),
)

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
return self.model(x, context=context)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@ nni
optuna
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
lpips==0.1.4
x-transformers==1.8.1
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def run_testsuit():
"test_integration_workflows_adversarial",
"test_latent_diffusion_inferer",
"test_diffusion_inferer",
"test_transformer",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
42 changes: 42 additions & 0 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from monai.networks import eval_mode

from generative.networks.nets import DecoderOnlyTransformer


class TestDecoderOnlyTransformer(unittest.TestCase):
def test_unconditioned_models(self):
net = DecoderOnlyTransformer(
num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2
)
with eval_mode(net):
net.forward(torch.randint(0, 10, (1, 16)))

def test_conditioned_models(self):
net = DecoderOnlyTransformer(
num_tokens=10,
max_seq_len=16,
attn_layers_dim=8,
attn_layers_depth=2,
attn_layers_heads=2,
with_cross_attention=True,
)
with eval_mode(net):
net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8))


if __name__ == "__main__":
unittest.main()