Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
"opentelemetry-exporter-otlp-proto-http < 2",
"pydantic >= 2.11.1, < 3",
"typing_extensions",
"google-cloud-iam",
]

evaluation_extra_require = [
Expand Down Expand Up @@ -256,6 +257,7 @@
"bigframes; python_version>='3.10' and python_version<'3.14'",
# google-api-core 2.x is required since kfp requires protobuf > 4
"google-api-core >= 2.11, < 3.0.0",
"google-cloud-iam",
"grpcio-testing",
"grpcio-tools >= 1.63.0; python_version>='3.13'",
"ipython",
Expand Down
100 changes: 100 additions & 0 deletions tests/unit/vertexai/genai/test_sandbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import importlib
import os

from unittest import mock
from urllib.parse import urlencode

from google import auth
from google.auth import credentials as auth_credentials
from google.cloud import aiplatform
import vertexai
from google.cloud.aiplatform import initializer
from vertexai._genai import _agent_engines_utils
from vertexai._genai import types as _genai_types
from google.genai import client
from google.genai import types as genai_types
import pytest

_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())
_TEST_LOCATION = "us-central1"
_TEST_PROJECT = "test-project"
_TEST_RESOURCE_ID = "1028944691210842416"
_TEST_SANDBOX_ID = "sandbox-123"
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_AGENT_ENGINE_RESOURCE_NAME = (
f"{_TEST_PARENT}/reasoningEngines/{_TEST_RESOURCE_ID}"
)
_TEST_SANDBOX_RESOURCE_NAME = (
f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}/sandboxes/{_TEST_SANDBOX_ID}"
)
_TEST_AGENT_ENGINE_ENV_KEY = "GOOGLE_CLOUD_AGENT_ENGINE_ENV"
_TEST_AGENT_ENGINE_ENV_VALUE = "test_env_value"
_TEST_SERVICE_ACCOUNT_EMAIL = "test-sa@test-project.iam.gserviceaccount.com"


@pytest.fixture(scope="module")
def google_auth_mock():
with mock.patch.object(auth, "default") as google_auth_mock:
google_auth_mock.return_value = (
auth_credentials.AnonymousCredentials(),
_TEST_PROJECT,
)
yield google_auth_mock


@pytest.mark.usefixtures("google_auth_mock")
class TestSandbox:
def setup_method(self):
importlib.reload(initializer)
importlib.reload(aiplatform)
importlib.reload(vertexai)
os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE
self.client = vertexai.Client(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@mock.patch.object(client.Client, "_get_api_client")
def test_send_command(self, mock_get_api_client):
mock_sandbox = mock.Mock()
mock_sandbox.connection_info.load_balancer_ip = "127.0.0.1"
mock_sandbox.connection_info.load_balancer_hostname = None
mock_http_client = mock_get_api_client.return_value
mock_http_client.request.return_value = (
genai_types.HttpResponse(body=b"{}", headers={})
)

self.client.agent_engines.sandboxes.send_command(
http_method="GET",
access_token="test_token",
sandbox_environment=mock_sandbox,
path="test/path",
)

call_args = mock_get_api_client.call_args
assert call_args is not None
_, kwargs = call_args
http_options = kwargs["http_options"]
assert http_options.base_url == "http://127.0.0.1/test/path"
assert http_options.headers["Authorization"] == "Bearer test_token"

mock_http_client.request.assert_called_with("GET", "test/path", {})
160 changes: 160 additions & 0 deletions vertexai/_genai/sandboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
import json
import logging
import mimetypes
import secrets
import time
from typing import Any, Iterator, Optional, Union
from urllib.parse import urlencode

from google import genai
from google.cloud import iam_credentials_v1
from google.genai import _api_module
from google.genai import _common
from google.genai import types as genai_types
from google.genai._common import get_value_by_path as getv
from google.genai._common import set_value_by_path as setv
from google.genai.pagers import Pager
Expand Down Expand Up @@ -704,6 +709,161 @@ def delete(
"""
return self._delete(name=name, config=config)

def generate_access_token(
self,
service_account_email: str,
sandbox_id: str,
port: str = "8080",
timeout: int = 3600,
) -> str:
"""Signs a JWT with a Google Cloud service account.

Args:
service_account_email (str):
Required. The email of the service account to use for signing.
sandbox_id (str):
Required. The resource name of the sandbox to generate a token for.
port (str):
Optional. The port to use for the token. Defaults to "8080".
timeout (int):
Optional. The timeout in seconds for the token. Defaults to 3600.

Returns:
str: The signed JWT.
"""
client = iam_credentials_v1.IAMCredentialsClient()
name = f"projects/-/serviceAccounts/{service_account_email}"
custom_claims = {"port": port, "sandbox_id": sandbox_id}
payload = {
"iat": int(time.time()),
"exp": int(time.time()) + timeout,
"iss": service_account_email,
"nonce": secrets.randbelow(1000000000) + 1,
"aud": "vmaas-proxy-api", # default audience for sandbox proxy
**custom_claims,
}
request = iam_credentials_v1.SignJwtRequest(
name=name,
payload=json.dumps(payload),
)
response = client.sign_jwt(request=request)
return response.signed_jwt

def send_command(
self,
*,
http_method: str,
access_token: str,
sandbox_environment: types.SandboxEnvironment,
path: str = None,
query_params: Optional[dict[str, object]] = None,
headers: Optional[dict[str, str]] = None,
request_dict: Optional[dict[str, object]] = None,
) -> genai_types.HttpResponse:
"""Sends a command to the sandbox.

Args:
http_method (str):
Required. The HTTP method to use for the command.
access_token (str):
Required. The access token to use for authorization.
sandbox_environment (types.SandboxEnvironment):
Required. The sandbox environment to send the command to.
path (str):
Optional. The path to send the command to.
query_params (dict[str, object]):
Optional. The query parameters to include in the command.
headers (dict[str, str]):
Optional. The headers to include in the command.
request_dict (dict[str, object]):
Optional. The request body to include in the command.

Returns:
genai_types.HttpResponse: The response from the sandbox.
"""
headers = headers or {}
request_dict = request_dict or {}
connection_info = sandbox_environment.connection_info
if not connection_info:
raise ValueError("Connection info is not available.")
if connection_info.load_balancer_hostname:
endpoint = "https://" + connection_info.load_balancer_hostname
elif connection_info.load_balancer_ip:
endpoint = "http://" + connection_info.load_balancer_ip
else:
raise ValueError("Load balancer hostname or ip is not available.")

path = path or ""
if query_params:
path = f"{path}?{urlencode(query_params)}"
headers["Authorization"] = f"Bearer {access_token}"
endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path
http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint)
http_client = genai.Client(vertexai=True, http_options=http_options)
# Full path is constructed in this function. The passed in path into request
# function will not be used.
response = http_client._api_client.request(http_method, path, request_dict)
return genai_types.HttpResponse(
headers=response.headers,
body=response.body,
)

def generate_browser_ws_headers(
self,
sandbox_environment: types.SandboxEnvironment,
service_account_email: str,
timeout: int = 3600,
) -> tuple[str, dict[str, str]]:
"""Generates the websocket upgrade headers for the browser.

Args:
sandbox_environment (types.SandboxEnvironment):
Required. The sandbox environment to generate websocket headers for.
service_account_email (str):
Required. The email of the service account to use for signing.
timeout (int):
Optional. The timeout in seconds for the token. Defaults to 3600.

Returns:
tuple[str, dict[str, str]]: A tuple containing the websocket URL and
the headers for websocket upgrade.
"""
sandbox_id = sandbox_environment.name
# port 8080 is the default port for http endpoint.
http_access_token = self.generate_access_token(
service_account_email, sandbox_id, "8080", timeout
)
response = self.send_command(
http_method="GET",
access_token=http_access_token,
sandbox_environment=sandbox_environment,
path="/cdp_ws_endpoint",
)
if not response:
raise ValueError("Failed to get the websocket endpoint.")
body_dict = json.loads(response.body)
ws_path = body_dict["endpoint"]

ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
if sandbox_environment and sandbox_environment.connection_info:
connection_info = sandbox_environment.connection_info
if connection_info.load_balancer_hostname:
ws_url = "wss://" + connection_info.load_balancer_hostname
elif connection_info.load_balancer_ip:
ws_url = "ws://" + connection_info.load_balancer_ip
else:
raise ValueError("Load balancer hostname or ip is not available.")
ws_url = ws_url + "/" + ws_path

# port 9222 is the default port for the browser websocket endpoint.
ws_access_token = self.generate_access_token(
service_account_email, sandbox_id, "9222", timeout
)

headers = {}
headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}"
return ws_url, headers


class AsyncSandboxes(_api_module.BaseModule):

Expand Down
Loading