diff --git a/generative/networks/layers/vector_quantizer.py b/generative/networks/layers/vector_quantizer.py index 358b95f3..79cebeb7 100644 --- a/generative/networks/layers/vector_quantizer.py +++ b/generative/networks/layers/vector_quantizer.py @@ -83,7 +83,6 @@ def __init__( range(1, self.spatial_dims + 1) ) - @torch.cuda.amp.autocast(enabled=False) def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. @@ -100,28 +99,28 @@ def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, to encoding_indices_view = list(inputs.shape) del encoding_indices_view[1] - inputs = inputs.float() + with torch.cuda.amp.autocast(enabled=False): + inputs = inputs.float() - # Converting to channel last format - flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim) + # Converting to channel last format + flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim) - # Calculate Euclidean distances - distances = ( - (flat_input**2).sum(dim=1, keepdim=True) - + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True) - - 2 * torch.mm(flat_input, self.embedding.weight.t()) - ) + # Calculate Euclidean distances + distances = ( + (flat_input**2).sum(dim=1, keepdim=True) + + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True) + - 2 * torch.mm(flat_input, self.embedding.weight.t()) + ) - # Mapping distances to indexes - encoding_indices = torch.max(-distances, dim=1)[1] - encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float() + # Mapping distances to indexes + encoding_indices = torch.max(-distances, dim=1)[1] + encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float() - # Quantize and reshape - encoding_indices = encoding_indices.view(encoding_indices_view) + # Quantize and reshape + encoding_indices = encoding_indices.view(encoding_indices_view) return flat_input, encodings, encoding_indices - @torch.cuda.amp.autocast(enabled=False) def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: """ Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space @@ -135,7 +134,8 @@ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Quantize space representation of encoding_indices in channel first format. """ - return self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() + with torch.cuda.amp.autocast(enabled=False): + return self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() @torch.jit.unused def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None: