diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3b5bf9acfa42..642d782a7870 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -53,6 +53,7 @@ infer_value, lstm_cell, new_var, + rnn_cell, shape_of, try_resolve_var_to_const, unbind, @@ -2723,7 +2724,7 @@ def expand_shape(in_shape, shape): class RNN(OnnxOpConverter): - """Operator converter for RNNs such as LSTM and GRU.""" + """Operator converter for RNNs such as RNN, LSTM and GRU.""" @classmethod def _activation_helper(cls, activation, alpha, beta): @@ -2756,35 +2757,27 @@ def _activation_needs_beta(cls, activation): ] return activation.decode("utf-8") in needs_beta - -class LSTM(RNN): - """Operator converter for LSTM""" - @classmethod - def bidir_lstm_cell( + def bidir_rnn_cell( cls, input_seqs, weight_dicts, acts, ): """ - Bidirectional LSTM cell + Bidirectional RNN cell """ seq_len = len(input_seqs) - forward_outputs, fw_H_t, fw_C_t = lstm_cell( + forward_outputs, fw_H_t = rnn_cell( input_seqs, **weight_dicts[0], - f_act=acts[0], - g_act=acts[1], - h_act=acts[2], + act=acts[0], ) - reverse_outputs, rev_H_t, rev_C_t = lstm_cell( + reverse_outputs, rev_H_t = rnn_cell( input_seqs, **weight_dicts[1], - f_act=acts[3], - g_act=acts[4], - h_act=acts[5], + act=acts[1], backwards=True, ) @@ -2797,44 +2790,24 @@ def bidir_lstm_cell( return ( _op.stack(final_outputs, axis=0), _op.stack([fw_H_t, rev_H_t], axis=0), - _op.stack([fw_C_t, rev_C_t], axis=0), ) @classmethod - def _impl_v7(cls, inputs, attr, params): - # Unpack inputs, note that if optional and not provided then value will be None. - X = inputs[0] - Wp = inputs[1] - Rp = inputs[2] - Bp = inputs[3] - # Sequence length currently unused as it can be inferred from shapes. - # sequence_lens = inputs['sequence_lens'] - Hp_0 = inputs[5] - Cp_0 = inputs[6] - Pp = inputs[7] - - num_directions = infer_shape(Wp)[0] - W_dtype = infer_type(Wp).checked_type.dtype - - if num_directions not in [1, 2]: - raise ValueError("num_directions must be either 1 or 2!") - - X_shape = infer_shape(X) - hidden_size = infer_shape(Rp)[-1] - batch_size = X_shape[1] - - # Initialize state if not provided. - # Otherwise remove bidirectional axis. - if Hp_0 is None: - Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - if Cp_0 is None: - Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + def _default_activations(cls, num_directions): + return [_op.tanh] * num_directions + @classmethod + def _get_activations(cls, attr, multiplier, num_directions, rnn_type): + """ + Activation functions + """ if "activations" in attr: activations = attr["activations"] - if len(activations) != 3 * num_directions: + if len(activations) != multiplier * num_directions: raise NotImplementedError( - f"LSTM assumes 3 * num_directions activation functions are provided" + "{} assumes {} * num_directions activation functions are provided".format( + rnn_type, multiplier + ) ) alpha_loc = 0 alphas = attr.get("activation_alpha", []) @@ -2845,7 +2818,7 @@ def _impl_v7(cls, inputs, attr, params): if isinstance(betas, float): betas = [betas] acts = [] - for i in range(3 * num_directions): + for i in range(multiplier * num_directions): alpha = None beta = None activation = activations[i] @@ -2857,18 +2830,171 @@ def _impl_v7(cls, inputs, attr, params): beta_loc += 1 acts.append(cls._activation_helper(activation, alpha, beta)) else: - acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions + acts = cls._default_activations(num_directions) + return acts + + @classmethod + def _inputs_helper(cls, inputs, layout): + """ + Process inputs + """ + # Unpack inputs, note that if optional and not provided then value will be None. + X = inputs[0] + Wp = inputs[1] + Rp = inputs[2] + Bp = inputs[3] + # Sequence length currently unused as it can be inferred from shapes. + # sequence_lens = inputs['sequence_lens'] + Hp_0 = inputs[5] + + num_directions = infer_shape(Wp)[0] + + if num_directions not in [1, 2]: + raise ValueError("num_directions must be either 1 or 2!") + + if layout == 1: + X = _op.transpose(X, axes=(1, 0)) + + # Initialize state if not provided. + if Hp_0 is None: + W_dtype = infer_type(Wp).checked_type.dtype + X_shape = infer_shape(X) + hidden_size = infer_shape(Rp)[-1] + batch_size = X_shape[1] + Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + elif layout == 1: + Hp_0 = _op.transpose(Hp_0, axes=(1, 0)) # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved X_steps = unbind(X, axis=0) H_ts = _op.split(Hp_0, num_directions) - C_ts = _op.split(Cp_0, num_directions) Ws = _op.split(Wp, num_directions) Rs = _op.split(Rp, num_directions) + Bs = None if Bp is not None: Bs = _op.split(Bp, num_directions) + return X_steps, H_ts, Ws, Rs, Bs, num_directions + + @classmethod + def _impl_common(cls, inputs, attr, layout): + X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout) + acts = cls._get_activations(attr, 1, num_directions, "RNN") + + weights_dicts = [] + for i in range(num_directions): + weights_dict = {} + + weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0]) + + weights_dict["w_inp"] = _op.squeeze(Ws[i], axis=[0]) + weights_dict["w_hid"] = _op.squeeze(Rs[i], axis=[0]) + if Bs is not None: + Bi, Bh = _op.split(Bs[i], 2, -1) + weights_dict["b_inp"] = _op.squeeze(Bi, axis=[0]) + weights_dict["b_hid"] = _op.squeeze(Bh, axis=[0]) + weights_dicts.append(weights_dict) + + if num_directions == 2: + output, H = RNN.bidir_rnn_cell( + input_seqs=X_steps, + weight_dicts=weights_dicts, + acts=acts, + ) + else: + # outputs shape = [seqs_num, (batch_size, hidden_size)] + outputs, H = rnn_cell( + input_seqs=X_steps, + **weights_dicts[0], + act=acts[0], + ) + + # output shape = (seqs_num, num_directions, batch_size, hidden_size) + output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) + H = _op.expand_dims(H, axis=0) + + if layout == 1: + output = _op.transpose(output, axes=(1, 0)) + H = _op.transpose(H, axes=(1, 0)) + return _expr.TupleWrapper(_expr.Tuple((output, H)), 2) + + @classmethod + def _impl_v7(cls, inputs, attr, params): + return cls._impl_common(inputs, attr, 0) + + @classmethod + def _impl_v14(cls, inputs, attr, params): + layout = attr.get("layout", 0) + return cls._impl_common(inputs, attr, layout) + + +class LSTM(RNN): + """Operator converter for LSTM""" + + @classmethod + def bidir_lstm_cell( + cls, + input_seqs, + weight_dicts, + acts, + ): + """ + Bidirectional LSTM cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t, fw_C_t = lstm_cell( + input_seqs, + **weight_dicts[0], + f_act=acts[0], + g_act=acts[1], + h_act=acts[2], + ) + + reverse_outputs, rev_H_t, rev_C_t = lstm_cell( + input_seqs, + **weight_dicts[1], + f_act=acts[3], + g_act=acts[4], + h_act=acts[5], + backwards=True, + ) + + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0) + ) + + return ( + _op.stack(final_outputs, axis=0), + _op.stack([fw_H_t, rev_H_t], axis=0), + _op.stack([fw_C_t, rev_C_t], axis=0), + ) + + @classmethod + def _default_activations(cls, num_directions): + return [_op.sigmoid, _op.tanh, _op.tanh] * num_directions + + @classmethod + def _impl_common(cls, inputs, attr, layout): + X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout) + acts = cls._get_activations(attr, 3, num_directions, "LSTM") + + # cell state + Cp_0 = inputs[6] + if Cp_0 is None: + C_ts = _expr.TupleWrapper( + _expr.Tuple([_op.zeros_like(H_ts[i]) for i in range(num_directions)]), + num_directions, + ) + else: + if layout == 1: + Cp_0 = _op.transpose(Cp_0, axes=(1, 0)) + C_ts = _op.split(Cp_0, num_directions) + + # peepholes + Pp = inputs[7] if Pp is not None: p_i, p_o, p_f = _op.split(Pp, 3, axis=1) @@ -2888,7 +3014,7 @@ def _impl_v7(cls, inputs, attr, params): weights_dict["w_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0) mati, mato, matf, matc = _op.split(_op.squeeze(Rs[i], axis=[0]), 4) weights_dict["w_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0) - if Bp is not None: + if Bs is not None: Bi, Bh = _op.split(Bs[i], 2, -1) mati, mato, matf, matc = _op.split(_op.squeeze(Bi, axis=[0]), 4) weights_dict["b_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0) @@ -2921,6 +3047,10 @@ def _impl_v7(cls, inputs, attr, params): H = _op.expand_dims(H, axis=0) C = _op.expand_dims(C, axis=0) + if layout == 1: + output = _op.transpose(output, axes=(1, 0)) + H = _op.transpose(H, axes=(1, 0)) + C = _op.transpose(C, axes=(1, 0)) return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) @@ -2965,68 +3095,14 @@ def bidir_gru_cell( ) @classmethod - def _impl_v7(cls, inputs, attr, params): - # Unpack inputs, note that if optional and not provided then value will be None. - X = inputs[0] - Wp = inputs[1] - Rp = inputs[2] - Bp = inputs[3] - # Sequence length currently unused as it can be inferred from shapes. - # sequence_lens = inputs['sequence_lens'] - Hp_0 = inputs[5] - linear_before_reset = attr.get("linear_before_reset", 0) - - num_directions = infer_shape(Wp)[0] - W_dtype = infer_type(Wp).checked_type.dtype - - if num_directions not in [1, 2]: - raise ValueError("num_directions must be either 1 or 2!") - - X_shape = infer_shape(X) - hidden_size = infer_shape(Rp)[-1] - batch_size = X_shape[1] - - if Hp_0 is None: - Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + def _default_activations(cls, num_directions): + return [_op.sigmoid, _op.tanh] * num_directions - if "activations" in attr: - activations = attr["activations"] - if len(activations) != 2 * num_directions: - raise NotImplementedError( - "GRU assumes 2 * num_directions activation functions are provided" - ) - alpha_loc = 0 - alphas = attr.get("activation_alpha", []) - if isinstance(alphas, float): - alphas = [alphas] - beta_loc = 0 - betas = attr.get("activation_beta", []) - if isinstance(betas, float): - betas = [betas] - acts = [] - for i in range(2 * num_directions): - alpha = None - beta = None - activation = activations[i] - if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: - alpha = alphas[alpha_loc] - alpha_loc += 1 - if cls._activation_needs_beta(activation) and len(betas) > beta_loc: - beta = betas[beta_loc] - beta_loc += 1 - acts.append(cls._activation_helper(activation, alpha, beta)) - else: - acts = [_op.sigmoid, _op.tanh] * 2 - - # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved - X_steps = unbind(X, axis=0) - - H_ts = _op.split(Hp_0, num_directions) - Ws = _op.split(Wp, num_directions) - Rs = _op.split(Rp, num_directions) - - if Bp is not None: - Bs = _op.split(Bp, num_directions) + @classmethod + def _impl_common(cls, inputs, attr, layout): + X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout) + acts = cls._get_activations(attr, 2, num_directions, "GRU") + linear_before_reset = attr.get("linear_before_reset", 0) weights_dicts = [] for i in range(num_directions): @@ -3040,7 +3116,7 @@ def _impl_v7(cls, inputs, attr, params): weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0) matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3) weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0) - if Bp is not None: + if Bs is not None: Bi, Bh = _op.split(Bs[i], 2, -1) matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3) weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], axis=0) @@ -3067,6 +3143,9 @@ def _impl_v7(cls, inputs, attr, params): output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) H = _op.expand_dims(H, axis=0) + if layout == 1: + output = _op.transpose(output, axes=(1, 0)) + H = _op.transpose(H, axes=(1, 0)) return _expr.TupleWrapper(_expr.Tuple((output, H)), 2) @@ -5287,6 +5366,7 @@ def _get_convert_map(opset): "Flatten": Flatten.get_converter(opset), "LRN": LRN.get_converter(opset), # Recurrent Layers + "RNN": RNN.get_converter(opset), "LSTM": LSTM.get_converter(opset), "GRU": GRU.get_converter(opset), # defs/vision diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d68b76751184..9d1817b7a310 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3748,12 +3748,15 @@ def verify_rnn( use_peep=False, linear_before_reset=False, directions=1, + layout=0, rtol=1e-5, atol=1e-5, target=None, dev=None, ): - if rnn_type == "LSTM": + if rnn_type == "RNN": + multiplier = 1 + elif rnn_type == "LSTM": multiplier = 4 elif rnn_type == "GRU": multiplier = 3 @@ -3786,7 +3789,10 @@ def register(np_arr, name, shape=None): proto_type = dtype_map[np_arr.dtype.name] input_tensors.append(helper.make_tensor_value_info(name, proto_type, shape)) - x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype("float32") + if layout == 1: + x_np = np.random.uniform(size=(batch_size, seq_length, input_size)).astype("float32") + else: + x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype("float32") w_np = np.random.uniform(size=(directions, multiplier * hidden_size, input_size)).astype( "float32" ) @@ -3808,15 +3814,25 @@ def register(np_arr, name, shape=None): sequence_np = np.repeat(seq_length, batch_size).astype("int32") register(sequence_np, "sequence_lens") - initial_h_np = np.random.uniform(size=(directions, batch_size, hidden_size)).astype( - "float32" - ) + if layout == 1: + initial_h_np = np.random.uniform(size=(batch_size, directions, hidden_size)).astype( + "float32" + ) + else: + initial_h_np = np.random.uniform(size=(directions, batch_size, hidden_size)).astype( + "float32" + ) register(initial_h_np, "initial_h") if rnn_type == "LSTM": - initial_c_np = np.random.uniform(size=(directions, batch_size, hidden_size)).astype( - "float32" - ) + if layout == 1: + initial_c_np = np.random.uniform( + size=(batch_size, directions, hidden_size) + ).astype("float32") + else: + initial_c_np = np.random.uniform( + size=(directions, batch_size, hidden_size) + ).astype("float32") register(initial_c_np, "initial_c") if use_peep and rnn_type == "LSTM": @@ -3838,11 +3854,18 @@ def register(name, shape, proto_type): graph_outputs.append(helper.make_tensor_value_info(name, proto_type, list(shape))) output_shapes.append(list(shape)) - register("Y", [seq_length, directions, batch_size, hidden_size], TensorProto.FLOAT) - register("Y_h", [directions, batch_size, hidden_size], TensorProto.FLOAT) + if layout == 1: + register("Y", [directions, seq_length, batch_size, hidden_size], TensorProto.FLOAT) + register("Y_h", [batch_size, directions, hidden_size], TensorProto.FLOAT) + else: + register("Y", [seq_length, directions, batch_size, hidden_size], TensorProto.FLOAT) + register("Y_h", [directions, batch_size, hidden_size], TensorProto.FLOAT) if rnn_type == "LSTM": - register("Y_c", [directions, batch_size, hidden_size], TensorProto.FLOAT) + if layout == 1: + register("Y_c", [batch_size, directions, hidden_size], TensorProto.FLOAT) + else: + register("Y_c", [directions, batch_size, hidden_size], TensorProto.FLOAT) return output_names, graph_outputs, output_shapes @@ -3866,6 +3889,9 @@ def register(name, shape, proto_type): if linear_before_reset and rnn_type == "GRU": lbr_attr = helper.make_attribute("linear_before_reset", 1) rnn_node.attribute.append(lbr_attr) + if layout == 1: + layout_attr = helper.make_attribute("layout", 1) + rnn_node.attribute.append(layout_attr) graph = helper.make_graph([rnn_node], "rnn_test", inputs=input_tensors, outputs=graph_outputs) @@ -3876,8 +3902,13 @@ def register(name, shape, proto_type): ) -@tvm.testing.parametrize_targets -def test_lstm(target, dev): +def verify_rnn_helper(target, dev, rnn_type): + num_activations = 1 + if rnn_type == "GRU": + num_activations = 2 + elif rnn_type == "LSTM": + num_activations = 3 + for directions in [1, 2]: # No bias. verify_rnn( @@ -3886,7 +3917,7 @@ def test_lstm(target, dev): input_size=16, hidden_size=32, use_bias=False, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -3898,7 +3929,7 @@ def test_lstm(target, dev): input_size=16, hidden_size=32, use_bias=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -3910,7 +3941,7 @@ def test_lstm(target, dev): input_size=16, hidden_size=40, use_bias=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -3922,7 +3953,7 @@ def test_lstm(target, dev): input_size=16, hidden_size=32, use_bias=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -3934,7 +3965,7 @@ def test_lstm(target, dev): input_size=16, hidden_size=128, use_bias=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -3946,7 +3977,7 @@ def test_lstm(target, dev): input_size=64, hidden_size=32, use_bias=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -3954,50 +3985,59 @@ def test_lstm(target, dev): # Different activation testing. # Default value hardsigmoid. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "Tanh", "Tanh"] * directions, - rnn_type="LSTM", - directions=directions, - target=target, - dev=dev, - ) + # TODO: onnxruntime <= v1.12.0 has wrong default value of all activation functions + if rnn_type != "RNN": + activations = ["HardSigmoid", "Tanh", "Tanh"][0:num_activations] * directions + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=activations, + rnn_type=rnn_type, + directions=directions, + target=target, + dev=dev, + ) # Multiple parametrized activations. + activations = ["HardSigmoid", "LeakyRelu", "Tanh"][0:num_activations] * directions + alphas = [2.0, 0.5, 0.0][0:num_activations] * directions + betas = [0.3, 0.0, 0.0][0:num_activations] * directions verify_rnn( seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False, - activations=["HardSigmoid", "LeakyRelu", "Tanh"] * directions, - alphas=[2.0, 0.5, 0.0] * directions, - betas=[0.3, 0.0, 0.0] * directions, - rnn_type="LSTM", + activations=activations, + alphas=alphas, + betas=betas, + rnn_type=rnn_type, directions=directions, target=target, dev=dev, ) # All parametrized with new Affine activation. + activations = ["Affine", "LeakyRelu", "HardSigmoid"][0:num_activations] * directions + alphas = [0.8, 2.0, 0.5][0:num_activations] * directions + betas = [0.0, 0.3, 0.0][0:num_activations] * directions verify_rnn( seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False, - activations=["HardSigmoid", "LeakyRelu", "Affine"] * directions, - alphas=[2.0, 0.5, 0.8] * directions, - betas=[0.3, 0.1, 0.0] * directions, - rnn_type="LSTM", + activations=activations, + alphas=alphas, + betas=betas, + rnn_type=rnn_type, directions=directions, target=target, dev=dev, ) - # Testing with initial state and peepholes + # Testing with initial state verify_rnn( seq_length=2, batch_size=1, @@ -4005,182 +4045,57 @@ def test_lstm(target, dev): hidden_size=32, use_bias=True, use_initial_state=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, ) - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - use_initial_state=True, - use_peep=True, - rnn_type="LSTM", - directions=directions, - target=target, - dev=dev, - ) + # Testing layout + # TODO: onnxruntime <= 1.12.0 doesn't support layout == 1 + # verify_rnn( + # seq_length=2, + # batch_size=1, + # input_size=16, + # hidden_size=32, + # use_bias=True, + # rnn_type="RNN", + # directions=directions, + # layout=1, + # target=target, + # dev=dev, + # ) + + # Testing with peepholes + if rnn_type == "LSTM": + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=True, + use_initial_state=True, + use_peep=True, + rnn_type="LSTM", + directions=directions, + target=target, + dev=dev, + ) @tvm.testing.parametrize_targets -def test_gru(target, dev): - # Set seed for test reproduction - np.random.seed(137) - for directions in [1, 2]: - # No bias. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) - # large batch. linear before reset - verify_rnn( - seq_length=4, - batch_size=8, - input_size=16, - hidden_size=32, - use_bias=True, - rnn_type="GRU", - linear_before_reset=True, - directions=directions, - target=target, - dev=dev, - ) - # Non power of two. - verify_rnn( - seq_length=3, - batch_size=3, - input_size=16, - hidden_size=40, - use_bias=True, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) - # Long sequence. - verify_rnn( - seq_length=8, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) - # Large hidden. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=128, - use_bias=True, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) - # Large input. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=64, - hidden_size=32, - use_bias=True, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) +def test_rnn(target, dev): + verify_rnn_helper(target, dev, "RNN") - # Different activation testing. - # Default value hardsigmoid. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "Softsign"] * directions, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) - # Multiple parametrized activations. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "LeakyRelu"] * directions, - alphas=[2.0, 0.5] * directions, - betas=[0.3, 0.0] * directions, - rnn_type="GRU", - directions=directions, - rtol=1e-8, - atol=1e-8, - target=target, - dev=dev, - ) - # All parametrized with new Affine activation. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "Affine"] * directions, - alphas=[2.0, 0.8] * directions, - betas=[0.3, 0.1] * directions, - rnn_type="GRU", - directions=directions, - rtol=1e-8, - atol=1e-8, - target=target, - dev=dev, - ) - # Testing with initial state - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - use_initial_state=True, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) +@tvm.testing.parametrize_targets +def test_lstm(target, dev): + verify_rnn_helper(target, dev, "LSTM") + + +@tvm.testing.parametrize_targets +def test_gru(target, dev): + verify_rnn_helper(target, dev, "GRU") @tvm.testing.parametrize_targets @@ -5212,7 +5127,6 @@ def verify_eyelike(indata, dynamic=False): "test_reduce_sum_keepdims_random", "test_reduce_sum_negative_axes_keepdims_example", "test_reduce_sum_negative_axes_keepdims_random", - "test_rnn_seq_length", "test_sequence_insert_at_back", "test_sequence_insert_at_front", "test_simple_rnn_batchwise",