Skip to content

Conversation

@Lunderberg
Copy link
Contributor

When determining whether to evaluate matrix multiplications as (A*B)*C or as A*(B*C), dynamic shapes may occur (e.g. a dynamic LoRA rank). This commit tests for these cases, and improves the arithmetic bounds used to prove which order of evaluation is preferred.

As part of the implementation, this commit also adds a utility CollectNonNegativeExpressions, exposed to the python API as relax.analysis.collect_non_negative_expresisons. This utility collects expressions within a StructInfo which must be non-negative, based on the location where they appear. For example, the size of a tensor along each dimension must be non-negative. Unlike the existing defineable_tir_vars_in_struct_info, this will include the N-2 expression in R.Tensor([N-2]).

@Lunderberg
Copy link
Contributor Author

Lunderberg commented Feb 16, 2024

The unit test tests/python/relax/test_transform_adjust_matmul_order.py::TestRHSPermuteDimsWithDynamicBatch is currently failing for this PR. In order to determine whether (A*B)*C or A*(B*C) results in fewer operations, it must make use of the improved ConstIntBounds in #16588. After that PR lands, TestRHSPermuteDimsWithDynamicBatch should pass.

When determining whether to evaluate matrix multiplications as
`(A*B)*C` or as `A*(B*C)`, dynamic shapes may occur (e.g. a dynamic
LoRA rank).  This commit tests for these cases, and improves the
arithmetic bounds used to prove which order of evaluation is
preferred.

As part of the implementation, this commit also adds a utility
`CollectNonNegativeExpressions`, exposed to the python API as
`relax.analysis.collect_non_negative_expresisons`.  This utility
collects expressions within a `StructInfo` which must be non-negative,
based on the location where they appear.  For example, the size of a
tensor along each dimension must be non-negative.  Unlike the existing
`defineable_tir_vars_in_struct_info`, this will include the `N-2`
expression in `R.Tensor([N-2])`.
@Lunderberg Lunderberg force-pushed the relax_adjust_matmul_order_with_permute_dims branch from 7dec543 to 517a93d Compare March 25, 2024 19:33
@Lunderberg
Copy link
Contributor Author

Rebased on top of #16735, all new unit tests passing.

@masahi masahi merged commit 0dfc5f9 into apache:main May 13, 2024
@Lunderberg Lunderberg deleted the relax_adjust_matmul_order_with_permute_dims branch May 13, 2024 20:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants