From f2ecbab594a0fe42ff521cff15f5c1c1ea031203 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 17 Oct 2024 08:52:10 -0400 Subject: [PATCH 1/3] Mechanism to insert BOS/EOS at the beginning/end of dialog MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/prompts/formatter.py | 19 ++++++++++++++++++- nemo/collections/common/prompts/gemma.py | 2 ++ nemo/collections/common/prompts/llama.py | 2 +- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py index 6d2c67f5311d..4b58135307c1 100644 --- a/nemo/collections/common/prompts/formatter.py +++ b/nemo/collections/common/prompts/formatter.py @@ -1,5 +1,4 @@ from abc import ABC -from enum import Enum from functools import lru_cache from typing import Any, Type @@ -142,6 +141,14 @@ class PromptFormatter(ABC): # PromptFormatter.encode_dialog() ends with this role, it indicates a training example. OUTPUT_ROLE = None + # When set to true, we will insert BOS/EOS symbol at the very beginning/end of the dialog + # (i.e., not before/after every turn). + # We query self.tokenizer.bos / self.tokenizer.eos to get their int IDs. + # Note that this is a separate mechanism from BOS_SLOT / EOS_SLOT which allows inserting + # these tokens at arbitrary positions in arbitrary turns. + INSERT_BOS = False + INSERT_EOS = False + # Internal reserved field. _REGISTERED_FORMATTERS = {} @@ -241,6 +248,11 @@ def encode_dialog(self, turns: list[dict]) -> dict[str, torch.Tensor]: turn_token_counts = [] turn_mask_values = [] + if self.INSERT_BOS: + turn_tokens.append(self.tokenizer.bos) + turn_token_counts.append(1) + turn_mask_values.append(False) + if "preamble" in self.TEMPLATE: preamble_turns = [idx for idx, t in enumerate(turns) if t["role"] == "preamble"] if not preamble_turns: @@ -267,6 +279,11 @@ def encode_dialog(self, turns: list[dict]) -> dict[str, torch.Tensor]: turn_token_counts.append(len(tokens)) turn_mask_values.append(role == self.OUTPUT_ROLE) + if self.INSERT_EOS: + turn_tokens.append(self.tokenizer.eos) + turn_token_counts.append(1) + turn_mask_values.append(False) + ans = {"input_ids": torch.tensor(turn_tokens, dtype=torch.long)} if turn_mask_values[-1]: # The last turn comes from OUTPUT_ROLE, i.e. it's a response from the system. diff --git a/nemo/collections/common/prompts/gemma.py b/nemo/collections/common/prompts/gemma.py index 128a5689e07f..2570995625ee 100644 --- a/nemo/collections/common/prompts/gemma.py +++ b/nemo/collections/common/prompts/gemma.py @@ -19,6 +19,8 @@ class GemmaPromptFormatter(PromptFormatter): NAME = "gemma" OUTPUT_ROLE = "assistant" + INSERT_BOS = True + INSERT_EOS = True TEMPLATE = { "user": { "template": f"{GEMMA_BOS}user\n|message|{GEMMA_END_OF_TURN}\n{GEMMA_BOS}model\n", diff --git a/nemo/collections/common/prompts/llama.py b/nemo/collections/common/prompts/llama.py index 7b2e1fe1d758..7defc49cc61a 100644 --- a/nemo/collections/common/prompts/llama.py +++ b/nemo/collections/common/prompts/llama.py @@ -26,7 +26,7 @@ class Llama2PromptFormatter(PromptFormatter): }, }, "user": { - "template": "|bos|[INST] |message| [/INST]", + "template": f"{BOS_SLOT}[INST] |message| [/INST]", "slots": { "message": Modality.Text, }, From b2970935895dbc8eae6449937c111f48ec6c5419 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 17 Oct 2024 09:38:52 -0400 Subject: [PATCH 2/3] Fix Gemma prompt formatter test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/prompts/formatter.py | 7 ++++--- .../test_gemma_prompt_formatter.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py index 4b58135307c1..76165b65fbce 100644 --- a/nemo/collections/common/prompts/formatter.py +++ b/nemo/collections/common/prompts/formatter.py @@ -279,10 +279,11 @@ def encode_dialog(self, turns: list[dict]) -> dict[str, torch.Tensor]: turn_token_counts.append(len(tokens)) turn_mask_values.append(role == self.OUTPUT_ROLE) - if self.INSERT_EOS: + # Insert EOS only when the last turn comes from the OUTPUT_ROLE. + if self.INSERT_EOS and turns[-1]["role"] == self.OUTPUT_ROLE: turn_tokens.append(self.tokenizer.eos) - turn_token_counts.append(1) - turn_mask_values.append(False) + turn_token_counts[-1] += 1 + turn_mask_values.append(True) ans = {"input_ids": torch.tensor(turn_tokens, dtype=torch.long)} if turn_mask_values[-1]: diff --git a/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py b/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py index be1f6de1a873..ea6454273e1a 100644 --- a/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py +++ b/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py @@ -11,16 +11,18 @@ def test_gemma_prompt_formatter_training(bpe_tokenizer): ) assert set(ans) == {"input_ids", "context_ids", "answer_ids", "mask"} # fmt: off - assert ans["input_ids"].tolist() == [ 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20, + # Note: The BPE tokenizer fixture in our test doesn't have BOS/EOS defined which is why the tokenizer + # returns an ID of -1 for these tokens. + assert ans["input_ids"].tolist() == [-1, 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20, 30, 104, 59, 18, 26, 18, 6, 60, 9, 7, 21, 53, 18, 26, 18, 6, 60, 9, 7, 73, 61, 69, 1, 81, 20, 30, 104, 59, - 18, 26, 18, 6, 60, 9, 7] - assert ans["context_ids"].tolist() == [ 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20, + 18, 26, 18, 6, 60, 9, 7, -1] + assert ans["context_ids"].tolist() == [-1, 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20, 30, 104, 59, 18, 26, 18, 6, 60, 9, 7, 21, 53, 18, 26, 18, 6, 60, 9, 7, 73, 61, 69] assert ans["answer_ids"].tolist() == [1, 81, 20, 30, 104, 59, - 18, 26, 18, 6, 60, 9, 7] - assert ans["mask"].tolist() == [False] * 36 + [True] * 13 + 18, 26, 18, 6, 60, 9, 7, -1] + assert ans["mask"].tolist() == [False] * 37 + [True] * 14 # fmt: on @@ -34,7 +36,7 @@ def test_gemma_prompt_formatter_inference(bpe_tokenizer): assert set(ans) == {"input_ids", "context_ids"} # fmt: off assert ans["input_ids"].tolist() == ans["context_ids"].tolist() - assert ans["input_ids"].tolist() == [ 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20, + assert ans["input_ids"].tolist() == [ -1, 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20, 30, 104, 59, 18, 26, 18, 6, 60, 9, 7, 21, 53, 18, 26, 18, 6, 60, 9, 7, 73, 61, 69] # fmt: on From ef36a5a941b36957938e1135c7cd7a9b0d900750 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 17 Oct 2024 09:41:34 -0400 Subject: [PATCH 3/3] Add a test specifically for multiturn insertion of bos/eos MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../test_gemma_prompt_formatter.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py b/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py index ea6454273e1a..518b24601d57 100644 --- a/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py +++ b/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py @@ -40,3 +40,29 @@ def test_gemma_prompt_formatter_inference(bpe_tokenizer): 30, 104, 59, 18, 26, 18, 6, 60, 9, 7, 21, 53, 18, 26, 18, 6, 60, 9, 7, 73, 61, 69] # fmt: on + + +def test_gemma_prompt_formatter_training_bos_eos_inserted_only_once_in_multiturn(bpe_tokenizer): + formatter = GemmaPromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "user", "slots": {"message": "TEST"}}, + {"role": "assistant", "slots": {"message": "TEST"}}, + {"role": "user", "slots": {"message": "TEST"}}, + {"role": "assistant", "slots": {"message": "TEST"}}, + {"role": "user", "slots": {"message": "TEST"}}, + {"role": "assistant", "slots": {"message": "TEST"}}, + {"role": "user", "slots": {"message": "TEST"}}, + {"role": "assistant", "slots": {"message": "TEST"}}, + ] + ) + + assert (ans["input_ids"] == -1).sum() == 2 + assert ans["input_ids"][0] == -1 + assert ans["input_ids"][-1] == -1 + + assert (ans["context_ids"] == -1).sum() == 1 + assert ans["context_ids"][0] == -1 + + assert (ans["answer_ids"] == -1).sum() == 1 + assert ans["answer_ids"][-1] == -1