From 503219d0a2306fc97c075c1c33a4122f0a369a89 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 20 Mar 2023 16:18:36 +0000 Subject: [PATCH 1/3] Add x_transformers as optional dependency --- generative/networks/nets/transformer.py | 32 ++++++++++++++++--------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/generative/networks/nets/transformer.py b/generative/networks/nets/transformer.py index a6fc81ec..4619354a 100644 --- a/generative/networks/nets/transformer.py +++ b/generative/networks/nets/transformer.py @@ -11,9 +11,16 @@ from __future__ import annotations +import importlib.util import torch import torch.nn as nn -from x_transformers import Decoder, TransformerWrapper + +if importlib.util.find_spec("x_transformers") is not None: + from x_transformers import Decoder, TransformerWrapper + has_x_transformers = True +else: + has_x_transformers = False + __all__ = ["DecoderOnlyTransformer"] @@ -46,16 +53,19 @@ def __init__( self.attn_layers_depth = attn_layers_depth self.attn_layers_heads = attn_layers_heads - self.model = TransformerWrapper( - num_tokens=self.num_tokens, - max_seq_len=self.max_seq_len, - attn_layers=Decoder( - dim=self.attn_layers_dim, - depth=self.attn_layers_depth, - heads=self.attn_layers_heads, - cross_attend=with_cross_attention, - ), - ) + if has_x_transformers: + self.model = TransformerWrapper( + num_tokens=self.num_tokens, + max_seq_len=self.max_seq_len, + attn_layers=Decoder( + dim=self.attn_layers_dim, + depth=self.attn_layers_depth, + heads=self.attn_layers_heads, + cross_attend=with_cross_attention, + ), + ) + else: + raise ImportError("x-transformers is not installed.") def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: return self.model(x, context=context) From ecac2c4efcb72af9ba24b56bf82b7bdd3e481a1d Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 20 Mar 2023 16:36:51 +0000 Subject: [PATCH 2/3] Update utils, fix formating --- generative/networks/nets/transformer.py | 2 ++ tests/test_vqvaetransformer_inferer.py | 4 ++-- tests/utils.py | 15 +++++++++++++-- .../2d_vqvae_transformer_tutorial.py | 4 +--- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/generative/networks/nets/transformer.py b/generative/networks/nets/transformer.py index 4619354a..4aefa417 100644 --- a/generative/networks/nets/transformer.py +++ b/generative/networks/nets/transformer.py @@ -12,11 +12,13 @@ from __future__ import annotations import importlib.util + import torch import torch.nn as nn if importlib.util.find_spec("x_transformers") is not None: from x_transformers import Decoder, TransformerWrapper + has_x_transformers = True else: has_x_transformers = False diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index 87766811..53378680 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -103,7 +103,7 @@ def test_prediction_shape_shorter_sequence( ): stage_1 = VQVAE(**stage_1_params) max_seq_len = 3 - stage_2_params_shorter = {k: v for k, v in stage_2_params.items()} + stage_2_params_shorter = dict(stage_2_params) stage_2_params_shorter["max_seq_len"] = max_seq_len stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) ordering = Ordering(**ordering_params) @@ -233,7 +233,7 @@ def test_get_likelihood_shorter_sequence( ): stage_1 = VQVAE(**stage_1_params) max_seq_len = 3 - stage_2_params_shorter = {k: v for k, v in stage_2_params.items()} + stage_2_params_shorter = dict(stage_2_params) stage_2_params_shorter["max_seq_len"] = max_seq_len stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) ordering = Ordering(**ordering_params) diff --git a/tests/utils.py b/tests/utils.py index 274f6aa9..c2c81dde 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from __future__ import annotations import copy @@ -50,7 +51,7 @@ from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") -http_error, has_requests = optional_import("requests", name="HTTPError") +http_error, has_req = optional_import("requests", name="HTTPError") quick_test_var = "QUICKTEST" _tf32_enabled = None @@ -67,6 +68,16 @@ def testing_data_config(*keys): return reduce(operator.getitem, keys, _test_data_config) +def get_testing_algo_template_path(): + """ + a local folder to the testing algorithm template or a url to the compressed template file. + Default to None, which effectively uses bundle_gen's ``default_algo_zip`` path. + + https://github.com/Project-MONAI/MONAI/blob/1.1.0/monai/apps/auto3dseg/bundle_gen.py#L380-L381 + """ + return os.environ.get("MONAI_TESTING_ALGO_TEMPLATE", None) + + def clone(data: NdarrayTensor) -> NdarrayTensor: """ Clone data independent of type. @@ -127,7 +138,7 @@ def assert_allclose( def skip_if_downloading_fails(): try: yield - except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_requests else () as e: + except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_req else () as e: # noqa: B030 raise unittest.SkipTest(f"error while downloading: {e}") from e except ssl.SSLError as ssl_e: if "decryption failed" in str(ssl_e): diff --git a/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py b/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py index 6bee24a2..7485e2e6 100644 --- a/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py +++ b/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py @@ -519,9 +519,7 @@ def generate(net, starting_tokens, seq_len, bos_token): samples = [] for i in range(5): starting_token = vqvae_model.num_embeddings * torch.ones((1, 1), device=device) - generated_latent = generate( - transformer_model, starting_token, spatial_shape[0] * spatial_shape[1], bos_token - ) + generated_latent = generate(transformer_model, starting_token, spatial_shape[0] * spatial_shape[1], bos_token) generated_latent = generated_latent[0] vqvae_latent = generated_latent[revert_sequence_ordering] vqvae_latent = vqvae_latent.reshape((1,) + spatial_shape) From b9ec2f1722322d7683495b3d76ea7919da3ab422 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 20 Mar 2023 16:37:15 +0000 Subject: [PATCH 3/3] Fix formating --- .../anomaly_detection/anomaly_detection_with_transformers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py b/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py index ea85475b..613bdd6e 100644 --- a/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py +++ b/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py @@ -329,7 +329,6 @@ progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110) progress_bar.set_description(f"Epoch {epoch}") for step, batch in progress_bar: - images = batch["image"].to(device) optimizer.zero_grad(set_to_none=True) @@ -353,7 +352,6 @@ val_loss = 0 with torch.no_grad(): for val_step, batch in enumerate(val_loader, start=1): - images = batch["image"].to(device) logits, quantizations_target, _ = inferer(