-
Notifications
You must be signed in to change notification settings - Fork 78
512 Simplification Rewrites for the Open Source Project Daphne #974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this contribution, @dogakarakas. Adding more simplification rewrites to DAPHNE is great as it can significantly improve both the runtime and memory consumption of DaphneDSL scripts. Overall, the code of your simplification rewrites makes sense and already looks quite good. However, a few less obvious points about the simplification rewrites should be improved. Furthermore, testing needs special attention. It would be great if you could address the following points:
Required changes: (must be addressed before we can merge this PR)
-
Use
CompilerUtils::isScaType()/hasScaType()(seesrc/compiler/utils/CompilerUtils.h) instead of your ownisScalar(). That's important for consistency. One may be tempted to classify every SSA value that is not a matrix, frame, or unknown as a scalar, but it's not that simple. There are other data types (e.g., columns and lists) and we may add more types in the future. The mentioned utility functions are the single point of truth for determining if a type is a scalar type. -
Avoid creating ops from MLIR's
arithdialect, create ops from thedaphnedialect instead. When rewritingdaphne::EwAddOp/EwMulOp, you sometimes create the new ops asarith::AddIOp/AddFOp/MulIOp/MulFOp. Thearithdialect is quite low-level, which is why you need to check if the operands are integers or floats; such checks can be avoided when you create DAPHNE's own ops. Furthermore, the introduction ofarithops may prevent subsequent simplification rewrites from DAPHNE, which targetEwAddOp/EwMulOp. -
Use
CompilerUtils::isConstant<bool>()to check if the transpose arguments ofMatMulOpare known at DaphneDSL compile-time. There are multiple ops for constants in MLIR and it is easy to forget something. That's why we use this central utility function whenever we need the compile-time constant value of anmlir::Value. -
Fix issues with the types of newly created intermediate results. When creating a new MLIR op, one usually has to specify its result type. You made a good attempt at selecting the right result types for ops newly created by your rewrites. However, when mixing different value types for the inputs of the expressions, subtle problems can arise. For instance, consider your rewrite
sum(X+Y) -> sum(X) + sum(Y). Here, you create two newsumAllops and give them the result type of the originalsumAllop as their result type. Now assumeXis a matrix with value typesi64(signed 64-bit integer) andYis a matrix with value typef64(64-bit floating-point). Then, the type ofX+Yisf64(integer plus float is float) and the type ofsum(X+Y)isf64too. However, the type ofsum(X)issi64(the same as the value type ofX), while you would assignf64. This mismatch can lead to problems later, e.g., a kernel for the summation of ansi64-matrix with anf64result type might be missing. To circumvent this kind of problem, I recommend the following: Create ops producing new intermediate results with DAPHNE's unknown type, the type will be inferred in a later compiler pass (the inference pass). Create the final op replacing the original op with the original op's result type (as you already do), because the final result type should not be changed by a rewrite. -
Fix potential bugs related to transposed arguments of
MatMulOp. DaphneIR'sMatMulOphas four arguments: the two matrices to be multiplied and two booleans that indicate if the corresponding input matrix should be interpreted as transposed (inspired by the underlying BLAS routines, e.g.,dgemm()). Looking at your code for the rewrite(X @ Y)[i, j] -> X[i, ] @ Y[, j], I think I spot two bugs that need further investigation:- I think in the two cases when exactly one of the input matrices is transposed,
sliceRowandsliceColneed to be swapped. - I think the newly created
MatMulOpshould retain the two transpose booleans of the originalMatMulOpinstead of setting them tofalse.
Please add test cases to check if your rewrite behaves correctly in all four combinations of (non-)transposed inputs of the matrix multiplication.
- I think in the two cases when exactly one of the input matrices is transposed,
-
Revise the test cases. We need two kinds of test cases:
- On the one hand, we need script-level test cases that check if the expressions addressed by your rewrites yield correct results (no matter how DAPHNE calculates them internally). For instance, add a DaphneDSL script as simple as
print(sum(t([1, 2, 3, 4])));and check if it really prints10\n. These test cases should reside intest/api/cli/expressions/. - On the other hand, we need IR test cases which take an input IR, apply the canonicalization pass (through
daphne-opt), and check if your rewrites have really been applied (by checking if certain operations exist or don't exist after the pass). To that end, we use LLVM's FileCheck tool. Examples of such test cases can be found intest/codegen/andtest/util/. You added your test cases intest/util/. However, you made them calldaphne --explain ...instead ofdaphne-opt.
The way you test at the moment may be due to a misunderstanding, because in our meeting, I mentioned writing script-level test cases that use--explainto print the IR plus checks on the IR output as an alternative.
- On the one hand, we need script-level test cases that check if the expressions addressed by your rewrites yield correct results (no matter how DAPHNE calculates them internally). For instance, add a DaphneDSL script as simple as
-
Undo changes unrelated to this PR. This PR proposes a few changes that are not related to the task and would not be useful on the main branch. They might originate from changes you made for your local setup. In detail, please undo the changes to
build.sh(not setting-j1will also make the CI checks run faster),daphne-opt/daphne-opt.cpp(only formatting changes),llvm.sh(not needed), andscripts/examples/extensions/myKernels/myKernels.cpp(only formatting changes).
Optional changes: (recommended, but not required for merging)
You can further improve the implementation of your rewrites as follows:
- You don't need to "check if the results of the inner sums are indeed scalars", the
sumAllop always returns a scalar. - Feel free to use
rewriter.replaceOpWithNewOp()instead ofrewriter.create()followed byrewriter.replaceOp(), it's shorter.
I hope this feedback helps you to improve your PR. Feel free to share your thoughts on these points.
…, hasScaType, isConstant} for consistency and correctness; replaced arith ops with DAPHNE dialect ops to preserve rewrite applicability. Fixed type handling in rewrites, preserved transpose flags in MatMulOp, and revised/extended both IR- and script-level tests while removing unrelated changes.
|
Thanks for the revisions so far, @dogakarakas. A short clarification regarding point 5 (transposed arguments of |
|
PS: Thinking further about it, I think your rewrite will usually not see a We could add and test the remaining three cases later, once we have fixed the current limitations of the transpose args of |
This pull request adds a set of algebraic simplification rewrites to the DaphneIR canonicalizer. These rewrites target common algebraic patterns to improve optimization and simplify the intermediate representation. The following rewrites have been implemented and can be found in the directory "src/ir/daphneir/Canonicalize.cpp":
Static Rewrites:
1)sumAll(ewAdd(X, Y)) → ewAdd(sumAll(X), sumAll(Y))
2)sumAll(transpose(X)) → sumAll(X)
3)sum(lambda * X) → lambda * sum(X) (only when lambda is a scalar and X is a matrix)
4)trace(X @ Y) (i.e., sum(diagVector(X @ Y))) → sum(X * transpose(Y))
5)X @ Y)[i,j] → X[i, :] @ Y[:, j] (only applicable if both inputs are not transposed)
Dynamic Rewrite:
1)X[a:b, c:d] = Y → X = Y, if dims(X) == dims(Y) *
*only applicable when both matrices have the same element type
All simplifications are tested using LLVM's FileCheck framework. The FileCheck tests verifying each rewrite are located in the "test/util/" directory and script-level test cases can be found under the directory "test/api/cli/expressions/".
These changes improve canonicalization by recognizing and applying equivalent but more efficient patterns, both statically and dynamically.