Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions nemo/collections/tts/modules/waveglow.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,19 @@ def forward(self, spec, z=None, audio=None, run_inverse=True, sigma=1.0):

@property
def input_types(self):
return {
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"z": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"audio": NeuralType(('B', 'T'), AudioSignal(), optional=True),
"run_inverse": NeuralType(elements_type=IntType(), optional=True),
"sigma": NeuralType(optional=True),
}
if self.mode == OperationMode.infer:
return {
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"z": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
}
else:
return {
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"z": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"audio": NeuralType(('B', 'T'), AudioSignal(), optional=True),
"run_inverse": NeuralType(elements_type=IntType(), optional=True),
"sigma": NeuralType(optional=True),
}

@property
def output_types(self):
Expand Down
6 changes: 2 additions & 4 deletions tests/collections/nlp/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def test_get_pretrained_bert_model(self):
self.omega_conf.language_model.pretrained_model_name = 'bert-base-uncased'
model = nemo_nlp.modules.get_lm_model(cfg=self.omega_conf)
assert isinstance(model, nemo_nlp.modules.BertEncoder)
# TODO: Fix
# do_export(model, "bert-base-uncased")
do_export(model, "bert-base-uncased")

@pytest.mark.with_downloads()
@pytest.mark.unit
Expand All @@ -74,8 +73,7 @@ def test_get_pretrained_albert_model(self):
self.omega_conf.language_model.pretrained_model_name = 'albert-base-v1'
model = nemo_nlp.modules.get_lm_model(cfg=self.omega_conf)
assert isinstance(model, nemo_nlp.modules.AlbertEncoder)
# TODO: fix
# do_export(model, "albert-base-v1")
do_export(model, "albert-base-v1")

@pytest.mark.with_downloads()
@pytest.mark.unit
Expand Down
2 changes: 0 additions & 2 deletions tests/collections/nlp/test_nlp_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def test_IntentSlotClassificationModel_export_to_onnx(self, dummy_data):
assert onnx_model.graph.output[0].name == 'intent_logits'
assert onnx_model.graph.output[1].name == 'slot_logits'

@pytest.mark.pleasefixme
@pytest.mark.with_downloads()
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand Down Expand Up @@ -130,7 +129,6 @@ def test_PunctuationCapitalizationModel_export_to_onnx(self):
assert onnx_model.graph.output[0].name == 'punct_logits'
assert onnx_model.graph.output[1].name == 'capit_logits'

@pytest.mark.pleasefixme
@pytest.mark.with_downloads()
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand Down
8 changes: 1 addition & 7 deletions tests/collections/tts/test_tts_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,17 @@

@pytest.fixture()
def fastpitch_model():
test_root = os.path.dirname(os.path.abspath(__file__))
conf = OmegaConf.load(os.path.join(test_root, '../../../examples/tts/conf/fastpitch_align_v1.05.yaml'))
conf.train_dataset = conf.validation_datasets = '.'
conf.model.train_ds = conf.model.test_ds = conf.model.validation_ds = None
model = FastPitchModel(cfg=conf.model)
model = FastPitchModel.from_pretrained(model_name="tts_en_fastpitch")
return model


@pytest.fixture()
def hifigan_model():
test_root = os.path.dirname(os.path.abspath(__file__))
model = HifiGanModel.from_pretrained(model_name="tts_hifigan")
return model


class TestExportable:
@pytest.mark.pleasefixme
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
def test_FastPitchModel_export_to_onnx(self, fastpitch_model):
Expand Down
1 change: 0 additions & 1 deletion tests/collections/tts/test_waveglow.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def forward_wrapper(self, spec, z=None):


class TestWaveGlow:
@pytest.mark.pleasefixme
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
def test_export_to_onnx(self):
Expand Down