Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 119 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11663,6 +11663,124 @@ template <typename T> struct CSE final : CheckedOpRewritePattern<T, CSE<T>> {
}
};

// Specialized CSE for DotGeneralOp that handles operand swapping
// Extends the generic CSE pattern with additional swapped equivalence check
struct CSEDotGeneral final
: CheckedOpRewritePattern<stablehlo::DotGeneralOp, CSEDotGeneral> {
using CheckedOpRewritePattern<stablehlo::DotGeneralOp,
CSEDotGeneral>::CheckedOpRewritePattern;

bool supportsDynamicShapes() { return true; }

// Helper to check if dimension arrays are swapped between two ops
bool areDimensionsSwapped(ArrayRef<int64_t> lhs1, ArrayRef<int64_t> rhs1,
ArrayRef<int64_t> lhs2,
ArrayRef<int64_t> rhs2) const {
if (lhs1.size() != rhs2.size() || rhs1.size() != lhs2.size())
return false;

return std::equal(lhs1.begin(), lhs1.end(), rhs2.begin()) &&
std::equal(rhs1.begin(), rhs1.end(), lhs2.begin());
}

// Check if two DotGeneralOps are equivalent with swapped operands
bool isSwappedEquivalent(stablehlo::DotGeneralOp op1,
stablehlo::DotGeneralOp op2) const {
// Check if operands are swapped: op1(A, B) vs op2(B, A)
if (op1.getLhs() != op2.getRhs() || op1.getRhs() != op2.getLhs())
return false;

// Check if types match
if (op1.getType() != op2.getType())
return false;

auto dims1 = op1.getDotDimensionNumbers();
auto dims2 = op2.getDotDimensionNumbers();

// Check if batching dimensions are swapped
if (!areDimensionsSwapped(dims1.getLhsBatchingDimensions(),
dims1.getRhsBatchingDimensions(),
dims2.getLhsBatchingDimensions(),
dims2.getRhsBatchingDimensions()))
return false;

// Check if contracting dimensions are swapped
if (!areDimensionsSwapped(dims1.getLhsContractingDimensions(),
dims1.getRhsContractingDimensions(),
dims2.getLhsContractingDimensions(),
dims2.getRhsContractingDimensions()))
return false;

// Check precision config if present
if (op1.getPrecisionConfig() != op2.getPrecisionConfig())
return false;

return true;
}

LogicalResult matchAndRewriteImpl(stablehlo::DotGeneralOp op,
PatternRewriter &rewriter) const {
// Reuse the standard CSE logic from the CSE template pattern
// Check users of the first operand
for (auto nop : op.getLhs().getUsers()) {
if (nop == op)
continue;
auto dotOp = dyn_cast<stablehlo::DotGeneralOp>(nop);
if (!dotOp)
continue;
if (dotOp->getBlock() != op->getBlock())
continue;

// Check for standard equivalence (from CSE template)
if (OperationEquivalence::isEquivalentTo(
op, dotOp, OperationEquivalence::IgnoreLocations)) {
if (dotOp->isBeforeInBlock(op)) {
rewriter.replaceOp(op, dotOp);
return success();
} else {
rewriter.replaceOp(dotOp, op);
return success();
}
}

// Additional check for swapped equivalence (DotGeneral-specific)
if (isSwappedEquivalent(op, dotOp)) {
if (dotOp->isBeforeInBlock(op)) {
rewriter.replaceOp(op, dotOp);
return success();
} else {
rewriter.replaceOp(dotOp, op);
return success();
}
}
}

// Also check users of the right operand for swapped cases
for (auto nop : op.getRhs().getUsers()) {
if (nop == op)
continue;
auto dotOp = dyn_cast<stablehlo::DotGeneralOp>(nop);
if (!dotOp)
continue;
if (dotOp->getBlock() != op->getBlock())
continue;

// Check for swapped equivalence
if (isSwappedEquivalent(op, dotOp)) {
if (dotOp->isBeforeInBlock(op)) {
rewriter.replaceOp(op, dotOp);
return success();
} else {
rewriter.replaceOp(dotOp, op);
return success();
}
}
}

return failure();
}
};

