-
Notifications
You must be signed in to change notification settings - Fork 404
[ReplayBuffer] add ReplayBuffer with various StorageBackend #1490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rl_design
Are you sure you want to change the base?
Conversation
…aleness, or Database(implement in the future)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces a new ReplayBuffer abstraction in xtuner/v1/rl/base with pluggable storage backends (e.g., FIFO and staleness-priority), plus initial unit tests for basic FIFO and staleness ordering behavior.
Changes:
- Added
ReplayBuffer,StorageBackendinterface, and multiple backend implementations (FIFOStorageBackend,StalenessStorageBackend, plus stub/pseudocode backends). - Implemented
StorageIndicesto partition storage by(task_name, group_status, tags). - Added async unit tests covering FIFO behavior, staleness priority order, and multi-task isolation.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
xtuner/v1/rl/base/replay_buffer.py |
Adds the replay buffer API and backend implementations (FIFO + staleness), with placeholder backends for future extensions. |
tests/ray/test_replay_buffer.py |
Adds async unit tests validating basic replay buffer behavior for FIFO and staleness backends. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| @dataclass | ||
| class StorageIndices: | ||
| # 为不同存储后段提供统一的索引接口 |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StorageIndices doc comment has a typo: “存储后段” should be “存储后端”.
| # 为不同存储后段提供统一的索引接口 | |
| # 为不同存储后端提供统一的索引接口 |
|
|
||
|
|
||
| class ReplayBuffer: | ||
| def __init__(self, storage_backend: StorageBackend = None): |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReplayBuffer.__init__ annotates storage_backend as StorageBackend but defaults it to None, which will fail type checking under mypy’s strict optional rules. Please change the annotation to StorageBackend | None (or Optional[StorageBackend]).
| def __init__(self, storage_backend: StorageBackend = None): | |
| def __init__(self, storage_backend: StorageBackend | None = None): |
| indices = self._hash_storage_indices(storage_indices) | ||
| group_seq_staleness = max([item.seq_staleness for item in items]) |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StalenessStorageBackend.put will crash if items is empty (max() on empty list). Also, if any seq_staleness falls outside [min_staleness, max_staleness] (defaults are both 0), self._storage[indices][group_seq_staleness] will raise KeyError. Consider explicitly handling empty input (no-op or clear error) and validating/clamping seq_staleness to the configured bucket range (or dynamically creating buckets).
| indices = self._hash_storage_indices(storage_indices) | |
| group_seq_staleness = max([item.seq_staleness for item in items]) | |
| # If there are no items, treat this as a no-op to avoid max() on an empty list. | |
| if not items: | |
| return | |
| indices = self._hash_storage_indices(storage_indices) | |
| group_seq_staleness = max(item.seq_staleness for item in items) | |
| # Clamp staleness into the configured bucket range to avoid KeyError. | |
| group_seq_staleness = max(self.min_staleness, | |
| min(self.max_staleness, group_seq_staleness)) |
| def __init__(self, limit: int = 0, max_staleness: int = 0, min_staleness: int = 0): | ||
| self.limit = limit | ||
| self.max_staleness = max_staleness | ||
| self.min_staleness = min_staleness | ||
| self._storage = defaultdict(lambda: {i: deque() for i in range(min_staleness, max_staleness + 1)}) | ||
| self._bucket_counts = defaultdict(int) |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StalenessStorageBackend.__init__ accepts a limit parameter but it is never enforced (items can grow unbounded). Either implement eviction behavior consistent with FIFOStorageBackend(limit=...) or remove the parameter to avoid misleading API.
| new_df = pd.DataFrame(new_rows) | ||
| self._df = pd.concat([self._df, new_df], ignore_index=True, sort=False) | ||
|
|
||
| def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PandasStorageBackend.get is defined as a synchronous method, but the StorageBackend interface defines async def get(...). Even though the class currently raises in __init__, keeping signatures consistent will prevent accidental misuse later and avoids confusing API expectations (callers will await this).
| def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: | |
| async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: |
| def _hash_storage_indices(self, indices: StorageIndices) -> tuple: | ||
| base = (indices.task_name, indices.group_status) | ||
|
|
||
| if indices.tags: | ||
| sorted_tags = tuple(sorted(indices.tags.items())) | ||
| return base + sorted_tags | ||
| return base |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_hash_storage_indices builds a tuple used as a dict key; if any tag value is unhashable (e.g., list/dict), this will raise TypeError at runtime. Since tags come from **kwargs, consider validating/coercing tag values to hashable types (e.g., str(value)/json.dumps) or restricting the accepted tag value types.
| # 构建动态查询 | ||
| query = "SELECT id, data FROM replay_buffer WHERE task_name = ? AND group_status = ?" | ||
| params = [indices.task_name, indices.group_status] | ||
|
|
||
| # SQLite 的 JSON 查询语法 (需要 SQLite 3.38+,如果是旧版本需要用 LIKE 模拟或不做 DB 级过滤) | ||
| # 这里演示简单的方法:如果在 Python 端过滤 tags 效率低,但在 SQL 端过滤 JSON 语法较复杂。 | ||
| # 为了通用性,这里我只用 task 和 status 查出候选集,然后用 Python 过滤 Tags (如果 tags 很复杂建议把 tags 独立成列) | ||
| # 或者使用 JSON_EXTRACT (推荐) | ||
| for key, value in indices.tags.items(): | ||
| # 注意:JSON 中数值和字符串的区别。这里假设 value 都是简单类型。 | ||
| # $.key 取出对应的值 | ||
| query += f" AND json_extract(tags, '$.{key}') = ?" | ||
| params.append(value) | ||
|
|
||
| query += f" LIMIT {count}" | ||
|
|
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SQLStorageBackend.get builds SQL using f-strings for both the JSON path ($.{key}) and LIMIT {count}. If key is user-controlled (it comes from indices.tags), this is a SQL injection risk once this backend is implemented. Prefer validating key against an allowlist/regex and using parameter binding for LIMIT (and avoid interpolating raw values into the query string).
| tags: dict = field(default_factory=dict) # 非等于的条件则使用 scores_gt > 0.8 | ||
|
|
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tags-based partitioning logic is part of the public ReplayBuffer.put/get API (via **kwargs), but there are no tests asserting that different tag values map to different storage partitions and don’t mix. Consider adding a small test that writes items with different tag combinations and verifies isolation.
| class PandasStorageBackend(StorageBackend): | ||
| def __init__(self, limit: int = 0): | ||
| raise NotImplementedError("PandasStorageBackend is under development and not yet implemented.") | ||
| import pandas as pd |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This statement is unreachable.
| class SQLStorageBackend(StorageBackend): | ||
| def __init__(self, db_path: str = ":memory:"): | ||
| raise NotImplementedError("SQLStorageBackend is under development and not yet implemented.") | ||
| self.db_path = db_path |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This statement is unreachable.
ReplayBuffer 设计说明
StorageIndices
数据索引类,给不同的后端支持统一的索引的方法
StorageBackend
抽象的存储后端,支持不同类型的存储,例如最简单的FIFO
FIFOStorageBackend, 优先级队列StalenessStorageBackend, 数据库等;并且提供了提供了PandasStorageBackend,SQLStorageBackend的伪代码作为参考;ReplayBufffer