From 04d656c6cfdddb5827f76b1ce0db62a7f8f12de9 Mon Sep 17 00:00:00 2001 From: Eby Elanjikal Date: Thu, 16 Oct 2025 19:23:18 +0530 Subject: [PATCH 1/4] WIP: Add rewrite to fuse nested BlockDiag Ops From c1ae3caf761898e495d716b0eea070c5fe3db7c5 Mon Sep 17 00:00:00 2001 From: Eby Elanjikal Date: Thu, 16 Oct 2025 22:27:10 +0530 Subject: [PATCH 2/4] Add fuse_blockdiagonal rewrite and corresponding test for nested BlockDiagonal --- pytensor/tensor/rewriting/linalg.py | 25 ++++++++++++++++ tests/tensor/rewriting/test_linalg.py | 43 +++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 17a3ce9165..9b51f0593d 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -60,11 +60,36 @@ solve_triangular, ) +from pytensor.tensor.slinalg import BlockDiagonal logger = logging.getLogger(__name__) MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) +from pytensor.tensor.slinalg import BlockDiagonal +from pytensor.graph import Apply + +def fuse_blockdiagonal(node): + # Only process if this node is a BlockDiagonal + if not isinstance(node.owner.op, BlockDiagonal): + return node + + new_inputs = [] + changed = False + for inp in node.owner.inputs: + # If input is itself a BlockDiagonal, flatten its inputs + if inp.owner and isinstance(inp.owner.op, BlockDiagonal): + new_inputs.extend(inp.owner.inputs) + changed = True + else: + new_inputs.append(inp) + + if changed: + # Return a new fused BlockDiagonal with all inputs + return BlockDiagonal(len(new_inputs))(*new_inputs) + return node + + def is_matrix_transpose(x: TensorVariable) -> bool: """Check if a variable corresponds to a transpose of the last two axes""" node = x.owner diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 6b6f92f292..fbe4db5b4c 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -43,7 +43,50 @@ from tests import unittest_tools as utt from tests.test_rop import break_op +from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal + +def test_nested_blockdiag_fusion(): + # Create matrix variables + x = pt.matrix("x") + y = pt.matrix("y") + z = pt.matrix("z") + + # Nested BlockDiagonal + inner = BlockDiagonal(2)(x, y) + outer = BlockDiagonal(2)(inner, z) + + # Count number of BlockDiagonal ops before fusion + nodes_before = ancestors([outer]) + initial_count = sum( + 1 for node in nodes_before + if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) + ) + assert initial_count > 1, "Setup failed: should have nested BlockDiagonal" + + # Apply the rewrite + fused = fuse_blockdiagonal(outer) + + # Count number of BlockDiagonal ops after fusion + nodes_after = ancestors([fused]) + fused_count = sum( + 1 for node in nodes_after + if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) + ) + assert fused_count == 1, "Nested BlockDiagonal ops were not fused" + + # Check that all original inputs are preserved + fused_inputs = [ + inp + for node in ancestors([fused]) + if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) + for inp in node.owner.inputs + ] + assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused" + + + + def test_matrix_inverse_rop_lop(): rtol = 1e-7 if config.floatX == "float64" else 1e-5 mx = matrix("mx") From f0b2797ebf6b3f2160fadf65d615b05be898f35a Mon Sep 17 00:00:00 2001 From: Eby Elanjikal Date: Tue, 4 Nov 2025 22:52:26 +0530 Subject: [PATCH 3/4] linalg: fuse nested BlockDiagonal ops and add corresponding tests --- pytensor/tensor/rewriting/linalg.py | 25 ++++----- tests/tensor/rewriting/test_linalg.py | 78 +++++++++++++++++---------- 2 files changed, 63 insertions(+), 40 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 9b51f0593d..3960a396cf 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -60,24 +60,23 @@ solve_triangular, ) -from pytensor.tensor.slinalg import BlockDiagonal logger = logging.getLogger(__name__) MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) -from pytensor.tensor.slinalg import BlockDiagonal -from pytensor.graph import Apply +@register_canonicalize +@node_rewriter([BlockDiagonal]) +def fuse_blockdiagonal(fgraph, node): + """Fuse nested BlockDiagonal ops into a single BlockDiagonal.""" -def fuse_blockdiagonal(node): - # Only process if this node is a BlockDiagonal - if not isinstance(node.owner.op, BlockDiagonal): - return node + if not isinstance(node.op, BlockDiagonal): + return None new_inputs = [] changed = False - for inp in node.owner.inputs: - # If input is itself a BlockDiagonal, flatten its inputs + + for inp in node.inputs: if inp.owner and isinstance(inp.owner.op, BlockDiagonal): new_inputs.extend(inp.owner.inputs) changed = True @@ -85,9 +84,11 @@ def fuse_blockdiagonal(node): new_inputs.append(inp) if changed: - # Return a new fused BlockDiagonal with all inputs - return BlockDiagonal(len(new_inputs))(*new_inputs) - return node + fused_op = BlockDiagonal(len(new_inputs)) + new_output = fused_op(*new_inputs) + return [new_output] + + return None def is_matrix_transpose(x: TensorVariable) -> bool: diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index fbe4db5b4c..35a3c8f0d7 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -43,50 +43,72 @@ from tests import unittest_tools as utt from tests.test_rop import break_op -from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal - def test_nested_blockdiag_fusion(): - # Create matrix variables - x = pt.matrix("x") - y = pt.matrix("y") - z = pt.matrix("z") + x = pt.tensor("x", shape=(3, 3)) + y = pt.tensor("y", shape=(3, 3)) + z = pt.tensor("z", shape=(3, 3)) - # Nested BlockDiagonal - inner = BlockDiagonal(2)(x, y) + inner = BlockDiagonal(2)(x, y) outer = BlockDiagonal(2)(inner, z) - # Count number of BlockDiagonal ops before fusion nodes_before = ancestors([outer]) initial_count = sum( - 1 for node in nodes_before + 1 + for node in nodes_before if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) ) - assert initial_count > 1, "Setup failed: should have nested BlockDiagonal" + assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops" - # Apply the rewrite - fused = fuse_blockdiagonal(outer) + f = pytensor.function([x, y, z], outer) + fgraph = f.maker.fgraph - # Count number of BlockDiagonal ops after fusion - nodes_after = ancestors([fused]) - fused_count = sum( - 1 for node in nodes_after - if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) - ) - assert fused_count == 1, "Nested BlockDiagonal ops were not fused" + nodes_after = fgraph.apply_nodes + fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)] + assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused" - # Check that all original inputs are preserved - fused_inputs = [ - inp - for node in ancestors([fused]) - if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) - for inp in node.owner.inputs + fused_op = fused_nodes[0].op + + assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}" + + out_shape = fgraph.outputs[0].type.shape + assert out_shape == (9, 9), f"Unexpected fused output shape: {out_shape}" + + +def test_deeply_nested_blockdiag_fusion(): + x = pt.tensor("x", shape=(3, 3)) + y = pt.tensor("y", shape=(3, 3)) + z = pt.tensor("z", shape=(3, 3)) + w = pt.tensor("w", shape=(3, 3)) + + inner1 = BlockDiagonal(2)(x, y) + inner2 = BlockDiagonal(2)(inner1, z) + outer = BlockDiagonal(2)(inner2, w) + + f = pytensor.function([x, y, z, w], outer) + fgraph = f.maker.fgraph + + fused_nodes = [ + node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal) ] - assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused" + assert len(fused_nodes) == 1, ( + f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}" + ) + + fused_op = fused_nodes[0].op + + assert fused_op.n_inputs == 4, ( + f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}" + ) + + out_shape = fgraph.outputs[0].type.shape + expected_shape = (12, 12) # 4 blocks of (3x3) + assert out_shape == expected_shape, ( + f"Unexpected fused output shape: expected {expected_shape}, got {out_shape}" + ) - def test_matrix_inverse_rop_lop(): rtol = 1e-7 if config.floatX == "float64" else 1e-5 mx = matrix("mx") From 13b71f0c83afb526dd83dbca8695ace5b5c645b6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 7 Jan 2026 21:47:49 -0600 Subject: [PATCH 4/4] Respond to feedback --- pytensor/tensor/rewriting/linalg.py | 4 +--- tests/tensor/rewriting/test_linalg.py | 32 +++++++++++++-------------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 3960a396cf..4a2b6cec44 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -70,9 +70,6 @@ def fuse_blockdiagonal(fgraph, node): """Fuse nested BlockDiagonal ops into a single BlockDiagonal.""" - if not isinstance(node.op, BlockDiagonal): - return None - new_inputs = [] changed = False @@ -86,6 +83,7 @@ def fuse_blockdiagonal(fgraph, node): if changed: fused_op = BlockDiagonal(len(new_inputs)) new_output = fused_op(*new_inputs) + copy_stack_trace(node.outputs[0], new_output) return [new_output] return None diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 35a3c8f0d7..665338e9f1 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -10,7 +10,7 @@ from pytensor import tensor as pt from pytensor.compile import get_default_mode from pytensor.configdefaults import config -from pytensor.graph import ancestors +from pytensor.graph import FunctionGraph, ancestors from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor import swapaxes from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape @@ -52,23 +52,22 @@ def test_nested_blockdiag_fusion(): inner = BlockDiagonal(2)(x, y) outer = BlockDiagonal(2)(inner, z) - nodes_before = ancestors([outer]) initial_count = sum( 1 - for node in nodes_before + for node in ancestors([outer]) if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) ) assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops" - f = pytensor.function([x, y, z], outer) - fgraph = f.maker.fgraph + fgraph = FunctionGraph(inputs=[x, y, z], outputs=[outer]) + rewrite_graph(fgraph, include=("fast_run", "blockdiag_fusion")) - nodes_after = fgraph.apply_nodes - fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)] + fused_nodes = [ + node for node in fgraph.toposort() if isinstance(node.op, BlockDiagonal) + ] assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused" fused_op = fused_nodes[0].op - assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}" out_shape = fgraph.outputs[0].type.shape @@ -85,21 +84,20 @@ def test_deeply_nested_blockdiag_fusion(): inner2 = BlockDiagonal(2)(inner1, z) outer = BlockDiagonal(2)(inner2, w) - f = pytensor.function([x, y, z, w], outer) - fgraph = f.maker.fgraph + fgraph = FunctionGraph(inputs=[x, y, z, w], outputs=[outer]) + rewrite_graph(fgraph, include=("fast_run", "blockdiag_fusion")) - fused_nodes = [ + fused_block_diag_nodes = [ node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal) ] - - assert len(fused_nodes) == 1, ( - f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}" + assert len(fused_block_diag_nodes) == 1, ( + f"Expected 1 fused BlockDiagonal, got {len(fused_block_diag_nodes)}" ) - fused_op = fused_nodes[0].op + fused_block_diag_op = fused_block_diag_nodes[0].op - assert fused_op.n_inputs == 4, ( - f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}" + assert fused_block_diag_op.n_inputs == 4, ( + f"Expected n_inputs=4 after fusion, got {fused_block_diag_op.n_inputs}" ) out_shape = fgraph.outputs[0].type.shape