-
Notifications
You must be signed in to change notification settings - Fork 79
Closed
Description
Motivation: Using MatmulOp/LinearOp in multi-GPU work.
Issue: We currently do not check for the presence of DID axis in the inputs to evaluate for any of the nodes.
Example: Consider: A (M x K), B (K x N) are sharded on 2-devices, with rank: 0 computing the first M/2 rows of C (MxN):
In the current approach, the inputs to evaluate will be: A = [1 x M/2 x K], B = [2 x K x N/2] (assuming there is an all-gather step to get the complete B matrix so we compute the first half of the output, instead of only the first quadrant). The output will be: C = [2 x M/2 x N/2].
Correct computation: reorder and reshape B = [K, N] before matmul to get [1 x M/2 x N].
CC: @wujingyue @cowanmeg. Correct me, if my understanding is incorrect.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels