From 9a9f6383c7f5d7ac4906afde1fdef52ace1bb725 Mon Sep 17 00:00:00 2001 From: Yubo Gao Date: Thu, 4 Sep 2025 17:24:37 -0700 Subject: [PATCH 1/3] make clear_cache optional Signed-off-by: Yubo Gao --- ...-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml | 1 + nemo_rl/models/policy/__init__.py | 1 + nemo_rl/models/policy/dtensor_policy_worker.py | 9 ++++++++- nemo_rl/models/policy/dtensor_policy_worker_v2.py | 10 +++++++++- 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml b/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml index 084ea843f2..e20d7970bb 100644 --- a/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml +++ b/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml @@ -49,6 +49,7 @@ policy: tensor_parallel_size: 8 context_parallel_size: 1 custom_parallel_plan: null + clear_cache_every_n_steps: 1 env_vars: PYTORCH_CUDA_ALLOC_CONF: "max_split_size_mb:64" diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 3f2fcfe877..7e38938db9 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -27,6 +27,7 @@ class DTensorConfig(TypedDict): tensor_parallel_size: NotRequired[int] context_parallel_size: NotRequired[int] custom_parallel_plan: NotRequired[str] + clear_cache_every_n_steps: NotRequired[int] class SequencePackingConfig(TypedDict): diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 7b2f0de271..c0ad1965d3 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -16,6 +16,7 @@ import gc import itertools import os +import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import Any, Generator, Iterable, Optional, Set, Union, cast @@ -629,10 +630,16 @@ def train( mb_iterator = batch.make_microbatch_iterator(mbs) iterator_len = batch.size // mbs + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get("empty_cache_every_n_steps") + if empty_cache_steps: + warnings.warn(f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead.") + for mb_idx, mb in enumerate( itertools.chain(mb_iterator, dummy_iterator) ): - torch.cuda.empty_cache() + # Conditioanlly empty cache when sensitive to fragmentation + if empty_cache_steps and mb_idx % empty_cache_steps == 0: + torch.cuda.empty_cache() with torch.autocast(device_type="cuda", dtype=self.dtype): if self.enable_seq_packing: diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 8d56c3e6eb..868d35c7fa 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -15,6 +15,7 @@ import gc import itertools import os +import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import Any, Generator, Iterable, Optional, cast @@ -557,10 +558,17 @@ def train( mb_iterator = batch.make_microbatch_iterator(mbs) iterator_len = batch.size // mbs + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get("empty_cache_every_n_steps") + if empty_cache_steps: + warnings.warn(f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead.") + for mb_idx, mb in enumerate( itertools.chain(mb_iterator, dummy_iterator) ): - torch.cuda.empty_cache() + + # Conditioanlly empty cache when sensitive to fragmentation + if empty_cache_steps and mb_idx % empty_cache_steps == 0: + torch.cuda.empty_cache() with torch.autocast(device_type="cuda", dtype=self.dtype): if self.enable_seq_packing: From 852e2532db273645ab2b0055cd329429b8104526 Mon Sep 17 00:00:00 2001 From: Yubo Gao Date: Thu, 4 Sep 2025 17:30:49 -0700 Subject: [PATCH 2/3] lint Signed-off-by: Yubo Gao --- nemo_rl/models/policy/dtensor_policy_worker.py | 8 ++++++-- nemo_rl/models/policy/dtensor_policy_worker_v2.py | 9 ++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index c0ad1965d3..14d6a118e6 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -630,9 +630,13 @@ def train( mb_iterator = batch.make_microbatch_iterator(mbs) iterator_len = batch.size // mbs - empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get("empty_cache_every_n_steps") + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( + "empty_cache_every_n_steps" + ) if empty_cache_steps: - warnings.warn(f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead.") + warnings.warn( + f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead." + ) for mb_idx, mb in enumerate( itertools.chain(mb_iterator, dummy_iterator) diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 868d35c7fa..67416dfa56 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -558,14 +558,17 @@ def train( mb_iterator = batch.make_microbatch_iterator(mbs) iterator_len = batch.size // mbs - empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get("empty_cache_every_n_steps") + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( + "empty_cache_every_n_steps" + ) if empty_cache_steps: - warnings.warn(f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead.") + warnings.warn( + f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead." + ) for mb_idx, mb in enumerate( itertools.chain(mb_iterator, dummy_iterator) ): - # Conditioanlly empty cache when sensitive to fragmentation if empty_cache_steps and mb_idx % empty_cache_steps == 0: torch.cuda.empty_cache() From cf551e5d80d7cc6272d3a69ea295c24d3fed0b7c Mon Sep 17 00:00:00 2001 From: Yubo Gao Date: Fri, 5 Sep 2025 07:02:31 -0700 Subject: [PATCH 3/3] fix base config Signed-off-by: Yubo Gao --- examples/configs/dpo.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 4a438e127e..fe953390e8 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -54,6 +54,7 @@ policy: tensor_parallel_size: 1 context_parallel_size: 1 custom_parallel_plan: null + clear_cache_every_n_steps: null dynamic_batching: enabled: false