Skip to content

[StreamingDataLoader, 5/N] Refactor StreamDataLoader implementation#23

Merged
0oshowero0 merged 1 commit intoAscend:mainfrom
NINGBENZHE:main
Feb 2, 2026
Merged

[StreamingDataLoader, 5/N] Refactor StreamDataLoader implementation#23
0oshowero0 merged 1 commit intoAscend:mainfrom
NINGBENZHE:main

Conversation

@NINGBENZHE
Copy link
Contributor

@NINGBENZHE NINGBENZHE commented Feb 1, 2026

  1. Simplify RankAwareSampler implementation.
  2. Add rank-aware capability in GRPOGroupNSampler.
  3. Support custom fetch, post-process, replay feature in StreamingDataLoader.

@0oshowero0
Copy link
Collaborator

Please use git commit -s -m so that you can pass DCO check~

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request refactors the StreamingDataLoader implementation to support fully asynchronous mode, as part 5 of a series. The main changes involve:

Purpose: Refactor the data loading infrastructure to use a simpler, cache-based coordination mechanism for distributed data parallel training.

Changes:

  • Replace the three-parameter data replica coordination system (data_replica_group, data_replica_rank, data_replica_world_size) with a single dp_rank parameter
  • Introduce batch-index-based caching in samplers for deterministic sample distribution
  • Add buffer management and new iteration control methods (reset(), step()) to StreamingDataset
  • Split batch retrieval and post-processing logic into separate configurable functions
  • Export AsyncTransferQueueClient in the public API

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
tutorial/05_streaming_dataloader.py Updated example to use simplified dp_rank parameter instead of three-parameter data replica system
transfer_queue/sampler/rank_aware_sampler.py Refactored to use dp_rank and batch_index for caching instead of rank-based coordination
transfer_queue/sampler/grpo_group_n_sampler.py Added caching mechanism using dp_rank and batch_index for deterministic sampling
transfer_queue/sampler/base.py Added clear_cache method for partition cleanup
transfer_queue/dataloader/streaming_dataset.py Major refactoring: added buffer management, split batch retrieval into default_get_batch and default_post_process_for_micro_func functions
transfer_queue/dataloader/streaming_dataloader.py Added delegation methods (reset, step, get_buffer) to underlying dataset
transfer_queue/controller.py Fixed partition creation logic and added kwargs propagation to sampler
transfer_queue/client.py Added defensive check for empty consumption status tensor
transfer_queue/init.py Exported AsyncTransferQueueClient in public API
tests/test_samplers.py Updated all tests to use new dp_rank and batch_index parameters
Comments suppressed due to low confidence (2)

