ATen scheduler for the new Matmul/LinearOp IR nodes#2209
Conversation
jacobhinkle
left a comment
There was a problem hiding this comment.
Looking great. Sprinkle in a few tests once you merge #2175 and we'll be on our way.
2a2786e to
314ab22
Compare
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
314ab22 to
1ad32ea
Compare
|
!build |
jacobhinkle
left a comment
There was a problem hiding this comment.
Looks good overall. I am just slightly concerned about removing stuff from PairwiseRootDomainMap::map.
| //! Define a schedule table to loop over all the heuristics in priority order. | ||
| constexpr std::array<ScheduleHeuristic, 8> all_heuristics_in_priority_order = { | ||
| constexpr std::array<ScheduleHeuristic, 9> all_heuristics_in_priority_order = { | ||
| ScheduleHeuristic::ExprEval, |
There was a problem hiding this comment.
Should NoOp come before ExprEval?
There was a problem hiding this comment.
Some cases get accepted by NoOp scheduler, which is why I prioritized ExprEval scheduler.
We may need to change the heuristics of NoOp if we want to switch the order.
There was a problem hiding this comment.
Oh! Thanks for mentioning that. Does NoOp scheduler accept the cases where you have a single scalar output? Because it seems to me that it would do so based on this code:
Lines 341 to 359 in 8c18701
size_zero would be false in the case that root_dom.empty(). However, the code below might not properly handle zero-dimensional outputs: Fuser/csrc/scheduler/no_op.cpp
Lines 71 to 80 in 8c18701
There was a problem hiding this comment.
[ FAILED ] 6 tests, listed below:
[ FAILED ] ATenNodesParametrizedTest.MatmulNodeConcrete/2, where GetParam() = ({ 32 }, { 32, 1 })
[ FAILED ] ATenNodesParametrizedTest.MatmulNodeConcrete/8, where GetParam() = ({ 1, 32 }, { 32 })
[ FAILED ] ATenNodesParametrizedTest.MatmulNodeConcrete/10, where GetParam() = ({ 1, 32 }, { 32, 1 })
[ FAILED ] ATenNodesParametrizedTest.MatmulNodeSymbolic/2, where GetParam() = ({ 32 }, { 32, 1 })
[ FAILED ] ATenNodesParametrizedTest.MatmulNodeSymbolic/8, where GetParam() = ({ 1, 32 }, { 32 })
[ FAILED ] ATenNodesParametrizedTest.MatmulNodeSymbolic/10, where GetParam() = ({ 1, 32 }, { 32, 1 })
It is likely because there no reductions identified since we use ATen, and all the dimensions in the output are broadcast dimensions. So the cases where M/N = 1 get picked by NoOp
jacobhinkle
left a comment
There was a problem hiding this comment.
LGTM after tests pass and you add some broadcasts in test.
|
!build |
|
🚀 |
ExprEvalSchedulerthat accepts the MatmulOp and LinearOp (next PR) for ATen evaluation.eagerMatmulAPI is renamed and replaces the existingmatmulAPI.fd.ops.matmulnow creates aMatmulOp(except in a few special cases such as scalar dot product, for eg:[K] x [K].Issue #2149, #2092.