Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion nemo/collections/common/prompts/formatter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC
from enum import Enum
from functools import lru_cache
from typing import Any, Type

Expand Down Expand Up @@ -142,6 +141,14 @@ class PromptFormatter(ABC):
# PromptFormatter.encode_dialog() ends with this role, it indicates a training example.
OUTPUT_ROLE = None

# When set to true, we will insert BOS/EOS symbol at the very beginning/end of the dialog
# (i.e., not before/after every turn).
# We query self.tokenizer.bos / self.tokenizer.eos to get their int IDs.
# Note that this is a separate mechanism from BOS_SLOT / EOS_SLOT which allows inserting
# these tokens at arbitrary positions in arbitrary turns.
INSERT_BOS = False
INSERT_EOS = False

# Internal reserved field.
_REGISTERED_FORMATTERS = {}

Expand Down Expand Up @@ -241,6 +248,11 @@ def encode_dialog(self, turns: list[dict]) -> dict[str, torch.Tensor]:
turn_token_counts = []
turn_mask_values = []

if self.INSERT_BOS:
turn_tokens.append(self.tokenizer.bos)
turn_token_counts.append(1)
turn_mask_values.append(False)

if "preamble" in self.TEMPLATE:
preamble_turns = [idx for idx, t in enumerate(turns) if t["role"] == "preamble"]
if not preamble_turns:
Expand All @@ -267,6 +279,12 @@ def encode_dialog(self, turns: list[dict]) -> dict[str, torch.Tensor]:
turn_token_counts.append(len(tokens))
turn_mask_values.append(role == self.OUTPUT_ROLE)

# Insert EOS only when the last turn comes from the OUTPUT_ROLE.
if self.INSERT_EOS and turns[-1]["role"] == self.OUTPUT_ROLE:
turn_tokens.append(self.tokenizer.eos)
turn_token_counts[-1] += 1
turn_mask_values.append(True)

ans = {"input_ids": torch.tensor(turn_tokens, dtype=torch.long)}
if turn_mask_values[-1]:
# The last turn comes from OUTPUT_ROLE, i.e. it's a response from the system.
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/common/prompts/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
class GemmaPromptFormatter(PromptFormatter):
NAME = "gemma"
OUTPUT_ROLE = "assistant"
INSERT_BOS = True
INSERT_EOS = True
TEMPLATE = {
"user": {
"template": f"{GEMMA_BOS}user\n|message|{GEMMA_END_OF_TURN}\n{GEMMA_BOS}model\n",
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/common/prompts/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Llama2PromptFormatter(PromptFormatter):
},
},
"user": {
"template": "|bos|[INST] |message| [/INST]",
"template": f"{BOS_SLOT}[INST] |message| [/INST]",
"slots": {
"message": Modality.Text,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ def test_gemma_prompt_formatter_training(bpe_tokenizer):
)
assert set(ans) == {"input_ids", "context_ids", "answer_ids", "mask"}
# fmt: off
assert ans["input_ids"].tolist() == [ 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20,
# Note: The BPE tokenizer fixture in our test doesn't have BOS/EOS defined which is why the tokenizer
# returns an ID of -1 for these tokens.
assert ans["input_ids"].tolist() == [-1, 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20,
30, 104, 59, 18, 26, 18, 6, 60, 9, 7, 21, 53, 18, 26,
18, 6, 60, 9, 7, 73, 61, 69, 1, 81, 20, 30, 104, 59,
18, 26, 18, 6, 60, 9, 7]
assert ans["context_ids"].tolist() == [ 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20,
18, 26, 18, 6, 60, 9, 7, -1]
assert ans["context_ids"].tolist() == [-1, 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20,
30, 104, 59, 18, 26, 18, 6, 60, 9, 7, 21, 53, 18, 26,
18, 6, 60, 9, 7, 73, 61, 69]
assert ans["answer_ids"].tolist() == [1, 81, 20, 30, 104, 59,
18, 26, 18, 6, 60, 9, 7]
assert ans["mask"].tolist() == [False] * 36 + [True] * 13
18, 26, 18, 6, 60, 9, 7, -1]
assert ans["mask"].tolist() == [False] * 37 + [True] * 14
# fmt: on


Expand All @@ -34,7 +36,33 @@ def test_gemma_prompt_formatter_inference(bpe_tokenizer):
assert set(ans) == {"input_ids", "context_ids"}
# fmt: off
assert ans["input_ids"].tolist() == ans["context_ids"].tolist()
assert ans["input_ids"].tolist() == [ 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20,
assert ans["input_ids"].tolist() == [ -1, 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20,
30, 104, 59, 18, 26, 18, 6, 60, 9, 7, 21, 53, 18, 26,
18, 6, 60, 9, 7, 73, 61, 69]
# fmt: on


def test_gemma_prompt_formatter_training_bos_eos_inserted_only_once_in_multiturn(bpe_tokenizer):
formatter = GemmaPromptFormatter(bpe_tokenizer)
ans = formatter.encode_dialog(
[
{"role": "user", "slots": {"message": "TEST"}},
{"role": "assistant", "slots": {"message": "TEST"}},
{"role": "user", "slots": {"message": "TEST"}},
{"role": "assistant", "slots": {"message": "TEST"}},
{"role": "user", "slots": {"message": "TEST"}},
{"role": "assistant", "slots": {"message": "TEST"}},
{"role": "user", "slots": {"message": "TEST"}},
{"role": "assistant", "slots": {"message": "TEST"}},
]
)

assert (ans["input_ids"] == -1).sum() == 2
assert ans["input_ids"][0] == -1
assert ans["input_ids"][-1] == -1

assert (ans["context_ids"] == -1).sum() == 1
assert ans["context_ids"][0] == -1

assert (ans["answer_ids"] == -1).sum() == 1
assert ans["answer_ids"][-1] == -1