Skip to content
This repository was archived by the owner on Nov 1, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions openeqa/baselines/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,43 @@
The commands required to run serveral baselines are listed below. Some baselines are labeled (language-only) because the model only receives an EQA question $Q$ and must answer based on its prior knowledge of the world. Others baselines are vision-language models (VLMs), which are able to jointly process the question $Q$ and image frames from the episode history $H$.

1. GPT-4 (language-only)

```bash
# requires setting the OPENAI_API_KEY environment variable
python openeqa/baselines/gpt4.py --dry-run # remove --dry-run to process the full benchmark
```

2. LLaMA (language-only)

First, download LLaMA weights in the Hugging Face format from [here](https://huggingface.co/meta-llama). Then, run:

```bash
python openeqa/baselines/llama.py -m <path/to/hf/weights>
```

3. GPT-4V (vision + language)

```bash
# requires setting the OPENAI_API_KEY environment variable
python openeqa/baselines/gpt4v.py --num-frames 50 --dry-run # remove --dry-run to process the full benchmark
```

4. Gemini Pro (language-only)

```bash
# requires setting the GOOGLE_API_KEY environment variable
python openeqa/baselines/gemini-pro.py --dry-run # remove --dry-run to process the full benchmark
```

5. Gemini Pro Vision (vision + language)

```bash
# requires setting the GOOGLE_API_KEY environment variable
python openeqa/baselines/gemini-pro-vision.py --num-frames 15 --dry-run # remove --dry-run to process the full benchmark
```

6. Claude 3 (vision + language)

```bash
# requires setting the ANTHROPIC_API_KEY environment variable
python openeqa/baselines/claude-vision.py --num-frames 20 --dry-run # remove --dry-run to process the full benchmark
Expand Down
159 changes: 159 additions & 0 deletions openeqa/baselines/idefics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import json
import numpy as np
from pathlib import Path
from typing import Optional, List

import tqdm

from openeqa.utils.idefics_utils import IdeficsRunner, enable_full_determinism, prepare_idefics_vision_messages
from openeqa.utils.prompt_utils import load_prompt


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=Path,
default="data/open-eqa-v0.json",
help="path to EQA dataset (default: data/open-eqa-v0.json)",
)
parser.add_argument(
"-m",
"--model-path",
type=Path,
required=True,
help="path to weights in huggingface format",
)
parser.add_argument(
"--frames-directory",
type=Path,
default="data/frames/",
help="path image frames (default: data/frames/)",
)
parser.add_argument(
"--num-frames",
type=int,
default=10,
help="num frames in gpt4v (default: 50)",
)
parser.add_argument(
"--seed",
type=int,
default=1234,
help="gpt seed (default: 1234)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.2,
help="gpt temperature (default: 0.2)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=128,
help="gpt maximum tokens (default: 128)",
)
parser.add_argument(
"--output-directory",
type=Path,
default="data/results",
help="output directory (default: data/results)",
)
parser.add_argument(
"--force",
action="store_true",
help="continue running on API errors (default: false)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="only process the first 5 questions",
)
args = parser.parse_args()
enable_full_determinism(args.seed)
args.output_directory.mkdir(parents=True, exist_ok=True)
args.output_path = args.output_directory / (
args.dataset.stem + "-{}-{}.json".format(str(args.model_path).replace('/', '-'), args.seed)
)
return args


def parse_output(output: str) -> str:
output_split = output[0].split("Assistant:")
if len(output_split)==1:
raise ValueError("Invalid output string: {}".format(output[0]))
return output[0].split("Assistant:")[-1].strip()


def ask_question(
model, question: str, image_paths: List[str], max_tokens: int = 128, temperature: float = 0.2
) -> Optional[str]:
prompt = load_prompt("idefics")
input = prompt.format(question=question)
prefix, suffix = prompt.split("User Query:")
suffix = "User Query:" + suffix.format(question=question)

input = prepare_idefics_vision_messages(prefix, suffix, image_paths)
output = model(input, image_paths=image_paths, max_new_tokens=max_tokens, temperature=temperature)
return parse_output(output)


def main(args: argparse.Namespace):
# load dataset
dataset = json.load(args.dataset.open("r"))
print("found {:,} questions".format(len(dataset)))

# load model
model = IdeficsRunner(args.model_path)

# load results
results = []
if args.output_path.exists():
results = json.load(args.output_path.open())
print("found {:,} existing results".format(len(results)))
completed = [item["question_id"] for item in results]

# process data
for idx, item in enumerate(tqdm.tqdm(dataset)):
if args.dry_run and idx >= 5:
break

# skip completed questions
question_id = item["question_id"]
if question_id in completed:
continue # skip existing

# extract scene paths
folder = args.frames_directory / item["episode_history"]
frames = sorted(folder.glob("*-rgb.png"))
indices = np.round(np.linspace(0, len(frames) - 1, args.num_frames)).astype(int)
paths = [str(frames[i]) for i in indices]

# generate answer
question = item["question"]
answer = ask_question(
model=model,
question=question,
image_paths=paths,
max_tokens=args.max_tokens,
temperature=args.temperature,
)

# store results
results.append({"question_id": question_id, "answer": answer})
json.dump(results, args.output_path.open("w"), indent=2)

# save at end (redundant)
json.dump(results, args.output_path.open("w"), indent=2)
print("saving {:,} answers".format(len(results)))


if __name__ == "__main__":
main(parse_args())
154 changes: 154 additions & 0 deletions openeqa/baselines/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import json
from pathlib import Path
from typing import Optional

import tqdm

from openeqa.utils.llama_utils import LLaMARunner, enable_full_determinism
from openeqa.utils.prompt_utils import load_prompt


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=Path,
default="data/open-eqa-v0.json",
help="path to EQA dataset (default: data/open-eqa-v0.json)",
)
parser.add_argument(
"-m",
"--model-path",
type=Path,
required=True,
help="path to weights in huggingface format",
)
parser.add_argument(
"--model-name",
type=str,
help="model name (defaults to model path folder name)",
)
parser.add_argument(
"--load-in-8bit",
action="store_true",
help="load model in 8bit mode (default: false)",
)
parser.add_argument(
"--use-fast-kernels",
action="store_true",
help="use fast kernels (default: false)",
)
parser.add_argument(
"--seed",
type=int,
default=1234,
help="gpt seed (default: 1234)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.2,
help="gpt temperature (default: 0.2)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=128,
help="gpt maximum tokens (default: 128)",
)
parser.add_argument(
"--output-directory",
type=Path,
default="data/results",
help="output directory (default: data/results)",
)
parser.add_argument(
"--force",
action="store_true",
help="continue running on API errors (default: false)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="only process the first 5 questions",
)
args = parser.parse_args()
enable_full_determinism(args.seed)
if args.model_name is None:
args.model_name = args.model_path.name.lower()
args.output_directory.mkdir(parents=True, exist_ok=True)
args.output_path = args.output_directory / (
args.dataset.stem + "-{}-{}.json".format(args.model_name, args.seed)
)
return args


def parse_output(output: str) -> str:
start_idx = output.find("A:")
if start_idx == -1:
raise ValueError("Invalid output string: {}".format(output))
end_idx = output.find("\n", start_idx)
if end_idx == -1:
return output[start_idx:].replace("A:", "").strip()
return output[start_idx:end_idx].replace("A:", "").strip()


def ask_question(
model, question: str, max_tokens: int = 128, temperature: float = 0.2
) -> Optional[str]:
prompt = load_prompt("blind-llm")
input = prompt.format(question=question)
output = model(input, max_new_tokens=max_tokens, temperature=temperature)
return parse_output(output)


def main(args: argparse.Namespace):
# load dataset
dataset = json.load(args.dataset.open("r"))
print("found {:,} questions".format(len(dataset)))

# load model
model = LLaMARunner(
args.model_path,
load_in_8bit=args.load_in_8bit,
use_fast_kernels=args.use_fast_kernels,
)

# load results
results = []
if args.output_path.exists():
results = json.load(args.output_path.open())
print("found {:,} existing results".format(len(results)))
completed = [item["question_id"] for item in results]

# process data
for idx, item in enumerate(tqdm.tqdm(dataset)):
if args.dry_run and idx >= 5:
break

# skip completed questions
question_id = item["question_id"]
if question_id in completed:
continue # skip existing

# generate answer
question = item["question"]
answer = ask_question(model=model, question=question)

# store results
results.append({"question_id": question_id, "answer": answer})
json.dump(results, args.output_path.open("w"), indent=2)

# save at end (redundant)
json.dump(results, args.output_path.open("w"), indent=2)
print("saving {:,} answers".format(len(results)))


if __name__ == "__main__":
main(parse_args())
Loading