Skip to content

Fix MatmulOp IterDomain mapping#2246

Merged
jacobhinkle merged 6 commits intomainfrom
matmul_op_id_mapping
May 16, 2024
Merged

Fix MatmulOp IterDomain mapping#2246
jacobhinkle merged 6 commits intomainfrom
matmul_op_id_mapping

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented May 14, 2024

This PR does the following:

  1. Add MatmulOp to ir_utils::isTvOp so that its IterDomains will be automatically propagated by IdModel.
  2. Updates the tests to check that all non-Broadcast axes are properly mapped by IdModel through the MatmulOp.
  3. Changes the output of MatmulOp to have an IterType::Reduction axis in the last position of its root domain to represent the K dimension. This change was motivated by needing a way to have both operand K dimensions exact mapped together, as they would be if the op were translated to a mul+sum+cast.
  4. Updates the matmul op to translate trivial cases where K=1 to simple multiply+cast patterns.

Fixes #1707. In fact, that test was actually fixed by #2175 but the test validation was failing because isTvOp was not picking up the matmul as a reduction.

This also adds testing of exact mapping to the node tests (WIP).
out_domain[idx] = ops::newOutputIterDomain(input_ids);
}

out_domain[ndims_out - 1] = IterDomainBuilder(mapping_a.back())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I placed the rK dimension last in the output, just because that makes it a bit easier to handle the many cases we can encounter for matmul. Note that this does not need to match the position of the K dimension in a translated MmaOp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use both, mapping_a and mapping_b here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I was assuming the K dimension would be IterType::Iteration in both operands. They will be exact mapped, so they should be equivalent to one another so I just took the ID from A.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if one is Symbolic and the other Concrete tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I updated ops::newOutputIterDomain to have a std::optional<IterType> force_iter_type argument that I am now using to create the reduction domain. That way, if the A ID is Symbolic but the B ID is not, then we'll use the B extent.

" and ",
tv_b->dtype());

// Check for K=1 i.e. reduction of broadcast. In these cases we don't need a
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered placing this translation to mul in another PR, but it is relatively compact and it is difficult to work around if we want to exhaustively check ID mappings with all different combinations of inputs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these cases!


// Input A to matmul: {*, M, K}
// Input B to matmul: {*, K, N}
auto kpos = input_role == MatmulRole::INPUT_A ? inp_size - 1 : inp_size - 2;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in this file are to accomodate the new Reduction output domain by mapping it to the k position in each operand.

ReductionAxisIsOne,
ATenNodesParametrizedTest,
testing::Values(std::make_tuple(Sizes({m, 1}), Sizes({1, n}))));
testing::Combine(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic: It's not needed here, but for new tests I wonder if we could combine these kinds of parametrizations? For example since we use k=32, we could parametrize the test with a boolean flag k_is_one and if the flag is true we would replace any 32s in the shapes with 1 before running the test. That would reduce the need to repeat all the inputs for each case, but it makes it harder to filter to only the concrete K tests since the tests are simply numbered. It would be great if the parameters were reflected in the name like they are in pytest...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add another flag, and replace the value of K within tests. I am not sure if there is a way to replace the values before they are passed to a test: that is, the test directly views K=1. I'll look into it.

// > 1D.
auto ndims_out = std::max(ndims_a, ndims_b);
// > 1D, but with 1 additional IterType::Reduction axis rK.
auto ndims_out = std::max(ndims_a, ndims_b) + 1;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in this file are for adding a new Reduction output domain to facilitate easier mapping of K axes.

@jacobhinkle jacobhinkle changed the title [WIP] Add MatmulOp to ir_utils::isTvOp Fix MatmulOp IterDomain mapping May 15, 2024
@jacobhinkle jacobhinkle requested a review from Priya2698 May 15, 2024 12:43
@jacobhinkle jacobhinkle marked this pull request as ready for review May 15, 2024 12:43
@jacobhinkle
Copy link
Collaborator Author

!build --diff

tv_b->dtype());

// Check for K=1 i.e. reduction of broadcast. In these cases we don't need a
// matmul so we translate it to a multiplication+cast
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a comment at the matmul definition or in the header enumerating the different cases and how they are handled for an easier summary.

Copy link
Collaborator

@Priya2698 Priya2698 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall.

My main question is around using iterdomains from both A and B for mapping. While they will be exact mapped, what happens we may have one Symbolic and one Concrete tensor (although rare, and we do not test this). The newOutputIterDomain call will resolve the extents using one of the tensors if we use both. Wdyt?

@jacobhinkle
Copy link
Collaborator Author

I agree about symbolic domains and I think we might need to handle this at concretization. Namely if K is a broadcast dim then we should translate to a mul op before segmentation. I'll work on an example for a case like that.

This lets us create a new Reduction IterDomain from the inputs.
Typically this will just use the mapping_a.back() extent, but if that ID
is Symbolic but the B ID is not, then it will use that extent instead.
@jacobhinkle jacobhinkle mentioned this pull request May 16, 2024
@jacobhinkle jacobhinkle merged commit a6ce3e1 into main May 16, 2024
@jacobhinkle jacobhinkle deleted the matmul_op_id_mapping branch May 16, 2024 13:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Segmenter failed to propose the right segment to MatmulScheduler.

2 participants