Skip to content

Samyamr/full precision for ZeRO Stage2 and Stage3#1004

Merged
jeffra merged 21 commits intomasterfrom
samyamr/full-precision-for-stage3
Apr 29, 2021
Merged

Samyamr/full precision for ZeRO Stage2 and Stage3#1004
jeffra merged 21 commits intomasterfrom
samyamr/full-precision-for-stage3

Conversation

@samyam
Copy link
Copy Markdown
Contributor

@samyam samyam commented Apr 23, 2021

No description provided.

Comment thread deepspeed/runtime/zero/partition_parameters.py
Comment thread deepspeed/runtime/zero/stage2.py
Comment thread deepspeed/runtime/engine.py
@stas00
Copy link
Copy Markdown
Collaborator

stas00 commented Apr 29, 2021

And just to indicate priority to this PR, we have all those bfloat16 models that won't train under fp16/mixed precision, and users want to use DeepSpeed to overcome GPU memory limitations, so they badly need this. Thank you!

@samyam samyam changed the title Samyamr/full precision for stage3 Samyamr/full precision for ZeRO Stage2 and Stage3 Apr 29, 2021
@stas00
Copy link
Copy Markdown
Collaborator

stas00 commented Apr 29, 2021

When you feel this looks good enough to test please let me know and I will start testing this branch on the transformers side. Thank you.

samyam and others added 3 commits April 29, 2021 13:24
Assert to check if param.dtype is torch.half for ZeRO3 should only happen if the model was initialized in ZeRO3 context.
@jeffra jeffra merged commit dad2642 into master Apr 29, 2021
@jeffra jeffra deleted the samyamr/full-precision-for-stage3 branch April 29, 2021 22:06
@stas00
Copy link
Copy Markdown
Collaborator

stas00 commented Apr 30, 2021

This is awesome - thank you!

I encountered only one issue:

As I am writing HF transformers tests for fp32, I found that zero.Init doesn't get dtype from the config file, I have to explicitly do:

           ds_config = deepspeed_config()
            # XXX: Fixme - we shouldn't need to figure dtype out, it should be in the config file
            dtype = torch.float16 if ds_config.get("fp16", {}).get("enabled", True) else torch.float
            with deepspeed.zero.Init(dtype=dtype, config=ds_config):
                model = cls(config, *model_args, **model_kwargs)

I thought the whole point of passing config to zero.Init is so that we don't need to manually parse the file in multiple places, we we discussing this to work:

           ds_config = deepspeed_config()
            with deepspeed.zero.Init(config=ds_config):
                model = cls(config, *model_args, **model_kwargs)

stas00 added a commit to stas00/DeepSpeed that referenced this pull request Apr 30, 2021
I'm not sure if this is the best approach but with deepspeedai#1004 I still have to pass `zero.Init(dtype)` because this branch never gets executed:
```
    def _set_dtype(self, ds_config, dtype):
        if ds_config is not None and dtype is None:
            _ds_config = DeepSpeedConfig(ds_config)
            self.dtype = torch.half if _ds_config.fp16_enabled else torch.float
```
@stas00 stas00 mentioned this pull request Apr 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants