Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f8b9514
CPU offload changes
yfw Apr 2, 2025
c02d454
Fix unit tests
yfw Apr 3, 2025
7229b87
Add memory tracking in unit test
yfw Apr 4, 2025
c299a75
Merge branch 'main' into yifu/cpu_offload2
yfw Apr 4, 2025
3b13a9a
ruff
yfw Apr 4, 2025
95c4508
Remove expandable_segments_enabled
yfw Apr 4, 2025
44edbe8
whitespace
yfw Apr 4, 2025
412f57d
Offload/onload buffers
yfw Apr 5, 2025
ba48445
Activation checkpointing
yfw Apr 5, 2025
ef25d4b
ruff + tests
yfw Apr 5, 2025
b26b86a
No fsdp for single gpu
yfw Apr 10, 2025
65d47de
Cleanup
yfw Apr 10, 2025
d511e0b
Merge remote-tracking branch 'origin' into yifu/cpu_offload2
yfw Apr 10, 2025
45f6de4
ruff
yfw Apr 10, 2025
b087ae2
ruff
yfw Apr 10, 2025
9194a57
Fix unit tests
yfw Apr 10, 2025
baf8396
Fix checkpoint
yfw Apr 10, 2025
3faa60e
Increase cicd timeout
yfw Apr 11, 2025
2eef1a5
Update nemo_reinforcer/models/policy/hf_policy.py
yfw Apr 11, 2025
8df247c
Merge branch 'main' into yifu/cpu_offload2
terrykong Apr 11, 2025
e926f92
Fix configs
yfw Apr 12, 2025
2887c56
ruff
yfw Apr 12, 2025
949b1ae
Merge remote-tracking branch 'origin' into yifu/cpu_offload2
yfw Apr 15, 2025
be4a3fa
Comment about held params
yfw Apr 15, 2025
30b826e
Merge remote-tracking branch 'origin' into yifu/cpu_offload2
yfw Apr 16, 2025
989ea20
Fix test
yfw Apr 16, 2025
6dd553c
Longer cicd timeout
yfw Apr 16, 2025
31223ac
30 min timeout
yfw Apr 16, 2025
dd3e713
Merge remote-tracking branch 'origin' into yifu/cpu_offload2
yfw Apr 16, 2025
75c0d2a
Merge remote-tracking branch 'origin' into yifu/cpu_offload2
yfw Apr 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ jobs:
if: ${{ needs.pre-flight.outputs.test_level != 'none' }}
with:
RUNNER: self-hosted-azure
TIMEOUT: 20
TIMEOUT: 30
UNIT_TEST_SCRIPT: |
cd /opt/reinforcer
if [[ "${{ needs.pre-flight.outputs.test_level }}" =~ ^(L0|L1|L2)$ ]]; then
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ policy:
logprob_batch_size: 4
max_total_sequence_length: 512
precision: "bfloat16"
fsdp_offload_enabled: false
activation_checkpointing_enabled: false

optimizer:
name: "torch.optim.AdamW"
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ policy:
logprob_batch_size: 2
max_total_sequence_length: 4096
precision: "bfloat16"
fsdp_offload_enabled: false
activation_checkpointing_enabled: false

optimizer:
name: "torch.optim.AdamW"
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ policy:
train_micro_batch_size: 1
max_total_sequence_length: 1024
precision: "float32"
fsdp_offload_enabled: false
activation_checkpointing_enabled: false

