Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 69 additions & 11 deletions src/applypilot/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
GEMINI_API_KEY, OPENAI_API_KEY, ANTHROPIC_API_KEY, then LLM_URL.
- Credentials come from provider env vars or generic LLM_API_KEY.
- LLM_URL is optional for custom OpenAI-compatible endpoints.
- LLM_STREAMING_MODE enables streaming mode for LLM proxies that require it.
"""

from __future__ import annotations
Expand Down Expand Up @@ -51,6 +52,7 @@ class LLMConfig:
api_base: str | None
model: str
api_key: str
use_streaming: bool = False


class ChatMessage(TypedDict):
Expand Down Expand Up @@ -79,9 +81,7 @@ def _env_get(env: Mapping[str, str], key: str) -> str:
def _provider_from_model(model: str) -> str:
provider, _, model_name = model.partition("/")
if not provider or not model_name:
raise RuntimeError(
"LLM_MODEL must include a provider prefix (for example 'openai/gpt-4o-mini')."
)
raise RuntimeError("LLM_MODEL must include a provider prefix (for example 'openai/gpt-4o-mini').")
return provider


Expand All @@ -106,9 +106,7 @@ def resolve_llm_config(env: Mapping[str, str] | None = None) -> LLMConfig:
provider, _ = inferred
model = f"{provider}/{model}"
else:
raise RuntimeError(
"LLM_MODEL must include a provider prefix (for example 'openai/gpt-4o-mini')."
)
raise RuntimeError("LLM_MODEL must include a provider prefix (for example 'openai/gpt-4o-mini').")
else:
if not inferred:
raise RuntimeError(
Expand All @@ -130,21 +128,21 @@ def resolve_llm_config(env: Mapping[str, str] | None = None) -> LLMConfig:
api_key = _env_get(env_map, api_key_env) or _env_get(env_map, "LLM_API_KEY")

if not api_key and not local_url:
key_help = (
f"{api_key_env} or LLM_API_KEY"
if provider in provider_api_key_env
else "LLM_API_KEY"
)
key_help = f"{api_key_env} or LLM_API_KEY" if provider in provider_api_key_env else "LLM_API_KEY"
raise RuntimeError(
f"Missing credentials for LLM_MODEL '{model}'. Set {key_help}, or set LLM_URL for "
"a local OpenAI-compatible endpoint."
)

# Check if streaming mode is enabled via environment variable
use_streaming = _env_get(env_map, "LLM_STREAMING_MODE").lower() in ("true", "1", "yes")

return LLMConfig(
provider=provider,
api_base=local_url.rstrip("/") if local_url else None,
model=model,
api_key=api_key,
use_streaming=use_streaming,
)


Expand All @@ -155,6 +153,7 @@ def __init__(self, config: LLMConfig) -> None:
self.config = config
self.provider = config.provider
self.model = config.model
self._use_streaming = config.use_streaming
litellm.suppress_debug_info = True

def chat(
Expand All @@ -169,6 +168,18 @@ def chat(
**extra: Unpack[LiteLLMExtra],
) -> str:
"""Send a completion request and return plain text content."""
# Use streaming mode when configured (required by some LLM proxies)
if self._use_streaming:
return self._chat_streaming(
messages=messages,
max_output_tokens=max_output_tokens,
temperature=temperature,
num_retries=num_retries,
drop_params=drop_params,
**extra,
)

# Standard non-streaming call
try:
if temperature is None:
response = litellm.completion(
Expand Down Expand Up @@ -208,6 +219,53 @@ def chat(
except Exception as exc: # pragma: no cover - provider SDK exception types vary by backend/version.
raise RuntimeError(f"LLM request failed ({self.provider}/{self.model}): {exc}") from exc

def _chat_streaming(
self,
messages: list[ChatMessage],
*,
max_output_tokens: int = 10000,
temperature: float | None = None,
num_retries: int = _MAX_RETRIES,
drop_params: bool = True,
**extra: Unpack[LiteLLMExtra],
) -> str:
"""Use streaming completion mode.

Some LLM proxies require streaming mode. This method uses stream=True
and accumulates the chunks into a plain text response.
"""
try:
kwargs: dict[str, Any] = {
"model": self.model,
"messages": messages,
"max_tokens": max_output_tokens,
"num_retries": num_retries,
"drop_params": drop_params,
"api_key": self.config.api_key or None,
"api_base": self.config.api_base or None,
"stream": True,
}
if temperature is not None:
kwargs["temperature"] = temperature

response = litellm.completion(**kwargs)

# Accumulate content from streaming chunks
content_parts = []
for chunk in response:
if hasattr(chunk, "choices") and chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, "content") and delta.content:
content_parts.append(delta.content)

text = "".join(content_parts).strip()

if not text:
raise RuntimeError("LLM response contained no text content.")
return text
except Exception as exc:
raise RuntimeError(f"LLM request failed ({self.provider}/{self.model}): {exc}") from exc

def close(self) -> None:
"""No-op. LiteLLM completion() is stateless per call."""
return None
Expand Down