Skip to content

[FEATURE]: Meta information patch for torch.matmul #2582

@Cypher30

Description

@Cypher30

Describe the feature

To provide strategies with our auto-parallel system, the meta information (including compute cost and memory cost during the training process) is the key factor in generating a good strategy. We have done a lot of work previously, including the MetaTensor #1515 mechanism. With the MetaTensor mechanism, we are able to access the lower-level operations during the training process. Then we construct profiler #1587 on top of this mechanism. However, we found that the profiler itself might be time-consuming when inspecting all the operations in the model. Therefore, we decided to make use of MetaTensor to retrieve the Aten-level graph and turn all the meta information calculation into the static formula, so that we could avoid calling torch.autograd.backward every time we want to profile a PyTorch operation. We have done a lot so far and manage to combine the SPMD solver and auto activation checkpoint solver on ResNet #2258.

In order to combine the SPMD solver and auto activation checkpoint solver on transformer-based networks, we need to patch the torch.matmul operation.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions