Skip to content

[doc] a possible gradient_clipping default fix and questions#656

Merged
tjruwase merged 6 commits intodeepspeedai:masterfrom
stas00:grad_clip
Apr 26, 2021
Merged

[doc] a possible gradient_clipping default fix and questions#656
tjruwase merged 6 commits intodeepspeedai:masterfrom
stas00:grad_clip

Conversation

@stas00
Copy link
Copy Markdown
Collaborator

@stas00 stas00 commented Jan 10, 2021

This PR fixes gradient_clipping default to be 1.0 and not 0, since I see that in your code it defaults to 1.0.

But I'm not sure about several things. As it appears that in different places this value behaves/is used differently.

In several places you call:

                torch.nn.utils.clip_grad_norm_(parameters=master_params,
                                               max_norm=self.gradient_clipping())

and here, it's the common max_grad_norm - which should be 1.0 by default. And yes, you have:

deepspeed/runtime/constants.py:"gradient_clipping": 1.0

but then you tell not to use that value:

         if 'max_grad_norm' in optimizer_parameters.keys():
            raise ValueError(
                "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
            )

and the doc you send the user to provides no details whatsoever.

So in some parts of the code I see gradient_clipping used exactly as max_grad_norm would, yet, FP16_Optimizer uses clip_norm with the default of 0.0!!!

class FP16_Optimizer(object):
[...]
    def __init__(self,
[...]
                 clip_grad=0.0,

Yet, it gets initialized from the same:

        clip_grad = self.gradient_clipping()

whose default is 1.0 everywhere in your code.

deepspeed/runtime/constants.py:"gradient_clipping": 1.0

Also why is gradient_clipping a top level entry and not part of the optimizer config?

Beyond the whys, the main question is whether it is safe for us to init deepspeed with:

  "gradient_clipping": args.max_grad_norm

which defaults to 1.0 in our setup.

And this doc https://www.deepspeed.ai/docs/config-json/#gradient-clipping could definitely use some disambiguation and perhaps a few more lines of explanation of what's happening there.

Thanks

@stas00
Copy link
Copy Markdown
Collaborator Author

stas00 commented Jan 12, 2021

Would it be possible to have a quick peek at this PR?

We are ready to merge the DeepSpeed integration: huggingface/transformers#9211

This is just one last bit that I need to validate with you before merging it. Thank you!

@jeffra, @tjruwase

@stas00
Copy link
Copy Markdown
Collaborator Author

stas00 commented Mar 13, 2021

ping?

@stas00
Copy link
Copy Markdown
Collaborator Author

stas00 commented Mar 18, 2021

So is it 1.0 or 0?

@tjruwase tjruwase merged commit b7f9706 into deepspeedai:master Apr 26, 2021
@stas00 stas00 deleted the grad_clip branch April 26, 2021 19:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants