Skip to content

Update RewardFunc type to use RewardCallable protocol#4938

Open
amit9oct wants to merge 2 commits intohuggingface:mainfrom
amit9oct:patch-1
Open

Update RewardFunc type to use RewardCallable protocol#4938
amit9oct wants to merge 2 commits intohuggingface:mainfrom
amit9oct:patch-1

Conversation

@amit9oct
Copy link

@amit9oct amit9oct commented Jan 31, 2026

Refactor RewardFunc type to use RewardCallable protocol for better type safety because it matches the actual usage of the reward function in the code:

output_reward_func = reward_func(
    prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
)

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case. GRPO Reward Function documentation and usage mismatch #4939
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Refactor RewardFunc type to use RewardCallable protocol for better type safety because it matches the actual usage of the reward function in the code:
```python
output_reward_func = reward_func(
    prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
)
```
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors the GRPO trainer’s RewardFunc type alias to use a RewardCallable Protocol, aligning the public typing with how reward functions are actually invoked (keyword args like completion_ids and **kwargs).

Changes:

  • Added a RewardCallable Protocol describing the expected callable signature for custom reward functions.
  • Updated RewardFunc to use RewardCallable instead of a narrow Callable[[list, list], list[float]].
  • Updated typing imports to include Protocol (and Optional).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

from functools import partial
from pathlib import Path
from typing import Any
from typing import Any, Protocol, Optional
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional is imported but never used in this module. Please remove it to avoid lint/type-check failures from unused imports.

Suggested change
from typing import Any, Protocol, Optional
from typing import Any, Protocol

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Optional typing should be removed from here: it is not used.

Comment on lines +113 to +124
class RewardCallable(Protocol):
def __call__(
self,
prompts: list[Any] | None = None,
completions: list[Any] | None = None,
completion_ids: list[list[int]] | None = None,
**kwargs: Any,
) -> list[float | None]: ...

# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]]
RewardFunc = str | PreTrainedModel | RewardCallable
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GRPOTrainer supports async reward functions (see _calculate_rewards: asyncio.iscoroutinefunction(reward_func)), but RewardFunc = ... | RewardCallable currently only types sync callables (returning list[float | None]). This makes the public type alias narrower than actual supported usage. Consider extending the typing to also accept async callables (e.g., a separate AsyncRewardCallable protocol returning an awaitable, or a union return type).

Copilot uses AI. Check for mistakes.
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Shouldn't we use the same typing in all the trainers?

from functools import partial
from pathlib import Path
from typing import Any
from typing import Any, Protocol, Optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Optional typing should be removed from here: it is not used.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments