Update RewardFunc type to use RewardCallable protocol#4938
Update RewardFunc type to use RewardCallable protocol#4938amit9oct wants to merge 2 commits intohuggingface:mainfrom
Conversation
Refactor RewardFunc type to use RewardCallable protocol for better type safety because it matches the actual usage of the reward function in the code:
```python
output_reward_func = reward_func(
prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
)
```
There was a problem hiding this comment.
Pull request overview
Refactors the GRPO trainer’s RewardFunc type alias to use a RewardCallable Protocol, aligning the public typing with how reward functions are actually invoked (keyword args like completion_ids and **kwargs).
Changes:
- Added a
RewardCallableProtocoldescribing the expected callable signature for custom reward functions. - Updated
RewardFuncto useRewardCallableinstead of a narrowCallable[[list, list], list[float]]. - Updated typing imports to include
Protocol(andOptional).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from functools import partial | ||
| from pathlib import Path | ||
| from typing import Any | ||
| from typing import Any, Protocol, Optional |
There was a problem hiding this comment.
Optional is imported but never used in this module. Please remove it to avoid lint/type-check failures from unused imports.
| from typing import Any, Protocol, Optional | |
| from typing import Any, Protocol |
There was a problem hiding this comment.
The Optional typing should be removed from here: it is not used.
| class RewardCallable(Protocol): | ||
| def __call__( | ||
| self, | ||
| prompts: list[Any] | None = None, | ||
| completions: list[Any] | None = None, | ||
| completion_ids: list[list[int]] | None = None, | ||
| **kwargs: Any, | ||
| ) -> list[float | None]: ... | ||
|
|
||
| # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of | ||
| # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. | ||
| RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]] | ||
| RewardFunc = str | PreTrainedModel | RewardCallable |
There was a problem hiding this comment.
GRPOTrainer supports async reward functions (see _calculate_rewards: asyncio.iscoroutinefunction(reward_func)), but RewardFunc = ... | RewardCallable currently only types sync callables (returning list[float | None]). This makes the public type alias narrower than actual supported usage. Consider extending the typing to also accept async callables (e.g., a separate AsyncRewardCallable protocol returning an awaitable, or a union return type).
albertvillanova
left a comment
There was a problem hiding this comment.
Thanks. Shouldn't we use the same typing in all the trainers?
| from functools import partial | ||
| from pathlib import Path | ||
| from typing import Any | ||
| from typing import Any, Protocol, Optional |
There was a problem hiding this comment.
The Optional typing should be removed from here: it is not used.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Refactor RewardFunc type to use RewardCallable protocol for better type safety because it matches the actual usage of the reward function in the code:
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case. GRPO Reward Function documentation and usage mismatch #4939
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.