From 47bda17228b807ad264198fcae0ad7d2424331da Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 8 Oct 2025 14:38:06 -0700 Subject: [PATCH] feat: GenAI SDK client - Support agent engine sandbox http request in genai sdk PiperOrigin-RevId: 816865842 --- setup.py | 2 + tests/unit/vertexai/genai/test_sandbox.py | 100 ++++++++++++++ vertexai/_genai/sandboxes.py | 160 ++++++++++++++++++++++ 3 files changed, 262 insertions(+) create mode 100644 tests/unit/vertexai/genai/test_sandbox.py diff --git a/setup.py b/setup.py index 958205c4e1..cb960d55f5 100644 --- a/setup.py +++ b/setup.py @@ -168,6 +168,7 @@ "opentelemetry-exporter-otlp-proto-http < 2", "pydantic >= 2.11.1, < 3", "typing_extensions", + "google-cloud-iam", ] evaluation_extra_require = [ @@ -256,6 +257,7 @@ "bigframes; python_version>='3.10' and python_version<'3.14'", # google-api-core 2.x is required since kfp requires protobuf > 4 "google-api-core >= 2.11, < 3.0.0", + "google-cloud-iam", "grpcio-testing", "grpcio-tools >= 1.63.0; python_version>='3.13'", "ipython", diff --git a/tests/unit/vertexai/genai/test_sandbox.py b/tests/unit/vertexai/genai/test_sandbox.py new file mode 100644 index 0000000000..443996a5da --- /dev/null +++ b/tests/unit/vertexai/genai/test_sandbox.py @@ -0,0 +1,100 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +import os + +from unittest import mock +from urllib.parse import urlencode + +from google import auth +from google.auth import credentials as auth_credentials +from google.cloud import aiplatform +import vertexai +from google.cloud.aiplatform import initializer +from vertexai._genai import _agent_engines_utils +from vertexai._genai import types as _genai_types +from google.genai import client +from google.genai import types as genai_types +import pytest + +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) +_TEST_LOCATION = "us-central1" +_TEST_PROJECT = "test-project" +_TEST_RESOURCE_ID = "1028944691210842416" +_TEST_SANDBOX_ID = "sandbox-123" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_AGENT_ENGINE_RESOURCE_NAME = ( + f"{_TEST_PARENT}/reasoningEngines/{_TEST_RESOURCE_ID}" +) +_TEST_SANDBOX_RESOURCE_NAME = ( + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}/sandboxes/{_TEST_SANDBOX_ID}" +) +_TEST_AGENT_ENGINE_ENV_KEY = "GOOGLE_CLOUD_AGENT_ENGINE_ENV" +_TEST_AGENT_ENGINE_ENV_VALUE = "test_env_value" +_TEST_SERVICE_ACCOUNT_EMAIL = "test-sa@test-project.iam.gserviceaccount.com" + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + google_auth_mock.return_value = ( + auth_credentials.AnonymousCredentials(), + _TEST_PROJECT, + ) + yield google_auth_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestSandbox: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + importlib.reload(vertexai) + os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE + self.client = vertexai.Client( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @mock.patch.object(client.Client, "_get_api_client") + def test_send_command(self, mock_get_api_client): + mock_sandbox = mock.Mock() + mock_sandbox.connection_info.load_balancer_ip = "127.0.0.1" + mock_sandbox.connection_info.load_balancer_hostname = None + mock_http_client = mock_get_api_client.return_value + mock_http_client.request.return_value = ( + genai_types.HttpResponse(body=b"{}", headers={}) + ) + + self.client.agent_engines.sandboxes.send_command( + http_method="GET", + access_token="test_token", + sandbox_environment=mock_sandbox, + path="test/path", + ) + + call_args = mock_get_api_client.call_args + assert call_args is not None + _, kwargs = call_args + http_options = kwargs["http_options"] + assert http_options.base_url == "http://127.0.0.1/test/path" + assert http_options.headers["Authorization"] == "Bearer test_token" + + mock_http_client.request.assert_called_with("GET", "test/path", {}) diff --git a/vertexai/_genai/sandboxes.py b/vertexai/_genai/sandboxes.py index 3e1d1181bb..c0032c033e 100644 --- a/vertexai/_genai/sandboxes.py +++ b/vertexai/_genai/sandboxes.py @@ -19,11 +19,16 @@ import json import logging import mimetypes +import secrets +import time from typing import Any, Iterator, Optional, Union from urllib.parse import urlencode +from google import genai +from google.cloud import iam_credentials_v1 from google.genai import _api_module from google.genai import _common +from google.genai import types as genai_types from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv from google.genai.pagers import Pager @@ -704,6 +709,161 @@ def delete( """ return self._delete(name=name, config=config) + def generate_access_token( + self, + service_account_email: str, + sandbox_id: str, + port: str = "8080", + timeout: int = 3600, + ) -> str: + """Signs a JWT with a Google Cloud service account. + + Args: + service_account_email (str): + Required. The email of the service account to use for signing. + sandbox_id (str): + Required. The resource name of the sandbox to generate a token for. + port (str): + Optional. The port to use for the token. Defaults to "8080". + timeout (int): + Optional. The timeout in seconds for the token. Defaults to 3600. + + Returns: + str: The signed JWT. + """ + client = iam_credentials_v1.IAMCredentialsClient() + name = f"projects/-/serviceAccounts/{service_account_email}" + custom_claims = {"port": port, "sandbox_id": sandbox_id} + payload = { + "iat": int(time.time()), + "exp": int(time.time()) + timeout, + "iss": service_account_email, + "nonce": secrets.randbelow(1000000000) + 1, + "aud": "vmaas-proxy-api", # default audience for sandbox proxy + **custom_claims, + } + request = iam_credentials_v1.SignJwtRequest( + name=name, + payload=json.dumps(payload), + ) + response = client.sign_jwt(request=request) + return response.signed_jwt + + def send_command( + self, + *, + http_method: str, + access_token: str, + sandbox_environment: types.SandboxEnvironment, + path: str = None, + query_params: Optional[dict[str, object]] = None, + headers: Optional[dict[str, str]] = None, + request_dict: Optional[dict[str, object]] = None, + ) -> genai_types.HttpResponse: + """Sends a command to the sandbox. + + Args: + http_method (str): + Required. The HTTP method to use for the command. + access_token (str): + Required. The access token to use for authorization. + sandbox_environment (types.SandboxEnvironment): + Required. The sandbox environment to send the command to. + path (str): + Optional. The path to send the command to. + query_params (dict[str, object]): + Optional. The query parameters to include in the command. + headers (dict[str, str]): + Optional. The headers to include in the command. + request_dict (dict[str, object]): + Optional. The request body to include in the command. + + Returns: + genai_types.HttpResponse: The response from the sandbox. + """ + headers = headers or {} + request_dict = request_dict or {} + connection_info = sandbox_environment.connection_info + if not connection_info: + raise ValueError("Connection info is not available.") + if connection_info.load_balancer_hostname: + endpoint = "https://" + connection_info.load_balancer_hostname + elif connection_info.load_balancer_ip: + endpoint = "http://" + connection_info.load_balancer_ip + else: + raise ValueError("Load balancer hostname or ip is not available.") + + path = path or "" + if query_params: + path = f"{path}?{urlencode(query_params)}" + headers["Authorization"] = f"Bearer {access_token}" + endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path + http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint) + http_client = genai.Client(vertexai=True, http_options=http_options) + # Full path is constructed in this function. The passed in path into request + # function will not be used. + response = http_client._api_client.request(http_method, path, request_dict) + return genai_types.HttpResponse( + headers=response.headers, + body=response.body, + ) + + def generate_browser_ws_headers( + self, + sandbox_environment: types.SandboxEnvironment, + service_account_email: str, + timeout: int = 3600, + ) -> tuple[str, dict[str, str]]: + """Generates the websocket upgrade headers for the browser. + + Args: + sandbox_environment (types.SandboxEnvironment): + Required. The sandbox environment to generate websocket headers for. + service_account_email (str): + Required. The email of the service account to use for signing. + timeout (int): + Optional. The timeout in seconds for the token. Defaults to 3600. + + Returns: + tuple[str, dict[str, str]]: A tuple containing the websocket URL and + the headers for websocket upgrade. + """ + sandbox_id = sandbox_environment.name + # port 8080 is the default port for http endpoint. + http_access_token = self.generate_access_token( + service_account_email, sandbox_id, "8080", timeout + ) + response = self.send_command( + http_method="GET", + access_token=http_access_token, + sandbox_environment=sandbox_environment, + path="/cdp_ws_endpoint", + ) + if not response: + raise ValueError("Failed to get the websocket endpoint.") + body_dict = json.loads(response.body) + ws_path = body_dict["endpoint"] + + ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog" + if sandbox_environment and sandbox_environment.connection_info: + connection_info = sandbox_environment.connection_info + if connection_info.load_balancer_hostname: + ws_url = "wss://" + connection_info.load_balancer_hostname + elif connection_info.load_balancer_ip: + ws_url = "ws://" + connection_info.load_balancer_ip + else: + raise ValueError("Load balancer hostname or ip is not available.") + ws_url = ws_url + "/" + ws_path + + # port 9222 is the default port for the browser websocket endpoint. + ws_access_token = self.generate_access_token( + service_account_email, sandbox_id, "9222", timeout + ) + + headers = {} + headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}" + return ws_url, headers + class AsyncSandboxes(_api_module.BaseModule):