Skip to content

[hotfix/rotor] fix variable names#1597

Merged
FrankLeeeee merged 27 commits intohpcaitech:mainfrom
super-dainiu:hotfix/rotor_variable_names
Sep 14, 2022
Merged

[hotfix/rotor] fix variable names#1597
FrankLeeeee merged 27 commits intohpcaitech:mainfrom
super-dainiu:hotfix/rotor_variable_names

Conversation

@super-dainiu
Copy link
Copy Markdown
Contributor

What's fixed?

In the last PR #1587, we modified some naming of variables. This hotfix will change the names correctly.

Testing

image



class Stage(Enum):
class Phase(Enum):
Copy link
Copy Markdown
Contributor Author

@super-dainiu super-dainiu Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stage should be Phase with respect to RPC phase.

Comment thread colossalai/fx/passes/algorithms/ckpt_solver_chen.py
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0)
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max is more plausible for this calculation.

n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`

n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta`
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid doubled MetaInfoProp that introduces doubled fwd_mem_out.

Comment thread colossalai/fx/passes/algorithms/ckpt_solver_rotor.py Outdated
Comment thread colossalai/fx/passes/algorithms/utils.py
Comment on lines +410 to +411
print(chain)
print(node_list)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops

Comment thread colossalai/fx/profiler/dataflow.py Outdated
Comment on lines +51 to +68
def is_forward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.FORWARD
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
return n.meta['phase'] == Phase.FORWARD


def is_loss(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.LOSS
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
return n.meta['phase'] == Phase.LOSS


def is_placeholder(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.PLACEHOLDER
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
return n.meta['phase'] == Phase.PLACEHOLDER


def is_backward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.BACKWARD
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
return n.meta['phase'] == Phase.BACKWARD
Copy link
Copy Markdown
Contributor

@FrankLeeeee FrankLeeeee Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can be merged into one function e.g. is_stage(node: Node, stage: Phase)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The var name in my code should be phase for consistency

Copy link
Copy Markdown
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@FrankLeeeee FrankLeeeee merged commit c8e9b2a into hpcaitech:main Sep 14, 2022
@super-dainiu super-dainiu deleted the hotfix/rotor_variable_names branch September 23, 2022 06:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants