Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mttl/datamodule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@


def maybe_filter_hf_dataset_by_task(
dataset, task_field, task_names: str = None, n_proc=16
dataset,
task_field,
task_names: str = None,
n_proc=16,
should_split_on_split_column=True,
):
"""Filter a HuggingFace dataset by task names."""

Expand Down Expand Up @@ -48,7 +52,9 @@ def maybe_filter_hf_dataset_by_task(
dev_dataset is None
and test_dataset is None
and "split" in train_dataset.features
and should_split_on_split_column
):
logger.info("Splitting train dataset on 'split' column.")
train_dataset, dev_dataset, test_dataset = split_on_split_column(
train_dataset, num_proc=n_proc
)
Expand Down
3 changes: 3 additions & 0 deletions projects/kms/train_km_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,7 @@ def train_km(training_args: KMArguments):
logger.info("Model already trained, skipping")
exit(0)

# The configs still contain pointers to internal az://mttldata paths, replace them
args.dataset = args.dataset.replace("az://mttldata", BASE_PREFIX)

train_km(args)
2 changes: 2 additions & 0 deletions projects/kms/utils/km_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def setup_dataset(self):
assert len(dataset) == 1, "all dataset should be in `train`"

# Let's first filter out unused tasks
# NOTE: do not split on split column here, we will do custom train / dev split later
(
self._task_names,
self._task_to_id,
Expand All @@ -360,6 +361,7 @@ def setup_dataset(self):
dataset,
self.config.task_name_field,
self.config.finetune_task_name,
should_split_on_split_column=False,
n_proc=n_proc,
)

Expand Down
7 changes: 6 additions & 1 deletion projects/kms/utils/longhealth_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def setup_dataset(self):
_,
_,
) = maybe_filter_hf_dataset_by_task(
dataset, self.config.task_name_field, self.config.finetune_task_name
dataset,
self.config.task_name_field,
self.config.finetune_task_name,
should_split_on_split_column=False,
)

# Let's make sure that the full prompt is always in context
Expand Down Expand Up @@ -186,6 +189,8 @@ def expand_questions(examples, tokenizer, len_template):
if self.tokenizer.chat_template is None:
self.tokenizer.apply_chat_template = lambda x, **kwargs: x[0]["content"]

# TODO: refactor code to leverage `split_on_split_column` in
# `maybe_filter_hf_dataset_by_task`
if "split" in train_dataset.features:
self.train_dataset, self.dev_dataset, self.test_dataset = (
split_on_split_column(train_dataset)
Expand Down
15 changes: 7 additions & 8 deletions projects/kms/utils/nqa_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ def setup_dataset(self):
(
self._task_names,
self._task_to_id,
train_dataset,
_,
_,
self.train_dataset,
self.dev_dataset,
self.test_dataset,
) = maybe_filter_hf_dataset_by_task(
dataset, self.config.task_name_field, self.config.finetune_task_name
)

self.train_dataset, self.dev_dataset, self.test_dataset = split_on_split_column(
train_dataset
dataset,
self.config.task_name_field,
self.config.finetune_task_name,
should_split_on_split_column=False,
)

def expand_questions(examples, tokenizer):
Expand Down
1 change: 1 addition & 0 deletions projects/kms/utils/pit_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def expand_targets_and_chat(example):
self.config.task_name_field,
self.config.finetune_task_name,
n_proc=n_proc,
should_split_on_split_column=False,
)

