Conversation
|
!build |
wujingyue
left a comment
There was a problem hiding this comment.
LGTM with comments! Make sure you address Jacob's as well.
csrc/ir/utils.cpp
Outdated
| // 2. The inputs to MmaOp are broadcasted as the last dim for the first operand | ||
| // and the first dim for the second operand. | ||
| // The inputs of MmaOp will be [M, K, 1] x [1, K, N]. | ||
| // Additionally, the inputs to the MmaOp should of the `expected_input_dtype`. |
There was a problem hiding this comment.
For posterity, can you explain why this is required?
There was a problem hiding this comment.
There is a comment on line 1351 reasoning about the dtypes. Do you want to move it upwards?
csrc/ir/utils.cpp
Outdated
| Val* bcast_bias = binary->input(1); | ||
| // Bias is casted to fp32 and broadcasted from shape [M,] to [M, 1] in | ||
| // biasEpilogue. | ||
| NVF_ERROR( |
There was a problem hiding this comment.
I'm not sure about NVF_CHECK vs NVF_ERROR. But this check and several others down the road seems to be of the same nature as those in verifyMmaOpForEvaluation. Should all of them use the same macro?
There was a problem hiding this comment.
I am using NVF_ERROR for all. Will make sure to replace any NVF_CHECK with NVF_ERROR. My reasoning, is that this is a case that should not have occured under our current matmul assumptions.
|
!build |
|
!build |
|
|
||
| Val* bias = nullptr; | ||
| // Case 2: Matmul + Bias | ||
| if (MmaOpUtils::matchMatmulBiasCast( |
There was a problem hiding this comment.
I am not sure how to address this here, but what worries me is torch.nn.functional.linear takes x.A(T)+bias. So unless we transpose A, we won't have [M,K,1] [1,N,K].
https://pytorch.org/docs/stable/generated/torch.nn.functional.linear.html
We may need to write/use a routine to detect the shape.
There was a problem hiding this comment.
I'll address this is in the next PR. The assumed shapes are the same as that in matmul, so I was assuming that the transpose would already have been done when needed. We can discuss more today.
Adds ATen evaluation for Matmul and Matmul + Bias. Based on PR #1921, when evaluating a
castOp, we look back to see if there is a preceding MmaOp and evaluate them together.Issue #1775.