From 61fafe8e26564dbc6c77ff2c76e947f23da7171a Mon Sep 17 00:00:00 2001 From: shubhamugare Date: Sat, 12 Apr 2025 16:00:55 -0500 Subject: [PATCH] Fix Phi-4 issue --- syncode/language_model.py | 8 +++++++- syncode/mask_store/byte_tokenizer.py | 7 +++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/syncode/language_model.py b/syncode/language_model.py index e1cac7ff..59f96605 100644 --- a/syncode/language_model.py +++ b/syncode/language_model.py @@ -111,7 +111,8 @@ def generate_grammar_constrained_completion( stop_criteria = [] # Generate completions - if self.opp and (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1: # Use our own implementation for greedy search and sampling + if self.opp and (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1: + # Use our own implementation for greedy search and sampling generated_ids = self._generate( inputs, gen_config, @@ -239,6 +240,11 @@ def _generate( # (the clone itself is always small) next_token_scores = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=token_ids.device) + if len(next_token_scores.shape) == 3: + # FIXME: This is a strange behaviour for some models like Phi-4 + # We expect next_token_scores to be of shape (batch_size, vocab_size) + next_token_scores = next_token_scores[:, -1, :] + if grammar_decoder is not None: next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores) is_valid = grammar_decoder.is_valid(token_ids, next_token) diff --git a/syncode/mask_store/byte_tokenizer.py b/syncode/mask_store/byte_tokenizer.py index a443e917..81edd5d2 100644 --- a/syncode/mask_store/byte_tokenizer.py +++ b/syncode/mask_store/byte_tokenizer.py @@ -188,6 +188,13 @@ def __init__(self, tokenizer, vocab_type=None): # Cache special token IDs as a set for faster lookups self.special_token_ids = set(getattr(tokenizer, "all_special_ids", [])) + # Added tokens are typically special tokens + # if added_tokens_decoder is not None self.tokenizer.added_tokens_decoder.keys() + # to special_token_ids + if hasattr(tokenizer, "added_tokens_decoder"): + self.special_token_ids.update(tokenizer.added_tokens_decoder.keys()) + + @classmethod def from_pretrained(cls, model_id, vocab_type=None): """