From 86b2fd3fe9176110edaa44a74b5707893cb2e643 Mon Sep 17 00:00:00 2001 From: yurekami Date: Sat, 25 Apr 2026 17:29:21 +0800 Subject: [PATCH] fix(mhc): MHCPreNormFn.backward returns wrong number of gradients `MHCPreNormFn.forward` takes 5 user inputs (x, fn, norm_eps, fuse_grad_acc, n_splits) but `backward` returned 6 elements, which causes torch.autograd to raise: RuntimeError: function MHCPreNormFnBackward returned an incorrect number of gradients (expected 5, got 6) Drop the trailing `None` so each return matches the 5 inputs, and update the type annotation to reflect the actual return shape (x_grad may be None when fuse_grad_acc is set). This path is exercised by `tests/mhc/test_norm_fn.py::test_correctness`. --- tile_kernels/modeling/mhc/ops/norm_fn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tile_kernels/modeling/mhc/ops/norm_fn.py b/tile_kernels/modeling/mhc/ops/norm_fn.py index b0ec6cd..272d957 100644 --- a/tile_kernels/modeling/mhc/ops/norm_fn.py +++ b/tile_kernels/modeling/mhc/ops/norm_fn.py @@ -129,7 +129,7 @@ def forward( def backward( ctx: 'MHCPreNormFn', out_grad: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]: + ) -> tuple[torch.Tensor | None, torch.Tensor, None, None, None]: x, fn, out_mul, sqrsum = ctx.saved_tensors norm_eps = ctx.norm_eps @@ -166,8 +166,8 @@ def backward( if ctx.fuse_grad_acc: del x.untyped_storage().grad_from_mhc_post - return None, fn_grad, None, None, None, None - return x_grad, fn_grad, None, None, None, None + return None, fn_grad, None, None, None + return x_grad, fn_grad, None, None, None def mhc_pre_norm_fn(