train_dataset = train_dataset.map(
Expand Down
249 changes: 14 additions & 235 deletions projects/kms/utils/quality_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def setup_dataset(self):
(
self._task_names,
self._task_to_id,
train_dataset,
_,
_,
self.train_dataset,
self.dev_dataset,
self.test_dataset,
) = maybe_filter_hf_dataset_by_task(
dataset, self.config.task_name_field, self.config.finetune_task_name
)
Expand Down Expand Up @@ -124,243 +124,22 @@ def expand_questions(examples, tokenizer):
if self.tokenizer.chat_template is None:
self.tokenizer.apply_chat_template = lambda x, **kwargs: x[0]["content"]

if "split" in train_dataset.features:
self.train_dataset, self.dev_dataset, self.test_dataset = (
split_on_split_column(train_dataset)
)
self.train_dataset = self.train_dataset.map(
lambda examples: expand_questions(examples, self.tokenizer),
batched=True,
batch_size=1000,
num_proc=1,
remove_columns=train_dataset.column_names,
)
self.train_dataset = self.train_dataset.map(
lambda examples: expand_questions(examples, self.tokenizer),
batched=True,
batch_size=1000,
num_proc=1,
remove_columns=self.train_dataset.column_names,
)
if self.dev_dataset:
self.dev_dataset = self.dev_dataset.map(
lambda examples: expand_questions(examples, self.tokenizer),
batched=True,
batch_size=1000,
num_proc=1,
remove_columns=train_dataset.column_names,
remove_columns=self.dev_dataset.column_names,
)
self.test_dataset = self.dev_dataset
else:
train_dataset = train_dataset.map(
lambda examples: expand_questions(examples, self.tokenizer),
batched=True,
batch_size=1000,
num_proc=1,
remove_columns=train_dataset.column_names,
)
self.train_dataset = self.dev_dataset = self.test_dataset = train_dataset


prompt_template_w_docs = """
--------------BEGIN CONTEXT--------------

{documents}

--------------END CONTEXT--------------

{question_text}
{options}

Please answer using the following format:
0. Begin your answer with the phrase "The correct answer is".
1. State the letter of the correct option (e.g., A, B, C, D).
2. Follow the letter with a colon and the exact text of the option you chose.
3. Make sure your answer is a single, concise sentence.

For example, if the correct answer to a question is option C, and the text for C is 'Acute Bronchitis', your answer should be:
'The correct answer is C: Acute bronchitis.'
"""

prompt_template_no_docs = """
{question_text}
{options}

Please answer using the following format:
1. Begin your answer with the phrase "The correct answer is".
2. State the letter of the correct option (e.g., A, B, C, D).
3. Follow the letter with a colon and the exact text of the option you chose.
4. Make sure your answer is a single, concise sentence.

For example, if the correct answer to a question is option C, and the text for C is 'Acute Bronchitis', your answer should be:
'The correct answer is C: Acute bronchitis.'
"""

max_new_tokens = 50


@dataclass
class GenQualityDatasetConfig(DatasetConfig):
task_name_field: str = "document_id"
task_source_field: str = "document_id"
prompt: str = (
"Answer the following question. Give only the answer, and no extra commentary, formatting, or chattiness. Question: "
)
include_context: bool = False
topk_context: int = 10
include_all_answers: bool = True


@DataModule.register("gen_quality", config_cls=GenQualityDatasetConfig)
class GenQualityDataModule(DataModule):
def setup_dataset(self):
from mttl.models.library.dataset_library import DatasetLibrary

dataset = DatasetLibrary.pull_dataset(self.config.dataset)

# Instead of always working with the large datasets, we can subsample it
if self.config.custom_split_file:
dataset = apply_custom_split_file(dataset, self.config.custom_split_file)

(
self._task_names,
self._task_to_id,
train_dataset,
_,
_,
) = maybe_filter_hf_dataset_by_task(
dataset, self.config.task_name_field, self.config.finetune_task_name
)

# Let's make sure that the full prompt is always in context
len_template = len(self.tokenizer.encode(prompt_template_w_docs))

def expand_questions(examples, tokenizer, len_template):
batch = {
"source": [],
"target": [],
"document_id": [],
}
self.dev_dataset = self.train_dataset

for i in range(len(examples["document_id"])):
for j in range(len(examples["questions"][i])):
document_id = examples["document_id"][i]
question = examples["questions"][i][j]
options = examples["options"][i][j]
gold_label = examples["gold_label"][i][j]
if gold_label == -1:
gold_label = label_index = None
else:
label_index = gold_label - 1

""" NEW """
letters = ["A", "B", "C", "D"]
option_str = "\n".join(
[f"{letters[i]}: {option}" for i, option in enumerate(options)]
)
len_question = len(tokenizer.encode(question))
len_options = len(tokenizer.encode(option_str))
len_suffix = len(tokenizer.encode("The correct answer is: "))

total_len = len_question + len_options + len_template + len_suffix

if self.config.include_context:
context = examples["text"][i]

if isinstance(context, list):
# following Alan's approach
context = " ".join(
[
f"Passage {k+1}: {context[k]}\n\n"
for k in range(
min(self.config.topk_context, len(context))
)[::-1]
]
)
assert (
type(context) == str
), f"Context should be a string, but got {type(context)}"

# Let's do some rough trucation if needed
context_ids = tokenizer.encode(context)
len_context = len(context_ids)
space_left = self.config.max_input_length - total_len

if space_left < len_context:
context_ids = context_ids[: max(0, space_left - 20)]
context = tokenizer.decode(
context_ids, skip_special_tokens=True
)

prompt = prompt_template_w_docs.format(
documents=context,
question_text=question,
options=option_str,
)
else:
prompt = prompt_template_no_docs.format(
question_text=question,
options=option_str,
)

"""
source = [
{
"role": "system",
"content": sys_prompt,
},
{
"role": "user",
"content": prompt,
},
]
"""
source = [
{
"role": "user",
"content": prompt,
}
]

batch["source"].append(
tokenizer.apply_chat_template(
source, add_generation_prompt=True, tokenize=False
)
+ "The correct answer is"
)
batch["target"].append(
letters[label_index]
) # [options[label_index]])
batch["document_id"].append(examples["document_id"][i])

return batch

if self.tokenizer.chat_template is None:
self.tokenizer.apply_chat_template = lambda x, **kwargs: x[0]["content"]

if "split" in train_dataset.features:
self.train_dataset, self.dev_dataset, self.test_dataset = (
split_on_split_column(train_dataset)
)
self.train_dataset = self.train_dataset.map(
lambda examples: expand_questions(
examples, self.tokenizer, len_template
),
batched=True,
batch_size=1000,
num_proc=1,
remove_columns=train_dataset.column_names,
)
self.dev_dataset = self.dev_dataset.map(
lambda examples: expand_questions(
examples, self.tokenizer, len_template
),
batched=True,
batch_size=1000,
num_proc=1,
remove_columns=train_dataset.column_names,
)
self.test_dataset = self.dev_dataset
else:
train_dataset = train_dataset.map(
lambda examples: expand_questions(
examples, self.tokenizer, len_template
),
batched=True,
batch_size=1000,
num_proc=1,
remove_columns=train_dataset.column_names,
)
self.train_dataset = self.dev_dataset = self.test_dataset = train_dataset
self.test_dataset = self.dev_dataset
2 changes: 0 additions & 2 deletions projects/kms/utils/quality_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from mttl.logging import logger, warn_once
from projects.kms.utils.nqa_datamodule import NQADatamodule, NQADatasetConfig
from projects.kms.utils.quality_datamodule import (
GenQualityDataModule,
GenQualityDatasetConfig,
QualityDatamodule,
QualityDatasetConfig,
)
Expand Down
Loading