diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index deb1be9e9e..3b91e378d2 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -874,7 +874,8 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: e, ) - def new_constant(self, irvalue: ir.Value, value): + def new_constant(self, node: ir.Node, value): + irvalue = node.outputs[0] if not isinstance(value, np.ndarray): # ONNX does not have a way to represent non-tensor constants, eg. a sequence. # So, a constant-value of type sequence is not folded, but it can be used @@ -891,12 +892,22 @@ def new_constant(self, irvalue: ir.Value, value): irvalue.const_value = tensor if value.nbytes > self._output_size_limit: - logger.info( - "Skip storing constant folded nvalue %s due to large size %s.", - irvalue.name, - value.nbytes, - ) - return None + # Handle examples like Transpose(weight) to be folded even if the size is large, + # as long as weight has no other uses. This won't increase model size. + removed_input_size = 0 + for input in node.inputs: + if (input is not None) and (len(input.uses()) == 1): + array = _get_numpy_value(input) + if array is not None: + removed_input_size += array.nbytes + increased_size = value.nbytes - removed_input_size + if increased_size > 0: + logger.info( + "Skip storing constant folded nvalue %s due to large size %s.", + irvalue.name, + value.nbytes, + ) + return None logger.debug( "New constant for value %s dtype: %s shape: %s", @@ -979,7 +990,7 @@ def convert(av): if outputs is None: return None if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): - replacement = self.new_constant(node.outputs[0], outputs) + replacement = self.new_constant(node, outputs) if is_onnx_op(node, "ConstantOfShape") or replacement is None: return None return Replacement(replacement.outputs, [replacement]) diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index b0df4dd546..d4124d3b21 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations + import unittest +import numpy as np import onnx import parameterized import pytest @@ -397,10 +400,12 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( class FoldConstantsIrTest(unittest.TestCase): - def _fold(self, model_text: str, onnx_shape_inference=False) -> ir.Model: - model_proto = onnx.parser.parse_model(model_text) - model = serde.deserialize_model(model_proto) - _constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference) + def _fold(self, model: str | onnx.ModelProto | ir.Model, **kwargs) -> ir.Model: + if isinstance(model, str): + model = onnx.parser.parse_model(model) + if isinstance(model, onnx.ModelProto): + model = serde.deserialize_model(model) + _constant_folding.fold_constants(model, **kwargs) optimizer.remove_unused_nodes(model) return model @@ -557,6 +562,32 @@ def test_gather_symdim(self): optimized = self._fold(model) self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + def test_large_transpose(self): + model = """ + + agraph (float[M, 256] x) => (float[M, 512] z) + # placeholder for large initializer of shape [512, 256] + { + wt = Transpose (w) + z = MatMul (x, wt) + } + """ + irmodel = serde.deserialize_model(onnx.parser.parse_model(model)) + w = irmodel.graph.initializers["w"] + w.shape = ir.Shape([512, 256]) + w.const_value = ir.tensor(np.random.random((512, 256)).astype(np.float32)) + + # Input size limit will prevent folding of Transpose op + optimized = self._fold(irmodel, input_size_limit=3 * 512 * 256) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["Transpose", "MatMul"]) + + # Input size limit will allow folding of Transpose op + # Since there is no increase in model-size, output-size is not a concern. + optimized = self._fold(irmodel, input_size_limit=4 * 512 * 256) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["Constant", "MatMul"]) + if __name__ == "__main__": unittest.main()