diff --git a/exir/passes/constant_prop_pass.py b/exir/passes/constant_prop_pass.py index 0049e597f8d..cf864153ba8 100644 --- a/exir/passes/constant_prop_pass.py +++ b/exir/passes/constant_prop_pass.py @@ -61,6 +61,8 @@ def is_const( ) elif isinstance(arg, _PRIMITIVE_TYPES): return True + elif arg is None: + return True elif not isinstance(arg, torch.fx.Node): return False elif arg in const_node_to_tensor: diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index d3c2d0a0936..6618c729987 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1823,3 +1823,34 @@ def _do_checks( self.assertTrue( torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0]) ) + + def test_constant_prop_pass_none(self) -> None: + """ + This checks that None arguments are treated as constants in constant_prop_pass. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.cst = torch.ones(3, 3, 3, dtype=torch.int8) + self.w = torch.ones(3, 3, 3, dtype=torch.int8) + + def forward(self, x): + # Note: using e.g aten.linear would not work as None is not in the graph + a = torch.ops.aten.convolution.default( + self.cst, self.w, None, [1], [0], [1], False, [0], 1 + ) + return a + x + + mod = M() + x = torch.randn([3, 3, 3]) + mod(x) + edge = to_edge( + export(mod, (x,), strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + # 2 constants: self.w and self.cst + self.assertEqual(2, len(edge.exported_program().constants)) + pass_result = constant_prop_pass(edge.exported_program()) + # 1 constant: a (= self.w @ self.cst) + self.assertEqual(1, len(pass_result.constants))