Skip to content

fix(infra): drain RTensor _fetch_buffer on all consumer workers#1282

Merged
garrett4wade merged 2 commits intoinclusionAI:mainfrom
guozhihao-224:fix/rtensor-fetch-buffer-leak
Apr 30, 2026
Merged

fix(infra): drain RTensor _fetch_buffer on all consumer workers#1282
garrett4wade merged 2 commits intoinclusionAI:mainfrom
guozhihao-224:fix/rtensor-fetch-buffer-leak

Conversation

@guozhihao-224
Copy link
Copy Markdown
Collaborator

Description

Fixes a per-process memory leak in _fetch_buffer (a module-level cache in areal/infra/rpc/rtensor.py) on cross-node consumer workers. The cache is populated by RTensor.to_local() on every RPC arrival carrying a meta RTensor, but the existing end-of-step cleanup only touched the controller's cache (RTensor.clear_node) and the storage owner's _storage (/data/clearremove()). Actor/critic/ref DP head processes — each a separate Python process with its own _fetch_buffer — were never drained. For VLM training this caused RSS to grow ~2 GB/step and crash around step 113 at train_batch_size=32.

Establishes a three-layer cleanup invariant — every process that can populate _fetch_buffer now drains at step end:

  1. Controller: RTensor.clear_node pops locally (unchanged).
  2. Storage owner: remove() also pops _fetch_buffer, covering the storage-owner-as-consumer case.
  3. Cross-node consumer: TrainController.clear_batches fans out a replicated RPC carrying a flat list[str] of shard IDs to every DP head, which runs clear_fetch_buffer(sids) on its local buffer.

Sending a flat list (not the original nested RTensor structure) is deliberate: engine_blueprint always runs RTensor.localize on incoming args — passing RTensors would re-fetch the shards we are trying to clear. _is_tensor_like(list[str]) is False, so dispatch goes through _replicate_inputs and every DP head sees the full sid set.

Observability: TrainController.clear_batches always calls fetch_buffer_stats on DP head 0 after drain and logs WARNING on leak, DEBUG on clean. The stats RPC is wrapped in try/except — observability must not break training.

Trainer fan-out widened: rl_trainer now clears actor/critic/ref (was: actor only); dpo_trainer adds ref clear. sft/rw trainers unchanged (only actor localizes batches).

Related Issue

Fixes #1209

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Breaking Change Details (if applicable):

N/A

Additional Context

Known limitations / follow-ups (intentionally deferred):

  • Ray backend: RayRTensorBackend.delete goes through ray.internal.free, not rtensor.remove(), so the storage-owner-as-consumer case is not covered for Ray. Issue [BUG] RTensor._fetch_buffer leaks unboundedly on worker processes, crashes long-running training (VLM) #1209 only reproduces on HTTP backend; Ray has its own object-store refcount semantics. Separate PR.
  • Observability on DP head 0 only: _custom_function_call collapses scalar dispatch via _collect_results[0], so the drain-check log reflects only DP head 0. Leaks are symmetric across heads in steady state, so head 0 is a sufficient canary.

Files changed:

  • areal/infra/rpc/rtensor.py: add clear_fetch_buffer, fetch_buffer_stats, flatten_shard_ids helpers; remove() pops _fetch_buffer alongside _storage
  • areal/infra/controller/train_controller.py: two-fan-out clear_batches (HTTP DELETE + replicated engine RPC) with observability log
  • areal/engine/{fsdp,megatron,sglang_remote,vllm_remote}_engine.py and areal/experimental/engine/archon_engine.py: implement clear_batches(shard_ids) and fetch_buffer_stats() RPC targets
  • areal/trainer/{rl_trainer,dpo_trainer}.py: fan out clear_batches to every role that localizes the batch
  • tests/test_rtensor.py: 7 new unit tests — clear_fetch_buffer selective/flush, fetch_buffer_stats, remove()-pops-fetch-buffer regression, flatten_shard_ids on nested/empty structures

