diff --git a/nemo/collections/tts/modules/waveglow.py b/nemo/collections/tts/modules/waveglow.py index b9c9aacc05df..daa5405298db 100644 --- a/nemo/collections/tts/modules/waveglow.py +++ b/nemo/collections/tts/modules/waveglow.py @@ -131,13 +131,19 @@ def forward(self, spec, z=None, audio=None, run_inverse=True, sigma=1.0): @property def input_types(self): - return { - "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), - "z": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), - "audio": NeuralType(('B', 'T'), AudioSignal(), optional=True), - "run_inverse": NeuralType(elements_type=IntType(), optional=True), - "sigma": NeuralType(optional=True), - } + if self.mode == OperationMode.infer: + return { + "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "z": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + } + else: + return { + "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "z": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "audio": NeuralType(('B', 'T'), AudioSignal(), optional=True), + "run_inverse": NeuralType(elements_type=IntType(), optional=True), + "sigma": NeuralType(optional=True), + } @property def output_types(self): diff --git a/tests/collections/nlp/test_huggingface.py b/tests/collections/nlp/test_huggingface.py index 68b742470d82..cfe2845caa9b 100644 --- a/tests/collections/nlp/test_huggingface.py +++ b/tests/collections/nlp/test_huggingface.py @@ -49,8 +49,7 @@ def test_get_pretrained_bert_model(self): self.omega_conf.language_model.pretrained_model_name = 'bert-base-uncased' model = nemo_nlp.modules.get_lm_model(cfg=self.omega_conf) assert isinstance(model, nemo_nlp.modules.BertEncoder) - # TODO: Fix - # do_export(model, "bert-base-uncased") + do_export(model, "bert-base-uncased") @pytest.mark.with_downloads() @pytest.mark.unit @@ -74,8 +73,7 @@ def test_get_pretrained_albert_model(self): self.omega_conf.language_model.pretrained_model_name = 'albert-base-v1' model = nemo_nlp.modules.get_lm_model(cfg=self.omega_conf) assert isinstance(model, nemo_nlp.modules.AlbertEncoder) - # TODO: fix - # do_export(model, "albert-base-v1") + do_export(model, "albert-base-v1") @pytest.mark.with_downloads() @pytest.mark.unit diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index e0043cdced7b..21f65ec5d94b 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -99,7 +99,6 @@ def test_IntentSlotClassificationModel_export_to_onnx(self, dummy_data): assert onnx_model.graph.output[0].name == 'intent_logits' assert onnx_model.graph.output[1].name == 'slot_logits' - @pytest.mark.pleasefixme @pytest.mark.with_downloads() @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -130,7 +129,6 @@ def test_PunctuationCapitalizationModel_export_to_onnx(self): assert onnx_model.graph.output[0].name == 'punct_logits' assert onnx_model.graph.output[1].name == 'capit_logits' - @pytest.mark.pleasefixme @pytest.mark.with_downloads() @pytest.mark.run_only_on('GPU') @pytest.mark.unit diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 40318173bd62..d847c3cf95e0 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -23,23 +23,17 @@ @pytest.fixture() def fastpitch_model(): - test_root = os.path.dirname(os.path.abspath(__file__)) - conf = OmegaConf.load(os.path.join(test_root, '../../../examples/tts/conf/fastpitch_align_v1.05.yaml')) - conf.train_dataset = conf.validation_datasets = '.' - conf.model.train_ds = conf.model.test_ds = conf.model.validation_ds = None - model = FastPitchModel(cfg=conf.model) + model = FastPitchModel.from_pretrained(model_name="tts_en_fastpitch") return model @pytest.fixture() def hifigan_model(): - test_root = os.path.dirname(os.path.abspath(__file__)) model = HifiGanModel.from_pretrained(model_name="tts_hifigan") return model class TestExportable: - @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_FastPitchModel_export_to_onnx(self, fastpitch_model): diff --git a/tests/collections/tts/test_waveglow.py b/tests/collections/tts/test_waveglow.py index 0d2388b9a124..9198f01c4226 100644 --- a/tests/collections/tts/test_waveglow.py +++ b/tests/collections/tts/test_waveglow.py @@ -73,7 +73,6 @@ def forward_wrapper(self, spec, z=None): class TestWaveGlow: - @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_export_to_onnx(self):