Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 91 additions & 23 deletions docs/_tutorials/zero.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ ZeRO leverages the aggregate computation and memory resources of data parallelis

* **Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states.

* **Stage 3**: The 16-bit model parameters are partitioned across the processes. ZeRO will automatically collect and partition them during the forward and backward passes.
* **Stage 3**: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes.

In addition, ZeRO-3 includes the *infinity offload engine* to form ZeRO-Infinity ([paper](https://arxiv.org/abs/2104.07857)), which can offload to both CPU and NVMe memory for huge memory savings.

## Training environment
We use the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM) GPT-2 code for this exercise. You can step through the Megatron-LM [tutorial](/tutorials/megatron/) to familiarize yourself with the code. We will train the models in this tutorial on [NVIDIA Tesla V100-SXM3 Tensor Core GPUs](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM.
Expand Down Expand Up @@ -108,37 +110,52 @@ Here is a screenshot of nvidia-smi showing GPU activity during training:

### Training trillion-scale models with ZeRO-Infinity

Stage 3 can be enabled in the JSON configuration. A full description of these
configurations is available [here](/docs/config-json/#zero-optimizations-for-fp16-training).
ZeRO-3, the third stage of ZeRO, partitions the full model state (i.e.,
weights, gradients, and optimizer states) to scale memory savings linearly
with the degree of data parallelism. ZeRO-3 can be enabled in the JSON
configuration. A full description of these configurations is available
[here](/docs/config-json/#zero-optimizations-for-fp16-training).

```json

#### Offloading to CPU and NVMe with ZeRO-Infinity

ZeRO-Infinity uses DeepSpeed's infinity offload engine to offload the full
model state to CPU or NVMe memory, allowing for even larger model sizes. Offloading
can be enabled inside the DeepSpeed configuration:

```diff
@@ -6,5 +6,11 @@
"zero_optimization": {
"stage": 3,
"cpu_offload": true,
"cpu_offload_params": true,
"contiguous_gradients": true,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_prefetch_bucket_size": 1e7,
"stage3_param_persistence_threshold": 1e5,
"reduce_bucket_size": 1e7,
"sub_group_size": 1e9
"stage3_prefetch_bucket_size": 1e7,
"stage3_param_persistence_threshold": 1e5,
"reduce_bucket_size": 1e7,
- "sub_group_size": 1e9
+ "sub_group_size": 1e9,
+ "offload_optimizer": {
+ "device": "cpu"
+ },
+ "offload_param": {
+ "device": "cpu"
+ }
}
}
```




#### Registering external parameters with ZeRO-3

**Deprecated:**
DeepSpeed version `0.3.15` introduced automatic external parameter
registration and this step is no longer needed.
**ZeRO-Infinity vs ZeRO-Offload:**
DeepSpeed first included offloading capabilities with ZeRO-Offload,
a system for offloading optimizer and gradient states to CPU memory
within ZeRO-2. ZeRO-Infinity is the next generation of offloading
capabilities accessible to ZeRO-3. ZeRO-Infinity is able to offload
more data than ZeRO-Offload and has more effective bandwidth utilization
and overlapping of computation and communication.
{: .notice--info}




#### Allocating Massive Megatron-LM Models

We make two further changes to model initialization in order to support models
Expand All @@ -158,7 +175,7 @@ for more details.
model = GPT2Model(num_tokentypes=0, parallel_output=True)
```

2. Gather the position embeddings weight for initialization. DeepSpeed will automatically
2. Gather the embeddings weight for initialization. DeepSpeed will automatically
gather a module's parameters during its constructor and for its forward and backward pass.
However, additional accesses must coordinate with DeepSpeed to ensure that parameter data
is gathered and subsequently partitioned. If the tensor is modified, the `modifier_rank`
Expand All @@ -173,23 +190,74 @@ for more details.
modifier_rank=0):
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)

...

self.tokentype_embeddings = torch.nn.Embedding(...)
with deepspeed.zero.GatheredParameters(self.tokentype_embeddings.weight,
modifier_rank=0):
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
```

#### Memory-centric tiling
ZeRO-Infinity includes a replacement for `Linear` layers that further reduces memory.
We optionally tile the model parallel linear layers found in each Transformer layer. Note
that model parallelism and tiling can be combined by specifying the corresponding
base class when building the layer.
The `deepspeed.zero.TiledLinear` module exploits the data fetch and release
pattern of ZeRO-3 to reduce the working memory requirements by breaking down
a large operator into smaller tiles that can be executed sequentially.

We include the changes for one example from Megatron-LM's [ParallelMLP](https://github.com/microsoft/DeepSpeedExamples/blob/bdf8e59aede8c8e0577e8d4d557298ca8515268f/Megatron-LM-v1.1.5-ZeRO3/megatron/model/transformer.py#L82). Three more
model-parallel layers in `transformer.py` proceed similarly.

The model parallel layers of Megatron-LM have a special form in which the
additive `bias` of the layer is delayed and instead returned from `forward()`
to be fused with a later operator. DeepSpeed's
`deepspeed.zero.TiledLinearReturnBias` subclass of `TiledLinear` simply also
forwards the returned `bias` parameter without accumulating.

```diff
@@ -1,6 +1,9 @@
-self.dense_h_to_4h = mpu.ColumnParallelLinear(
+self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias(
args.hidden_size,
4 * args.hidden_size,
+ in_splits=args.tile_factor,
+ out_splits=4*args.tile_factor,
+ linear_cls=mpu.ColumnParallelLinear,
gather_output=False,
init_method=init_method,
skip_bias_add=True)
```

Note that we scale `in_splits` and `out_splits` proportionally with `input_size` and `output_size`. This
results in tiles of fixed size `[hidden/tile_factor, hidden/tile_factor]`.

#### Registering external parameters

**Deprecated:**
DeepSpeed version `0.3.15` introduced automatic external parameter
registration and this step is no longer needed.
{: .notice--info}


## Extracting weights

If you need to take the pretrained weights out of Deepspeed here is what you can do for getting fp16 weights:

- under ZeRO-2 `state_dict` contains the fp16 model weights and these can be saved normally with `torch.save`.
- under ZeRO-3 `state_dict` contains just the placeholders since the model weights are partitioned across multiple GPUs. If you want to get to these weights enable:

```
```json
"zero_optimization": {
"stage3_gather_fp16_weights_on_model_save": true
},
```
And then save the model using:

```
```python
if self.deepspeed:
self.deepspeed.save_fp16_model(output_dir, output_file)
```
Expand All @@ -201,7 +269,7 @@ You can use this method to save ZeRO-2 weights as well.

If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage:

```
``` bash
$ cd /path/to/checkpoints_dir
$ ./zero_to_fp32.py global_step1 pytorch_model.bin
Processing zero checkpoint at global_step1
Expand Down