From a110de2c296ca934808c0b1d45bb15b17eb97cf5 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 16 Jan 2025 18:15:49 -0800 Subject: [PATCH 1/8] new initializers --- onnxscript/ir/_tape.py | 9 +++++++++ onnxscript/optimizer/_constant_folding.py | 6 +++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 0a179af852..894b978425 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -18,6 +18,7 @@ class Tape(Iterable[ir.Node]): def __init__(self) -> None: self._nodes: list[ir.Node] = [] + self._initializers: list[ir.Value] = [] def __iter__(self) -> Iterator[ir.Node]: return iter(self._nodes) @@ -26,6 +27,10 @@ def __iter__(self) -> Iterator[ir.Node]: def nodes(self) -> Sequence[ir.Node]: return tuple(self._nodes) + @property + def initializers(self) -> Sequence[ir.Value]: + return tuple(self._initializers) + def op( self, op_type: str, @@ -60,6 +65,10 @@ def op_multi_output( return node.outputs + def initializer(self, name: str, tensor: ir.TensorProtocol) -> ir.Value: + value = ir.Value(name=name, shape=tensor.shape, type=ir.TensorType(tensor.dtype), const_value=tensor) + self._initializers.append(value) + return value # A type representing the domains/versions used in creating nodes in IR. UsedOpsets = List[Tuple[str, Optional[int]]] diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 8b4dbbfe55..7b15076ed5 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -119,7 +119,11 @@ def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: evaluator = self.get_evaluator(domain, op, version) if evaluator is None: return None - return evaluator(*args, **kwargs) + try: + return evaluator(*args, **kwargs) + except Exception as e: + logger.debug("Evaluation failed: %s", e) + return None _reference_evaluator = ReferenceEvaluator() From 066ae316bb87f416ee7db458eacddf21e6284745 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Jan 2025 09:39:29 -0800 Subject: [PATCH 2/8] Add initializer support in rewrite rule --- onnxscript/ir/_tape.py | 5 ++++- onnxscript/rewriter/pattern.py | 9 +++++++- onnxscript/rewriter/pattern_test.py | 33 +++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 894b978425..fd9cc1dfbe 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -65,7 +65,10 @@ def op_multi_output( return node.outputs - def initializer(self, name: str, tensor: ir.TensorProtocol) -> ir.Value: + def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: + name = name or tensor.name + if name is None: + raise ValueError("Name must be provided for initializer.") value = ir.Value(name=name, shape=tensor.shape, type=ir.TensorType(tensor.dtype), const_value=tensor) self._initializers.append(value) return value diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 84ac42beb2..239c0d746f 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -900,6 +900,7 @@ class ReplacementSubgraph: match: MatchResult new_outputs: Sequence[ir.Value] new_nodes: Sequence[ir.Node] + new_initializers: Sequence[ir.Value] used_opsets: _tape.UsedOpsets @@ -928,7 +929,7 @@ def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: return None # Failed to create replacement subgraph if not isinstance(new_outputs, Sequence): new_outputs = [new_outputs] - return ReplacementSubgraph(match, new_outputs, context.nodes, context.used_opsets) + return ReplacementSubgraph(match, new_outputs, context.nodes, context.initializers, context.used_opsets) def _update_opset_imports( @@ -1566,6 +1567,12 @@ def _apply_to_graph_or_function( if delta is None or tracer is not None: continue assert isinstance(delta, ReplacementSubgraph) + if delta.new_initializers and isinstance(graph_or_function, ir.Function): + # TODO(rama): Can't add initializers to functions. But currently this is not + # an issue, as we apply inlining before applying rewrite rules. + continue + for initializer in delta.new_initializers: + graph_or_function.initializers[initializer.name] = initializer # TODO: This does not yet handle the problem of determining the correct insertion point # for inserted nodes in the case of patterns with multiple output-nodes. The following # is sufficient for patterns with a single output-node "node", which can serve as the diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 1803ab6706..f584748743 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -3,6 +3,7 @@ import contextlib import io import logging +import numpy as np import unittest import onnx.checker @@ -543,6 +544,38 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]: # Not a robust test. But test serves to ensure that debug mode is producing something. self.assertIn("OpType mismatch: expected Abs, got Neg", captured_output) + def test_new_initializer(self): + def source_pattern(op, x, y): + return op.Gemm(x, op.Transpose(y)) + + def check(context, x, y): + return y.const_value is not None + + def replacement(op, x, y): + tensor = y.const_value + name = y.name + "_transposed" + transposed = ir.tensor(tensor.numpy().T, name=name) + initializer = op.initializer(transposed) + return op.Gemm(x, initializer) + + rule = pattern.RewriteRule(source_pattern, replacement, check) + + y_value = np.random.rand(8, 4).astype(np.float32) + @script() + def test_model(x: FLOAT[16, 8]) -> FLOAT[16, 4]: + y = op.Constant(value=y_value) + return op.Gemm(x, op.Transpose(y)) + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual(len(model.graph.initializers), 1) + last_node = model.graph[-1] + self.assertEqual(len(last_node.inputs), 2) + init_name = last_node.inputs[1].name + self.assertIn(init_name, model.graph.initializers) + self.assertIs(last_node.inputs[1], model.graph.initializers[init_name]) + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): From 0dfc56504491d05620ee3dc8d12754cbdf70ea0d Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Jan 2025 10:12:53 -0800 Subject: [PATCH 3/8] Add warning messages --- onnxscript/rewriter/pattern.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 239c0d746f..52f820f4e0 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1570,9 +1570,19 @@ def _apply_to_graph_or_function( if delta.new_initializers and isinstance(graph_or_function, ir.Function): # TODO(rama): Can't add initializers to functions. But currently this is not # an issue, as we apply inlining before applying rewrite rules. + if verbose: + print( + f"Rewrites adding initializers not supported for functions: {rule}" + ) continue + initializers = graph_or_function.initializers for initializer in delta.new_initializers: - graph_or_function.initializers[initializer.name] = initializer + if initializer.name in initializers: + if verbose: + print(f"Initializer {initializer.name} already exists.") + continue + for initializer in delta.new_initializers: + initializers[initializer.name] = initializer # TODO: This does not yet handle the problem of determining the correct insertion point # for inserted nodes in the case of patterns with multiple output-nodes. The following # is sufficient for patterns with a single output-node "node", which can serve as the From 6556646c47ed30efe36e691c3708e52ef3a95e04 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Jan 2025 10:14:59 -0800 Subject: [PATCH 4/8] Fix --- onnxscript/rewriter/pattern.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 52f820f4e0..25a55c99aa 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1567,22 +1567,23 @@ def _apply_to_graph_or_function( if delta is None or tracer is not None: continue assert isinstance(delta, ReplacementSubgraph) - if delta.new_initializers and isinstance(graph_or_function, ir.Function): - # TODO(rama): Can't add initializers to functions. But currently this is not - # an issue, as we apply inlining before applying rewrite rules. - if verbose: - print( - f"Rewrites adding initializers not supported for functions: {rule}" - ) - continue - initializers = graph_or_function.initializers - for initializer in delta.new_initializers: - if initializer.name in initializers: + if delta.new_initializers: + if isinstance(graph_or_function, ir.Function): + # TODO(rama): Can't add initializers to functions. But currently this is not + # an issue, as we apply inlining before applying rewrite rules. if verbose: - print(f"Initializer {initializer.name} already exists.") + print( + f"Rewrites adding initializers not supported for functions: {rule}" + ) continue - for initializer in delta.new_initializers: - initializers[initializer.name] = initializer + initializers = graph_or_function.initializers + for initializer in delta.new_initializers: + if initializer.name in initializers: + if verbose: + print(f"Initializer {initializer.name} already exists.") + continue + for initializer in delta.new_initializers: + initializers[initializer.name] = initializer # TODO: This does not yet handle the problem of determining the correct insertion point # for inserted nodes in the case of patterns with multiple output-nodes. The following # is sufficient for patterns with a single output-node "node", which can serve as the From 1f213e62dafd1745976001d2cf90d159cb59d317 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 17 Jan 2025 10:26:01 -0800 Subject: [PATCH 5/8] Update onnxscript/optimizer/_constant_folding.py Co-authored-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 7b15076ed5..deb1be9e9e 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -122,7 +122,7 @@ def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: try: return evaluator(*args, **kwargs) except Exception as e: - logger.debug("Evaluation failed: %s", e) + logger.warning("Evaluation failed: %s", e) return None From eed8ba5ec49b80f2ed9a77d2ee65accb273c638f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Jan 2025 10:36:51 -0800 Subject: [PATCH 6/8] Lint fixes --- onnxscript/rewriter/pattern.py | 2 +- onnxscript/rewriter/pattern_test.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 25a55c99aa..20973cfd95 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1583,7 +1583,7 @@ def _apply_to_graph_or_function( print(f"Initializer {initializer.name} already exists.") continue for initializer in delta.new_initializers: - initializers[initializer.name] = initializer + initializers[initializer.name] = initializer # type: ignore[index] # TODO: This does not yet handle the problem of determining the correct insertion point # for inserted nodes in the case of patterns with multiple output-nodes. The following # is sufficient for patterns with a single output-node "node", which can serve as the diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index f584748743..4ce73bdea7 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -547,7 +547,7 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]: def test_new_initializer(self): def source_pattern(op, x, y): return op.Gemm(x, op.Transpose(y)) - + def check(context, x, y): return y.const_value is not None @@ -557,7 +557,7 @@ def replacement(op, x, y): transposed = ir.tensor(tensor.numpy().T, name=name) initializer = op.initializer(transposed) return op.Gemm(x, initializer) - + rule = pattern.RewriteRule(source_pattern, replacement, check) y_value = np.random.rand(8, 4).astype(np.float32) @@ -565,7 +565,7 @@ def replacement(op, x, y): def test_model(x: FLOAT[16, 8]) -> FLOAT[16, 4]: y = op.Constant(value=y_value) return op.Gemm(x, op.Transpose(y)) - + model_proto = test_model.to_model_proto() model = ir.serde.deserialize_model(model_proto) rule.apply_to_model(model) From 8e1079eff778832b31a1e91b8f75fd5b41f11595 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Jan 2025 11:03:24 -0800 Subject: [PATCH 7/8] mypy warning fix --- onnxscript/ir/_tape.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index fd9cc1dfbe..e94a3fe33a 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -69,7 +69,8 @@ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir. name = name or tensor.name if name is None: raise ValueError("Name must be provided for initializer.") - value = ir.Value(name=name, shape=tensor.shape, type=ir.TensorType(tensor.dtype), const_value=tensor) + shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims) + value = ir.Value(name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor) self._initializers.append(value) return value From 05af7ffef62d6424222608c1453be38ba781b5be Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Jan 2025 12:38:21 -0800 Subject: [PATCH 8/8] Address ruff --- onnxscript/ir/_tape.py | 5 ++++- onnxscript/rewriter/pattern.py | 4 +++- onnxscript/rewriter/pattern_test.py | 3 ++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index e94a3fe33a..752a52a243 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -70,10 +70,13 @@ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir. if name is None: raise ValueError("Name must be provided for initializer.") shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims) - value = ir.Value(name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor) + value = ir.Value( + name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor + ) self._initializers.append(value) return value + # A type representing the domains/versions used in creating nodes in IR. UsedOpsets = List[Tuple[str, Optional[int]]] diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 20973cfd95..868da62443 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -929,7 +929,9 @@ def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: return None # Failed to create replacement subgraph if not isinstance(new_outputs, Sequence): new_outputs = [new_outputs] - return ReplacementSubgraph(match, new_outputs, context.nodes, context.initializers, context.used_opsets) + return ReplacementSubgraph( + match, new_outputs, context.nodes, context.initializers, context.used_opsets + ) def _update_opset_imports( diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 4ce73bdea7..ca865ecde1 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -3,9 +3,9 @@ import contextlib import io import logging -import numpy as np import unittest +import numpy as np import onnx.checker import onnx.parser @@ -561,6 +561,7 @@ def replacement(op, x, y): rule = pattern.RewriteRule(source_pattern, replacement, check) y_value = np.random.rand(8, 4).astype(np.float32) + @script() def test_model(x: FLOAT[16, 8]) -> FLOAT[16, 4]: y = op.Constant(value=y_value)