diff --git a/contrib/tool-finetune/README.md b/contrib/tool-finetune/README.md new file mode 100644 index 000000000..9bdb775ad --- /dev/null +++ b/contrib/tool-finetune/README.md @@ -0,0 +1,75 @@ +## Function-call Finetune + +### Conda environment + +``` +bitsandbytes==0.40.0 +deepspeed==0.12.0 +flash-attn==2.5.7 +peft==0.10.0 +torch==2.1.2+cu118 +transformers==4.40.1 +vllm==0.5.2 +xformers==0.0.27 +``` + +### Conversation Template +``` +{ + "type": "conversation", + "instances": [ + { + "system": "You are a helpful assistant with access to the following functions. Use them if required - ", + "tools": ["{\"name\": \"", \"description\": \"", \"parameters\": {\"type\": \"object\", \"properties\": {\"property_1\": {\"type\": \"xxx\", \"description\": \"\"}, \"property_2\": {\"type\": \"xxx\", \"description\": \"\"}}, \"required\": [\"required_1\", \"property_n\"]}}",]", + "messages": [ + { + "role": "user", + "content": "" + }, + { + "role": "function", + "content": "" + }, + { + "role": "observation", + "content": "" + }, + { + "role": "assistant", + "content": "" + } + ] + }, + { + "system": "You are a helpful assistant, with no access to external functions.", + "tools": [], + "messages": [ + { + "role": "user", + "content": "" + }, + { + "role": "assistant", + "content": "" + } + ] + }, + ] +} +``` + +### Run Function-call Finetune Example +``` +./contrib/tool-finetune/run_function_call_finetune.sh \ + --model_name_or_path meta-llama/Meta-Llama-3-8B \ + --trust_remote_code True \ + --conversation_template llama3_for_tool \ + --dataset_path /home/wenhesun/LMFlow/data/glaive-function-calling-v2 \ + --output_model_path /home/wenhesun/LMFlow/output_models/function-call-finetuned-llama +``` + +### Command-Line Arguments +- `--model-name-or-path` - Specifies the name or path of the model used for +- `conversation_template` - So far supports the following choices: llama3_for_tool, qwen2_for_tool +- `dataset_path` - The path to the dataset that has been converted to the specified format +- `output_model_path` - Directory to store the finetuned model and logs \ No newline at end of file diff --git a/contrib/tool-finetune/function_call_finetune.py b/contrib/tool-finetune/function_call_finetune.py new file mode 100644 index 000000000..313563db5 --- /dev/null +++ b/contrib/tool-finetune/function_call_finetune.py @@ -0,0 +1,282 @@ +import os +import sys +sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) +import hashlib +from typing import Dict, List, Union, Tuple, Optional, Sequence +import logging +from datasets import Features +import transformers +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.testing_utils import CaptureLogger +from transformers import HfArgumentParser, PreTrainedTokenizer, PreTrainedTokenizerFast +from peft import LoraConfig, TaskType +from lmflow.datasets.dataset import Dataset +from lmflow.pipeline.auto_pipeline import AutoPipeline +from lmflow.tokenization.hf_decoder_model import blocking +from lmflow.utils.conversation_template.base import TemplateComponent +from lmflow.utils.constants import ( + TEXT_ONLY_DATASET_DESCRIPTION, + TEXT2TEXT_DATASET_DESCRIPTION, + CONVERSATION_DATASET_DESCRIPTION, + CONVERSATION_ROLE_NAMES +) +from lmflow.utils.conversation_template import ConversationTemplateForTool, PRESET_TEMPLATES +from lmflow.args import ( + ModelArguments, + DatasetArguments, + AutoArguments, +) +from lmflow.models.hf_decoder_model import HFDecoderModel +tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") +logger = logging.getLogger(__name__) + + +class HFDecoderModelForTool(HFDecoderModel): + def tokenize( + self, + dataset, + add_special_tokens=True, + *args, + **kwargs + ) -> Dataset: + """ + Tokenize the full dataset. + + Parameters + ------------ + dataset : lmflow.datasets.Dataset. + + args : Optional. + Positional arguments. + + kwargs : Optional. + Keyword arguments. + + Returns + ------------ + tokenized_datasets : + The tokenized dataset, without any leading or trailing special + tokens (normally they are Begin-Of-Sentence or End-Of-Sentence + tokens). + """ + # Preprocessing the datasets. + # First we tokenize all the texts. + if dataset.get_backend() != "huggingface": + raise NotImplementedError( + "tokenization of datasets with non-huggingface backend are" + "not supported yet" + ) + + dataset_type = dataset.get_type() + model_args = self.model_args + raw_datasets = dataset + hf_raw_datasets = dataset.get_backend_dataset() + column_names = list(hf_raw_datasets.features) + data_args = raw_datasets.get_data_args() + + # Requires three types of information for tokenizing different datasets + # 1) Which fields require tokenization, e.g. + # "text2float": "text", but not "float" + # "text2text": both "input" and "output" + # 2) How will there tokenized sequence concatenated together, e.g. + # "text_only": "text" -> "text" + # "text2text": "input", "output" -> "input" + "output" + # 3) Which fields require loss in final computation, e.g. + # "text_only": "text" + # "text2text": "output" only + tokenized_column_order = None # Handles 1) and 2) + label_columns = None # Handles 3) + if dataset_type == "text_only": + tokenized_column_order = ["text"] + label_columns = ["text"] + elif dataset_type == "text2text": + tokenized_column_order = ["input", "output"] + label_columns = ["output"] + add_special_tokens = False + elif dataset_type == "conversation": + if data_args.conversation_template: + if data_args.conversation_template in PRESET_TEMPLATES.keys(): + conversation_template = PRESET_TEMPLATES[data_args.conversation_template] + else: + raise NotImplementedError( + f"Conversation template {data_args.conversation_template} is not supported yet." + ) + else: + logger.warning("No conversation template provided. Using default template.") + conversation_template = PRESET_TEMPLATES['empty'] + + logger.warning(f"Conversation template: {conversation_template}") + else: + raise NotImplementedError( + f"dataset type \"{dataset_type}\" is not supported, currently" + " only support following data types:\n" + f" 1) {TEXT_ONLY_DATASET_DESCRIPTION}\n" + f" 2) {TEXT2TEXT_DATASET_DESCRIPTION}\n" + f" 3) {CONVERSATION_DATASET_DESCRIPTION}\n" + ) + + # Whether to truncate long sequences to fit into max_length + use_truncation = False + if model_args.use_lora or data_args.disable_group_texts: + use_truncation = True + + tokenize_fn = conversation_tokenize_function + tokenize_fn_kwargs = { + "data_args": data_args, + "tokenizer": self.tokenizer, + "column_names": column_names, + } + if "conversation" in dataset_type: + tokenize_fn_kwargs["conversation_template"] = conversation_template + else: + tokenize_fn_kwargs["label_columns"] = label_columns + tokenize_fn_kwargs["tokenized_column_order"] = tokenized_column_order + tokenize_fn_kwargs["add_special_tokens"] = add_special_tokens + tokenize_fn_kwargs["use_truncation"] = use_truncation + + tokenize_kwargs = {} + if not data_args.streaming: + fingerprint = hashlib.md5( + ( + raw_datasets.get_fingerprint() + + str(self.tokenizer) + + f'###padding_side={self.tokenizer.padding_side}' + + ('###conversation_template=' + str(conversation_template) if "conversation" in dataset_type else "") + + f'###disable_group_texts={data_args.disable_group_texts}' + + f'###block_size={data_args.block_size}' + ).encode("utf-8") + ).hexdigest() + tokenize_kwargs = { + "num_proc": data_args.preprocessing_num_workers, + "load_from_cache_file": not data_args.overwrite_cache, + "desc": "Running tokenizer on dataset", + "new_fingerprint": fingerprint, + } + + tokenized_datasets = raw_datasets.map( + tokenize_fn, + batched=True, + remove_columns=column_names, + fn_kwargs=tokenize_fn_kwargs, + **tokenize_kwargs + ) + + return tokenized_datasets + +def conversation_tokenize_function( + examples, + data_args: DatasetArguments, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + column_names, + conversation_template: ConversationTemplateForTool, +) -> Dict: + """Handels conversation datasets tokenization + """ + num_example = len(examples[column_names[0]]) + token_dict = { + "input_ids": [[] for _ in range(num_example)], + "attention_mask": [[] for _ in range(num_example)], + "labels": [[] for _ in range(num_example)], + } + with CaptureLogger(tok_logger) as cl: + for i in range(len(examples["messages"])): + messages = examples["messages"][i] + system = examples.get("system", [None] * num_example)[i] + tools = examples.get("tools", [None] * num_example)[i] + if len(messages) < 2 or messages[0]['role'] != CONVERSATION_ROLE_NAMES['user']: + tok_logger.warning( + "Invalid instance encountered. Either the conversation has less than " + "one round or the first message is not from the user." + ) + continue + + if len(messages) % 2 != 0: + logger.warning( + "The number of messages is not even, the last message will be ignored." + ) + messages = messages[:-1] + + encoded_conversation = conversation_template.encode_conversation( + tokenizer=tokenizer, + messages=messages, + system=system, + tools=tools, + ) + + input_ids, labels = [], [] + for turn_idx, conversation_tuple in enumerate(encoded_conversation): + if len(conversation_tuple) == 2: + user_input = conversation_tuple[0] + assistant_result = conversation_tuple[1] + input_ids += user_input + assistant_result + if data_args.train_on_prompt: + labels += user_input + assistant_result + else: + labels += [-100] * len(user_input) + assistant_result + elif len(conversation_tuple) == 4: + user_input = conversation_tuple[0] + function_result = conversation_tuple[1] + observation_input = conversation_tuple[2] + assistant_result = conversation_tuple[3] + input_ids += user_input + function_result + observation_input + assistant_result + if data_args.train_on_prompt: + labels += user_input + function_result + observation_input + assistant_result + else: + labels += [-100] * len(user_input) + function_result + [-100] * len(observation_input) + assistant_result + else: + logger.warning("The number of roles in conversation is not appropriate") + + token_dict["input_ids"][i].extend(input_ids) + token_dict["attention_mask"][i].extend([1] * len(input_ids)) + token_dict["labels"][i].extend(labels) + + if data_args.disable_group_texts: + token_dict = blocking( + token_dict=token_dict, + block_size=data_args.block_size, + model_max_length=tokenizer.model_max_length, + pad_token_id=tokenizer.pad_token_id, + padding_side=tokenizer.padding_side, + ) + + # clm input could be much much longer than block_size + if "Token indices sequence length is longer than the" in cl.out: + tok_logger.warning( + "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" + " before being passed to the model." + ) + return token_dict + + +def train(): + + # Initialize args + ## Prepare training_args + pipeline_name = "finetuner" + PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) + parser = HfArgumentParser((ModelArguments, DatasetArguments, PipelineArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() + print("Model args", model_args) + print("data_args", data_args) + print("training_args", pipeline_args) + + # Init model + model = HFDecoderModelForTool(model_args) + + # Process data + dataset = Dataset(data_args) + + # Finetune + finetuner = AutoPipeline.get_pipeline( + pipeline_name=pipeline_name, + model_args=model_args, + data_args=data_args, + pipeline_args=pipeline_args, + ) + tuned_model = finetuner.tune(model=model, dataset=dataset) + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/contrib/tool-finetune/run_function_call_finetune.sh b/contrib/tool-finetune/run_function_call_finetune.sh new file mode 100755 index 000000000..220abee9a --- /dev/null +++ b/contrib/tool-finetune/run_function_call_finetune.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Please run this script under ${project_id} in project directory of +# https://github.com/shizhediao/llm-ft +# COMMIT: d5fecf30ba8011067b10cf51fede53a5ab6574e4 + +# Parses arguments +model_name_or_path=gpt2 +dataset_path=data/alpaca/train_conversation +output_dir=output_models/finetune +deepspeed_args="--master_port=12000" +conversation_template=llama2 + +# Safety related arguments +trust_remote_code=0 + +while [[ $# -ge 1 ]]; do + key="$1" + case ${key} in + -m|--model_name_or_path) + model_name_or_path="$2" + shift + ;; + -d|--dataset_path) + dataset_path="$2" + shift + ;; + -o|--output_model_path) + output_dir="$2" + shift + ;; + --conversation_template) + conversation_template="$2" + shift + ;; + --deepspeed_args) + deepspeed_args="$2" + shift + ;; + --trust_remote_code) + trust_remote_code="$2" + shift + ;; + *) + echo "error: unknown option \"${key}\"" 1>&2 + exit 1 + esac + shift +done + +# Finetune +exp_id=finetune +project_dir=$(cd "$(dirname $0)"/..; pwd) +log_dir=${project_dir}/log/${exp_id} +mkdir -p ${output_dir} ${log_dir} + +deepspeed ${deepspeed_args} \ + contrib/tool-finetune/function_call_finetune.py \ + --model_name_or_path ${model_name_or_path} \ + --trust_remote_code ${trust_remote_code} \ + --dataset_path ${dataset_path} \ + --output_dir ${output_dir} --overwrite_output_dir \ + --conversation_template ${conversation_template} \ + --num_train_epochs 0.01 \ + --learning_rate 2e-5 \ + --disable_group_texts 1 \ + --block_size 1024 \ + --per_device_train_batch_size 1 \ + --deepspeed configs/ds_config_zero3.json \ + --fp16 \ + --run_name finetune \ + --validation_split_percentage 0 \ + --logging_steps 20 \ + --do_train \ + --ddp_timeout 72000 \ + --save_steps 5000 \ + --dataloader_num_workers 1 \ + > >(tee ${log_dir}/train.log) \ + 2> >(tee ${log_dir}/train.err >&2) diff --git a/src/lmflow/utils/conversation_template/__init__.py b/src/lmflow/utils/conversation_template/__init__.py index da5266e64..3f62001e5 100644 --- a/src/lmflow/utils/conversation_template/__init__.py +++ b/src/lmflow/utils/conversation_template/__init__.py @@ -1,15 +1,15 @@ #!/usr/bin/env python # coding=utf-8 # Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. -from .base import EMPTY_TEMPLATE, EMPTY_NO_SPECIAL_TOKENS_TEMPLATE, ConversationTemplate +from .base import EMPTY_TEMPLATE, EMPTY_NO_SPECIAL_TOKENS_TEMPLATE, ConversationTemplate, ConversationTemplateForTool from .chatglm import CHATGLM3_TEMPLATE from .chatml import CHATML_TEMPLATE from .deepseek import DEEPSEEK_TEMPLATE from .gemma import GEMMA_TEMPLATE from .internlm import INTERNLM2_TEMPLATE -from .llama import LLAMA2_TEMPLATE, LLAMA3_TEMPLATE +from .llama import LLAMA2_TEMPLATE, LLAMA3_TEMPLATE, LLAMA3_TEMPLATE_FOR_TOOL from .phi import PHI3_TEMPLATE -from .qwen import QWEN2_TEMPLATE +from .qwen import QWEN2_TEMPLATE, QWEN2_TEMPLATE_FOR_TOOL from .yi import YI1_5_TEMPLATE from .zephyr import ZEPHYR_TEMPLATE @@ -25,8 +25,10 @@ 'internlm2': INTERNLM2_TEMPLATE, 'llama2': LLAMA2_TEMPLATE, 'llama3': LLAMA3_TEMPLATE, + 'llama3_for_tool': LLAMA3_TEMPLATE_FOR_TOOL, 'phi3': PHI3_TEMPLATE, 'qwen2': QWEN2_TEMPLATE, + 'qwen2_for_tool': QWEN2_TEMPLATE_FOR_TOOL, 'yi': CHATML_TEMPLATE, 'yi1_5': YI1_5_TEMPLATE, 'zephyr': ZEPHYR_TEMPLATE diff --git a/src/lmflow/utils/conversation_template/base.py b/src/lmflow/utils/conversation_template/base.py index 4542cbee0..e93afacd2 100644 --- a/src/lmflow/utils/conversation_template/base.py +++ b/src/lmflow/utils/conversation_template/base.py @@ -9,6 +9,7 @@ from transformers import PreTrainedTokenizer +from lmflow.utils.constants import CONVERSATION_ROLE_NAMES logger = logging.getLogger(__name__) @@ -156,6 +157,8 @@ def format(self, **kwargs) -> list: class ConversationTemplate: user_formatter: Formatter assistant_formatter: Formatter + function_formatter: Optional[Formatter] = None, + observation_formatter: Optional[Formatter] = None, system_formatter: Optional[Formatter] = None tools_formatter: Optional[Formatter] = None separator: Optional[TemplateComponent] = None @@ -393,6 +396,168 @@ def _ensure_id_list(self, obj: Union[int, List[int]]) -> List[int]: else: raise ValueError(f"Object type {type(obj)} is not supported yet.") +@dataclass +class ConversationTemplateForTool(ConversationTemplate): + def encode_conversation( + self, + tokenizer: PreTrainedTokenizer, + messages: List[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[List[str]] = None, + remove_last_sep: bool = False, + **kwargs + ) -> Sequence[Tuple[List[int], List[int]]]: + r''' + Messages here should be guaranteed to be in pairs, with the first message being the user message and the second message being the system message. + Data example: + ```json + { + "conversation_id": 2, + "system": "sysinfo1", + "tools": ["tool_1_desc"], + "messages": [ + { + "role": "user", + "content": "hi" + }, + { + "role": "assistant", + "content": "Hello!" + } + ] + } + ``` + ''' + assert isinstance(messages, list), "Messages must be a list." + + if tools is not None: + tools = ','.join(tools) + # logger.warning("Tools are not supported yet. Please include tools in the system message manually.") + else: + tools = '' + + if system: + if system.replace(" ",""): + if not self.system_formatter: + raise ValueError("Your dataset contains system message but no system formatter is provided. " + "Consider either providing a system formatter or removing system prompt from your dataset.") + else: + system = None + encoded_pairs = self._encode(tokenizer, messages, system, tools, **kwargs) + + if self.separator and remove_last_sep: + # For models that require a separator between messages, + # user can include the seperator at the end of each template + # and specify the separator. Auto formatting will remove the + # last separator once user specifies this option. + encoded_pairs = self.remove_last_separator(encoded_pairs, tokenizer) + + if self.special_starter: + # For models that has ONLY ONE bos token at the beginning of + # a conversation session (not a conversation pair), user can + # specify a special starter to add that starter to the very + # beginning of the conversation session. + # eg: + # llama-2: and at every pair of conversation + # v.s. + # llama-3: <|begin_of_text|> only at the beginning of a session + encoded_pairs = self.add_special_starter(encoded_pairs, tokenizer) + + if self.special_stopper: + encoded_pairs = self.add_special_stopper(encoded_pairs, tokenizer) + + return encoded_pairs + + def _encode( + self, + tokenizer: PreTrainedTokenizer, + messages: List[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **kwargs + ) -> Sequence[Tuple[List[int], List[int]]]: + # TODO: truncation according to model max length + # TODO: make sure the last few tokens are "learnable", not masked with token_id = -100. + + res_all = [] + # Concatenate the system and tools strings + system = system + tools + system_formatted = self.system_formatter.format(content=system) if system else [] + system_encoded = self._encode_template(system_formatted, tokenizer) + ls_for_save = [] + for i in range(0, len(messages), 1): + if messages[i]['role'] == CONVERSATION_ROLE_NAMES['user']: + user_message = messages[i] + user_formatted = self.user_formatter.format(content=user_message["content"]) + user_encoded = self._encode_template(user_formatted, tokenizer) + if i == 0: + user_encoded = system_encoded + user_encoded + ls_for_save.append(user_encoded) + elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['function']: + function_message = messages[i] + function_formatted = self.assistant_formatter.format(content=function_message['content']) + function_encoded = self._encode_template(function_formatted, tokenizer) + ls_for_save.append(function_encoded) + elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['observation']: + observation_message = messages[i] + observation_formatted = self.user_formatter.format(content=observation_message['content']) + observation_encoded = self._encode_template(observation_formatted, tokenizer) + ls_for_save.append(observation_encoded) + elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['assistant']: + assistant_message = messages[i] + assistant_formatted = self.assistant_formatter.format(content=assistant_message["content"]) + assistant_encoded = self._encode_template(assistant_formatted, tokenizer) + ls_for_save.append(assistant_encoded) + # res_tuple = (ls_for_save[0], ls_for_save[1], ls_for_save[2], ls_for_save[3]) + res_all.append(tuple(ls_for_save)) + ls_for_save = [] + + if ls_for_save: + res_all.append(tuple(ls_for_save)) + + return res_all + + def _encode_template( + self, + template: List[TemplateComponent], + tokenizer: PreTrainedTokenizer, + **kwargs + ) -> List[int]: + """Encode template components into token ids. + + Parameters + ---------- + template : List[TemplateComponent] + Formatted template components. + tokenizer : PreTrainedTokenizer + Tokenizer to convert tokens into token ids. + + Returns + ------- + List[int] + Encoded token ids. + """ + encoded_ids = [] + for component in template: + if component.type == 'string': + if len(component.content) == 0: + logger.warning("Empty string component found in the template.") + continue + else: + encoded_ids += tokenizer.encode(component.content, add_special_tokens=False) + elif component.type == 'token': + if component.content == 'bos_token': + encoded_ids += [tokenizer.bos_token_id] + elif component.content == 'eos_token': + encoded_ids += [tokenizer.eos_token_id] + else: + encoded_ids += self._ensure_id_list(tokenizer.convert_tokens_to_ids(component.content)) + elif component.type == 'token_id': + encoded_ids += self._ensure_id_list(component.content) + else: + raise NotImplementedError(f"Component type {component.type} is not supported yet.") + return encoded_ids + EMPTY_TEMPLATE = ConversationTemplate( template_name='empty', diff --git a/src/lmflow/utils/conversation_template/llama.py b/src/lmflow/utils/conversation_template/llama.py index 1895ee144..5cc147cfe 100644 --- a/src/lmflow/utils/conversation_template/llama.py +++ b/src/lmflow/utils/conversation_template/llama.py @@ -6,12 +6,12 @@ from transformers import PreTrainedTokenizer -from .base import StringFormatter, TemplateComponent, ConversationTemplate +from .base import StringFormatter, TemplateComponent, ConversationTemplate, ConversationTemplateForTool +from lmflow.utils.constants import CONVERSATION_ROLE_NAMES logger = logging.getLogger(__name__) - class Llama2ConversationTemplate(ConversationTemplate): def _encode( self, @@ -47,8 +47,58 @@ def _encode( )) return res_all - - + +class Llama2ConversationTemplateForTool(Llama2ConversationTemplate): + def _encode( + self, + tokenizer: PreTrainedTokenizer, + messages: List[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **kwargs + ) -> Sequence[Tuple[List[int], List[int]]]: + if tools: + # logger.warning("Formatted tools are not supported in Llama2, thus tools will be ignored. " + # "If this is intended, please include tools in the system message manually.") + system = system + tools + res_all = [] + system_formatted = self.system_formatter.format(content=system) if system else [] + system_formatted_text = "".join([component.content for component in system_formatted if component.type == 'string']) # HACK + ls_for_save = [] + for i in range(0, len(messages), 1): + if messages[i]['role'] == CONVERSATION_ROLE_NAMES['user']: + user_message = messages[i] + if i == 0: + user_content = system_formatted_text + user_message['content'] + else: + user_content = user_message['content'] + user_formatted = self.user_formatter.format(content=user_content) + user_encoded = self._encode_template(user_formatted, tokenizer) + ls_for_save.append(user_encoded) + elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['function']: + function_message = messages[i] + function_formatted = self.assistant_formatter.format(content=function_message['content']) + function_encoded = self._encode_template(function_formatted, tokenizer) + ls_for_save.append(function_encoded) + elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['observation']: + observation_message = messages[i] + observation_formatted = self.user_formatter.format(content=observation_message['content']) + observation_encoded = self._encode_template(observation_formatted, tokenizer) + ls_for_save.append(observation_encoded) + elif messages[i]['role'] == CONVERSATION_ROLE_NAMES['assistant']: + assistant_message = messages[i] + assistant_formatted = self.assistant_formatter.format(content=assistant_message["content"]) + assistant_encoded = self._encode_template(assistant_formatted, tokenizer) + ls_for_save.append(assistant_encoded) + # res_tuple = (ls_for_save[0], ls_for_save[1], ls_for_save[2], ls_for_save[3]) + res_all.append(tuple(ls_for_save)) + ls_for_save = [] + + if ls_for_save: + res_all.append(tuple(ls_for_save)) + + return res_all + LLAMA3_TEMPLATE = ConversationTemplate( template_name='llama3', user_formatter=StringFormatter( @@ -69,6 +119,36 @@ def _encode( special_starter=TemplateComponent(type='token', content='bos_token') ) +LLAMA3_TEMPLATE_FOR_TOOL = ConversationTemplateForTool( + template_name='llama3_for_tool', + user_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>') + ] + ), + function_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<|start_header_id|>assistant<|end_header_id|>\n\n{{content}}<|eot_id|>') + ] + ), + observation_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>') + ] + ), + assistant_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<|start_header_id|>assistant<|end_header_id|>\n\n{{content}}<|eot_id|>') + ] + ), + system_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>') + ] + ), + special_starter=TemplateComponent(type='token', content='bos_token') +) + LLAMA2_TEMPLATE = Llama2ConversationTemplate( template_name='llama2', @@ -89,4 +169,25 @@ def _encode( TemplateComponent(type='string', content='<>\n{{content}}\n<>\n\n') ] ) +) + +LLAMA2_TEMPLATE_FOR_TOOL = Llama2ConversationTemplate( + template_name='llama2_for_tool', + user_formatter=StringFormatter( + template=[ + TemplateComponent(type='token', content='bos_token'), + TemplateComponent(type='string', content='[INST] {{content}} [/INST]') + ] + ), + assistant_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='{{content}}'), + TemplateComponent(type='token', content='eos_token') + ] + ), + system_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<>\n{{content}}\n<>\n\n') + ] + ) ) \ No newline at end of file diff --git a/src/lmflow/utils/conversation_template/qwen.py b/src/lmflow/utils/conversation_template/qwen.py index 9783f1fba..88c8fb581 100644 --- a/src/lmflow/utils/conversation_template/qwen.py +++ b/src/lmflow/utils/conversation_template/qwen.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. -from .base import StringFormatter, TemplateComponent, ConversationTemplate +from .base import StringFormatter, TemplateComponent, ConversationTemplate, ConversationTemplateForTool QWEN2_TEMPLATE = ConversationTemplate( @@ -22,4 +22,34 @@ ] ), separator=TemplateComponent(type='string', content='\n') +) + +QWEN2_TEMPLATE_FOR_TOOL = ConversationTemplateForTool( + template_name='qwen2_for_tool', + user_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<|im_start|>user\n{{content}}<|im_end|>\n') + ] + ), + function_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<|im_start|>assistant\n{{content}}<|im_end|>\n') + ] + ), + observation_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<|im_start|>tool\n{{content}}<|im_end|>\n') + ] + ), + assistant_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<|im_start|>assistant\n{{content}}<|im_end|>\n') + ] + ), + system_formatter=StringFormatter( + template=[ + TemplateComponent(type='string', content='<|im_start|>system\n{{content}}<|im_end|>\n') + ] + ), + separator=TemplateComponent(type='string', content='\n') ) \ No newline at end of file