diff --git a/evals/cli/llmreport.py b/evals/cli/llmreport.py new file mode 100644 index 0000000000..15dd95471a --- /dev/null +++ b/evals/cli/llmreport.py @@ -0,0 +1,96 @@ +import argparse +import json +import pickle +import re +import glob +from io import StringIO +from pathlib import Path + +import pandas as pd +import matplotlib.pyplot as plt + + +def main() -> None: + parser = argparse.ArgumentParser(description="Report evals results") + parser.add_argument("run_id", type=str, nargs="+", help="Eval Run id") + parser.add_argument("--mlops", type=str, default=None) + parser.add_argument("--name", type=str, default="LLM_Eval") + + args = parser.parse_args() + + logfiles = [] + for run_id in args.run_id: + logfiles += glob.glob(f"/tmp/evallogs/{run_id}*/**", recursive=True) + logfiles = sorted([f for f in logfiles if Path(f).suffix == ".jsonl"]) + logger_data = {} + table_collection = [] + qa_collection = [] + + for logfile in logfiles: + with open(logfile, "r") as f: + events_df = pd.read_json(f, lines=True) + if not "final_report" in events_df.columns: + continue + final_report = events_df["final_report"].dropna().iloc[0] + + print(events_df) + run_config = events_df.loc[0, "spec"] + evalname = run_config["base_eval"] + model = run_config["completion_fns"][0].replace("/", ".") + matches_df = events_df[events_df["type"] == "match"].reset_index(drop=True) + matches_df = matches_df.join(pd.json_normalize(matches_df.data)) + + qa_collection.append({"eval": evalname, "model": model, **final_report}) + + if "file_name" in matches_df.columns: + matches_df["doi"] = [re.sub("__([0-9]+)__", r"(\1)", Path(f).stem).replace("_", "/") for f in matches_df["file_name"]] + + # TODO: compare on different completion_functions + if "jobtype" in matches_df.columns: + # Table extract tasks + accuracy_by_type_and_file = matches_df.groupby(["jobtype", "doi"])['correct'].mean().reset_index() + accuracy_by_type_and_file["model"] = model + table_collection.append(accuracy_by_type_and_file) + + accuracy_by_type = matches_df.groupby(["jobtype"])['correct'].mean().to_dict() + print(accuracy_by_type_and_file) + + logger_data = {**logger_data, **{f"Accuracy_{key}/model:{model}": value for key, value in accuracy_by_type.items()}} + + for doi, df in matches_df.groupby("doi"): + print(df) + logger_data[f"{doi.replace('/', '_')}/model:{model},context:match"] = df[df["jobtype"] != "match_all"][["correct", "expected", "picked", "jobtype"]] + match_all_data = df[df["jobtype"] == "match_all"].iloc[0, :] + logger_data[f"{doi.replace('/', '_')}/context:truth"] = pd.read_csv(StringIO(match_all_data["expected"]), header=[0, 1]) + logger_data[f"{doi.replace('/', '_')}/model:{model},context:extract"] = pd.read_csv(StringIO(match_all_data["picked"]), header=[0, 1]) \ + if df["jobtype"].iloc[0] != "match_all" else match_all_data["picked"] + else: + # Regular tasks + pass + + if len(table_collection) > 0: + accuracy_by_model_type_and_file = pd.concat(table_collection) + metrics_by_eval = pd.DataFrame(qa_collection) + accuracies = metrics_by_eval[metrics_by_eval["accuracy"] >= 0] + scores = metrics_by_eval[metrics_by_eval["score"] >= 0] + + if args.mlops: + import plotly.express as px + logger_data["TableExtraction"] = px.box(accuracy_by_model_type_and_file, + x="jobtype", y="correct", color="model", + title="Accuracy by jobtype and model") + logger_data["QA_accuracy"] = px.bar(accuracies, x="eval", y="accuracy", color="model", + title="Accuracy by eval and model") + logger_data["QA_score"] = px.bar(scores, x="eval", y="accuracy", color="model", + title="Accuracy by eval and model") + if args.mlops: + config_logger = json.load(open(args.mlops, 'r')) + if "name" not in config_logger.keys(): + config_logger["name"] = args.name + if "dp_mlops" in config_logger: + from evals.reporters.DPTracking import DPTrackingReporter + DPTrackingReporter.report_run(config_logger, {}, logger_data, step=0) + + +if __name__ == "__main__": + main() diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index 20b7d4c3bf..b8adfcba5a 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -2,9 +2,14 @@ This file defines the `oaieval` CLI for running evals. """ import argparse +import json import logging +import pickle +import re import shlex import sys +from io import StringIO +from pathlib import Path from typing import Any, Mapping, Optional, Union, cast import openai @@ -229,6 +234,7 @@ def to_number(x: str) -> Union[int, float, str]: logger.info("Final report:") for key, value in result.items(): logger.info(f"{key}: {value}") + return run_spec.run_id diff --git a/evals/completion_fns/openai.py b/evals/completion_fns/openai.py index ed50818630..b57570e0e9 100644 --- a/evals/completion_fns/openai.py +++ b/evals/completion_fns/openai.py @@ -15,6 +15,7 @@ from evals.utils.api_utils import ( openai_chat_completion_create_retrying, openai_completion_create_retrying, + openai_rag_completion_create_retrying ) @@ -46,6 +47,15 @@ def get_completions(self) -> list[str]: return completions +class RetrievalCompletionResult(CompletionResult): + def __init__(self, response: str, prompt: Any) -> None: + self.response = response + self.prompt = prompt + + def get_completions(self) -> list[str]: + return [self.response.strip()] + + class OpenAICompletionFn(CompletionFn): def __init__( self, @@ -81,13 +91,22 @@ def __call__( openai_create_prompt: OpenAICreatePrompt = prompt.to_formatted_prompt() - result = openai_completion_create_retrying( - OpenAI(api_key=self.api_key, base_url=self.api_base), - model=self.model, - prompt=openai_create_prompt, - **{**kwargs, **self.extra_options}, - ) - result = OpenAICompletionResult(raw_data=result, prompt=openai_create_prompt) + if "file_name" not in kwargs: + result = openai_completion_create_retrying( + OpenAI(api_key=self.api_key, base_url=self.api_base), + model=self.model, + prompt=openai_create_prompt, + **{**kwargs, **self.extra_options}, + ) + result = OpenAICompletionResult(raw_data=result, prompt=openai_create_prompt) + else: + answer = openai_rag_completion_create_retrying( + OpenAI(api_key=self.api_key, base_url=self.api_base), + model=self.model, + instructions=kwargs.get("instructions", ""), + file_name=kwargs.get("file_name", ""), + ) + result = RetrievalCompletionResult(answer, prompt=openai_create_prompt) record_sampling(prompt=result.prompt, sampled=result.get_completions()) return result @@ -126,12 +145,23 @@ def __call__( openai_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt() - result = openai_chat_completion_create_retrying( - OpenAI(api_key=self.api_key, base_url=self.api_base), - model=self.model, - messages=openai_create_prompt, - **{**kwargs, **self.extra_options}, - ) - result = OpenAIChatCompletionResult(raw_data=result, prompt=openai_create_prompt) + if "file_name" not in kwargs: + result = openai_chat_completion_create_retrying( + OpenAI(api_key=self.api_key, base_url=self.api_base), + model=self.model, + messages=openai_create_prompt, + **{**kwargs, **self.extra_options}, + ) + result = OpenAIChatCompletionResult(raw_data=result, prompt=openai_create_prompt) + else: + chatmodel_to_apimodel = lambda x: "gpt-3.5-turbo-1106" if x.startswith("gpt-3.5-turbo") else "gpt-4-1106-preview" if x.startswith("gpt-4") else "" + answer = openai_rag_completion_create_retrying( + OpenAI(api_key=self.api_key, base_url=self.api_base), + model=chatmodel_to_apimodel(self.model), + instructions=kwargs.get("instructions", ""), + file_name=kwargs.get("file_name", ""), + prompt=CompletionPrompt(raw_prompt=openai_create_prompt).to_formatted_prompt() + ) + result = RetrievalCompletionResult(answer, prompt=openai_create_prompt) record_sampling(prompt=result.prompt, sampled=result.get_completions()) return result diff --git a/evals/completion_fns/retrieval_native.py b/evals/completion_fns/retrieval_native.py new file mode 100644 index 0000000000..f06e2da423 --- /dev/null +++ b/evals/completion_fns/retrieval_native.py @@ -0,0 +1,60 @@ +""" +Extending Completion Functions with Embeddings-based retrieval from a fetched dataset +""" +import os +from ast import literal_eval +import time +from typing import Any, Optional, Union + +import numpy as np +from openai import OpenAI + +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + +from evals.api import CompletionFn, CompletionResult +from evals.completion_fns.openai import RetrievalCompletionResult +from evals.prompt.base import ChatCompletionPrompt, CompletionPrompt +from evals.record import record_sampling +from evals.utils.api_utils import openai_rag_completion_create_retrying + + +class OpenAIRetrievalCompletionFn(CompletionFn): + """ + This Completion Function uses embeddings to retrieve the top k relevant docs from a dataset to the prompt, then adds them to the context before calling the completion. + """ + + def __init__( + self, + model: Optional[str] = None, + instructions: Optional[str] = "You are a helpful assistant on extracting information from files.", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + n_ctx: Optional[int] = None, + extra_options: Optional[dict] = {}, + **kwargs + ): + self.model = model + self.instructions = instructions + self.api_base = api_base + self.api_key = api_key + self.n_ctx = n_ctx + self.extra_options = extra_options + + def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> RetrievalCompletionResult: + """ + Args: + prompt: The prompt to complete, in either text string or Chat format. + kwargs: Additional arguments to pass to the completion function call method. + """ + + assert "file_name" in kwargs, "Must provide a file_name to retrieve." + + answer = openai_rag_completion_create_retrying( + client, + model=self.model, + instructions=self.instructions, + file_name=kwargs.get("file_name", ""), + prompt=CompletionPrompt(raw_prompt=prompt).to_formatted_prompt(), + ) + record_sampling(prompt=prompt, sampled=answer) + return RetrievalCompletionResult(answer, prompt=prompt) diff --git a/evals/completion_fns/uni_finder.py b/evals/completion_fns/uni_finder.py new file mode 100644 index 0000000000..5a6fa04c9e --- /dev/null +++ b/evals/completion_fns/uni_finder.py @@ -0,0 +1,103 @@ +""" +Extending Completion Functions with Embeddings-based retrieval from a fetched dataset +""" +import json +import os +import time +from pathlib import Path + +import requests +from typing import Any, Optional, Union + +from evals.prompt.base import CompletionPrompt +from evals.api import CompletionFn, CompletionResult +from evals.record import record_sampling + + +class UniFinderCompletionResult(CompletionResult): + def __init__(self, response: str) -> None: + self.response = response + + def get_completions(self) -> list[str]: + return [self.response.strip()] if self.response else ["Unknown"] + + +class UniFinderCompletionFn(CompletionFn): + """ + This Completion Function uses embeddings to retrieve the top k relevant docs from a dataset to the prompt, then adds them to the context before calling the completion. + """ + + def __init__( + self, + model: Optional[str] = None, + instructions: Optional[str] = "You are a helpful assistant on extracting information from files.", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + n_ctx: Optional[int] = None, + cache_dir: Optional[str] = str(Path.home() / ".uni_finder/knowledge_base.json"), + pdf_parse_mode: Optional[str] = 'fast', # or 'precise', 指定使用的pdf解析版本 + extra_options: Optional[dict] = {}, + **kwargs + ): + self.model = model + self.instructions = instructions + self.api_base = api_base or os.environ.get("UNIFINDER_API_BASE") + self.api_key = api_key or os.environ.get("UNIFINDER_API_KEY") + self.n_ctx = n_ctx + self.extra_options = extra_options + self.cache_dir = cache_dir + self.pdf_parse_mode = pdf_parse_mode + Path(self.cache_dir).parent.mkdir(parents=True, exist_ok=True) + if not Path(self.cache_dir).exists(): + json.dump({}, open(self.cache_dir, "w")) + + def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCompletionResult: + """ + Args: + prompt: The prompt to complete, in either text string or Chat format. + kwargs: Additional arguments to pass to the completion function call method. + """ + + pdf_token = [] + if "file_name" in kwargs: + cache = json.load(open(self.cache_dir, 'r+')) + + if cache.get(kwargs["file_name"], {}).get(self.pdf_parse_mode, None) is None: + url = f"{self.api_base}/api/external/upload_pdf" + files = {'file': open(kwargs["file_name"], 'rb')} + data = { + 'pdf_parse_mode': self.pdf_parse_mode, + 'api_key': self.api_key + } + response = requests.post(url, data=data, files=files) + pdf_id = response.json()['pdf_token'] # 获得pdf的id,表示上传成功,后续可以使用这个id来指定pdf + + if kwargs["file_name"] not in cache: + cache[kwargs["file_name"]] = {self.pdf_parse_mode: pdf_id} + else: + cache[kwargs["file_name"]][self.pdf_parse_mode] = pdf_id + json.dump(cache, open(self.cache_dir, "w")) + else: + pdf_id = cache[kwargs["file_name"]][self.pdf_parse_mode] + print("############# pdf_id ##############", pdf_id) + pdf_token.append(pdf_id) + + url = f"{self.api_base}/api/external/chatpdf" + + if type(prompt) == list: + prompt = CompletionPrompt(prompt).to_formatted_prompt() + + payload = { + "model_engine": self.model, + "pdf_token": pdf_token, + "query": prompt, + 'api_key': self.api_key + } + response = requests.post(url, json=payload, timeout=1200) + try: + answer = response.json()['answer'] + except: + print(response.text) + answer = response.text + record_sampling(prompt=prompt, sampled=answer) + return UniFinderCompletionResult(answer) diff --git a/evals/completion_fns/zhishu.py b/evals/completion_fns/zhishu.py new file mode 100644 index 0000000000..5677c67f54 --- /dev/null +++ b/evals/completion_fns/zhishu.py @@ -0,0 +1,109 @@ +from typing import Any, Optional, Union +import os +import requests + +from evals.api import CompletionFn, CompletionResult +from evals.prompt.base import ( + OpenAICreateChatPrompt, + OpenAICreatePrompt, + Prompt, +) +from evals.record import record_sampling +from evals.utils.api_utils import ( + request_with_timeout +) + +default_prompts = { + "activity": "请汇总文献中全部抑制剂(分子请分别用名字和SMILES表达)的结合活性、活性种类(IC50, EC50, TC50, Ki, Kd中的一个),并备注每类结合活性的实验手段。以json格式输出,活性和活性类型的字段名分别为 \"Affinity\" 和 \"Affinity_type\"", +} + + +class Struct: + def __init__(self, **entries): + self.__dict__.update({k: self._wrap(v) for k, v in entries.items()}) + + def _wrap(self, value): + if isinstance(value, (tuple, list, set, frozenset)): + return type(value)([self._wrap(v) for v in value]) + else: + return Struct(**value) if isinstance(value, dict) else value + + def __repr__(self): + return '<%s>' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) + + +class BaseCompletionResult(CompletionResult): + def __init__(self, raw_data: Any, prompt: Any): + self.raw_data = Struct(**raw_data) if type(raw_data) == dict else raw_data + self.prompt = prompt + + def get_completions(self) -> list[str]: + raise NotImplementedError + + +class ZhishuCompletionResult(BaseCompletionResult): + def get_completions(self) -> list[str]: + completions = [] + if self.raw_data: + for choice in self.raw_data.choices: + if choice.message.content is not None: + completions.append(choice.message.content) + return completions + + +class ZhishuCompletionFn(CompletionFn): + def __init__( + self, + model: Optional[str] = None, + instructions: Optional[str] = "You are a helpful assistant on extracting information from files.", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + n_ctx: Optional[int] = None, + all_tools: Optional[bool] = False, + extra_options: Optional[dict] = {}, + **kwargs, + ): + self.model = model + self.instructions = instructions + self.api_base = api_base + self.api_key = api_key + self.n_ctx = n_ctx + self.all_tools = all_tools + self.extra_options = extra_options + + def __call__( + self, + prompt: Union[str, OpenAICreateChatPrompt], + **kwargs, + ) -> ZhishuCompletionResult: + if not isinstance(prompt, Prompt): + assert ( + isinstance(prompt, str) + or (isinstance(prompt, list) and all(isinstance(token, int) for token in prompt)) + or (isinstance(prompt, list) and all(isinstance(token, str) for token in prompt)) + or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt)) + ), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]" + + url = f"https://api.zhishuyun.com/openai/gpt-4-all?token={self.api_key or os.environ['ZHISHU_API_KEY']}" + headers = { + "content-type": "application/json" + } + + basic_message = [{"role": "system", "content": self.instructions}] if self.all_tools else [] + + messages = basic_message + [ + {"role": "user", "content": f"{kwargs['file_link']} {prompt}"} + ] if "file_link" in kwargs else prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}] + + payload = { + "model": self.model, + "messages": messages + } + + # result = request_with_timeout(requests.post, url, json=payload, headers=headers) + result = requests.post(url, json=payload, headers=headers) + + result = ZhishuCompletionResult(raw_data=result.json(), prompt=prompt) + print(result.get_completions()[0].replace("\\n", "\n")) + record_sampling(prompt=result.prompt, sampled=result.get_completions()) + return result diff --git a/evals/elsuite/modelgraded/rag_classify.py b/evals/elsuite/modelgraded/rag_classify.py new file mode 100644 index 0000000000..aa471ebd39 --- /dev/null +++ b/evals/elsuite/modelgraded/rag_classify.py @@ -0,0 +1,131 @@ +""" +Generic eval that uses a prompt + classification. +""" +from collections import Counter +from random import Random +from typing import Any, Optional, Union + +import evals +import evals.record +from evals.elsuite.modelgraded.classify_utils import classify, sample_and_concat_n_completions +from evals.elsuite.rag_match import get_rag_dataset +from evals.elsuite.utils import PromptFn, scrub_formatting_from_prompt + + +class RAGModelBasedClassify(evals.Eval): + def __init__( + self, + modelgraded_spec: str, + *args, + modelgraded_spec_args: Optional[dict[str, dict[str, str]]] = None, + sample_kwargs: Optional[dict[str, Any]] = None, + eval_kwargs: Optional[dict[str, Any]] = None, + multicomp_n: Union[int, str] = 1, + eval_type: Optional[str] = None, + match_fn: Optional[str] = None, + metaeval: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + # treat last completion_fn as eval_completion_fn + self.eval_completion_fn = self.completion_fns[-1] + if len(self.completion_fns) > 1: + self.completion_fns = self.completion_fns[:-1] + n_models = len(self.completion_fns) + self.sample_kwargs = {"max_tokens": 1024} + self.sample_kwargs.update(sample_kwargs or {}) + self.eval_kwargs = {"max_tokens": 1024} + self.eval_kwargs.update(eval_kwargs or {}) + self.metaeval = metaeval + self.modelgraded_spec_args = modelgraded_spec_args or {} + self.eval_type = eval_type + self.match_fn = match_fn + if multicomp_n == "from_models": + assert n_models > 1 + self.multicomp_n = n_models + else: + assert isinstance(multicomp_n, int) + self.multicomp_n = multicomp_n + if len(self.completion_fns) > 1: + assert self.multicomp_n == n_models + + self.mg = self.registry.get_modelgraded_spec(modelgraded_spec) + + def eval_sample(self, test_sample: dict, rng: Random) -> None: + """Evaluate a single sample. + + Recorded metrics are always: one of the self.choice_strings, or "__invalid__". + """ + # process test_sample + sample_file_dict = {key: value for key, value in test_sample.items() if key.startswith("file")} + test_sample = {key: value for key, value in test_sample.items() if not key.startswith("file")} + for k in self.mg.input_outputs: + test_sample[k] = scrub_formatting_from_prompt(test_sample[k]) + + # run policy completions + completions = {} + for k, v in self.mg.input_outputs.items(): + if v in test_sample: # test_sample already has completion, skip. + continue + + if self.multicomp_n > 1: + completion = sample_and_concat_n_completions( + self.completion_fns, + prompt=test_sample[k], + template_i=self.mg.output_template, + sample_kwargs={**self.sample_kwargs, "completion_kwargs": sample_file_dict}, + n=self.multicomp_n, + ) + else: + get_input_completion = PromptFn( + test_sample[k], completion_fn=self.completion_fn, **{**self.sample_kwargs, "completion_kwargs": sample_file_dict} + ) + completion, _ = get_input_completion() + completions[v] = completion + + # run modelgraded eval + metrics = {} + choice, info = classify( + mg=self.mg, + completion_fn=self.eval_completion_fn, + completion_kwargs=self.eval_kwargs, + eval_type=self.eval_type, + n=self.multicomp_n, + match_fn=self.match_fn, + format_kwargs={**completions, **test_sample, **self.modelgraded_spec_args}, + ) + metrics.update(dict(choice=choice, score=info["score"])) + + # run metaeval if requested + if self.metaeval: + assert "choice" in test_sample + metrics["metascore"] = choice == test_sample["choice"] + + evals.record.record_metrics(**metrics) + + return choice + + def run(self, recorder): + samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix()) + + self.eval_all_samples(recorder, samples) + record_metrics = {} + + all_sample_metrics = recorder.get_metrics() + if not all_sample_metrics: + return record_metrics + + # record the counts + choices = [m["choice"] for m in all_sample_metrics] + counts = dict(Counter(choices)) + record_metrics.update({f"counts/{k}": v for k, v in counts.items()}) + + # record the scores + scores = [m["score"] for m in all_sample_metrics if m["score"] is not None] + if scores: + record_metrics["score"] = sum(scores) / len(scores) + metascores = [m["metascore"] for m in all_sample_metrics if "metascore" in m] + if metascores: + record_metrics["metascore"] = sum(metascores) / len(metascores) + + return record_metrics diff --git a/evals/elsuite/rag_match.py b/evals/elsuite/rag_match.py new file mode 100644 index 0000000000..e541e520e8 --- /dev/null +++ b/evals/elsuite/rag_match.py @@ -0,0 +1,120 @@ +import os +from pathlib import Path +from typing import Any + +import oss2 +from oss2.credentials import EnvironmentVariableCredentialsProvider + +import evals +import evals.metrics +from evals.api import CompletionFn +from evals.prompt.base import is_chat_prompt + + +def init_oss(): + """ + Initialize OSS client. + """ + # Please set OSS_ACCESS_KEY_ID & OSS_ACCESS_KEY_SECRET in your environment variables. + auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider()) + + # 设置 Endpoint + endpoint = 'https://oss-cn-beijing.aliyuncs.com' + + # 设置 Bucket + bucket_name = 'dp-filetrans-bj' + bucket = oss2.Bucket(auth, endpoint, bucket_name) + + return bucket + + +def get_rag_dataset(samples_jsonl: str) -> list[dict]: + bucket = init_oss() + raw_samples = evals.get_jsonl(samples_jsonl) + + for raw_sample in raw_samples: + for ftype in ["", "answer"]: + if f"{ftype}file_name" not in raw_sample and f"{ftype}file_link" not in raw_sample: + continue + if f"{ftype}file_name" in raw_sample: + oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_name"]) + raw_sample[f"{ftype}file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file + + exists = bucket.object_exists(oss_file) + if exists: + print(f"文件 {oss_file} 已存在于 OSS 中。") + else: + # 上传文件 + bucket.put_object_from_file(oss_file, raw_sample[f"{ftype}file_name"]) + print(f"文件 {oss_file} 已上传到 OSS。") + if f"{ftype}file_link" in raw_sample: + local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else \ + os.path.basename(raw_sample[f"{ftype}file_link"]) + oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_link"]) + if not os.path.exists(local_file): + if bucket.object_exists(oss_file): + # 从 OSS 下载文件 + Path(local_file).parent.mkdir(parents=True, exist_ok=True) + bucket.get_object_to_file(oss_file, local_file) + print(f"文件 {oss_file} 已下载到本地。") + return raw_samples + + +class RAGMatch(evals.Eval): + def __init__( + self, + completion_fns: list[CompletionFn], + samples_jsonl: str, + *args, + max_tokens: int = 500, + num_few_shot: int = 0, + few_shot_jsonl: str = None, + **kwargs, + ): + super().__init__(completion_fns, *args, **kwargs) + assert len(completion_fns) == 1, "Match only supports one completion fn" + self.max_tokens = max_tokens + self.samples_jsonl = samples_jsonl + self.num_few_shot = num_few_shot + if self.num_few_shot > 0: + assert few_shot_jsonl is not None, "few shot requires few shot sample dataset" + self.few_shot_jsonl = few_shot_jsonl + self.few_shot = evals.get_jsonl(self._prefix_registry_path(self.few_shot_jsonl)) + + def eval_sample(self, sample: Any, *_): + assert isinstance(sample, dict), "sample must be a dict" + assert "input" in sample, "sample must have an 'input' key" + assert "ideal" in sample, "sample must have an 'ideal' key" + assert isinstance(sample["ideal"], str) or isinstance( + sample["ideal"], list + ), "sample['ideal'] must be a string or list of strings" + + prompt = sample["input"] + if self.num_few_shot > 0: + assert is_chat_prompt(sample["input"]), "few shot requires chat prompt" + prompt = sample["input"][:-1] + for s in self.few_shot[: self.num_few_shot]: + prompt += s["sample"] + prompt += sample["input"][-1:] + + result = self.completion_fn( + prompt=prompt, + temperature=0.0, + **{k: v for k, v in sample.items() if k not in ["input", "ideal"]} + ) + sampled = result.get_completions()[0] + + return evals.record_and_check_match( + prompt=prompt, + sampled=sampled, + expected=sample["ideal"], + ) + + def run(self, recorder): + samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix()) + self.eval_all_samples(recorder, samples) + events = recorder.get_events("match") + return { + "accuracy": evals.metrics.get_accuracy(events), + "boostrap_std": evals.metrics.get_bootstrap_accuracy_std(events), + } diff --git a/evals/elsuite/rag_table_extract.py b/evals/elsuite/rag_table_extract.py new file mode 100644 index 0000000000..281977ef61 --- /dev/null +++ b/evals/elsuite/rag_table_extract.py @@ -0,0 +1,309 @@ +import os +import traceback +from io import StringIO +import json +import re +from pathlib import Path + +from typing import List, Optional, Tuple, Union + +import pandas as pd +from pydantic import BaseModel +import uuid + +import evals +import evals.metrics +from evals.api import CompletionFn +from evals.elsuite.rag_match import get_rag_dataset +from evals.record import RecorderBase, record_match + + +code_pattern = r"```[\s\S]*?\n([\s\S]+?)\n```" +json_pattern = r"```json[\s\S]*?\n([\s\S]+?)\n```" +csv_pattern = r"```csv[\s\S]*?\n([\s\S]+?)\n```" +outlink_pattern = r"\[Download[a-zA-Z0-9 ]+?\]\((https://[a-zA-Z0-9_. /]+?)\)" + + +def parse_csv_text(csvtext: str) -> str: + lines = csvtext.strip().split("\n") + tuple_pattern = r"\((\"[\s\S]*?\"),(\"[\s\S]*?\")\)" + if re.search(tuple_pattern, lines[0]) is not None: + lines[0] = re.sub(tuple_pattern, r"(\1|\2)", lines[0]) + lines_clr = [re.sub(r"\"[\s\S]*?\"", "", line) for line in lines] + max_commas = max([line_clr.count(",") for line_clr in lines_clr]) + unified_lines = [line + ("," * (max_commas - line_clr.count(","))) for line, line_clr in zip(lines, lines_clr)] + return "\n".join(unified_lines) + + +def parse_table_multiindex(table: pd.DataFrame) -> pd.DataFrame: + """ + Parse a table with multiindex columns. + """ + + df = table.copy() + if df.columns.nlevels == 1: + coltypes = {col: type(df[col].iloc[0]) for col in df.columns} + for col, ctype in coltypes.items(): + if ctype == str: + if ":" in df[col].iloc[0] and "," in df[col].iloc[0]: + df[col] = [{key: value for key, value in [pair.split(": ") for pair in data.split(", ")]} for data + in df[col]] + coltypes[col] = dict + dfs = [] + + for col, ctype in coltypes.items(): + if ctype == dict: + d = pd.DataFrame(df.pop(col).tolist()) + d.columns = pd.MultiIndex.from_tuples([(col, fuzzy_normalize(key)) for key in d.columns]) + dfs.append(d) + df.columns = pd.MultiIndex.from_tuples([eval(col.replace("|", ",")) if (col[0] == "(" and col[-1] == ")") else + (col, "") for col in df.columns]) + df = pd.concat([df] + dfs, axis=1) + if df.columns.nlevels > 1: + df.columns = pd.MultiIndex.from_tuples([(col, fuzzy_normalize(subcol)) for col, subcol in df.columns]) + + return df + + +class FileSample(BaseModel): + file_name: Optional[str] + file_link: Optional[str] + answerfile_name: Optional[str] + answerfile_link: Optional[str] + compare_fields: List[Union[str, Tuple]] + index: Union[str, Tuple] = ("Compound", "") + + +def fuzzy_compare(a: str, b: str) -> Union[bool, float]: + """ + Compare two strings with fuzzy matching. + """ + + def standardize_unit(s: str) -> str: + """ + Standardize a (affinity) string to common units. + """ + mark = "" if re.search(r"[><=]", s) is None else re.search(r"[><=]", s).group() + unit = s.rstrip()[-2:] + number = float(re.search(r"[\+\-]*[0-9.]+", s).group()) + + if unit in ["µM", "uM"]: + unit = "nM" + number *= 1000 + elif unit in ["mM", "mm"]: + unit = "nM" + number *= 1000000 + + if mark == "=": + mark = "" + return f"{mark}{number:.1f} {unit}" + + unit_str = ["nM", "uM", "µM", "mM", "M", "%", " %"] + nan_str = ["n/a", "nan", "na", "n.a.", "nd", "not determined", "not tested", "inactive"] + a = a.strip() + b = b.strip() + if (a[-2:] in unit_str or a[-1] in unit_str) and (b[-2:] in unit_str or b[-1] in unit_str): + a = standardize_unit(a) + b = standardize_unit(b) + return a == b + elif a.lower() in nan_str and b.lower() in nan_str: + return True + elif (a.lower() in b.lower()) or (b.lower() in a.lower()): + return True + else: + import Levenshtein + return Levenshtein.distance(a.lower(), b.lower()) / (len(a) + len(b)) < 0.1 + + +def fuzzy_normalize(s): + if s.startswith("Unnamed"): + return "" + else: + """ 标准化字符串 """ + # 定义需要移除的单位和符号 + units = ["µM", "µg/mL", "nM"] + for unit in units: + s = s.replace(unit, "") + + # 定义特定关键字 + keywords = ["pIC50", "IC50", "EC50", "TC50", "GI50", "Ki", "Kd", "Kb", "pKb"] + + # 移除非字母数字的字符,除了空格 + # s = re.sub(r'[^\w\s]', '', s) + + # 分割字符串为单词列表 + words = s.split() + + # 将关键字移到末尾 + reordered_words = [word for word in words if word not in keywords] + keywords_in_string = [word for word in words if word in keywords] + reordered_words.extend(keywords_in_string) + # 重新组合为字符串 + return ' '.join(reordered_words) + + +class TableExtract(evals.Eval): + def __init__( + self, + completion_fns: list[CompletionFn], + samples_jsonl: str, + *args, + instructions: Optional[str] = "", + **kwargs, + ): + super().__init__(completion_fns, *args, **kwargs) + assert len(completion_fns) < 3, "TableExtract only supports 3 completion fns" + self.samples_jsonl = samples_jsonl + self.instructions = instructions + + def eval_sample(self, sample, rng): + assert isinstance(sample, FileSample) + + prompt = \ + self.instructions + # + f"\nThe fields should at least contain {sample.compare_fields}" + result = self.completion_fn( + prompt=prompt, + temperature=0.0, + max_tokens=5, + file_name=sample.file_name, + file_link=sample.file_link + ) + sampled = result.get_completions()[0] + + compare_fields_types = [type(x) for x in sample.compare_fields] + header_rows = [0, 1] if tuple in compare_fields_types else [0] + + correct_answer = parse_table_multiindex(pd.read_csv(sample.answerfile_name, header=header_rows).astype(str)) + correct_answer.to_csv("temp.csv", index=False) + correct_str = open("temp.csv", 'r').read() + + try: + if re.search(outlink_pattern, sampled) is not None: + code = re.search(outlink_pattern, sampled).group() + link = re.sub(outlink_pattern, r"\1", code) + + fname = f"/tmp/LLMEvals_{uuid.uuid4()}.csv" + os.system(f"wget {link} -O {fname}") + table = pd.read_csv(fname) + if pd.isna(table.iloc[0, 0]): + table = pd.read_csv(fname, header=header_rows) + elif "csv" in prompt: + code = re.search(csv_pattern, sampled).group() + code_content = re.sub(csv_pattern, r"\1", code) + code_content_processed = parse_csv_text(code_content) + # table = pd.read_csv(StringIO(code_content_processed), header=header_rows) + table = pd.read_csv(StringIO(code_content_processed)) + if pd.isna(table.iloc[0, 0]): + table = pd.read_csv(StringIO(code_content_processed), header=header_rows) + + elif "json" in prompt: + code = re.search(json_pattern, sampled).group() + code_content = re.sub(json_pattern, r"\1", code).replace("\"", "") + table = pd.DataFrame(json.loads(code_content)) + else: + table = pd.DataFrame() + table = parse_table_multiindex(table) + + if sample.index not in table.columns: + table.columns = [sample.index] + list(table.columns)[1:] + answerfile_out = sample.answerfile_name.replace(".csv", "_output.csv") + table.to_csv(answerfile_out, index=False) + picked_str = open(answerfile_out, 'r').read() + except: + print(Path(sample.file_name).stem) + code = re.search(code_pattern, sampled).group() + code_content = re.sub(code_pattern, r"\1", code) + code_content_processed = parse_csv_text(code_content) + print(code_content) + print(code_content_processed) + traceback.print_exc() + record_match( + correct=False, + expected=correct_str, + picked=sampled, + file_name=sample.file_name, + jobtype="match_all" + ) + return + + # TODO: Use similarity and Bipartite matching to match fields + renames = {} + for field in sample.compare_fields: + for i, sample_field in enumerate(table.columns): + field_query = field if type(field) != tuple else field[0] if field[1] == "" else field[1] + sample_field_query = sample_field if type(sample_field) != tuple else sample_field[0] if sample_field[1] == "" else sample_field[1] + if fuzzy_normalize(field_query) == "" or fuzzy_normalize(sample_field_query) == "": + continue + if fuzzy_compare(fuzzy_normalize(field_query), fuzzy_normalize(sample_field_query)) and \ + fuzzy_normalize(field_query).split()[-1] == fuzzy_normalize(sample_field_query).split()[-1]: + if sample_field not in renames.keys() and field_query not in renames.values(): + renames[sample_field_query] = field_query + break + renames = {key: value for key, value in renames.items() if key not in ["Compound", "Name", "SMILES"]} + if len(renames) > 0: + print("Find similar fields between answer and correct:", renames) + table.rename(columns=renames, inplace=True) + print(table) + + table[sample.index] = table[sample.index].astype(str) + correct_answer[sample.index] = correct_answer[sample.index].astype(str) + comparison_df = pd.merge(table.set_index(sample.index, drop=False), + correct_answer.set_index(sample.index, drop=False), + how="right", left_index=True, right_index=True) + + match_all = True + for field in sample.compare_fields: + if type(field) == tuple and len(field) > 1: + field_sample, field_correct = (f"{field[0]}_x", field[1]), (f"{field[0]}_y", field[1]) + else: + field_sample, field_correct = f"{field}_x", f"{field}_y" + match_field = field in table.columns and field in correct_answer.columns + match_all = match_all and match_field + record_match( + correct=match_field, + expected=field, + picked=str(list(table.columns)), + file_name=sample.file_name, + jobtype="match_field" + ) + if match_field: + match_number = table[field].shape[0] == correct_answer[field].shape[0] + match_all = match_all and match_number + record_match( + correct=match_number, + expected=correct_answer[field].shape[0], + picked=table[field].shape[0], + file_name=sample.file_name, + jobtype="match_number" + ) + + for sample_value, correct_value in zip(comparison_df[field_sample], comparison_df[field_correct]): + match_value = fuzzy_compare(str(sample_value), str(correct_value)) + match_all = match_all and match_value + record_match( + correct=match_value, + expected=correct_value, + picked=sample_value, + file_name=sample.file_name, + jobtype=field if type(field) == str else field[0] + ) + record_match( + correct=match_all, + expected=correct_str, + picked=picked_str, + file_name=sample.file_name, + jobtype="match_all" + ) + + def run(self, recorder: RecorderBase): + raw_samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix()) + for raw_sample in raw_samples: + raw_sample["compare_fields"] = [field if type(field) == str else tuple(field) for field in + raw_sample["compare_fields"]] + + samples = [FileSample(**raw_sample) for raw_sample in raw_samples] + self.eval_all_samples(recorder, samples) + return { + "accuracy": evals.metrics.get_accuracy(recorder.get_events("match")), + } diff --git a/evals/registry/completion_fns/retrieve.yaml b/evals/registry/completion_fns/retrieve.yaml new file mode 100644 index 0000000000..648d58a5f9 --- /dev/null +++ b/evals/registry/completion_fns/retrieve.yaml @@ -0,0 +1,23 @@ +retrieval/presidents/gpt-3.5-turbo: + class: evals.completion_fns.retrieval:RetrievalCompletionFn + args: + completion_fn: gpt-3.5-turbo + embeddings_and_text_path: presidents_embeddings.csv + k: 2 + +retrieval/presidents/cot/gpt-3.5-turbo: + class: evals.completion_fns.retrieval:RetrievalCompletionFn + args: + completion_fn: cot/gpt-3.5-turbo + embeddings_and_text_path: presidents_embeddings.csv + k: 2 + +retrieval_native/gpt-3.5-turbo: + class: evals.completion_fns.retrieval_native:OpenAIRetrievalCompletionFn + args: + model: gpt-3.5-turbo-1106 + +retrieval_native/gpt-4-all: + class: evals.completion_fns.retrieval_native:OpenAIRetrievalCompletionFn + args: + model: gpt-4-1106-preview \ No newline at end of file diff --git a/evals/registry/completion_fns/uni_finder.yaml b/evals/registry/completion_fns/uni_finder.yaml new file mode 100644 index 0000000000..ae2c7b778e --- /dev/null +++ b/evals/registry/completion_fns/uni_finder.yaml @@ -0,0 +1,23 @@ +uni_finder/fast/gpt-3.5-turbo: + class: evals.completion_fns.uni_finder:UniFinderCompletionFn + args: + pdf_parse_mode: fast + model: gpt35 + +uni_finder/precise/gpt-3.5-turbo: + class: evals.completion_fns.uni_finder:UniFinderCompletionFn + args: + pdf_parse_mode: precise + model: gpt35 + +uni_finder/fast/gpt-4-all: + class: evals.completion_fns.uni_finder:UniFinderCompletionFn + args: + pdf_parse_mode: fast + model: gpt4 + +uni_finder/precise/gpt-4-all: + class: evals.completion_fns.uni_finder:UniFinderCompletionFn + args: + pdf_parse_mode: precise + model: gpt4 \ No newline at end of file diff --git a/evals/registry/completion_fns/zhishu.yaml b/evals/registry/completion_fns/zhishu.yaml new file mode 100644 index 0000000000..13c281f84a --- /dev/null +++ b/evals/registry/completion_fns/zhishu.yaml @@ -0,0 +1,11 @@ +zhishu/gpt-4: + class: evals.completion_fns.zhishu:ZhishuCompletionFn + args: + model: gpt-4-all + all_tools: False + +zhishu/gpt-4-all: + class: evals.completion_fns.zhishu:ZhishuCompletionFn + args: + model: gpt-4-all + all_tools: True diff --git a/evals/registry/data/00_scipaper_affinity/samples.jsonl b/evals/registry/data/00_scipaper_affinity/samples.jsonl new file mode 100644 index 0000000000..ed8c716f73 --- /dev/null +++ b/evals/registry/data/00_scipaper_affinity/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e7193711a5f342aa26e16fe275e108eefc8fdc1ef6f4f6545dcb8f901132f2d +size 4916 diff --git a/evals/registry/data/01_scipaper_hasmol/samples.jsonl b/evals/registry/data/01_scipaper_hasmol/samples.jsonl new file mode 100644 index 0000000000..2962fd16b8 --- /dev/null +++ b/evals/registry/data/01_scipaper_hasmol/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac93b432f74bf7347b09b91a1c501738b65d0f2cfff6fdfdb6acc9786430ac86 +size 2151 diff --git a/evals/registry/data/01_scipaper_tag2mol/samples.jsonl b/evals/registry/data/01_scipaper_tag2mol/samples.jsonl new file mode 100644 index 0000000000..805bb85da8 --- /dev/null +++ b/evals/registry/data/01_scipaper_tag2mol/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b49af9ced518fdc5477f229494a29e9b927349dddad630217c44eaa813af4abc +size 2131 diff --git a/evals/registry/data/02_markush2mol/samples.jsonl b/evals/registry/data/02_markush2mol/samples.jsonl new file mode 100644 index 0000000000..e6ddf2c625 --- /dev/null +++ b/evals/registry/data/02_markush2mol/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fda4937180b350dacffafdd5dcaa299fca3a1468f269be03ef4b158a50e8e02 +size 502 diff --git a/evals/registry/data/03_scipaper_targets/samples.jsonl b/evals/registry/data/03_scipaper_targets/samples.jsonl new file mode 100644 index 0000000000..d14eb83645 --- /dev/null +++ b/evals/registry/data/03_scipaper_targets/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0efe816302535508a331db6feeb3cbbd36a95ea1c59d16f11e41a5e80a47daa8 +size 3482 diff --git a/evals/registry/eval_sets/chemistry.yaml b/evals/registry/eval_sets/chemistry.yaml new file mode 100644 index 0000000000..ad417da139 --- /dev/null +++ b/evals/registry/eval_sets/chemistry.yaml @@ -0,0 +1,7 @@ +chemistry: + evals: + - abstract2title + - research-question-extraction + - balance-chemical-equation + - mmlu-college-chemistry + - mmlu-high-school-chemistry \ No newline at end of file diff --git a/evals/registry/eval_sets/chemistry_drug.yaml b/evals/registry/eval_sets/chemistry_drug.yaml new file mode 100644 index 0000000000..f152a92231 --- /dev/null +++ b/evals/registry/eval_sets/chemistry_drug.yaml @@ -0,0 +1,8 @@ +chemistry_drug: + evals: + - scipaper_affinity + - scipaper_tag2mol + - scipaper_hasmol + - markush2mol + - scipaper_targets + - medmcqa \ No newline at end of file diff --git a/evals/registry/evals/00_scipaper_affinity.yaml b/evals/registry/evals/00_scipaper_affinity.yaml new file mode 100644 index 0000000000..e02548dcc3 --- /dev/null +++ b/evals/registry/evals/00_scipaper_affinity.yaml @@ -0,0 +1,38 @@ +scipaper_affinity: + id: scipaper_affinity.val.csv + metrics: [accuracy] +scipaper_affinity.val.json: + class: evals.elsuite.rag_table_extract:TableExtract + args: + samples_jsonl: 00_scipaper_affinity/samples.jsonl + instructions: | + Please give a complete list of names, affinities, target info (protein or cell line), and affinity types (chosen from IC50, EC50, TC50, GI50, Ki, Kd) of all inhibitors in the paper. If there are multiple tables, combine them. Don't give me reference. Output in json format. For example: + ```json + [ + { + "Compound": "5a", + "Name": "Aspirin", + "Affinities": { + "5HT1A (IC50)": "2.0 nM", + "5HT1D (IC50)": "8.0 nM", + "5HT-UT (IC50)": "12.6 nM", + "5HT1E (IC50)": ">1000 nM" + } + } + ] + ``` + +scipaper_affinity.val.csv: + class: evals.elsuite.rag_table_extract:TableExtract + args: + samples_jsonl: 00_scipaper_affinity/samples.jsonl + instructions: | + Please give a complete list of SMILES structures, affinities, target info (protein or cell line), and affinity types (chosen from IC50, EC50, TC50, GI50, Ki, Kd) of all compounds in the paper. Usually the coumpounds' tags are numbers. + 1. Find all the tables with relevant information + 2. Output in csv format with multiindex (Affinities, protein/cell line), write units not in header but in the value like "10.5 µM". Quote the value if it has comma! For example: + ```csv + Compound,Name,SMILES,Affinities,Affinities,Affinities,Affinities + ,,,5HT1A (IC50),5HT1D (IC50),5HT-UT (IC50),5HT1E () + "5a","Aspirin","CC(=O)Oc1ccccc1C(=O)O",2.0 nM,8.0 nM,12.6 nM,>1000 nM + ``` + 3. If there are multiple tables, concat them. Don't give me reference or using "...", give me complete table! \ No newline at end of file diff --git a/evals/registry/evals/01_scipaper_hasmol.yaml b/evals/registry/evals/01_scipaper_hasmol.yaml new file mode 100644 index 0000000000..c176f7dc06 --- /dev/null +++ b/evals/registry/evals/01_scipaper_hasmol.yaml @@ -0,0 +1,8 @@ +scipaper_hasmol: + id: scipaper_hasmol.dev.v0 + metrics: [accuracy] + +scipaper_hasmol.dev.v0: + class: evals.elsuite.rag_match:RAGMatch + args: + samples_jsonl: 01_scipaper_hasmol/samples.jsonl \ No newline at end of file diff --git a/evals/registry/evals/01_scipaper_tag2mol.yaml b/evals/registry/evals/01_scipaper_tag2mol.yaml new file mode 100644 index 0000000000..556c73e1b9 --- /dev/null +++ b/evals/registry/evals/01_scipaper_tag2mol.yaml @@ -0,0 +1,8 @@ +scipaper_tag2mol: + id: scipaper_tag2mol.dev.v0 + metrics: [accuracy] + +scipaper_tag2mol.dev.v0: + class: evals.elsuite.rag_match:RAGMatch + args: + samples_jsonl: 01_scipaper_tag2mol/samples.jsonl \ No newline at end of file diff --git a/evals/registry/evals/02_markush2mol.yaml b/evals/registry/evals/02_markush2mol.yaml new file mode 100644 index 0000000000..d564774e9a --- /dev/null +++ b/evals/registry/evals/02_markush2mol.yaml @@ -0,0 +1,8 @@ +markush2mol: + id: markush2mol.dev.v0 + metrics: [accuracy] + +markush2mol.dev.v0: + class: evals.elsuite.basic.match:Match + args: + samples_jsonl: 02_markush2mol/samples.jsonl \ No newline at end of file diff --git a/evals/registry/evals/03_scipaper_targets.yaml b/evals/registry/evals/03_scipaper_targets.yaml new file mode 100644 index 0000000000..722a1eb7eb --- /dev/null +++ b/evals/registry/evals/03_scipaper_targets.yaml @@ -0,0 +1,12 @@ +scipaper_targets: + id: scipaper_targets.test.v1 + metrics: [accuracy] + description: Test the model's ability to retrieve protein/cell line targets from literature. + +scipaper_targets.test.v1: + class: evals.elsuite.modelgraded.rag_classify:RAGModelBasedClassify + args: + samples_jsonl: 03_scipaper_targets/samples.jsonl + modelgraded_spec: closedqa + modelgraded_spec_args: + criteria: "conciseness: Does the answer has the same biological meaning as the content?" diff --git a/evals/reporters/DPTracking.py b/evals/reporters/DPTracking.py new file mode 100644 index 0000000000..485170711c --- /dev/null +++ b/evals/reporters/DPTracking.py @@ -0,0 +1,104 @@ +import glob +import os +import time +import uuid +from copy import deepcopy +from datetime import datetime +from pathlib import Path +from typing import Dict, Union, List, Any + +import numpy as np +import pandas as pd +import aim +from PIL import Image + + +class DPTrackingReporter: + @staticmethod + def _convert_logger_table(df: pd.DataFrame) -> aim.Table: + aim_df = deepcopy(df) + if aim_df.shape[0] == 0: + return aim.Table(aim_df) + for col in aim_df.columns: + i = 0 + while not aim_df[col].iloc[i]: + i += 1 + if i == aim_df.shape[0]: + i = 0 + break + data0 = aim_df[col].iloc[i] + # if isinstance(data0, Chem.Mol): + # molfiles = [] + # tmpdir = f"aim-tmp-{uuid.uuid4().hex}" + # Path(tmpdir).mkdir(exist_ok=True, parents=True) + # for i, mol in enumerate(aim_df[col]): + # if mol: + # molfile = f"{tmpdir}/{i}.sdf" + # Chem.MolToMolFile(mol, molfile) + # molfiles.append(molfile) + # else: + # molfiles.append(None) + # aim_df[col] = [aim.Molecule(molfile) if molfile else None for molfile in molfiles] + if isinstance(data0, Image.Image): + imgfiles = [] + tmpdir = f"aim-tmp-{uuid.uuid4().hex}" + Path(tmpdir).mkdir(exist_ok=True, parents=True) + for i, img in enumerate(aim_df[col]): + if img: + imgfile = f"{tmpdir}/{i}.png" + img.save(imgfile) + imgfiles.append(imgfile) + else: + imgfiles.append(None) + aim_df[col] = [aim.TableImage(imgfile) if imgfile else None for imgfile in imgfiles] + return aim.Table(aim_df) + + @staticmethod + def _convert_logger_data(v: Any) -> Any: + import matplotlib.pyplot as plt + try: + import plotly.graph_objects as go + except ImportError: + go = plt + if type(v) in [go.Figure, plt.Figure]: + return aim.Figure(v) + if type(v) in [Image.Image] or (type(v) == str and Path(v).exists() and Path(v).suffix in [".png", ".jpg"]): + return aim.Image(v) + if type(v) in [pd.DataFrame]: + return DPTrackingReporter._convert_logger_table(v) + if type(v) in [np.ndarray, list]: + return aim.Distribution(v) + if type(v) == str: + return aim.Text(v) + return v + + @staticmethod + def report_run(config_logger: Dict, config_run: Dict = {}, logger_data: Dict = {}, step: int = -1): + dp_mlops_config = config_logger["dp_mlops"] + + # Experiment Tracking + if "aim_personal_token" in dp_mlops_config.keys(): + os.environ["AIM_ACCESS_TOKEN"] = dp_mlops_config["aim_personal_token"] + run = aim.Run( + experiment=config_logger["project"], + run_hash=config_logger.get("hash", None), + repo=dp_mlops_config["aim_repo"] + ) + run.name = config_logger["name"] + run["config"] = config_run + for tag in set([config_logger["name"]] + dp_mlops_config.get("tags", [])): + if tag and tag.lower() not in [t.lower() for t in run.props.tags]: + print(tag.lower(), run.props.tags) + run.add_tag(tag.lower()) + + logger_data_aim = {key: DPTrackingReporter._convert_logger_data(value) for key, value in logger_data.items()} + + for key, value in logger_data_aim.items(): + print(key, type(value)) + if "/" not in key or "kcal/mol" in key or "10.1021/" in key or "10.1016/" in key: + run.track(value, name=key) + else: + key, context_str = key.split("/") + context_dict = {k: v for k, v in [kv.split(":") for kv in context_str.split(",")]} + run.track(value, name=key, context={**context_dict}) + run.close() diff --git a/evals/reporters/Feishu.py b/evals/reporters/Feishu.py new file mode 100644 index 0000000000..65ccc610aa --- /dev/null +++ b/evals/reporters/Feishu.py @@ -0,0 +1,369 @@ +import os +from pathlib import Path +import json +import datetime + +from typing import Dict, Union, List + +import requests + +# 时间、实验名、项目、成功体系占比、Protocol、imgkey、Tracking链接、工作流链接 +FEISHU_MESSAGE_STRING = \ + ''' +{ + "config": { + "wide_screen_mode": true + }, + "elements": [ + { + "fields": [ + { + "is_short": true, + "text": { + "content": "**🕐 时间:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": true, + "text": { + "content": "**🔢 实验名:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": false, + "text": { + "content": "", + "tag": "lark_md" + } + }, + { + "is_short": true, + "text": { + "content": "**📋 项目:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": true, + "text": { + "content": "**📋 成功体系:**\n%s", + "tag": "lark_md" + } + } + ], + "tag": "div" + }, + { + "fields": [ + { + "is_short": false, + "text": { + "content": "**🕐 Protocol:**\n%s", + "tag": "lark_md" + } + } + ], + "tag": "div" + }, + { + "alt": { + "content": "", + "tag": "plain_text" + }, + "img_key": "%s", + "tag": "img", + "title": { + "content": "Metrics 汇总:", + "tag": "lark_md" + } + }, + { + "actions": [ + { + "tag": "button", + "text": { + "content": "跟进处理", + "tag": "plain_text" + }, + "type": "primary", + "value": { + "key1": "value1" + } + }, + { + "options": [ + { + "text": { + "content": "屏蔽10分钟", + "tag": "plain_text" + }, + "value": "1" + }, + { + "text": { + "content": "屏蔽30分钟", + "tag": "plain_text" + }, + "value": "2" + }, + { + "text": { + "content": "屏蔽1小时", + "tag": "plain_text" + }, + "value": "3" + }, + { + "text": { + "content": "屏蔽24小时", + "tag": "plain_text" + }, + "value": "4" + } + ], + "placeholder": { + "content": "暂时屏蔽实验跟踪", + "tag": "plain_text" + }, + "tag": "select_static", + "value": { + "key": "value" + } + } + ], + "tag": "action" + }, + { + "tag": "hr" + }, + { + "tag": "div", + "text": { + "content": "📝 [Tracking链接](%s) | 🙋 [工作流链接](%s)", + "tag": "lark_md" + } + } + ], + "header": { + "template": "green", + "title": { + "content": "IFD 实验跟踪", + "tag": "plain_text" + } + } +} +''' + +FEISHU_MESSAGE = { + "config": { + "wide_screen_mode": True + }, + "elements": [ + { + "fields": [ + { + "is_short": True, + "text": { + "content": "**🕐 时间:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": True, + "text": { + "content": "**🔢 实验名:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": False, + "text": { + "content": "", + "tag": "lark_md" + } + }, + { + "is_short": True, + "text": { + "content": "**📋 项目:**\n%s", + "tag": "lark_md" + } + }, + { + "is_short": True, + "text": { + "content": "**📋 成功体系:**\n%s", + "tag": "lark_md" + } + } + ], + "tag": "div" + }, + { + "fields": [ + { + "is_short": False, + "text": { + "content": "**🕐 Protocol:**\n%s", + "tag": "lark_md" + } + } + ], + "tag": "div" + }, + { + "alt": { + "content": "", + "tag": "plain_text" + }, + "img_key": "%s", + "tag": "img", + "title": { + "content": "Metrics 汇总:", + "tag": "lark_md" + } + }, + { + "actions": [ + { + "tag": "button", + "text": { + "content": "跟进处理", + "tag": "plain_text" + }, + "type": "primary", + "value": { + "key1": "value1" + } + }, + { + "options": [ + { + "text": { + "content": "屏蔽10分钟", + "tag": "plain_text" + }, + "value": "1" + }, + { + "text": { + "content": "屏蔽30分钟", + "tag": "plain_text" + }, + "value": "2" + }, + { + "text": { + "content": "屏蔽1小时", + "tag": "plain_text" + }, + "value": "3" + }, + { + "text": { + "content": "屏蔽24小时", + "tag": "plain_text" + }, + "value": "4" + } + ], + "placeholder": { + "content": "暂时屏蔽实验跟踪", + "tag": "plain_text" + }, + "tag": "select_static", + "value": { + "key": "value" + } + } + ], + "tag": "action" + }, + { + "tag": "hr" + }, + { + "tag": "div", + "text": { + "content": "📝 [Tracking链接](%s) | 🙋 [工作流链接](%s)", + "tag": "lark_md" + } + } + ], + "header": { + "template": "green", + "title": { + "content": "IFD 实验跟踪", + "tag": "plain_text" + } + } +} + + +class FeishuReporter: + @staticmethod + def _get_tenant_token(app_id: str = "cli_a301e6759d32500c", app_secret: str = "uLiHOmf0QOQRkhwymy8AmfHWykMQaMFk"): + url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" + + payload = json.dumps({ + "app_id": app_id, + "app_secret": app_secret + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + response.raise_for_status() + data = response.json() + assert data['code'] == 0 + return data['tenant_access_token'] + + @staticmethod + def _upload_image(file_path, type='image/png', app_id: str = "cli_a301e6759d32500c", + app_secret: str = "uLiHOmf0QOQRkhwymy8AmfHWykMQaMFk"): + url = "https://open.feishu.cn/open-apis/im/v1/images" + payload = {'image_type': 'message'} + files = [ + ('image', (Path(file_path).stem, open(file_path, 'rb'), type)) + ] + token = FeishuReporter._get_tenant_token(app_id=app_id, app_secret=app_secret) + headers = { + 'Authorization': f'Bearer {token}' + } + response = requests.request("POST", url, headers=headers, data=payload, files=files) + response.raise_for_status() + data = response.json() + assert data['code'] == 0 + return data['data']['image_key'] + + @staticmethod + def report_run(feishu_groups: List, experiment_group: str, project: str, success_ratio: str, + config_protocol: Dict, + imgfile: Union[str, Path], track_url: str, workflow_url: str, + app_id: str = "", app_secret: str = ""): + app_id = os.environ.get("FEISHU_APP_ID", app_id) + app_secret = os.environ.get("FEISHU_APP_SECRET", app_secret) + now = datetime.datetime.now() + img_key = FeishuReporter._upload_image(imgfile, app_id=app_id, app_secret=app_secret) + + message = FEISHU_MESSAGE.copy() + + message["elements"][0]["fields"][0]["text"]["content"] = \ + message["elements"][0]["fields"][0]["text"]["content"] % now.strftime("%Y-%m-%d %H:%M:%S") + message["elements"][0]["fields"][1]["text"]["content"] = \ + message["elements"][0]["fields"][1]["text"]["content"] % experiment_group + message["elements"][0]["fields"][3]["text"]["content"] = \ + message["elements"][0]["fields"][3]["text"]["content"] % project + message["elements"][0]["fields"][4]["text"]["content"] = \ + message["elements"][0]["fields"][4]["text"]["content"] % success_ratio + message["elements"][1]["fields"][0]["text"]["content"] = \ + message["elements"][1]["fields"][0]["text"]["content"] % json.dumps(config_protocol, indent=4) + message["elements"][2]["img_key"] = img_key + message["elements"][5]["text"]["content"] = message["elements"][5]["text"]["content"] % ( + track_url, workflow_url) + + for feishu_group in feishu_groups: + requests.post(feishu_group, + json={"msg_type": "interactive", "card": message}) diff --git a/evals/reporters/WandB.py b/evals/reporters/WandB.py new file mode 100644 index 0000000000..7fa45d4c25 --- /dev/null +++ b/evals/reporters/WandB.py @@ -0,0 +1,43 @@ +from pathlib import Path +from typing import Dict, Union, List +import traceback + +import pandas as pd + +try: + import wandb +except: + print("No wandb found!") + + +class WandBReporter: + @staticmethod + def report_run(config_logger: Dict, metric_data: pd.DataFrame, step: int = -1): + logger_data = {} + + logger_data[f"correlation_ligand_sidechain"] = wandb.Plotly(fig) + + wandb_config = config_logger.get("wandb", {}).copy() + wandb_config["name"] = config_logger["name"] + wandb_config["group"] = config_logger["group"] + wandb_config["id"] = config_logger["id"] + wandb.login(key=wandb_config.pop('key')) + + try: + run = wandb.init(**wandb_config) + except: + traceback.print_exc() + wandb_config["mode"] = "offline" + run = wandb.init(**wandb_config) + sampler_metric_wb = wandb.Table(dataframe=metric_data) + logger_data["sampler_metrics"] = sampler_metric_wb + + if step >= 0: + wandb.log(data=logger_data, step=step) + else: + wandb.log(data=logger_data) + wandb.finish() + + @staticmethod + def report_summary(): + pass diff --git a/evals/reporters/__init__.py b/evals/reporters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/evals/utils/api_utils.py b/evals/utils/api_utils.py index ae6d34ae30..f6592919f9 100644 --- a/evals/utils/api_utils.py +++ b/evals/utils/api_utils.py @@ -4,6 +4,7 @@ import concurrent import logging import os +import time import backoff import openai @@ -70,3 +71,61 @@ def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs): logging.warning(result) raise openai.error.APIError(result["error"]) return result + + +@backoff.on_exception( + wait_gen=backoff.expo, + exception=( + openai.RateLimitError, + openai.APIConnectionError, + openai.APITimeoutError, + openai.InternalServerError, + ), + max_value=60, + factor=1.5, +) +def openai_rag_completion_create_retrying(client: OpenAI, *args, **kwargs): + """ + Helper function for creating a RAG completion. + `args` and `kwargs` match what is accepted by `openai.ChatCompletion.create`. + """ + + file = client.files.create(file=open(kwargs["file_name"], "rb"), purpose='assistants') + + # Create an Assistant (Note model="gpt-3.5-turbo-1106" instead of "gpt-4-1106-preview") + assistant = client.beta.assistants.create( + name="File Assistant", + instructions=kwargs.get("instructions", ""), + model=kwargs.get("model", "gpt-3.5-turbo-1106"), + tools=[{"type": "retrieval"}], + file_ids=[file.id] + ) + + # Create a Thread + thread = client.beta.threads.create() + + # Add a Message to a Thread + message = client.beta.threads.messages.create(thread_id=thread.id, role="user", + content=kwargs.get("prompt", "") + ) + + # Run the Assistant + run = client.beta.threads.runs.create(thread_id=thread.id, assistant_id=assistant.id) + + # If run is 'completed', get messages and print + while True: + # Retrieve the run status + run_status = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) + time.sleep(10) + if run_status.status == 'completed': + messages = client.beta.threads.messages.list(thread_id=thread.id) + answer = messages.data[0].content[0].text.value + break + else: + ### sleep again + time.sleep(2) + + # if "error" in result: + # logging.warning(result) + # raise openai.error.APIError(result["error"]) + return answer diff --git a/examples/config_logger.json b/examples/config_logger.json new file mode 100644 index 0000000000..dcfee21077 --- /dev/null +++ b/examples/config_logger.json @@ -0,0 +1,7 @@ +{ + "name": "20231231-unifinder-poc", + "project": "Uni-Finder/Benchmark", + "dp_mlops":{ + "aim_repo": "aim://tracking-api.mlops.dp.tech:443" + } +} diff --git a/pyproject.toml b/pyproject.toml index b6eff11e67..05a5be64fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ formatters = [ [project.scripts] oaieval = "evals.cli.oaieval:main" oaievalset = "evals.cli.oaievalset:main" +llmreport = "evals.cli.llmreport:main" [tool.setuptools] packages = ["evals"]