From d36ecdf0122c0d653e1551137797933277f4eaea Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 May 2024 20:07:05 +0000 Subject: [PATCH 1/8] adding hooks context manager --- deepspeed/runtime/zero/__init__.py | 2 ++ deepspeed/runtime/zero/utils.py | 33 ++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/deepspeed/runtime/zero/__init__.py b/deepspeed/runtime/zero/__init__.py index 1ccca09a9e69..0cbab6a36d5c 100644 --- a/deepspeed/runtime/zero/__init__.py +++ b/deepspeed/runtime/zero/__init__.py @@ -13,3 +13,5 @@ from .tiling import TiledLinearReturnBias from .mics import MiCS_Init + +from .utils import unwrap_model_for_generation diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index f61715bd4387..0a201556c73e 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -15,6 +15,7 @@ from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion from deepspeed.utils.nvtx import instrument_w_nvtx from deepspeed.accelerator import get_accelerator +from contextlib import contextmanager def _initialize_parameter_parallel_groups(parameter_parallel_size=None): @@ -158,3 +159,35 @@ def apply_to_tensors_only(function, value, warning_msg_fn=None): logger.warning(warning_msg_fn(value)) warned = True return value + +@contextmanager +def unwrap_model_for_generation(model): + """ + For ZeRO-3 models, we gather the weights once to speed up generation. + """ + with deepspeed.zero.GatheredParameters(model.parameters()): + # Removes the optimizer hooks from a DeepSpeed ZeRO-3 model. + + # Remove hooks + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + + for hook in optimizer_offload.forward_hooks: + hook.remove() + for hook in optimizer_offload.backward_hooks: + hook.remove() + + optimizer_offload.forward_hooks = [] + optimizer_offload.backward_hooks = [] + + yield model + + # Adds the optimizer hooks from a DeepSpeed ZeRO-3 model. + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + optimizer_offload._register_hooks_recursively(optimizer_offload.module) + return From 318929c8e1a9364265ca6250b4ee1ad501b98dd9 Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Fri, 14 Jun 2024 14:30:05 -0700 Subject: [PATCH 2/8] format changes --- deepspeed/runtime/zero/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 0a201556c73e..dcc7ea42ba01 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -160,6 +160,7 @@ def apply_to_tensors_only(function, value, warning_msg_fn=None): warned = True return value + @contextmanager def unwrap_model_for_generation(model): """ From 63fa13ca96d10f457bcc7e78420569b690148f37 Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Thu, 27 Jun 2024 11:59:31 -0700 Subject: [PATCH 3/8] running precommit checks --- csrc/aio/py_lib/deepspeed_py_copy.cpp | 2 +- .../predicated_tile_access_iterator_residual_last.h | 8 ++++---- csrc/includes/simd.h | 2 +- csrc/xpu/includes/simd.h | 2 +- csrc/xpu/includes/type_shim.h | 10 +++++----- deepspeed/runtime/zero/utils.py | 3 ++- 6 files changed, 14 insertions(+), 13 deletions(-) mode change 100755 => 100644 csrc/xpu/includes/simd.h diff --git a/csrc/aio/py_lib/deepspeed_py_copy.cpp b/csrc/aio/py_lib/deepspeed_py_copy.cpp index 8a59107dd347..c597b91d05c9 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.cpp +++ b/csrc/aio/py_lib/deepspeed_py_copy.cpp @@ -10,7 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. #include "deepspeed_py_copy.h" #include -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #if defined(__AVX512__) or defined(__AVX256__) union AVX_Data { diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h index 7f6a2430845a..dcbdc11c27ad 100644 --- a/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h +++ b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h @@ -488,7 +488,7 @@ class PredicatedTileAccessIteratorResidualLast tensor's layout CUTLASS_HOST_DEVICE Params(Layout const& layout) - : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {}; }; private: @@ -1413,7 +1413,7 @@ class PredicatedTileAccessIteratorResidualLast tensor's layout CUTLASS_HOST_DEVICE Params(Layout const& layout) - : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {}; }; private: diff --git a/csrc/includes/simd.h b/csrc/includes/simd.h index f5bfb45dd2e2..a205026ec7c1 100644 --- a/csrc/includes/simd.h +++ b/csrc/includes/simd.h @@ -27,7 +27,7 @@ inline void writeAs(void* dst, const T& val) std::memcpy(dst, &val, sizeof(T)); } -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #if defined(__AVX512__) #define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) diff --git a/csrc/xpu/includes/simd.h b/csrc/xpu/includes/simd.h old mode 100755 new mode 100644 index f77568be7835..097e2d8585cc --- a/csrc/xpu/includes/simd.h +++ b/csrc/xpu/includes/simd.h @@ -13,7 +13,7 @@ #define TILE (128 * 1024 * 1024) #if defined(__AVX512__) or defined(__AVX256__) -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #if defined(__AVX512__) #define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) diff --git a/csrc/xpu/includes/type_shim.h b/csrc/xpu/includes/type_shim.h index fa41757c895b..1897afd1fea2 100644 --- a/csrc/xpu/includes/type_shim.h +++ b/csrc/xpu/includes/type_shim.h @@ -82,11 +82,11 @@ } template -__inline__ __attribute__((always_inline)) T reduce_block_into_lanes( - T* x, - T val, - int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. +__inline__ __attribute__((always_inline)) T +reduce_block_into_lanes(T* x, + T val, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); int tid = item_ct1.get_local_id(2) + item_ct1.get_local_id(1) * item_ct1.get_local_range(2); diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index dcc7ea42ba01..6b594569c699 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -15,6 +15,7 @@ from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion from deepspeed.utils.nvtx import instrument_w_nvtx from deepspeed.accelerator import get_accelerator +from deepspeed.zero import GatheredParameters from contextlib import contextmanager @@ -166,7 +167,7 @@ def unwrap_model_for_generation(model): """ For ZeRO-3 models, we gather the weights once to speed up generation. """ - with deepspeed.zero.GatheredParameters(model.parameters()): + with GatheredParameters(model.parameters()): # Removes the optimizer hooks from a DeepSpeed ZeRO-3 model. # Remove hooks From 00870e32a120e729089c7057887faac634e63bb6 Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Tue, 16 Jul 2024 19:46:40 +0000 Subject: [PATCH 4/8] remove circular dependency --- deepspeed/runtime/zero/__init__.py | 2 +- deepspeed/runtime/zero/stage3.py | 33 ++++++++++++++++++ deepspeed/runtime/zero/utils.py | 36 -------------------- tests/unit/runtime/zero/test_zero_context.py | 4 +++ 4 files changed, 38 insertions(+), 37 deletions(-) diff --git a/deepspeed/runtime/zero/__init__.py b/deepspeed/runtime/zero/__init__.py index 0cbab6a36d5c..23fcf9ec13fb 100644 --- a/deepspeed/runtime/zero/__init__.py +++ b/deepspeed/runtime/zero/__init__.py @@ -14,4 +14,4 @@ from .mics import MiCS_Init -from .utils import unwrap_model_for_generation +from .stage3 import unwrap_model_for_generation diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 37b81d42c0d6..31059ce8b6d4 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -7,6 +7,7 @@ import gc import collections from typing import Deque, Dict, Tuple +from contextlib import contextmanager from deepspeed import comm as dist from deepspeed.utils import groups @@ -68,6 +69,38 @@ def move_to_cpu(tensor_list): for tensor in tensor_list: tensor.data = tensor.data.cpu() +@contextmanager +def unwrap_model_for_generation(model): + """ + For ZeRO-3 models, we gather the weights once to speed up generation. + """ + with GatheredParameters(model.parameters()): + # Removes the optimizer hooks from a DeepSpeed ZeRO-3 model. + + # Remove hooks + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + + for hook in optimizer_offload.forward_hooks: + hook.remove() + for hook in optimizer_offload.backward_hooks: + hook.remove() + + optimizer_offload.forward_hooks = [] + optimizer_offload.backward_hooks = [] + + yield model + + # Adds the optimizer hooks from a DeepSpeed ZeRO-3 model. + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + optimizer_offload._register_hooks_recursively(optimizer_offload.module) + return + INITIAL_MICRO_STEP_ID = -1 diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 6b594569c699..0febf17daa1a 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -15,9 +15,6 @@ from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion from deepspeed.utils.nvtx import instrument_w_nvtx from deepspeed.accelerator import get_accelerator -from deepspeed.zero import GatheredParameters -from contextlib import contextmanager - def _initialize_parameter_parallel_groups(parameter_parallel_size=None): data_parallel_size = int(dist.get_world_size()) @@ -160,36 +157,3 @@ def apply_to_tensors_only(function, value, warning_msg_fn=None): logger.warning(warning_msg_fn(value)) warned = True return value - - -@contextmanager -def unwrap_model_for_generation(model): - """ - For ZeRO-3 models, we gather the weights once to speed up generation. - """ - with GatheredParameters(model.parameters()): - # Removes the optimizer hooks from a DeepSpeed ZeRO-3 model. - - # Remove hooks - if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): - optimizer_offload = model.optimizer.parameter_offload - elif model.optimizer is not None: - optimizer_offload = model.optimizer - - for hook in optimizer_offload.forward_hooks: - hook.remove() - for hook in optimizer_offload.backward_hooks: - hook.remove() - - optimizer_offload.forward_hooks = [] - optimizer_offload.backward_hooks = [] - - yield model - - # Adds the optimizer hooks from a DeepSpeed ZeRO-3 model. - if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): - optimizer_offload = model.optimizer.parameter_offload - elif model.optimizer is not None: - optimizer_offload = model.optimizer - optimizer_offload._register_hooks_recursively(optimizer_offload.module) - return diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index ec9e9e94aeaf..dc9dcba0accb 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -9,6 +9,7 @@ import pytest import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape +from deepspeed.runtime.zero import unwrap_model_for_generation import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator @@ -299,3 +300,6 @@ def test(self): with deepspeed.zero.GatheredParameters(l.weight): # all ranks compare assert torch.equal(l.weight, torch.zeros_like(l.weight)) + +#class TestUnwrapModel(DistributedTest): +# world_size = 2 From 3dd6ebe876980410e2bec58e76a2ef676ef66e2d Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Wed, 17 Jul 2024 18:19:55 +0000 Subject: [PATCH 5/8] adding unwrap unittest --- tests/unit/runtime/zero/test_unwrap_model.py | 69 ++++++++++++++++++++ tests/unit/runtime/zero/test_zero_context.py | 4 -- 2 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 tests/unit/runtime/zero/test_unwrap_model.py diff --git a/tests/unit/runtime/zero/test_unwrap_model.py b/tests/unit/runtime/zero/test_unwrap_model.py new file mode 100644 index 000000000000..eba3d2819746 --- /dev/null +++ b/tests/unit/runtime/zero/test_unwrap_model.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from types import SimpleNamespace + +import torch +import pytest +import deepspeed +from deepspeed.runtime.zero import unwrap_model_for_generation +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest, preferred_dtype +from unit.simple_model import SimpleModel + +config = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": 3, + "stage3_param_persistence_threshold": 1, + "offload_param": { + "device": "cpu", + "pin_memory": True + } + } +} + +if get_accelerator().is_fp16_supported(): + config["fp16"] = {"enabled": True, "loss_scale": 138.} +elif get_accelerator().is_bf16_supported(): + config["bf16"] = {"enabled": True} + +class TestUnwrapModel(DistributedTest): + # gather across more than 1 gpu + world_size = 2 + + def test(self): + def hooks_exist(engine): + if engine.optimizer is not None and hasattr(engine.optimizer, "parameter_offload"): + optimizer_offload = engine.optimizer.parameter_offload + elif engine.optimizer is not None: + optimizer_offload = engine.optimizer + + hooks = 0 + for hook in optimizer_offload.forward_hooks: + hooks += 1 + if hooks > 0: + return True + return False + + model = SimpleModel(hidden_dim=100) + engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config) + + with unwrap_model_for_generation(engine): + # assert no hooks + assert not hooks_exist(engine) + # assert parameters gathered + assert model.linears[0].weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor" + + # assert hooks + assert hooks_exist(engine) diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index dc9dcba0accb..ec9e9e94aeaf 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -9,7 +9,6 @@ import pytest import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape -from deepspeed.runtime.zero import unwrap_model_for_generation import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator @@ -300,6 +299,3 @@ def test(self): with deepspeed.zero.GatheredParameters(l.weight): # all ranks compare assert torch.equal(l.weight, torch.zeros_like(l.weight)) - -#class TestUnwrapModel(DistributedTest): -# world_size = 2 From b3405f0e3b170cd9331fd7e2a9fad99f0e55761e Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Tue, 30 Jul 2024 04:18:34 +0000 Subject: [PATCH 6/8] tests and precommit --- deepspeed/runtime/zero/stage3.py | 1 + tests/unit/runtime/zero/test_unwrap_model.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 1768c5ad6e31..afe836adc0a4 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -69,6 +69,7 @@ def move_to_cpu(tensor_list): for tensor in tensor_list: tensor.data = tensor.data.cpu() + @contextmanager def unwrap_model_for_generation(model): """ diff --git a/tests/unit/runtime/zero/test_unwrap_model.py b/tests/unit/runtime/zero/test_unwrap_model.py index eba3d2819746..b790f88dbbfb 100644 --- a/tests/unit/runtime/zero/test_unwrap_model.py +++ b/tests/unit/runtime/zero/test_unwrap_model.py @@ -38,11 +38,13 @@ elif get_accelerator().is_bf16_supported(): config["bf16"] = {"enabled": True} + class TestUnwrapModel(DistributedTest): # gather across more than 1 gpu world_size = 2 def test(self): + def hooks_exist(engine): if engine.optimizer is not None and hasattr(engine.optimizer, "parameter_offload"): optimizer_offload = engine.optimizer.parameter_offload From 153ccd6ecab26d0894bb3a5ff95c44bd3a94af10 Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Mon, 5 Aug 2024 12:00:07 -0700 Subject: [PATCH 7/8] format changes --- tests/unit/runtime/zero/test_unwrap_model.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/unit/runtime/zero/test_unwrap_model.py b/tests/unit/runtime/zero/test_unwrap_model.py index b790f88dbbfb..d75519b67f68 100644 --- a/tests/unit/runtime/zero/test_unwrap_model.py +++ b/tests/unit/runtime/zero/test_unwrap_model.py @@ -3,15 +3,11 @@ # DeepSpeed Team -from types import SimpleNamespace - -import torch -import pytest import deepspeed from deepspeed.runtime.zero import unwrap_model_for_generation from deepspeed.accelerator import get_accelerator -from unit.common import DistributedTest, preferred_dtype +from unit.common import DistributedTest from unit.simple_model import SimpleModel config = { From 5e0be490480b0f558556b536de5bf369b417d1e3 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Tue, 6 Aug 2024 09:18:50 -0700 Subject: [PATCH 8/8] Update formatting --- .../predicated_tile_access_iterator_residual_last.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h index dcbdc11c27ad..7f6a2430845a 100644 --- a/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h +++ b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h @@ -488,7 +488,7 @@ class PredicatedTileAccessIteratorResidualLast tensor's layout CUTLASS_HOST_DEVICE Params(Layout const& layout) - : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {}; + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; }; private: @@ -1413,7 +1413,7 @@ class PredicatedTileAccessIteratorResidualLast tensor's layout CUTLASS_HOST_DEVICE Params(Layout const& layout) - : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {}; + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; }; private: