diff --git a/src/webagents_step/environment/webarena.py b/src/webagents_step/environment/webarena.py index 1363060..5cd1c09 100644 --- a/src/webagents_step/environment/webarena.py +++ b/src/webagents_step/environment/webarena.py @@ -25,6 +25,8 @@ from webagents_step.environment.env import WebEnvironment import json import re +import subprocess +import tempfile # Init an environment from browser_env import ( create_id_based_action, @@ -33,6 +35,7 @@ ActionTypes, ScriptBrowserEnv ) +from browser_env.auto_login import get_site_comb_from_filepath from evaluation_harness.evaluators import evaluator_router class WebArenaEnvironmentWrapper(WebEnvironment): @@ -44,6 +47,39 @@ def __init__(self, config_file, max_browser_rows=300, max_steps=50, slow_mo=1, o current_viewport_only=current_viewport_only, viewport_size=viewport_size ) + + with open(config_file) as f: + _c = json.load(f) + intent = _c["intent"] + task_id = _c["task_id"] + # automatically login + if _c["storage_state"]: + cookie_file_name = os.path.basename(_c["storage_state"]) + comb = get_site_comb_from_filepath(cookie_file_name) + temp_dir = tempfile.mkdtemp() + # subprocess to renew the cookie + + current_dir = os.getcwd() + script_dir = os.path.join(current_dir, "../webarena") + assert os.path.isdir(script_dir), "WebArena directory not found at {}".format(script_dir) + + subprocess.run( + [ + "python", + os.path.join(script_dir, "browser_env/auto_login.py"), + "--auth_folder", + temp_dir, + "--site_list", + *comb, + ] + ) + _c["storage_state"] = f"{temp_dir}/{cookie_file_name}" + assert os.path.exists(_c["storage_state"]) + # update the config file + config_file = f"{temp_dir}/{os.path.basename(config_file)}" + with open(config_file, "w") as f: + json.dump(_c, f) + self.config_file = config_file with open(self.config_file, "r") as f: self.config = json.load(f) @@ -136,4 +172,4 @@ def update_webarena_metrics(self, action_cmd=None): self.reward = evaluator(trajectory=self.trajectory, config_file=self.config_file, page=self.webarena_env.page, client=self.webarena_env.get_page_client(self.webarena_env.page)) except Exception as e: print(f"Got excepetion: {e}") - self.reward = 0 \ No newline at end of file + self.reward = 0