Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions generative/networks/nets/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@

from __future__ import annotations

import importlib.util

import torch
import torch.nn as nn
from x_transformers import Decoder, TransformerWrapper

if importlib.util.find_spec("x_transformers") is not None:
from x_transformers import Decoder, TransformerWrapper

has_x_transformers = True
else:
has_x_transformers = False


__all__ = ["DecoderOnlyTransformer"]

Expand Down Expand Up @@ -46,16 +55,19 @@ def __init__(
self.attn_layers_depth = attn_layers_depth
self.attn_layers_heads = attn_layers_heads

self.model = TransformerWrapper(
num_tokens=self.num_tokens,
max_seq_len=self.max_seq_len,
attn_layers=Decoder(
dim=self.attn_layers_dim,
depth=self.attn_layers_depth,
heads=self.attn_layers_heads,
cross_attend=with_cross_attention,
),
)
if has_x_transformers:
self.model = TransformerWrapper(
num_tokens=self.num_tokens,
max_seq_len=self.max_seq_len,
attn_layers=Decoder(
dim=self.attn_layers_dim,
depth=self.attn_layers_depth,
heads=self.attn_layers_heads,
cross_attend=with_cross_attention,
),
)
else:
raise ImportError("x-transformers is not installed.")

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
return self.model(x, context=context)
4 changes: 2 additions & 2 deletions tests/test_vqvaetransformer_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_prediction_shape_shorter_sequence(
):
stage_1 = VQVAE(**stage_1_params)
max_seq_len = 3
stage_2_params_shorter = {k: v for k, v in stage_2_params.items()}
stage_2_params_shorter = dict(stage_2_params)
stage_2_params_shorter["max_seq_len"] = max_seq_len
stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter)
ordering = Ordering(**ordering_params)
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_get_likelihood_shorter_sequence(
):
stage_1 = VQVAE(**stage_1_params)
max_seq_len = 3
stage_2_params_shorter = {k: v for k, v in stage_2_params.items()}
stage_2_params_shorter = dict(stage_2_params)
stage_2_params_shorter["max_seq_len"] = max_seq_len
stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter)
ordering = Ordering(**ordering_params)
Expand Down
15 changes: 13 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import copy
Expand Down Expand Up @@ -50,7 +51,7 @@
from monai.utils.type_conversion import convert_data_type

nib, _ = optional_import("nibabel")
http_error, has_requests = optional_import("requests", name="HTTPError")
http_error, has_req = optional_import("requests", name="HTTPError")

quick_test_var = "QUICKTEST"
_tf32_enabled = None
Expand All @@ -67,6 +68,16 @@ def testing_data_config(*keys):
return reduce(operator.getitem, keys, _test_data_config)


def get_testing_algo_template_path():
"""
a local folder to the testing algorithm template or a url to the compressed template file.
Default to None, which effectively uses bundle_gen's ``default_algo_zip`` path.

https://github.com/Project-MONAI/MONAI/blob/1.1.0/monai/apps/auto3dseg/bundle_gen.py#L380-L381
"""
return os.environ.get("MONAI_TESTING_ALGO_TEMPLATE", None)


def clone(data: NdarrayTensor) -> NdarrayTensor:
"""
Clone data independent of type.
Expand Down Expand Up @@ -127,7 +138,7 @@ def assert_allclose(
def skip_if_downloading_fails():
try:
yield
except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_requests else () as e:
except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_req else () as e: # noqa: B030
raise unittest.SkipTest(f"error while downloading: {e}") from e
except ssl.SSLError as ssl_e:
if "decryption failed" in str(ssl_e):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,7 @@ def generate(net, starting_tokens, seq_len, bos_token):
samples = []
for i in range(5):
starting_token = vqvae_model.num_embeddings * torch.ones((1, 1), device=device)
generated_latent = generate(
transformer_model, starting_token, spatial_shape[0] * spatial_shape[1], bos_token
)
generated_latent = generate(transformer_model, starting_token, spatial_shape[0] * spatial_shape[1], bos_token)
generated_latent = generated_latent[0]
vqvae_latent = generated_latent[revert_sequence_ordering]
vqvae_latent = vqvae_latent.reshape((1,) + spatial_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in progress_bar:

images = batch["image"].to(device)

optimizer.zero_grad(set_to_none=True)
Expand All @@ -353,7 +352,6 @@
val_loss = 0
with torch.no_grad():
for val_step, batch in enumerate(val_loader, start=1):

images = batch["image"].to(device)

logits, quantizations_target, _ = inferer(
Expand Down