optimizer:
name: "torch.optim.AdamW"
Expand Down
4 changes: 2 additions & 2 deletions nemo_reinforcer/distributed/worker_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def __call__(
if env_vars:
if "runtime_env" not in options:
options["runtime_env"] = {}
options["runtime_env"]["env_vars"] = env_vars
for k, v in env_vars.items():
options["runtime_env"]["env_vars"][k] = v
Comment thread
terrykong marked this conversation as resolved.

# Apply initialization parameters
if init_kwargs:
Expand Down Expand Up @@ -207,7 +208,6 @@ def __init__(
self.cluster = cluster
self.name_prefix = name_prefix
self.tied_workers_groups = []

# Maps worker indices to their corresponding tied group index
# For example, if worker with index 3 belongs to tied worker group 1,
# then worker_to_tied_group_index[3] = 1
Expand Down
2 changes: 2 additions & 0 deletions nemo_reinforcer/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ class PolicyConfig(TypedDict):
logprob_batch_size: int
generation: GenerationConfig
precision: str
fsdp_offload_enabled: bool
Comment thread
SahilJain314 marked this conversation as resolved.
activation_checkpointing_enabled: bool
114 changes: 84 additions & 30 deletions nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import (
CPUOffload,
FullyShardedDataParallel,
MixedPrecision,
)
Expand Down Expand Up @@ -110,6 +111,12 @@ def __init__(
# ------------------------------------------------

def do_fsdp(model):
if world_size == 1:
print(
"[INFO] Using a single GPU - skipping FSDP wrapper to avoid GPU memory offloading issues"
)
return model

# Create a device mesh with 'world_size' GPUs in a 1D arrangement.
mesh = init_device_mesh("cuda", (world_size,))
mp_policy = MixedPrecision(
Expand All @@ -118,21 +125,32 @@ def do_fsdp(model):
buffer_dtype=torch.float32,
)

cpu_offload = (
CPUOffload(offload_params=True)
if self.cfg["fsdp_offload_enabled"]
else None
)

return FullyShardedDataParallel(
model,
device_mesh=mesh,
auto_wrap_policy=size_based_auto_wrap_policy,
mixed_precision=mp_policy,
cpu_offload=cpu_offload,
)

self.model.to("cuda")
if self.cfg["activation_checkpointing_enabled"]:
self.model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
self.model = do_fsdp(self.model)
self.model = self.move_to_cpu(self.model)
self.model = self.manual_offload_to_cpu(self.model)
if self.reference_model is not None:
self.reference_model.to("cuda")
self.reference_model = do_fsdp(self.reference_model)
self.reference_model = self.move_to_cpu(self.reference_model)
self.model.to("cuda")
self.reference_model = self.manual_offload_to_cpu(self.reference_model)
self.model = self.manual_load_to_gpu(self.model)
self._held_reference_model_params = None
# register_fsdp_forward_method(self.model, "generate")
if init_optimizer:
Expand Down Expand Up @@ -440,8 +458,8 @@ def use_reference_model(self):
original_model = self.model
original_reference_model = self.reference_model

self.model = self.move_to_cpu(self.model)
self.reference_model = self.reference_model.to("cuda")
self.model = self.manual_offload_to_cpu(self.model)
self.reference_model = self.manual_load_to_gpu(self.reference_model)

# Swap the references
self.model, self.reference_model = self.reference_model, self.model
Expand All @@ -454,8 +472,8 @@ def use_reference_model(self):

finally:
# Restore original references and device placement
self.reference_model = self.move_to_cpu(original_reference_model)
self.model = original_model.to("cuda")
self.reference_model = self.manual_offload_to_cpu(original_reference_model)
self.model = self.manual_load_to_gpu(original_model)
gc.collect()
torch.cuda.empty_cache()

Expand Down Expand Up @@ -537,7 +555,13 @@ def generate(
# Set attention mask for the actual tokens (at the end for left padding)
left_padded_attention_mask[i, seq_len - length :] = 1

outputs = self.model.module.generate(
if isinstance(
self.model, torch.distributed.fsdp.FullyShardedDataParallel
):
generation_module = self.model.module
else:
generation_module = self.model
outputs = generation_module.generate(
input_ids=left_padded_input_ids,
attention_mask=left_padded_attention_mask,
max_new_tokens=gen_cfg["max_new_tokens"],
Expand Down Expand Up @@ -724,6 +748,11 @@ def report_device_id(self) -> str:
def get_weight_ipc_handles(self, offload_model=True):
from torch.multiprocessing.reductions import reduce_tensor

# If the model is not FSDP, then we need to manually move it to the GPU
# For an FSDP model, model.state_dict() will move the params to the GPU
if not isinstance(self.model, torch.distributed.fsdp.FullyShardedDataParallel):
self.model = self.manual_load_to_gpu(self.model)

# TODO @sahilj: do this without an allgather (maybe FSDP2)
params = self.model.state_dict()

Expand All @@ -735,49 +764,50 @@ def get_weight_ipc_handles(self, offload_model=True):

# Replace the original params with the converted ones
params = dtype_params
# For FSDP1, params may get GC'ed before sending to vllm,
# so we need to hold a reference to them
self._held_reference_model_params = params
data = {}
device_uuid = self.report_device_id()
for name, p in params.items():
data[name] = reduce_tensor(p.detach())

if offload_model:
self.model = self.move_to_cpu(self.model)
self.model = self.manual_offload_to_cpu(self.model)
gc.collect()
torch.cuda.empty_cache()
return {device_uuid: data}

def prepare_for_lp_inference(self):
self.model.to("cuda")
self.model = self.manual_load_to_gpu(self.model)
self.model.eval()
self.offload_before_refit()

def prepare_for_training(self, *args, **kwargs):
# onload models and optimizer state to cuda
self.model.to("cuda")
self.model = self.manual_load_to_gpu(self.model)
self.model.train()

# Move optimizer state to CUDA if it exists
if hasattr(self, "optimizer") and self.optimizer is not None:
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v) and not v.is_cuda:
state[k] = v.to("cuda")
if not self.cfg["fsdp_offload_enabled"]:
# Move optimizer state to CUDA if it exists
if hasattr(self, "optimizer") and self.optimizer is not None:
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v) and not v.is_cuda:
state[k] = v.to("cuda")

torch.cuda.empty_cache()

@torch.no_grad()
def offload_before_refit(self):
"""Offload the optimizer and buffers to the CPU."""
torch.randn(1).cuda() # wake up torch allocator
if hasattr(self, "optimizer") and self.optimizer is not None:
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to("cpu")

for buffer in self.model.buffers():
buffer.data = buffer.data.to("cpu")
if not self.cfg["fsdp_offload_enabled"]:
if hasattr(self, "optimizer") and self.optimizer is not None:
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to("cpu")

gc.collect()
torch.cuda.empty_cache()
Expand All @@ -792,7 +822,7 @@ def offload_before_refit(self):
@torch.no_grad()
def offload_after_refit(self):
# Offload as much as possible on the CPU
self.model = self.move_to_cpu(self.model)
self.model = self.manual_offload_to_cpu(self.model)
self.model.eval()
torch.randn(1).cuda() # wake up torch allocator
self.offload_before_refit() # rerun the old offload function
Expand All @@ -810,15 +840,39 @@ def offload_after_refit(self):
f"GPU Memory after refit complete: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved"
)

def move_to_cpu(self, model):
def manual_offload_to_cpu(self, model):
if self.cfg["fsdp_offload_enabled"]:
return model

for param in model.parameters():
param.data = param.data.to("cpu")
param.data = param.data.to("cpu", non_blocking=True)
if hasattr(param, "_local_shard"):
param._local_shard = param.data
if param.grad is not None:
param.grad = param.grad.to("cpu", non_blocking=True)
for buffer in model.buffers():
buffer.data = buffer.data.to("cpu", non_blocking=True)

if hasattr(model, "_fsdp_wrapped_module"):
self.manual_offload_to_cpu(model._fsdp_wrapped_module)

return model

def manual_load_to_gpu(self, model):
if self.cfg["fsdp_offload_enabled"]:
return model

for param in model.parameters():
param.data = param.data.to("cuda", non_blocking=True)
if hasattr(param, "_local_shard"):
param._local_shard = param.data
if param.grad is not None:
param.grad = param.grad.to("cuda", non_blocking=True)
for buffer in model.buffers():
buffer.data = buffer.data.to("cpu")
buffer.data = buffer.data.to("cuda", non_blocking=True)

if hasattr(model, "_fsdp_wrapped_module"):
model._fsdp_wrapped_module.to("cpu")
self.manual_load_to_gpu(model._fsdp_wrapped_module)

return model

Expand Down
5 changes: 4 additions & 1 deletion nemo_reinforcer/utils/native_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def save_checkpoint(
save_hf: Whether to save in HuggingFace format
"""
if save_hf:
model_state_dict = model._fsdp_wrapped_module.state_dict()
if hasattr(model, "_fsdp_wrapped_module"):
model_state_dict = model._fsdp_wrapped_module.state_dict()
else:
model_state_dict = model.state_dict()

if torch.distributed.get_rank() == 0:
# Create a new path by appending "-hf" to the weights path
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
"max_new_tokens": 16,
"do_sample": False,
"precision": "float32",
"fsdp_offload_enabled": False,
"activation_checkpointing_enabled": False,
"optimizer": {
"name": "torch.optim.AdamW",
"kwargs": {
Expand Down
Loading