diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index dc96c12523..7c638032f1 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -192,7 +192,7 @@ def forward( func_ctx.backward_ops = backward_ops func_ctx.basic_ops = basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.num_params = num_params + func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops] func_ctx.num_extra_inputs = num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() @@ -258,14 +258,14 @@ def backward( # Flatten list of parameter gradients grad_params_flat = [] for idx, dparams in enumerate(grad_params): - params = list(basic_ops[idx].parameters()) + num_params = func_ctx.basic_op_num_params[idx] if dparams is None: - dparams = [None for _ in range(len(params))] + dparams = [None for _ in range(num_params)] else: dparams = list(dparams) - if len(dparams) != len(params): + if len(dparams) != num_params: raise RuntimeError( - f"Expected op {idx} to generate {len(params)} param grads, " + f"Expected op {idx} to generate {num_params} param grads, " f"but got {len(dparams)}" ) grad_params_flat.extend(dparams)