transfer_queue/sampler/rank_aware_sampler.py:86

  • The docstring is outdated and doesn't match the new implementation. It still describes the old behavior of "first rank in each data replica group" and "subsequent ranks in the same data replica group", but the new implementation uses a simpler caching mechanism based on dp_rank and batch_index where any call with the same dp_rank and batch_index returns cached results. The docstring should be updated to reflect this change.
        """Sample indices for the current rank, coordinating with other data replica ranks.

        This method implements coordinated sampling for distributed training.
        The first rank in each data replica group to call this method performs actual sampling
        from ``ready_indexes`` and caches the result. Subsequent ranks in the same
        data replica group receive the cached indices directly.

        Internal state structure (self._states):

        .. code-block:: python

            self._states = {
                "partition_id": {
                    "task_name": {
                        dp_rank: [sampled_indexes]
                    }
                }
            }

        State lifecycle:
        1. First rank samples from ``ready_indexes``, caches results for other ranks
        2. Other ranks pop and retrieve the cached indices

transfer_queue/sampler/base.py:53

    def sample(
        self,
        ready_indexes: list[int],
        batch_size: int,
        *args: Any,
        **kwargs: Any,
    ) -> tuple[list[int], list[int]]:

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

dataset = StreamingDataset(
config=config,
batch_size=2,
micro_batch_size=2, # Number of samples per batch
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline comment "Number of samples per batch" on line 208 is misleading. The micro_batch_size parameter controls the size of each micro-batch that results from splitting the main batch (controlled by batch_size on line 207). Since both are set to 2, there will be no splitting, but the comment should clarify that this is the micro-batch size, not the total batch size.

Suggested change
micro_batch_size=2, # Number of samples per batch
micro_batch_size=2, # Number of samples per micro-batch (main batch size is controlled by batch_size)

Copilot uses AI. Check for mistakes.
Comment on lines 184 to 195
Ranks in the same group receive the same data samples
data_replica_rank: Local rank index within the data replica group
Range: [0, data_replica_world_size - 1]
data_replica_world_size: Total number of ranks in this data replica group
config: TransferQueue configuration
max_steps: Maximum number of batches to consume

Returns:
dict: Contains data_replica_rank, data_replica_group, and consumed_ids

Example:
For a setup with 2 data replica groups (0 and 1), each with 2 ranks:
For a setup with 2 data rank (0 and 1), each with 2 ranks:
- Group 0: ranks [0, 1] receive identical samples
- Group 1: ranks [2, 3] receive identical samples
All ranks within the same group get the same global indexes.
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example description states "Ranks in the same group receive the same data samples" (line 184) and "Group 0: ranks [0, 1] receive identical samples" (line 193), but with the new implementation, this only works if ranks with the same dp_rank call the sampler with the same batch_index values. If ranks iterate at different speeds or start at different times, they will have different batch_index values and receive different samples. The documentation should clarify this synchronization requirement or the behavior when ranks are not synchronized.

Copilot uses AI. Check for mistakes.

Example:
For a setup with 2 data replica groups (0 and 1), each with 2 ranks:
For a setup with 2 data rank (0 and 1), each with 2 ranks:
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tutorial example comment is grammatically incorrect. "For a setup with 2 data rank (0 and 1), each with 2 ranks" should be "For a setup with 2 data parallel ranks (0 and 1)" - but this doesn't make logical sense as written. The example describes groups where ranks [0,1] receive identical samples and ranks [2,3] receive identical samples, which suggests there should be 2 groups with 2 ranks each, not 2 ranks.

Suggested change
For a setup with 2 data rank (0 and 1), each with 2 ranks:
For a setup with 2 data replica groups, each with 2 ranks (global ranks 03):

Copilot uses AI. Check for mistakes.
Comment on lines 182 to 187
if dp_rank is not None:
if dp_rank not in states:
states[dp_rank] = {}
states[dp_rank][batch_index] = (sampled_indexes, consumed_indexes, 1)
elif batch_index not in states[dp_rank]:
states[dp_rank][batch_index] = (sampled_indexes, consumed_indexes, 1)
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The caching logic has a bug. When dp_rank not in states, it creates states[dp_rank] = {} and immediately sets states[dp_rank][batch_index] in line 185. However, if the condition on line 186 (elif batch_index not in states[dp_rank]) is entered (which can only happen if dp_rank in states already), it sets the same value again on line 187. This redundant assignment in the elif branch is unnecessary since it's already handled in the if branch. The elif should just be updating the states without the initial dictionary creation.

Copilot uses AI. Check for mistakes.
partition_id: The partition ID associated with the task.
"""
if partition_id in self._states.keys():
self._states[partition_id] = {}
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant operation: Line 89 calls self._states.pop(partition_id) after already clearing the partition with self._states[partition_id] = {} on line 88. The line 88 assignment is unnecessary since line 90 immediately removes the key. Either remove line 89 (just clear the dict) or remove line 88 (just pop the key).

Suggested change
self._states[partition_id] = {}

