From fe1fe1756f1a55813fa49b15399b5c31316f7663 Mon Sep 17 00:00:00 2001 From: yuyuan12138 <2797274057@qq.com> Date: Tue, 27 Jan 2026 13:08:42 +0800 Subject: [PATCH 1/4] update --- CLAUDE.md | 126 ++++++++++++++++++++++++++ common.py | 186 ++++++++++++++++++++++++++++++++++++++ eval_huggingface.sh | 4 +- run_api.py | 216 +++++++++++++------------------------------- run_huggingface.py | 160 +++----------------------------- 5 files changed, 390 insertions(+), 302 deletions(-) create mode 100644 CLAUDE.md create mode 100644 common.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..fa0a9ce --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,126 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +ToMBench is a bilingual (Chinese/English) benchmark for evaluating Theory of Mind (ToM) capabilities in Large Language Models. It contains 2,860 testing samples across 21 tasks, covering 6 ability categories and 31 specific abilities from the ATOMS framework. + +**IMPORTANT**: Use ToMBench for evaluation purposes only. Do NOT use ToMBench data for training to prevent data contamination. + +## Core Architecture + +The evaluation pipeline consists of three main components: + +1. **Data Layer** (`data/`): 21 JSONL files containing stories, questions, and multiple-choice options +2. **Evaluation Scripts**: `run_api.py` (API-based) and `run_huggingface.py` (local models) +3. **Results Analysis**: `get_results.py` aggregates multiple trials and calculates accuracy + +### Evaluation Flow + +``` +data/*.jsonl → run_api.py/run_huggingface.py → results/*.jsonl → get_results.py → results.json +``` + +Each sample is evaluated `try_times` (default: 5) times with randomized answer options to mitigate position bias. Final answers use majority voting. + +### Key Design Patterns + +- **Option Randomization**: Answer choices are shuffled each trial, with a mapping dict tracking original positions +- **Bilingual Prompts**: All prompts in `prompts.py` exist in both Chinese (`_zh`) and English (`_en`) +- **Chain-of-Thought**: Optional CoT mode via `--cot` flag uses different system prompts +- **Concurrent Processing**: `run_api.py` uses ThreadPoolExecutor (32 workers) for parallel API calls + +## Commands + +### API-based Evaluation (OpenAI, etc.) + +```bash +python3 run_api.py \ + --task "" \ + --model_name gpt-4-1106-preview \ + --api_base "" \ + --api_key "" \ + --language "zh" \ + --cot True \ + --try_times 5 \ + --output_path ./results +``` + +### Local Model Evaluation (HuggingFace) + +```bash +python3 run_huggingface.py \ + --task "" \ + --model_name "" \ + --language "zh" \ + --cot True \ + --try_times 5 +``` + +### Results Aggregation + +```bash +python3 get_results.py --try_times 5 +``` + +Output: `results.json` containing accuracy per task and per ability. + +### Running a Specific Task + +Replace `--task ""` with a specific task name (e.g., `--task "False Belief Task.jsonl"`) + +## Data Format + +Each JSONL entry contains bilingual fields: + +**Chinese fields:** +- `故事` (story), `问题` (question), `选项A/B/C/D` (options), `答案` (answer), `能力` (ability) + +**English fields:** +- `STORY`, `QUESTION`, `OPTION-A/B/C/D`, `ANSWER` + +**Metadata:** +- `INDEX` (sample number) + +## Argument Reference + +| Argument | Description | Default | +|----------|-------------|---------| +| `--task` | Specific task file (empty = all tasks) | "" | +| `--model_name` | Model identifier | Required | +| `--language` | "zh" or "en" | "zh" | +| `--cot` | Enable chain-of-thought reasoning | False | +| `--try_times` | Number of trials per sample | 5 | +| `--seed` | Random seed for reproducibility | 42 | +| `--output_path` | Results directory | "./results" | + +## Prompt Templates + +All prompts are defined in `prompts.py`: +- `SystemEvaluatePrompt_zh/en`: Standard evaluation prompt +- `SystemEvaluatePrompt_zh_cot/en_cot`: Chain-of-thought variant +- `UserEvaluatePrompt4Choices_zh/en`: 4-option format +- `UserEvaluatePrompt2Choices_zh/en`: 2-option format + +Models must output answers in format `[[A]]`, `[[B]]`, etc. Answer extraction in `get_results.py:extract_answer()` handles this parsing. + +## 8 ToM Tasks + +1. **Unexpected Outcome Test**: Infer emotions when expectations are violated +2. **Scalar Implicature Task**: Understand pragmatic implications (e.g., "some" → "not all") +3. **Persuasion Story Task**: Choose effective persuasion strategies +4. **False Belief Task**: Distinguish true vs. false beliefs +5. **Ambiguous Story Task**: Infer mental states in uncertain situations +6. **Hinting Test**: Infer mental states from indirect hints +7. **Strange Story Task**: Handle complex social communication (lies, irony, jokes) +8. **Faux-pas Recognition Test**: Recognize social norm violations + +## 6 ATOMS Ability Categories + +- **Emotion** (7 abilities): Typical/atypical/discrepant/mixed/hidden/moral emotions, regulation +- **Desire** (4 abilities): Discrepant/multiple desires, influence on emotions/actions, contradiction +- **Intention** (4 abilities): Failed actions, discrepant intentions, action prediction, explanations +- **Knowledge** (4 abilities): Pretend play, percepts, information, attention links +- **Belief** (6 abilities): Content/location/identity/second-order/sequence false beliefs, belief-based actions +- **Non-Literal Communication** (6 abilities): Irony, egocentric/white/involuntary lies, humor, faux pas diff --git a/common.py b/common.py new file mode 100644 index 0000000..4459ba7 --- /dev/null +++ b/common.py @@ -0,0 +1,186 @@ +""" +Common utility functions for ToMBench evaluation scripts. +This module contains shared functions used by both API-based and HuggingFace evaluation. +""" +import random +from prompts import * + + +def format_prompt_4(d, args): + """ + Format prompt for 4-choice questions. + + Args: + d: Data dictionary containing the question and options + args: Arguments containing language setting + + Returns: + tuple: (mapping_dict, formatted_prompt) + """ + if args.language == 'zh': + cA = d['选项A'].replace("A. ", "") + cB = d['选项B'].replace("B. ", "") + cC = d['选项C'].replace("C. ", "") + cD = d['选项D'].replace("D. ", "") + choices = [cA, cB, cC, cD] + random.shuffle(choices) + prompt = UserEvaluatePrompt4Choices_zh.format(story=d['故事'], question=d['问题'], choice_a=choices[0], choice_b=choices[1], choice_c=choices[2], choice_d=choices[3]) + map = {"A": "", "B": "", "C": "", "D": ""} + + if choices[0] == cA: + map['A'] = 'A' + elif choices[0] == cB: + map['A'] = 'B' + elif choices[0] == cC: + map['A'] = 'C' + elif choices[0] == cD: + map['A'] = 'D' + + if choices[1] == cA: + map['B'] = 'A' + elif choices[1] == cB: + map['B'] = 'B' + elif choices[1] == cC: + map['B'] = 'C' + elif choices[1] == cD: + map['B'] = 'D' + + if choices[2] == cA: + map['C'] = 'A' + elif choices[2] == cB: + map['C'] = 'B' + elif choices[2] == cC: + map['C'] = 'C' + elif choices[2] == cD: + map['C'] = 'D' + + if choices[3] == cA: + map['D'] = 'A' + elif choices[3] == cB: + map['D'] = 'B' + elif choices[3] == cC: + map['D'] = 'C' + elif choices[3] == cD: + map['D'] = 'D' + else: + cA = d['OPTION-A'].replace("A. ", "") + cB = d['OPTION-B'].replace("B. ", "") + cC = d['OPTION-C'].replace("C. ", "") + cD = d['OPTION-D'].replace("D. ", "") + choices = [cA, cB, cC, cD] + random.shuffle(choices) + prompt = UserEvaluatePrompt4Choices_en.format(story=d['STORY'], question=d['QUESTION'], choice_a=choices[0], choice_b=choices[1], choice_c=choices[2], choice_d=choices[3]) + map = {"A": "", "B": "", "C": "", "D": ""} + + if choices[0] == cA: + map['A'] = 'A' + elif choices[0] == cB: + map['A'] = 'B' + elif choices[0] == cC: + map['A'] = 'C' + elif choices[0] == cD: + map['A'] = 'D' + + if choices[1] == cA: + map['B'] = 'A' + elif choices[1] == cB: + map['B'] = 'B' + elif choices[1] == cC: + map['B'] = 'C' + elif choices[1] == cD: + map['B'] = 'D' + + if choices[2] == cA: + map['C'] = 'A' + elif choices[2] == cB: + map['C'] = 'B' + elif choices[2] == cC: + map['C'] = 'C' + elif choices[2] == cD: + map['C'] = 'D' + + if choices[3] == cA: + map['D'] = 'A' + elif choices[3] == cB: + map['D'] = 'B' + elif choices[3] == cC: + map['D'] = 'C' + elif choices[3] == cD: + map['D'] = 'D' + return map, prompt + + +def format_prompt_2(d, args): + """ + Format prompt for 2-choice questions. + + Args: + d: Data dictionary containing the question and options + args: Arguments containing language setting + + Returns: + tuple: (mapping_dict, formatted_prompt) + """ + if args.language == 'zh': + cA = d['选项A'].replace("A. ", "") + cB = d['选项B'].replace("B. ", "") + choices = [cA, cB] + random.shuffle(choices) + prompt = UserEvaluatePrompt2Choices_zh.format(story=d['故事'], question=d['问题'], choice_a=choices[0], choice_b=choices[1]) + map = {"A": "", "B": "", "C": "", "D": ""} + if choices[0] == cA: + map['A'] = 'A' + elif choices[0] == cB: + map['A'] = 'B' + + if choices[1] == cA: + map['B'] = 'A' + elif choices[1] == cB: + map['B'] = 'B' + else: + cA = d['OPTION-A'].replace("A. ", "") + cB = d['OPTION-B'].replace("B. ", "") + choices = [cA, cB] + random.shuffle(choices) + prompt = UserEvaluatePrompt2Choices_en.format(story=d['STORY'], question=d['QUESTION'], choice_a=choices[0], choice_b=choices[1]) + map = {"A": "", "B": "", "C": "", "D": ""} + if choices[0] == cA: + map['A'] = 'A' + elif choices[0] == cB: + map['A'] = 'B' + + if choices[1] == cA: + map['B'] = 'A' + elif choices[1] == cB: + map['B'] = 'B' + + return map, prompt + + +def get_system_prompt(args): + """ + Get the appropriate system prompt based on language and CoT settings. + + Args: + args: Arguments containing language and cot settings + + Returns: + str: The formatted system prompt + """ + if args.language == "zh": + return SystemEvaluatePrompt_zh_cot if args.cot else SystemEvaluatePrompt_zh + else: + return SystemEvaluatePrompt_en_cot if args.cot else SystemEvaluatePrompt_en + + +def has_four_choices(d): + """ + Check if the data entry has four answer options. + + Args: + d: Data dictionary containing the question and options + + Returns: + bool: True if 4-choice options exist, False for 2-choice + """ + return ('选项C' in d and d.get('选项C')) or ('OPTION-C' in d and d.get('OPTION-C')) diff --git a/eval_huggingface.sh b/eval_huggingface.sh index fe41f52..374aa83 100644 --- a/eval_huggingface.sh +++ b/eval_huggingface.sh @@ -1,9 +1,7 @@ python3 run_huggingface.py \ --task "" \ --model_name "" \ - --api_base "" \ - --api_key "" \ --language "zh" \ --cot True \ --try_times 5 \ - --output_dir ./results \ \ No newline at end of file + --output_path ./results \ \ No newline at end of file diff --git a/run_api.py b/run_api.py index 6318a10..882231d 100644 --- a/run_api.py +++ b/run_api.py @@ -1,6 +1,7 @@ import json import random import time +import logging from tqdm import tqdm import multiprocessing from concurrent.futures import ThreadPoolExecutor @@ -8,13 +9,19 @@ import os import argparse from prompts import * +from common import format_prompt_4, format_prompt_2, get_system_prompt, has_four_choices + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) class ChatGPTProcessor: - def __init__(self): + def __init__(self, api_base="", api_key=""): self.lock = multiprocessing.Lock() - openai.api_key = "" - openai.api_base = "" + # Use provided values or fall back to environment variables + openai.api_key = api_key or os.environ.get("OPENAI_API_KEY", "") + openai.api_base = api_base or os.environ.get("OPENAI_API_BASE", "") def read_jsonl(self, input_file): with open(input_file, 'r', encoding='utf-8') as f: @@ -22,159 +29,70 @@ def read_jsonl(self, input_file): return list(map(json.loads, tqdm(lines, desc='Reading...'))) def write_to_json(self, data, file_path): + """ + Thread-safe JSON file writing. + + Args: + data: Dictionary to write to JSON file + file_path: Path to the output file + """ with self.lock: + # Ensure directory exists + os.makedirs(os.path.dirname(file_path), exist_ok=True) + # Write to file with proper encoding with open(file_path, 'a', encoding='utf-8') as file: json.dump(data, file, ensure_ascii=False) file.write('\n') - def multiple_gpt(self, payload): - while True: + def multiple_gpt(self, payload, max_retries=5): + """ + Send request to API with retry logic and error logging. + + Args: + payload: Dictionary containing request parameters + max_retries: Maximum number of retry attempts (default: 5) + """ + retries = 0 + while retries < max_retries: try: - chat_completion = openai.ChatCompletion.create(model=payload['model'], temperature=0, messages=payload['messages']) + chat_completion = openai.ChatCompletion.create( + model=payload['model'], + temperature=0, + messages=payload['messages'] + ) data = payload.copy() data['messages'] = payload['messages'] data['answer'] = payload['answer'] data['output'] = chat_completion.choices[0].message.content break + except openai.error.RateLimitError as e: + retries += 1 + wait_time = random.randint(2, 5) * retries + logger.warning(f"Rate limit error (attempt {retries}/{max_retries}): {e}. Waiting {wait_time}s...") + time.sleep(wait_time) + except openai.error.APIError as e: + retries += 1 + wait_time = random.randint(1, 3) + logger.warning(f"API error (attempt {retries}/{max_retries}): {e}. Waiting {wait_time}s...") + time.sleep(wait_time) + except openai.error.AuthenticationError as e: + logger.error(f"Authentication error: {e}. Please check your API key.") + raise except Exception as e: + retries += 1 + logger.error(f"Unexpected error (attempt {retries}/{max_retries}): {e}") time.sleep(random.randint(1, 3)) + else: + logger.error(f"Max retries ({max_retries}) exceeded for payload idx={payload['idx']}, number={payload['number']}") + # Write error result to maintain data continuity + data = payload.copy() + data['output'] = f"ERROR: Max retries exceeded" self.write_to_json(data, payload['save_path']) time.sleep(random.randint(1, 3)) -def format_prompt_4(d, args): - if args.language == 'zh': - cA = d['选项A'].replace("A. ", "") - cB = d['选项B'].replace("B. ", "") - cC = d['选项C'].replace("C. ", "") - cD = d['选项D'].replace("D. ", "") - choices = [cA, cB, cC, cD] - random.shuffle(choices) - prompt = UserEvaluatePrompt4Choices_zh.format(story=d['故事'], question=d['问题'], choice_a=choices[0], choice_b=choices[1], choice_c=choices[2], choice_d=choices[3]) - map = {"A": "", "B": "", "C": "", "D": ""} - - if choices[0] == cA: - map['A'] = 'A' - elif choices[0] == cB: - map['A'] = 'B' - elif choices[0] == cC: - map['A'] = 'C' - elif choices[0] == cD: - map['A'] = 'D' - - if choices[1] == cA: - map['B'] = 'A' - elif choices[1] == cB: - map['B'] = 'B' - elif choices[1] == cC: - map['B'] = 'C' - elif choices[1] == cD: - map['B'] = 'D' - - if choices[2] == cA: - map['C'] = 'A' - elif choices[2] == cB: - map['C'] = 'B' - elif choices[2] == cC: - map['C'] = 'C' - elif choices[2] == cD: - map['C'] = 'D' - - if choices[3] == cA: - map['D'] = 'A' - elif choices[3] == cB: - map['D'] = 'B' - elif choices[3] == cC: - map['D'] = 'C' - elif choices[3] == cD: - map['D'] = 'D' - else: - cA = d['OPTION-A'].replace("A. ", "") - cB = d['OPTION-B'].replace("B. ", "") - cC = d['OPTION-C'].replace("C. ", "") - cD = d['OPTION-D'].replace("D. ", "") - choices = [cA, cB, cC, cD] - random.shuffle(choices) - prompt = UserEvaluatePrompt4Choices_en.format(story=d['STORY'], question=d['QUESTION'], choice_a=choices[0], choice_b=choices[1], choice_c=choices[2], choice_d=choices[3]) - map = {"A": "", "B": "", "C": "", "D": ""} - - if choices[0] == cA: - map['A'] = 'A' - elif choices[0] == cB: - map['A'] = 'B' - elif choices[0] == cC: - map['A'] = 'C' - elif choices[0] == cD: - map['A'] = 'D' - - if choices[1] == cA: - map['B'] = 'A' - elif choices[1] == cB: - map['B'] = 'B' - elif choices[1] == cC: - map['B'] = 'C' - elif choices[1] == cD: - map['B'] = 'D' - - if choices[2] == cA: - map['C'] = 'A' - elif choices[2] == cB: - map['C'] = 'B' - elif choices[2] == cC: - map['C'] = 'C' - elif choices[2] == cD: - map['C'] = 'D' - - if choices[3] == cA: - map['D'] = 'A' - elif choices[3] == cB: - map['D'] = 'B' - elif choices[3] == cC: - map['D'] = 'C' - elif choices[3] == cD: - map['D'] = 'D' - return map, prompt - - -def format_prompt_2(d, args): - if args.language == 'zh': - cA = d['选项A'].replace("A. ", "") - cB = d['选项B'].replace("B. ", "") - choices = [cA, cB] - random.shuffle(choices) - prompt = UserEvaluatePrompt2Choices_zh.format(story=d['故事'], question=d['问题'], choice_a=choices[0], choice_b=choices[1]) - map = {"A": "", "B": "", "C": "", "D": ""} - if choices[0] == cA: - map['A'] = 'A' - elif choices[0] == cB: - map['A'] = 'B' - - if choices[1] == cA: - map['B'] = 'A' - elif choices[1] == cB: - map['B'] = 'B' - else: - cA = d['OPTION-A'].replace("A. ", "") - cB = d['OPTION-B'].replace("B. ", "") - choices = [cA, cB] - random.shuffle(choices) - prompt = UserEvaluatePrompt2Choices_en.format(story=d['STORY'], question=d['QUESTION'], choice_a=choices[0], choice_b=choices[1]) - map = {"A": "", "B": "", "C": "", "D": ""} - if choices[0] == cA: - map['A'] = 'A' - elif choices[0] == cB: - map['A'] = 'B' - - if choices[1] == cA: - map['B'] = 'A' - elif choices[1] == cB: - map['B'] = 'B' - - return map, prompt - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="") @@ -189,7 +107,7 @@ def format_prompt_2(d, args): args = parser.parse_args() random.seed(args.seed) - processor = ChatGPTProcessor() + processor = ChatGPTProcessor(api_base=args.api_base, api_key=args.api_key) files = os.listdir("./data") if args.task != "": @@ -200,28 +118,20 @@ def format_prompt_2(d, args): try: with open(f"data/{file}", "r", encoding='utf-8') as f: data = [json.loads(line) for line in f.readlines()] - except: + except (FileNotFoundError, json.JSONDecodeError, IOError) as e: + print(f"Error reading {file}: {e}") continue payloads = [] for i, d in enumerate(data): for j in range(args.try_times): - if d['选项C'] != None: + # Check if 4-choice options exist (supports both Chinese and English keys) + if has_four_choices(d): maps, prompt = format_prompt_4(d, args) else: maps, prompt = format_prompt_2(d, args) - - system_prompt = "" - if args.language == "zh": - if args.cot == False: - system_prompt = SystemEvaluatePrompt_zh - else: - system_prompt = SystemEvaluatePrompt_zh_cot - else: - if args.cot == False: - system_prompt = SystemEvaluatePrompt_en - else: - system_prompt = SystemEvaluatePrompt_en_cot - + + system_prompt = get_system_prompt(args) + payload = { "model": args.model_name, "stream": False, diff --git a/run_huggingface.py b/run_huggingface.py index ac4e9bf..c1a99e0 100644 --- a/run_huggingface.py +++ b/run_huggingface.py @@ -5,137 +5,7 @@ from prompts import * from tqdm import tqdm import os - - -def format_prompt_4(d, args): - if args.language == 'zh': - cA = d['选项A'].replace("A. ", "") - cB = d['选项B'].replace("B. ", "") - cC = d['选项C'].replace("C. ", "") - cD = d['选项D'].replace("D. ", "") - choices = [cA, cB, cC, cD] - random.shuffle(choices) - prompt = UserEvaluatePrompt4Choices_zh.format(story=d['故事'], question=d['问题'], choice_a=choices[0], choice_b=choices[1], choice_c=choices[2], choice_d=choices[3]) - map = {"A": "", "B": "", "C": "", "D": ""} - - if choices[0] == cA: - map['A'] = 'A' - elif choices[0] == cB: - map['A'] = 'B' - elif choices[0] == cC: - map['A'] = 'C' - elif choices[0] == cD: - map['A'] = 'D' - - if choices[1] == cA: - map['B'] = 'A' - elif choices[1] == cB: - map['B'] = 'B' - elif choices[1] == cC: - map['B'] = 'C' - elif choices[1] == cD: - map['B'] = 'D' - - if choices[2] == cA: - map['C'] = 'A' - elif choices[2] == cB: - map['C'] = 'B' - elif choices[2] == cC: - map['C'] = 'C' - elif choices[2] == cD: - map['C'] = 'D' - - if choices[3] == cA: - map['D'] = 'A' - elif choices[3] == cB: - map['D'] = 'B' - elif choices[3] == cC: - map['D'] = 'C' - elif choices[3] == cD: - map['D'] = 'D' - else: - cA = d['OPTION-A'].replace("A. ", "") - cB = d['OPTION-B'].replace("B. ", "") - cC = d['OPTION-C'].replace("C. ", "") - cD = d['OPTION-D'].replace("D. ", "") - choices = [cA, cB, cC, cD] - random.shuffle(choices) - prompt = UserEvaluatePrompt4Choices_en.format(story=d['STORY'], question=d['QUESTION'], choice_a=choices[0], choice_b=choices[1], choice_c=choices[2], choice_d=choices[3]) - map = {"A": "", "B": "", "C": "", "D": ""} - - if choices[0] == cA: - map['A'] = 'A' - elif choices[0] == cB: - map['A'] = 'B' - elif choices[0] == cC: - map['A'] = 'C' - elif choices[0] == cD: - map['A'] = 'D' - - if choices[1] == cA: - map['B'] = 'A' - elif choices[1] == cB: - map['B'] = 'B' - elif choices[1] == cC: - map['B'] = 'C' - elif choices[1] == cD: - map['B'] = 'D' - - if choices[2] == cA: - map['C'] = 'A' - elif choices[2] == cB: - map['C'] = 'B' - elif choices[2] == cC: - map['C'] = 'C' - elif choices[2] == cD: - map['C'] = 'D' - - if choices[3] == cA: - map['D'] = 'A' - elif choices[3] == cB: - map['D'] = 'B' - elif choices[3] == cC: - map['D'] = 'C' - elif choices[3] == cD: - map['D'] = 'D' - return map, prompt - - -def format_prompt_2(d, args): - if args.language == 'zh': - cA = d['选项A'].replace("A. ", "") - cB = d['选项B'].replace("B. ", "") - choices = [cA, cB] - random.shuffle(choices) - prompt = UserEvaluatePrompt2Choices_zh.format(story=d['故事'], question=d['问题'], choice_a=choices[0], choice_b=choices[1]) - map = {"A": "", "B": "", "C": "", "D": ""} - if choices[0] == cA: - map['A'] = 'A' - elif choices[0] == cB: - map['A'] = 'B' - - if choices[1] == cA: - map['B'] = 'A' - elif choices[1] == cB: - map['B'] = 'B' - else: - cA = d['OPTION-A'].replace("A. ", "") - cB = d['OPTION-B'].replace("B. ", "") - choices = [cA, cB] - random.shuffle(choices) - prompt = UserEvaluatePrompt2Choices_en.format(story=d['STORY'], question=d['QUESTION'], choice_a=choices[0], choice_b=choices[1]) - map = {"A": "", "B": "", "C": "", "D": ""} - if choices[0] == cA: - map['A'] = 'A' - elif choices[0] == cB: - map['A'] = 'B' - - if choices[1] == cA: - map['B'] = 'A' - elif choices[1] == cB: - map['B'] = 'B' - - return map, prompt +from common import format_prompt_4, format_prompt_2, get_system_prompt, has_four_choices if __name__ == "__main__": @@ -146,12 +16,19 @@ def format_prompt_2(d, args): parser.add_argument("--try_times", type=int, default=5) parser.add_argument("--cot", type=bool, default=False) parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--trust_remote_code", type=bool, default=False, + help="Allow trusting remote code. Only enable for trusted model sources.") args = parser.parse_args() random.seed(args.seed) - tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) - model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True).half().cuda() + # Security notice: trust_remote_code allows execution of arbitrary code from the model repository + # Only enable this for trusted model sources (e.g., HuggingFace hub models you trust) + if args.trust_remote_code: + print("WARNING: trust_remote_code=True - Only use this with trusted model sources!") + + tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=args.trust_remote_code) + model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=args.trust_remote_code).half().cuda() model_name = args.model_name.split("/")[-1] @@ -167,22 +44,13 @@ def format_prompt_2(d, args): print(file) for i, d in tqdm(enumerate(data[:10])): for j in range(args.try_times): - if d['选项C'] != None: + # Check if 4-choice options exist (supports both Chinese and English keys) + if has_four_choices(d): maps, prompt = format_prompt_4(d, args) else: maps, prompt = format_prompt_2(d, args) - - system_prompt = "" - if args.language == "zh": - if args.cot == False: - system_prompt = SystemEvaluatePrompt_zh - else: - system_prompt = SystemEvaluatePrompt_zh_cot - else: - if args.cot == False: - system_prompt = SystemEvaluatePrompt_en - else: - system_prompt = SystemEvaluatePrompt_en_cot + + system_prompt = get_system_prompt(args) messages = [ {"role": "system", "content": system_prompt}, From e476398248f35f43a31381f842ad90685c3d7987 Mon Sep 17 00:00:00 2001 From: yuyuan12138 <2797274057@qq.com> Date: Tue, 27 Jan 2026 13:09:31 +0800 Subject: [PATCH 2/4] l --- .gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..726fb45 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +claude.md From f20d4476d6074250185a20a835b09f4cb877db93 Mon Sep 17 00:00:00 2001 From: yuyuan12138 <2797274057@qq.com> Date: Tue, 27 Jan 2026 13:09:53 +0800 Subject: [PATCH 3/4] delete --- CLAUDE.md | 126 ------------------------------------------------------ 1 file changed, 126 deletions(-) delete mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index fa0a9ce..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,126 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -ToMBench is a bilingual (Chinese/English) benchmark for evaluating Theory of Mind (ToM) capabilities in Large Language Models. It contains 2,860 testing samples across 21 tasks, covering 6 ability categories and 31 specific abilities from the ATOMS framework. - -**IMPORTANT**: Use ToMBench for evaluation purposes only. Do NOT use ToMBench data for training to prevent data contamination. - -## Core Architecture - -The evaluation pipeline consists of three main components: - -1. **Data Layer** (`data/`): 21 JSONL files containing stories, questions, and multiple-choice options -2. **Evaluation Scripts**: `run_api.py` (API-based) and `run_huggingface.py` (local models) -3. **Results Analysis**: `get_results.py` aggregates multiple trials and calculates accuracy - -### Evaluation Flow - -``` -data/*.jsonl → run_api.py/run_huggingface.py → results/*.jsonl → get_results.py → results.json -``` - -Each sample is evaluated `try_times` (default: 5) times with randomized answer options to mitigate position bias. Final answers use majority voting. - -### Key Design Patterns - -- **Option Randomization**: Answer choices are shuffled each trial, with a mapping dict tracking original positions -- **Bilingual Prompts**: All prompts in `prompts.py` exist in both Chinese (`_zh`) and English (`_en`) -- **Chain-of-Thought**: Optional CoT mode via `--cot` flag uses different system prompts -- **Concurrent Processing**: `run_api.py` uses ThreadPoolExecutor (32 workers) for parallel API calls - -## Commands - -### API-based Evaluation (OpenAI, etc.) - -```bash -python3 run_api.py \ - --task "" \ - --model_name gpt-4-1106-preview \ - --api_base "" \ - --api_key "" \ - --language "zh" \ - --cot True \ - --try_times 5 \ - --output_path ./results -``` - -### Local Model Evaluation (HuggingFace) - -```bash -python3 run_huggingface.py \ - --task "" \ - --model_name "" \ - --language "zh" \ - --cot True \ - --try_times 5 -``` - -### Results Aggregation - -```bash -python3 get_results.py --try_times 5 -``` - -Output: `results.json` containing accuracy per task and per ability. - -### Running a Specific Task - -Replace `--task ""` with a specific task name (e.g., `--task "False Belief Task.jsonl"`) - -## Data Format - -Each JSONL entry contains bilingual fields: - -**Chinese fields:** -- `故事` (story), `问题` (question), `选项A/B/C/D` (options), `答案` (answer), `能力` (ability) - -**English fields:** -- `STORY`, `QUESTION`, `OPTION-A/B/C/D`, `ANSWER` - -**Metadata:** -- `INDEX` (sample number) - -## Argument Reference - -| Argument | Description | Default | -|----------|-------------|---------| -| `--task` | Specific task file (empty = all tasks) | "" | -| `--model_name` | Model identifier | Required | -| `--language` | "zh" or "en" | "zh" | -| `--cot` | Enable chain-of-thought reasoning | False | -| `--try_times` | Number of trials per sample | 5 | -| `--seed` | Random seed for reproducibility | 42 | -| `--output_path` | Results directory | "./results" | - -## Prompt Templates - -All prompts are defined in `prompts.py`: -- `SystemEvaluatePrompt_zh/en`: Standard evaluation prompt -- `SystemEvaluatePrompt_zh_cot/en_cot`: Chain-of-thought variant -- `UserEvaluatePrompt4Choices_zh/en`: 4-option format -- `UserEvaluatePrompt2Choices_zh/en`: 2-option format - -Models must output answers in format `[[A]]`, `[[B]]`, etc. Answer extraction in `get_results.py:extract_answer()` handles this parsing. - -## 8 ToM Tasks - -1. **Unexpected Outcome Test**: Infer emotions when expectations are violated -2. **Scalar Implicature Task**: Understand pragmatic implications (e.g., "some" → "not all") -3. **Persuasion Story Task**: Choose effective persuasion strategies -4. **False Belief Task**: Distinguish true vs. false beliefs -5. **Ambiguous Story Task**: Infer mental states in uncertain situations -6. **Hinting Test**: Infer mental states from indirect hints -7. **Strange Story Task**: Handle complex social communication (lies, irony, jokes) -8. **Faux-pas Recognition Test**: Recognize social norm violations - -## 6 ATOMS Ability Categories - -- **Emotion** (7 abilities): Typical/atypical/discrepant/mixed/hidden/moral emotions, regulation -- **Desire** (4 abilities): Discrepant/multiple desires, influence on emotions/actions, contradiction -- **Intention** (4 abilities): Failed actions, discrepant intentions, action prediction, explanations -- **Knowledge** (4 abilities): Pretend play, percepts, information, attention links -- **Belief** (6 abilities): Content/location/identity/second-order/sequence false beliefs, belief-based actions -- **Non-Literal Communication** (6 abilities): Irony, egocentric/white/involuntary lies, humor, faux pas From a801cd160e7004e8fe1555b4a1fe63cce66a0123 Mon Sep 17 00:00:00 2001 From: yuyuan12138 <2797274057@qq.com> Date: Tue, 27 Jan 2026 17:52:12 +0800 Subject: [PATCH 4/4] update --- get_results.py | 73 +++++++++++++++++++++++++++++++++++++------------- run_api.py | 24 ++++++++++------- 2 files changed, 69 insertions(+), 28 deletions(-) diff --git a/get_results.py b/get_results.py index 53a1a1e..a28b4bf 100644 --- a/get_results.py +++ b/get_results.py @@ -3,6 +3,17 @@ import os def most_common_element(lst): + """ + Find the most common element in a list. + + Args: + lst: List of elements + + Returns: + The most common element, or None if list is empty + """ + if not lst: + return None element_freq = {} for item in lst: element_freq[item] = element_freq.get(item, 0) + 1 @@ -11,6 +22,17 @@ def most_common_element(lst): def extract_answer(text): + """ + Extract answer from model output. + + Args: + text: Model output text + + Returns: + Extracted answer (A, B, C, or D) + """ + if not text: + return "A" if "[[A]]" in text: return "A" elif "[[B]]" in text: @@ -45,39 +67,54 @@ def extract_answer(text): parser.add_argument("--input_path", type=str, default="") parser.add_argument("--try_times", type=int, default=5) args = parser.parse_args() - + files = os.listdir("./results") acc_per_task = {} cnt_per_task = {} acc_per_ability = {} cnt_per_ability = {} - + for file in files: with open(f"./results/{file}", "r", encoding='utf-8') as f: data = [json.loads(line) for line in f.readlines()] - - answers = ["" for _ in range(len(data) // args.try_times)] - preds = [[] for _ in range(len(data) // args.try_times)] - abilities = ["" for _ in range(len(data) // args.try_times)] + + # Find max index to properly size arrays + max_idx = max(d.get('idx', 0) for d in data) + 1 if data else 0 + + answers = [""] * max_idx + preds = [[] for _ in range(max_idx)] + abilities = [""] * max_idx + for d in data: - preds[d['idx']].append(d['map'][extract_answer(d['output'])]) - if answers[d['idx']] == "": - answers[d['idx']] = d['answer'] - - if abilities[d['idx']] == "": - abilities[d['idx']] = d['data']['能力\nABILITY'] - - - for i in range(len(data) // args.try_times): + idx = d.get('idx', 0) + extracted = extract_answer(d.get('output', '')) + # Safely get mapped answer, default to extracted if not in map + mapped = d.get('map', {}).get(extracted, extracted) + if mapped: # Only append non-empty values + preds[idx].append(mapped) + if answers[idx] == "": + answers[idx] = d.get('answer', '') + + if abilities[idx] == "": + abilities[idx] = d.get('data', {}).get('能力\nABILITY', '') + + # Count valid samples + valid_samples = sum(1 for i in range(max_idx) if answers[i]) + + for i in range(max_idx): + if not answers[i]: + continue # Skip empty samples + task = file.split("_")[0] - ability = abilities[i] + ability = abilities[i] if abilities[i] else "Unknown" cnt_per_task[task] = cnt_per_task.get(task, 0) + 1 cnt_per_ability[ability] = cnt_per_ability.get(ability, 0) + 1 - - if answers[i] == most_common_element(preds[i]): + # Get most common prediction + most_common = most_common_element(preds[i]) + if most_common and answers[i] == most_common: acc_per_task[task] = acc_per_task.get(task, 0) + 1 acc_per_ability[ability] = acc_per_ability.get(ability, 0) + 1 diff --git a/run_api.py b/run_api.py index 882231d..9e9b47f 100644 --- a/run_api.py +++ b/run_api.py @@ -5,7 +5,7 @@ from tqdm import tqdm import multiprocessing from concurrent.futures import ThreadPoolExecutor -import openai +from openai import OpenAI import os import argparse from prompts import * @@ -20,8 +20,9 @@ class ChatGPTProcessor: def __init__(self, api_base="", api_key=""): self.lock = multiprocessing.Lock() # Use provided values or fall back to environment variables - openai.api_key = api_key or os.environ.get("OPENAI_API_KEY", "") - openai.api_base = api_base or os.environ.get("OPENAI_API_BASE", "") + api_key = api_key or os.environ.get("OPENAI_API_KEY", "") + api_base = api_base or os.environ.get("OPENAI_API_BASE", "") + self.client = OpenAI(api_key=api_key, base_url=api_base) def read_jsonl(self, input_file): with open(input_file, 'r', encoding='utf-8') as f: @@ -55,7 +56,7 @@ def multiple_gpt(self, payload, max_retries=5): retries = 0 while retries < max_retries: try: - chat_completion = openai.ChatCompletion.create( + chat_completion = self.client.chat.completions.create( model=payload['model'], temperature=0, messages=payload['messages'] @@ -65,17 +66,17 @@ def multiple_gpt(self, payload, max_retries=5): data['answer'] = payload['answer'] data['output'] = chat_completion.choices[0].message.content break - except openai.error.RateLimitError as e: + except openai.RateLimitError as e: retries += 1 wait_time = random.randint(2, 5) * retries logger.warning(f"Rate limit error (attempt {retries}/{max_retries}): {e}. Waiting {wait_time}s...") time.sleep(wait_time) - except openai.error.APIError as e: + except openai.APIError as e: retries += 1 wait_time = random.randint(1, 3) logger.warning(f"API error (attempt {retries}/{max_retries}): {e}. Waiting {wait_time}s...") time.sleep(wait_time) - except openai.error.AuthenticationError as e: + except openai.AuthenticationError as e: logger.error(f"Authentication error: {e}. Please check your API key.") raise except Exception as e: @@ -150,7 +151,10 @@ def multiple_gpt(self, payload, max_retries=5): } payloads.append(payload) - + + # Execute API calls in parallel and wait for completion with ThreadPoolExecutor(max_workers=32) as executor: - for payload in payloads: - executor.submit(processor.multiple_gpt, payload) + futures = [executor.submit(processor.multiple_gpt, payload) for payload in payloads] + # Wait for all tasks to complete + for future in futures: + future.result()