Skip to content

[checkpoint] use gather_tensor in checkpoint and update its unit test#1339

Merged
1SAA merged 1 commit intohpcaitech:mainfrom
1SAA:main
Jul 19, 2022
Merged

[checkpoint] use gather_tensor in checkpoint and update its unit test#1339
1SAA merged 1 commit intohpcaitech:mainfrom
1SAA:main

Conversation

@1SAA
Copy link
Copy Markdown
Contributor

@1SAA 1SAA commented Jul 19, 2022

  • only gathers tensors to rank0 to reduce memory usage
  • corrected and polished colo_tensor's checkpointing unit test

@1SAA 1SAA requested a review from feifeibear July 19, 2022 04:13
assert v.is_replicate()
delattr(v, 'save_ready')
# model saving
save_state = {'epoch': epoch, 'model': model_state}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the next PR. You can merge model and optim in a single file.

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

like
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html

old_dist_spec = colo_tensor.dist_spec
colo_tensor.to_replicate_()
if dist.get_rank() != 0:
colo_tensor.set_dist_spec(old_dist_spec)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line triggers collective communication.
Will there be potential blocking if rank 0 is excluded?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no communication, since old_dist_spec must be SHARD and we have a replicated tensor here.

colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
for model_name in ['simple_net', 'bert']:
# TODO(haichen) add BERT in the test
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside a DP group, the input is replicated?

Copy link
Copy Markdown
Contributor Author

@1SAA 1SAA Jul 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on which model is using. We do not have a unifited standard now.

@1SAA 1SAA merged commit f92c100 into hpcaitech:main Jul 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants