From 3774b56ea220b714d1a6b25ef1cf2d15429e6820 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 2 Jun 2025 19:30:57 +0000 Subject: [PATCH 01/48] use flexible model loading --- verifiers/utils/model_utils.py | 42 ++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index 12679412b..9d1a8ac63 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -1,8 +1,9 @@ from importlib.util import find_spec +from importlib import import_module from typing import Dict, Any, Union, Tuple, Callable import torch -from transformers import AutoModelForCausalLM, AutoTokenizer # type: ignore +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PreTrainedModel # type: ignore import torch.nn as nn @@ -59,6 +60,43 @@ def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn. def is_liger_available() -> bool: return find_spec("liger_kernel") is not None +def generic_model_loader(model_id: str, **model_kwargs) -> PreTrainedModel: + cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + for arch in cfg.architectures or []: + try: + cls = getattr(import_module("transformers"), arch) + return cls.from_pretrained( + model_id, + trust_remote_code=True, + **model_kwargs, + ) + except (AttributeError, ImportError, ValueError): + pass + + from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForVision2Seq, + ) + + for auto_cls in ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForVision2Seq, + AutoModel, + ): + try: + return auto_cls.from_pretrained( + model_id, + trust_remote_code=True, + **model_kwargs, + ) + except ValueError: + continue + + raise RuntimeError(f"No suitable loader found for model type {cfg.model_type!r}") + def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[str, Any], None] = None) -> Any: if model_kwargs is None: model_kwargs = dict( @@ -71,7 +109,7 @@ def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[ from liger_kernel.transformers import AutoLigerKernelForCausalLM # type: ignore return AutoLigerKernelForCausalLM.from_pretrained(model_name, **model_kwargs) else: - return AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + return generic_model_loader(model_name, **model_kwargs) def get_tokenizer(model_name: str) -> Any: tokenizer = AutoTokenizer.from_pretrained(model_name) From 2fbd45a39b59fa83e2ad958681727f5b012c5595 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 2 Jun 2025 19:31:16 +0000 Subject: [PATCH 02/48] start example --- verifiers/examples/docvqa.py | 55 ++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 verifiers/examples/docvqa.py diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py new file mode 100644 index 000000000..e5a489eff --- /dev/null +++ b/verifiers/examples/docvqa.py @@ -0,0 +1,55 @@ +from datasets import load_dataset +import verifiers as vf + +""" +CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' +TODO: + - check liger kernel support + - check that generic_model_loader didn't break anything +""" + +def preprocess_docvqa(x): + return { + "question": x["question"], + "images": [x["image"]], + "answer": x["answers"][0], + } + + +dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation") +dataset = dataset.map( + preprocess_docvqa, num_proc=10, remove_columns=dataset.column_names +) + +parser = vf.XMLParser(["think", "answer"], answer_field="answer") +system_prompt = f"""Answer the questions. + +Respond in the following format: +{parser.get_format_str()}""" + +rubric = vf.Rubric( + funcs=[ + parser.get_format_reward_func(), + ] +) + +vf_env = vf.SingleTurnEnv( + dataset=dataset, system_prompt=system_prompt, parser=parser, rubric=rubric +) + +model_name = "Qwen/Qwen2.5-VL-3B-Instruct" +model, tokenizer = vf.get_model_and_tokenizer(model_name, use_liger=False) # TODO: modify model loading to add liger support +run_name = "docvqa_" + model_name.split("/")[-1].lower() + +training_args = vf.grpo_defaults(run_name=run_name) +training_args.per_device_train_batch_size = 4 +training_args.num_generations = 4 + +trainer = vf.GRPOTrainer( + model=model, + processing_class=tokenizer, + env=vf_env, + args=training_args, +) +import pdb;pdb.set_trace() +trainer.train() From ad19f9fac9bac94bf6ee5bcb39c7349009eab875 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 2 Jun 2025 20:46:22 +0000 Subject: [PATCH 03/48] use AutoProcessor class --- README.md | 2 +- verifiers/__init__.py | 6 ++--- verifiers/examples/docvqa.py | 6 ++--- verifiers/examples/doublecheck.py | 2 +- verifiers/examples/math_python.py | 2 +- verifiers/examples/reverse_text.py | 2 +- verifiers/examples/self_reward.py | 2 +- verifiers/examples/sft/math_python.py | 2 +- verifiers/examples/sft/reverse_text.py | 2 +- verifiers/examples/smola_math_tools.py | 2 +- verifiers/trainers/grpo_trainer.py | 33 ++++++++++++++++++++++---- verifiers/utils/__init__.py | 6 ++--- verifiers/utils/model_utils.py | 13 +++++----- 13 files changed, 52 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 6ef770d67..e91fba9bf 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ See `verifiers/examples/sft/reverse_text.py` for an example script using TRL's S ```python # train.py -model, tokenizer = vf.get_model_and_tokenizer(model_name) +model, tokenizer = vf.get_model_and_processor(model_name) trainer = vf.GRPOTrainer( model=model, processing_class=tokenizer, diff --git a/verifiers/__init__.py b/verifiers/__init__.py index 92816d83e..c02b21d81 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -27,7 +27,7 @@ from .trainers.grpo_trainer import GRPOTrainer from .trainers.grpo_config import GRPOConfig from .utils.data_utils import extract_boxed_answer, extract_hash_answer, load_example_dataset -from .utils.model_utils import get_model, get_tokenizer, get_model_and_tokenizer +from .utils.model_utils import get_model, get_processor, get_model_and_processor from .utils.config_utils import grpo_defaults, lora_defaults __version__ = "0.1.0" @@ -53,8 +53,8 @@ "GRPOConfig", "VLLMClient", "get_model", - "get_tokenizer", - "get_model_and_tokenizer", + "get_processor", + "get_model_and_processor", "grpo_defaults", "lora_defaults", "extract_boxed_answer", diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index e5a489eff..c1f89a6ee 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -3,6 +3,7 @@ """ CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' +CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py TODO: - check liger kernel support - check that generic_model_loader didn't break anything @@ -38,7 +39,7 @@ def preprocess_docvqa(x): ) model_name = "Qwen/Qwen2.5-VL-3B-Instruct" -model, tokenizer = vf.get_model_and_tokenizer(model_name, use_liger=False) # TODO: modify model loading to add liger support +model, processor = vf.get_model_and_processor(model_name, use_liger=False) # TODO: modify model loading to add liger support run_name = "docvqa_" + model_name.split("/")[-1].lower() training_args = vf.grpo_defaults(run_name=run_name) @@ -47,9 +48,8 @@ def preprocess_docvqa(x): trainer = vf.GRPOTrainer( model=model, - processing_class=tokenizer, + processing_class=processor, env=vf_env, args=training_args, ) -import pdb;pdb.set_trace() trainer.train() diff --git a/verifiers/examples/doublecheck.py b/verifiers/examples/doublecheck.py index 71671bcb7..fa57bf0ec 100644 --- a/verifiers/examples/doublecheck.py +++ b/verifiers/examples/doublecheck.py @@ -12,7 +12,7 @@ system_prompt=SIMPLE_PROMPT, few_shot=[] ) -model, tokenizer = vf.get_model_and_tokenizer(model_name) +model, tokenizer = vf.get_model_and_processor(model_name) args = vf.grpo_defaults(run_name="doublecheck-{}".format(model_name.split("/")[-1].lower())) trainer = vf.GRPOTrainer( model=model, diff --git a/verifiers/examples/math_python.py b/verifiers/examples/math_python.py index 9b06d597f..7499c7ab3 100644 --- a/verifiers/examples/math_python.py +++ b/verifiers/examples/math_python.py @@ -58,7 +58,7 @@ print(vf_env.system_prompt) model_name = "willcb/Qwen2.5-7B-Math-Python-SFT" -model, tokenizer = vf.get_model_and_tokenizer(model_name) +model, tokenizer = vf.get_model_and_processor(model_name) run_name = "math-grpo_" + model_name.split("/")[-1].lower() training_args=vf.grpo_defaults(run_name=run_name) diff --git a/verifiers/examples/reverse_text.py b/verifiers/examples/reverse_text.py index 494445aef..811f66afc 100644 --- a/verifiers/examples/reverse_text.py +++ b/verifiers/examples/reverse_text.py @@ -48,7 +48,7 @@ def lcs_ratio(x: str, y: str) -> float: args.eval_steps = 10 args.max_steps = 100 -model, tokenizer = vf.get_model_and_tokenizer(model_name) +model, tokenizer = vf.get_model_and_processor(model_name) trainer = vf.GRPOTrainer( model=model, processing_class=tokenizer, diff --git a/verifiers/examples/self_reward.py b/verifiers/examples/self_reward.py index 30722441a..701a030ee 100644 --- a/verifiers/examples/self_reward.py +++ b/verifiers/examples/self_reward.py @@ -10,6 +10,6 @@ system_prompt="You are a helpful assistant.", rubric=rubric ) -model, tokenizer = vf.get_model_and_tokenizer(model_name) +model, tokenizer = vf.get_model_and_processor(model_name) trainer = vf.GRPOTrainer(env=vf_env, model=model, processing_class=tokenizer, args=vf.grpo_defaults(run_name="self_reward")) trainer.train() \ No newline at end of file diff --git a/verifiers/examples/sft/math_python.py b/verifiers/examples/sft/math_python.py index a15aff4a0..7562218e7 100644 --- a/verifiers/examples/sft/math_python.py +++ b/verifiers/examples/sft/math_python.py @@ -7,7 +7,7 @@ """ # convenience function for FA2 initialization -model, tokenizer = vf.get_model_and_tokenizer("Qwen/Qwen2.5-7B-Instruct", use_liger=False) +model, tokenizer = vf.get_model_and_processor("Qwen/Qwen2.5-7B-Instruct", use_liger=False) dataset = load_dataset('willcb/V3-gsm8k-python-test', split='train') tok_counts = [] diff --git a/verifiers/examples/sft/reverse_text.py b/verifiers/examples/sft/reverse_text.py index 77f3cb2d5..8aab76f1e 100644 --- a/verifiers/examples/sft/reverse_text.py +++ b/verifiers/examples/sft/reverse_text.py @@ -7,7 +7,7 @@ """ # convenience function for FA2 initialization -model, tokenizer = vf.get_model_and_tokenizer("Qwen/Qwen2.5-0.5B-Instruct", use_liger=False) +model, tokenizer = vf.get_model_and_processor("Qwen/Qwen2.5-0.5B-Instruct", use_liger=False) dataset = load_dataset('willcb/R1-reverse-wikipedia-paragraphs-v1-1000', split='train') tok_counts = [] diff --git a/verifiers/examples/smola_math_tools.py b/verifiers/examples/smola_math_tools.py index b2eeb4538..d9b65d189 100644 --- a/verifiers/examples/smola_math_tools.py +++ b/verifiers/examples/smola_math_tools.py @@ -47,7 +47,7 @@ print(vf_env.system_prompt) model_name = "Qwen/Qwen2.5-7B-Instruct" -model, tokenizer = vf.get_model_and_tokenizer(model_name) +model, tokenizer = vf.get_model_and_processor(model_name) run_name = "math-smola-grpo_" + model_name.split("/")[-1].lower() args = vf.grpo_defaults(run_name=run_name) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 13fde83d7..9cf661984 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1,5 +1,6 @@ # adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py +import inspect import logging from collections import defaultdict, deque from contextlib import nullcontext @@ -37,6 +38,18 @@ from verifiers.utils.logging_utils import print_prompt_completions_sample from verifiers.utils.trainer_utils import RepeatSampler +def _accepts_logits_to_keep(model) -> bool: + forward = ( + model.get_base_model().forward + if hasattr(model, "get_base_model") + else model.forward + ) + try: + inspect.signature(forward).bind_partial(**{"logits_to_keep": None}) + return True + except TypeError: + return False + # torch.nanstd doesn't exist, so we define it here def nanstd(tensor: torch.Tensor) -> torch.Tensor: """ @@ -447,16 +460,18 @@ def _get_last_hidden_state(self, unwrapped_model, input_ids, attention_mask, log return last_hidden_state # Get the per-token log probabilities for the completions for the model and the reference model - def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor: + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None, **model_kwargs) -> torch.Tensor: batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] + if _accepts_logits_to_keep(model): + model_kwargs["logits_to_keep"] = logits_to_keep + 1 for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model( - input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1 + input_ids=input_ids_batch, attention_mask=attention_mask_batch, **model_kwargs ).logits logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids_batch = input_ids_batch[:, -logits_to_keep:] @@ -820,7 +835,15 @@ def compute_loss(self, input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature( + model.get_base_model().forward + ).parameters.keys() + ) + model_kwargs = {k: inputs[k] for k in model_kwarg_keys if k in inputs} + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, **model_kwargs) # Compute the loss advantages = inputs["advantages"] @@ -847,12 +870,12 @@ def compute_loss(self, with torch.no_grad(): if self.ref_model is not None: ref_per_token_logps = self._get_per_token_logps( - self.ref_model, input_ids, attention_mask, logits_to_keep + self.ref_model, input_ids, attention_mask, logits_to_keep, **model_kwargs ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): # type: ignore ref_per_token_logps = self._get_per_token_logps( - self.model, input_ids, attention_mask, logits_to_keep + self.model, input_ids, attention_mask, logits_to_keep, **model_kwargs ) per_token_kl = ( torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 diff --git a/verifiers/utils/__init__.py b/verifiers/utils/__init__.py index a83ce148a..1abce403c 100644 --- a/verifiers/utils/__init__.py +++ b/verifiers/utils/__init__.py @@ -1,6 +1,6 @@ from .data_utils import extract_boxed_answer, extract_hash_answer, load_example_dataset from .config_utils import grpo_defaults, lora_defaults -from .model_utils import get_model, get_tokenizer, get_model_and_tokenizer +from .model_utils import get_model, get_processor, get_model_and_processor from .logging_utils import setup_logging, print_prompt_completions_sample __all__ = [ @@ -10,8 +10,8 @@ "grpo_defaults", "lora_defaults", "get_model", - "get_tokenizer", - "get_model_and_tokenizer", + "get_processor", + "get_model_and_processor", "setup_logging", "print_prompt_completions_sample", ] \ No newline at end of file diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index 9d1a8ac63..7519ae0e6 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -3,7 +3,7 @@ from typing import Dict, Any, Union, Tuple, Callable import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PreTrainedModel # type: ignore +from transformers import AutoModelForCausalLM, AutoProcessor, AutoConfig, PreTrainedModel # type: ignore import torch.nn as nn @@ -111,15 +111,16 @@ def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[ else: return generic_model_loader(model_name, **model_kwargs) -def get_tokenizer(model_name: str) -> Any: - tokenizer = AutoTokenizer.from_pretrained(model_name) +def get_processor(model_name: str, padding_side: str = "left") -> Any: + processor = AutoProcessor.from_pretrained(model_name, padding_side=padding_side) + tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor if not hasattr(tokenizer, "chat_template"): raise ValueError(f"Tokenizer for model {model_name} does not have chat_template attribute, \ and could not find a tokenizer with the same name as the model with suffix \ '-Instruct'. Please provide a tokenizer with the chat_template attribute.") - return tokenizer + return processor -def get_model_and_tokenizer(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[str, Any], None] = None) -> Tuple[Any, Any]: +def get_model_and_processor(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[str, Any], None] = None) -> Tuple[Any, Any]: model = get_model(model_name, use_liger, model_kwargs) - tokenizer = get_tokenizer(model_name) + tokenizer = get_processor(model_name) return model, tokenizer \ No newline at end of file From 88405ce6cc8d09bbed6354e860e33c61d83f2e21 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 2 Jun 2025 22:41:46 +0000 Subject: [PATCH 04/48] fix processor calls and pin transformers --- pyproject.toml | 2 +- verifiers/envs/environment.py | 16 +++++++++---- verifiers/trainers/grpo_trainer.py | 36 +++++++++++++++++++++++------- 3 files changed, 41 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e75a0542b..71117d49a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "vllm>=0.8.5.post1", "openai>=1.81.0", "datasets>=3.6.0", - "transformers", + "transformers==4.51.1", ] [project.optional-dependencies] diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index f3d2558c5..c410258c9 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -344,7 +344,11 @@ def process_chat_format( # tokenize just the prompt prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) assert isinstance(prompt_text, str) - prompt_ids = processing_class.encode(prompt_text) + if hasattr(processing_class, "tokenizer"): + encode = processing_class.tokenizer.encode + else: + encode = processing_class.encode + prompt_ids = encode(prompt_text) prompt_mask = [1] * len(prompt_ids) # track completion tokens and masks by processing incrementally @@ -366,7 +370,7 @@ def process_chat_format( add_generation_prompt=False, ) assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}" - current_ids = processing_class.encode(prefix_text) + current_ids = encode(prefix_text) assert current_ids[:len(prev_ids)] == prev_ids, f"Tokenization difference in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" # add new tokens to completion tokens @@ -409,11 +413,15 @@ def process_completion_format( prompt_ids, prompt_mask, completion_ids, completion_mask """ # Tokenize prompt - prompt_ids = processing_class.encode(prompt) + if hasattr(processing_class, "tokenizer"): + encode = processing_class.tokenizer.encode + else: + encode = processing_class.encode + prompt_ids = encode(prompt) prompt_mask = [1] * len(prompt_ids) # Tokenize completion - completion_ids = processing_class.encode(completion) + completion_ids = encode(completion) completion_mask = [1] * len(completion_ids) return prompt_ids, prompt_mask, completion_ids, completion_mask diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 9cf661984..e6181fc6d 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -171,8 +171,12 @@ def __init__( model.warnings_issued["estimate_tokens"] = True # Tokenizer pad token - if processing_class.pad_token is None: # type: ignore - processing_class.pad_token = processing_class.eos_token # type: ignore + if hasattr(processing_class, "tokenizer"): + if processing_class.tokenizer.pad_token is None: + processing_class.tokenizer.pad_token = processing_class.tokenizer.eos_token # type: ignore + else: + if processing_class.pad_token is None: # type: ignore + processing_class.pad_token = processing_class.eos_token # type: ignore # Training arguments self.per_device_train_batch_size = args.per_device_train_batch_size @@ -227,7 +231,11 @@ def filter_by_prompt_length(example): else: # Completion format prompt_text = prompt - prompt_ids = processing_class.encode(prompt_text) # type: ignore + if hasattr(processing_class, "tokenizer"): + encode = processing_class.tokenizer.encode + else: + encode = processing_class.encode + prompt_ids = encode(prompt_text) # type: ignore return len(prompt_ids) <= max_length original_size = len(train_dataset) @@ -743,10 +751,14 @@ def _prepare_inputs( # type: ignore completion_mask_list.append(torch.tensor(broadcast_data['completion_mask'][i], device=self.accelerator.device)) # Pad sequences - prompt_ids = pad(prompt_ids_list, padding_value=self.processing_class.pad_token_id, padding_side='left') # type: ignore + if hasattr(self.processing_class, "tokenizer"): + pad_token_id = self.processing_class.tokenizer.pad_token_id + else: + pad_token_id = self.processing_class.pad_token_id + prompt_ids = pad(prompt_ids_list, padding_value=pad_token_id, padding_side='left') # type: ignore prompt_mask = pad(prompt_mask_list, padding_side='left') # type: ignore - completion_ids = pad(completion_ids_list, padding_value=self.processing_class.pad_token_id, padding_side='right') # type: ignore - completion_mask = pad(completion_mask_list) + completion_ids = pad(completion_ids_list, padding_value=pad_token_id, padding_side='left') # type: ignore + completion_mask = pad(completion_mask_list, padding_side="left") # Truncate if needed if self.max_prompt_length is not None and prompt_ids.size(1) > self.max_prompt_length: @@ -951,7 +963,11 @@ def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval" completions = eval_results['completion'] if isinstance(completions[0], str): # Completion format - directly tokenize strings - completion_lengths = [len(self.processing_class.encode(c)) for c in completions] # type: ignore + if hasattr(self.processing_class, "tokenizer"): + encode = self.processing_class.tokenizer.encode + else: + encode = self.processing_class.encode + completion_lengths = [len(encode(c)) for c in completions] # type: ignore else: # Chat format - use apply_chat_template completion_lengths = [] @@ -1127,8 +1143,12 @@ def _log_completion_metrics_primary( # Check for EOS tokens term_lengths = [] + if hasattr(self.processing_class, "tokenizer"): + eos_token_id = self.processing_class.tokenizer.eos_token_id + else: + eos_token_id = self.processing_class.eos_token_id for comp_ids, comp_mask in zip(all_completion_ids, all_completion_mask): - has_eos = any(token == self.processing_class.eos_token_id for token, mask in zip(comp_ids, comp_mask) if mask) # type: ignore + has_eos = any(token == eos_token_id for token, mask in zip(comp_ids, comp_mask) if mask) # type: ignore if has_eos: term_lengths.append(sum(comp_mask)) From 38062e01bcf22ccff0444668b5924fb6e19f319a Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 3 Jun 2025 02:24:01 +0000 Subject: [PATCH 05/48] gather images --- verifiers/trainers/grpo_trainer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index e6181fc6d..e46acd0ec 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -590,7 +590,7 @@ def _ids_to_tensors(self, 'mask': mask } - def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any], List[Any]]: + def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[List[Any]] | None, List[Any], List[Any]]: """ Gather batch data from all processes. @@ -607,14 +607,17 @@ def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any # Gather batch data from all processes prompts = [x['prompt'] for x in batch] + images = [x['images'] for x in batch] + images = None if images == [] else images answers = [x['answer'] for x in batch] tasks = [x.get('task', 'default') for x in batch] all_prompts = gather_object(prompts) + all_images = gather_object(images) all_answers = gather_object(answers) all_tasks = gather_object(tasks) - return all_prompts, all_answers, all_tasks + return all_prompts, all_images, all_answers, all_tasks def _prepare_inputs( # type: ignore self, inputs: list[dict[str, Any]] @@ -661,7 +664,8 @@ def _prepare_inputs( # type: ignore for batch_id in range(self._next_batch_id, target_batch_id + 1): batch_offset = batch_id - batch_id_to_retrieve - all_prompts, all_answers, all_tasks = self._gather_batch_data(batch_offset) + all_prompts, all_images, all_answers, all_tasks = self._gather_batch_data(batch_offset) + import pdb;pdb.set_trace() local_batch_size = len(all_prompts) // self.accelerator.num_processes @@ -669,7 +673,7 @@ def _prepare_inputs( # type: ignore if self.accelerator.is_main_process: request = BatchRequest( batch_id=batch_id, - env_inputs={'prompt': all_prompts, 'answer': all_answers, 'task': all_tasks}, + env_inputs={'prompt': all_prompts, 'images': all_images, 'answer': all_answers, 'task': all_tasks}, processing_class=self.processing_class, mask_env_responses=self.mask_env_responses, max_completion_length=self.max_completion_length, From f91df352ea5a079b0f6a633f0e293b16536f9aaf Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 3 Jun 2025 05:56:03 +0000 Subject: [PATCH 06/48] format images --- verifiers/envs/environment.py | 56 ++++++++++++++------- verifiers/examples/docvqa.py | 4 ++ verifiers/inference/vllm_server.py | 2 +- verifiers/trainers/async_batch_generator.py | 2 +- verifiers/trainers/grpo_trainer.py | 4 +- 5 files changed, 44 insertions(+), 24 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index c410258c9..040a08afd 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from typing import Any, Dict, List, Literal, Tuple, Optional, Union +import base64, io import torch @@ -12,10 +13,19 @@ from datasets import Dataset from openai import OpenAI +from PIL import Image + from verifiers import RewardFunc from verifiers.parsers import Parser from verifiers.rubrics import Rubric +def _pil_to_data_url(img: Image.Image, fmt: str | None = None) -> str: + buf = io.BytesIO() + fmt = (fmt or img.format or "PNG").upper() + img.save(buf, format=fmt) + b64 = base64.b64encode(buf.getvalue()).decode("utf-8") + return f"data:image/{fmt.lower()};base64,{b64}" + class Environment(ABC): """ Base class for all environments. @@ -96,26 +106,34 @@ def format_dataset(self, system_prompt: str | None = None, few_shot: List[Dict[str, Any]] | None = None, question_key: str = "question", - answer_key: str = "answer") -> Dataset: + images_key: str = "images") -> Dataset: # Extract format_prompt as a standalone function to avoid capturing self - def format_prompt_fn(prompt: str) -> List[Dict[str, Any]]: - messages = [] - if system_prompt: - messages.append({'role': 'system', 'content': system_prompt}) - if few_shot: - messages.extend(few_shot) - messages.append({'role': 'user', 'content': prompt}) - return messages - - if answer_key == "answer": - return dataset.map(lambda x: { - "prompt": format_prompt_fn(x[question_key]), - }, num_proc=self.max_concurrent) - else: - return dataset.map(lambda x: { - "prompt": format_prompt_fn(x[question_key]), - "answer": x[answer_key] - }, num_proc=self.max_concurrent) + def format_prompt_fn(batch) -> Dict[str, List[Any]]: + batch_size = len(batch[question_key]) + formatted_prompts = [] + for i in range(batch_size): + messages = [] + if system_prompt: + messages.append({'role': 'system', 'content': system_prompt}) + if few_shot: + messages.extend(few_shot) + content_blocks = [ + {"type": "text", "text": batch[question_key][i]} + ] + if images_key in batch.keys(): + for img in batch[images_key][i]: + content_blocks.append( + { + "type": "image_url", + "image_url": {"url": _pil_to_data_url(img)}, + } + ) + messages.append({"role": "user", "content": content_blocks}) + formatted_prompts.append(messages) + batch["prompt"] = formatted_prompts + return batch + + return dataset.with_transform(format_prompt_fn) def get_dataset(self, n: int = -1, seed: int = 0, **kwargs: Any) -> Dataset | None: if n > 0 and self.dataset is not None: diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index c1f89a6ee..03c7ea511 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -7,6 +7,9 @@ TODO: - check liger kernel support - check that generic_model_loader didn't break anything + - check completions format, not just chat + - what happens if dataset is already formatted? + - fix print completions """ def preprocess_docvqa(x): @@ -45,6 +48,7 @@ def preprocess_docvqa(x): training_args = vf.grpo_defaults(run_name=run_name) training_args.per_device_train_batch_size = 4 training_args.num_generations = 4 +training_args.log_completions = False trainer = vf.GRPOTrainer( model=model, diff --git a/verifiers/inference/vllm_server.py b/verifiers/inference/vllm_server.py index dce5207ef..80ff6d10b 100644 --- a/verifiers/inference/vllm_server.py +++ b/verifiers/inference/vllm_server.py @@ -81,7 +81,7 @@ async def get_next_worker_connection(connections: list[AnyType]) -> tuple[int, A # -------- OpenAI /v1/chat/completions Pydantic Models ---------- # class OAChatMessage(BaseModel): role: str - content: str + content: str | list class OAChatCompletionRequest(BaseModel): model: str diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 28946cd5e..28b08c28c 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -11,7 +11,7 @@ class BatchRequest: """Request for batch generation""" batch_id: int - env_inputs: Dict[str, List[Any]] + env_inputs: Dict[str, List[Any] | None] processing_class: Any mask_env_responses: bool max_completion_length: int diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index e46acd0ec..a0f966b09 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -590,7 +590,7 @@ def _ids_to_tensors(self, 'mask': mask } - def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[List[Any]] | None, List[Any], List[Any]]: + def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any] | None, List[Any], List[Any]]: """ Gather batch data from all processes. @@ -665,10 +665,8 @@ def _prepare_inputs( # type: ignore for batch_id in range(self._next_batch_id, target_batch_id + 1): batch_offset = batch_id - batch_id_to_retrieve all_prompts, all_images, all_answers, all_tasks = self._gather_batch_data(batch_offset) - import pdb;pdb.set_trace() local_batch_size = len(all_prompts) // self.accelerator.num_processes - # Submit batch (main process only) if self.accelerator.is_main_process: request = BatchRequest( From 85c3f3128b15f66ea06a5383e2fbc5f0fe0460f0 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 3 Jun 2025 06:08:10 +0000 Subject: [PATCH 07/48] fix rich log --- verifiers/examples/docvqa.py | 4 ++-- verifiers/utils/logging_utils.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 03c7ea511..ff9523ecc 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -9,7 +9,7 @@ - check that generic_model_loader didn't break anything - check completions format, not just chat - what happens if dataset is already formatted? - - fix print completions + - fix wandb log """ def preprocess_docvqa(x): @@ -48,7 +48,7 @@ def preprocess_docvqa(x): training_args = vf.grpo_defaults(run_name=run_name) training_args.per_device_train_batch_size = 4 training_args.num_generations = 4 -training_args.log_completions = False +training_args.log_completions = True trainer = vf.GRPOTrainer( model=model, diff --git a/verifiers/utils/logging_utils.py b/verifiers/utils/logging_utils.py index a36b97f1f..454d07537 100644 --- a/verifiers/utils/logging_utils.py +++ b/verifiers/utils/logging_utils.py @@ -79,6 +79,8 @@ def print_prompt_completions_sample( if prompt: last_message = prompt[-1] content = last_message.get("content", "") + if isinstance(content, list): # multimodal case + content = content[0]["text"] formatted_prompt = Text(content, style="bright_yellow") else: formatted_prompt = Text("") From d4e981843dfc66abe25ed45a9e3b768fe46f28ac Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 3 Jun 2025 00:15:47 -0600 Subject: [PATCH 08/48] update comments --- verifiers/examples/docvqa.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index ff9523ecc..eed097204 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -10,6 +10,8 @@ - check completions format, not just chat - what happens if dataset is already formatted? - fix wandb log +NOTES: + - transformers seems to be having an issue, so it's pinned for now: https://github.com/volcengine/verl/issues/1710 """ def preprocess_docvqa(x): @@ -42,7 +44,7 @@ def preprocess_docvqa(x): ) model_name = "Qwen/Qwen2.5-VL-3B-Instruct" -model, processor = vf.get_model_and_processor(model_name, use_liger=False) # TODO: modify model loading to add liger support +model, processor = vf.get_model_and_processor(model_name, use_liger=False) run_name = "docvqa_" + model_name.split("/")[-1].lower() training_args = vf.grpo_defaults(run_name=run_name) From f38eee869c0a091390d2a9255329e46c5b380584 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Fri, 6 Jun 2025 20:29:16 +0000 Subject: [PATCH 09/48] update example --- verifiers/examples/docvqa.py | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index eed097204..1dededfc3 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -1,8 +1,15 @@ +import re + from datasets import load_dataset + import verifiers as vf """ +# install qwen stuff +uv pip install qwen-vl-utils +# inference CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' +# train CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py TODO: - check liger kernel support @@ -14,6 +21,7 @@ - transformers seems to be having an issue, so it's pinned for now: https://github.com/volcengine/verl/issues/1710 """ + def preprocess_docvqa(x): return { "question": x["question"], @@ -33,9 +41,49 @@ def preprocess_docvqa(x): Respond in the following format: {parser.get_format_str()}""" + +def correctness_reward_func(completion: list[dict[str, str]], **kwargs) -> float: + def get_assistant_messages(messages: list[dict[str, str]]) -> list[dict[str, str]]: + return [msg for msg in messages if msg.get("role") == "assistant"] + + def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: + pattern = rf"<{tag}>\s*(.*?)\s*" + match = re.search(pattern, text, re.DOTALL) + if match: + content = match.group(1) + return content.strip() if strip else content + return None + + assistant_messages = get_assistant_messages(completion) + if assistant_messages is None: + return 0.0 + msgs_scores = [] + for msg in assistant_messages: + content = msg.get("content", "") + answer = parse_xml_content(content, "answer") + if answer is None: + continue + gt_answers = kwargs["answer"] + mean_gt_len = sum([len(gt_answer) for gt_answer in gt_answers]) / len( + gt_answers + ) + diff_from_mean = min(mean_gt_len / len(answer), 1.0) # penalize long answers + if answer in gt_answers: + msgs_scores.append(2.0) + elif answer.lower() in [ans.lower() for ans in gt_answers]: + msgs_scores.append(1.0) + elif any(ans.lower() in answer.lower() for ans in gt_answers): + msgs_scores.append(diff_from_mean) + if msgs_scores == []: + return 0.0 + else: + return sum(msgs_scores) / len(msgs_scores) + + rubric = vf.Rubric( funcs=[ parser.get_format_reward_func(), + correctness_reward_func, ] ) From c40152ac315860204f820ed67c0ad40fbdcfc2f9 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sat, 7 Jun 2025 18:52:22 +0000 Subject: [PATCH 10/48] fix wandb logs --- verifiers/examples/docvqa.py | 4 +--- verifiers/trainers/grpo_trainer.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 1dededfc3..b271f0686 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -8,7 +8,7 @@ # install qwen stuff uv pip install qwen-vl-utils # inference -CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' +CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' --max-model-len 16000 # train CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py TODO: @@ -96,8 +96,6 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: run_name = "docvqa_" + model_name.split("/")[-1].lower() training_args = vf.grpo_defaults(run_name=run_name) -training_args.per_device_train_batch_size = 4 -training_args.num_generations = 4 training_args.log_completions = True trainer = vf.GRPOTrainer( diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index a0f966b09..1ab7c13c8 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1053,10 +1053,19 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: import pandas as pd - + if list(self._textual_logs["prompt"]) and isinstance(list(self._textual_logs["prompt"])[0], list): + prompt = [] + for messages in list(self._textual_logs["prompt"]): + last_message = messages[-1] + content = last_message.get("content", "") + if isinstance(content, list): + content = content[0]["text"] + prompt.append(content) + else: + prompt = list(self._textual_logs["prompt"]) table = { "step": [str(self.state.global_step)] * len(self._textual_logs["prompt"]), - "prompt": list(self._textual_logs["prompt"]), + "prompt": prompt, "completion": list(self._textual_logs["completion"]), **{k: list(v) for k, v in self._textual_logs["rewards"].items()}, } From 6eddfd6159612dcdd93aaf537a43cccd06056788 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sun, 8 Jun 2025 15:12:30 +0000 Subject: [PATCH 11/48] resize --- verifiers/examples/docvqa.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index b271f0686..902073d13 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -3,12 +3,13 @@ from datasets import load_dataset import verifiers as vf +from qwen_vl_utils import smart_resize """ # install qwen stuff uv pip install qwen-vl-utils # inference -CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' --max-model-len 16000 +CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' # train CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py TODO: @@ -25,7 +26,7 @@ def preprocess_docvqa(x): return { "question": x["question"], - "images": [x["image"]], + "images": [x["image"].resize(smart_resize(480, 640))], "answer": x["answers"][0], } From 3be9f4504bf0f0939e2c73f4789b5785d40f37e3 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sun, 8 Jun 2025 17:48:05 +0000 Subject: [PATCH 12/48] model len --- verifiers/examples/docvqa.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 902073d13..4afc96dda 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -9,7 +9,7 @@ # install qwen stuff uv pip install qwen-vl-utils # inference -CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' +CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' --max-model-len 64000 # train CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py TODO: @@ -26,7 +26,7 @@ def preprocess_docvqa(x): return { "question": x["question"], - "images": [x["image"].resize(smart_resize(480, 640))], + "images": [x["image"].resize(smart_resize(768, 1024))], # XGA "answer": x["answers"][0], } @@ -98,6 +98,7 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: training_args = vf.grpo_defaults(run_name=run_name) training_args.log_completions = True +training_args.num_train_epochs = 2 trainer = vf.GRPOTrainer( model=model, From 65aeb1b8c3d92bee7d30e017ac91334d705feb96 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 9 Jun 2025 02:39:35 +0000 Subject: [PATCH 13/48] update example --- verifiers/examples/docvqa.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 4afc96dda..fe1c9b94f 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -98,7 +98,8 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: training_args = vf.grpo_defaults(run_name=run_name) training_args.log_completions = True -training_args.num_train_epochs = 2 +training_args.num_train_epochs = 3 +training_args.max_steps = -1 trainer = vf.GRPOTrainer( model=model, From 606027a4d2a75fc5a1de86e1f1bcdd4eb431c9e5 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 9 Jun 2025 03:59:17 +0000 Subject: [PATCH 14/48] fix format and remove unused images --- verifiers/envs/environment.py | 12 +++++++----- verifiers/trainers/async_batch_generator.py | 2 +- verifiers/trainers/grpo_trainer.py | 9 +++------ 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 040a08afd..863c6aba8 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -117,18 +117,20 @@ def format_prompt_fn(batch) -> Dict[str, List[Any]]: messages.append({'role': 'system', 'content': system_prompt}) if few_shot: messages.extend(few_shot) - content_blocks = [ - {"type": "text", "text": batch[question_key][i]} - ] if images_key in batch.keys(): + content = [ + {"type": "text", "text": batch[question_key][i]} + ] for img in batch[images_key][i]: - content_blocks.append( + content.append( { "type": "image_url", "image_url": {"url": _pil_to_data_url(img)}, } ) - messages.append({"role": "user", "content": content_blocks}) + else: + content = batch[question_key][i] + messages.append({"role": "user", "content": content}) formatted_prompts.append(messages) batch["prompt"] = formatted_prompts return batch diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 28b08c28c..28946cd5e 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -11,7 +11,7 @@ class BatchRequest: """Request for batch generation""" batch_id: int - env_inputs: Dict[str, List[Any] | None] + env_inputs: Dict[str, List[Any]] processing_class: Any mask_env_responses: bool max_completion_length: int diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 1ab7c13c8..16957b119 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -590,7 +590,7 @@ def _ids_to_tensors(self, 'mask': mask } - def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any] | None, List[Any], List[Any]]: + def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any], List[Any], List[Any]]: """ Gather batch data from all processes. @@ -607,17 +607,14 @@ def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any # Gather batch data from all processes prompts = [x['prompt'] for x in batch] - images = [x['images'] for x in batch] - images = None if images == [] else images answers = [x['answer'] for x in batch] tasks = [x.get('task', 'default') for x in batch] all_prompts = gather_object(prompts) - all_images = gather_object(images) all_answers = gather_object(answers) all_tasks = gather_object(tasks) - return all_prompts, all_images, all_answers, all_tasks + return all_prompts, all_answers, all_tasks def _prepare_inputs( # type: ignore self, inputs: list[dict[str, Any]] @@ -671,7 +668,7 @@ def _prepare_inputs( # type: ignore if self.accelerator.is_main_process: request = BatchRequest( batch_id=batch_id, - env_inputs={'prompt': all_prompts, 'images': all_images, 'answer': all_answers, 'task': all_tasks}, + env_inputs={'prompt': all_prompts, 'answer': all_answers, 'task': all_tasks}, processing_class=self.processing_class, mask_env_responses=self.mask_env_responses, max_completion_length=self.max_completion_length, From bd76eababa12005ecaf6899de3eb407cea3ffc1a Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 9 Jun 2025 04:02:57 +0000 Subject: [PATCH 15/48] fix image unpacking --- verifiers/examples/docvqa.py | 1 - verifiers/trainers/grpo_trainer.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index fe1c9b94f..6cf5b7110 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -98,7 +98,6 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: training_args = vf.grpo_defaults(run_name=run_name) training_args.log_completions = True -training_args.num_train_epochs = 3 training_args.max_steps = -1 trainer = vf.GRPOTrainer( diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 16957b119..b6960bd76 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -661,7 +661,7 @@ def _prepare_inputs( # type: ignore for batch_id in range(self._next_batch_id, target_batch_id + 1): batch_offset = batch_id - batch_id_to_retrieve - all_prompts, all_images, all_answers, all_tasks = self._gather_batch_data(batch_offset) + all_prompts, all_answers, all_tasks = self._gather_batch_data(batch_offset) local_batch_size = len(all_prompts) // self.accelerator.num_processes # Submit batch (main process only) From a89de9b13ea1515752102b179544402b1b18d7d7 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 9 Jun 2025 16:09:37 +0000 Subject: [PATCH 16/48] change format dataset --- verifiers/envs/environment.py | 56 +++++++++++++++++------------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 863c6aba8..39eb9e08e 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -106,36 +106,36 @@ def format_dataset(self, system_prompt: str | None = None, few_shot: List[Dict[str, Any]] | None = None, question_key: str = "question", - images_key: str = "images") -> Dataset: + images_key: str = "images", + answer_key: str = "answer") -> Dataset: # Extract format_prompt as a standalone function to avoid capturing self - def format_prompt_fn(batch) -> Dict[str, List[Any]]: - batch_size = len(batch[question_key]) - formatted_prompts = [] - for i in range(batch_size): - messages = [] - if system_prompt: - messages.append({'role': 'system', 'content': system_prompt}) - if few_shot: - messages.extend(few_shot) - if images_key in batch.keys(): - content = [ - {"type": "text", "text": batch[question_key][i]} - ] - for img in batch[images_key][i]: - content.append( - { - "type": "image_url", - "image_url": {"url": _pil_to_data_url(img)}, - } - ) - else: - content = batch[question_key][i] - messages.append({"role": "user", "content": content}) - formatted_prompts.append(messages) - batch["prompt"] = formatted_prompts - return batch + def format_prompt_fn(prompt: str, images: list | None) -> List[Dict[str, Any]]: + messages = [] + if system_prompt: + messages.append({'role': 'system', 'content': [{"type": "text", "text": system_prompt}]}) + if few_shot: + messages.extend(few_shot) + content = [{"type": "text", "text": prompt}] + if images is not None: + for img in images: + content.append( + { + "type": "image_url", + "image_url": {"url": _pil_to_data_url(img)}, + } + ) + messages.append({"role": "user", "content": content}) + return messages - return dataset.with_transform(format_prompt_fn) + if answer_key == "answer": + return dataset.map(lambda x: { + "prompt": format_prompt_fn(x[question_key], x.get(images_key)), + }, num_proc=self.max_concurrent) + else: + return dataset.map(lambda x: { + "prompt": format_prompt_fn(x[question_key], x.get(images_key)), + "answer": x[answer_key] + }, num_proc=self.max_concurrent) def get_dataset(self, n: int = -1, seed: int = 0, **kwargs: Any) -> Dataset | None: if n > 0 and self.dataset is not None: From 6c464f95c3a51aef2047b7ada67c88b5f93264a8 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 9 Jun 2025 16:36:31 +0000 Subject: [PATCH 17/48] opt --- verifiers/examples/docvqa.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 6cf5b7110..211c4b49e 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -97,8 +97,9 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: run_name = "docvqa_" + model_name.split("/")[-1].lower() training_args = vf.grpo_defaults(run_name=run_name) -training_args.log_completions = True training_args.max_steps = -1 +training_args.lr_scheduler_type = "cosine" +training_args.learning_rate = 3e-6 trainer = vf.GRPOTrainer( model=model, From 5036f693642ea49a8e4f265f6de154bffdf0b70b Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 9 Jun 2025 20:41:54 +0000 Subject: [PATCH 18/48] fix format on text-only --- verifiers/envs/environment.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 39eb9e08e..e5e40aaf3 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -111,12 +111,18 @@ def format_dataset(self, # Extract format_prompt as a standalone function to avoid capturing self def format_prompt_fn(prompt: str, images: list | None) -> List[Dict[str, Any]]: messages = [] - if system_prompt: - messages.append({'role': 'system', 'content': [{"type": "text", "text": system_prompt}]}) - if few_shot: - messages.extend(few_shot) - content = [{"type": "text", "text": prompt}] - if images is not None: + if images is None: + if system_prompt: + messages.append({'role': 'system', 'content': system_prompt}) + if few_shot: + messages.extend(few_shot) + messages.append({'role': 'user', 'content': prompt}) + else: + if system_prompt: + messages.append({'role': 'system', 'content': [{"type": "text", "text": system_prompt}]}) + if few_shot: + messages.extend(few_shot) + content = [{"type": "text", "text": prompt}] for img in images: content.append( { @@ -124,7 +130,7 @@ def format_prompt_fn(prompt: str, images: list | None) -> List[Dict[str, Any]]: "image_url": {"url": _pil_to_data_url(img)}, } ) - messages.append({"role": "user", "content": content}) + messages.append({"role": "user", "content": content}) return messages if answer_key == "answer": From f3959f0d320e30fe62a48599faaa608556d2740b Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 9 Jun 2025 21:42:58 +0000 Subject: [PATCH 19/48] fix _gather_batch_data type --- verifiers/trainers/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index b6960bd76..bba9a938d 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -590,7 +590,7 @@ def _ids_to_tensors(self, 'mask': mask } - def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any], List[Any], List[Any]]: + def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any], List[Any]]: """ Gather batch data from all processes. From b8732f77d185212f0ffa19db5a76500581c8dc22 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 9 Jun 2025 21:56:37 +0000 Subject: [PATCH 20/48] relax transformers condition --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 71117d49a..835425c45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "vllm>=0.8.5.post1", "openai>=1.81.0", "datasets>=3.6.0", - "transformers==4.51.1", + "transformers<4.52.0", ] [project.optional-dependencies] From e4612583c42022d6f7fea714fc1297d2d1550498 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 9 Jun 2025 21:57:12 +0000 Subject: [PATCH 21/48] update comment / increase lr --- verifiers/examples/docvqa.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 211c4b49e..721c673a4 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -14,12 +14,11 @@ CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py TODO: - check liger kernel support - - check that generic_model_loader didn't break anything - check completions format, not just chat - - what happens if dataset is already formatted? - fix wandb log -NOTES: - - transformers seems to be having an issue, so it's pinned for now: https://github.com/volcengine/verl/issues/1710 + - transformers changed weight keys. pinned for now, but should update: + - https://github.com/volcengine/verl/issues/1710 + - https://github.com/huggingface/transformers/pull/38385 """ @@ -97,9 +96,8 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: run_name = "docvqa_" + model_name.split("/")[-1].lower() training_args = vf.grpo_defaults(run_name=run_name) +training_args.learning_rate = 2e-6 training_args.max_steps = -1 -training_args.lr_scheduler_type = "cosine" -training_args.learning_rate = 3e-6 trainer = vf.GRPOTrainer( model=model, From 58ac1c78fbf573526c7596570fa6a0046f73278e Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 9 Jun 2025 22:41:12 +0000 Subject: [PATCH 22/48] liger monkey patch --- verifiers/examples/docvqa.py | 5 ++--- verifiers/utils/model_utils.py | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 721c673a4..acd5cef00 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -13,7 +13,6 @@ # train CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py TODO: - - check liger kernel support - check completions format, not just chat - fix wandb log - transformers changed weight keys. pinned for now, but should update: @@ -77,7 +76,7 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: if msgs_scores == []: return 0.0 else: - return sum(msgs_scores) / len(msgs_scores) + return (sum(msgs_scores) / len(msgs_scores) / 2.0) rubric = vf.Rubric( @@ -92,7 +91,7 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: ) model_name = "Qwen/Qwen2.5-VL-3B-Instruct" -model, processor = vf.get_model_and_processor(model_name, use_liger=False) +model, processor = vf.get_model_and_processor(model_name) run_name = "docvqa_" + model_name.split("/")[-1].lower() training_args = vf.grpo_defaults(run_name=run_name) diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index 7519ae0e6..06b68d29d 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -106,8 +106,21 @@ def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[ ) if is_liger_available() and use_liger: print("Using Liger kernel") - from liger_kernel.transformers import AutoLigerKernelForCausalLM # type: ignore - return AutoLigerKernelForCausalLM.from_pretrained(model_name, **model_kwargs) + from liger_kernel.transformers import AutoLigerKernelForCausalLM, apply_liger_kernel_to_qwen2_5_vl # type: ignore + patch_mapping = { + "Qwen2_5_VLConfig": apply_liger_kernel_to_qwen2_5_vl + } + try: + model = AutoLigerKernelForCausalLM.from_pretrained(model_name, **model_kwargs) + return model + except ValueError: # try monkey patch + print(f"Model {model_name} is not supported with AutoLigerKernelForCausalLM. Attempting monkey patch...") + config_name = AutoConfig.from_pretrained(model_name, trust_remote_code=True).__class__.__name__ + if config_name in patch_mapping.keys(): + patch_mapping[config_name]() + return generic_model_loader(model_name, **model_kwargs) + else: + raise ValueError(f"Model {model_name} is not supported with Liger-Kernel in verifiers") else: return generic_model_loader(model_name, **model_kwargs) From 2e25655fa7076447646211410dd6c9cd0a7906cc Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 10 Jun 2025 00:17:39 +0000 Subject: [PATCH 23/48] generic liger patch --- verifiers/utils/model_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index 06b68d29d..c9d0d20d3 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -1,3 +1,4 @@ +import importlib from importlib.util import find_spec from importlib import import_module from typing import Dict, Any, Union, Tuple, Callable @@ -106,19 +107,21 @@ def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[ ) if is_liger_available() and use_liger: print("Using Liger kernel") - from liger_kernel.transformers import AutoLigerKernelForCausalLM, apply_liger_kernel_to_qwen2_5_vl # type: ignore - patch_mapping = { - "Qwen2_5_VLConfig": apply_liger_kernel_to_qwen2_5_vl - } try: + from liger_kernel.transformers import AutoLigerKernelForCausalLM # type: ignore model = AutoLigerKernelForCausalLM.from_pretrained(model_name, **model_kwargs) return model except ValueError: # try monkey patch print(f"Model {model_name} is not supported with AutoLigerKernelForCausalLM. Attempting monkey patch...") - config_name = AutoConfig.from_pretrained(model_name, trust_remote_code=True).__class__.__name__ - if config_name in patch_mapping.keys(): - patch_mapping[config_name]() - return generic_model_loader(model_name, **model_kwargs) + model_type = AutoConfig.from_pretrained(model_name, trust_remote_code=True).model_type + patch_func_name = f"apply_liger_kernel_to_{model_type}" + ligermod = importlib.import_module("liger_kernel.transformers") + patch_func = getattr(ligermod, patch_func_name, None) + if callable(patch_func): + patch_func() + model = generic_model_loader(model_name, **model_kwargs) + print(f"Applied Liger-Kernel patch to {model_name}") + return model else: raise ValueError(f"Model {model_name} is not supported with Liger-Kernel in verifiers") else: From 0da588829a0a2caf61fb3400cf41800950005b64 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 10 Jun 2025 00:47:13 +0000 Subject: [PATCH 24/48] increase lr --- verifiers/examples/docvqa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index acd5cef00..195abb895 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -95,7 +95,7 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: run_name = "docvqa_" + model_name.split("/")[-1].lower() training_args = vf.grpo_defaults(run_name=run_name) -training_args.learning_rate = 2e-6 +training_args.learning_rate = 3e-6 training_args.max_steps = -1 trainer = vf.GRPOTrainer( From c9eaa0229f8d51c01c828cdf5c8b211e505089cf Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 10 Jun 2025 01:05:22 +0000 Subject: [PATCH 25/48] return to old naming --- README.md | 2 +- verifiers/__init__.py | 6 +++--- verifiers/examples/docvqa.py | 2 +- verifiers/examples/doublecheck.py | 2 +- verifiers/examples/math_python.py | 2 +- verifiers/examples/reverse_text.py | 2 +- verifiers/examples/self_reward.py | 2 +- verifiers/examples/sft/math_python.py | 2 +- verifiers/examples/sft/reverse_text.py | 2 +- verifiers/examples/smola_math_tools.py | 2 +- verifiers/utils/__init__.py | 6 +++--- verifiers/utils/model_utils.py | 6 +++--- 12 files changed, 18 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index e91fba9bf..6ef770d67 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ See `verifiers/examples/sft/reverse_text.py` for an example script using TRL's S ```python # train.py -model, tokenizer = vf.get_model_and_processor(model_name) +model, tokenizer = vf.get_model_and_tokenizer(model_name) trainer = vf.GRPOTrainer( model=model, processing_class=tokenizer, diff --git a/verifiers/__init__.py b/verifiers/__init__.py index c02b21d81..92816d83e 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -27,7 +27,7 @@ from .trainers.grpo_trainer import GRPOTrainer from .trainers.grpo_config import GRPOConfig from .utils.data_utils import extract_boxed_answer, extract_hash_answer, load_example_dataset -from .utils.model_utils import get_model, get_processor, get_model_and_processor +from .utils.model_utils import get_model, get_tokenizer, get_model_and_tokenizer from .utils.config_utils import grpo_defaults, lora_defaults __version__ = "0.1.0" @@ -53,8 +53,8 @@ "GRPOConfig", "VLLMClient", "get_model", - "get_processor", - "get_model_and_processor", + "get_tokenizer", + "get_model_and_tokenizer", "grpo_defaults", "lora_defaults", "extract_boxed_answer", diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 195abb895..2f4ff35cf 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -91,7 +91,7 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: ) model_name = "Qwen/Qwen2.5-VL-3B-Instruct" -model, processor = vf.get_model_and_processor(model_name) +model, processor = vf.get_model_and_tokenizer(model_name) run_name = "docvqa_" + model_name.split("/")[-1].lower() training_args = vf.grpo_defaults(run_name=run_name) diff --git a/verifiers/examples/doublecheck.py b/verifiers/examples/doublecheck.py index fa57bf0ec..71671bcb7 100644 --- a/verifiers/examples/doublecheck.py +++ b/verifiers/examples/doublecheck.py @@ -12,7 +12,7 @@ system_prompt=SIMPLE_PROMPT, few_shot=[] ) -model, tokenizer = vf.get_model_and_processor(model_name) +model, tokenizer = vf.get_model_and_tokenizer(model_name) args = vf.grpo_defaults(run_name="doublecheck-{}".format(model_name.split("/")[-1].lower())) trainer = vf.GRPOTrainer( model=model, diff --git a/verifiers/examples/math_python.py b/verifiers/examples/math_python.py index 7499c7ab3..9b06d597f 100644 --- a/verifiers/examples/math_python.py +++ b/verifiers/examples/math_python.py @@ -58,7 +58,7 @@ print(vf_env.system_prompt) model_name = "willcb/Qwen2.5-7B-Math-Python-SFT" -model, tokenizer = vf.get_model_and_processor(model_name) +model, tokenizer = vf.get_model_and_tokenizer(model_name) run_name = "math-grpo_" + model_name.split("/")[-1].lower() training_args=vf.grpo_defaults(run_name=run_name) diff --git a/verifiers/examples/reverse_text.py b/verifiers/examples/reverse_text.py index 811f66afc..494445aef 100644 --- a/verifiers/examples/reverse_text.py +++ b/verifiers/examples/reverse_text.py @@ -48,7 +48,7 @@ def lcs_ratio(x: str, y: str) -> float: args.eval_steps = 10 args.max_steps = 100 -model, tokenizer = vf.get_model_and_processor(model_name) +model, tokenizer = vf.get_model_and_tokenizer(model_name) trainer = vf.GRPOTrainer( model=model, processing_class=tokenizer, diff --git a/verifiers/examples/self_reward.py b/verifiers/examples/self_reward.py index 701a030ee..30722441a 100644 --- a/verifiers/examples/self_reward.py +++ b/verifiers/examples/self_reward.py @@ -10,6 +10,6 @@ system_prompt="You are a helpful assistant.", rubric=rubric ) -model, tokenizer = vf.get_model_and_processor(model_name) +model, tokenizer = vf.get_model_and_tokenizer(model_name) trainer = vf.GRPOTrainer(env=vf_env, model=model, processing_class=tokenizer, args=vf.grpo_defaults(run_name="self_reward")) trainer.train() \ No newline at end of file diff --git a/verifiers/examples/sft/math_python.py b/verifiers/examples/sft/math_python.py index 7562218e7..a15aff4a0 100644 --- a/verifiers/examples/sft/math_python.py +++ b/verifiers/examples/sft/math_python.py @@ -7,7 +7,7 @@ """ # convenience function for FA2 initialization -model, tokenizer = vf.get_model_and_processor("Qwen/Qwen2.5-7B-Instruct", use_liger=False) +model, tokenizer = vf.get_model_and_tokenizer("Qwen/Qwen2.5-7B-Instruct", use_liger=False) dataset = load_dataset('willcb/V3-gsm8k-python-test', split='train') tok_counts = [] diff --git a/verifiers/examples/sft/reverse_text.py b/verifiers/examples/sft/reverse_text.py index 8aab76f1e..77f3cb2d5 100644 --- a/verifiers/examples/sft/reverse_text.py +++ b/verifiers/examples/sft/reverse_text.py @@ -7,7 +7,7 @@ """ # convenience function for FA2 initialization -model, tokenizer = vf.get_model_and_processor("Qwen/Qwen2.5-0.5B-Instruct", use_liger=False) +model, tokenizer = vf.get_model_and_tokenizer("Qwen/Qwen2.5-0.5B-Instruct", use_liger=False) dataset = load_dataset('willcb/R1-reverse-wikipedia-paragraphs-v1-1000', split='train') tok_counts = [] diff --git a/verifiers/examples/smola_math_tools.py b/verifiers/examples/smola_math_tools.py index d9b65d189..b2eeb4538 100644 --- a/verifiers/examples/smola_math_tools.py +++ b/verifiers/examples/smola_math_tools.py @@ -47,7 +47,7 @@ print(vf_env.system_prompt) model_name = "Qwen/Qwen2.5-7B-Instruct" -model, tokenizer = vf.get_model_and_processor(model_name) +model, tokenizer = vf.get_model_and_tokenizer(model_name) run_name = "math-smola-grpo_" + model_name.split("/")[-1].lower() args = vf.grpo_defaults(run_name=run_name) diff --git a/verifiers/utils/__init__.py b/verifiers/utils/__init__.py index 1abce403c..a83ce148a 100644 --- a/verifiers/utils/__init__.py +++ b/verifiers/utils/__init__.py @@ -1,6 +1,6 @@ from .data_utils import extract_boxed_answer, extract_hash_answer, load_example_dataset from .config_utils import grpo_defaults, lora_defaults -from .model_utils import get_model, get_processor, get_model_and_processor +from .model_utils import get_model, get_tokenizer, get_model_and_tokenizer from .logging_utils import setup_logging, print_prompt_completions_sample __all__ = [ @@ -10,8 +10,8 @@ "grpo_defaults", "lora_defaults", "get_model", - "get_processor", - "get_model_and_processor", + "get_tokenizer", + "get_model_and_tokenizer", "setup_logging", "print_prompt_completions_sample", ] \ No newline at end of file diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index c9d0d20d3..e5af811b4 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -127,7 +127,7 @@ def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[ else: return generic_model_loader(model_name, **model_kwargs) -def get_processor(model_name: str, padding_side: str = "left") -> Any: +def get_tokenizer(model_name: str, padding_side: str = "left") -> Any: processor = AutoProcessor.from_pretrained(model_name, padding_side=padding_side) tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor if not hasattr(tokenizer, "chat_template"): @@ -136,7 +136,7 @@ def get_processor(model_name: str, padding_side: str = "left") -> Any: '-Instruct'. Please provide a tokenizer with the chat_template attribute.") return processor -def get_model_and_processor(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[str, Any], None] = None) -> Tuple[Any, Any]: +def get_model_and_tokenizer(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[str, Any], None] = None) -> Tuple[Any, Any]: model = get_model(model_name, use_liger, model_kwargs) - tokenizer = get_processor(model_name) + tokenizer = get_tokenizer(model_name) return model, tokenizer \ No newline at end of file From 8aff5d1b007d6d67ffb4fac84c1394d4c9d5efe4 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 10 Jun 2025 01:06:56 +0000 Subject: [PATCH 26/48] remove todos --- verifiers/examples/docvqa.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 2f4ff35cf..4da54c15f 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -12,12 +12,6 @@ CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' --max-model-len 64000 # train CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py -TODO: - - check completions format, not just chat - - fix wandb log - - transformers changed weight keys. pinned for now, but should update: - - https://github.com/volcengine/verl/issues/1710 - - https://github.com/huggingface/transformers/pull/38385 """ From 9bb135c2e983b0f3ff56e2d3513d7c826dab5f0f Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 10 Jun 2025 01:41:58 +0000 Subject: [PATCH 27/48] restore padding side --- verifiers/trainers/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index bba9a938d..00f4151a3 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -756,7 +756,7 @@ def _prepare_inputs( # type: ignore pad_token_id = self.processing_class.pad_token_id prompt_ids = pad(prompt_ids_list, padding_value=pad_token_id, padding_side='left') # type: ignore prompt_mask = pad(prompt_mask_list, padding_side='left') # type: ignore - completion_ids = pad(completion_ids_list, padding_value=pad_token_id, padding_side='left') # type: ignore + completion_ids = pad(completion_ids_list, padding_value=pad_token_id, padding_side='right') # type: ignore completion_mask = pad(completion_mask_list, padding_side="left") # Truncate if needed From a0323d7c7135519079a4e24e6167c1fbfda6853b Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 10 Jun 2025 01:45:35 +0000 Subject: [PATCH 28/48] remove padding side for completion_mask --- verifiers/trainers/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 00f4151a3..7a01373c6 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -757,7 +757,7 @@ def _prepare_inputs( # type: ignore prompt_ids = pad(prompt_ids_list, padding_value=pad_token_id, padding_side='left') # type: ignore prompt_mask = pad(prompt_mask_list, padding_side='left') # type: ignore completion_ids = pad(completion_ids_list, padding_value=pad_token_id, padding_side='right') # type: ignore - completion_mask = pad(completion_mask_list, padding_side="left") + completion_mask = pad(completion_mask_list) # Truncate if needed if self.max_prompt_length is not None and prompt_ids.size(1) > self.max_prompt_length: From 6b630ba30afb745d4267777043d65927e1832896 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 10 Jun 2025 02:16:16 +0000 Subject: [PATCH 29/48] fix wandb logging --- verifiers/trainers/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 7a01373c6..85376cab8 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1057,7 +1057,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: content = last_message.get("content", "") if isinstance(content, list): content = content[0]["text"] - prompt.append(content) + prompt.append([{'role': 'user', 'content': content}]) # format like text-only msgs else: prompt = list(self._textual_logs["prompt"]) table = { From edd9b294f4a810f30a5d68a040668c3ab99ea746 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 10 Jun 2025 02:45:14 +0000 Subject: [PATCH 30/48] logging format --- verifiers/trainers/grpo_trainer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 85376cab8..8d2373572 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1050,16 +1050,15 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: import pandas as pd - if list(self._textual_logs["prompt"]) and isinstance(list(self._textual_logs["prompt"])[0], list): - prompt = [] + # format prompt for logging + prompt = [] + if list(self._textual_logs["prompt"]): for messages in list(self._textual_logs["prompt"]): last_message = messages[-1] content = last_message.get("content", "") if isinstance(content, list): - content = content[0]["text"] - prompt.append([{'role': 'user', 'content': content}]) # format like text-only msgs - else: - prompt = list(self._textual_logs["prompt"]) + content = content[0]["text"] # extract text only in multimodal case + prompt.append([{'role': 'user', 'content': content}]) table = { "step": [str(self.state.global_step)] * len(self._textual_logs["prompt"]), "prompt": prompt, From f3f403e057559d72ed2be10b2d31ac38228597c4 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 10 Jun 2025 21:06:27 +0000 Subject: [PATCH 31/48] use data collator --- verifiers/examples/docvqa.py | 44 ++++++++++++++++++++++++------ verifiers/trainers/grpo_trainer.py | 7 +++-- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 4da54c15f..9193c18b0 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -15,18 +15,43 @@ """ -def preprocess_docvqa(x): - return { - "question": x["question"], - "images": [x["image"].resize(smart_resize(768, 1024))], # XGA - "answer": x["answers"][0], - } +# def preprocess_docvqa(x): +# return { +# "question": x["question"], +# "images": [x["image"].resize(smart_resize(768, 1024))], # XGA +# "answer": x["answers"][0], +# } + +def data_collator(batch: list[dict]) -> list[dict]: + processed_samples = [] + for sample in batch: + messages = [] + messages.append({"role": "system", "content": SYSTEM_PROMPT}) + content_block = [] + content_block.append({"type": "text", "text": sample["question"]}) + content_block.append( + { + "type": "image", + "image": sample["image"], # only one image in this ds + "resized_height": 768, # XGA resolution + "resized_width": 1024, + } + ) + messages.append({"role": "user", "content": content_block}) + processed_images, *_ = process_vision_info( # process with qwen utils + messages.copy() + ) + sample["prompt"] = messages + sample["images"] = processed_images + sample["answer"] = sample["answers"] + processed_samples.append(sample) + return processed_samples dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation") -dataset = dataset.map( - preprocess_docvqa, num_proc=10, remove_columns=dataset.column_names -) +# dataset = dataset.map( +# preprocess_docvqa, num_proc=10, remove_columns=dataset.column_names +# ) parser = vf.XMLParser(["think", "answer"], answer_field="answer") system_prompt = f"""Answer the questions. @@ -97,5 +122,6 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: processing_class=processor, env=vf_env, args=training_args, + data_collator=data_collator, ) trainer.train() diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 8d2373572..423925dec 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -14,6 +14,7 @@ from peft import PeftConfig, get_peft_model from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM +from transformers.data.data_collator import DataCollator from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -155,6 +156,7 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional[PeftConfig] = None, + data_collator: Optional[DataCollator] = None, **kwargs, ): self.logger = logging.getLogger(__name__) @@ -245,12 +247,12 @@ def filter_by_prompt_length(example): self.logger.info(f"Filtered dataset from {original_size} to {filtered_size} examples ({original_size - filtered_size} prompts were too long)") # dummy data collator - def data_collator(features): + def default_data_collator(features): return features super().__init__( model=model, args=args, - data_collator=data_collator, + data_collator=data_collator if data_collator is not None else default_data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, @@ -629,6 +631,7 @@ def _prepare_inputs( # type: ignore """ # Ensure all processes are synchronized at the start self.accelerator.wait_for_everyone() + import pdb;pdb.set_trace() # inputs = list of dicts for all gradient accumulation steps generate_every = self.gradient_accumulation_steps * self.num_iterations From 815e0003c876eee4b590608cb9324a1db346a1aa Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 10 Jun 2025 22:25:22 +0000 Subject: [PATCH 32/48] format oai-api prompts --- verifiers/envs/environment.py | 57 ++++++++++++++++++++- verifiers/examples/docvqa.py | 4 +- verifiers/trainers/async_batch_generator.py | 2 +- verifiers/trainers/grpo_trainer.py | 15 +++--- 4 files changed, 67 insertions(+), 11 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index e5e40aaf3..f1ba2e039 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -26,6 +26,56 @@ def _pil_to_data_url(img: Image.Image, fmt: str | None = None) -> str: b64 = base64.b64encode(buf.getvalue()).decode("utf-8") return f"data:image/{fmt.lower()};base64,{b64}" +def format_oai_chat_msg( + prompts: List[List[Dict[str, Any]]], + images: List[List[Image.Image]] +) -> List[List[Dict[str, Any]]]: + """ + Given: + - prompts: a list (for each convo) of lists of message‐dicts, + where a user message's content may be a list of parts, + some with {'type': 'image', 'image': PIL.Image, ...}. + - images: a parallel list (for each convo) of lists of PIL.Image objects. + + Returns: + A new list of the same shape, but with every image part replaced by: + {'type': 'image_url', 'image_url': DATA_URL} + """ + formatted_conversations: List[List[Dict[str, Any]]] = [] + + for conv_prompts, conv_images in zip(prompts, images): + img_iter = iter(conv_images) + new_conv: List[Dict[str, Any]] = [] + + for msg in conv_prompts: + role = msg["role"] + content = msg["content"] + + # If this message's content is a list of parts (text/image) + if isinstance(content, list): + new_parts: List[Dict[str, Any]] = [] + for part in content: + if part.get("type") == "image": + # grab the next PIL.Image from the images list + img = next(img_iter) + data_url = _pil_to_data_url(img) + new_parts.append({ + "type": "image_url", + "image_url": {"url": data_url} + }) + else: + # leave text (or any other part) untouched + new_parts.append(part.copy()) + new_conv.append({"role": role, "content": new_parts}) + + else: + # system or assistant messages with string content + new_conv.append({"role": role, "content": content}) + + formatted_conversations.append(new_conv) + + return formatted_conversations + class Environment(ABC): """ Base class for all environments. @@ -324,8 +374,12 @@ def generate(self, results = {col: deepcopy(inputs[col]) for col in inputs.column_names} else: results = deepcopy(inputs) + if results.get('images') is not None: + prompts = format_oai_chat_msg(results['prompt'], results['images']) + else: + prompts = results['prompt'] rollouts = self.run_rollouts( - prompts=results['prompt'], + prompts=prompts, client=client, model=model, sampling_args=gen_sampling_args, @@ -333,6 +387,7 @@ def generate(self, **kwargs ) results['completion'] = [rollout[0] for rollout in rollouts] + import pdb;pdb.set_trace() results['state'] = [rollout[1] for rollout in rollouts] if 'task' not in results: results['task'] = ['default'] * len(results['prompt']) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 9193c18b0..c061799ab 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -3,7 +3,7 @@ from datasets import load_dataset import verifiers as vf -from qwen_vl_utils import smart_resize +from qwen_vl_utils import process_vision_info """ # install qwen stuff @@ -26,7 +26,7 @@ def data_collator(batch: list[dict]) -> list[dict]: processed_samples = [] for sample in batch: messages = [] - messages.append({"role": "system", "content": SYSTEM_PROMPT}) + messages.append({"role": "system", "content": system_prompt}) content_block = [] content_block.append({"type": "text", "text": sample["question"]}) content_block.append( diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 28946cd5e..28b08c28c 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -11,7 +11,7 @@ class BatchRequest: """Request for batch generation""" batch_id: int - env_inputs: Dict[str, List[Any]] + env_inputs: Dict[str, List[Any] | None] processing_class: Any mask_env_responses: bool max_completion_length: int diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 423925dec..bb9c39dbd 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -14,7 +14,6 @@ from peft import PeftConfig, get_peft_model from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM -from transformers.data.data_collator import DataCollator from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -156,7 +155,7 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional[PeftConfig] = None, - data_collator: Optional[DataCollator] = None, + data_collator: Optional[Any] = None, **kwargs, ): self.logger = logging.getLogger(__name__) @@ -592,7 +591,7 @@ def _ids_to_tensors(self, 'mask': mask } - def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any], List[Any]]: + def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any] | None, List[Any], List[Any]]: """ Gather batch data from all processes. @@ -609,14 +608,17 @@ def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any # Gather batch data from all processes prompts = [x['prompt'] for x in batch] + images = [x['images'] for x in batch if 'images' in x] answers = [x['answer'] for x in batch] tasks = [x.get('task', 'default') for x in batch] all_prompts = gather_object(prompts) + all_images = gather_object(images) + all_images = all_images if all_images != [] else None all_answers = gather_object(answers) all_tasks = gather_object(tasks) - return all_prompts, all_answers, all_tasks + return all_prompts, all_images, all_answers, all_tasks def _prepare_inputs( # type: ignore self, inputs: list[dict[str, Any]] @@ -631,7 +633,6 @@ def _prepare_inputs( # type: ignore """ # Ensure all processes are synchronized at the start self.accelerator.wait_for_everyone() - import pdb;pdb.set_trace() # inputs = list of dicts for all gradient accumulation steps generate_every = self.gradient_accumulation_steps * self.num_iterations @@ -664,14 +665,14 @@ def _prepare_inputs( # type: ignore for batch_id in range(self._next_batch_id, target_batch_id + 1): batch_offset = batch_id - batch_id_to_retrieve - all_prompts, all_answers, all_tasks = self._gather_batch_data(batch_offset) + all_prompts, all_images, all_answers, all_tasks = self._gather_batch_data(batch_offset) local_batch_size = len(all_prompts) // self.accelerator.num_processes # Submit batch (main process only) if self.accelerator.is_main_process: request = BatchRequest( batch_id=batch_id, - env_inputs={'prompt': all_prompts, 'answer': all_answers, 'task': all_tasks}, + env_inputs={'prompt': all_prompts, 'images': all_images, 'answer': all_answers, 'task': all_tasks}, processing_class=self.processing_class, mask_env_responses=self.mask_env_responses, max_completion_length=self.max_completion_length, From 6ed448dab407f65c3dcdea5ddc7b2cd57a25ef0e Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 11 Jun 2025 08:32:19 +0000 Subject: [PATCH 33/48] post-process images --- verifiers/envs/environment.py | 1 - verifiers/examples/docvqa.py | 5 +- verifiers/trainers/async_batch_generator.py | 4 +- verifiers/trainers/grpo_trainer.py | 88 ++++++++++++++++++--- verifiers/utils/logging_utils.py | 1 + 5 files changed, 84 insertions(+), 15 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index f1ba2e039..b04fd647c 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -387,7 +387,6 @@ def generate(self, **kwargs ) results['completion'] = [rollout[0] for rollout in rollouts] - import pdb;pdb.set_trace() results['state'] = [rollout[1] for rollout in rollouts] if 'task' not in results: results['task'] = ['default'] * len(results['prompt']) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index c061799ab..ecfa28392 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -33,10 +33,11 @@ def data_collator(batch: list[dict]) -> list[dict]: { "type": "image", "image": sample["image"], # only one image in this ds - "resized_height": 768, # XGA resolution - "resized_width": 1024, + "resized_height": 480, # VGA resolution + "resized_width": 640, } ) + # content_block.append({"type": "text", "text": sample["question"]}) messages.append({"role": "user", "content": content_block}) processed_images, *_ = process_vision_info( # process with qwen utils messages.copy() diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 28b08c28c..c6ca1ad80 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -33,6 +33,7 @@ class BatchResult: all_reward_dict: Dict[str, List[float]] = field(default_factory=dict) # All reward scores completions: List[Any] = field(default_factory=list) # Store completions for logging prompts: List[Any] = field(default_factory=list) # Store prompts for logging + images: List[Any] | None = None # Store images for further processing class AsyncBatchGenerator: @@ -253,5 +254,6 @@ def _generate_batch(self, request: BatchRequest) -> BatchResult: processed_results=processed_results, all_reward_dict=all_reward_dict, completions=env_results['completion'], - prompts=env_results['prompt'] + prompts=env_results['prompt'], + images=env_results['images'], ) \ No newline at end of file diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index bb9c39dbd..b3d2559bb 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -116,6 +116,54 @@ def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[ permutation = torch.randperm(batch_size) return {key: tensor[permutation] if tensor is not None else None for key, tensor in tensor_dict.items()} +def shuffle_data_dict(data_dict: dict[str, Any]) -> dict[str, Any]: + """ + Shuffles a dictionary of tensors or lists along the first dimension in unison. + """ + first_item = next(item for item in data_dict.values() if item is not None) + batch_size = len(first_item) + permutation = torch.randperm(batch_size) + + shuffled_dict = {} + for key, value in data_dict.items(): + if value is None: + shuffled_dict[key] = None + elif isinstance(value, torch.Tensor): + shuffled_dict[key] = value[permutation] + elif isinstance(value, list): + shuffled_dict[key] = [value[i] for i in permutation] + else: + raise TypeError(f"Unsupported type for shuffling: {type(value)}") + return shuffled_dict + +def split_data_dict( + data_dict: dict[str, Any], num_chunks: int +) -> list[dict[str, Any]]: + """ + Splits a dictionary of tensors or lists along the first dimension into `num_chunks` equal parts. + """ + first_item = next(item for item in data_dict.values() if item is not None) + # Ensure chunk_size is an integer + chunk_size = len(first_item) // num_chunks + if len(first_item) % num_chunks != 0: + logging.warning( + f"The total number of samples ({len(first_item)}) is not divisible by the number of chunks ({num_chunks}). " + f"The last {len(first_item) % num_chunks} samples will be dropped." + ) + + chunked_list = [] + for i in range(num_chunks): + chunk = {} + start_idx = i * chunk_size + end_idx = (i + 1) * chunk_size + for key, value in data_dict.items(): + if value is None: + chunk[key] = None + else: + chunk[key] = value[start_idx:end_idx] + chunked_list.append(chunk) + return chunked_list + def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors. @@ -719,6 +767,7 @@ def _prepare_inputs( # type: ignore 'all_reward_dict': batch_result.all_reward_dict if hasattr(batch_result, 'all_reward_dict') else {'reward': processed_results['rewards']}, 'completions': batch_result.completions if hasattr(batch_result, 'completions') else [], 'prompts': batch_result.prompts if hasattr(batch_result, 'prompts') else [], + 'images': batch_result.images if hasattr(batch_result, 'images') else [], } else: broadcast_data = None @@ -774,6 +823,8 @@ def _prepare_inputs( # type: ignore # Take this process's slice of advantages advantages = all_advantages[process_slice] + + images = broadcast_data['images'][process_slice] # Log metrics on main process only if self.accelerator.is_main_process: @@ -797,7 +848,7 @@ def _prepare_inputs( # type: ignore all_completion_ids=broadcast_data['completion_ids'], all_prompt_mask=broadcast_data['prompt_mask'] ) - + # Concatenate all data for shuffling full_batch = { "prompt_ids": prompt_ids, @@ -806,11 +857,12 @@ def _prepare_inputs( # type: ignore "completion_mask": completion_mask, "old_per_token_logps": None, "advantages": advantages, + "images": images, } # Shuffle and split for gradient accumulation - full_batch = shuffle_tensor_dict(full_batch) - self._buffered_inputs = split_tensor_dict(full_batch, self.gradient_accumulation_steps) + full_batch = shuffle_data_dict(full_batch) + self._buffered_inputs = split_data_dict(full_batch, self.gradient_accumulation_steps) self.accelerator.wait_for_everyone() # Return appropriate slice from buffer result = self._buffered_inputs[self._step % self.gradient_accumulation_steps] @@ -847,17 +899,30 @@ def compute_loss(self, # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + images = inputs.get('images') + model_kwargs = {} + if images is not None: + prompt_texts = self.processing_class.batch_decode(prompt_ids) + prompt_inputs = self.processing_class( + text=prompt_texts, + images=images, + return_tensors="pt", + padding=True, + padding_side="left", + add_special_tokens=False, + ).to(self.accelerator.device) + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] # TODO: remove this. these should come from previous processing + model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature( + model.get_base_model().forward + ).parameters.keys() + ) + model_kwargs = {k: prompt_inputs[k] for k in model_kwarg_keys if k in prompt_inputs and k not in ["input_ids", "attention_mask"]} input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - model_kwarg_keys = ( - inspect.signature(model.forward).parameters.keys() - if not hasattr(model, "get_base_model") - else inspect.signature( - model.get_base_model().forward - ).parameters.keys() - ) - model_kwargs = {k: inputs[k] for k in model_kwarg_keys if k in inputs} per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, **model_kwargs) # Compute the loss @@ -1061,6 +1126,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: last_message = messages[-1] content = last_message.get("content", "") if isinstance(content, list): + # content = content[-1]["text"] # extract text only in multimodal case content = content[0]["text"] # extract text only in multimodal case prompt.append([{'role': 'user', 'content': content}]) table = { diff --git a/verifiers/utils/logging_utils.py b/verifiers/utils/logging_utils.py index 454d07537..f16dc6e3e 100644 --- a/verifiers/utils/logging_utils.py +++ b/verifiers/utils/logging_utils.py @@ -80,6 +80,7 @@ def print_prompt_completions_sample( last_message = prompt[-1] content = last_message.get("content", "") if isinstance(content, list): # multimodal case + # content = content[-1]["text"] content = content[0]["text"] formatted_prompt = Text(content, style="bright_yellow") else: From 24ed694c67f041376ebf5749260db4efd16c11fe Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 11 Jun 2025 09:08:15 +0000 Subject: [PATCH 34/48] fix text position --- verifiers/examples/docvqa.py | 12 ------------ verifiers/trainers/grpo_trainer.py | 1 - verifiers/utils/logging_utils.py | 1 - 3 files changed, 14 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index ecfa28392..e11afe6d0 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -15,13 +15,6 @@ """ -# def preprocess_docvqa(x): -# return { -# "question": x["question"], -# "images": [x["image"].resize(smart_resize(768, 1024))], # XGA -# "answer": x["answers"][0], -# } - def data_collator(batch: list[dict]) -> list[dict]: processed_samples = [] for sample in batch: @@ -37,7 +30,6 @@ def data_collator(batch: list[dict]) -> list[dict]: "resized_width": 640, } ) - # content_block.append({"type": "text", "text": sample["question"]}) messages.append({"role": "user", "content": content_block}) processed_images, *_ = process_vision_info( # process with qwen utils messages.copy() @@ -50,9 +42,6 @@ def data_collator(batch: list[dict]) -> list[dict]: dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation") -# dataset = dataset.map( -# preprocess_docvqa, num_proc=10, remove_columns=dataset.column_names -# ) parser = vf.XMLParser(["think", "answer"], answer_field="answer") system_prompt = f"""Answer the questions. @@ -115,7 +104,6 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: run_name = "docvqa_" + model_name.split("/")[-1].lower() training_args = vf.grpo_defaults(run_name=run_name) -training_args.learning_rate = 3e-6 training_args.max_steps = -1 trainer = vf.GRPOTrainer( diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index b3d2559bb..39f516750 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1126,7 +1126,6 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: last_message = messages[-1] content = last_message.get("content", "") if isinstance(content, list): - # content = content[-1]["text"] # extract text only in multimodal case content = content[0]["text"] # extract text only in multimodal case prompt.append([{'role': 'user', 'content': content}]) table = { diff --git a/verifiers/utils/logging_utils.py b/verifiers/utils/logging_utils.py index f16dc6e3e..454d07537 100644 --- a/verifiers/utils/logging_utils.py +++ b/verifiers/utils/logging_utils.py @@ -80,7 +80,6 @@ def print_prompt_completions_sample( last_message = prompt[-1] content = last_message.get("content", "") if isinstance(content, list): # multimodal case - # content = content[-1]["text"] content = content[0]["text"] formatted_prompt = Text(content, style="bright_yellow") else: From ffee1d605becbae40eba0f4c469a2e199cbe670e Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 11 Jun 2025 17:29:45 +0000 Subject: [PATCH 35/48] process inputs in environment --- verifiers/envs/environment.py | 153 +++++++++++++------- verifiers/trainers/async_batch_generator.py | 4 +- verifiers/trainers/grpo_trainer.py | 51 +++---- 3 files changed, 126 insertions(+), 82 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index b04fd647c..3f0ca85fb 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -406,6 +406,7 @@ def generate(self, def process_chat_format( self, prompt: List[Dict[str, str]], + images: Optional[List[List[Any]]], completion: List[Dict[str, str]], processing_class: PreTrainedTokenizerBase, mask_env_responses: bool = False @@ -421,59 +422,105 @@ def process_chat_format( Returns: prompt_ids, prompt_mask, completion_ids, completion_mask """ - # tokenize just the prompt - prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) - assert isinstance(prompt_text, str) - if hasattr(processing_class, "tokenizer"): - encode = processing_class.tokenizer.encode - else: - encode = processing_class.encode - prompt_ids = encode(prompt_text) - prompt_mask = [1] * len(prompt_ids) - - # track completion tokens and masks by processing incrementally completion_ids = [] completion_mask = [] - - # previous tokenization (starts with just prompt) - prev_ids = prompt_ids - - # process each completion message incrementally - for i, msg in enumerate(completion): - # create conversation prefix: prompt + completion[:i+1] - conversation_prefix = prompt + completion[:i+1] + remaining_inputs = {} + if images: + # Tokenize the prompt with images to establish the initial state. + prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + # The multimodal processor call requires both text and images. + inputs = processing_class(text=prompt_text, images=images, return_tensors="pt") + remaining_inputs = { + k: v + for k, v in inputs.items() + if k not in ["input_ids", "attention_mask"] + } + prev_ids = inputs.input_ids[0].tolist() + prompt_ids = prev_ids + prompt_mask = [1] * len(prompt_ids) + + # Process each completion message incrementally. + for i, msg in enumerate(completion): + conversation_prefix = prompt + completion[:i+1] + + # Get the full text representation of the conversation up to this point. + prefix_text = processing_class.apply_chat_template( + conversation_prefix, + tokenize=False, + add_generation_prompt=False, + ) + + # Tokenize the new prefix, passing the images each time. + current_ids = processing_class(text=prefix_text, images=images, return_tensors="pt").input_ids[0].tolist() + assert current_ids[:len(prev_ids)] == prev_ids, "Tokenization difference in chat format." + + new_tokens = current_ids[len(prev_ids):] + completion_ids.extend(new_tokens) + + # Create mask for the new tokens. + if msg["role"] == "assistant": + msg_mask = [1] * len(new_tokens) + elif msg["role"] != "assistant" and mask_env_responses: + msg_mask = [0] * len(new_tokens) + else: + msg_mask = [1] * len(new_tokens) + + completion_mask.extend(msg_mask) + prev_ids = current_ids + else: + # tokenize just the prompt + prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + assert isinstance(prompt_text, str) + if hasattr(processing_class, "tokenizer"): + encode = processing_class.tokenizer.encode + else: + encode = processing_class.encode + prompt_ids = encode(prompt_text) + prompt_mask = [1] * len(prompt_ids) - # tokenize the full prefix - prefix_text = processing_class.apply_chat_template( - conversation_prefix, - tokenize=False, - add_generation_prompt=False, - ) - assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}" - current_ids = encode(prefix_text) - assert current_ids[:len(prev_ids)] == prev_ids, f"Tokenization difference in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" + # track completion tokens and masks by processing incrementally + completion_ids = [] + completion_mask = [] - # add new tokens to completion tokens - new_tokens = current_ids[len(prev_ids):] - assert len(new_tokens) > 0, f"No new tokens in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" - completion_ids.extend(new_tokens) - - # create mask - if msg["role"] == "assistant": - msg_mask = [1] * len(new_tokens) - elif msg["role"] != "assistant" and mask_env_responses: - # mask intermediate 'user' and/or 'tool' messages - msg_mask = [0] * len(new_tokens) - else: - # default to not masking - msg_mask = [1] * len(new_tokens) + # previous tokenization (starts with just prompt) + prev_ids = prompt_ids - completion_mask.extend(msg_mask) - # Update previous tokenization for next iteration - prev_ids = current_ids - assert len(completion_ids) == len(completion_mask), f"Length mismatch in chat format. Completion ids: {completion_ids}, completion mask: {completion_mask}" + # process each completion message incrementally + for i, msg in enumerate(completion): + # create conversation prefix: prompt + completion[:i+1] + conversation_prefix = prompt + completion[:i+1] + + # tokenize the full prefix + prefix_text = processing_class.apply_chat_template( + conversation_prefix, + tokenize=False, + add_generation_prompt=False, + ) + assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}" + current_ids = encode(prefix_text) + assert current_ids[:len(prev_ids)] == prev_ids, f"Tokenization difference in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" + + # add new tokens to completion tokens + new_tokens = current_ids[len(prev_ids):] + assert len(new_tokens) > 0, f"No new tokens in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" + completion_ids.extend(new_tokens) + + # create mask + if msg["role"] == "assistant": + msg_mask = [1] * len(new_tokens) + elif msg["role"] != "assistant" and mask_env_responses: + # mask intermediate 'user' and/or 'tool' messages + msg_mask = [0] * len(new_tokens) + else: + # default to not masking + msg_mask = [1] * len(new_tokens) + + completion_mask.extend(msg_mask) + # Update previous tokenization for next iteration + prev_ids = current_ids + assert len(completion_ids) == len(completion_mask), f"Length mismatch in chat format. Completion ids: {completion_ids}, completion mask: {completion_mask}" - return prompt_ids, prompt_mask, completion_ids, completion_mask + return prompt_ids, prompt_mask, completion_ids, completion_mask, remaining_inputs def process_completion_format( self, @@ -509,6 +556,7 @@ def process_completion_format( def process_env_results( self, prompts: List[Union[str, List[Dict[str, Any]]]], + images: Optional[List[List[Any]]], completions: List[Union[str, List[Dict[str, Any]]]], states: List[Dict[str, Any]], rewards: List[float], @@ -532,13 +580,16 @@ def process_env_results( all_prompt_masks = [] all_completion_ids = [] all_completion_masks = [] + all_remaining_inputs = [] + + images = images or [None] * len(prompts) - for i, (prompt, completion, state, reward) in enumerate(zip(prompts, completions, states, rewards)): + for i, (prompt, images, completion, state, reward) in enumerate(zip(prompts, images, completions, states, rewards)): # Format-specific processing if is_chat_format: assert isinstance(prompt, list) and isinstance(completion, list) - prompt_ids, prompt_mask, completion_ids, completion_mask = self.process_chat_format( - prompt, completion, processing_class, mask_env_responses + prompt_ids, prompt_mask, completion_ids, completion_mask, remaining_inputs = self.process_chat_format( + prompt, images, completion, processing_class, mask_env_responses ) else: assert isinstance(prompt, str) and isinstance(completion, str) @@ -552,6 +603,7 @@ def process_env_results( all_prompt_masks.append(prompt_mask) all_completion_ids.append(completion_ids) all_completion_masks.append(completion_mask) + all_remaining_inputs.append(remaining_inputs) return { "prompt_ids": all_prompt_ids, @@ -559,6 +611,7 @@ def process_env_results( "completion_ids": all_completion_ids, "completion_mask": all_completion_masks, "rewards": rewards, + "remaining_inputs": all_remaining_inputs, } # Evaluation and dataset generation diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index c6ca1ad80..450643c8b 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -240,6 +240,7 @@ def _generate_batch(self, request: BatchRequest) -> BatchResult: # Process results processed_results = self.env.process_env_results( env_results['prompt'], + env_results['images'], env_results['completion'], env_results['state'], env_results['reward'], @@ -248,12 +249,11 @@ def _generate_batch(self, request: BatchRequest) -> BatchResult: max_completion_length=request.max_completion_length, mask_truncated_completions=request.mask_truncated_completions ) - + return BatchResult( batch_id=request.batch_id, processed_results=processed_results, all_reward_dict=all_reward_dict, completions=env_results['completion'], prompts=env_results['prompt'], - images=env_results['images'], ) \ No newline at end of file diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 39f516750..a1bae4282 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -520,15 +520,24 @@ def _get_last_hidden_state(self, unwrapped_model, input_ids, attention_mask, log def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None, **model_kwargs) -> torch.Tensor: batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] - if _accepts_logits_to_keep(model): - model_kwargs["logits_to_keep"] = logits_to_keep + 1 for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] - + model_kwargs_batch = {} + for key, value in model_kwargs.items(): + if isinstance(value, list): + # 1. Slice the list to get the tensors for this micro-batch + sub_list = value[i : i + batch_size] + # 2. Batch the tensors in the sub-list together + model_kwargs_batch[key] = torch.cat(sub_list, dim=0).to(self.accelerator.device) + else: + # Handle non-list arguments (like the 'logits_to_keep' we added) + model_kwargs_batch[key] = value + if _accepts_logits_to_keep(model): + model_kwargs_batch["logits_to_keep"] = logits_to_keep + 1 # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model( - input_ids=input_ids_batch, attention_mask=attention_mask_batch, **model_kwargs + input_ids=input_ids_batch, attention_mask=attention_mask_batch, **model_kwargs_batch ).logits logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids_batch = input_ids_batch[:, -logits_to_keep:] @@ -764,10 +773,10 @@ def _prepare_inputs( # type: ignore 'completion_ids': processed_results['completion_ids'], 'completion_mask': processed_results['completion_mask'], 'rewards': processed_results['rewards'], + 'remaining_inputs': processed_results['remaining_inputs'], 'all_reward_dict': batch_result.all_reward_dict if hasattr(batch_result, 'all_reward_dict') else {'reward': processed_results['rewards']}, 'completions': batch_result.completions if hasattr(batch_result, 'completions') else [], 'prompts': batch_result.prompts if hasattr(batch_result, 'prompts') else [], - 'images': batch_result.images if hasattr(batch_result, 'images') else [], } else: broadcast_data = None @@ -824,8 +833,9 @@ def _prepare_inputs( # type: ignore # Take this process's slice of advantages advantages = all_advantages[process_slice] - images = broadcast_data['images'][process_slice] - + # slice remaining inputs + remaining_inputs = broadcast_data['remaining_inputs'][process_slice] + # Log metrics on main process only if self.accelerator.is_main_process: self._log_reward_metrics_primary( @@ -857,7 +867,7 @@ def _prepare_inputs( # type: ignore "completion_mask": completion_mask, "old_per_token_logps": None, "advantages": advantages, - "images": images, + "remaining_inputs": remaining_inputs, } # Shuffle and split for gradient accumulation @@ -899,31 +909,12 @@ def compute_loss(self, # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - images = inputs.get('images') - model_kwargs = {} - if images is not None: - prompt_texts = self.processing_class.batch_decode(prompt_ids) - prompt_inputs = self.processing_class( - text=prompt_texts, - images=images, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - ).to(self.accelerator.device) - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] # TODO: remove this. these should come from previous processing - model_kwarg_keys = ( - inspect.signature(model.forward).parameters.keys() - if not hasattr(model, "get_base_model") - else inspect.signature( - model.get_base_model().forward - ).parameters.keys() - ) - model_kwargs = {k: prompt_inputs[k] for k in model_kwarg_keys if k in prompt_inputs and k not in ["input_ids", "attention_mask"]} + model_kwargs = inputs["remaining_inputs"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, **model_kwargs) + model_kwargs = {key: [d[key] for d in model_kwargs] for key in model_kwargs[0].keys()} + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, batch_size = 1 if model_kwargs != {} else None, **model_kwargs) # Compute the loss advantages = inputs["advantages"] From e0026a8dda9b4051632ee35587a799dcd438ed34 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Thu, 12 Jun 2025 03:20:40 +0000 Subject: [PATCH 36/48] increase res and lr --- verifiers/examples/docvqa.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index e11afe6d0..e05ec6d0a 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -26,8 +26,8 @@ def data_collator(batch: list[dict]) -> list[dict]: { "type": "image", "image": sample["image"], # only one image in this ds - "resized_height": 480, # VGA resolution - "resized_width": 640, + "resized_height": 768, # XGA resolution + "resized_width": 1024, } ) messages.append({"role": "user", "content": content_block}) @@ -104,6 +104,7 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: run_name = "docvqa_" + model_name.split("/")[-1].lower() training_args = vf.grpo_defaults(run_name=run_name) +training_args.learning_rate = 3e-6 training_args.max_steps = -1 trainer = vf.GRPOTrainer( From f51c6659502623e604239fd9b64a5b994d70aa86 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Thu, 12 Jun 2025 03:47:28 +0000 Subject: [PATCH 37/48] fix --- verifiers/envs/environment.py | 24 ---------- verifiers/trainers/async_batch_generator.py | 5 +- verifiers/trainers/grpo_trainer.py | 51 +-------------------- 3 files changed, 4 insertions(+), 76 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 3f0ca85fb..29826d3f7 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -30,17 +30,6 @@ def format_oai_chat_msg( prompts: List[List[Dict[str, Any]]], images: List[List[Image.Image]] ) -> List[List[Dict[str, Any]]]: - """ - Given: - - prompts: a list (for each convo) of lists of message‐dicts, - where a user message's content may be a list of parts, - some with {'type': 'image', 'image': PIL.Image, ...}. - - images: a parallel list (for each convo) of lists of PIL.Image objects. - - Returns: - A new list of the same shape, but with every image part replaced by: - {'type': 'image_url', 'image_url': DATA_URL} - """ formatted_conversations: List[List[Dict[str, Any]]] = [] for conv_prompts, conv_images in zip(prompts, images): @@ -51,12 +40,10 @@ def format_oai_chat_msg( role = msg["role"] content = msg["content"] - # If this message's content is a list of parts (text/image) if isinstance(content, list): new_parts: List[Dict[str, Any]] = [] for part in content: if part.get("type") == "image": - # grab the next PIL.Image from the images list img = next(img_iter) data_url = _pil_to_data_url(img) new_parts.append({ @@ -64,12 +51,10 @@ def format_oai_chat_msg( "image_url": {"url": data_url} }) else: - # leave text (or any other part) untouched new_parts.append(part.copy()) new_conv.append({"role": role, "content": new_parts}) else: - # system or assistant messages with string content new_conv.append({"role": role, "content": content}) formatted_conversations.append(new_conv) @@ -426,9 +411,7 @@ def process_chat_format( completion_mask = [] remaining_inputs = {} if images: - # Tokenize the prompt with images to establish the initial state. prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) - # The multimodal processor call requires both text and images. inputs = processing_class(text=prompt_text, images=images, return_tensors="pt") remaining_inputs = { k: v @@ -439,25 +422,18 @@ def process_chat_format( prompt_ids = prev_ids prompt_mask = [1] * len(prompt_ids) - # Process each completion message incrementally. for i, msg in enumerate(completion): conversation_prefix = prompt + completion[:i+1] - - # Get the full text representation of the conversation up to this point. prefix_text = processing_class.apply_chat_template( conversation_prefix, tokenize=False, add_generation_prompt=False, ) - - # Tokenize the new prefix, passing the images each time. current_ids = processing_class(text=prefix_text, images=images, return_tensors="pt").input_ids[0].tolist() assert current_ids[:len(prev_ids)] == prev_ids, "Tokenization difference in chat format." - new_tokens = current_ids[len(prev_ids):] completion_ids.extend(new_tokens) - # Create mask for the new tokens. if msg["role"] == "assistant": msg_mask = [1] * len(new_tokens) elif msg["role"] != "assistant" and mask_env_responses: diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 450643c8b..6e73af494 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -33,7 +33,6 @@ class BatchResult: all_reward_dict: Dict[str, List[float]] = field(default_factory=dict) # All reward scores completions: List[Any] = field(default_factory=list) # Store completions for logging prompts: List[Any] = field(default_factory=list) # Store prompts for logging - images: List[Any] | None = None # Store images for further processing class AsyncBatchGenerator: @@ -249,11 +248,11 @@ def _generate_batch(self, request: BatchRequest) -> BatchResult: max_completion_length=request.max_completion_length, mask_truncated_completions=request.mask_truncated_completions ) - + return BatchResult( batch_id=request.batch_id, processed_results=processed_results, all_reward_dict=all_reward_dict, completions=env_results['completion'], - prompts=env_results['prompt'], + prompts=env_results['prompt'] ) \ No newline at end of file diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index a1bae4282..c33146547 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -68,54 +68,6 @@ def nanstd(tensor: torch.Tensor) -> torch.Tensor: variance *= count / (count - 1) # Bessel's correction return torch.sqrt(variance) -def split_tensor_dict( - tensor_dict: dict[str, Optional[torch.Tensor]], num_chunks: int -) -> list[dict[str, Optional[torch.Tensor]]]: - """ - Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts. - - Example: - >>> x = torch.arange(12).reshape(6, 2) - >>> y = torch.arange(6).reshape(6, 1) - >>> tensor_dict = {"x": x, "y": y} - >>> split_tensor_dict(tensor_dict, 3) - [ - {"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])}, - {"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])}, - {"x": tensor([[ 8, 9], [10, 11]]), "y": tensor([[4], [5]])} - ] - """ - first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) - chunk_size = first_tensor.shape[0] // num_chunks - return [ - { - key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None - for key, tensor in tensor_dict.items() - } - for i in range(num_chunks) - ] - -def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[str, Optional[torch.Tensor]]: - """ - Shuffles a dictionary of tensors along the first dimension in unison. - - Example: - >>> x = torch.arange(6).reshape(3, 2) - >>> y = torch.arange(3).reshape(3, 1) - >>> tensor_dict = {"x": x, "y": y} - >>> shuffle_tensor_dict(tensor_dict) - {'x': tensor([[2, 3], - [0, 1], - [4, 5]]), - 'y': tensor([[1], - [0], - [2]])} - """ - first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) - batch_size = first_tensor.shape[0] - permutation = torch.randperm(batch_size) - return {key: tensor[permutation] if tensor is not None else None for key, tensor in tensor_dict.items()} - def shuffle_data_dict(data_dict: dict[str, Any]) -> dict[str, Any]: """ Shuffles a dictionary of tensors or lists along the first dimension in unison. @@ -520,6 +472,7 @@ def _get_last_hidden_state(self, unwrapped_model, input_ids, attention_mask, log def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None, **model_kwargs) -> torch.Tensor: batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] + accepts_logits_to_keep = _accepts_logits_to_keep(model) for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] @@ -533,7 +486,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, else: # Handle non-list arguments (like the 'logits_to_keep' we added) model_kwargs_batch[key] = value - if _accepts_logits_to_keep(model): + if accepts_logits_to_keep: model_kwargs_batch["logits_to_keep"] = logits_to_keep + 1 # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model( From 4dd16b7209898c7f8d0d613025c0e07d9e64c3af Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 16 Jun 2025 03:17:47 +0000 Subject: [PATCH 38/48] fix eval with data collator --- verifiers/envs/environment.py | 14 +++++++++++++- verifiers/examples/docvqa.py | 21 ++++++++++++++------- verifiers/trainers/grpo_trainer.py | 21 ++++++++++++++++----- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 29826d3f7..9923d81e6 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -77,6 +77,7 @@ def __init__(self, sampling_args: Dict[str, Any] = {}, max_concurrent: int = 32, message_type: Literal['chat', 'completion'] = 'chat', + data_collator: Any | None = None, **kwargs: Any): self.client = client self.model = model @@ -84,6 +85,7 @@ def __init__(self, self.system_prompt = system_prompt self.few_shot = few_shot self.max_concurrent = max_concurrent + self.data_collator = data_collator if self.message_type == 'chat': if dataset is not None: self.dataset = self.format_dataset(dataset, self.system_prompt, self.few_shot) @@ -619,8 +621,18 @@ def evaluate(self, if num_samples > 0: inputs = inputs.select(range(num_samples)) + if self.data_collator: + batch = list(inputs) + processed_batch = self.data_collator(batch) + if not processed_batch: + processed_inputs = {} + else: + keys = processed_batch[0].keys() + processed_inputs = {key: [sample.get(key) for sample in processed_batch] for key in keys} + else: + processed_inputs = inputs results = self.generate( - inputs, client, model, sampling_args, max_concurrent, **kwargs + processed_inputs, client, model, sampling_args, max_concurrent, **kwargs ) return results diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index e05ec6d0a..dd53e21d2 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -1,9 +1,9 @@ import re from datasets import load_dataset +from qwen_vl_utils import process_vision_info import verifiers as vf -from qwen_vl_utils import process_vision_info """ # install qwen stuff @@ -25,8 +25,8 @@ def data_collator(batch: list[dict]) -> list[dict]: content_block.append( { "type": "image", - "image": sample["image"], # only one image in this ds - "resized_height": 768, # XGA resolution + "image": sample["image"], # only one image in this ds + "resized_height": 768, # XGA resolution "resized_width": 1024, } ) @@ -41,7 +41,8 @@ def data_collator(batch: list[dict]) -> list[dict]: return processed_samples -dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation") +dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[10%:]") +eval_dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[:10%]") parser = vf.XMLParser(["think", "answer"], answer_field="answer") system_prompt = f"""Answer the questions. @@ -85,7 +86,7 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: if msgs_scores == []: return 0.0 else: - return (sum(msgs_scores) / len(msgs_scores) / 2.0) + return sum(msgs_scores) / len(msgs_scores) / 2.0 rubric = vf.Rubric( @@ -96,7 +97,12 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: ) vf_env = vf.SingleTurnEnv( - dataset=dataset, system_prompt=system_prompt, parser=parser, rubric=rubric + dataset=dataset, + eval_dataset=eval_dataset, + system_prompt=system_prompt, + parser=parser, + rubric=rubric, + data_collator=data_collator, ) model_name = "Qwen/Qwen2.5-VL-3B-Instruct" @@ -106,12 +112,13 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: training_args = vf.grpo_defaults(run_name=run_name) training_args.learning_rate = 3e-6 training_args.max_steps = -1 +training_args.eval_strategy = "steps" +training_args.eval_steps = 2 trainer = vf.GRPOTrainer( model=model, processing_class=processor, env=vf_env, args=training_args, - data_collator=data_collator, ) trainer.train() diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index c33146547..8ab73905c 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -155,7 +155,6 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional[PeftConfig] = None, - data_collator: Optional[Any] = None, **kwargs, ): self.logger = logging.getLogger(__name__) @@ -245,13 +244,14 @@ def filter_by_prompt_length(example): if filtered_size < original_size: self.logger.info(f"Filtered dataset from {original_size} to {filtered_size} examples ({original_size - filtered_size} prompts were too long)") + self.data_collator = env.data_collator # dummy data collator def default_data_collator(features): return features super().__init__( model=model, args=args, - data_collator=data_collator if data_collator is not None else default_data_collator, + data_collator=self.data_collator if self.data_collator is not None else default_data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, @@ -985,7 +985,10 @@ def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval" completion_lengths = [] for comp in completions: # Apply chat template to get the full text - tokens = self.processing_class.apply_chat_template(comp, tokenize=True, add_generation_prompt=False) # type: ignore + if hasattr(self.processing_class, "tokenizer"): # if multimodal processor, use tokenizer; ow, it expects mm inputs + tokens = self.processing_class.tokenizer.apply_chat_template(comp, tokenize=True, add_generation_prompt=False) # type: ignore + else: + tokens = self.processing_class.apply_chat_template(comp, tokenize=True, add_generation_prompt=False) # type: ignore # Tokenize and count completion_lengths.append(len(tokens)) @@ -1018,10 +1021,18 @@ def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval" # Log to wandb if available if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: import pandas as pd - + # format prompt for logging + prompt = [] + if prompts: + for messages in prompts: + last_message = messages[-1] + content = last_message.get("content", "") + if isinstance(content, list): + content = content[0]["text"] # extract text only in multimodal case + prompt.append([{'role': 'user', 'content': content}]) table_data = { "step": [str(self.state.global_step)] * len(prompts), - "prompt": prompts, + "prompt": prompt, "completion": completions, } for k, v in reward_dict.items(): From 020d9d79f269bfa8c47aea5f3b6e15c96c62fcc1 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 16 Jun 2025 04:14:36 +0000 Subject: [PATCH 39/48] transform eval ds once --- verifiers/envs/environment.py | 27 ++++++++++++++------------- verifiers/examples/docvqa.py | 5 ++++- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 9923d81e6..0bd6e8c3e 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod from copy import deepcopy from typing import Any, Dict, List, Literal, Tuple, Optional, Union -import base64, io +import base64 +import io import torch @@ -104,6 +105,13 @@ def __init__(self, ) self.dataset = dataset self.eval_dataset = eval_dataset + if self.data_collator is not None and self.eval_dataset is not None: + processed_dataset = self.data_collator(list(self.eval_dataset)) + if not processed_dataset: + self.eval_dataset = {} + else: + keys = processed_dataset[0].keys() + self.eval_dataset = {key: [sample.get(key) for sample in processed_dataset] for key in keys} self.parser = parser self.rubric = rubric self.sampling_args = { @@ -619,20 +627,13 @@ def evaluate(self, else: inputs = self.eval_dataset if num_samples > 0: - inputs = inputs.select(range(num_samples)) + if isinstance(inputs, dict): + inputs = {key: value_list[:num_samples] for key, value_list in inputs.items()} + elif isinstance(inputs, Dataset): + inputs = inputs.select(range(num_samples)) - if self.data_collator: - batch = list(inputs) - processed_batch = self.data_collator(batch) - if not processed_batch: - processed_inputs = {} - else: - keys = processed_batch[0].keys() - processed_inputs = {key: [sample.get(key) for sample in processed_batch] for key in keys} - else: - processed_inputs = inputs results = self.generate( - processed_inputs, client, model, sampling_args, max_concurrent, **kwargs + inputs, client, model, sampling_args, max_concurrent, **kwargs ) return results diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index dd53e21d2..126797f80 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -76,7 +76,10 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: mean_gt_len = sum([len(gt_answer) for gt_answer in gt_answers]) / len( gt_answers ) - diff_from_mean = min(mean_gt_len / len(answer), 1.0) # penalize long answers + if len(answer) > 0: + diff_from_mean = min(mean_gt_len / len(answer), 1.0) # penalize long answers + else: + diff_from_mean = 0.0 if answer in gt_answers: msgs_scores.append(2.0) elif answer.lower() in [ans.lower() for ans in gt_answers]: From ac5682bb6e9d01e141e6cbe32dea0f07aeaaa8f6 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Mon, 16 Jun 2025 04:16:58 +0000 Subject: [PATCH 40/48] change eval steps --- verifiers/examples/docvqa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 126797f80..c3da7c835 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -116,7 +116,7 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: training_args.learning_rate = 3e-6 training_args.max_steps = -1 training_args.eval_strategy = "steps" -training_args.eval_steps = 2 +training_args.eval_steps = 100 trainer = vf.GRPOTrainer( model=model, From 2538d622e70bbf31918f946dcd0ecd5a5d9f0ab9 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 17 Jun 2025 17:30:42 +0000 Subject: [PATCH 41/48] fix batch size in func call --- verifiers/trainers/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 8ab73905c..2231e6233 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -894,12 +894,12 @@ def compute_loss(self, with torch.no_grad(): if self.ref_model is not None: ref_per_token_logps = self._get_per_token_logps( - self.ref_model, input_ids, attention_mask, logits_to_keep, **model_kwargs + self.ref_model, input_ids, attention_mask, logits_to_keep, batch_size = 1 if model_kwargs != {} else None, **model_kwargs ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): # type: ignore ref_per_token_logps = self._get_per_token_logps( - self.model, input_ids, attention_mask, logits_to_keep, **model_kwargs + self.model, input_ids, attention_mask, logits_to_keep, batch_size = 1 if model_kwargs != {} else None, **model_kwargs ) per_token_kl = ( torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 From b1b90e4473d72fa153589510e472701e54fe4f90 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 17 Jun 2025 18:26:52 +0000 Subject: [PATCH 42/48] liger patch suffix opt --- verifiers/utils/model_utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index e5af811b4..c8360595e 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -98,7 +98,7 @@ def generic_model_loader(model_id: str, **model_kwargs) -> PreTrainedModel: raise RuntimeError(f"No suitable loader found for model type {cfg.model_type!r}") -def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[str, Any], None] = None) -> Any: +def get_model(model_name: str, use_liger: bool = True, liger_patch_suffix: str | None = None, model_kwargs: Union[Dict[str, Any], None] = None) -> Any: if model_kwargs is None: model_kwargs = dict( torch_dtype=torch.bfloat16, @@ -113,8 +113,10 @@ def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[ return model except ValueError: # try monkey patch print(f"Model {model_name} is not supported with AutoLigerKernelForCausalLM. Attempting monkey patch...") - model_type = AutoConfig.from_pretrained(model_name, trust_remote_code=True).model_type - patch_func_name = f"apply_liger_kernel_to_{model_type}" + if liger_patch_suffix is None: # try with model tpe + liger_patch_suffix = AutoConfig.from_pretrained(model_name, trust_remote_code=True).model_type + print(f"No liger_patch_suffix provided, attempting with model_type: {liger_patch_suffix}") + patch_func_name = f"apply_liger_kernel_to_{liger_patch_suffix}" ligermod = importlib.import_module("liger_kernel.transformers") patch_func = getattr(ligermod, patch_func_name, None) if callable(patch_func): @@ -123,7 +125,7 @@ def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[ print(f"Applied Liger-Kernel patch to {model_name}") return model else: - raise ValueError(f"Model {model_name} is not supported with Liger-Kernel in verifiers") + raise ValueError(f"Model {model_name} may not be supported with Liger-Kernel in verifiers. Check the Liger-Kernel documentation.") else: return generic_model_loader(model_name, **model_kwargs) @@ -136,7 +138,7 @@ def get_tokenizer(model_name: str, padding_side: str = "left") -> Any: '-Instruct'. Please provide a tokenizer with the chat_template attribute.") return processor -def get_model_and_tokenizer(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[str, Any], None] = None) -> Tuple[Any, Any]: - model = get_model(model_name, use_liger, model_kwargs) +def get_model_and_tokenizer(model_name: str, use_liger: bool = True, liger_patch_suffix:str | None = None, model_kwargs: Union[Dict[str, Any], None] = None) -> Tuple[Any, Any]: + model = get_model(model_name, use_liger, liger_patch_suffix, model_kwargs) tokenizer = get_tokenizer(model_name) return model, tokenizer \ No newline at end of file From 5dbc658c6bbcd26d73b28b0187fa033cb0f60aa6 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 17 Jun 2025 22:58:37 +0000 Subject: [PATCH 43/48] load ref with generic_model_loader --- verifiers/trainers/grpo_trainer.py | 3 ++- verifiers/utils/__init__.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 2231e6233..fee49670a 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -37,6 +37,7 @@ from verifiers.trainers.async_dataloader_wrapper import AsyncDataLoaderWrapper from verifiers.utils.logging_utils import print_prompt_completions_sample from verifiers.utils.trainer_utils import RepeatSampler +from verifiers.utils.model_utils import generic_model_loader def _accepts_logits_to_keep(model) -> bool: forward = ( @@ -266,7 +267,7 @@ def default_data_collator(features): elif is_deepspeed_zero3_enabled(): model_id = model.config._name_or_path model_init_kwargs = {"torch_dtype": "auto"} - self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) + self.ref_model = generic_model_loader(model_id, **model_init_kwargs) elif is_peft_model(model): # If PEFT is used, the reference model is not needed since the adapter can be disabled # to revert to the initial model. diff --git a/verifiers/utils/__init__.py b/verifiers/utils/__init__.py index a83ce148a..ac567b6c1 100644 --- a/verifiers/utils/__init__.py +++ b/verifiers/utils/__init__.py @@ -1,6 +1,6 @@ from .data_utils import extract_boxed_answer, extract_hash_answer, load_example_dataset from .config_utils import grpo_defaults, lora_defaults -from .model_utils import get_model, get_tokenizer, get_model_and_tokenizer +from .model_utils import get_model, get_tokenizer, get_model_and_tokenizer, generic_model_loader from .logging_utils import setup_logging, print_prompt_completions_sample __all__ = [ @@ -14,4 +14,5 @@ "get_model_and_tokenizer", "setup_logging", "print_prompt_completions_sample", + "generic_model_loader", ] \ No newline at end of file From ee2e44caa8d429083a64dafe703b115db7f9f54d Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 17 Jun 2025 22:59:17 +0000 Subject: [PATCH 44/48] set use_reentrant false --- verifiers/examples/docvqa.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index c3da7c835..a91cef04c 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -117,6 +117,9 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: training_args.max_steps = -1 training_args.eval_strategy = "steps" training_args.eval_steps = 100 +training_args.gradient_checkpointing_kwargs = { + "use_reentrant": False, +} trainer = vf.GRPOTrainer( model=model, From 9ff8b7995826920b36ac220b4d486312db987ac6 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 18 Jun 2025 00:50:23 +0000 Subject: [PATCH 45/48] reset format_dataset func --- verifiers/envs/environment.py | 34 +++++++++------------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 0bd6e8c3e..d88db1740 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -151,40 +151,24 @@ def format_dataset(self, system_prompt: str | None = None, few_shot: List[Dict[str, Any]] | None = None, question_key: str = "question", - images_key: str = "images", answer_key: str = "answer") -> Dataset: # Extract format_prompt as a standalone function to avoid capturing self - def format_prompt_fn(prompt: str, images: list | None) -> List[Dict[str, Any]]: + def format_prompt_fn(prompt: str) -> List[Dict[str, Any]]: messages = [] - if images is None: - if system_prompt: - messages.append({'role': 'system', 'content': system_prompt}) - if few_shot: - messages.extend(few_shot) - messages.append({'role': 'user', 'content': prompt}) - else: - if system_prompt: - messages.append({'role': 'system', 'content': [{"type": "text", "text": system_prompt}]}) - if few_shot: - messages.extend(few_shot) - content = [{"type": "text", "text": prompt}] - for img in images: - content.append( - { - "type": "image_url", - "image_url": {"url": _pil_to_data_url(img)}, - } - ) - messages.append({"role": "user", "content": content}) + if system_prompt: + messages.append({'role': 'system', 'content': system_prompt}) + if few_shot: + messages.extend(few_shot) + messages.append({'role': 'user', 'content': prompt}) return messages - + if answer_key == "answer": return dataset.map(lambda x: { - "prompt": format_prompt_fn(x[question_key], x.get(images_key)), + "prompt": format_prompt_fn(x[question_key]), }, num_proc=self.max_concurrent) else: return dataset.map(lambda x: { - "prompt": format_prompt_fn(x[question_key], x.get(images_key)), + "prompt": format_prompt_fn(x[question_key]), "answer": x[answer_key] }, num_proc=self.max_concurrent) From 49c28bb7eb007c4cbce1785aaae0235455c2d48e Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 18 Jun 2025 03:40:47 +0000 Subject: [PATCH 46/48] format stuff --- verifiers/envs/environment.py | 47 ++++++++++++++---------------- verifiers/trainers/grpo_trainer.py | 39 +++++++++++++------------ verifiers/utils/model_utils.py | 10 ++----- 3 files changed, 44 insertions(+), 52 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index d88db1740..bb72599a8 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -3,7 +3,7 @@ from asyncio import Semaphore from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Dict, List, Literal, Tuple, Optional, Union +from typing import Any, Dict, List, Literal, Tuple, Optional, Union, Callable import base64 import io @@ -30,19 +30,19 @@ def _pil_to_data_url(img: Image.Image, fmt: str | None = None) -> str: def format_oai_chat_msg( prompts: List[List[Dict[str, Any]]], images: List[List[Image.Image]] -) -> List[List[Dict[str, Any]]]: - formatted_conversations: List[List[Dict[str, Any]]] = [] +) -> List[Any]: + formatted_conversations = [] for conv_prompts, conv_images in zip(prompts, images): img_iter = iter(conv_images) - new_conv: List[Dict[str, Any]] = [] + new_conv = [] for msg in conv_prompts: role = msg["role"] content = msg["content"] if isinstance(content, list): - new_parts: List[Dict[str, Any]] = [] + new_parts = [] for part in content: if part.get("type") == "image": img = next(img_iter) @@ -78,7 +78,7 @@ def __init__(self, sampling_args: Dict[str, Any] = {}, max_concurrent: int = 32, message_type: Literal['chat', 'completion'] = 'chat', - data_collator: Any | None = None, + data_collator: Callable | None = None, **kwargs: Any): self.client = client self.model = model @@ -177,7 +177,7 @@ def get_dataset(self, n: int = -1, seed: int = 0, **kwargs: Any) -> Dataset | No return self.dataset.shuffle(seed=seed).select(range(n)) # type: ignore return self.dataset - def get_eval_dataset(self, n: int = -1, seed: int = 0, **kwargs: Any) -> Dataset | None: + def get_eval_dataset(self, n: int = -1, seed: int = 0, **kwargs: Any) -> Dataset | dict[Any, list[Any]] | None: if n > 0 and self.eval_dataset is not None: return self.eval_dataset.shuffle(seed=seed).select(range(n)) # type: ignore return self.eval_dataset @@ -387,9 +387,9 @@ def process_chat_format( prompt: List[Dict[str, str]], images: Optional[List[List[Any]]], completion: List[Dict[str, str]], - processing_class: PreTrainedTokenizerBase, + processing_class: Any, mask_env_responses: bool = False - ) -> Tuple[List[int], List[int], List[int], List[int]]: + ) -> Tuple[List[int], List[int], List[int], List[int], dict[str, Any]]: """ Process chat format conversations using incremental prefixes. @@ -405,7 +405,9 @@ def process_chat_format( completion_mask = [] remaining_inputs = {} if images: + assert not isinstance(processing_class, PreTrainedTokenizerBase) prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + assert isinstance(prompt_text, str) inputs = processing_class(text=prompt_text, images=images, return_tensors="pt") remaining_inputs = { k: v @@ -423,6 +425,7 @@ def process_chat_format( tokenize=False, add_generation_prompt=False, ) + assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}" current_ids = processing_class(text=prefix_text, images=images, return_tensors="pt").input_ids[0].tolist() assert current_ids[:len(prev_ids)] == prev_ids, "Tokenization difference in chat format." new_tokens = current_ids[len(prev_ids):] @@ -438,14 +441,11 @@ def process_chat_format( completion_mask.extend(msg_mask) prev_ids = current_ids else: + assert isinstance(processing_class, PreTrainedTokenizerBase) # tokenize just the prompt prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) assert isinstance(prompt_text, str) - if hasattr(processing_class, "tokenizer"): - encode = processing_class.tokenizer.encode - else: - encode = processing_class.encode - prompt_ids = encode(prompt_text) + prompt_ids = processing_class.encode(prompt_text) prompt_mask = [1] * len(prompt_ids) # track completion tokens and masks by processing incrementally @@ -467,7 +467,7 @@ def process_chat_format( add_generation_prompt=False, ) assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}" - current_ids = encode(prefix_text) + current_ids = processing_class.encode(prefix_text) assert current_ids[:len(prev_ids)] == prev_ids, f"Tokenization difference in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" # add new tokens to completion tokens @@ -510,17 +510,13 @@ def process_completion_format( prompt_ids, prompt_mask, completion_ids, completion_mask """ # Tokenize prompt - if hasattr(processing_class, "tokenizer"): - encode = processing_class.tokenizer.encode - else: - encode = processing_class.encode - prompt_ids = encode(prompt) + prompt_ids = processing_class.encode(prompt) prompt_mask = [1] * len(prompt_ids) # Tokenize completion - completion_ids = encode(completion) + completion_ids = processing_class.encode(completion) completion_mask = [1] * len(completion_ids) - + return prompt_ids, prompt_mask, completion_ids, completion_mask def process_env_results( @@ -530,7 +526,7 @@ def process_env_results( completions: List[Union[str, List[Dict[str, Any]]]], states: List[Dict[str, Any]], rewards: List[float], - processing_class: PreTrainedTokenizerBase, + processing_class: Any, max_completion_length: int = -1, mask_truncated_completions: bool = False, mask_env_responses: bool = False, @@ -552,9 +548,9 @@ def process_env_results( all_completion_masks = [] all_remaining_inputs = [] - images = images or [None] * len(prompts) + input_images = images or [None] * len(prompts) - for i, (prompt, images, completion, state, reward) in enumerate(zip(prompts, images, completions, states, rewards)): + for i, (prompt, images, completion, state, reward) in enumerate(zip(prompts, input_images, completions, states, rewards)): # Format-specific processing if is_chat_format: assert isinstance(prompt, list) and isinstance(completion, list) @@ -566,6 +562,7 @@ def process_env_results( prompt_ids, prompt_mask, completion_ids, completion_mask = self.process_completion_format( prompt, completion, processing_class ) + remaining_inputs = [None] * len(prompt_ids) if mask_truncated_completions and max_completion_length > 0 and len(completion_ids) > max_completion_length: completion_ids = completion_ids[:max_completion_length] completion_mask = [0] * len(completion_ids) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index fee49670a..c4147757e 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -12,11 +12,10 @@ from torch.utils.data import DataLoader, Sampler from accelerate.utils import broadcast_object_list, gather_object, is_peft_model from peft import PeftConfig, get_peft_model -from torch.utils.data import DataLoader -from transformers import AutoModelForCausalLM from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.processing_utils import ProcessorMixin from transformers.trainer import Trainer from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import seed_worker @@ -152,7 +151,7 @@ def __init__( model: PreTrainedModel, env: Environment, args: GRPOConfig, - processing_class: PreTrainedTokenizerBase, + processing_class: ProcessorMixin, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional[PeftConfig] = None, @@ -171,9 +170,10 @@ def __init__( # Suppress irrelevant warning model.warnings_issued["estimate_tokens"] = True + self.tokenizer_base = getattr(processing_class, "tokenizer", None) # extract tokenizer in multimodal case # Tokenizer pad token - if hasattr(processing_class, "tokenizer"): - if processing_class.tokenizer.pad_token is None: + if self.tokenizer_base is not None: + if processing_class.tokenizer.pad_token is None: # type: ignore processing_class.tokenizer.pad_token = processing_class.tokenizer.eos_token # type: ignore else: if processing_class.pad_token is None: # type: ignore @@ -216,8 +216,6 @@ def __init__( train_dataset = env.get_dataset() assert train_dataset is not None - eval_dataset = env.get_eval_dataset() - # Filter out prompts that are too long if max_prompt_length is set if self.max_prompt_length is not None: self.logger.info(f"Filtering dataset for prompts with length <= {self.max_prompt_length}") @@ -232,9 +230,10 @@ def filter_by_prompt_length(example): else: # Completion format prompt_text = prompt - if hasattr(processing_class, "tokenizer"): - encode = processing_class.tokenizer.encode + if self.tokenizer_base is not None: + encode = self.tokenizer_base.encode else: + assert isinstance(processing_class, PreTrainedTokenizerBase) encode = processing_class.encode prompt_ids = encode(prompt_text) # type: ignore return len(prompt_ids) <= max_length @@ -245,16 +244,15 @@ def filter_by_prompt_length(example): if filtered_size < original_size: self.logger.info(f"Filtered dataset from {original_size} to {filtered_size} examples ({original_size - filtered_size} prompts were too long)") - self.data_collator = env.data_collator # dummy data collator def default_data_collator(features): return features super().__init__( model=model, args=args, - data_collator=self.data_collator if self.data_collator is not None else default_data_collator, + data_collator=env.data_collator if env.data_collator is not None else default_data_collator, train_dataset=train_dataset, - eval_dataset=eval_dataset, + eval_dataset=datasets.Dataset.from_dict({}), # dummy eval ds. This is actually handled by environment processing_class=processing_class, callbacks=callbacks, optimizers=optimizers, @@ -766,9 +764,10 @@ def _prepare_inputs( # type: ignore completion_mask_list.append(torch.tensor(broadcast_data['completion_mask'][i], device=self.accelerator.device)) # Pad sequences - if hasattr(self.processing_class, "tokenizer"): - pad_token_id = self.processing_class.tokenizer.pad_token_id + if self.tokenizer_base is not None: + pad_token_id = self.tokenizer_base.pad_token_id else: + assert isinstance(self.processing_class, PreTrainedTokenizerBase) pad_token_id = self.processing_class.pad_token_id prompt_ids = pad(prompt_ids_list, padding_value=pad_token_id, padding_side='left') # type: ignore prompt_mask = pad(prompt_mask_list, padding_side='left') # type: ignore @@ -856,7 +855,7 @@ def _compute_advantages( def compute_loss(self, model: PreTrainedModel, - inputs: Dict[str, torch.Tensor], + inputs: Dict[str, Any], return_outputs: bool = False, num_items_in_batch: int | None = None) -> torch.Tensor: mode = "train" @@ -976,9 +975,10 @@ def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval" completions = eval_results['completion'] if isinstance(completions[0], str): # Completion format - directly tokenize strings - if hasattr(self.processing_class, "tokenizer"): - encode = self.processing_class.tokenizer.encode + if self.tokenizer_base is not None: + encode = self.tokenizer_base.encode else: + assert isinstance(self.processing_class, PreTrainedTokenizerBase) encode = self.processing_class.encode completion_lengths = [len(encode(c)) for c in completions] # type: ignore else: @@ -1175,9 +1175,10 @@ def _log_completion_metrics_primary( # Check for EOS tokens term_lengths = [] - if hasattr(self.processing_class, "tokenizer"): - eos_token_id = self.processing_class.tokenizer.eos_token_id + if self.tokenizer_base is not None: + eos_token_id = self.tokenizer_base.eos_token_id else: + assert isinstance(self.processing_class, PreTrainedTokenizerBase) eos_token_id = self.processing_class.eos_token_id for comp_ids, comp_mask in zip(all_completion_ids, all_completion_mask): has_eos = any(token == eos_token_id for token, mask in zip(comp_ids, comp_mask) if mask) # type: ignore diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index c8360595e..21894439a 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -4,7 +4,8 @@ from typing import Dict, Any, Union, Tuple, Callable import torch -from transformers import AutoModelForCausalLM, AutoProcessor, AutoConfig, PreTrainedModel # type: ignore +from transformers import AutoModelForCausalLM, AutoModel, AutoProcessor, AutoConfig, PreTrainedModel # type: ignore +from transformers.models.auto.modeling_auto import AutoModelForSeq2SeqLM, AutoModelForVision2Seq import torch.nn as nn @@ -74,13 +75,6 @@ def generic_model_loader(model_id: str, **model_kwargs) -> PreTrainedModel: except (AttributeError, ImportError, ValueError): pass - from transformers import ( - AutoModel, - AutoModelForCausalLM, - AutoModelForSeq2SeqLM, - AutoModelForVision2Seq, - ) - for auto_cls in ( AutoModelForCausalLM, AutoModelForSeq2SeqLM, From cc69cc6e739b189ea87c62d1106d36f3df15ac59 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 18 Jun 2025 03:51:07 +0000 Subject: [PATCH 47/48] rase error --- verifiers/envs/environment.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index bb72599a8..d2ee01681 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -558,6 +558,8 @@ def process_env_results( prompt, images, completion, processing_class, mask_env_responses ) else: + if images is not None: + raise NotImplementedError("Multi-modal training is not supported with completion formats yet") assert isinstance(prompt, str) and isinstance(completion, str) prompt_ids, prompt_mask, completion_ids, completion_mask = self.process_completion_format( prompt, completion, processing_class From d41d00f7061890514f4375923022639b414e4f08 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 18 Jun 2025 22:30:01 +0000 Subject: [PATCH 48/48] update example --- verifiers/examples/docvqa.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index a91cef04c..6b4c8da29 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -9,9 +9,9 @@ # install qwen stuff uv pip install qwen-vl-utils # inference -CUDA_VISIBLE_DEVICES=0 uv run vf-vllm --model 'Qwen/Qwen2.5-VL-3B-Instruct' --max-model-len 64000 +CUDA_VISIBLE_DEVICES=0,1,2,3 vf-vllm --model 'Qwen/Qwen2.5-VL-7B-Instruct' --max-model-len 32000 --tensor_parallel_size 4 # train -CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py +CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config-file configs/zero3.yaml --num-processes 4 verifiers/examples/docvqa.py """ @@ -108,7 +108,7 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: data_collator=data_collator, ) -model_name = "Qwen/Qwen2.5-VL-3B-Instruct" +model_name = "Qwen/Qwen2.5-VL-7B-Instruct" model, processor = vf.get_model_and_tokenizer(model_name) run_name = "docvqa_" + model_name.split("/")[-1].lower()