Skip to content

[autoparallel] modify construct chain in auto activation checkpoint solver rotor#2254

Merged
super-dainiu merged 53 commits intohpcaitech:debug/ckpt-autoparallelfrom
Cypher30:hotfix/fix_construct_chain
Jan 2, 2023
Merged

[autoparallel] modify construct chain in auto activation checkpoint solver rotor#2254
super-dainiu merged 53 commits intohpcaitech:debug/ckpt-autoparallelfrom
Cypher30:hotfix/fix_construct_chain

Conversation

@Cypher30
Copy link
Copy Markdown
Contributor

@Cypher30 Cypher30 commented Jan 2, 2023

What's New?

In this PR, I modify the _extract_node_info in _construct_chain of CheckpointSolverRotor, so that it could correctly compute the memory peak of forward phase.

The new forward peak is calculated with xbar. When we iterate over the list of torch.fx.Node, we first update xbar with buffer memory (e.g. running mean and running var in batch normalization) and calculate_fwd_out() function that computes the true output, i.e. the output used by users of the node. Then we update memory peak static with max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)), where cis._extract_unused_output() is the cis._extract_ftmp() in the past, but I modify it to focus on a single torch.fx.Node other than the last one in the list of torch.fx.Node, so that it could be aware of the discarded output of every torch.fx.Node inside the linearized 'node'.

Cypher30 and others added 30 commits July 14, 2022 16:07
@Cypher30 Cypher30 requested a review from super-dainiu January 2, 2023 02:34
@super-dainiu super-dainiu changed the base branch from main to debug/ckpt-autoparallel January 2, 2023 08:25
@super-dainiu super-dainiu merged commit ac37399 into hpcaitech:debug/ckpt-autoparallel Jan 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants