Skip to content

Make dataloader use another random generator#276

Merged
stas00 merged 1 commit intolayer-norm-auto-syncfrom
thomas/layer-norm-auto-sync
Apr 6, 2022
Merged

Make dataloader use another random generator#276
stas00 merged 1 commit intolayer-norm-auto-syncfrom
thomas/layer-norm-auto-sync

Conversation

@thomasw21
Copy link
Member

This is in order to start synchronizing the dropout across TP.

@thomasw21 thomasw21 marked this pull request as ready for review April 4, 2022 09:21
@thomasw21 thomasw21 requested a review from stas00 April 4, 2022 09:21
@stas00
Copy link
Contributor

stas00 commented Apr 4, 2022

From https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading:

By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily) or a specified generator. However, seeds for other libraries may be duplicated upon initializing workers, causing each worker to return identical random numbers. (See this section in FAQ.).

How do you interpret this writeup?

  1. Will it still do + worker_id and ours is 2 - which would make the generators different
  2. or is the doc vague and actually it'll use the same generator for all workers?

@thomasw21
Copy link
Member Author

  1. Yes, we don't actually care about the seed of each worker since data loading should not be random.

The issue was base_seed is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily) Which means that tp_rank=0 would consume a a RNG state, while other ranks wouldn't.

@stas00
Copy link
Contributor

stas00 commented Apr 4, 2022

tp_rank=0 would consume a a RNG state, while other ranks wouldn't.

why won't the other ranks?

I think here by "main process" it refers to the process that spawns the dataloader workers. Won't that main process be each rank's process with relationship to its workers? i.e. there are multiple main processes in this context.

Since I haven't investigated this - have you validated that this fix actually changes something? I assume you do, but you haven't described this in the OP, hence the asking.

@thomasw21
Copy link
Member Author

thomasw21 commented Apr 5, 2022

why won't the other ranks?

if train_dataloader is not None:
train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(train_dataloader))
else:
train_data_iterator = None
This line only runs on tp_rank = 0 if you go up the file a bit. The main process here will refer to tp_rank=0

Since I haven't investigated this - have you validated that this fix actually changes something?

Yes I was able to reproduce the discrepancy in torch_rng_state accross tp ranks, and this fixed it.

I assume you do, but you haven't described this in the OP, hence the asking.

My bad, since we talked about this on slack I didn't copy paste the thread.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding the missing sync, Thomas

@stas00 stas00 merged commit a9fb317 into layer-norm-auto-sync Apr 6, 2022
@stas00 stas00 deleted the thomas/layer-norm-auto-sync branch April 6, 2022 17:10
stas00 added a commit that referenced this pull request Jul 4, 2022
* sync layer norms

* all_reduce is an in_place operation

* Make dataloader use another random generator (#276)

* do all_reduce op.AVG directly

* add eval dataloader deadlock workaround

* revert generator sync

* make auto-sync configurable; basic test; cleanup

* test with updated AMI image

* fix unrelated test

Co-authored-by: thomasw21 <24695242+thomasw21@users.noreply.github.com>
younesbelkada pushed a commit to younesbelkada/Megatron-DeepSpeed that referenced this pull request Sep 28, 2022
* sync layer norms

* all_reduce is an in_place operation

* Make dataloader use another random generator (bigscience-workshop#276)

* do all_reduce op.AVG directly

* add eval dataloader deadlock workaround

* revert generator sync

* make auto-sync configurable; basic test; cleanup

* test with updated AMI image

* fix unrelated test

Co-authored-by: thomasw21 <24695242+thomasw21@users.noreply.github.com>
adammoody pushed a commit to adammoody/Megatron-DeepSpeed that referenced this pull request Dec 18, 2023
* universal-ckp: fix gpt model param names

Signed-off-by: Moshe Island <misland@habana.ai>

* universal-ckp: reconfigure model parameter rng tracker

When loading from universal checkpoint with a different model parameter
configuration, the loaded tensor parallel RNG tracker states are incorrect.
In this case, we reconfigure the tensor parallel RNG tracker states with new
seed values (each tp rank with a unique seed).
We add an offset=iteration to the base seed. This is to ensure that when we
load multiple times from universal checkpoint, we will use a different random
sequence at each run.

This commit requires a counter change in DeepSpeed repo.

Signed-off-by: Moshe Island <misland@habana.ai>

* universal-ckp: remove embedding norm patterns

Embedding norm patterns originate from Bloom, but are not in vanilla GPT.
Therefore, Remove the patterns.

Signed-off-by: Moshe Island <misland@habana.ai>

---------

Signed-off-by: Moshe Island <misland@habana.ai>
Co-authored-by: Moshe Island <misland@habana.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants