From a81b59854ab9aec692a0a549c611183cd537b7e9 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 13 Feb 2023 18:26:49 +0000 Subject: [PATCH] Use Tuple for typing Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/layers/vector_quantizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/generative/networks/layers/vector_quantizer.py b/generative/networks/layers/vector_quantizer.py index 661f2129..358b95f3 100644 --- a/generative/networks/layers/vector_quantizer.py +++ b/generative/networks/layers/vector_quantizer.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence +from typing import Sequence, Tuple import torch from torch import nn @@ -84,7 +84,7 @@ def __init__( ) @torch.cuda.amp.autocast(enabled=False) - def quantize(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. @@ -158,7 +158,7 @@ def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Ten else: pass - def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: flat_input, encodings, encoding_indices = self.quantize(inputs) quantized = self.embed(encoding_indices) @@ -205,7 +205,7 @@ def __init__(self, quantizer: torch.nn.Module = None): self.perplexity: torch.Tensor = torch.rand(1) - def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: quantized, loss, encoding_indices = self.quantizer(inputs) # Perplexity calculations