diff --git a/examples/ov.conf.example b/examples/ov.conf.example index f9c18010c..2ef493ada 100644 --- a/examples/ov.conf.example +++ b/examples/ov.conf.example @@ -14,7 +14,8 @@ "volcengine": { "region": "cn-beijing", "ak": null, - "sk": null + "sk": null, + "session_token": null } }, "agfs": { diff --git a/openviking/storage/vectordb/collection/volcengine_clients.py b/openviking/storage/vectordb/collection/volcengine_clients.py index 9eac48460..7615d7697 100644 --- a/openviking/storage/vectordb/collection/volcengine_clients.py +++ b/openviking/storage/vectordb/collection/volcengine_clients.py @@ -21,11 +21,12 @@ class ClientForConsoleApi: "cn-guangzhou": "vikingdb.cn-guangzhou.volcengineapi.com", } - def __init__(self, ak, sk, region, host=None): + def __init__(self, ak, sk, region, host=None, session_token=None): self.ak = ak self.sk = sk self.region = region self.host = host if host else ClientForConsoleApi._global_host[region] + self.session_token = session_token or "" if not all([self.ak, self.sk, self.host, self.region]): raise ValueError("AK, SK, Host, and Region are required for ClientForConsoleApi") @@ -54,7 +55,13 @@ def prepare_request(self, method, params=None, data=None): if data is not None: r.set_body(json.dumps(data)) - credentials = Credentials(self.ak, self.sk, "vikingdb", self.region) + credentials = Credentials( + self.ak, + self.sk, + "vikingdb", + self.region, + session_token=self.session_token, + ) SignerV4.sign(r, credentials) return r @@ -64,7 +71,7 @@ def do_req(self, req_method, req_params=None, req_body=None): method=req.method, url=f"https://{self.host}{req.path}", headers=req.headers, - params=req_params, + params=req.query, data=req.body, timeout=DEFAULT_TIMEOUT, ) @@ -77,11 +84,12 @@ class ClientForDataApi: "cn-guangzhou": "api-vikingdb.vikingdb.cn-guangzhou.volces.com", } - def __init__(self, ak, sk, region, host=None): + def __init__(self, ak, sk, region, host=None, session_token=None): self.ak = ak self.sk = sk self.region = region self.host = host if host else ClientForDataApi._global_host[region] + self.session_token = session_token or "" if not all([self.ak, self.sk, self.host, self.region]): raise ValueError("AK, SK, Host, and Region are required for ClientForDataApi") @@ -110,7 +118,13 @@ def prepare_request(self, method, path, params=None, data=None): if data is not None: r.set_body(json.dumps(data)) - credentials = Credentials(self.ak, self.sk, "vikingdb", self.region) + credentials = Credentials( + self.ak, + self.sk, + "vikingdb", + self.region, + session_token=self.session_token, + ) SignerV4.sign(r, credentials) return r @@ -122,7 +136,7 @@ def do_req(self, req_method, req_path, req_params=None, req_body=None): method=req.method, url=f"https://{self.host}{req.path}", headers=req.headers, - params=req_params, + params=req.query, data=req.body, timeout=DEFAULT_TIMEOUT, ) diff --git a/openviking/storage/vectordb/collection/volcengine_collection.py b/openviking/storage/vectordb/collection/volcengine_collection.py index 855eff889..45e58aef8 100644 --- a/openviking/storage/vectordb/collection/volcengine_collection.py +++ b/openviking/storage/vectordb/collection/volcengine_collection.py @@ -4,7 +4,8 @@ import json from typing import Any, Dict, List, Optional -from openviking.storage.vectordb.collection.collection import ICollection +from openviking.storage.errors import ConnectionError +from openviking.storage.vectordb.collection.collection import Collection, ICollection from openviking.storage.vectordb.collection.result import ( AggregateResult, DataItem, @@ -35,13 +36,16 @@ def get_or_create_volcengine_collection(config: Dict[str, Any], meta_data: Dict[ ak = config.get("AK") sk = config.get("SK") region = config.get("Region") + session_token = config.get("SessionToken") + if not ak or not sk or not region: + raise ValueError("AK, SK, and Region are required in config") collection_name = meta_data.get("CollectionName") if not collection_name: raise ValueError("CollectionName is required in config") # Initialize Console client for creating Collection - client = ClientForConsoleApi(ak, sk, region) + client = ClientForConsoleApi(ak, sk, region, session_token=session_token) # Try to create Collection try: @@ -63,10 +67,15 @@ def get_or_create_volcengine_collection(config: Dict[str, Any], meta_data: Dict[ raise e logger.info(f"Collection {collection_name} created successfully") - return VolcengineCollection(ak, sk, region, meta_data=meta_data) - - # Return VolcengineCollection instance - return VolcengineCollection(ak=ak, sk=sk, region=region, meta_data=meta_data) + return Collection( + VolcengineCollection( + ak, + sk, + region, + session_token=session_token, + meta_data=meta_data, + ) + ) class VolcengineCollection(ICollection): @@ -76,19 +85,59 @@ def __init__( sk: str, region: str, host: Optional[str] = None, + session_token: Optional[str] = None, meta_data: Optional[Dict[str, Any]] = None, ): - self.console_client = ClientForConsoleApi(ak, sk, region, host) - self.data_client = ClientForDataApi(ak, sk, region, host) + self.console_client = ClientForConsoleApi( + ak, + sk, + region, + host, + session_token=session_token, + ) + self.data_client = ClientForDataApi( + ak, + sk, + region, + host, + session_token=session_token, + ) self.meta_data = meta_data if meta_data is not None else {} self.project_name = self.meta_data.get("ProjectName", "default") self.collection_name = self.meta_data.get("CollectionName", "") + @staticmethod + def _build_response_error(response: Any, action: str) -> ConnectionError: + try: + result = response.json() + except json.JSONDecodeError: + result = {} + + metadata = result.get("ResponseMetadata", {}) if isinstance(result, dict) else {} + error = metadata.get("Error", {}) if isinstance(metadata, dict) else {} + code = error.get("Code", "UnknownError") + message = error.get("Message", response.text) + return ConnectionError( + f"Request to {action} failed: {response.status_code} {code} {message}" + ) + + @staticmethod + def _is_collection_not_found(response: Any, action: str) -> bool: + if action != "GetVikingdbCollection" or response.status_code != 404: + return False + try: + result = response.json() + except json.JSONDecodeError: + return False + metadata = result.get("ResponseMetadata", {}) if isinstance(result, dict) else {} + error = metadata.get("Error", {}) if isinstance(metadata, dict) else {} + return error.get("Code") == "NotFound.VikingdbCollection" + def _console_post(self, data: Dict[str, Any], action: str): params = {"Action": action, "Version": VIKING_DB_VERSION} response = self.console_client.do_req("POST", req_params=params, req_body=data) if response.status_code != 200: - logger.error(f"Request to {action} failed: {response.text}") + logger.error(str(self._build_response_error(response, action))) return {} try: result = response.json() @@ -103,11 +152,10 @@ def _console_get(self, params: Optional[Dict[str, Any]], action: str): params = {} req_params = {"Action": action, "Version": VIKING_DB_VERSION} req_body = params - response = self.console_client.do_req("POST", req_params=req_params, req_body=req_body) if response.status_code != 200: - logger.error(f"Request to {action} failed: {response.text}") + logger.error(str(self._build_response_error(response, action))) return {} try: result = response.json() diff --git a/openviking/storage/vectordb_adapters/volcengine_adapter.py b/openviking/storage/vectordb_adapters/volcengine_adapter.py index d06b0e84f..0b8ca021a 100644 --- a/openviking/storage/vectordb_adapters/volcengine_adapter.py +++ b/openviking/storage/vectordb_adapters/volcengine_adapter.py @@ -24,15 +24,18 @@ def __init__( ak: str, sk: str, region: str, + session_token: str | None, project_name: str, collection_name: str, index_name: str, ): super().__init__(collection_name=collection_name, index_name=index_name) + self._collection: Collection | None = None self.mode = "volcengine" self._ak = ak self._sk = sk self._region = region + self._session_token = session_token self._project_name = project_name @classmethod @@ -48,6 +51,7 @@ def from_config(cls, config: Any): ak=config.volcengine.ak, sk=config.volcengine.sk, region=config.volcengine.region, + session_token=config.volcengine.session_token, project_name=config.project_name or "default", collection_name=config.name or "context", index_name=config.index_name or "default", @@ -64,14 +68,18 @@ def _config(self) -> Dict[str, Any]: "AK": self._ak, "SK": self._sk, "Region": self._region, + "SessionToken": self._session_token, } - def _new_collection_handle(self) -> VolcengineCollection: - return VolcengineCollection( - ak=self._ak, - sk=self._sk, - region=self._region, - meta_data=self._meta(), + def _new_collection_handle(self) -> Collection: + return Collection( + VolcengineCollection( + ak=self._ak, + sk=self._sk, + region=self._region, + session_token=self._session_token, + meta_data=self._meta(), + ) ) def _load_existing_collection_if_needed(self) -> None: diff --git a/openviking_cli/utils/config/vectordb_config.py b/openviking_cli/utils/config/vectordb_config.py index 9a8f547b3..9f9f1739b 100644 --- a/openviking_cli/utils/config/vectordb_config.py +++ b/openviking_cli/utils/config/vectordb_config.py @@ -17,6 +17,10 @@ class VolcengineConfig(BaseModel): ak: Optional[str] = Field(default=None, description="Volcengine Access Key") sk: Optional[str] = Field(default=None, description="Volcengine Secret Key") + session_token: Optional[str] = Field( + default=None, + description="Optional Volcengine STS security token for temporary credentials", + ) region: Optional[str] = Field( default=None, description="Volcengine region (e.g., 'cn-beijing')" ) diff --git a/tests/storage/test_volcengine_clients.py b/tests/storage/test_volcengine_clients.py new file mode 100644 index 000000000..57ebc9adb --- /dev/null +++ b/tests/storage/test_volcengine_clients.py @@ -0,0 +1,180 @@ +from volcengine.base.Request import Request + +from openviking.storage.vectordb.collection.volcengine_clients import ( + ClientForConsoleApi, + ClientForDataApi, +) +from openviking.storage.vectordb.collection.volcengine_collection import VolcengineCollection +from openviking.storage.vectordb_adapters.volcengine_adapter import VolcengineCollectionAdapter +from openviking_cli.utils.config.vectordb_config import ( + VectorDBBackendConfig, + VolcengineConfig, +) + + +def test_console_client_prepare_request_includes_session_token(): + client = ClientForConsoleApi( + "test-ak", + "test-sk", + "cn-beijing", + session_token="test-session-token", + ) + + request = client.prepare_request( + "POST", + params={"Action": "ListVikingdbCollection", "Version": "2025-06-09"}, + data={"PageNumber": 1, "PageSize": 10}, + ) + + assert request.headers["X-Security-Token"] == "test-session-token" + assert "Authorization" in request.headers + + +def test_console_client_do_req_uses_signed_query_params(monkeypatch): + captured = {} + + def fake_request(**kwargs): + captured.update(kwargs) + return object() + + def fake_prepare_request(self, method, params=None, data=None): + request = Request() + request.method = method + request.path = "/" + request.body = '{"PageNumber": 1, "PageSize": 10}' + request.headers = {"Authorization": "signed-auth"} + request.query = { + "Action": "ListVikingdbCollection", + "Version": "2025-06-09", + "X-Date": "20260405T091640Z", + "X-Signature": "signed", + } + return request + + monkeypatch.setattr( + "openviking.storage.vectordb.collection.volcengine_clients.requests.request", + fake_request, + ) + monkeypatch.setattr(ClientForConsoleApi, "prepare_request", fake_prepare_request) + + client = ClientForConsoleApi("test-ak", "test-sk", "cn-beijing") + client.do_req( + "POST", + req_params={"Action": "ListVikingdbCollection", "Version": "2025-06-09"}, + req_body={"PageNumber": 1, "PageSize": 10}, + ) + + assert captured["params"]["X-Date"] == "20260405T091640Z" + assert captured["params"]["X-Signature"] == "signed" + + +def test_data_client_do_req_uses_signed_query_params(monkeypatch): + captured = {} + + def fake_request(**kwargs): + captured.update(kwargs) + return object() + + def fake_prepare_request(self, method, path, params=None, data=None): + request = Request() + request.method = method + request.path = path + request.body = '{"project": "default"}' + request.headers = {"Authorization": "signed-auth"} + request.query = { + "Action": "Search", + "Version": "2025-06-09", + "X-Date": "20260405T091640Z", + "X-Signature": "signed", + } + return request + + monkeypatch.setattr( + "openviking.storage.vectordb.collection.volcengine_clients.requests.request", + fake_request, + ) + monkeypatch.setattr(ClientForDataApi, "prepare_request", fake_prepare_request) + + client = ClientForDataApi("test-ak", "test-sk", "cn-beijing") + client.do_req( + "POST", + "/api/vikingdb/data/search/vector", + req_params={"Action": "Search", "Version": "2025-06-09"}, + req_body={"project": "default"}, + ) + + assert captured["params"]["X-Date"] == "20260405T091640Z" + assert captured["params"]["X-Signature"] == "signed" + + +def test_volcengine_adapter_preserves_session_token_from_config(): + config = VectorDBBackendConfig( + backend="volcengine", + name="context", + volcengine=VolcengineConfig( + ak="test-ak", + sk="test-sk", + region="cn-beijing", + session_token="test-session-token", + ), + ) + + adapter = VolcengineCollectionAdapter.from_config(config) + + assert adapter._config()["SessionToken"] == "test-session-token" + + +def test_volcengine_collection_get_meta_data_returns_empty_on_signature_error(monkeypatch): + class _Response: + status_code = 403 + text = "signature mismatch" + + @staticmethod + def json(): + return { + "ResponseMetadata": { + "Error": { + "Code": "SignatureDoesNotMatch", + "Message": "The request signature we calculated does not match", + } + } + } + + collection = VolcengineCollection( + ak="test-ak", + sk="test-sk", + region="cn-beijing", + meta_data={"ProjectName": "default", "CollectionName": "context"}, + ) + monkeypatch.setattr(collection.console_client, "do_req", lambda *args, **kwargs: _Response()) + + assert collection.get_meta_data() == {} + + +def test_volcengine_collection_get_meta_data_returns_empty_on_collection_not_found( + monkeypatch, +): + class _Response: + status_code = 404 + text = "collection not found" + + @staticmethod + def json(): + return { + "ResponseMetadata": { + "Error": { + "Code": "NotFound.VikingdbCollection", + "Message": "The specified collection 'context' of VikingDB does not exist.", + } + } + } + + collection = VolcengineCollection( + ak="test-ak", + sk="test-sk", + region="cn-beijing", + meta_data={"ProjectName": "default", "CollectionName": "context"}, + ) + monkeypatch.setattr(collection.console_client, "do_req", lambda *args, **kwargs: _Response()) + + assert collection.get_meta_data() == {}