Skip to content

DIDx aware ATen evaluation for matmul and linear. #2372

@Priya2698

Description

@Priya2698

Motivation: Using MatmulOp/LinearOp in multi-GPU work.

Issue: We currently do not check for the presence of DID axis in the inputs to evaluate for any of the nodes.

Example: Consider: A (M x K), B (K x N) are sharded on 2-devices, with rank: 0 computing the first M/2 rows of C (MxN):
In the current approach, the inputs to evaluate will be: A = [1 x M/2 x K], B = [2 x K x N/2] (assuming there is an all-gather step to get the complete B matrix so we compute the first half of the output, instead of only the first quadrant). The output will be: C = [2 x M/2 x N/2].

Correct computation: reorder and reshape B = [K, N] before matmul to get [1 x M/2 x N].

CC: @wujingyue @cowanmeg. Correct me, if my understanding is incorrect.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions