Skip to content

Conversation

@eisene
Copy link
Contributor

@eisene eisene commented Apr 12, 2023

This is a proposed fix for #3202, it makes the deepspeed.zero.Init() context ignore any nested enters, ie any additional enters that were not preceded by an exit.

It also fixes some whitespace and fixes a forgotten cls._old_init in init_subclass in the Init() context entry.

@tjruwase
Copy link
Contributor

@eisene, thanks for this PR. Can you please check for similarities with #2989?

@eisene
Copy link
Contributor Author

eisene commented Apr 13, 2023

Unless I'm misunderstanding something, this is what I see:

First of all, in #2989 the test TestNewClassDeclaredInsideInit defined inside tests/unit/runtime/zero/test_zero_dynamic_class.py appears to have a bug. It should test whether model.fc.weight has the ds_id attribute without initializing the deepspeed_engine. The way it's implemented right now it always passes even if I comment out the two lines

            cls.__init__ = partition_after(cls.__init__)

in partition_parameters.py. However, when checking for the attribute on model.fc.weight without initializing the deepspeed_engine, I get the expected failure when commenting out the above two lines.

In my pull request, the test TestNewClassDeclaredInsideInitFailure.test_new_class_declared_inside_init_failure fails because my way of doing it doesn't crash if you define a new class inside the Init() context. In my way of doing things I believe this is the correct behavior and there shouldn't be any problems with partitioning the parameters of such subclasses. This is because I add the line

        def _init_subclass(cls, **kwargs):
            cls._old_init = cls.__init__    # <-- THIS IS THE ADDED LINE
            cls.__init__ = partition_after(cls.__init__)

This activates partitioning for all subclasses that are defined inside an Init() context. This seems to have been the desired behavior to begin with, except someone forgot to set the _old_init?

In the other pull request #2989 the all_wrapped_classes set boils down to enforcing the behavior that _init_subclass is only ever called together with _enable_class. This is why it never crashes because _old_init is missing, because it's always set inside _enable_class. But wouldn't it just be better to set _old_init inside _init_subclass? That seems to guarantee that all subclasses, whether they're dynamic or whatever, wherever they're defined, are always partitioned.

So in my way of doing things there should not be a RuntimeError in TestNewClassDeclaredInsideInitFailure.test_new_class_declared_inside_init_failure and instead everything should behave correctly in this case. I can add these tests to my pull request and change the test_new_class_declared_inside_init_failure to be a test_new_class_declared_inside_init without a failure, to check whether the MyModel class was processed by DeepSpeed by checking for the ds_id attribute on model.fc.weight?

@eisene
Copy link
Contributor Author

eisene commented Apr 13, 2023

This is the change I'm proposing to the test:

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch

from unit.common import DistributedTest

import deepspeed


class TestNewClassDeclaredInsideNestedInit(DistributedTest):
    world_size = 1

    def test_new_class_declared_inside_nested_init(self):
        ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))

        with deepspeed.zero.Init(config_dict_or_path=ds_config):

            class MyModel(torch.nn.Module):

                def __init__(self):
                    super().__init__()
                    self.fc = torch.nn.Linear(4, 4)

            with deepspeed.zero.Init(config_dict_or_path=ds_config):
                model = MyModel()

        # deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
        # ensure that zero3 processed the parameter
        # assert hasattr(deepspeed_engine.fc.weight, "ds_id")
        assert hasattr(model.fc.weight, "ds_id")


class TestNewClassDeclaredInsideInit(DistributedTest):
    world_size = 1

    def test_new_class_declared_inside_init(self):
        ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))

        with deepspeed.zero.Init(config_dict_or_path=ds_config):

            class MyModel(torch.nn.Module):

                def __init__(self):
                    super().__init__()
                    self.fc = torch.nn.Linear(1, 1)

            model = MyModel()
        # ensure that zero3 processed the parameter
        assert hasattr(model.fc.weight, "ds_id")

Both of these tests should pass in my PR.

@tjruwase
Copy link
Contributor

@stas00, FYI

@tjruwase
Copy link
Contributor

tjruwase commented Apr 14, 2023

@jeffra and @samyam, FYI

@eisene
Copy link
Contributor Author

eisene commented Apr 14, 2023

The HF accelerate tests fail because deepspeed.initialize calls shutdown_init_context directly. Then it's called again during __exit__, causing the counter to go negative. So I made shutdown_init_context also idempotent.

This clearly seems to be the intended behavior. Whoever wrote deepspeed.initialize seems to have known that there would be a second call to shutdown_init_context and expected it to do nothing.

@stas00
Copy link
Collaborator

