fix(infra): drain RTensor _fetch_buffer on all consumer workers#1282
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses memory leaks (issue #1209) by ensuring that client-side RTensor fetch buffers are properly drained across all processes at the end of training steps. It implements a two-stage clear_batches process in the TrainController, adds utility functions for buffer management in rtensor.py, and updates various engine and trainer classes to invoke these cleanup routines. Feedback includes a suggestion to deduplicate shard IDs in flatten_shard_ids to optimize RPC payloads and a note on potential lock contention in clear_fetch_buffer.
| so every head must see the full sid set to drain completely. | ||
| """ | ||
| shards_by_node = RTensor.collect_shards(obj) | ||
| return [sid for sids in shards_by_node.values() for sid in sids] |
There was a problem hiding this comment.
The flatten_shard_ids function returns a list that may contain duplicate shard IDs if the same RTensor appears multiple times in the input structure (e.g., across different micro-batches or roles). While clear_fetch_buffer handles duplicates correctly, deduplicating the list here reduces the RPC payload size and the number of pop operations on the workers.
| return [sid for sids in shards_by_node.values() for sid in sids] | |
| return list({sid for sids in shards_by_node.values() for sid in sids}) |
| n = len(_fetch_buffer) | ||
| _fetch_buffer.clear() | ||
| return n | ||
| return sum(_fetch_buffer.pop(sid, None) is not None for sid in shard_ids) |
There was a problem hiding this comment.
Iterating over shard_ids while holding _fetch_buffer_lock can block other threads (like those calling to_local() or localize()) for the duration of the loop. If shard_ids is a very large collection, this could lead to contention. Consider deduplicating shard_ids before entering the lock context or using a more efficient bulk removal if possible.
|
@sitabulaixizawaluduo Went with Option B. Changes:
Ready for another look when you have tim |
24d752b to
38605ba
Compare
|
Hey, I accidentally hit the rebase button on the GitHub UI, which made me the committer on your commits. Could you do a quick rebase and force push to fix it? |
38605ba to
5b3872c
Compare
|
5b3872c to
acda439
Compare
The module-level _fetch_buffer in areal/infra/rpc/rtensor.py is a
per-process cache populated by RTensor.to_local() on every RPC arrival
that carries meta RTensors. End-of-step cleanup only touched the
controller's cache (RTensor.clear_node) and the storage owner's
_storage (HTTP DELETE /data/clear -> rtensor.remove). The _fetch_buffer
on cross-node consumer workers (actor/critic/ref DP heads) was never
drained, so RSS grew ~2 GB/step for VLM training and crashed around
step 113 at train_batch_size=32.
Establish a three-layer cleanup invariant -- every process that could
populate _fetch_buffer now drains at step end:
1. Controller: RTensor.clear_node pops locally (unchanged).
2. Storage owner: remove() also pops _fetch_buffer, covering the
storage-owner-as-consumer case where one process both stores and
localizes a shard.
3. Cross-node consumer: TrainController.clear_batches now fans out a
replicated RPC carrying a flat list[str] of shard IDs to every DP
head, which runs clear_fetch_buffer(sids) on its local buffer.
Sending a flat list (rather than the original nested RTensor structure)
is deliberate: engine_blueprint always runs RTensor.localize on
incoming args -- passing RTensors would re-fetch the shards we are
trying to clear. _is_tensor_like(list[str]) is False, so dispatch goes
through _replicate_inputs and every DP head sees the full sid set.
Observability: TrainController.clear_batches always calls
fetch_buffer_stats on DP head 0 after drain and logs WARNING on leak,
DEBUG on clean. The stats RPC is wrapped in try/except -- observability
must not break training.
Trainer fan-out widened: rl_trainer now clears actor/critic/ref (was:
actor only); dpo_trainer adds ref clear. sft/rw trainers unchanged
(only actor localizes batches).
Key changes:
- areal/infra/rpc/rtensor.py: add clear_fetch_buffer,
fetch_buffer_stats, flatten_shard_ids helpers; remove() pops
_fetch_buffer alongside _storage
- areal/infra/controller/train_controller.py: two-fan-out
clear_batches (HTTP DELETE + replicated engine RPC) with
observability log
- areal/engine/{fsdp,megatron,sglang_remote,vllm_remote}.py and
areal/experimental/engine/archon_engine.py: implement
clear_batches(shard_ids) and fetch_buffer_stats() RPC targets
- areal/trainer/{rl_trainer,dpo_trainer}.py: fan out clear_batches
to every role that localizes the batch
- tests/test_rtensor.py: 7 new unit tests covering
clear_fetch_buffer selective/flush, fetch_buffer_stats,
remove()-pops-fetch-buffer regression, flatten_shard_ids on
nested/empty structures
Refs: inclusionAI#1209
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Updated the clear_batches method across multiple engine classes to require a non-empty list of shard_ids. This change simplifies the logic by removing the conditional check for empty inputs, as upstream calls guarantee that shard_ids will always be non-empty. The documentation has been updated to reflect this assumption, ensuring clarity for future developers. Key changes: - Modified clear_batches method signature in FSDPEngine, MegatronEngine, SGLangEngine, vLLMEngine, and ArchonEngine to accept only non-empty shard_ids. - Removed unnecessary conditional checks for empty shard_ids in the clear_batches implementation. Refs: inclusionAI#1209
acda439 to
d907735
Compare
`TrainController._call_workers` (train_controller.py:558-589) sends positional args only to DP-head ranks; non-DP-head ranks invoke methods with no args. The post-inclusionAI#1282 signature `clear_batches(self, shard_ids: list[str])` therefore crashes the first end-of-step call on any TP/PP > 1 config with `missing 1 required positional argument: 'shard_ids'`, taking the whole job down at step ~1. Restore the pre-merge default and guard on FSDP, Megatron, vLLM-remote, SGLang-remote, and Archon engines so non-DP-head ranks accept the no-args call and noop (their _fetch_buffer is empty). See inclusionAI#1209 follow-up comment.
Description
Fixes a per-process memory leak in
_fetch_buffer(a module-level cache inareal/infra/rpc/rtensor.py) on cross-node consumer workers. The cache is populated byRTensor.to_local()on every RPC arrival carrying a meta RTensor, but the existing end-of-step cleanup only touched the controller's cache (RTensor.clear_node) and the storage owner's_storage(/data/clear→remove()). Actor/critic/ref DP head processes — each a separate Python process with its own_fetch_buffer— were never drained. For VLM training this caused RSS to grow ~2 GB/step and crash around step 113 attrain_batch_size=32.Establishes a three-layer cleanup invariant — every process that can populate
_fetch_buffernow drains at step end:RTensor.clear_nodepops locally (unchanged).remove()also pops_fetch_buffer, covering the storage-owner-as-consumer case.TrainController.clear_batchesfans out a replicated RPC carrying a flatlist[str]of shard IDs to every DP head, which runsclear_fetch_buffer(sids)on its local buffer.Sending a flat list (not the original nested RTensor structure) is deliberate:
engine_blueprintalways runsRTensor.localizeon incoming args — passing RTensors would re-fetch the shards we are trying to clear._is_tensor_like(list[str])is False, so dispatch goes through_replicate_inputsand every DP head sees the full sid set.Observability:
TrainController.clear_batchesalways callsfetch_buffer_statson DP head 0 after drain and logs WARNING on leak, DEBUG on clean. The stats RPC is wrapped in try/except — observability must not break training.Trainer fan-out widened:
rl_trainernow clears actor/critic/ref (was: actor only);dpo_traineradds ref clear.sft/rwtrainers unchanged (only actor localizes batches).Related Issue
Fixes #1209
Type of Change
Checklist
jb build docs/gemini review)Breaking Change Details (if applicable):
N/A
Additional Context
Known limitations / follow-ups (intentionally deferred):
RayRTensorBackend.deletegoes throughray.internal.free, notrtensor.remove(), so the storage-owner-as-consumer case is not covered for Ray. Issue [BUG] RTensor._fetch_buffer leaks unboundedly on worker processes, crashes long-running training (VLM) #1209 only reproduces on HTTP backend; Ray has its own object-store refcount semantics. Separate PR._custom_function_callcollapses scalar dispatch via_collect_results[0], so the drain-check log reflects only DP head 0. Leaks are symmetric across heads in steady state, so head 0 is a sufficient canary.Files changed:
areal/infra/rpc/rtensor.py: addclear_fetch_buffer,fetch_buffer_stats,flatten_shard_idshelpers;remove()pops_fetch_bufferalongside_storageareal/infra/controller/train_controller.py: two-fan-outclear_batches(HTTP DELETE + replicated engine RPC) with observability logareal/engine/{fsdp,megatron,sglang_remote,vllm_remote}_engine.pyandareal/experimental/engine/archon_engine.py: implementclear_batches(shard_ids)andfetch_buffer_stats()RPC targetsareal/trainer/{rl_trainer,dpo_trainer}.py: fan outclear_batchesto every role that localizes the batchtests/test_rtensor.py: 7 new unit tests —clear_fetch_bufferselective/flush,fetch_buffer_stats,remove()-pops-fetch-buffer regression,flatten_shard_idson nested/empty structures🤖 Generated with Claude Code