diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c9c55eb8d61a..a84c35e62234 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -391,6 +391,300 @@ def _lstm(self, node: fx.Node) -> relax.Var: output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) return output + def _gru(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + hx = args[1] if len(args) > 1 else None + params = args[2] if len(args) > 2 else None + has_biases = args[3] if len(args) > 3 else True + num_layers = args[4] if len(args) > 4 else 1 + _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference + _train = args[6] if len(args) > 6 else False # Not used in inference + bidirectional = args[7] if len(args) > 7 else False + batch_first = args[8] if len(args) > 8 else False + + if bidirectional: + raise NotImplementedError("Bidirectional GRU is not yet supported") + + input_shape = self.shape_of(input_tensor) + if batch_first: + batch_size, seq_len, input_size = input_shape + else: + seq_len, batch_size, input_size = input_shape + + if isinstance(seq_len, tvm.tir.IntImm): + seq_len = seq_len.value + if isinstance(batch_size, tvm.tir.IntImm): + batch_size = batch_size.value + if isinstance(input_size, tvm.tir.IntImm): + input_size = input_size.value + + if params and len(params) >= 2: + # For multi-layer, we need to extract the first layer's weights + # to determine hidden size + if num_layers > 1: + # Multi-layer: params[0] is first layer's weight_ih + weight_ih = params[0] + else: + # Single layer: params[0] is weight_ih + weight_ih = params[0] + # Extract hidden size from weight dimensions + # weight_ih has shape (3 * hidden_size, input_size) + weight_ih_shape = self.shape_of(weight_ih) + hidden_size = weight_ih_shape[0] // 3 # 3 gates: reset, update, new + else: + # Fallback to a default hidden size + hidden_size = 16 + + # Implement actual GRU computation using Relax operations + # GRU equations: + # r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr) + # z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz) + # n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn)) + # h_t = (1 - z_t) * n_t + z_t * h_{t-1} + dtype = input_tensor.struct_info.dtype + + # Reshape input for processing + if batch_first: + # Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size) + input_reshaped = self.block_builder.emit( + relax.op.permute_dims(input_tensor, axes=[1, 0, 2]) + ) + else: + input_reshaped = input_tensor + + # Initialize hidden states for all layers + if hx is not None: + # hx shape: (num_layers, batch_size, hidden_size) + h_states = [] + for layer in range(num_layers): + h_layer = self.block_builder.emit( + relax.op.take(hx, relax.const(layer, "int64"), axis=0, mode="clip") + ) + h_states.append(h_layer) + else: + h_states = [] + for layer in range(num_layers): + h_layer = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + h_states.append(h_layer) + + outputs = [] + + for t in range(seq_len): + # Get input at time t: (batch_size, input_size) + x_t = self.block_builder.emit( + relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") + ) + + # Process through each layer + current_input = x_t + new_h_states = [] + + for layer in range(num_layers): + # Get layer parameters + if params and len(params) >= 4 * num_layers: + # Multi-layer case: params are organized as + # [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh, layer1_ih, ...] + param_offset = layer * 4 + weight_ih = params[param_offset] + weight_hh = params[param_offset + 1] + bias_ih = params[param_offset + 2] if has_biases else None + bias_hh = params[param_offset + 3] if has_biases else None + elif params and len(params) >= 4: + # Single layer case + weight_ih = params[0] + weight_hh = params[1] + bias_ih = params[2] if has_biases else None + bias_hh = params[3] if has_biases else None + else: + # Fallback: create zero weights + weight_ih = self.block_builder.emit( + relax.op.zeros( + relax.ShapeExpr( + (3 * hidden_size, input_size if layer == 0 else hidden_size) + ), + dtype, + ) + ) + weight_hh = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((3 * hidden_size, hidden_size)), dtype) + ) + bias_ih = None + bias_hh = None + + # Get previous hidden state for this layer + h_prev = h_states[layer] + + # Split weights by gates: PyTorch GRU gate order: reset, update, new (r, z, n) + gate_size = hidden_size + + # Reset gate weights + weight_ih_r = self.block_builder.emit( + relax.op.strided_slice(weight_ih, axes=[0], begin=[0], end=[gate_size]) + ) + weight_hh_r = self.block_builder.emit( + relax.op.strided_slice(weight_hh, axes=[0], begin=[0], end=[gate_size]) + ) + + # Update gate weights + weight_ih_z = self.block_builder.emit( + relax.op.strided_slice( + weight_ih, axes=[0], begin=[gate_size], end=[2 * gate_size] + ) + ) + weight_hh_z = self.block_builder.emit( + relax.op.strided_slice( + weight_hh, axes=[0], begin=[gate_size], end=[2 * gate_size] + ) + ) + + # New gate weights + weight_ih_n = self.block_builder.emit( + relax.op.strided_slice( + weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] + ) + ) + weight_hh_n = self.block_builder.emit( + relax.op.strided_slice( + weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] + ) + ) + + # Transpose weights for matmul + weight_ih_r_t = self.block_builder.emit( + relax.op.permute_dims(weight_ih_r, axes=[1, 0]) + ) + weight_hh_r_t = self.block_builder.emit( + relax.op.permute_dims(weight_hh_r, axes=[1, 0]) + ) + weight_ih_z_t = self.block_builder.emit( + relax.op.permute_dims(weight_ih_z, axes=[1, 0]) + ) + weight_hh_z_t = self.block_builder.emit( + relax.op.permute_dims(weight_hh_z, axes=[1, 0]) + ) + weight_ih_n_t = self.block_builder.emit( + relax.op.permute_dims(weight_ih_n, axes=[1, 0]) + ) + weight_hh_n_t = self.block_builder.emit( + relax.op.permute_dims(weight_hh_n, axes=[1, 0]) + ) + + # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr) + r_ih = self.block_builder.emit( + relax.op.linear_algebra.matmul(current_input, weight_ih_r_t) + ) + r_hh = self.block_builder.emit( + relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t) + ) + if bias_ih is not None and bias_hh is not None: + bias_ih_r = self.block_builder.emit( + relax.op.strided_slice(bias_ih, axes=[0], begin=[0], end=[gate_size]) + ) + bias_hh_r = self.block_builder.emit( + relax.op.strided_slice(bias_hh, axes=[0], begin=[0], end=[gate_size]) + ) + r_t = self.block_builder.emit( + relax.op.sigmoid( + relax.op.add( + relax.op.add(relax.op.add(r_ih, bias_ih_r), r_hh), bias_hh_r + ) + ) + ) + else: + r_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh))) + + # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz) + z_ih = self.block_builder.emit( + relax.op.linear_algebra.matmul(current_input, weight_ih_z_t) + ) + z_hh = self.block_builder.emit( + relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t) + ) + if bias_ih is not None and bias_hh is not None: + bias_ih_z = self.block_builder.emit( + relax.op.strided_slice( + bias_ih, axes=[0], begin=[gate_size], end=[2 * gate_size] + ) + ) + bias_hh_z = self.block_builder.emit( + relax.op.strided_slice( + bias_hh, axes=[0], begin=[gate_size], end=[2 * gate_size] + ) + ) + z_t = self.block_builder.emit( + relax.op.sigmoid( + relax.op.add( + relax.op.add(relax.op.add(z_ih, bias_ih_z), z_hh), bias_hh_z + ) + ) + ) + else: + z_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh))) + + # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn)) + n_ih = self.block_builder.emit( + relax.op.linear_algebra.matmul(current_input, weight_ih_n_t) + ) + n_hh = self.block_builder.emit( + relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t) + ) + if bias_ih is not None and bias_hh is not None: + bias_ih_n = self.block_builder.emit( + relax.op.strided_slice( + bias_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] + ) + ) + bias_hh_n = self.block_builder.emit( + relax.op.strided_slice( + bias_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] + ) + ) + n_t = self.block_builder.emit( + relax.op.tanh( + relax.op.add( + relax.op.add(n_ih, bias_ih_n), + relax.op.multiply(r_t, relax.op.add(n_hh, bias_hh_n)), + ) + ) + ) + else: + n_t = self.block_builder.emit( + relax.op.tanh(relax.op.add(n_ih, relax.op.multiply(r_t, n_hh))) + ) + + # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1} + one_minus_z = self.block_builder.emit( + relax.op.subtract(relax.const(1.0, dtype), z_t) + ) + h_t = self.block_builder.emit( + relax.op.add( + relax.op.multiply(one_minus_z, n_t), relax.op.multiply(z_t, h_prev) + ) + ) + + new_h_states.append(h_t) + + current_input = h_t + + # Update hidden states for next time step + h_states = new_h_states + + # Store output (from the last layer) + outputs.append(h_states[-1]) + + # Stack outputs: (seq_len, batch_size, hidden_size) + output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) + + # Reshape back to batch_first if needed + if batch_first: + # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size) + output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) + + return output + ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: @@ -652,6 +946,7 @@ def create_convert_map( "layer_norm.default": self._layer_norm, "linear.default": self._linear, "lstm.input": self._lstm, + "gru.input": self._gru, "max_pool1d.default": self._max_pool1d, "max_pool2d.default": self._max_pool2d, "max_pool3d.default": self._max_pool3d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index b35af088b530..657ade455bd7 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6050,5 +6050,76 @@ def main( verify_model(TensorNoneModel(), example_args, {}, Expected) +def test_gru(): + class BasicGRU(nn.Module): + def __init__(self): + super().__init__() + self.gru = nn.GRU( + input_size=4, + hidden_size=8, + num_layers=1, + batch_first=True, + bidirectional=False, + ) + + def forward(self, x): + y, _ = self.gru(x) + return y + + torch.manual_seed(42) + x = torch.randn(2, 3, 4, dtype=torch.float32) + model = BasicGRU() + with torch.no_grad(): + pytorch_output = model(x) + exported_program = export(model, args=(x,)) + mod = from_exported_program(exported_program) + target = tvm.target.Target("llvm") + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_tvm = tvm.runtime.tensor(x.numpy()) + tvm_output = vm["main"](x_tvm) + if hasattr(tvm_output, "numpy"): + tvm_output_np = tvm_output.numpy() + else: + tvm_output_np = tvm_output[0].numpy() + assert ( + pytorch_output.shape == tvm_output_np.shape + ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" + np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) + + class SeqFirstGRU(nn.Module): + def __init__(self): + super().__init__() + self.gru = nn.GRU( + input_size=3, + hidden_size=6, + num_layers=1, + batch_first=False, + bidirectional=False, + ) + + def forward(self, x): + y, _ = self.gru(x) + return y + + torch.manual_seed(43) + x2 = torch.randn(4, 2, 3, dtype=torch.float32) + model2 = SeqFirstGRU() + with torch.no_grad(): + pytorch_output2 = model2(x2) + exported_program2 = export(model2, args=(x2,)) + mod2 = from_exported_program(exported_program2) + ex2 = relax.build(mod2, target) + vm2 = relax.VirtualMachine(ex2, tvm.cpu()) + x2_tvm = tvm.runtime.tensor(x2.numpy()) + tvm_output2 = vm2["main"](x2_tvm) + if hasattr(tvm_output2, "numpy"): + tvm_output2_np = tvm_output2.numpy() + else: + tvm_output2_np = tvm_output2[0].numpy() + assert pytorch_output2.shape == tvm_output2_np.shape + np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) + + if __name__ == "__main__": tvm.testing.main()