struct ConstPropThroughBarrier final
: CheckedOpRewritePattern<stablehlo::OptimizationBarrierOp,
ConstPropThroughBarrier> {
Expand Down Expand Up @@ -26575,7 +26693,7 @@ struct EnzymeHLOOptPass
if (cse) {
patterns.add<CSE<stablehlo::BroadcastInDimOp>, CSE<stablehlo::SliceOp>,
CSE<stablehlo::TransposeOp>, CSE<stablehlo::ConvertOp>,
CSE<stablehlo::PadOp>, CSE<stablehlo::DotGeneralOp>,
CSE<stablehlo::PadOp>, CSEDotGeneral,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot can we reuse more of the CSE struct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to follow the structure of the generic CSE pattern more closely. The implementation now directly mirrors the CSE template's iteration and checking logic, only adding the DotGeneral-specific swapped equivalence check. Commit: 9dea61e

CSE<stablehlo::ReshapeOp>, CSE<stablehlo::MulOp>,
CSE<stablehlo::DivOp>, CSE<stablehlo::AddOp>,
CSE<stablehlo::SubtractOp>, CSE<stablehlo::MinOp>,
Expand Down
61 changes: 61 additions & 0 deletions test/lit_tests/dotgeneral_cse_swapped.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

// Test that dot_general CSE recognizes swapped operands as equivalent
// when the dimension numbers are also appropriately swapped

// CHECK-LABEL: func.func @test_dotgeneral_cse_symmetric
func.func @test_dotgeneral_cse_symmetric(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
// Both operations have symmetric dimension specs: batching_dims = [0] x [0], contracting_dims = [1] x [1]
// This computes: for each b in [0,4), result[b] = sum_k arg0[b,k] * arg1[b,k]
// When operands are swapped with these symmetric specs, the operations are equivalent
// because multiplication is commutative: arg0[b,k] * arg1[b,k] = arg1[b,k] * arg0[b,k]
%0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [1] x [1] : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4xf32>
%1 = stablehlo.dot_general %arg1, %arg0, batching_dims = [0] x [0], contracting_dims = [1] x [1] : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4xf32>
return %0, %1 : tensor<4xf32>, tensor<4xf32>
}

// CHECK: %[[V0:.+]] = stablehlo.dot_general %arg0, %arg1
// CHECK-NEXT: return %[[V0]], %[[V0]]

// CHECK-LABEL: func.func @test_dotgeneral_cse_with_batching
func.func @test_dotgeneral_cse_with_batching(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
// With batching_dims = [0] x [0] and contracting_dims = [1, 2] x [1, 2]
// This computes: for each b in [0,2), result[b] = sum_{i,j} arg0[b,i,j] * arg1[b,i,j]
// These are equivalent when operands are swapped (Frobenius inner product is commutative)
%0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [1, 2] x [1, 2] : (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2xf32>
%1 = stablehlo.dot_general %arg1, %arg0, batching_dims = [0] x [0], contracting_dims = [1, 2] x [1, 2] : (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2xf32>
return %0, %1 : tensor<2xf32>, tensor<2xf32>
}

// CHECK: %[[V0:.+]] = stablehlo.dot_general %arg0, %arg1
// CHECK-NEXT: return %[[V0]], %[[V0]]

// CHECK-LABEL: func.func @test_dotgeneral_no_cse_different_shapes
func.func @test_dotgeneral_no_cse_different_shapes(%arg0: tensor<4x8xf32>, %arg1: tensor<8x4xf32>) -> (tensor<4x4xf32>, tensor<8x8xf32>) {
// These have different output shapes and should NOT be CSE'd
// First: A(4x8) × B(8x4) → (4x4)
// Second: B(8x4) × A(4x8) with same contracting dims → (8x8)
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<4x8xf32>, tensor<8x4xf32>) -> tensor<4x4xf32>
%1 = stablehlo.dot_general %arg1, %arg0, contracting_dims = [1] x [0] : (tensor<8x4xf32>, tensor<4x8xf32>) -> tensor<8x8xf32>
return %0, %1 : tensor<4x4xf32>, tensor<8x8xf32>
}

// CHECK: %[[V0:.+]] = stablehlo.dot_general %arg0, %arg1
// CHECK-NEXT: %[[V1:.+]] = stablehlo.dot_general %arg1, %arg0
// CHECK-NEXT: return %[[V0]], %[[V1]]

// CHECK-LABEL: func.func @test_dotgeneral_no_cse_different_computation
func.func @test_dotgeneral_no_cse_different_computation(%arg0: tensor<4x8xf32>, %arg1: tensor<8x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>) {
// These have the same output shape but compute different things
// First: A(4x8) × B(8x4) → C(4x4) where C[i,j] = sum_k A[i,k] * B[k,j]
// Second: B(8x4) × A(4x8) with contracting_dims=[0]x[1] → D(4x4) where D[i,j] = sum_k B[k,i] * A[j,k]
// These are NOT equivalent (D is not the same as C)
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<4x8xf32>, tensor<8x4xf32>) -> tensor<4x4xf32>
%1 = stablehlo.dot_general %arg1, %arg0, contracting_dims = [0] x [1] : (tensor<8x4xf32>, tensor<4x8xf32>) -> tensor<4x4xf32>
return %0, %1 : tensor<4x4xf32>, tensor<4x4xf32>
}

// CHECK: %[[V0:.+]] = stablehlo.dot_general %arg0, %arg1
// CHECK-NEXT: %[[V1:.+]] = stablehlo.dot_general %arg1, %arg0
// CHECK-NEXT: return %[[V0]], %[[V1]]