Copilot uses AI. Check for mistakes.
Comment on lines 256 to 263
sampling_config["batch_index"] = batch_index
sampling_config["partition_id"] = partition_id
batch_meta = tq_client.get_meta(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
task_name=task_name,
sampling_config=sampling_config,
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutation of shared state: The function mutates the sampling_config dictionary by adding batch_index and partition_id keys (lines 256-257). This could cause unexpected side effects if the same dictionary is reused across multiple calls. Consider creating a copy of the config before mutation: config = {**sampling_config, "batch_index": batch_index, "partition_id": partition_id}.

Suggested change
sampling_config["batch_index"] = batch_index
sampling_config["partition_id"] = partition_id
batch_meta = tq_client.get_meta(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
task_name=task_name,
sampling_config=sampling_config,
config = {**sampling_config, "batch_index": batch_index, "partition_id": partition_id}
batch_meta = tq_client.get_meta(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
task_name=task_name,
sampling_config=config,

Copilot uses AI. Check for mistakes.
identify which data replica group this rank belongs to.
data_replica_rank: Local rank inside this data_replica_group.
data_replica_world_size: Total number of ranks in this data_replica_group.
dp_rank: The group id of current dp_rank. Used to identify which dp_rank belongs to.
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grammatical error in the docstring. "The group id of current dp_rank. Used to identify which dp_rank belongs to" is incomplete and grammatically incorrect. Should be something like "The data parallel rank ID for this worker. Used to identify which data parallel group this rank belongs to."

Suggested change
dp_rank: The group id of current dp_rank. Used to identify which dp_rank belongs to.
dp_rank: The data parallel rank ID for this worker. Used to identify which
data parallel group this rank belongs to.

Copilot uses AI. Check for mistakes.
Args:
rank_id: Global rank identifier for logging and display purposes
data_replica_group: ID of the data parallel group this rank belongs to
dp_rank: ID of the data parallel rank belongs to
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grammatical error in docstring. "ID of the data parallel rank belongs to" is incomplete. Should be something like "Data parallel rank ID that this worker belongs to" or similar complete sentence.

Suggested change
dp_rank: ID of the data parallel rank belongs to
dp_rank: Data parallel rank ID that this worker belongs to

Copilot uses AI. Check for mistakes.
if len(required_fields) < 1:
raise ValueError(f"required_fields must be a list with at least one field name, got {required_fields}")
if len(data_fields) < 1:
raise ValueError(f"required_fields must be a list with at least one field name, got {data_fields}")
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message still references the old parameter name required_fields instead of the new parameter name data_fields. The message should be updated to: "data_fields must be a list with at least one field name, got {data_fields}"

Suggested change
raise ValueError(f"required_fields must be a list with at least one field name, got {data_fields}")
raise ValueError(f"data_fields must be a list with at least one field name, got {data_fields}")

Copilot uses AI. Check for mistakes.

# Return cached result if available
if dp_rank in states.keys() and batch_index in states[dp_rank].keys():
return states[dp_rank][batch_index]
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cached result is a tuple of 3 elements (sampled_indexes, consumed_indexes, 1) as shown in lines 185 and 187, but here it's being returned directly as a 3-element tuple. However, the function signature indicates it should return a 2-element tuple (sampled_indexes, consumed_indexes). This will cause the caller to receive an unexpected third element.

Suggested change
return states[dp_rank][batch_index]
cached = states[dp_rank][batch_index]
sampled_indexes, consumed_indexes = cached[0], cached[1]
return sampled_indexes, consumed_indexes

Copilot uses AI. Check for mistakes.
data_replica_rank: int,
data_replica_world_size: int,
dp_rank: int,
batch_index: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update docstring so user can know what is this param

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, batch_index has already used to refer to the relative position of each sample in BatchMeta. Consider use another one?

data_replica_group: {
data_replica_rank: [[sampled_indexes], ...] # Buffer of cached sampled indices
}
dp_rank: [sampled_indexes]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing batch_index

identify which data replica group this rank belongs to.
data_replica_rank: Local rank inside this data_replica_group.
data_replica_world_size: Total number of ranks in this data_replica_group.
dp_rank: The group id of current dp_rank. Used to identify which dp_rank belongs to.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing batch_index

import os

from .client import (
AsyncTransferQueueClient,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now TransferQueueClient has both sync & async function. So we don't need this to be exposed.


if mode == "insert":
partition = self._get_partition(partition_id)
if partition_id not in self.partitions:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe for other modes when user doesn't specify partition_id?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've checked—when partition_id is not provided, it defaults to returning an empty value, which should not cause any issues. I plan to rely on the pipeline to ensure no additional problems are introduced.

data_replica_rank: int,
data_replica_world_size: int,
dp_rank: int,
n_samples_per_prompt: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is algorithm related param. Can we implement this in Sampler rather than in StreamingDataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_meta = tq_client.get_meta(
    data_fields=data_fields,
    batch_size=batch_size,
    partition_id=partition_id,
    task_name=task_name,
    sampling_config=sampling_config,
)

The get_meta function requires n_samples_per_prompt; I can set its default value to 1.

data_replica_world_size: int,
dp_rank: int,
n_samples_per_prompt: int,
custom_get_batch_func: Any = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about fetch_batch_fn: Callable = None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

dp_rank: int,
n_samples_per_prompt: int,
custom_get_batch_func: Any = None,
custom_post_process_for_micro_func: Any = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about process_batch_fn: Callable = None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@0oshowero0 0oshowero0 changed the title [StreamingDataLoader, 5/N] feat: Refactor the StreamDataLoader implem… [StreamingDataLoader, 5/N] Refactor StreamDataLoader implementation Feb 2, 2026
assert sampled1_g0 == [0, 1]
assert consumed1_g0 == [0, 1]
assert sampled1 == [0, 1]
assert sampled2 == [0, 1] # Same because both sample from the beginning of ready_indexes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little bit misleading. We should mimic the controller's behavior that manually delete the samples in ready_indexes that show up in consumed1

assert sampler._states["test"]["task0"][0][1] == [[0, 1]]
assert sampler._states["test"]["task1"][0][0] == []
assert sampler._states["test"]["task1"][0][1] == [[0, 1]]
# Should return cached result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this already checked in test_rank_aware_sampler_multiple_dp_ranks()?

data_replica_world_size: int,
dp_rank: int,
n_samples_per_prompt: int = 1,
custom_get_batch_func: Callable | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to change? :)

data_replica_rank: int,
data_replica_world_size: int,
dp_rank: int,
n_samples_per_prompt: int = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need to add a TODO here. It's possible to set this n_sample param during the initialization of Sampler, so we don't have to set it during get_meta(). It can make the StreamingDataLoader more universal

assert consumed0_g0 == [0, 1]
assert consumed0_g1 == [2, 3]
assert sampled1 == [0, 1]
assert sampled2 == [2, 3] # Same because both sample from the beginning of ready_indexes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment here is wrong

…entation to support fully asynchronous mode

Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
Copy link
Collaborator

@0oshowero0 0oshowero0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@0oshowero0 0oshowero0 merged commit cd12c1e into Ascend:main Feb 2, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants