diff --git a/fastchat/conversation.py b/fastchat/conversation.py index f3beeeaa2..67ba7df38 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -805,6 +805,19 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Default ChatML format +register_conv_template( + Conversation( + name="chatml", + system_template="<|im_start|>system\n{system_message}", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + stop_str="<|im_end|>", + ) +) + # Baichuan-13B-Chat template register_conv_template( # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555 diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 0ddbc666c..82b8df3a4 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -1283,7 +1283,11 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("h4_default_v3") + tokenizer = AutoTokenizer.from_pretrained(model_path) + if "<|im_start|>" in tokenizer.chat_template: + return get_conv_template("chatml") + else: + return get_conv_template("h4_default_v3") class MistralAdapter(BaseModelAdapter): """The model adapter for mistral""" @@ -1298,7 +1302,11 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("h4_default_v3") + tokenizer = AutoTokenizer.from_pretrained(model_path) + if "<|im_start|>" in tokenizer.chat_template: + return get_conv_template("chatml") + else: + return get_conv_template("h4_default_v3") class H4DeepSeekAdapter(BaseModelAdapter): """The model adapter for H4 DeepSeek models""" @@ -1313,7 +1321,11 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("h4_default_v3") + tokenizer = AutoTokenizer.from_pretrained(model_path) + if "<|im_start|>" in tokenizer.chat_template: + return get_conv_template("chatml") + else: + return get_conv_template("h4_default_v3") class H4MixtralAdapter(BaseModelAdapter): """The model adapter for H4 Mixtral models""" @@ -1328,7 +1340,11 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("h4_default_v3") + tokenizer = AutoTokenizer.from_pretrained(model_path) + if "<|im_start|>" in tokenizer.chat_template: + return get_conv_template("chatml") + else: + return get_conv_template("h4_default_v3") class CuteGPTAdapter(BaseModelAdapter):