-
Notifications
You must be signed in to change notification settings - Fork 25
Optimize consecutive DUS operations with slices from same source to single DUS with extend #1768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements a new optimization that detects when two consecutive dynamic_update_slice operations have updates from slices of the same source tensor at the beginning and end of a dimension, and converts them to a single DUS with an extended middle section. Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
Use the start indices from whichever DUS operation has the slice from the beginning of the source tensor, ensuring the extended tensor is placed at the correct position. Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
Register the new pattern in the transform dialect and add it to the list of available primitives. Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
- Add stride checking to ensure all strides are 1 - Simplify DUS index matching logic - Improve documentation with concrete example - Remove unnecessary conditional logic Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
Correctly determine lhsPad and rhsPad based on which slice is at the start vs end of the dimension, and use the appropriate DUS indices. This ensures the extended tensor is placed at the correct position. Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
Clean up dead code identified in code review. Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
Update comments to use clearer naming (inner_dus/outer_dus) and add a note explaining the relationship to implementation variables. Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
| diffDim = i; | ||
| } | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot take this comment into account and make sure it is properly optimized:
func.func @f(%arg0: tensor<20x1536x3072xf32>, %arg1: tensor<20x1536x3072xf32>) -> (tensor<20x1536x3072xf32>) {
%c = stablehlo.constant dense<1528> : tensor<i32>
%c_0 = stablehlo.constant dense<8> : tensor<i32>
%c_1 = stablehlo.constant dense<7> : tensor<i32>
%0 = stablehlo.slice %arg1 [8:12, 8:1528, 8:3064] : (tensor<20x1536x3072xf32>) -> tensor<4x1520x3056xf32>
%1 = "enzymexla.extend"(%0) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 1 : i64}> : (tensor<4x1520x3056xf32>) -> tensor<6x1520x3056xf32>
// %2 := %arg0[7:13, 8:1528, 8:3064] = %1 = extend(%0, lhs=1, rhs=1) = extend([8:12, 8:1528, 8:3064], lhs=1, rhs=1, dim=0)
%2 = stablehlo.dynamic_update_slice %arg0, %1, %c_1, %c_0, %c_0 : (tensor<20x1536x3072xf32>, tensor<6x1520x3056xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<20x1536x3072xf32>
%3 = stablehlo.slice %arg1 [8:12, 8:9, 8:3064] : (tensor<20x1536x3072xf32>) -> tensor<4x1x3056xf32>
%4 = stablehlo.slice %arg1 [8:12, 1527:1528, 8:3064] : (tensor<20x1536x3072xf32>) -> tensor<4x1x3056xf32>
// %5 := %2[8:12, 7:8, 8:3064] = %3 = %arg1 [8:12, 8:9, 8:3064]
%5 = stablehlo.dynamic_update_slice %2, %3, %c_0, %c_1, %c_0 : (tensor<20x1536x3072xf32>, tensor<4x1x3056xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<20x1536x3072xf32>
// %6 := %5[8:12, 1528:1529, 8:3064] = %4 = %arg1 [8:12, 1527:1528, 8:3064]
%6 = stablehlo.dynamic_update_slice %5, %4, %c_0, %c, %c_0 : (tensor<20x1536x3072xf32>, tensor<4x1x3056xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<20x1536x3072xf32>
return %6 : tensor<20x1536x3072xf32>
}
Specifically here we want to combine %5 and %6 to become a single DUS(%2, extend(slice(arg1))). To do this we have to:
a) Prove that in a DUS of DUS, both operands are sliced from the same argument, and the corresponding update within the DUS has an offset such that the slice indices could fit an extended operand between then
b) Prove that the operand of the innermost DUS contains equivalent data to the extend we would otherwise be applying. For this, we probably want a helper function isEquivalentTo(Value v1, ArrayRef<int64_t>, slicestart1, ArrayRef<int64_t>, slicelimit1, Value v2, ArrayRef<int64_t>, slicestart2, ArrayRef<int64_t>, slicelimit2). If v1 is itself a DUS/slice/extend/etc we would recur into the data to see if it is equivalent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've implemented the requested enhancement in commit 12ac468:
a) Added isEquivalentTo helper function that recursively checks through DUS, slice, and extend operations to determine if values with different slice regions are equivalent.
b) Enhanced the pattern to check if dus2's operand already contains the middle region. When it does (e.g., from a previous extend operation), the pattern now optimizes by creating a single DUS with extend that references the existing data.
The pattern now handles both:
- Simple case: Two DUS ops directly slicing from the same source
- Complex case: Two DUS ops where the operand already contains the middle via a previous extend/DUS
Added test case in dusdus_to_extend_complex.mlir covering the scenario from your example.
Add isEquivalentTo helper function to recursively check if values are equivalent through DUS/slice/extend operations. This enables the pattern to optimize cases where the operand already contains the middle region from a previous extend operation. Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
Calculate lhsPad and rhsPad once and reuse in both code paths to avoid duplication and potential inconsistencies. Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
- Rename dus1Starts/dus2Starts to dusOuterStarts/dusInnerStarts for clarity - Fix extend slice handling to reject partial overlaps instead of clamping - Only accept slices completely within the middle region of extends Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
Explain padding calculation and middle region positioning logic to improve code readability. Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
When two consecutive
dynamic_update_sliceoperations insert slices from the same source tensor at the beginning and end of a dimension, they can be replaced with a single DUS operation using anExtendOpon the middle section.Pattern matched (simple case):
Optimized to:
Pattern matched (complex case with existing middle region):
Also optimized when the operand already contains the middle region from a previous extend/DUS operation.
Implementation:
DUSDUSToExtendpattern inEnzymeHLOOpt.cppwith constraints:isEquivalentTohelper function that recursively checks through DUS, slice, and extend operations to determine if values with different slice regions are equivalentBenefits: Reduces operation count and memory traffic. The resulting extend operation can be further optimized by existing passes. Handles complex scenarios where previous operations have already placed the middle region data.
Original prompt
💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.