Skip to content

fix(mhc): MHCPreNormFn.backward returns wrong number of gradients#10

Open
yurekami wants to merge 1 commit intodeepseek-ai:mainfrom
yurekami:fix/mhc-pre-norm-fn-backward-arity
Open

fix(mhc): MHCPreNormFn.backward returns wrong number of gradients#10
yurekami wants to merge 1 commit intodeepseek-ai:mainfrom
yurekami:fix/mhc-pre-norm-fn-backward-arity

Conversation

@yurekami
Copy link
Copy Markdown

@yurekami yurekami commented Apr 25, 2026

Summary

MHCPreNormFn.forward (tile_kernels/modeling/mhc/ops/norm_fn.py) takes 5 user inputs (x, fn, norm_eps, fuse_grad_acc, n_splits), but backward returns 6 elements. PyTorch autograd requires one returned gradient per input, so this raises at runtime:

RuntimeError: function MHCPreNormFnBackward returned an incorrect number of gradients (expected 5, got 6)

Fix

Drop the trailing None from both return statements so the arity matches the 5 forward inputs, and update the type annotation to reflect the real return shape (x_grad is None when fuse_grad_acc=True).

-    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
+    ) -> tuple[torch.Tensor | None, torch.Tensor, None, None, None]:
...
-            return None, fn_grad, None, None, None, None
-        return x_grad, fn_grad, None, None, None, None
+            return None, fn_grad, None, None, None
+        return x_grad, fn_grad, None, None, None

`MHCPreNormFn.forward` takes 5 user inputs (x, fn, norm_eps,
fuse_grad_acc, n_splits) but `backward` returned 6 elements,
which causes torch.autograd to raise:

    RuntimeError: function MHCPreNormFnBackward returned an
    incorrect number of gradients (expected 5, got 6)

Drop the trailing `None` so each return matches the 5 inputs,
and update the type annotation to reflect the actual return
shape (x_grad may be None when fuse_grad_acc is set).

This path is exercised by `tests/mhc/test_norm_fn.py::test_correctness`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant