diff --git a/tile_kernels/modeling/mhc/ops/norm_fn.py b/tile_kernels/modeling/mhc/ops/norm_fn.py index b0ec6cd..272d957 100644 --- a/tile_kernels/modeling/mhc/ops/norm_fn.py +++ b/tile_kernels/modeling/mhc/ops/norm_fn.py @@ -129,7 +129,7 @@ def forward( def backward( ctx: 'MHCPreNormFn', out_grad: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]: + ) -> tuple[torch.Tensor | None, torch.Tensor, None, None, None]: x, fn, out_mul, sqrsum = ctx.saved_tensors norm_eps = ctx.norm_eps @@ -166,8 +166,8 @@ def backward( if ctx.fuse_grad_acc: del x.untyped_storage().grad_from_mhc_post - return None, fn_grad, None, None, None, None - return x_grad, fn_grad, None, None, None, None + return None, fn_grad, None, None, None + return x_grad, fn_grad, None, None, None def mhc_pre_norm_fn(