diff --git a/mttl/datamodule/utils.py b/mttl/datamodule/utils.py index 191fab40..4d9c58ed 100644 --- a/mttl/datamodule/utils.py +++ b/mttl/datamodule/utils.py @@ -8,7 +8,11 @@ def maybe_filter_hf_dataset_by_task( - dataset, task_field, task_names: str = None, n_proc=16 + dataset, + task_field, + task_names: str = None, + n_proc=16, + should_split_on_split_column=True, ): """Filter a HuggingFace dataset by task names.""" @@ -48,7 +52,9 @@ def maybe_filter_hf_dataset_by_task( dev_dataset is None and test_dataset is None and "split" in train_dataset.features + and should_split_on_split_column ): + logger.info("Splitting train dataset on 'split' column.") train_dataset, dev_dataset, test_dataset = split_on_split_column( train_dataset, num_proc=n_proc ) diff --git a/projects/kms/train_km_simple.py b/projects/kms/train_km_simple.py index 96e13da0..3a40e2d9 100644 --- a/projects/kms/train_km_simple.py +++ b/projects/kms/train_km_simple.py @@ -332,4 +332,7 @@ def train_km(training_args: KMArguments): logger.info("Model already trained, skipping") exit(0) + # The configs still contain pointers to internal az://mttldata paths, replace them + args.dataset = args.dataset.replace("az://mttldata", BASE_PREFIX) + train_km(args) diff --git a/projects/kms/utils/km_datamodule.py b/projects/kms/utils/km_datamodule.py index 50711fb6..b132bef1 100644 --- a/projects/kms/utils/km_datamodule.py +++ b/projects/kms/utils/km_datamodule.py @@ -350,6 +350,7 @@ def setup_dataset(self): assert len(dataset) == 1, "all dataset should be in `train`" # Let's first filter out unused tasks + # NOTE: do not split on split column here, we will do custom train / dev split later ( self._task_names, self._task_to_id, @@ -360,6 +361,7 @@ def setup_dataset(self): dataset, self.config.task_name_field, self.config.finetune_task_name, + should_split_on_split_column=False, n_proc=n_proc, ) diff --git a/projects/kms/utils/longhealth_datamodule.py b/projects/kms/utils/longhealth_datamodule.py index aa7e75c7..c6bd2b7a 100644 --- a/projects/kms/utils/longhealth_datamodule.py +++ b/projects/kms/utils/longhealth_datamodule.py @@ -86,7 +86,10 @@ def setup_dataset(self): _, _, ) = maybe_filter_hf_dataset_by_task( - dataset, self.config.task_name_field, self.config.finetune_task_name + dataset, + self.config.task_name_field, + self.config.finetune_task_name, + should_split_on_split_column=False, ) # Let's make sure that the full prompt is always in context @@ -186,6 +189,8 @@ def expand_questions(examples, tokenizer, len_template): if self.tokenizer.chat_template is None: self.tokenizer.apply_chat_template = lambda x, **kwargs: x[0]["content"] + # TODO: refactor code to leverage `split_on_split_column` in + # `maybe_filter_hf_dataset_by_task` if "split" in train_dataset.features: self.train_dataset, self.dev_dataset, self.test_dataset = ( split_on_split_column(train_dataset) diff --git a/projects/kms/utils/nqa_datamodule.py b/projects/kms/utils/nqa_datamodule.py index 56a6e1dc..5678ee0a 100644 --- a/projects/kms/utils/nqa_datamodule.py +++ b/projects/kms/utils/nqa_datamodule.py @@ -35,15 +35,14 @@ def setup_dataset(self): ( self._task_names, self._task_to_id, - train_dataset, - _, - _, + self.train_dataset, + self.dev_dataset, + self.test_dataset, ) = maybe_filter_hf_dataset_by_task( - dataset, self.config.task_name_field, self.config.finetune_task_name - ) - - self.train_dataset, self.dev_dataset, self.test_dataset = split_on_split_column( - train_dataset + dataset, + self.config.task_name_field, + self.config.finetune_task_name, + should_split_on_split_column=False, ) def expand_questions(examples, tokenizer): diff --git a/projects/kms/utils/pit_datamodule.py b/projects/kms/utils/pit_datamodule.py index 2c83d21c..82c7688d 100644 --- a/projects/kms/utils/pit_datamodule.py +++ b/projects/kms/utils/pit_datamodule.py @@ -129,6 +129,7 @@ def expand_targets_and_chat(example): self.config.task_name_field, self.config.finetune_task_name, n_proc=n_proc, + should_split_on_split_column=False, ) train_dataset = train_dataset.map( diff --git a/projects/kms/utils/quality_datamodule.py b/projects/kms/utils/quality_datamodule.py index d9f5fbbf..c06fff99 100644 --- a/projects/kms/utils/quality_datamodule.py +++ b/projects/kms/utils/quality_datamodule.py @@ -35,9 +35,9 @@ def setup_dataset(self): ( self._task_names, self._task_to_id, - train_dataset, - _, - _, + self.train_dataset, + self.dev_dataset, + self.test_dataset, ) = maybe_filter_hf_dataset_by_task( dataset, self.config.task_name_field, self.config.finetune_task_name ) @@ -124,243 +124,22 @@ def expand_questions(examples, tokenizer): if self.tokenizer.chat_template is None: self.tokenizer.apply_chat_template = lambda x, **kwargs: x[0]["content"] - if "split" in train_dataset.features: - self.train_dataset, self.dev_dataset, self.test_dataset = ( - split_on_split_column(train_dataset) - ) - self.train_dataset = self.train_dataset.map( - lambda examples: expand_questions(examples, self.tokenizer), - batched=True, - batch_size=1000, - num_proc=1, - remove_columns=train_dataset.column_names, - ) + self.train_dataset = self.train_dataset.map( + lambda examples: expand_questions(examples, self.tokenizer), + batched=True, + batch_size=1000, + num_proc=1, + remove_columns=self.train_dataset.column_names, + ) + if self.dev_dataset: self.dev_dataset = self.dev_dataset.map( lambda examples: expand_questions(examples, self.tokenizer), batched=True, batch_size=1000, num_proc=1, - remove_columns=train_dataset.column_names, + remove_columns=self.dev_dataset.column_names, ) - self.test_dataset = self.dev_dataset else: - train_dataset = train_dataset.map( - lambda examples: expand_questions(examples, self.tokenizer), - batched=True, - batch_size=1000, - num_proc=1, - remove_columns=train_dataset.column_names, - ) - self.train_dataset = self.dev_dataset = self.test_dataset = train_dataset - - -prompt_template_w_docs = """ ---------------BEGIN CONTEXT-------------- - -{documents} - ---------------END CONTEXT-------------- - -{question_text} -{options} - -Please answer using the following format: -0. Begin your answer with the phrase "The correct answer is". -1. State the letter of the correct option (e.g., A, B, C, D). -2. Follow the letter with a colon and the exact text of the option you chose. -3. Make sure your answer is a single, concise sentence. - -For example, if the correct answer to a question is option C, and the text for C is 'Acute Bronchitis', your answer should be: -'The correct answer is C: Acute bronchitis.' -""" - -prompt_template_no_docs = """ -{question_text} -{options} - -Please answer using the following format: -1. Begin your answer with the phrase "The correct answer is". -2. State the letter of the correct option (e.g., A, B, C, D). -3. Follow the letter with a colon and the exact text of the option you chose. -4. Make sure your answer is a single, concise sentence. - -For example, if the correct answer to a question is option C, and the text for C is 'Acute Bronchitis', your answer should be: -'The correct answer is C: Acute bronchitis.' -""" - -max_new_tokens = 50 - - -@dataclass -class GenQualityDatasetConfig(DatasetConfig): - task_name_field: str = "document_id" - task_source_field: str = "document_id" - prompt: str = ( - "Answer the following question. Give only the answer, and no extra commentary, formatting, or chattiness. Question: " - ) - include_context: bool = False - topk_context: int = 10 - include_all_answers: bool = True - - -@DataModule.register("gen_quality", config_cls=GenQualityDatasetConfig) -class GenQualityDataModule(DataModule): - def setup_dataset(self): - from mttl.models.library.dataset_library import DatasetLibrary - - dataset = DatasetLibrary.pull_dataset(self.config.dataset) - - # Instead of always working with the large datasets, we can subsample it - if self.config.custom_split_file: - dataset = apply_custom_split_file(dataset, self.config.custom_split_file) - - ( - self._task_names, - self._task_to_id, - train_dataset, - _, - _, - ) = maybe_filter_hf_dataset_by_task( - dataset, self.config.task_name_field, self.config.finetune_task_name - ) - - # Let's make sure that the full prompt is always in context - len_template = len(self.tokenizer.encode(prompt_template_w_docs)) - - def expand_questions(examples, tokenizer, len_template): - batch = { - "source": [], - "target": [], - "document_id": [], - } + self.dev_dataset = self.train_dataset - for i in range(len(examples["document_id"])): - for j in range(len(examples["questions"][i])): - document_id = examples["document_id"][i] - question = examples["questions"][i][j] - options = examples["options"][i][j] - gold_label = examples["gold_label"][i][j] - if gold_label == -1: - gold_label = label_index = None - else: - label_index = gold_label - 1 - - """ NEW """ - letters = ["A", "B", "C", "D"] - option_str = "\n".join( - [f"{letters[i]}: {option}" for i, option in enumerate(options)] - ) - len_question = len(tokenizer.encode(question)) - len_options = len(tokenizer.encode(option_str)) - len_suffix = len(tokenizer.encode("The correct answer is: ")) - - total_len = len_question + len_options + len_template + len_suffix - - if self.config.include_context: - context = examples["text"][i] - - if isinstance(context, list): - # following Alan's approach - context = " ".join( - [ - f"Passage {k+1}: {context[k]}\n\n" - for k in range( - min(self.config.topk_context, len(context)) - )[::-1] - ] - ) - assert ( - type(context) == str - ), f"Context should be a string, but got {type(context)}" - - # Let's do some rough trucation if needed - context_ids = tokenizer.encode(context) - len_context = len(context_ids) - space_left = self.config.max_input_length - total_len - - if space_left < len_context: - context_ids = context_ids[: max(0, space_left - 20)] - context = tokenizer.decode( - context_ids, skip_special_tokens=True - ) - - prompt = prompt_template_w_docs.format( - documents=context, - question_text=question, - options=option_str, - ) - else: - prompt = prompt_template_no_docs.format( - question_text=question, - options=option_str, - ) - - """ - source = [ - { - "role": "system", - "content": sys_prompt, - }, - { - "role": "user", - "content": prompt, - }, - ] - """ - source = [ - { - "role": "user", - "content": prompt, - } - ] - - batch["source"].append( - tokenizer.apply_chat_template( - source, add_generation_prompt=True, tokenize=False - ) - + "The correct answer is" - ) - batch["target"].append( - letters[label_index] - ) # [options[label_index]]) - batch["document_id"].append(examples["document_id"][i]) - - return batch - - if self.tokenizer.chat_template is None: - self.tokenizer.apply_chat_template = lambda x, **kwargs: x[0]["content"] - - if "split" in train_dataset.features: - self.train_dataset, self.dev_dataset, self.test_dataset = ( - split_on_split_column(train_dataset) - ) - self.train_dataset = self.train_dataset.map( - lambda examples: expand_questions( - examples, self.tokenizer, len_template - ), - batched=True, - batch_size=1000, - num_proc=1, - remove_columns=train_dataset.column_names, - ) - self.dev_dataset = self.dev_dataset.map( - lambda examples: expand_questions( - examples, self.tokenizer, len_template - ), - batched=True, - batch_size=1000, - num_proc=1, - remove_columns=train_dataset.column_names, - ) - self.test_dataset = self.dev_dataset - else: - train_dataset = train_dataset.map( - lambda examples: expand_questions( - examples, self.tokenizer, len_template - ), - batched=True, - batch_size=1000, - num_proc=1, - remove_columns=train_dataset.column_names, - ) - self.train_dataset = self.dev_dataset = self.test_dataset = train_dataset + self.test_dataset = self.dev_dataset diff --git a/projects/kms/utils/quality_evaluator.py b/projects/kms/utils/quality_evaluator.py index 0851ca3e..e55ef234 100644 --- a/projects/kms/utils/quality_evaluator.py +++ b/projects/kms/utils/quality_evaluator.py @@ -15,8 +15,6 @@ from mttl.logging import logger, warn_once from projects.kms.utils.nqa_datamodule import NQADatamodule, NQADatasetConfig from projects.kms.utils.quality_datamodule import ( - GenQualityDataModule, - GenQualityDatasetConfig, QualityDatamodule, QualityDatasetConfig, )