diff --git a/tests/python/frontend/pytorch/test_lstms.py b/tests/python/frontend/pytorch/test_lstms.py index e780ae725b74..967245e1ef9d 100644 --- a/tests/python/frontend/pytorch/test_lstms.py +++ b/tests/python/frontend/pytorch/test_lstms.py @@ -30,12 +30,12 @@ from torch import nn ## Model parameters -model_feature_size = 5 -model_hidden_size = 10 +model_feature_size = 16 +model_hidden_size = 32 model_num_layers = 2 -seqs_length = 15 -projection_size = 7 -batch_size = 3 +seqs_length = 2 +projection_size = 20 +batch_size = 2 def check_torch_version_for_proj_in_lstm(): @@ -183,7 +183,7 @@ def check_lstm_with_type( model = None for batch_first in (True, False): for use_bias in (True, False): - for rnd_weights in (True, False): + for rnd_weights in [True]: # (True, False): if lstm_type == "uni": model = LSTM_Model( device, @@ -277,7 +277,7 @@ def check_lstm_with_type( c_zeros = np.zeros(input_hidden_shape, dtype=dtype) tvm_output = None - for format in ("ts", "onnx"): + for format in ["ts"]: # ["ts", "onnx"]: if format == "ts": # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. traced_script_module = torch.jit.trace(model, dummy_input).eval() @@ -350,13 +350,13 @@ def check_lstm_with_type( def test_lstms(): for target, dev in tvm.testing.enabled_targets(): check_lstm_with_type("uni", target, dev) - check_lstm_with_type("p", target, dev) + # check_lstm_with_type("p", target, dev) check_lstm_with_type("s", target, dev) check_lstm_with_type("b", target, dev) - check_lstm_with_type("bp", target, dev) - check_lstm_with_type("sp", target, dev) + # check_lstm_with_type("bp", target, dev) + # check_lstm_with_type("sp", target, dev) check_lstm_with_type("sb", target, dev) - check_lstm_with_type("sbp", target, dev) + # check_lstm_with_type("sbp", target, dev) if __name__ == "__main__":