diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py index 9e7bae816ecd..54f7fd56abfd 100644 --- a/deepspeed/runtime/fp16/onebit/lamb.py +++ b/deepspeed/runtime/fp16/onebit/lamb.py @@ -10,6 +10,7 @@ from deepspeed.utils.torch import required_torch_version from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.utils import filter_empty_parameters class OnebitLamb(torch.optim.Optimizer): @@ -82,6 +83,9 @@ def __init__(self, if amsgrad: raise RuntimeError('1-bit Lamb does not support the AMSGrad variant.') + # Filter out empty parameters (numel == 0) to avoid NaN in scaling calculations + filtered_params = filter_empty_parameters(params) + defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, @@ -91,7 +95,7 @@ def __init__(self, max_coeff=max_coeff, min_coeff=min_coeff) - super(OnebitLamb, self).__init__(params, defaults) + super(OnebitLamb, self).__init__(filtered_params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 self.deepspeed = deepspeed self.lamb_freeze_key = False diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 13f2b2c1dc03..0443e003b79f 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -47,6 +47,45 @@ def __init__(self, params): self.param_groups.append({'params': params}) +def filter_empty_parameters(params): + """Filter out empty parameters (numel == 0) from optimizer params. + + This is useful for optimizers that perform operations like division by numel, + which would produce NaNs for empty parameters. + + Args: + params: Either a list/tuple of Parameters, or a list of parameter group dicts + (each dict has 'params' key with list of Parameters) + + Returns: + Filtered params in the same format as input (list of Parameters or list of dicts) + """ + if not isinstance(params, (list, tuple)) or len(params) == 0: + return params + + # Check if first element is a dict (parameter groups) or a Parameter + if isinstance(params[0], dict): + # params is a list of parameter group dicts + filtered_params = [] + for param_group in params: + filtered_group = {} + trainable_params = [] + for key, value in param_group.items(): + if key == 'params': + # Filter out empty parameters + trainable_params = [p for p in value if p.numel() > 0] + else: + filtered_group[key] = value + # Only add group if it has non-empty parameters + if len(trainable_params) > 0: + filtered_group['params'] = trainable_params + filtered_params.append(filtered_group) + return filtered_params + else: + # params is a list of Parameters + return [p for p in params if p.numel() > 0] + + graph_cache = {} diff --git a/tests/unit/runtime/half_precision/onebit/test_onebit.py b/tests/unit/runtime/half_precision/onebit/test_onebit.py index 014c29baeaa7..cc8180563db5 100644 --- a/tests/unit/runtime/half_precision/onebit/test_onebit.py +++ b/tests/unit/runtime/half_precision/onebit/test_onebit.py @@ -1244,3 +1244,95 @@ def torch_sim(a): if torch.sum(check_mag_mask) != 0: print("Fails at {} of positions".format(torch.sum(check_mag_mask))) assert torch.sum(diff_server_mask) == 0 or torch.sum(check_mag_mask) == 0 + + +class TestOneBitLambEmptyParameters(DistributedTest): + world_size = 2 + + def test(self): + """Test that OnebitLamb correctly filters out empty parameters (numel=0)""" + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + + # Create a model with normal and empty parameters + class ModelWithEmptyParam(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + # Empty parameter (0 elements) + self.empty_param = torch.nn.Parameter(torch.empty(0, 10)) + + def forward(self, x, y): + return self.cross_entropy_loss(self.linear(x), y) + + model = ModelWithEmptyParam() + model.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + # Create parameter groups including empty parameter + param_groups = [ + { + 'params': [model.linear.weight, model.linear.bias], + 'weight_decay': 0.01 + }, + { + 'params': [model.empty_param], + 'weight_decay': 0.0 + } # Empty parameter + ] + + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": get_accelerator().communication_backend_name(), + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1, + }, + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16, + }, + } + + # Verify empty parameter is filtered out + model, optimizer, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=param_groups, + ) + + # Check that empty parameter is not in optimizer param_groups + for group in optimizer.optimizer.param_groups: + for p in group['params']: + assert p.numel() > 0, "Empty parameters should be filtered out" + + # Run a few training steps to ensure no NaN + data_loader = random_dataloader( + model=model, + total_samples=20, + hidden_dim=10, + device=model.device, + dtype=torch.float16, + ) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + # Verify no NaN in parameters + for group in optimizer.optimizer.param_groups: + for p in group['params']: + assert not torch.isnan(p).any(), "Parameters should not contain NaN"