🤖 Generated with Claude Code

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses memory leaks (issue #1209) by ensuring that client-side RTensor fetch buffers are properly drained across all processes at the end of training steps. It implements a two-stage clear_batches process in the TrainController, adds utility functions for buffer management in rtensor.py, and updates various engine and trainer classes to invoke these cleanup routines. Feedback includes a suggestion to deduplicate shard IDs in flatten_shard_ids to optimize RPC payloads and a note on potential lock contention in clear_fetch_buffer.

so every head must see the full sid set to drain completely.
"""
shards_by_node = RTensor.collect_shards(obj)
return [sid for sids in shards_by_node.values() for sid in sids]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The flatten_shard_ids function returns a list that may contain duplicate shard IDs if the same RTensor appears multiple times in the input structure (e.g., across different micro-batches or roles). While clear_fetch_buffer handles duplicates correctly, deduplicating the list here reduces the RPC payload size and the number of pop operations on the workers.

Suggested change
return [sid for sids in shards_by_node.values() for sid in sids]
return list({sid for sids in shards_by_node.values() for sid in sids})

n = len(_fetch_buffer)
_fetch_buffer.clear()
return n
return sum(_fetch_buffer.pop(sid, None) is not None for sid in shard_ids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Iterating over shard_ids while holding _fetch_buffer_lock can block other threads (like those calling to_local() or localize()) for the duration of the loop. If shard_ids is a very large collection, this could lead to contention. Consider deduplicating shard_ids before entering the lock context or using a more efficient bulk removal if possible.

Comment thread areal/engine/vllm_remote.py Outdated
@guozhihao-224
Copy link
Copy Markdown
Collaborator Author

@sitabulaixizawaluduo Went with Option B. Changes:

  1. Engine clear_batches signatures — dropped | None = None, now shard_ids: list[str] (required). Removed the redundant if shard_ids: guard since upstream TrainController.clear_batches already returns early on empty sids.

    • fsdp_engine.py, megatron_engine.py, sglang_remote.py, vllm_remote.py, experimental/engine/archon_engine.py
  2. Trainer call sites — wrapped the clear_batches fan-out in if is_single_controller(): for rl_trainer, dpo_trainer, sft_trainer, rw_trainer. Encodes the "_fetch_buffer only exists in single-controller mode" invariant directly in the code.

Ready for another look when you have tim

@sitabulaixizawaluduo
Copy link
Copy Markdown
Collaborator

Hey, I accidentally hit the rebase button on the GitHub UI, which made me the committer on your commits. Could you do a quick rebase and force push to fix it?

@guozhihao-224 guozhihao-224 force-pushed the fix/rtensor-fetch-buffer-leak branch from 38605ba to 5b3872c Compare April 30, 2026 10:07
@guozhihao-224
Copy link
Copy Markdown
Collaborator Author

guozhihao-224 commented Apr 30, 2026

Hey, I accidentally hit the rebase button on the GitHub UI, which made me the committer on your commits. Could you do a quick rebase and force push to fix it?

@sitabulaixizawaluduo done

@guozhihao-224 guozhihao-224 force-pushed the fix/rtensor-fetch-buffer-leak branch from 5b3872c to acda439 Compare April 30, 2026 10:33
guozhihao-224 and others added 2 commits April 30, 2026 18:37
The module-level _fetch_buffer in areal/infra/rpc/rtensor.py is a
per-process cache populated by RTensor.to_local() on every RPC arrival
that carries meta RTensors. End-of-step cleanup only touched the
controller's cache (RTensor.clear_node) and the storage owner's
_storage (HTTP DELETE /data/clear -> rtensor.remove). The _fetch_buffer
on cross-node consumer workers (actor/critic/ref DP heads) was never
drained, so RSS grew ~2 GB/step for VLM training and crashed around
step 113 at train_batch_size=32.

Establish a three-layer cleanup invariant -- every process that could
populate _fetch_buffer now drains at step end:

1. Controller: RTensor.clear_node pops locally (unchanged).
2. Storage owner: remove() also pops _fetch_buffer, covering the
   storage-owner-as-consumer case where one process both stores and
   localizes a shard.
3. Cross-node consumer: TrainController.clear_batches now fans out a
   replicated RPC carrying a flat list[str] of shard IDs to every DP
   head, which runs clear_fetch_buffer(sids) on its local buffer.

Sending a flat list (rather than the original nested RTensor structure)
is deliberate: engine_blueprint always runs RTensor.localize on
incoming args -- passing RTensors would re-fetch the shards we are
trying to clear. _is_tensor_like(list[str]) is False, so dispatch goes
through _replicate_inputs and every DP head sees the full sid set.

Observability: TrainController.clear_batches always calls
fetch_buffer_stats on DP head 0 after drain and logs WARNING on leak,
DEBUG on clean. The stats RPC is wrapped in try/except -- observability
must not break training.

Trainer fan-out widened: rl_trainer now clears actor/critic/ref (was:
actor only); dpo_trainer adds ref clear. sft/rw trainers unchanged
(only actor localizes batches).

Key changes:
- areal/infra/rpc/rtensor.py: add clear_fetch_buffer,
  fetch_buffer_stats, flatten_shard_ids helpers; remove() pops
  _fetch_buffer alongside _storage
- areal/infra/controller/train_controller.py: two-fan-out
  clear_batches (HTTP DELETE + replicated engine RPC) with
  observability log
- areal/engine/{fsdp,megatron,sglang_remote,vllm_remote}.py and
  areal/experimental/engine/archon_engine.py: implement
  clear_batches(shard_ids) and fetch_buffer_stats() RPC targets
- areal/trainer/{rl_trainer,dpo_trainer}.py: fan out clear_batches
  to every role that localizes the batch
- tests/test_rtensor.py: 7 new unit tests covering
  clear_fetch_buffer selective/flush, fetch_buffer_stats,
  remove()-pops-fetch-buffer regression, flatten_shard_ids on
  nested/empty structures

Refs: inclusionAI#1209

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Updated the clear_batches method across multiple engine classes to require a non-empty list of shard_ids. This change simplifies the logic by removing the conditional check for empty inputs, as upstream calls guarantee that shard_ids will always be non-empty. The documentation has been updated to reflect this assumption, ensuring clarity for future developers.

Key changes:
- Modified clear_batches method signature in FSDPEngine, MegatronEngine, SGLangEngine, vLLMEngine, and ArchonEngine to accept only non-empty shard_ids.
- Removed unnecessary conditional checks for empty shard_ids in the clear_batches implementation.

Refs: inclusionAI#1209
@guozhihao-224 guozhihao-224 force-pushed the fix/rtensor-fetch-buffer-leak branch from acda439 to d907735 Compare April 30, 2026 10:38
@garrett4wade garrett4wade merged commit e0c004a into inclusionAI:main Apr 30, 2026
6 checks passed
Adiactive added a commit to Adiactive/AReaL that referenced this pull request Apr 30, 2026
`TrainController._call_workers` (train_controller.py:558-589) sends
positional args only to DP-head ranks; non-DP-head ranks invoke methods
with no args. The post-inclusionAI#1282 signature `clear_batches(self, shard_ids:
list[str])` therefore crashes the first end-of-step call on any TP/PP > 1
config with `missing 1 required positional argument: 'shard_ids'`,
taking the whole job down at step ~1.

Restore the pre-merge default and guard on FSDP, Megatron, vLLM-remote,
SGLang-remote, and Archon engines so non-DP-head ranks accept the no-args
call and noop (their _fetch_buffer is empty).

See inclusionAI#1209 follow-up comment.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] RTensor._fetch_buffer leaks unboundedly on worker processes, crashes long-running training (VLM)

3 participants