|
19 | 19 | import json |
20 | 20 | import logging |
21 | 21 | import mimetypes |
| 22 | +import secrets |
| 23 | +import time |
22 | 24 | from typing import Any, Iterator, Optional, Union |
23 | 25 | from urllib.parse import urlencode |
24 | 26 |
|
| 27 | +from google import genai |
| 28 | +from google.cloud import iam_credentials_v1 |
25 | 29 | from google.genai import _api_module |
26 | 30 | from google.genai import _common |
| 31 | +from google.genai import types as genai_types |
27 | 32 | from google.genai._common import get_value_by_path as getv |
28 | 33 | from google.genai._common import set_value_by_path as setv |
29 | 34 | from google.genai.pagers import Pager |
@@ -704,6 +709,64 @@ def delete( |
704 | 709 | """ |
705 | 710 | return self._delete(name=name, config=config) |
706 | 711 |
|
| 712 | + def generate_access_token( |
| 713 | + self, |
| 714 | + service_account_email: str, |
| 715 | + sandbox_id: str, |
| 716 | + port: str = "8080", |
| 717 | + timeout: int = 3600, |
| 718 | + ) -> str: |
| 719 | + """Signs a JWT with a Google Cloud service account.""" |
| 720 | + client = iam_credentials_v1.IAMCredentialsClient() |
| 721 | + name = f"projects/-/serviceAccounts/{service_account_email}" |
| 722 | + custom_claims = {"port": port, "sandbox_id": sandbox_id} |
| 723 | + payload = { |
| 724 | + "iat": int(time.time()), |
| 725 | + "exp": int(time.time()) + timeout, |
| 726 | + "iss": service_account_email, |
| 727 | + "nonce": secrets.randbelow(1000000000) + 1, |
| 728 | + "aud": "vmaas-proxy-api", # default audience for sandbox proxy |
| 729 | + **custom_claims, |
| 730 | + } |
| 731 | + request = iam_credentials_v1.SignJwtRequest( |
| 732 | + name=name, |
| 733 | + payload=json.dumps(payload), |
| 734 | + ) |
| 735 | + response = client.sign_jwt(request=request) |
| 736 | + return response.signed_jwt |
| 737 | + |
| 738 | + def send_command( |
| 739 | + self, |
| 740 | + http_method: str, |
| 741 | + access_token: str, |
| 742 | + sandbox_environment: types.SandboxEnvironment, |
| 743 | + path: str = None, |
| 744 | + query_params: Optional[dict[str, object]] = None, |
| 745 | + headers: Optional[dict[str, str]] = None, |
| 746 | + request_dict: Optional[dict[str, object]] = None, |
| 747 | + ) -> str | None: |
| 748 | + """Sends a command to the sandbox.""" |
| 749 | + headers = headers or {} |
| 750 | + connection_info = sandbox_environment.connection_info |
| 751 | + if not connection_info: |
| 752 | + raise ValueError("Connection info is not available.") |
| 753 | + if connection_info.load_balancer_hostname: |
| 754 | + endpoint = "https://" + connection_info.load_balancer_hostname |
| 755 | + elif connection_info.load_balancer_ip: |
| 756 | + endpoint = "http://" + connection_info.load_balancer_ip |
| 757 | + else: |
| 758 | + endpoint = "https://test-us-central1.autopush-sandbox.vertexai.goog" |
| 759 | + |
| 760 | + path = path or "" |
| 761 | + if query_params: |
| 762 | + path = f"{path}?{urlencode(query_params)}" |
| 763 | + |
| 764 | + headers["Authorization"] = f"Bearer {access_token}" |
| 765 | + http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint) |
| 766 | + http_client = genai.Client(vertexai=True, http_options=http_options) |
| 767 | + response = http_client._api_client.request(http_method, path, request_dict) |
| 768 | + return response |
| 769 | + |
707 | 770 |
|
708 | 771 | class AsyncSandboxes(_api_module.BaseModule): |
709 | 772 |
|
|
0 commit comments