Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions python/copilot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import subprocess
import threading
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, List, Optional, cast
from typing import Any, Optional, cast

from .generated.session_events import session_event_from_dict
from .jsonrpc import JsonRpcClient
Expand All @@ -28,9 +28,11 @@
from .types import (
ConnectionState,
CopilotClientOptions,
CustomAgentConfig,
GetAuthStatusResponse,
GetStatusResponse,
ModelInfo,
ProviderConfig,
ResumeSessionConfig,
SessionConfig,
SessionMetadata,
Expand Down Expand Up @@ -132,7 +134,7 @@ def __init__(self, options: Optional[CopilotClientOptions] = None):
self._process: Optional[subprocess.Popen] = None
self._client: Optional[JsonRpcClient] = None
self._state: ConnectionState = "disconnected"
self._sessions: Dict[str, CopilotSession] = {}
self._sessions: dict[str, CopilotSession] = {}
self._sessions_lock = threading.Lock()

def _parse_cli_url(self, url: str) -> tuple[str, int]:
Expand Down Expand Up @@ -218,7 +220,7 @@ async def start(self) -> None:
self._state = "error"
raise

async def stop(self) -> List[Dict[str, str]]:
async def stop(self) -> list[dict[str, str]]:
"""
Stop the CLI server and close all active sessions.

Expand All @@ -237,7 +239,7 @@ async def stop(self) -> List[Dict[str, str]]:
... for error in errors:
... print(f"Cleanup error: {error['message']}")
"""
errors: List[Dict[str, str]] = []
errors: list[dict[str, str]] = []

# Atomically take ownership of all sessions and clear the dict
# so no other thread can access them
Expand Down Expand Up @@ -357,7 +359,7 @@ async def create_session(self, config: Optional[SessionConfig] = None) -> Copilo
definition["parameters"] = tool.parameters
tool_defs.append(definition)

payload: Dict[str, Any] = {}
payload: dict[str, Any] = {}
if cfg.get("model"):
payload["model"] = cfg["model"]
if cfg.get("session_id"):
Expand Down Expand Up @@ -482,7 +484,7 @@ async def resume_session(
definition["parameters"] = tool.parameters
tool_defs.append(definition)

payload: Dict[str, Any] = {"sessionId": session_id}
payload: dict[str, Any] = {"sessionId": session_id}
if tool_defs:
payload["tools"] = tool_defs

Expand Down Expand Up @@ -612,7 +614,7 @@ async def get_auth_status(self) -> "GetAuthStatusResponse":

return await self._client.request("auth.getStatus", {})

async def list_models(self) -> List["ModelInfo"]:
async def list_models(self) -> list["ModelInfo"]:
"""
List available models with their metadata.

Expand All @@ -634,7 +636,7 @@ async def list_models(self) -> List["ModelInfo"]:
response = await self._client.request("models.list", {})
return response.get("models", [])

async def list_sessions(self) -> List["SessionMetadata"]:
async def list_sessions(self) -> list["SessionMetadata"]:
"""
List all available sessions known to the server.

Expand Down Expand Up @@ -710,7 +712,9 @@ async def _verify_protocol_version(self) -> None:
f"Please update your SDK or server to ensure compatibility."
)

def _convert_provider_to_wire_format(self, provider: Dict[str, Any]) -> Dict[str, Any]:
def _convert_provider_to_wire_format(
self, provider: ProviderConfig | dict[str, Any]
) -> dict[str, Any]:
"""
Convert provider config from snake_case to camelCase wire format.

Expand All @@ -720,7 +724,7 @@ def _convert_provider_to_wire_format(self, provider: Dict[str, Any]) -> Dict[str
Returns:
The provider configuration in camelCase wire format.
"""
wire_provider: Dict[str, Any] = {"type": provider.get("type")}
wire_provider: dict[str, Any] = {"type": provider.get("type")}
if "base_url" in provider:
wire_provider["baseUrl"] = provider["base_url"]
if "api_key" in provider:
Expand All @@ -731,14 +735,16 @@ def _convert_provider_to_wire_format(self, provider: Dict[str, Any]) -> Dict[str
wire_provider["bearerToken"] = provider["bearer_token"]
if "azure" in provider:
azure = provider["azure"]
wire_azure: Dict[str, Any] = {}
wire_azure: dict[str, Any] = {}
if "api_version" in azure:
wire_azure["apiVersion"] = azure["api_version"]
if wire_azure:
wire_provider["azure"] = wire_azure
return wire_provider

def _convert_custom_agent_to_wire_format(self, agent: Dict[str, Any]) -> Dict[str, Any]:
def _convert_custom_agent_to_wire_format(
self, agent: CustomAgentConfig | dict[str, Any]
) -> dict[str, Any]:
"""
Convert custom agent config from snake_case to camelCase wire format.

Expand All @@ -748,7 +754,7 @@ def _convert_custom_agent_to_wire_format(self, agent: Dict[str, Any]) -> Dict[st
Returns:
The custom agent configuration in camelCase wire format.
"""
wire_agent: Dict[str, Any] = {"name": agent.get("name"), "prompt": agent.get("prompt")}
wire_agent: dict[str, Any] = {"name": agent.get("name"), "prompt": agent.get("prompt")}
if "display_name" in agent:
wire_agent["displayName"] = agent["display_name"]
if "description" in agent:
Expand Down
7 changes: 4 additions & 3 deletions python/copilot/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import json
import threading
import uuid
from typing import Any, Awaitable, Callable, Dict, Optional, Union
from collections.abc import Awaitable
from typing import Any, Callable, Optional, Union


class JsonRpcError(Exception):
Expand Down Expand Up @@ -41,9 +42,9 @@ def __init__(self, process):
process: subprocess.Popen with stdin=PIPE, stdout=PIPE
"""
self.process = process
self.pending_requests: Dict[str, asyncio.Future] = {}
self.pending_requests: dict[str, asyncio.Future] = {}
self.notification_handler: Optional[Callable[[str, dict], None]] = None
self.request_handlers: Dict[str, RequestHandler] = {}
self.request_handlers: dict[str, RequestHandler] = {}
self._running = False
self._read_thread: Optional[threading.Thread] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
Expand Down
10 changes: 5 additions & 5 deletions python/copilot/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import asyncio
import inspect
import threading
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Optional

from .generated.session_events import SessionEvent, SessionEventType, session_event_from_dict
from .types import (
Expand Down Expand Up @@ -62,9 +62,9 @@ def __init__(self, session_id: str, client: Any):
"""
self.session_id = session_id
self._client = client
self._event_handlers: Set[Callable[[SessionEvent], None]] = set()
self._event_handlers: set[Callable[[SessionEvent], None]] = set()
self._event_handlers_lock = threading.Lock()
self._tool_handlers: Dict[str, ToolHandler] = {}
self._tool_handlers: dict[str, ToolHandler] = {}
self._tool_handlers_lock = threading.Lock()
self._permission_handler: Optional[PermissionHandler] = None
self._permission_handler_lock = threading.Lock()
Expand Down Expand Up @@ -220,7 +220,7 @@ def _dispatch_event(self, event: SessionEvent) -> None:
except Exception as e:
print(f"Error in session event handler: {e}")

def _register_tools(self, tools: Optional[List[Tool]]) -> None:
def _register_tools(self, tools: Optional[list[Tool]]) -> None:
"""
Register custom tool handlers for this session.

Expand Down Expand Up @@ -307,7 +307,7 @@ async def _handle_permission_request(self, request: dict) -> dict:
# Handler failed, deny permission
return {"kind": "denied-no-approval-rule-and-could-not-request-from-user"}

async def get_messages(self) -> List[SessionEvent]:
async def get_messages(self) -> list[SessionEvent]:
"""
Retrieve all events and messages from this session's history.

Expand Down
6 changes: 3 additions & 3 deletions python/copilot/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import inspect
import json
from typing import Any, Callable, Type, TypeVar, get_type_hints, overload
from typing import Any, Callable, TypeVar, get_type_hints, overload

from pydantic import BaseModel

Expand All @@ -33,7 +33,7 @@ def define_tool(
*,
description: str | None = None,
handler: Callable[[T, ToolInvocation], R],
params_type: Type[T],
params_type: type[T],
) -> Tool: ...


Expand All @@ -42,7 +42,7 @@ def define_tool(
*,
description: str | None = None,
handler: Callable[[Any, ToolInvocation], Any] | None = None,
params_type: Type[BaseModel] | None = None,
params_type: type[BaseModel] | None = None,
) -> Tool | Callable[[Callable[[Any, ToolInvocation], Any]], Tool]:
"""
Define a tool with automatic JSON schema generation from Pydantic models.
Expand Down
Loading
Loading