diff --git a/contrib/tool-finetune/README.md b/contrib/tool-finetune/README.md
new file mode 100644
index 000000000..9bdb775ad
--- /dev/null
+++ b/contrib/tool-finetune/README.md
@@ -0,0 +1,75 @@
+## Function-call Finetune
+
+### Conda environment
+
+```
+bitsandbytes==0.40.0
+deepspeed==0.12.0
+flash-attn==2.5.7
+peft==0.10.0
+torch==2.1.2+cu118
+transformers==4.40.1
+vllm==0.5.2
+xformers==0.0.27
+```
+
+### Conversation Template
+```
+{
+ "type": "conversation",
+ "instances": [
+ {
+ "system": "You are a helpful assistant with access to the following functions. Use them if required - ",
+ "tools": ["{\"name\": \"", \"description\": \"", \"parameters\": {\"type\": \"object\", \"properties\": {\"property_1\": {\"type\": \"xxx\", \"description\": \"\"}, \"property_2\": {\"type\": \"xxx\", \"description\": \"\"}}, \"required\": [\"required_1\", \"property_n\"]}}",]",
+ "messages": [
+ {
+ "role": "user",
+ "content": ""
+ },
+ {
+ "role": "function",
+ "content": ""
+ },
+ {
+ "role": "observation",
+ "content": ""
+ },
+ {
+ "role": "assistant",
+ "content": ""
+ }
+ ]
+ },
+ {
+ "system": "You are a helpful assistant, with no access to external functions.",
+ "tools": [],
+ "messages": [
+ {
+ "role": "user",
+ "content": ""
+ },
+ {
+ "role": "assistant",
+ "content": ""
+ }
+ ]
+ },
+ ]
+}
+```
+
+### Run Function-call Finetune Example
+```
+./contrib/tool-finetune/run_function_call_finetune.sh \
+ --model_name_or_path meta-llama/Meta-Llama-3-8B \
+ --trust_remote_code True \
+ --conversation_template llama3_for_tool \
+ --dataset_path /home/wenhesun/LMFlow/data/glaive-function-calling-v2 \
+ --output_model_path /home/wenhesun/LMFlow/output_models/function-call-finetuned-llama
+```
+
+### Command-Line Arguments
+- `--model-name-or-path` - Specifies the name or path of the model used for
+- `conversation_template` - So far supports the following choices: llama3_for_tool, qwen2_for_tool
+- `dataset_path` - The path to the dataset that has been converted to the specified format
+- `output_model_path` - Directory to store the finetuned model and logs
\ No newline at end of file
diff --git a/contrib/tool-finetune/function_call_finetune.py b/contrib/tool-finetune/function_call_finetune.py
new file mode 100644
index 000000000..313563db5
--- /dev/null
+++ b/contrib/tool-finetune/function_call_finetune.py
@@ -0,0 +1,282 @@
+import os
+import sys
+sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
+import hashlib
+from typing import Dict, List, Union, Tuple, Optional, Sequence
+import logging
+from datasets import Features
+import transformers
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.testing_utils import CaptureLogger
+from transformers import HfArgumentParser, PreTrainedTokenizer, PreTrainedTokenizerFast
+from peft import LoraConfig, TaskType
+from lmflow.datasets.dataset import Dataset
+from lmflow.pipeline.auto_pipeline import AutoPipeline
+from lmflow.tokenization.hf_decoder_model import blocking
+from lmflow.utils.conversation_template.base import TemplateComponent
+from lmflow.utils.constants import (
+ TEXT_ONLY_DATASET_DESCRIPTION,
+ TEXT2TEXT_DATASET_DESCRIPTION,
+ CONVERSATION_DATASET_DESCRIPTION,
+ CONVERSATION_ROLE_NAMES
+)
+from lmflow.utils.conversation_template import ConversationTemplateForTool, PRESET_TEMPLATES
+from lmflow.args import (
+ ModelArguments,
+ DatasetArguments,
+ AutoArguments,
+)
+from lmflow.models.hf_decoder_model import HFDecoderModel
+tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
+logger = logging.getLogger(__name__)
+
+
+class HFDecoderModelForTool(HFDecoderModel):
+ def tokenize(
+ self,
+ dataset,
+ add_special_tokens=True,
+ *args,
+ **kwargs
+ ) -> Dataset:
+ """
+ Tokenize the full dataset.
+
+ Parameters
+ ------------
+ dataset : lmflow.datasets.Dataset.
+
+ args : Optional.
+ Positional arguments.
+
+ kwargs : Optional.
+ Keyword arguments.
+
+ Returns
+ ------------
+ tokenized_datasets :
+ The tokenized dataset, without any leading or trailing special
+ tokens (normally they are Begin-Of-Sentence or End-Of-Sentence
+ tokens).
+ """
+ # Preprocessing the datasets.
+ # First we tokenize all the texts.
+ if dataset.get_backend() != "huggingface":
+ raise NotImplementedError(
+ "tokenization of datasets with non-huggingface backend are"
+ "not supported yet"
+ )
+
+ dataset_type = dataset.get_type()
+ model_args = self.model_args
+ raw_datasets = dataset
+ hf_raw_datasets = dataset.get_backend_dataset()
+ column_names = list(hf_raw_datasets.features)
+ data_args = raw_datasets.get_data_args()
+
+ # Requires three types of information for tokenizing different datasets
+ # 1) Which fields require tokenization, e.g.
+ # "text2float": "text", but not "float"
+ # "text2text": both "input" and "output"
+ # 2) How will there tokenized sequence concatenated together, e.g.
+ # "text_only": "text" -> "text"
+ # "text2text": "input", "output" -> "input" + "output"
+ # 3) Which fields require loss in final computation, e.g.
+ # "text_only": "text"
+ # "text2text": "output" only
+ tokenized_column_order = None # Handles 1) and 2)
+ label_columns = None # Handles 3)
+ if dataset_type == "text_only":
+ tokenized_column_order = ["text"]
+ label_columns = ["text"]
+ elif dataset_type == "text2text":
+ tokenized_column_order = ["input", "output"]
+ label_columns = ["output"]
+ add_special_tokens = False
+ elif dataset_type == "conversation":
+ if data_args.conversation_template:
+ if data_args.conversation_template in PRESET_TEMPLATES.keys():
+ conversation_template = PRESET_TEMPLATES[data_args.conversation_template]
+ else:
+ raise NotImplementedError(
+ f"Conversation template {data_args.conversation_template} is not supported yet."
+ )
+ else:
+ logger.warning("No conversation template provided. Using default template.")
+ conversation_template = PRESET_TEMPLATES['empty']
+
+ logger.warning(f"Conversation template: {conversation_template}")
+ else:
+ raise NotImplementedError(
+ f"dataset type \"{dataset_type}\" is not supported, currently"
+ " only support following data types:\n"
+ f" 1) {TEXT_ONLY_DATASET_DESCRIPTION}\n"
+ f" 2) {TEXT2TEXT_DATASET_DESCRIPTION}\n"
+ f" 3) {CONVERSATION_DATASET_DESCRIPTION}\n"
+ )
+
+ # Whether to truncate long sequences to fit into max_length
+ use_truncation = False
+ if model_args.use_lora or data_args.disable_group_texts:
+ use_truncation = True
+
+ tokenize_fn = conversation_tokenize_function
+ tokenize_fn_kwargs = {
+ "data_args": data_args,
+ "tokenizer": self.tokenizer,
+ "column_names": column_names,
+ }
+ if "conversation" in dataset_type:
+ tokenize_fn_kwargs["conversation_template"] = conversation_template
+ else:
+ tokenize_fn_kwargs["label_columns"] = label_columns
+ tokenize_fn_kwargs["tokenized_column_order"] = tokenized_column_order
+ tokenize_fn_kwargs["add_special_tokens"] = add_special_tokens
+ tokenize_fn_kwargs["use_truncation"] = use_truncation
+
+ tokenize_kwargs = {}
+ if not data_args.streaming:
+ fingerprint = hashlib.md5(
+ (
+ raw_datasets.get_fingerprint()
+ + str(self.tokenizer)
+ + f'###padding_side={self.tokenizer.padding_side}'
+ + ('###conversation_template=' + str(conversation_template) if "conversation" in dataset_type else "")
+ + f'###disable_group_texts={data_args.disable_group_texts}'
+ + f'###block_size={data_args.block_size}'
+ ).encode("utf-8")
+ ).hexdigest()
+ tokenize_kwargs = {
+ "num_proc": data_args.preprocessing_num_workers,
+ "load_from_cache_file": not data_args.overwrite_cache,
+ "desc": "Running tokenizer on dataset",
+ "new_fingerprint": fingerprint,
+ }
+
+ tokenized_datasets = raw_datasets.map(
+ tokenize_fn,
+ batched=True,
+ remove_columns=column_names,
+ fn_kwargs=tokenize_fn_kwargs,
+ **tokenize_kwargs
+ )
+
+ return tokenized_datasets
+
+def conversation_tokenize_function(
+ examples,
+ data_args: DatasetArguments,
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
+ column_names,
+ conversation_template: ConversationTemplateForTool,
+) -> Dict:
+ """Handels conversation datasets tokenization
+ """
+ num_example = len(examples[column_names[0]])
+ token_dict = {
+ "input_ids": [[] for _ in range(num_example)],
+ "attention_mask": [[] for _ in range(num_example)],
+ "labels": [[] for _ in range(num_example)],
+ }
+ with CaptureLogger(tok_logger) as cl:
+ for i in range(len(examples["messages"])):
+ messages = examples["messages"][i]
+ system = examples.get("system", [None] * num_example)[i]
+ tools = examples.get("tools", [None] * num_example)[i]
+ if len(messages) < 2 or messages[0]['role'] != CONVERSATION_ROLE_NAMES['user']:
+ tok_logger.warning(
+ "Invalid instance encountered. Either the conversation has less than "
+ "one round or the first message is not from the user."
+ )
+ continue
+
+ if len(messages) % 2 != 0:
+ logger.warning(
+ "The number of messages is not even, the last message will be ignored."
+ )
+ messages = messages[:-1]
+
+ encoded_conversation = conversation_template.encode_conversation(
+ tokenizer=tokenizer,
+ messages=messages,
+ system=system,
+ tools=tools,
+ )
+
+ input_ids, labels = [], []
+ for turn_idx, conversation_tuple in enumerate(encoded_conversation):
+ if len(conversation_tuple) == 2:
+ user_input = conversation_tuple[0]
+ assistant_result = conversation_tuple[1]
+ input_ids += user_input + assistant_result
+ if data_args.train_on_prompt:
+ labels += user_input + assistant_result
+ else:
+ labels += [-100] * len(user_input) + assistant_result
+ elif len(conversation_tuple) == 4:
+ user_input = conversation_tuple[0]
+ function_result = conversation_tuple[1]
+ observation_input = conversation_tuple[2]
+ assistant_result = conversation_tuple[3]
+ input_ids += user_input + function_result + observation_input + assistant_result
+ if data_args.train_on_prompt:
+ labels += user_input + function_result + observation_input + assistant_result
+ else:
+ labels += [-100] * len(user_input) + function_result + [-100] * len(observation_input) + assistant_result
+ else:
+ logger.warning("The number of roles in conversation is not appropriate")
+
+ token_dict["input_ids"][i].extend(input_ids)
+ token_dict["attention_mask"][i].extend([1] * len(input_ids))
+ token_dict["labels"][i].extend(labels)
+
+ if data_args.disable_group_texts:
+ token_dict = blocking(
+ token_dict=token_dict,
+ block_size=data_args.block_size,
+ model_max_length=tokenizer.model_max_length,
+ pad_token_id=tokenizer.pad_token_id,
+ padding_side=tokenizer.padding_side,
+ )
+
+ # clm input could be much much longer than block_size
+ if "Token indices sequence length is longer than the" in cl.out:
+ tok_logger.warning(
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
+ " before being passed to the model."
+ )
+ return token_dict
+
+
+def train():
+
+ # Initialize args
+ ## Prepare training_args
+ pipeline_name = "finetuner"
+ PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
+ parser = HfArgumentParser((ModelArguments, DatasetArguments, PipelineArguments))
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
+ else:
+ model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses()
+ print("Model args", model_args)
+ print("data_args", data_args)
+ print("training_args", pipeline_args)
+
+ # Init model
+ model = HFDecoderModelForTool(model_args)
+
+ # Process data
+ dataset = Dataset(data_args)
+
+ # Finetune
+ finetuner = AutoPipeline.get_pipeline(
+ pipeline_name=pipeline_name,
+ model_args=model_args,
+ data_args=data_args,
+ pipeline_args=pipeline_args,
+ )
+ tuned_model = finetuner.tune(model=model, dataset=dataset)
+
+if __name__ == "__main__":
+ train()
\ No newline at end of file
diff --git a/contrib/tool-finetune/run_function_call_finetune.sh b/contrib/tool-finetune/run_function_call_finetune.sh
new file mode 100755
index 000000000..220abee9a
--- /dev/null
+++ b/contrib/tool-finetune/run_function_call_finetune.sh
@@ -0,0 +1,78 @@
+#!/bin/bash
+# Please run this script under ${project_id} in project directory of
+# https://github.com/shizhediao/llm-ft
+# COMMIT: d5fecf30ba8011067b10cf51fede53a5ab6574e4
+
+# Parses arguments
+model_name_or_path=gpt2
+dataset_path=data/alpaca/train_conversation
+output_dir=output_models/finetune
+deepspeed_args="--master_port=12000"
+conversation_template=llama2
+
+# Safety related arguments
+trust_remote_code=0
+
+while [[ $# -ge 1 ]]; do
+ key="$1"
+ case ${key} in
+ -m|--model_name_or_path)
+ model_name_or_path="$2"
+ shift
+ ;;
+ -d|--dataset_path)
+ dataset_path="$2"
+ shift
+ ;;
+ -o|--output_model_path)
+ output_dir="$2"
+ shift
+ ;;
+ --conversation_template)
+ conversation_template="$2"
+ shift
+ ;;
+ --deepspeed_args)
+ deepspeed_args="$2"
+ shift
+ ;;
+ --trust_remote_code)
+ trust_remote_code="$2"
+ shift
+ ;;
+ *)
+ echo "error: unknown option \"${key}\"" 1>&2
+ exit 1
+ esac
+ shift
+done
+
+# Finetune
+exp_id=finetune
+project_dir=$(cd "$(dirname $0)"/..; pwd)
+log_dir=${project_dir}/log/${exp_id}
+mkdir -p ${output_dir} ${log_dir}
+
+deepspeed ${deepspeed_args} \
+ contrib/tool-finetune/function_call_finetune.py \
+ --model_name_or_path ${model_name_or_path} \
+ --trust_remote_code ${trust_remote_code} \
+ --dataset_path ${dataset_path} \
+ --output_dir ${output_dir} --overwrite_output_dir \
+ --conversation_template ${conversation_template} \
+ --num_train_epochs 0.01 \
+ --learning_rate 2e-5 \
+ --disable_group_texts 1 \
+ --block_size 1024 \
+ --per_device_train_batch_size 1 \
+ --deepspeed configs/ds_config_zero3.json \
+ --fp16 \
+ --run_name finetune \
+ --validation_split_percentage 0 \
+ --logging_steps 20 \
+ --do_train \
+ --ddp_timeout 72000 \
+ --save_steps 5000 \
+ --dataloader_num_workers 1 \
+ > >(tee ${log_dir}/train.log) \
+ 2> >(tee ${log_dir}/train.err >&2)
diff --git a/src/lmflow/utils/conversation_template/__init__.py b/src/lmflow/utils/conversation_template/__init__.py
index da5266e64..3f62001e5 100644
--- a/src/lmflow/utils/conversation_template/__init__.py
+++ b/src/lmflow/utils/conversation_template/__init__.py
@@ -1,15 +1,15 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
-from .base import EMPTY_TEMPLATE, EMPTY_NO_SPECIAL_TOKENS_TEMPLATE, ConversationTemplate
+from .base import EMPTY_TEMPLATE, EMPTY_NO_SPECIAL_TOKENS_TEMPLATE, ConversationTemplate, ConversationTemplateForTool
from .chatglm import CHATGLM3_TEMPLATE
from .chatml import CHATML_TEMPLATE
from .deepseek import DEEPSEEK_TEMPLATE
from .gemma import GEMMA_TEMPLATE
from .internlm import INTERNLM2_TEMPLATE
-from .llama import LLAMA2_TEMPLATE, LLAMA3_TEMPLATE
+from .llama import LLAMA2_TEMPLATE, LLAMA3_TEMPLATE, LLAMA3_TEMPLATE_FOR_TOOL
from .phi import PHI3_TEMPLATE
-from .qwen import QWEN2_TEMPLATE
+from .qwen import QWEN2_TEMPLATE, QWEN2_TEMPLATE_FOR_TOOL
from .yi import YI1_5_TEMPLATE
from .zephyr import ZEPHYR_TEMPLATE
@@ -25,8 +25,10 @@
'internlm2': INTERNLM2_TEMPLATE,
'llama2': LLAMA2_TEMPLATE,
'llama3': LLAMA3_TEMPLATE,
+ 'llama3_for_tool': LLAMA3_TEMPLATE_FOR_TOOL,
'phi3': PHI3_TEMPLATE,
'qwen2': QWEN2_TEMPLATE,
+ 'qwen2_for_tool': QWEN2_TEMPLATE_FOR_TOOL,
'yi': CHATML_TEMPLATE,
'yi1_5': YI1_5_TEMPLATE,
'zephyr': ZEPHYR_TEMPLATE
diff --git a/src/lmflow/utils/conversation_template/base.py b/src/lmflow/utils/conversation_template/base.py
index 4542cbee0..e93afacd2 100644
--- a/src/lmflow/utils/conversation_template/base.py
+++ b/src/lmflow/utils/conversation_template/base.py
@@ -9,6 +9,7 @@
from transformers import PreTrainedTokenizer
+from lmflow.utils.constants import CONVERSATION_ROLE_NAMES
logger = logging.getLogger(__name__)
@@ -156,6 +157,8 @@ def format(self, **kwargs) -> list:
class ConversationTemplate:
user_formatter: Formatter
assistant_formatter: Formatter
+ function_formatter: Optional[Formatter] = None,
+ observation_formatter: Optional[Formatter] = None,
system_formatter: Optional[Formatter] = None
tools_formatter: Optional[Formatter] = None
separator: Optional[TemplateComponent] = None
@@ -393,6 +396,168 @@ def _ensure_id_list(self, obj: Union[int, List[int]]) -> List[int]:
else:
raise ValueError(f"Object type {type(obj)} is not supported yet.")
+@dataclass
+class ConversationTemplateForTool(ConversationTemplate):
+ def encode_conversation(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ messages: List[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[List[str]] = None,
+ remove_last_sep: bool = False,
+ **kwargs
+ ) -> Sequence[Tuple[List[int], List[int]]]:
+ r'''
+ Messages here should be guaranteed to be in pairs, with the first message being the user message and the second message being the system message.
+ Data example:
+ ```json
+ {
+ "conversation_id": 2,
+ "system": "sysinfo1",
+ "tools": ["tool_1_desc"],
+ "messages": [
+ {
+ "role": "user",
+ "content": "hi"
+ },
+ {
+ "role": "assistant",
+ "content": "Hello!"
+ }
+ ]
+ }
+ ```
+ '''
+ assert isinstance(messages, list), "Messages must be a list."
+
+ if tools is not None:
+ tools = ','.join(tools)
+ # logger.warning("Tools are not supported yet. Please include tools in the system message manually.")
+ else:
+ tools = ''
+
+ if system:
+ if system.replace(" ",""):
+ if not self.system_formatter:
+ raise ValueError("Your dataset contains system message but no system formatter is provided. "
+ "Consider either providing a system formatter or removing system prompt from your dataset.")
+ else:
+ system = None
+ encoded_pairs = self._encode(tokenizer, messages, system, tools, **kwargs)
+
+ if self.separator and remove_last_sep:
+ # For models that require a separator between messages,
+ # user can include the seperator at the end of each template
+ # and specify the separator. Auto formatting will remove the
+ # last separator once user specifies this option.
+ encoded_pairs = self.remove_last_separator(encoded_pairs, tokenizer)
+
+ if self.special_starter:
+ # For models that has ONLY ONE bos token at the beginning of
+ # a conversation session (not a conversation pair), user can
+ # specify a special starter to add that starter to the very
+ # beginning of the conversation session.
+ # eg:
+ # llama-2: and at every pair of conversation
+ # v.s.
+ # llama-3: <|begin_of_text|> only at the beginning of a session
+ encoded_pairs = self.add_special_starter(encoded_pairs, tokenizer)
+
+ if self.special_stopper:
+ encoded_pairs = self.add_special_stopper(encoded_pairs, tokenizer)
+
+ return encoded_pairs
+
+ def _encode(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ messages: List[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ **kwargs
+ ) -> Sequence[Tuple[List[int], List[int]]]:
+ # TODO: truncation according to model max length
+ # TODO: make sure the last few tokens are "learnable", not masked with token_id = -100.
+
+ res_all = []
+ # Concatenate the system and tools strings
+ system = system + tools
+ system_formatted = self.system_formatter.format(content=system) if system else []
+ system_encoded = self._encode_template(system_formatted, tokenizer)
+ ls_for_save = []
+ for i in range(0, len(messages), 1):
+ if messages[i]['role'] == CONVERSATION_ROLE_NAMES['user']:
+ user_message = messages[i]
+ user_formatted = self.user_formatter.format(content=user_message["content"])
+ user_encoded = self._encode_template(user_formatted, tokenizer)
+ if i == 0:
+ user_encoded = system_encoded + user_encoded
+ ls_for_save.append(user_encoded)
+ elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['function']:
+ function_message = messages[i]
+ function_formatted = self.assistant_formatter.format(content=function_message['content'])
+ function_encoded = self._encode_template(function_formatted, tokenizer)
+ ls_for_save.append(function_encoded)
+ elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['observation']:
+ observation_message = messages[i]
+ observation_formatted = self.user_formatter.format(content=observation_message['content'])
+ observation_encoded = self._encode_template(observation_formatted, tokenizer)
+ ls_for_save.append(observation_encoded)
+ elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['assistant']:
+ assistant_message = messages[i]
+ assistant_formatted = self.assistant_formatter.format(content=assistant_message["content"])
+ assistant_encoded = self._encode_template(assistant_formatted, tokenizer)
+ ls_for_save.append(assistant_encoded)
+ # res_tuple = (ls_for_save[0], ls_for_save[1], ls_for_save[2], ls_for_save[3])
+ res_all.append(tuple(ls_for_save))
+ ls_for_save = []
+
+ if ls_for_save:
+ res_all.append(tuple(ls_for_save))
+
+ return res_all
+
+ def _encode_template(
+ self,
+ template: List[TemplateComponent],
+ tokenizer: PreTrainedTokenizer,
+ **kwargs
+ ) -> List[int]:
+ """Encode template components into token ids.
+
+ Parameters
+ ----------
+ template : List[TemplateComponent]
+ Formatted template components.
+ tokenizer : PreTrainedTokenizer
+ Tokenizer to convert tokens into token ids.
+
+ Returns
+ -------
+ List[int]
+ Encoded token ids.
+ """
+ encoded_ids = []
+ for component in template:
+ if component.type == 'string':
+ if len(component.content) == 0:
+ logger.warning("Empty string component found in the template.")
+ continue
+ else:
+ encoded_ids += tokenizer.encode(component.content, add_special_tokens=False)
+ elif component.type == 'token':
+ if component.content == 'bos_token':
+ encoded_ids += [tokenizer.bos_token_id]
+ elif component.content == 'eos_token':
+ encoded_ids += [tokenizer.eos_token_id]
+ else:
+ encoded_ids += self._ensure_id_list(tokenizer.convert_tokens_to_ids(component.content))
+ elif component.type == 'token_id':
+ encoded_ids += self._ensure_id_list(component.content)
+ else:
+ raise NotImplementedError(f"Component type {component.type} is not supported yet.")
+ return encoded_ids
+
EMPTY_TEMPLATE = ConversationTemplate(
template_name='empty',
diff --git a/src/lmflow/utils/conversation_template/llama.py b/src/lmflow/utils/conversation_template/llama.py
index 1895ee144..5cc147cfe 100644
--- a/src/lmflow/utils/conversation_template/llama.py
+++ b/src/lmflow/utils/conversation_template/llama.py
@@ -6,12 +6,12 @@
from transformers import PreTrainedTokenizer
-from .base import StringFormatter, TemplateComponent, ConversationTemplate
+from .base import StringFormatter, TemplateComponent, ConversationTemplate, ConversationTemplateForTool
+from lmflow.utils.constants import CONVERSATION_ROLE_NAMES
logger = logging.getLogger(__name__)
-
class Llama2ConversationTemplate(ConversationTemplate):
def _encode(
self,
@@ -47,8 +47,58 @@ def _encode(
))
return res_all
-
-
+
+class Llama2ConversationTemplateForTool(Llama2ConversationTemplate):
+ def _encode(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ messages: List[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ **kwargs
+ ) -> Sequence[Tuple[List[int], List[int]]]:
+ if tools:
+ # logger.warning("Formatted tools are not supported in Llama2, thus tools will be ignored. "
+ # "If this is intended, please include tools in the system message manually.")
+ system = system + tools
+ res_all = []
+ system_formatted = self.system_formatter.format(content=system) if system else []
+ system_formatted_text = "".join([component.content for component in system_formatted if component.type == 'string']) # HACK
+ ls_for_save = []
+ for i in range(0, len(messages), 1):
+ if messages[i]['role'] == CONVERSATION_ROLE_NAMES['user']:
+ user_message = messages[i]
+ if i == 0:
+ user_content = system_formatted_text + user_message['content']
+ else:
+ user_content = user_message['content']
+ user_formatted = self.user_formatter.format(content=user_content)
+ user_encoded = self._encode_template(user_formatted, tokenizer)
+ ls_for_save.append(user_encoded)
+ elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['function']:
+ function_message = messages[i]
+ function_formatted = self.assistant_formatter.format(content=function_message['content'])
+ function_encoded = self._encode_template(function_formatted, tokenizer)
+ ls_for_save.append(function_encoded)
+ elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['observation']:
+ observation_message = messages[i]
+ observation_formatted = self.user_formatter.format(content=observation_message['content'])
+ observation_encoded = self._encode_template(observation_formatted, tokenizer)
+ ls_for_save.append(observation_encoded)
+ elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['assistant']:
+ assistant_message = messages[i]
+ assistant_formatted = self.assistant_formatter.format(content=assistant_message["content"])
+ assistant_encoded = self._encode_template(assistant_formatted, tokenizer)
+ ls_for_save.append(assistant_encoded)
+ # res_tuple = (ls_for_save[0], ls_for_save[1], ls_for_save[2], ls_for_save[3])
+ res_all.append(tuple(ls_for_save))
+ ls_for_save = []
+
+ if ls_for_save:
+ res_all.append(tuple(ls_for_save))
+
+ return res_all
+
LLAMA3_TEMPLATE = ConversationTemplate(
template_name='llama3',
user_formatter=StringFormatter(
@@ -69,6 +119,36 @@ def _encode(
special_starter=TemplateComponent(type='token', content='bos_token')
)
+LLAMA3_TEMPLATE_FOR_TOOL = ConversationTemplateForTool(
+ template_name='llama3_for_tool',
+ user_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>')
+ ]
+ ),
+ function_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<|start_header_id|>assistant<|end_header_id|>\n\n{{content}}<|eot_id|>')
+ ]
+ ),
+ observation_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>')
+ ]
+ ),
+ assistant_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<|start_header_id|>assistant<|end_header_id|>\n\n{{content}}<|eot_id|>')
+ ]
+ ),
+ system_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>')
+ ]
+ ),
+ special_starter=TemplateComponent(type='token', content='bos_token')
+)
+
LLAMA2_TEMPLATE = Llama2ConversationTemplate(
template_name='llama2',
@@ -89,4 +169,25 @@ def _encode(
TemplateComponent(type='string', content='<>\n{{content}}\n<>\n\n')
]
)
+)
+
+LLAMA2_TEMPLATE_FOR_TOOL = Llama2ConversationTemplate(
+ template_name='llama2_for_tool',
+ user_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='token', content='bos_token'),
+ TemplateComponent(type='string', content='[INST] {{content}} [/INST]')
+ ]
+ ),
+ assistant_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='{{content}}'),
+ TemplateComponent(type='token', content='eos_token')
+ ]
+ ),
+ system_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<>\n{{content}}\n<>\n\n')
+ ]
+ )
)
\ No newline at end of file
diff --git a/src/lmflow/utils/conversation_template/qwen.py b/src/lmflow/utils/conversation_template/qwen.py
index 9783f1fba..88c8fb581 100644
--- a/src/lmflow/utils/conversation_template/qwen.py
+++ b/src/lmflow/utils/conversation_template/qwen.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
-from .base import StringFormatter, TemplateComponent, ConversationTemplate
+from .base import StringFormatter, TemplateComponent, ConversationTemplate, ConversationTemplateForTool
QWEN2_TEMPLATE = ConversationTemplate(
@@ -22,4 +22,34 @@
]
),
separator=TemplateComponent(type='string', content='\n')
+)
+
+QWEN2_TEMPLATE_FOR_TOOL = ConversationTemplateForTool(
+ template_name='qwen2_for_tool',
+ user_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<|im_start|>user\n{{content}}<|im_end|>\n')
+ ]
+ ),
+ function_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<|im_start|>assistant\n{{content}}<|im_end|>\n')
+ ]
+ ),
+ observation_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<|im_start|>tool\n{{content}}<|im_end|>\n')
+ ]
+ ),
+ assistant_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<|im_start|>assistant\n{{content}}<|im_end|>\n')
+ ]
+ ),
+ system_formatter=StringFormatter(
+ template=[
+ TemplateComponent(type='string', content='<|im_start|>system\n{{content}}<|im_end|>\n')
+ ]
+ ),
+ separator=TemplateComponent(type='string', content='\n')
)
\ No newline at end of file