stas00 commented Apr 14, 2023

@tjruwase, one other thing to be concerned about in the context of the discussion of huggingface/diffusers#3076 is all those global objects - these will be a problem when using multiple models which don't all necessarily use the same setup. One model may use zero.Init whereas another won't. I mention it here since there are a few globals used in the zero.Init code. These probably should be tied to the engine object instead.

@eisene
Copy link
Contributor Author

eisene commented Apr 14, 2023

Agreed. I can make the engine object an optional argument to the context and have it store the state on the engine object if provided. If not provided it can raise a deprecation warning of some kind and store the state in the globals. Thoughts?

One concern is that we need to be careful with the shutdown_init_context call because it happens inside deepspeed.initialize before it creates a new engine object.

Maybe something like this: after partitioning, the model corresponds to a specific engine, it's no longer just a model in the wild. So the model can have an attribute pointing to the corresponding engine and all these functions can check the attribute. This is a breaking change once it's enforced.

@stas00
Copy link
Collaborator

stas00 commented Apr 14, 2023

The single-model concept has been used until now and the problem is that zero.Init is integrated into from_pretrained so that it can automatically partition the model transparently for the user. At that point there is no engine, this happens long before deepspeed.initialize so I'm not sure what engine object could be provided here.

I think this new situation of multiple models calls for a major redesign of the zero.Init + deepspeed.initialize combo and how they interplay - but clearly globals won't work here, unless they are somehow magically not shared between models. I will let @tjruwase comment on that thanks to his work on https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat where multiple models are used.

@tjruwase
Copy link
Contributor

@tjruwase, one other thing to be concerned about in the context of the discussion of huggingface/diffusers#3076 is all those global objects - these will be a problem when using multiple models which don't all necessarily use the same setup. One model may use zero.Init whereas another won't. I mention it here since there are a few globals used in the zero.Init code. These probably should be tied to the engine object instead.

Regarding this, I had started process of eliminating globals in zero.Init. This was partially done here for this PR #2953. There remains more global state to eliminate, and I will prioritize that.

The goal is to optionally associate independent zero.Init context objects with each models. The deepspeed engine object of a model is then naturally associated with the zero.Init context object of the model by a subsequent deepspeed.intialize(). I would like to discuss more on this to clarify the API and design.

@tjruwase
Copy link
Contributor

Agreed. I can make the engine object an optional argument to the context and have it store the state on the engine object if provided. If not provided it can raise a deprecation warning of some kind and store the state in the globals. Thoughts?

One concern is that we need to be careful with the shutdown_init_context call because it happens inside deepspeed.initialize before it creates a new engine object.

Maybe something like this: after partitioning, the model corresponds to a specific engine, it's no longer just a model in the wild. So the model can have an attribute pointing to the corresponding engine and all these functions can check the attribute. This is a breaking change once it's enforced.

@eisene, thanks for sharing these thoughts. Our design is different and has the following properties:

  1. zero.Init context is directly associated with models (more specifically with the parameters). Deepspeed engine only interacts with zero.Init context through the model.
  2. This means a model constructed using zero.Init() can be used without a deepspeed engine. In such scenarios, the client has responsibility of using the Gather API to assemble needed parameters. For example, one might want to compute the weight norms of a pretrained model right after construction but before deepspeed.initialize().
  3. If deepspeed.initialize() is called on a model with associated zero.Init context object, the resulting deepspeed engine will be associated with that context object.
  4. If deepspeed.initialize() is called on a model without zero.Init context, but zero stage 3 is specified in ds_config, then a zero.Init context will be automatically created and associated with the model. In this case, the model parameters will also automatically be partitioned by deepspeed.initialize(). The resulting deepspeed engine will be associated with the zero.Init as normal.

While there might still be unresolved coding issues, but this is the current goal. Please hare any questions or feedback? Thanks!

@eisene
Copy link
Contributor Author

eisene commented Apr 19, 2023

At a high level this looks reasonable to me. I think what would be most helpful here is if I outlined how I'm using DS and what sorts of features I would potentially find helpful that may be relevant to the design principles.

My use case for DS is to train relatively large models, ~3B-7B, on consumer hardware. I'm working on a side project to make a machine translation system for a low resource language. The idea is to translate long texts, so I'm experimenting with different ways to modify existing model architectures to make them better at very long text and then fine-tuning them with DS. I'm leaning heavily on modular model design and using available pre-trained models in each slot. In this way what I'm doing is similar to the other discussions. However, in my case all the models are trained together on a single task, so my situation is simpler, at least for now. I do have ideas that will involve outright separate models, as well.

For example, I have a pre-trained encoder for my source language that I pre-trained myself from scratch, I have a large decoder into the target language that I take from the HF Hub, and I have a helper model that encodes some long context for the current sentence being translated and injects this into the decoder in some way. I'm experimenting with different architectures for this high level idea. This can create fairly complex model creation and data loading code.

ZeRO stage 3, especially NVMe offloading, has been a total game changer for me here. Unlike people working with large clusters and a large budget, I'm trying to squeeze every drop of performance from whatever hardware I happen to have lying around. I think being able to use DS to squeeze large models onto small hardware is very important, and not just inference but fine-tuning as well.

It may be useful to be able to manually decide that some parameters/gradients/optimizer state will go to some specific device, or at least be kept together. For example, I could have a smaller helper model that I want to place on a cheaper machine, ensure that all parameters/gradients/optimizer state for the helper model live on the same machine, and only have it send the output/final gradients of the helper model across the network. Depending on the size of the helper model, it seems possible to me that this could matter, even with all the optimizaton DS is already doing.

I also need NVMe checkpointing, but that's a separate conversation.

@tjruwase
Copy link
Contributor

I also need NVMe checkpointing, but that's a separate conversation.

@eisene, thanks for sharing this. It is very helpful to us to know important features to our users. I think our increased emphasis on modularity should help with flexible combination of models and features. You may be aware of our recent DeepSpeed Chat release where we support up to 5 models on a single rank to implement RLHF finetuning.

@tjruwase
Copy link
Contributor

@eisene, regarding this PR, can you confirm that your proposed test fails with master branch? Thanks!

@eisene
Copy link
Contributor Author

eisene commented Apr 26, 2023

The test that's supposed to fail is test_new_class_declared_inside_init declared inside tests/unit/runtime/zero/test_zero_dynamic_class.py. I can confirm that it fails on the master branch for me.

with deepspeed.zero.Init(config_dict_or_path=ds_config):
assert (deepspeed.zero.partition_parameters.zero_init_enabled == 1)
model = torch.nn.Linear(4, 4)
_, *_ = deepspeed.initialize(model=model, config_params=ds_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see two issues here:

  1. deepspeed.initialize() should not modify init_context in an externally visible way. For example, deepspeed.zero.partition_parameters.zero_init_enabled should be the same before and after.
  2. The following test fails with infinite recursion error.
import torch 
import deepspeed 

ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))
with deepspeed.zero.Init(config_dict_or_path=ds_config):
    with deepspeed.zero.Init(config_dict_or_path=ds_config):
        model = torch.nn.Linear(4, 4)
        _, *_ = deepspeed.initialize(model=model, config_params=ds_config)

Ideally, deepspeed.initialize should temporarily disable any existing init_context and re-enable afterwards.

Copy link
Contributor Author

@eisene eisene May 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand. I see at least two ways to deal with this:

  1. The way it works now - hard shut down the Init() context in deepspeed.initialize(). This way the partitioning hook is installed in DeepSpeedEngine when the class is declared as a subclass of nn.Module and then removed in initialize(). This seems to me to be brittle - if ever DeepSpeedEngine needs to be created outside initialize() it always needs to be accompanied by shutdown_init_context() for reasons that are not immediately obvious. Perhaps this is fine, I defer to your expertise here.
  2. I tried explicitly excepting DeepSpeedEngine from the partitioning hook in partition_parameters.py:
    def _enable_class(cls):
        if cls == deepspeed.runtime.engine.DeepSpeedEngine:
            return
        cls._old_init = cls.__init__
        cls.__init__ = partition_after(cls.__init__)

    def _disable_class(cls):
        if cls == deepspeed.runtime.engine.DeepSpeedEngine:
            return
        cls.__init__ = cls._old_init

This looks like it works. Maybe it would be better to make a class decorator that excludes classes from partitioning and decorate DeepSpeedEngine with it? Then you never have to call shutdown_init_context() explicitly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @eisene, thank you for your great improvements. This PR significantly simplifies the management of zero.Init() contexts while it allows us to flexibly declare a new model class inside zero.Init()

Since the current code in this PR throws in some cases as pointed by @tjruwase, I submitted a separate PR #3592.
This can run the code shown by @tjruwase without an error, and also restores the context of zero.Init() after deepspeed.initialize().

I would be happy if we could discuss with you all including @tjruwase and @stas00.

@eisene
Copy link
Contributor Author

eisene commented Jul 3, 2023

The issue has been resolved by #3592. The decorator is introduced separately in #3865.

@eisene eisene closed this Jul 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants