diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index 6498926..c552245 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -198,8 +198,8 @@ def register_provider( "/" ) # Remove trailing slash if present - # default route name to provider name if not provided - route_name = route_name or provider_name + # Update Javelin headers into the client's _custom_headers + openai_client._custom_headers["x-javelin-provider"] = base_url_str openai_client._custom_headers["x-javelin-route"] = route_name # Store the original methods only if not already stored @@ -521,12 +521,12 @@ def get_nested_attr(obj, attr_path): return openai_client - def register_openai(self, openai_client: Any, route_name: str = "") -> Any: + def register_openai(self, openai_client: Any, route_name: str = None) -> Any: return self.register_provider( openai_client, provider_name="openai", route_name=route_name ) - def register_azureopenai(self, openai_client: Any, route_name: str = "") -> Any: + def register_azureopenai(self, openai_client: Any, route_name: str = None) -> Any: return self.register_provider( openai_client, provider_name="azureopenai", route_name=route_name ) @@ -546,7 +546,7 @@ def register_bedrock( bedrock_runtime_client: Any, bedrock_client: Any = None, bedrock_session: Any = None, - route_name: str = "", + route_name: str = None, ) -> None: """ Register an AWS Bedrock Runtime client @@ -583,7 +583,7 @@ def register_bedrock( self.bedrock_runtime_client = bedrock_runtime_client if not route_name: - route_name = "amazon" + route_name = "awsbedrock" # Store the default bedrock route if route_name is not None: diff --git a/javelin_sdk/models.py b/javelin_sdk/models.py index 128813c..871dca5 100644 --- a/javelin_sdk/models.py +++ b/javelin_sdk/models.py @@ -114,44 +114,31 @@ class ContentFilter(BaseModel): ) -class RouteConfig(BaseModel): - rate_limit: Optional[int] = Field( - default=None, description="Rate limit for the route" - ) - owner: Optional[str] = Field(default=None, description="Owner of the route") - organization: Optional[str] = Field( - default=None, description="Organization associated with the route" - ) - archive: Optional[bool] = Field( - default=None, description="Whether archiving is enabled" - ) - retries: Optional[int] = Field( - default=None, description="Number of retries for the route" - ) - llm_cache: bool = Field(False, description="Whether LLM cache is enabled") - role_to_assume: Optional[str] = Field( - None, description="Role to assume for the route" - ) - enable_telemetry: Optional[bool] = Field( - None, description="Whether telemetry is enabled" - ) +class ArchivePolicy(BaseModel): + enabled: Optional[bool] = Field(default=None, description="Whether archiving is enabled") retention: Optional[int] = Field(default=None, description="Data retention period") + + +class Policy(BaseModel): + dlp: Optional[Dlp] = Field(default=None, description="DLP configuration") + archive: Optional[ArchivePolicy] = Field(default=None, description="Archive policy configuration") + enabled: Optional[bool] = Field(default=None, description="Whether the policy is enabled") + prompt_safety: Optional[PromptSafety] = Field(default=None, description="Prompt Safety Description") + content_filter: Optional[ContentFilter] = Field(default=None, description="Content Filter Description") + security_filters: Optional[SecurityFilters] = Field(default=None, description="Security Filters Description") + + +class RouteConfig(BaseModel): + policy: Optional[Policy] = Field(default=None, description="Policy configuration") + retries: Optional[int] = Field(default=None, description="Number of retries for the route") + rate_limit: Optional[int] = Field(default=None, description="Rate limit for the route") + unified_endpoint: Optional[bool] = Field(default=None, description="Whether unified endpoint is enabled") request_chain: Optional[Dict[str, Any]] = Field( None, description="Request chain configuration" ) response_chain: Optional[Dict[str, Any]] = Field( None, description="Response chain configuration" ) - dlp: Optional[Dlp] = Field(default=None, description="DLP configuration") - content_filter: Optional[ContentFilter] = Field( - default=None, description="Content Filter Description" - ) - prompt_safety: Optional[PromptSafety] = Field( - default=None, description="Prompt Safety Description" - ) - security_filters: Optional[SecurityFilters] = Field( - default=None, description="Security Filters Description" - ) class Model(BaseModel): @@ -354,6 +341,7 @@ class SecretType(str, Enum): AWS = "aws" KUBERNETES = "kubernetes" + class Secret(BaseModel): api_key: str = Field(default=None, description="Key of the Secret") api_key_secret_name: str = Field(default=None, description="Name of the Secret") @@ -370,8 +358,6 @@ class Secret(BaseModel): enabled: Optional[bool] = Field( default=True, description="Whether the secret is enabled" ) - secret_name: str = Field(default=None, description="Secret Name of the Secret") - secrets_provider: SecretType = Field(default=SecretType.KUBERNETES, description="Type of the secret: aws or kubernetes") def masked(self): """