[StreamingDataLoader, 5/N] Refactor StreamDataLoader implementation#23
[StreamingDataLoader, 5/N] Refactor StreamDataLoader implementation#230oshowero0 merged 1 commit intoAscend:mainfrom
StreamDataLoader implementation#23Conversation
|
Please use |
There was a problem hiding this comment.
Pull request overview
This pull request refactors the StreamingDataLoader implementation to support fully asynchronous mode, as part 5 of a series. The main changes involve:
Purpose: Refactor the data loading infrastructure to use a simpler, cache-based coordination mechanism for distributed data parallel training.
Changes:
- Replace the three-parameter data replica coordination system (
data_replica_group,data_replica_rank,data_replica_world_size) with a singledp_rankparameter - Introduce batch-index-based caching in samplers for deterministic sample distribution
- Add buffer management and new iteration control methods (
reset(),step()) to StreamingDataset - Split batch retrieval and post-processing logic into separate configurable functions
- Export AsyncTransferQueueClient in the public API
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| tutorial/05_streaming_dataloader.py | Updated example to use simplified dp_rank parameter instead of three-parameter data replica system |
| transfer_queue/sampler/rank_aware_sampler.py | Refactored to use dp_rank and batch_index for caching instead of rank-based coordination |
| transfer_queue/sampler/grpo_group_n_sampler.py | Added caching mechanism using dp_rank and batch_index for deterministic sampling |
| transfer_queue/sampler/base.py | Added clear_cache method for partition cleanup |
| transfer_queue/dataloader/streaming_dataset.py | Major refactoring: added buffer management, split batch retrieval into default_get_batch and default_post_process_for_micro_func functions |
| transfer_queue/dataloader/streaming_dataloader.py | Added delegation methods (reset, step, get_buffer) to underlying dataset |
| transfer_queue/controller.py | Fixed partition creation logic and added kwargs propagation to sampler |
| transfer_queue/client.py | Added defensive check for empty consumption status tensor |
| transfer_queue/init.py | Exported AsyncTransferQueueClient in public API |
| tests/test_samplers.py | Updated all tests to use new dp_rank and batch_index parameters |
Comments suppressed due to low confidence (2)
transfer_queue/sampler/rank_aware_sampler.py:86
- The docstring is outdated and doesn't match the new implementation. It still describes the old behavior of "first rank in each data replica group" and "subsequent ranks in the same data replica group", but the new implementation uses a simpler caching mechanism based on
dp_rankandbatch_indexwhere any call with the samedp_rankandbatch_indexreturns cached results. The docstring should be updated to reflect this change.
"""Sample indices for the current rank, coordinating with other data replica ranks.
This method implements coordinated sampling for distributed training.
The first rank in each data replica group to call this method performs actual sampling
from ``ready_indexes`` and caches the result. Subsequent ranks in the same
data replica group receive the cached indices directly.
Internal state structure (self._states):
.. code-block:: python
self._states = {
"partition_id": {
"task_name": {
dp_rank: [sampled_indexes]
}
}
}
State lifecycle:
1. First rank samples from ``ready_indexes``, caches results for other ranks
2. Other ranks pop and retrieve the cached indices
transfer_queue/sampler/base.py:53
- Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_rank'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'batch_index'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'task_name'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'partition_id'. Overriding method method RankAwareSampler.sample matches the call.
def sample(
self,
ready_indexes: list[int],
batch_size: int,
*args: Any,
**kwargs: Any,
) -> tuple[list[int], list[int]]:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tutorial/05_streaming_dataloader.py
Outdated
| dataset = StreamingDataset( | ||
| config=config, | ||
| batch_size=2, | ||
| micro_batch_size=2, # Number of samples per batch |
There was a problem hiding this comment.
The inline comment "Number of samples per batch" on line 208 is misleading. The micro_batch_size parameter controls the size of each micro-batch that results from splitting the main batch (controlled by batch_size on line 207). Since both are set to 2, there will be no splitting, but the comment should clarify that this is the micro-batch size, not the total batch size.
| micro_batch_size=2, # Number of samples per batch | |
| micro_batch_size=2, # Number of samples per micro-batch (main batch size is controlled by batch_size) |
tutorial/05_streaming_dataloader.py
Outdated
| Ranks in the same group receive the same data samples | ||
| data_replica_rank: Local rank index within the data replica group | ||
| Range: [0, data_replica_world_size - 1] | ||
| data_replica_world_size: Total number of ranks in this data replica group | ||
| config: TransferQueue configuration | ||
| max_steps: Maximum number of batches to consume | ||
|
|
||
| Returns: | ||
| dict: Contains data_replica_rank, data_replica_group, and consumed_ids | ||
|
|
||
| Example: | ||
| For a setup with 2 data replica groups (0 and 1), each with 2 ranks: | ||
| For a setup with 2 data rank (0 and 1), each with 2 ranks: | ||
| - Group 0: ranks [0, 1] receive identical samples | ||
| - Group 1: ranks [2, 3] receive identical samples | ||
| All ranks within the same group get the same global indexes. |
There was a problem hiding this comment.
The example description states "Ranks in the same group receive the same data samples" (line 184) and "Group 0: ranks [0, 1] receive identical samples" (line 193), but with the new implementation, this only works if ranks with the same dp_rank call the sampler with the same batch_index values. If ranks iterate at different speeds or start at different times, they will have different batch_index values and receive different samples. The documentation should clarify this synchronization requirement or the behavior when ranks are not synchronized.
tutorial/05_streaming_dataloader.py
Outdated
|
|
||
| Example: | ||
| For a setup with 2 data replica groups (0 and 1), each with 2 ranks: | ||
| For a setup with 2 data rank (0 and 1), each with 2 ranks: |
There was a problem hiding this comment.
The tutorial example comment is grammatically incorrect. "For a setup with 2 data rank (0 and 1), each with 2 ranks" should be "For a setup with 2 data parallel ranks (0 and 1)" - but this doesn't make logical sense as written. The example describes groups where ranks [0,1] receive identical samples and ranks [2,3] receive identical samples, which suggests there should be 2 groups with 2 ranks each, not 2 ranks.
| For a setup with 2 data rank (0 and 1), each with 2 ranks: | |
| For a setup with 2 data replica groups, each with 2 ranks (global ranks 0–3): |
| if dp_rank is not None: | ||
| if dp_rank not in states: | ||
| states[dp_rank] = {} | ||
| states[dp_rank][batch_index] = (sampled_indexes, consumed_indexes, 1) | ||
| elif batch_index not in states[dp_rank]: | ||
| states[dp_rank][batch_index] = (sampled_indexes, consumed_indexes, 1) |
There was a problem hiding this comment.
The caching logic has a bug. When dp_rank not in states, it creates states[dp_rank] = {} and immediately sets states[dp_rank][batch_index] in line 185. However, if the condition on line 186 (elif batch_index not in states[dp_rank]) is entered (which can only happen if dp_rank in states already), it sets the same value again on line 187. This redundant assignment in the elif branch is unnecessary since it's already handled in the if branch. The elif should just be updating the states without the initial dictionary creation.
transfer_queue/sampler/base.py
Outdated
| partition_id: The partition ID associated with the task. | ||
| """ | ||
| if partition_id in self._states.keys(): | ||
| self._states[partition_id] = {} |
There was a problem hiding this comment.
Redundant operation: Line 89 calls self._states.pop(partition_id) after already clearing the partition with self._states[partition_id] = {} on line 88. The line 88 assignment is unnecessary since line 90 immediately removes the key. Either remove line 89 (just clear the dict) or remove line 88 (just pop the key).
| self._states[partition_id] = {} |
| sampling_config["batch_index"] = batch_index | ||
| sampling_config["partition_id"] = partition_id | ||
| batch_meta = tq_client.get_meta( | ||
| data_fields=data_fields, | ||
| batch_size=batch_size, | ||
| partition_id=partition_id, | ||
| task_name=task_name, | ||
| sampling_config=sampling_config, |
There was a problem hiding this comment.
Mutation of shared state: The function mutates the sampling_config dictionary by adding batch_index and partition_id keys (lines 256-257). This could cause unexpected side effects if the same dictionary is reused across multiple calls. Consider creating a copy of the config before mutation: config = {**sampling_config, "batch_index": batch_index, "partition_id": partition_id}.
| sampling_config["batch_index"] = batch_index | |
| sampling_config["partition_id"] = partition_id | |
| batch_meta = tq_client.get_meta( | |
| data_fields=data_fields, | |
| batch_size=batch_size, | |
| partition_id=partition_id, | |
| task_name=task_name, | |
| sampling_config=sampling_config, | |
| config = {**sampling_config, "batch_index": batch_index, "partition_id": partition_id} | |
| batch_meta = tq_client.get_meta( | |
| data_fields=data_fields, | |
| batch_size=batch_size, | |
| partition_id=partition_id, | |
| task_name=task_name, | |
| sampling_config=config, |
| identify which data replica group this rank belongs to. | ||
| data_replica_rank: Local rank inside this data_replica_group. | ||
| data_replica_world_size: Total number of ranks in this data_replica_group. | ||
| dp_rank: The group id of current dp_rank. Used to identify which dp_rank belongs to. |
There was a problem hiding this comment.
Grammatical error in the docstring. "The group id of current dp_rank. Used to identify which dp_rank belongs to" is incomplete and grammatically incorrect. Should be something like "The data parallel rank ID for this worker. Used to identify which data parallel group this rank belongs to."
| dp_rank: The group id of current dp_rank. Used to identify which dp_rank belongs to. | |
| dp_rank: The data parallel rank ID for this worker. Used to identify which | |
| data parallel group this rank belongs to. |
tutorial/05_streaming_dataloader.py
Outdated
| Args: | ||
| rank_id: Global rank identifier for logging and display purposes | ||
| data_replica_group: ID of the data parallel group this rank belongs to | ||
| dp_rank: ID of the data parallel rank belongs to |
There was a problem hiding this comment.
Grammatical error in docstring. "ID of the data parallel rank belongs to" is incomplete. Should be something like "Data parallel rank ID that this worker belongs to" or similar complete sentence.
| dp_rank: ID of the data parallel rank belongs to | |
| dp_rank: Data parallel rank ID that this worker belongs to |
| if len(required_fields) < 1: | ||
| raise ValueError(f"required_fields must be a list with at least one field name, got {required_fields}") | ||
| if len(data_fields) < 1: | ||
| raise ValueError(f"required_fields must be a list with at least one field name, got {data_fields}") |
There was a problem hiding this comment.
The error message still references the old parameter name required_fields instead of the new parameter name data_fields. The message should be updated to: "data_fields must be a list with at least one field name, got {data_fields}"
| raise ValueError(f"required_fields must be a list with at least one field name, got {data_fields}") | |
| raise ValueError(f"data_fields must be a list with at least one field name, got {data_fields}") |
|
|
||
| # Return cached result if available | ||
| if dp_rank in states.keys() and batch_index in states[dp_rank].keys(): | ||
| return states[dp_rank][batch_index] |
There was a problem hiding this comment.
The cached result is a tuple of 3 elements (sampled_indexes, consumed_indexes, 1) as shown in lines 185 and 187, but here it's being returned directly as a 3-element tuple. However, the function signature indicates it should return a 2-element tuple (sampled_indexes, consumed_indexes). This will cause the caller to receive an unexpected third element.
| return states[dp_rank][batch_index] | |
| cached = states[dp_rank][batch_index] | |
| sampled_indexes, consumed_indexes = cached[0], cached[1] | |
| return sampled_indexes, consumed_indexes |
| data_replica_rank: int, | ||
| data_replica_world_size: int, | ||
| dp_rank: int, | ||
| batch_index: int, |
There was a problem hiding this comment.
Need to update docstring so user can know what is this param
There was a problem hiding this comment.
BTW, batch_index has already used to refer to the relative position of each sample in BatchMeta. Consider use another one?
| data_replica_group: { | ||
| data_replica_rank: [[sampled_indexes], ...] # Buffer of cached sampled indices | ||
| } | ||
| dp_rank: [sampled_indexes] |
| identify which data replica group this rank belongs to. | ||
| data_replica_rank: Local rank inside this data_replica_group. | ||
| data_replica_world_size: Total number of ranks in this data_replica_group. | ||
| dp_rank: The group id of current dp_rank. Used to identify which dp_rank belongs to. |
transfer_queue/__init__.py
Outdated
| import os | ||
|
|
||
| from .client import ( | ||
| AsyncTransferQueueClient, |
There was a problem hiding this comment.
Now TransferQueueClient has both sync & async function. So we don't need this to be exposed.
|
|
||
| if mode == "insert": | ||
| partition = self._get_partition(partition_id) | ||
| if partition_id not in self.partitions: |
There was a problem hiding this comment.
Is this safe for other modes when user doesn't specify partition_id?
There was a problem hiding this comment.
Yes, I've checked—when partition_id is not provided, it defaults to returning an empty value, which should not cause any issues. I plan to rely on the pipeline to ensure no additional problems are introduced.
| data_replica_rank: int, | ||
| data_replica_world_size: int, | ||
| dp_rank: int, | ||
| n_samples_per_prompt: int, |
There was a problem hiding this comment.
This is algorithm related param. Can we implement this in Sampler rather than in StreamingDataset?
There was a problem hiding this comment.
batch_meta = tq_client.get_meta(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
task_name=task_name,
sampling_config=sampling_config,
)
The get_meta function requires n_samples_per_prompt; I can set its default value to 1.
| data_replica_world_size: int, | ||
| dp_rank: int, | ||
| n_samples_per_prompt: int, | ||
| custom_get_batch_func: Any = None, |
There was a problem hiding this comment.
How about fetch_batch_fn: Callable = None?
| dp_rank: int, | ||
| n_samples_per_prompt: int, | ||
| custom_get_batch_func: Any = None, | ||
| custom_post_process_for_micro_func: Any = None, |
There was a problem hiding this comment.
How about process_batch_fn: Callable = None?
StreamDataLoader implementation
tests/test_samplers.py
Outdated
| assert sampled1_g0 == [0, 1] | ||
| assert consumed1_g0 == [0, 1] | ||
| assert sampled1 == [0, 1] | ||
| assert sampled2 == [0, 1] # Same because both sample from the beginning of ready_indexes |
There was a problem hiding this comment.
This is a little bit misleading. We should mimic the controller's behavior that manually delete the samples in ready_indexes that show up in consumed1
tests/test_samplers.py
Outdated
| assert sampler._states["test"]["task0"][0][1] == [[0, 1]] | ||
| assert sampler._states["test"]["task1"][0][0] == [] | ||
| assert sampler._states["test"]["task1"][0][1] == [[0, 1]] | ||
| # Should return cached result |
There was a problem hiding this comment.
Is this already checked in test_rank_aware_sampler_multiple_dp_ranks()?
| data_replica_world_size: int, | ||
| dp_rank: int, | ||
| n_samples_per_prompt: int = 1, | ||
| custom_get_batch_func: Callable | None = None, |
| data_replica_rank: int, | ||
| data_replica_world_size: int, | ||
| dp_rank: int, | ||
| n_samples_per_prompt: int = 1, |
There was a problem hiding this comment.
Maybe we need to add a TODO here. It's possible to set this n_sample param during the initialization of Sampler, so we don't have to set it during get_meta(). It can make the StreamingDataLoader more universal
tests/test_samplers.py
Outdated
| assert consumed0_g0 == [0, 1] | ||
| assert consumed0_g1 == [2, 3] | ||
| assert sampled1 == [0, 1] | ||
| assert sampled2 == [2, 3] # Same because both sample from the beginning of ready_indexes |
There was a problem hiding this comment.
The comment here is wrong
…entation to support fully asynchronous mode Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
RankAwareSamplerimplementation.GRPOGroupNSampler.StreamingDataLoader.