Skip to content

[BUG] BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch#28201

Merged
ArthurZucker merged 1 commit intohuggingface:mainfrom
inkinworld:fix-barkeos-processor
Jan 10, 2024
Merged

[BUG] BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch#28201
ArthurZucker merged 1 commit intohuggingface:mainfrom
inkinworld:fix-barkeos-processor

Conversation

@inkinworld
Copy link
Copy Markdown
Contributor

@inkinworld inkinworld commented Dec 22, 2023

What does this PR do?

Fixes bug about transformers.generation.logits_process.BarkEosPrioritizerLogitsProcessor.
when BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch.

such as below test case:

    def test_early_stop_processor_multi_eos(self):
        input_ids = None
        eos_token_id = [2, 3]
        min_eos_p = 0.1  ## some small float

        scores = self._get_uniform_logits(2, 4)
        scores[0][eos_token_id] = -6  ## less than log(min_eos_p)

        esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
        actual_scores = esp(input_ids, scores)
        expected_scores_list = [
            scores[0].tolist(),
            [float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
        ]
        self.assertListEqual(actual_scores.tolist(), expected_scores_list)

will occur this exception

self = <transformers.generation.logits_process.BarkEosPrioritizerLogitsProcessor object at 0x12f1e0220>
input_ids = None
scores = tensor([[ 0.2500,  0.2500, -6.0000, -6.0000],
        [ 0.2500,  0.2500,  0.2500,  0.2500]])

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.min_eos_p:
            probs = torch.nn.functional.softmax(scores.float(), dim=-1)
            # create scores full of -inf except for the eos_token_id
            early_stop_scores = torch.ones_like(scores) * -float("inf")
            early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
    
            do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
            # do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True)
>           scores = torch.where(do_early_stop, early_stop_scores, scores)
E           RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 1

src/transformers/generation/logits_process.py:2142: RuntimeError

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@gante

…eos_token_id use list, tensor size mismatch
@inkinworld inkinworld force-pushed the fix-barkeos-processor branch from 9f1faed to ce52bef Compare December 22, 2023 11:23
Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍 Thank you for reporting and fixing the issue! 💛

@gante gante requested a review from ArthurZucker January 9, 2024 19:18
@gante
Copy link
Copy Markdown
Contributor

gante commented Jan 9, 2024

(cc @ylacombe )

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix and taking the time to add a test 🤗

@ArthurZucker ArthurZucker merged commit 4df1d69 into huggingface:main Jan 10, 2024
@ylacombe
Copy link
Copy Markdown
Contributor

Thanks for fixing @inkinworld !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants