Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepspeed/runtime/zero/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from .tiling import TiledLinearReturnBias

from .mics import MiCS_Init

from .stage3 import unwrap_model_for_generation
34 changes: 34 additions & 0 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import gc
import collections
from typing import Deque, Dict, Tuple
from contextlib import contextmanager
from deepspeed import comm as dist
from deepspeed.utils import groups

Expand Down Expand Up @@ -69,6 +70,39 @@ def move_to_cpu(tensor_list):
tensor.data = tensor.data.cpu()


@contextmanager
def unwrap_model_for_generation(model):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems my comment here was missed?

Suggested change
def unwrap_model_for_generation(model):
def unshard_and_remove_hooks(model):

"""
For ZeRO-3 models, we gather the weights once to speed up generation.
"""
with GatheredParameters(model.parameters()):
# Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.

# Remove hooks
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer

for hook in optimizer_offload.forward_hooks:
hook.remove()
for hook in optimizer_offload.backward_hooks:
hook.remove()

optimizer_offload.forward_hooks = []
optimizer_offload.backward_hooks = []

yield model

# Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
return


INITIAL_MICRO_STEP_ID = -1


Expand Down
67 changes: 67 additions & 0 deletions tests/unit/runtime/zero/test_unwrap_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed
from deepspeed.runtime.zero import unwrap_model_for_generation
from deepspeed.accelerator import get_accelerator

from unit.common import DistributedTest
from unit.simple_model import SimpleModel

config = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"zero_optimization": {
"stage": 3,
"stage3_param_persistence_threshold": 1,
"offload_param": {
"device": "cpu",
"pin_memory": True
}
}
}

if get_accelerator().is_fp16_supported():
config["fp16"] = {"enabled": True, "loss_scale": 138.}
elif get_accelerator().is_bf16_supported():
config["bf16"] = {"enabled": True}


class TestUnwrapModel(DistributedTest):
# gather across more than 1 gpu
world_size = 2

def test(self):

def hooks_exist(engine):
if engine.optimizer is not None and hasattr(engine.optimizer, "parameter_offload"):
optimizer_offload = engine.optimizer.parameter_offload
elif engine.optimizer is not None:
optimizer_offload = engine.optimizer

hooks = 0
for hook in optimizer_offload.forward_hooks:
hooks += 1
if hooks > 0:
return True
return False

model = SimpleModel(hidden_dim=100)
engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config)

with unwrap_model_for_generation(engine):
# assert no hooks
assert not hooks_exist(engine)
# assert parameters gathered
assert model.linears[0].weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor"

# assert hooks
assert hooks_exist(engine)