diff --git a/.gitignore b/.gitignore index 94b6e614d..82276c686 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,6 @@ tests/state_of_the_union.txt # Build build + +# Data +fastchat/llm_judge/data/ \ No newline at end of file diff --git a/fastchat/llm_judge/gen_judgment.py b/fastchat/llm_judge/gen_judgment.py index a1c70b295..7b1b18116 100644 --- a/fastchat/llm_judge/gen_judgment.py +++ b/fastchat/llm_judge/gen_judgment.py @@ -301,7 +301,7 @@ def make_judge_single(judge_model, judge_prompts): # Show match stats and prompt enter to continue print("Stats:") print(json.dumps(match_stat, indent=4)) - input("Press Enter to confirm...") + # input("Press Enter to confirm...") # Play matches if args.parallel == 1: diff --git a/fastchat/llm_judge/run.sh b/fastchat/llm_judge/run.sh new file mode 100755 index 000000000..f7f6a319b --- /dev/null +++ b/fastchat/llm_judge/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -x -e +HUB_MODEL_ID=$1 +MT_BENCH_ID=$2 +[ -z "$3" ] && DTYPE="float16" || DTYPE=$3 + +# Generate answer +python gen_model_answer.py --model-path $HUB_MODEL_ID --model-id $MT_BENCH_ID --dtype $DTYPE + +# Judge! +python gen_judgment.py --model-list $MT_BENCH_ID + +# Get results +python show_result.py \ No newline at end of file diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 91fe223fb..e52ed4017 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -179,6 +179,7 @@ def load_model( """Load a model from Hugging Face.""" # get model adapter adapter = get_model_adapter(model_path) + print(f"Using model adapter: {adapter.__class__.__name__} for model path {model_path} and revision {revision}") # Handle device mapping cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( @@ -1834,6 +1835,15 @@ def match(self, model_path: str): def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("zephyr") +class H4MistralAdapter(BaseModelAdapter): + """The model adapter for H4 Mistral models""" + + def match(self, model_path: str): + return "HuggingFaceH4/mistral" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zephyr") + class XwinLMAdapter(BaseModelAdapter): """The model adapter for Xwin-LM V0.1 and V0.2 series of models(e.g., Xwin-LM/Xwin-LM-70B-V0.1)""" @@ -1962,6 +1972,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(ZephyrAdapter) register_model_adapter(XwinLMAdapter) register_model_adapter(LemurAdapter) +register_model_adapter(H4MistralAdapter) register_model_adapter(PygmalionAdapter) register_model_adapter(MicrosoftOrcaAdapter) register_model_adapter(YiAdapter)