Skip to content

Evaluate Matmul+Bias#1993

Merged
Priya2698 merged 12 commits intomainfrom
pm/matmul_single_eval
Apr 2, 2024
Merged

Evaluate Matmul+Bias#1993
Priya2698 merged 12 commits intomainfrom
pm/matmul_single_eval

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Mar 25, 2024

Adds ATen evaluation for Matmul and Matmul + Bias. Based on PR #1921, when evaluating a castOp, we look back to see if there is a preceding MmaOp and evaluate them together.

Issue #1775.

@Priya2698
Copy link
Collaborator Author

!build

@Priya2698 Priya2698 marked this pull request as ready for review March 25, 2024 21:03
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with comments! Make sure you address Jacob's as well.

// 2. The inputs to MmaOp are broadcasted as the last dim for the first operand
// and the first dim for the second operand.
// The inputs of MmaOp will be [M, K, 1] x [1, K, N].
// Additionally, the inputs to the MmaOp should of the `expected_input_dtype`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For posterity, can you explain why this is required?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a comment on line 1351 reasoning about the dtypes. Do you want to move it upwards?

Val* bcast_bias = binary->input(1);
// Bias is casted to fp32 and broadcasted from shape [M,] to [M, 1] in
// biasEpilogue.
NVF_ERROR(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about NVF_CHECK vs NVF_ERROR. But this check and several others down the road seems to be of the same nature as those in verifyMmaOpForEvaluation. Should all of them use the same macro?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using NVF_ERROR for all. Will make sure to replace any NVF_CHECK with NVF_ERROR. My reasoning, is that this is a case that should not have occured under our current matmul assumptions.

@Priya2698
Copy link
Collaborator Author

!build

@Priya2698
Copy link
Collaborator Author

!build


Val* bias = nullptr;
// Case 2: Matmul + Bias
if (MmaOpUtils::matchMatmulBiasCast(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how to address this here, but what worries me is torch.nn.functional.linear takes x.A(T)+bias. So unless we transpose A, we won't have [M,K,1] [1,N,K].

https://pytorch.org/docs/stable/generated/torch.nn.functional.linear.html

We may need to write/use a routine to detect the shape.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll address this is in the next PR. The assumed shapes are the same as that in matmul, so I was assuming that the transpose would already have been done when needed. We can discuss more today.

@Priya2698 Priya2698 merged commit 9c2359f into main Apr 2, 2024
@Priya2698 Priya2698 deleted the pm/matmul_single_eval branch April 2, 2024 17:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants