Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
claude.md
186 changes: 186 additions & 0 deletions common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
Common utility functions for ToMBench evaluation scripts.
This module contains shared functions used by both API-based and HuggingFace evaluation.
"""
import random
from prompts import *


def format_prompt_4(d, args):
"""
Format prompt for 4-choice questions.

Args:
d: Data dictionary containing the question and options
args: Arguments containing language setting

Returns:
tuple: (mapping_dict, formatted_prompt)
"""
if args.language == 'zh':
cA = d['选项A'].replace("A. ", "")
cB = d['选项B'].replace("B. ", "")
cC = d['选项C'].replace("C. ", "")
cD = d['选项D'].replace("D. ", "")
choices = [cA, cB, cC, cD]
random.shuffle(choices)
prompt = UserEvaluatePrompt4Choices_zh.format(story=d['故事'], question=d['问题'], choice_a=choices[0], choice_b=choices[1], choice_c=choices[2], choice_d=choices[3])
map = {"A": "", "B": "", "C": "", "D": ""}

if choices[0] == cA:
map['A'] = 'A'
elif choices[0] == cB:
map['A'] = 'B'
elif choices[0] == cC:
map['A'] = 'C'
elif choices[0] == cD:
map['A'] = 'D'

if choices[1] == cA:
map['B'] = 'A'
elif choices[1] == cB:
map['B'] = 'B'
elif choices[1] == cC:
map['B'] = 'C'
elif choices[1] == cD:
map['B'] = 'D'

if choices[2] == cA:
map['C'] = 'A'
elif choices[2] == cB:
map['C'] = 'B'
elif choices[2] == cC:
map['C'] = 'C'
elif choices[2] == cD:
map['C'] = 'D'

if choices[3] == cA:
map['D'] = 'A'
elif choices[3] == cB:
map['D'] = 'B'
elif choices[3] == cC:
map['D'] = 'C'
elif choices[3] == cD:
map['D'] = 'D'
else:
cA = d['OPTION-A'].replace("A. ", "")
cB = d['OPTION-B'].replace("B. ", "")
cC = d['OPTION-C'].replace("C. ", "")
cD = d['OPTION-D'].replace("D. ", "")
choices = [cA, cB, cC, cD]
random.shuffle(choices)
prompt = UserEvaluatePrompt4Choices_en.format(story=d['STORY'], question=d['QUESTION'], choice_a=choices[0], choice_b=choices[1], choice_c=choices[2], choice_d=choices[3])
map = {"A": "", "B": "", "C": "", "D": ""}

if choices[0] == cA:
map['A'] = 'A'
elif choices[0] == cB:
map['A'] = 'B'
elif choices[0] == cC:
map['A'] = 'C'
elif choices[0] == cD:
map['A'] = 'D'

if choices[1] == cA:
map['B'] = 'A'
elif choices[1] == cB:
map['B'] = 'B'
elif choices[1] == cC:
map['B'] = 'C'
elif choices[1] == cD:
map['B'] = 'D'

if choices[2] == cA:
map['C'] = 'A'
elif choices[2] == cB:
map['C'] = 'B'
elif choices[2] == cC:
map['C'] = 'C'
elif choices[2] == cD:
map['C'] = 'D'

if choices[3] == cA:
map['D'] = 'A'
elif choices[3] == cB:
map['D'] = 'B'
elif choices[3] == cC:
map['D'] = 'C'
elif choices[3] == cD:
map['D'] = 'D'
return map, prompt


def format_prompt_2(d, args):
"""
Format prompt for 2-choice questions.

Args:
d: Data dictionary containing the question and options
args: Arguments containing language setting

Returns:
tuple: (mapping_dict, formatted_prompt)
"""
if args.language == 'zh':
cA = d['选项A'].replace("A. ", "")
cB = d['选项B'].replace("B. ", "")
choices = [cA, cB]
random.shuffle(choices)
prompt = UserEvaluatePrompt2Choices_zh.format(story=d['故事'], question=d['问题'], choice_a=choices[0], choice_b=choices[1])
map = {"A": "", "B": "", "C": "", "D": ""}
if choices[0] == cA:
map['A'] = 'A'
elif choices[0] == cB:
map['A'] = 'B'

if choices[1] == cA:
map['B'] = 'A'
elif choices[1] == cB:
map['B'] = 'B'
else:
cA = d['OPTION-A'].replace("A. ", "")
cB = d['OPTION-B'].replace("B. ", "")
choices = [cA, cB]
random.shuffle(choices)
prompt = UserEvaluatePrompt2Choices_en.format(story=d['STORY'], question=d['QUESTION'], choice_a=choices[0], choice_b=choices[1])
map = {"A": "", "B": "", "C": "", "D": ""}
if choices[0] == cA:
map['A'] = 'A'
elif choices[0] == cB:
map['A'] = 'B'

if choices[1] == cA:
map['B'] = 'A'
elif choices[1] == cB:
map['B'] = 'B'

return map, prompt


def get_system_prompt(args):
"""
Get the appropriate system prompt based on language and CoT settings.

Args:
args: Arguments containing language and cot settings

Returns:
str: The formatted system prompt
"""
if args.language == "zh":
return SystemEvaluatePrompt_zh_cot if args.cot else SystemEvaluatePrompt_zh
else:
return SystemEvaluatePrompt_en_cot if args.cot else SystemEvaluatePrompt_en


def has_four_choices(d):
"""
Check if the data entry has four answer options.

Args:
d: Data dictionary containing the question and options

Returns:
bool: True if 4-choice options exist, False for 2-choice
"""
return ('选项C' in d and d.get('选项C')) or ('OPTION-C' in d and d.get('OPTION-C'))
4 changes: 1 addition & 3 deletions eval_huggingface.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
python3 run_huggingface.py \
--task "" \
--model_name "" \
--api_base "" \
--api_key "" \
--language "zh" \
--cot True \
--try_times 5 \
--output_dir ./results \
--output_path ./results \
73 changes: 55 additions & 18 deletions get_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
import os

def most_common_element(lst):
"""
Find the most common element in a list.

Args:
lst: List of elements

Returns:
The most common element, or None if list is empty
"""
if not lst:
return None
element_freq = {}
for item in lst:
element_freq[item] = element_freq.get(item, 0) + 1
Expand All @@ -11,6 +22,17 @@ def most_common_element(lst):


def extract_answer(text):
"""
Extract answer from model output.

Args:
text: Model output text

Returns:
Extracted answer (A, B, C, or D)
"""
if not text:
return "A"
if "[[A]]" in text:
return "A"
elif "[[B]]" in text:
Expand Down Expand Up @@ -45,39 +67,54 @@ def extract_answer(text):
parser.add_argument("--input_path", type=str, default="")
parser.add_argument("--try_times", type=int, default=5)
args = parser.parse_args()

files = os.listdir("./results")
acc_per_task = {}
cnt_per_task = {}

acc_per_ability = {}
cnt_per_ability = {}

for file in files:
with open(f"./results/{file}", "r", encoding='utf-8') as f:
data = [json.loads(line) for line in f.readlines()]

answers = ["" for _ in range(len(data) // args.try_times)]
preds = [[] for _ in range(len(data) // args.try_times)]
abilities = ["" for _ in range(len(data) // args.try_times)]

# Find max index to properly size arrays
max_idx = max(d.get('idx', 0) for d in data) + 1 if data else 0

answers = [""] * max_idx
preds = [[] for _ in range(max_idx)]
abilities = [""] * max_idx

for d in data:
preds[d['idx']].append(d['map'][extract_answer(d['output'])])
if answers[d['idx']] == "":
answers[d['idx']] = d['answer']

if abilities[d['idx']] == "":
abilities[d['idx']] = d['data']['能力\nABILITY']


for i in range(len(data) // args.try_times):
idx = d.get('idx', 0)
extracted = extract_answer(d.get('output', ''))
# Safely get mapped answer, default to extracted if not in map
mapped = d.get('map', {}).get(extracted, extracted)
if mapped: # Only append non-empty values
preds[idx].append(mapped)
if answers[idx] == "":
answers[idx] = d.get('answer', '')

if abilities[idx] == "":
abilities[idx] = d.get('data', {}).get('能力\nABILITY', '')

# Count valid samples
valid_samples = sum(1 for i in range(max_idx) if answers[i])

for i in range(max_idx):
if not answers[i]:
continue # Skip empty samples

task = file.split("_")[0]
ability = abilities[i]
ability = abilities[i] if abilities[i] else "Unknown"

cnt_per_task[task] = cnt_per_task.get(task, 0) + 1
cnt_per_ability[ability] = cnt_per_ability.get(ability, 0) + 1


if answers[i] == most_common_element(preds[i]):
# Get most common prediction
most_common = most_common_element(preds[i])
if most_common and answers[i] == most_common:
acc_per_task[task] = acc_per_task.get(task, 0) + 1
acc_per_ability[ability] = acc_per_ability.get(ability, 0) + 1

Expand Down
Loading