Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/kimi_cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class LLMProvider(BaseModel):
"""Environment variables to set before creating the provider instance"""
custom_headers: dict[str, str] | None = None
"""Custom headers to include in API requests"""
default_query: dict[str, str | int | float | bool] | None = None
"""Default query parameters to include in API requests (JSON-serializable scalars only)."""
oauth: OAuthRef | None = None
"""OAuth credential reference (do not store tokens here)."""

Expand Down
13 changes: 12 additions & 1 deletion src/kimi_cli/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Literal, cast, get_args
from typing import TYPE_CHECKING, Any, Literal, cast, get_args

from kosong.chat_provider import ChatProvider
from pydantic import SecretStr
Expand Down Expand Up @@ -103,6 +103,15 @@ def _kimi_default_headers(provider: LLMProvider, oauth: OAuthManager | None) ->
return headers


def _openai_client_kwargs(provider: LLMProvider) -> dict[str, Any]:
client_kwargs: dict[str, Any] = {}
if provider.custom_headers:
client_kwargs["default_headers"] = provider.custom_headers
if provider.default_query:
client_kwargs["default_query"] = provider.default_query
return client_kwargs


def create_llm(
provider: LLMProvider,
model: LLMModel,
Expand Down Expand Up @@ -152,6 +161,7 @@ def create_llm(
model=model.model,
base_url=provider.base_url,
api_key=resolved_api_key,
**_openai_client_kwargs(provider),
)
case "openai_responses":
from kosong.contrib.chat_provider.openai_responses import OpenAIResponses
Expand All @@ -160,6 +170,7 @@ def create_llm(
model=model.model,
base_url=provider.base_url,
api_key=resolved_api_key,
**_openai_client_kwargs(provider),
)
case "anthropic":
from kosong.contrib.chat_provider.anthropic import Anthropic
Expand Down
40 changes: 40 additions & 0 deletions tests/core/test_create_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from inline_snapshot import snapshot
from kosong.chat_provider.echo import EchoChatProvider
from kosong.chat_provider.kimi import Kimi
from kosong.contrib.chat_provider.openai_legacy import OpenAILegacy
from kosong.contrib.chat_provider.openai_responses import OpenAIResponses
from pydantic import SecretStr

from kimi_cli.config import LLMModel, LLMProvider
Expand Down Expand Up @@ -133,3 +135,41 @@ def test_create_llm_requires_base_url_for_kimi():
model = LLMModel(provider="kimi", model="kimi-base", max_context_size=4096)

assert create_llm(provider, model) is None


def test_create_llm_openai_legacy_passes_client_kwargs():
provider = LLMProvider(
type="openai_legacy",
base_url="https://openai.example/v1",
api_key=SecretStr("test-key"),
custom_headers={"x-test": "header"},
default_query={"api-version": "2024-05-01-preview"},
)
model = LLMModel(provider="openai", model="gpt-4o", max_context_size=4096)

llm = create_llm(provider, model)
assert llm is not None
assert isinstance(llm.chat_provider, OpenAILegacy)
assert llm.chat_provider._client_kwargs["default_headers"] == {"x-test": "header"}
assert llm.chat_provider._client_kwargs["default_query"] == {
"api-version": "2024-05-01-preview"
}


def test_create_llm_openai_responses_passes_client_kwargs():
provider = LLMProvider(
type="openai_responses",
base_url="https://openai.example/v1",
api_key=SecretStr("test-key"),
custom_headers={"x-test": "header"},
default_query={"api-version": "2024-05-01-preview"},
)
model = LLMModel(provider="openai-responses", model="gpt-4o", max_context_size=4096)

llm = create_llm(provider, model)
assert llm is not None
assert isinstance(llm.chat_provider, OpenAIResponses)
assert llm.chat_provider._client_kwargs["default_headers"] == {"x-test": "header"}
assert llm.chat_provider._client_kwargs["default_query"] == {
"api-version": "2024-05-01-preview"
}