-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax][PyTorch] Support gru op for ExportedProgram importer #18360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -391,6 +391,300 @@ def _lstm(self, node: fx.Node) -> relax.Var: | |
| output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) | ||
| return output | ||
|
|
||
| def _gru(self, node: fx.Node) -> relax.Var: | ||
| args = self.retrieve_args(node) | ||
| input_tensor = args[0] | ||
| hx = args[1] if len(args) > 1 else None | ||
| params = args[2] if len(args) > 2 else None | ||
| has_biases = args[3] if len(args) > 3 else True | ||
| num_layers = args[4] if len(args) > 4 else 1 | ||
| _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference | ||
| _train = args[6] if len(args) > 6 else False # Not used in inference | ||
| bidirectional = args[7] if len(args) > 7 else False | ||
| batch_first = args[8] if len(args) > 8 else False | ||
|
|
||
| if bidirectional: | ||
| raise NotImplementedError("Bidirectional GRU is not yet supported") | ||
|
|
||
| input_shape = self.shape_of(input_tensor) | ||
| if batch_first: | ||
| batch_size, seq_len, input_size = input_shape | ||
| else: | ||
| seq_len, batch_size, input_size = input_shape | ||
|
|
||
| if isinstance(seq_len, tvm.tir.IntImm): | ||
| seq_len = seq_len.value | ||
| if isinstance(batch_size, tvm.tir.IntImm): | ||
| batch_size = batch_size.value | ||
| if isinstance(input_size, tvm.tir.IntImm): | ||
| input_size = input_size.value | ||
|
|
||
| if params and len(params) >= 2: | ||
| # For multi-layer, we need to extract the first layer's weights | ||
| # to determine hidden size | ||
| if num_layers > 1: | ||
| # Multi-layer: params[0] is first layer's weight_ih | ||
| weight_ih = params[0] | ||
| else: | ||
| # Single layer: params[0] is weight_ih | ||
| weight_ih = params[0] | ||
| # Extract hidden size from weight dimensions | ||
| # weight_ih has shape (3 * hidden_size, input_size) | ||
| weight_ih_shape = self.shape_of(weight_ih) | ||
| hidden_size = weight_ih_shape[0] // 3 # 3 gates: reset, update, new | ||
| else: | ||
| # Fallback to a default hidden size | ||
| hidden_size = 16 | ||
|
|
||
| # Implement actual GRU computation using Relax operations | ||
| # GRU equations: | ||
| # r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr) | ||
| # z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz) | ||
| # n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn)) | ||
| # h_t = (1 - z_t) * n_t + z_t * h_{t-1} | ||
| dtype = input_tensor.struct_info.dtype | ||
|
|
||
| # Reshape input for processing | ||
| if batch_first: | ||
| # Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size) | ||
| input_reshaped = self.block_builder.emit( | ||
| relax.op.permute_dims(input_tensor, axes=[1, 0, 2]) | ||
| ) | ||
| else: | ||
| input_reshaped = input_tensor | ||
|
|
||
| # Initialize hidden states for all layers | ||
| if hx is not None: | ||
| # hx shape: (num_layers, batch_size, hidden_size) | ||
| h_states = [] | ||
| for layer in range(num_layers): | ||
| h_layer = self.block_builder.emit( | ||
| relax.op.take(hx, relax.const(layer, "int64"), axis=0, mode="clip") | ||
| ) | ||
| h_states.append(h_layer) | ||
| else: | ||
| h_states = [] | ||
| for layer in range(num_layers): | ||
| h_layer = self.block_builder.emit( | ||
| relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) | ||
| ) | ||
| h_states.append(h_layer) | ||
|
|
||
| outputs = [] | ||
|
|
||
| for t in range(seq_len): | ||
| # Get input at time t: (batch_size, input_size) | ||
| x_t = self.block_builder.emit( | ||
| relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") | ||
| ) | ||
|
|
||
| # Process through each layer | ||
| current_input = x_t | ||
| new_h_states = [] | ||
|
|
||
| for layer in range(num_layers): | ||
| # Get layer parameters | ||
| if params and len(params) >= 4 * num_layers: | ||
| # Multi-layer case: params are organized as | ||
| # [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh, layer1_ih, ...] | ||
| param_offset = layer * 4 | ||
| weight_ih = params[param_offset] | ||
| weight_hh = params[param_offset + 1] | ||
| bias_ih = params[param_offset + 2] if has_biases else None | ||
| bias_hh = params[param_offset + 3] if has_biases else None | ||
| elif params and len(params) >= 4: | ||
| # Single layer case | ||
| weight_ih = params[0] | ||
| weight_hh = params[1] | ||
| bias_ih = params[2] if has_biases else None | ||
| bias_hh = params[3] if has_biases else None | ||
| else: | ||
| # Fallback: create zero weights | ||
| weight_ih = self.block_builder.emit( | ||
| relax.op.zeros( | ||
| relax.ShapeExpr( | ||
| (3 * hidden_size, input_size if layer == 0 else hidden_size) | ||
| ), | ||
| dtype, | ||
| ) | ||
| ) | ||
| weight_hh = self.block_builder.emit( | ||
| relax.op.zeros(relax.ShapeExpr((3 * hidden_size, hidden_size)), dtype) | ||
| ) | ||
| bias_ih = None | ||
| bias_hh = None | ||
|
|
||
| # Get previous hidden state for this layer | ||
| h_prev = h_states[layer] | ||
|
|
||
| # Split weights by gates: PyTorch GRU gate order: reset, update, new (r, z, n) | ||
| gate_size = hidden_size | ||
|
|
||
| # Reset gate weights | ||
| weight_ih_r = self.block_builder.emit( | ||
| relax.op.strided_slice(weight_ih, axes=[0], begin=[0], end=[gate_size]) | ||
| ) | ||
| weight_hh_r = self.block_builder.emit( | ||
| relax.op.strided_slice(weight_hh, axes=[0], begin=[0], end=[gate_size]) | ||
| ) | ||
|
|
||
| # Update gate weights | ||
| weight_ih_z = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| weight_ih, axes=[0], begin=[gate_size], end=[2 * gate_size] | ||
| ) | ||
| ) | ||
| weight_hh_z = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| weight_hh, axes=[0], begin=[gate_size], end=[2 * gate_size] | ||
| ) | ||
| ) | ||
|
|
||
| # New gate weights | ||
| weight_ih_n = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] | ||
| ) | ||
| ) | ||
| weight_hh_n = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] | ||
| ) | ||
| ) | ||
|
|
||
| # Transpose weights for matmul | ||
| weight_ih_r_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_ih_r, axes=[1, 0]) | ||
| ) | ||
| weight_hh_r_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_hh_r, axes=[1, 0]) | ||
| ) | ||
| weight_ih_z_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_ih_z, axes=[1, 0]) | ||
| ) | ||
| weight_hh_z_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_hh_z, axes=[1, 0]) | ||
| ) | ||
| weight_ih_n_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_ih_n, axes=[1, 0]) | ||
| ) | ||
| weight_hh_n_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_hh_n, axes=[1, 0]) | ||
| ) | ||
|
Comment on lines
+523
to
+573
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The weight slicing and transposition operations are performed inside the time-step loop ( |
||
|
|
||
| # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr) | ||
| r_ih = self.block_builder.emit( | ||
| relax.op.linear_algebra.matmul(current_input, weight_ih_r_t) | ||
| ) | ||
| r_hh = self.block_builder.emit( | ||
| relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t) | ||
| ) | ||
| if bias_ih is not None and bias_hh is not None: | ||
| bias_ih_r = self.block_builder.emit( | ||
| relax.op.strided_slice(bias_ih, axes=[0], begin=[0], end=[gate_size]) | ||
| ) | ||
| bias_hh_r = self.block_builder.emit( | ||
| relax.op.strided_slice(bias_hh, axes=[0], begin=[0], end=[gate_size]) | ||
| ) | ||
| r_t = self.block_builder.emit( | ||
| relax.op.sigmoid( | ||
| relax.op.add( | ||
| relax.op.add(relax.op.add(r_ih, bias_ih_r), r_hh), bias_hh_r | ||
| ) | ||
| ) | ||
| ) | ||
| else: | ||
| r_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh))) | ||
|
|
||
| # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz) | ||
| z_ih = self.block_builder.emit( | ||
| relax.op.linear_algebra.matmul(current_input, weight_ih_z_t) | ||
| ) | ||
| z_hh = self.block_builder.emit( | ||
| relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t) | ||
| ) | ||
| if bias_ih is not None and bias_hh is not None: | ||
| bias_ih_z = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| bias_ih, axes=[0], begin=[gate_size], end=[2 * gate_size] | ||
| ) | ||
| ) | ||
| bias_hh_z = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| bias_hh, axes=[0], begin=[gate_size], end=[2 * gate_size] | ||
| ) | ||
| ) | ||
| z_t = self.block_builder.emit( | ||
| relax.op.sigmoid( | ||
| relax.op.add( | ||
| relax.op.add(relax.op.add(z_ih, bias_ih_z), z_hh), bias_hh_z | ||
| ) | ||
| ) | ||
| ) | ||
| else: | ||
| z_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh))) | ||
|
|
||
| # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn)) | ||
| n_ih = self.block_builder.emit( | ||
| relax.op.linear_algebra.matmul(current_input, weight_ih_n_t) | ||
| ) | ||
| n_hh = self.block_builder.emit( | ||
| relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t) | ||
| ) | ||
| if bias_ih is not None and bias_hh is not None: | ||
| bias_ih_n = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| bias_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] | ||
| ) | ||
| ) | ||
| bias_hh_n = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| bias_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] | ||
| ) | ||
| ) | ||
| n_t = self.block_builder.emit( | ||
| relax.op.tanh( | ||
| relax.op.add( | ||
| relax.op.add(n_ih, bias_ih_n), | ||
| relax.op.multiply(r_t, relax.op.add(n_hh, bias_hh_n)), | ||
| ) | ||
| ) | ||
| ) | ||
| else: | ||
| n_t = self.block_builder.emit( | ||
| relax.op.tanh(relax.op.add(n_ih, relax.op.multiply(r_t, n_hh))) | ||
| ) | ||
|
|
||
| # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1} | ||
| one_minus_z = self.block_builder.emit( | ||
| relax.op.subtract(relax.const(1.0, dtype), z_t) | ||
| ) | ||
| h_t = self.block_builder.emit( | ||
| relax.op.add( | ||
| relax.op.multiply(one_minus_z, n_t), relax.op.multiply(z_t, h_prev) | ||
| ) | ||
| ) | ||
|
|
||
| new_h_states.append(h_t) | ||
|
|
||
| current_input = h_t | ||
|
|
||
| # Update hidden states for next time step | ||
| h_states = new_h_states | ||
|
|
||
| # Store output (from the last layer) | ||
| outputs.append(h_states[-1]) | ||
|
|
||
| # Stack outputs: (seq_len, batch_size, hidden_size) | ||
| output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) | ||
|
|
||
| # Reshape back to batch_first if needed | ||
| if batch_first: | ||
| # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size) | ||
| output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) | ||
|
|
||
| return output | ||
|
|
||
| ########## Manipulation ########## | ||
|
|
||
| def _narrow(self, node: fx.Node) -> relax.Var: | ||
|
|
@@ -652,6 +946,7 @@ def create_convert_map( | |
| "layer_norm.default": self._layer_norm, | ||
| "linear.default": self._linear, | ||
| "lstm.input": self._lstm, | ||
| "gru.input": self._gru, | ||
| "max_pool1d.default": self._max_pool1d, | ||
| "max_pool2d.default": self._max_pool2d, | ||
| "max_pool3d.default": self._max_pool3d, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6050,5 +6050,76 @@ def main( | |
| verify_model(TensorNoneModel(), example_args, {}, Expected) | ||
|
|
||
|
|
||
| def test_gru(): | ||
| class BasicGRU(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.gru = nn.GRU( | ||
| input_size=4, | ||
| hidden_size=8, | ||
| num_layers=1, | ||
| batch_first=True, | ||
| bidirectional=False, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| y, _ = self.gru(x) | ||
| return y | ||
|
|
||
| torch.manual_seed(42) | ||
| x = torch.randn(2, 3, 4, dtype=torch.float32) | ||
| model = BasicGRU() | ||
| with torch.no_grad(): | ||
| pytorch_output = model(x) | ||
| exported_program = export(model, args=(x,)) | ||
| mod = from_exported_program(exported_program) | ||
| target = tvm.target.Target("llvm") | ||
| ex = relax.build(mod, target) | ||
| vm = relax.VirtualMachine(ex, tvm.cpu()) | ||
| x_tvm = tvm.runtime.tensor(x.numpy()) | ||
| tvm_output = vm["main"](x_tvm) | ||
| if hasattr(tvm_output, "numpy"): | ||
| tvm_output_np = tvm_output.numpy() | ||
| else: | ||
| tvm_output_np = tvm_output[0].numpy() | ||
| assert ( | ||
| pytorch_output.shape == tvm_output_np.shape | ||
| ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" | ||
| np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) | ||
|
|
||
| class SeqFirstGRU(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.gru = nn.GRU( | ||
| input_size=3, | ||
| hidden_size=6, | ||
| num_layers=1, | ||
| batch_first=False, | ||
| bidirectional=False, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| y, _ = self.gru(x) | ||
| return y | ||
|
|
||
| torch.manual_seed(43) | ||
| x2 = torch.randn(4, 2, 3, dtype=torch.float32) | ||
| model2 = SeqFirstGRU() | ||
| with torch.no_grad(): | ||
| pytorch_output2 = model2(x2) | ||
| exported_program2 = export(model2, args=(x2,)) | ||
| mod2 = from_exported_program(exported_program2) | ||
| ex2 = relax.build(mod2, target) | ||
| vm2 = relax.VirtualMachine(ex2, tvm.cpu()) | ||
| x2_tvm = tvm.runtime.tensor(x2.numpy()) | ||
| tvm_output2 = vm2["main"](x2_tvm) | ||
| if hasattr(tvm_output2, "numpy"): | ||
| tvm_output2_np = tvm_output2.numpy() | ||
| else: | ||
| tvm_output2_np = tvm_output2[0].numpy() | ||
| assert pytorch_output2.shape == tvm_output2_np.shape | ||
| np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) | ||
|
Comment on lines
+6053
to
+6121
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test function
Comment on lines
+6053
to
+6121
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
|
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tvm.testing.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
if/elseblock is redundant as both branches execute the same code (weight_ih = params[0]). This can be simplified to improve code clarity and maintainability.