diff --git a/nemo/collections/tts/modules/aligner.py b/nemo/collections/tts/modules/aligner.py index 2910602474fd..f044a86a52eb 100644 --- a/nemo/collections/tts/modules/aligner.py +++ b/nemo/collections/tts/modules/aligner.py @@ -98,7 +98,7 @@ def get_dist(self, keys, queries, mask=None): self._apply_mask(dist, mask, float("inf")) - return dist + return dist.squeeze(1) @staticmethod def get_euclidean_dist(queries_enc, keys_enc):