diff --git a/docs/requirements.txt b/docs/requirements.txt index 00dd4d2c1e..3530d63c49 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -20,3 +20,4 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops +transformers==4.10.2 diff --git a/docs/source/installation.md b/docs/source/installation.md index 08ab109142..902f596dfc 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` , `einops` and `transformers`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 273431fc72..ff45b29531 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -73,6 +73,7 @@ def get_optional_config_values(): output["psutil"] = psutil_version output["pandas"] = get_package_version("pandas") output["einops"] = get_package_version("einops") + output["transformers"] = get_package_version("transformers") return output diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py new file mode 100644 index 0000000000..f51a5c2913 --- /dev/null +++ b/monai/networks/nets/vltransformer.py @@ -0,0 +1,359 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import shutil +import tarfile +import tempfile +from typing import Sequence, Union + +import torch +from torch import nn + +from monai.utils import optional_import + +transformers = optional_import("transformers") +load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") +cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] +BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] + + +class BertPreTrainedModel(nn.Module): + """Module to load BERT pre-trained weights. + Based on: + LXMERT + https://github.com/airsplay/lxmert + BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, *inputs, **kwargs) -> None: + super(BertPreTrainedModel, self).__init__() + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, torch.nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained( + cls, + num_language_layers, + num_vision_layers, + num_mixed_layers, + bert_config, + state_dict=None, + cache_dir=None, + from_tf=False, + *inputs, + **kwargs, + ): + archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + tempdir = tempfile.mkdtemp() + with tarfile.open(resolved_archive_file, "r:gz") as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, "pytorch_model.bin") + state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) + if tempdir: + shutil.rmtree(tempdir) + if from_tf: + weights_path = os.path.join(serialization_dir, "model.ckpt") + return load_tf_weights_in_bert(model, weights_path) + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + start_prefix = "" + if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()): + start_prefix = "bert." + load(model, prefix=start_prefix) + return model + + +class BertAttention(nn.Module): + """BERT attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, context): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(context) + mixed_value_layer = self.value(context) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores)) + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertOutput(nn.Module): + """BERT output layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertMixedLayer(nn.Module): + """BERT cross attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.att = BertAttention(config) + self.output = BertOutput(config) + + def forward(self, x, y): + output = self.att(x, y) + return self.output(output, x) + + +class Pooler(nn.Module): + """BERT pooler layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + hidden_size, + ) -> None: + super(Pooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class MultiModal(BertPreTrainedModel): + """ + Multimodal Transformers From Pretrained BERT Weights" + """ + + def __init__( + self, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + bert_config: dict, # type: ignore + ) -> None: + """ + Args: + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + bert_config: configuration for bert language transformer encoder. + + """ + super().__init__() + self.config = type("obj", (object,), bert_config) + self.embeddings = BertEmbeddings(self.config) + self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)]) + self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)]) + self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)]) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None): + language_features = self.embeddings(input_ids, token_type_ids) + for layer in self.vision_encoder: + hidden_state_vision = layer(vision_feats, None)[0] + for layer in self.language_encoder: + hidden_state_language = layer(language_features, attention_mask)[0] + for layer in self.mixed_encoder: + hidden_state_mixed = layer(hidden_state_language, hidden_state_vision) + return hidden_state_mixed + + +class VLTransformers(torch.nn.Module): + """ + Vision Language Multimodal Transformers" + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], # type: ignore + patch_size: Union[Sequence[int], int], # type: ignore + num_classes: int, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + drop_out: float = 0.0, + bert_config: dict = { + "attention_probs_dropout_prob": 0.1, + "classifier_dropout": None, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.10.2", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522, + "chunk_size_feed_forward": 0, + "is_decoder": False, + "add_cross_attention": False, + }, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + img_size: dimension of input image. + patch_size: dimension of patch size. + num_classes: number of classes if classification is used. + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + num_mixed_layers: number of mixed transformer layers. + drop_out: faction of the input units to drop. + bert_config: configuration for bert language transformer encoder. + Examples:: + # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, + 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head + >>> net = VLTransformers(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, + num_vision_layers=2, num_mixed_layers=2, drop_out=0.2) + """ + super(VLTransformers, self).__init__() + + if not (0 <= drop_out <= 1): +<<<<<<< HEAD + raise ValueError("dropout_rate should be in the range of 0 and 1.") +======= + raise ValueError("dropout_rate should be between 0 and 1.") +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd + + if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore + raise ValueError("img_size should be divisible by patch_size.") + + self.multimodal = MultiModal.from_pretrained( + num_language_layers=num_language_layers, + num_vision_layers=num_vision_layers, + num_mixed_layers=num_mixed_layers, + bert_config=bert_config, + ) + + self.embed_dim = 768 + self.patch_size = patch_size + self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore + self.vision_proj = nn.Conv2d( + in_channels=in_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + self.norm_vision_pos = nn.LayerNorm(self.embed_dim) + self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) + self.pooler = Pooler(hidden_size=self.embed_dim) + self.drop = torch.nn.Dropout(drop_out) + self.cls_head = torch.nn.Linear(self.embed_dim, num_classes) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None): + attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) + attention_mask = (1.0 - attention_mask) * -10000.0 + vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2) + vision_feats = self.norm_vision_pos(vision_feats) + vision_feats = vision_feats + self.pos_embed_vis + hidden_state_mixed = self.multimodal( + input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask + ) + pooled_features = self.pooler(hidden_state_mixed) + logits = self.cls_head(self.drop(pooled_features)) + return logits diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index add47e27ca..e0629847d5 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -32,10 +32,18 @@ map_classes_to_indices, ) from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis -from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import +from monai.utils import ( + convert_data_type, + convert_to_numpy, + convert_to_tensor, + ensure_tuple, + get_equivalent_dtype, + look_up_option, + min_version, + optional_import, +) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_data_type PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -334,15 +342,25 @@ class ToTensor(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] +<<<<<<< HEAD + def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: + super().__init__() + self.dtype = dtype +======= def __init__(self, device: Optional[torch.device] = None) -> None: super().__init__() +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd self.device = device def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ +<<<<<<< HEAD + return convert_to_tensor(img, dtype=self.dtype, device=self.device, wrap_sequence=True) # type: ignore +======= return convert_to_tensor(img, wrap_sequence=True, device=self.device) # type: ignore +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd class EnsureType(Transform): @@ -354,19 +372,24 @@ class EnsureType(Transform): Args: data_type: target data type to convert, should be "tensor" or "numpy". + dtype: target data content type to convert, for example: np.float32, torch.float, etc. + device: for Tensor data type, specify the target device. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, data_type: str = "tensor") -> None: - data_type = data_type.lower() - if data_type not in ("tensor", "numpy"): - raise ValueError("`data type` must be 'tensor' or 'numpy'.") - - self.data_type = data_type + def __init__( + self, + data_type: str = "tensor", + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + device: Optional[torch.device] = None, + ) -> None: + self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"}) + self.dtype = dtype + self.device = device - def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, data: NdarrayOrTensor): """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. @@ -375,7 +398,12 @@ def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: if applicable. """ - return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore + if self.data_type == "tensor": + dtype_ = get_equivalent_dtype(self.dtype, torch.Tensor) + return convert_to_tensor(data, dtype=dtype_, device=self.device) + else: + dtype_ = get_equivalent_dtype(self.dtype, np.ndarray) + return convert_to_numpy(data, dtype=dtype_) class ToNumpy(Transform): @@ -385,25 +413,40 @@ class ToNumpy(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, dtype: Optional[DtypeLike] = None) -> None: + super().__init__() + self.dtype = dtype + def __call__(self, img: NdarrayOrTensor) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ - return convert_to_numpy(img) # type: ignore + return convert_to_numpy(img, dtype=self.dtype) # type: ignore class ToCupy(Transform): """ Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor. + + Args: + dtype: data type specifier. It is inferred from the input by default. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __init__(self, dtype=None) -> None: + super().__init__() + self.dtype = dtype + + def __call__(self, data: NdarrayOrTensor): """ - Apply the transform to `img` and make it contiguous. + Create a CuPy array from `data` and make it contiguous """ +<<<<<<< HEAD + return convert_to_cupy(data, self.dtype) +======= return cp.ascontiguousarray(cp.asarray(img)) # type: ignore +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd class ToPIL(Transform): @@ -779,6 +822,9 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + if image is not None: + image, *_ = convert_data_type(image, np.ndarray) # type: ignore if output_shape is None: output_shape = self.output_shape fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) @@ -828,6 +874,10 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + if image is not None: + image, *_ = convert_data_type(image, np.ndarray) # type: ignore + if output_shape is None: output_shape = self.output_shape indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) @@ -848,6 +898,7 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): """ def __call__(self, img: np.ndarray) -> np.ndarray: + img, *_ = convert_data_type(img, np.ndarray) # type: ignore # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: img = np.squeeze(img, axis=0) @@ -914,6 +965,9 @@ def __call__( if label.shape[0] != 1: raise ValueError("Only supports single channel labels!") + img, *_ = convert_data_type(img, np.ndarray) # type: ignore + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + # Generate extreme points self.randomize(label[0, :]) @@ -950,6 +1004,7 @@ def __call__(self, img: torch.Tensor): img: PyTorch Tensor data for the TorchVision transform. """ + img, *_ = convert_data_type(img, torch.Tensor) # type: ignore return self.trans(img) @@ -980,7 +1035,7 @@ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeL self.dtype = dtype def __call__(self, img: np.ndarray): - img = np.asarray(img) + img, *_ = convert_data_type(img, np.ndarray) # type: ignore img_flat = img.flatten() try: out_flat = np.copy(img_flat).astype(self.dtype) @@ -1036,6 +1091,7 @@ def __call__( mask must have the same shape as input `img`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore if meta_data is None: meta_data = {} diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index b51ff6a9c8..03d5380ee8 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -16,6 +16,7 @@ "get_equivalent_dtype", "convert_data_type", "get_dtype", + "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", "convert_to_dst_type", @@ -60,6 +61,8 @@ def get_equivalent_dtype(dtype, data_type): im = torch.tensor(1) dtype = get_equivalent_dtype(np.float32, type(im)) """ + if dtype is None: + return None if data_type is torch.Tensor: if type(dtype) is torch.dtype: return dtype @@ -83,7 +86,16 @@ def get_dtype(data: Any): return type(data) +<<<<<<< HEAD +def convert_to_tensor( + data, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = False, +): +======= def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch.device] = None): +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd """ Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. @@ -92,18 +104,42 @@ def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch. data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original. for dictionary, list or tuple, convert every item to a Tensor if applicable. - wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. - If `True`, then `[1, 2]` -> `tensor([1, 2])`. + dtype: target data type to when converting to Tensor. + device: target device to put the converted Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. """ if isinstance(data, torch.Tensor): +<<<<<<< HEAD + return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore +======= return data.contiguous().to(device) +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 if re.search(r"[SaUO]", data.dtype.str) is None: # numpy array with 0 dims is also sequence iterable, # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims +<<<<<<< HEAD + if data.ndim > 0: + data = np.ascontiguousarray(data) + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore + elif ( + has_cp + and isinstance(data, cp_ndarray) + or isinstance(data, (float, int, bool)) + or (isinstance(data, Sequence) and wrap_sequence) + ): + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore + elif isinstance(data, list): + return [convert_to_tensor(i, dtype=dtype, device=device) for i in data] + elif isinstance(data, tuple): + return tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data) + elif isinstance(data, dict): + return {k: convert_to_tensor(v, dtype=dtype, device=device) for k, v in data.items()} +======= return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data), device=device) elif has_cp and isinstance(data, cp_ndarray): return torch.as_tensor(data, device=device) @@ -117,11 +153,12 @@ def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch. return tuple(convert_to_tensor(i, device=device) for i in data) elif isinstance(data, dict): return {k: convert_to_tensor(v, device=device) for k, v in data.items()} +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd return data -def convert_to_numpy(data, wrap_sequence: bool = False): +def convert_to_numpy(data, dtype: Optional[DtypeLike] = None, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array. @@ -130,23 +167,22 @@ def convert_to_numpy(data, wrap_sequence: bool = False): data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original. for dictionary, list or tuple, convert every item to a numpy array if applicable. + dtype: target data type when converting to numpy array. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. """ if isinstance(data, torch.Tensor): - data = data.detach().cpu().numpy() + data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy() elif has_cp and isinstance(data, cp_ndarray): - data = cp.asnumpy(data) - elif isinstance(data, (float, int, bool)): - data = np.asarray(data) - elif isinstance(data, Sequence) and wrap_sequence: - return np.asarray(data) + data = cp.asnumpy(data).astype(dtype) + elif isinstance(data, (np.ndarray, float, int, bool)) or (isinstance(data, Sequence) and wrap_sequence): + data = np.asarray(data, dtype=dtype) elif isinstance(data, list): - return [convert_to_numpy(i) for i in data] + return [convert_to_numpy(i, dtype=dtype) for i in data] elif isinstance(data, tuple): - return tuple(convert_to_numpy(i) for i in data) + return tuple(convert_to_numpy(i, dtype=dtype) for i in data) elif isinstance(data, dict): - return {k: convert_to_numpy(v) for k, v in data.items()} + return {k: convert_to_numpy(v, dtype=dtype) for k, v in data.items()} if isinstance(data, np.ndarray) and data.ndim > 0: data = np.ascontiguousarray(data) @@ -154,6 +190,42 @@ def convert_to_numpy(data, wrap_sequence: bool = False): return data +def convert_to_cupy(data, dtype, wrap_sequence: bool = True): + """ + Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple, + recursively check every item and convert it to cupy array. + + Args: + data: input data can be PyTorch Tensor, numpy array, cupy array, list, dictionary, int, float, bool, str, etc. + Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays + + for dictionary, list or tuple, convert every item to a numpy array if applicable. + dtype: target data type when converting to Cupy array. + wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. + If `True`, then `[1, 2]` -> `array([1, 2])`. + """ + + # direct calls + if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)) or ( + isinstance(data, Sequence) and wrap_sequence + ): + data = cp.asarray(data, dtype) + elif isinstance(data, list): + return [convert_to_cupy(i, dtype) for i in data] + elif isinstance(data, tuple): + return tuple(convert_to_cupy(i, dtype) for i in data) + elif isinstance(data, dict): + return {k: convert_to_cupy(v, dtype) for k, v in data.items()} + # make it contiguous + if isinstance(data, cp.ndarray): + if data.ndim > 0: + data = cp.ascontiguousarray(data) + else: + raise ValueError(f"The input data type [{type(data)}] cannot be converted into cupy arrays!") + + return data + + def convert_data_type( data: Any, output_type: Optional[type] = None, @@ -178,6 +250,8 @@ def convert_data_type( orig_type = torch.Tensor elif isinstance(data, np.ndarray): orig_type = np.ndarray + elif has_cp and isinstance(data, cp.ndarray): + orig_type = cp.ndarray else: orig_type = type(data) @@ -185,30 +259,33 @@ def convert_data_type( output_type = output_type or orig_type - dtype = get_equivalent_dtype(dtype or get_dtype(data), output_type) + dtype_ = get_equivalent_dtype(dtype or get_dtype(data), output_type) if output_type is torch.Tensor: - if orig_type is not torch.Tensor: - data = convert_to_tensor(data) - if dtype != data.dtype: - data = data.to(dtype) - if device is not None: - data = data.to(device) + data = convert_to_tensor(data, dtype=dtype_, device=device) elif output_type is np.ndarray: - if orig_type is not np.ndarray: - data = convert_to_numpy(data) - if data is not None and dtype != data.dtype: - data = data.astype(dtype) + data = convert_to_numpy(data, dtype=dtype_) + elif has_cp and output_type is cp.ndarray: + data = convert_to_cupy(data, dtype=dtype_) else: raise ValueError(f"Unsupported output type: {output_type}") return data, orig_type, orig_device -def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: +def convert_to_dst_type( + src: Any, dst: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None +) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: """ +<<<<<<< HEAD + If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, + if `dst` is an instance of `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`, + otherwise, convert to the type of `dst` directly. + `dtype` is an optional argument if the target `dtype` is different from the original `dst`'s data type. +======= If `dst` is `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, if `dst` is `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`, otherwise, convert to the type of `dst` directly. +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd See Also: :func:`convert_data_type` @@ -217,6 +294,12 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor if isinstance(dst, torch.Tensor): device = dst.device +<<<<<<< HEAD + if dtype is None: + dtype = dst.dtype + +======= +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd output_type: Any if isinstance(dst, torch.Tensor): output_type = torch.Tensor @@ -224,4 +307,8 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor output_type = np.ndarray else: output_type = type(dst) +<<<<<<< HEAD + return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype) +======= return convert_data_type(data=src, output_type=output_type, device=device, dtype=dst.dtype) +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd diff --git a/requirements-dev.txt b/requirements-dev.txt index 785454ad5d..ed8739ded8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -36,3 +36,4 @@ openslide-python==1.1.2 pandas requests einops +transformers diff --git a/setup.cfg b/setup.cfg index 6efe768a6f..f7ed90a14a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,7 @@ all = openslide-python==1.1.2 pandas einops + transformers nibabel = nibabel skimage = @@ -74,6 +75,8 @@ pandas = pandas einops = einops +transformers = + transformers [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 diff --git a/tests/min_tests.py b/tests/min_tests.py index 5b376d7b57..bac6521889 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -140,6 +140,7 @@ def run_testsuit(): "test_zoom", "test_zoom_affine", "test_zoomd", + "test_vltransformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 3d187a1dba..4f618ccc1a 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -11,6 +11,7 @@ import unittest +import torch from parameterized import parameterized from torch import Tensor @@ -35,16 +36,27 @@ class TestToTensor(unittest.TestCase): @parameterized.expand(TESTS) def test_array_input(self, test_data, expected_shape): +<<<<<<< HEAD + result = ToTensor(dtype=torch.float32, device="cpu")(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) +======= result = ToTensor()(test_data) self.assertTrue(isinstance(result, Tensor)) assert_allclose(result, test_data) +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand(TESTS_SINGLE) def test_single_input(self, test_data): result = ToTensor()(test_data) +<<<<<<< HEAD + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) +======= self.assertTrue(isinstance(result, Tensor)) assert_allclose(result, test_data) +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd self.assertEqual(result.ndim, 0) @unittest.skipUnless(has_cp, "CuPy is required.") @@ -52,8 +64,13 @@ def test_cupy(self): test_data = [[1, 2], [3, 4]] cupy_array = cp.ascontiguousarray(cp.asarray(test_data)) result = ToTensor()(cupy_array) +<<<<<<< HEAD + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) +======= self.assertTrue(isinstance(result, Tensor)) assert_allclose(result, test_data) +>>>>>>> 7be790dac0381cc7a3ed393d351f2a860570cbdd if __name__ == "__main__": diff --git a/tests/test_vltransformer.py b/tests/test_vltransformer.py new file mode 100644 index 0000000000..a92a9bf79a --- /dev/null +++ b/tests/test_vltransformer.py @@ -0,0 +1,80 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vltransformer import VLTransformers + +TEST_CASE_VLTransformers = [] +for drop_out in [0.4]: + for in_channels in [3]: + for img_size in [224]: + for patch_size in [16, 32]: + for num_language_layers in [2]: + for num_vision_layers in [4]: + for num_mixed_layers in [3]: + for num_classes in [8]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * 2, + "patch_size": (patch_size,) * 2, + "num_vision_layers": num_vision_layers, + "num_mixed_layers": num_mixed_layers, + "num_language_layers": num_language_layers, + "num_classes": num_classes, + "drop_out": drop_out, + }, + (2, num_classes), # type: ignore + ] + TEST_CASE_VLTransformers.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_VLTransformers) + def test_shape(self, input_param, expected_shape): + net = VLTransformers(**input_param) + with eval_mode(net): + result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224))) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + VLTransformers( + in_channels=3, + img_size=(128, 128), + patch_size=(16, 16), + num_language_layers=2, + num_mixed_layers=4, + num_vision_layers=2, + num_classes=2, + drop_out=5.0, + ) + + with self.assertRaises(ValueError): + VLTransformers( + in_channels=1, + img_size=(97, 97), + patch_size=(16, 16), + num_language_layers=6, + num_mixed_layers=6, + num_vision_layers=8, + num_classes=8, + drop_out=0.4, + ) + + +if __name__ == "__main__": + unittest.main()