From 99a1e21fe835e650bf7886cf11bc1d6df865df01 Mon Sep 17 00:00:00 2001 From: actae0n <19864268+xpcmdshell@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:21:36 -0700 Subject: [PATCH] Use server-managed container session IDs --- .../execution/container/client.py | 51 ++++----- .../execution/container/executor.py | 41 ++++++- .../execution/container/server.py | 40 +++---- src/py_code_mode/integrations/autogen.py | 21 ++-- tests/container/test_client.py | 102 +++++++++++++++++- tests/container/test_container_auth.py | 11 ++ tests/container/test_executor.py | 43 ++++++++ tests/container/test_server.py | 40 ++++--- 8 files changed, 275 insertions(+), 74 deletions(-) diff --git a/src/py_code_mode/execution/container/client.py b/src/py_code_mode/execution/container/client.py index a538e3d..2092e5a 100644 --- a/src/py_code_mode/execution/container/client.py +++ b/src/py_code_mode/execution/container/client.py @@ -3,6 +3,7 @@ This client connects to a running session server and provides a Python API for code execution. Each client maintains its own isolated session with separate Python namespace and artifacts. +The server allocates the session on the first execute call. Usage: async with SessionClient("http://localhost:8080") as client: @@ -13,7 +14,6 @@ from __future__ import annotations -import uuid from dataclasses import dataclass from typing import Any @@ -64,7 +64,7 @@ class ResetResult: """Reset result.""" status: str - session_id: str + session_id: str | None class SessionClient: @@ -74,15 +74,15 @@ class SessionClient: - Separate Python namespace (variables don't leak between sessions) - Separate artifact directory - Use the same client instance across requests to maintain state, - or create a new client for a fresh isolated session. + Use the same client instance across requests to maintain state. + The server issues the session ID on first execution and the client + reuses it for later session-scoped requests. """ def __init__( self, base_url: str = "http://localhost:8080", timeout: float = 30.0, - session_id: str | None = None, auth_token: str | None = None, ) -> None: """Initialize session client. @@ -90,8 +90,6 @@ def __init__( Args: base_url: Base URL of session server. timeout: Default timeout for HTTP requests. - session_id: Optional session ID. If not provided, a new - unique session is created on first request. auth_token: Optional Bearer token for API authentication. If provided, sent as Authorization header. """ @@ -101,7 +99,7 @@ def __init__( # Strip trailing slash self.base_url = base_url.rstrip("/") self.timeout = timeout - self.session_id = session_id or str(uuid.uuid4()) + self.session_id: str | None = None self.auth_token = auth_token self._client: httpx.AsyncClient | None = None @@ -113,7 +111,9 @@ async def _get_client(self) -> httpx.AsyncClient: def _headers(self) -> dict[str, str]: """Get headers with session ID and optional auth token.""" - headers = {"X-Session-ID": self.session_id} + headers: dict[str, str] = {} + if self.session_id is not None: + headers["X-Session-ID"] = self.session_id if self.auth_token: headers["Authorization"] = f"Bearer {self.auth_token}" return headers @@ -145,16 +145,15 @@ async def execute( response.raise_for_status() data = response.json() - # Update session_id if server assigned one - if "session_id" in data: - self.session_id = data["session_id"] + session_id = data["session_id"] + self.session_id = session_id return ExecuteResult( value=data["value"], stdout=data["stdout"], error=data["error"], execution_time_ms=data["execution_time_ms"], - session_id=data.get("session_id", self.session_id), + session_id=session_id, ) async def health(self) -> HealthResult: @@ -198,6 +197,9 @@ async def reset(self) -> ResetResult: Returns: ResetResult confirming reset. """ + if self.session_id is None: + return ResetResult(status="reset", session_id=None) + client = await self._get_client() response = await client.post( f"{self.base_url}/reset", @@ -205,10 +207,11 @@ async def reset(self) -> ResetResult: ) response.raise_for_status() data = response.json() + self.session_id = None return ResetResult( status=data["status"], - session_id=data.get("session_id", self.session_id), + session_id=data.get("session_id"), ) async def install_deps(self, packages: list[str]) -> dict[str, Any]: @@ -230,10 +233,8 @@ async def install_deps(self, packages: list[str]) -> dict[str, Any]: headers=self._headers(), timeout=300.0, # Long timeout for package installation ) - data = response.json() - if response.status_code != 200: - raise RuntimeError(data.get("error", "Install failed")) - return data + response.raise_for_status() + return response.json() async def uninstall_deps(self, packages: list[str]) -> dict[str, Any]: """Uninstall packages from the container. @@ -254,10 +255,8 @@ async def uninstall_deps(self, packages: list[str]) -> dict[str, Any]: headers=self._headers(), timeout=120.0, # Reasonable timeout for uninstall ) - data = response.json() - if response.status_code != 200: - raise RuntimeError(data.get("error", "Uninstall failed")) - return data + response.raise_for_status() + return response.json() # ========================================================================== # Tools API Methods @@ -481,9 +480,7 @@ async def api_add_dep(self, package: str) -> dict[str, Any]: headers=self._headers(), timeout=300.0, # Long timeout for package installation ) - if response.status_code != 200: - data = response.json() - raise RuntimeError(data.get("detail", "Add dep failed")) + response.raise_for_status() return response.json() async def api_remove_dep(self, package: str) -> dict[str, Any]: @@ -504,9 +501,7 @@ async def api_remove_dep(self, package: str) -> dict[str, Any]: json={"package": package}, headers=self._headers(), ) - if response.status_code != 200: - data = response.json() - raise RuntimeError(data.get("detail", "Remove dep failed")) + response.raise_for_status() return response.json() async def api_sync_deps(self) -> dict[str, Any]: diff --git a/src/py_code_mode/execution/container/executor.py b/src/py_code_mode/execution/container/executor.py index 154f0b2..ec16903 100644 --- a/src/py_code_mode/execution/container/executor.py +++ b/src/py_code_mode/execution/container/executor.py @@ -196,6 +196,27 @@ async def __aexit__( """Stop container and cleanup.""" await self.close() + def _translate_client_http_error(self, error: Exception) -> RuntimeError: + """Translate low-level client HTTP errors into executor-style RuntimeErrors.""" + if HTTPX_AVAILABLE and isinstance(error, httpx.HTTPStatusError): + detail: str | None = None + try: + data = error.response.json() + except ValueError: + data = None + + if isinstance(data, dict): + value = data.get("detail") or data.get("error") + if isinstance(value, str) and value: + detail = value + + if detail is None: + detail = f"HTTP {error.response.status_code}: {error.response.text}" + + return RuntimeError(detail) + + return RuntimeError(str(error)) + def _create_docker_client(self) -> Any: """Create Docker client, trying multiple socket locations if needed.""" logger = logging.getLogger(__name__) @@ -541,7 +562,10 @@ async def install_deps(self, packages: list[str]) -> dict[str, Any]: if self._client is None: raise RuntimeError("Container not started") - return await self._client.install_deps(packages) + try: + return await self._client.install_deps(packages) + except Exception as e: + raise self._translate_client_http_error(e) from e async def uninstall_deps(self, packages: list[str]) -> dict[str, Any]: """Uninstall packages from the container environment. @@ -568,7 +592,10 @@ async def uninstall_deps(self, packages: list[str]) -> dict[str, Any]: if self._client is None: raise RuntimeError("Container not started") - return await self._client.uninstall_deps(packages) + try: + return await self._client.uninstall_deps(packages) + except Exception as e: + raise self._translate_client_http_error(e) from e # ========================================================================== # Tools API Methods @@ -751,7 +778,10 @@ async def add_dep(self, package: str) -> dict[str, Any]: if self._client is None: raise RuntimeError("Container not started") - return await self._client.api_add_dep(package) + try: + return await self._client.api_add_dep(package) + except Exception as e: + raise self._translate_client_http_error(e) from e async def remove_dep(self, package: str) -> dict[str, Any]: """Remove a package from configuration and uninstall it. @@ -770,7 +800,10 @@ async def remove_dep(self, package: str) -> dict[str, Any]: if self._client is None: raise RuntimeError("Container not started") - return await self._client.api_remove_dep(package) + try: + return await self._client.api_remove_dep(package) + except Exception as e: + raise self._translate_client_http_error(e) from e async def sync_deps(self) -> dict[str, Any]: """Install all configured packages. diff --git a/src/py_code_mode/execution/container/server.py b/src/py_code_mode/execution/container/server.py index e6e9718..435b122 100644 --- a/src/py_code_mode/execution/container/server.py +++ b/src/py_code_mode/execution/container/server.py @@ -313,24 +313,22 @@ def create_session(session_id: str) -> Session: ) -def get_or_create_session(session_id: str | None) -> Session: - """Get existing session or create a new one.""" - # Generate session_id if not provided - if session_id is None: - session_id = str(uuid.uuid4()) - - # Return existing session - if session_id in _state.sessions: - session = _state.sessions[session_id] - session.last_used = time.time() - return session - - # Create new session +def create_new_session() -> Session: + """Create a new isolated session with a server-issued ID.""" + session_id = str(uuid.uuid4()) session = create_session(session_id) _state.sessions[session_id] = session return session +def get_existing_session(session_id: str) -> Session | None: + """Get an existing session if present.""" + session = _state.sessions.get(session_id) + if session is not None: + session.last_used = time.time() + return session + + def cleanup_expired_sessions() -> int: """Remove sessions that haven't been used recently.""" now = time.time() @@ -603,8 +601,12 @@ async def execute( # Cleanup expired sessions periodically cleanup_expired_sessions() - # Get or create session - session = get_or_create_session(x_session_id) + if x_session_id is None: + session = create_new_session() + else: + session = get_existing_session(x_session_id) + if session is None: + raise HTTPException(status_code=400, detail="Invalid session ID") start = time.time() timeout = body.timeout or _state.config.default_timeout @@ -667,12 +669,14 @@ async def reset( x_session_id: str | None = Header(None, alias="X-Session-ID"), ) -> ResetResponseModel: """Reset a session (clears namespace, keeps artifacts).""" - if x_session_id and x_session_id in _state.sessions: - del _state.sessions[x_session_id] + if x_session_id is None or x_session_id not in _state.sessions: + raise HTTPException(status_code=400, detail="Invalid session ID") + + del _state.sessions[x_session_id] return ResetResponseModel( status="reset", - session_id=x_session_id or "", + session_id=x_session_id, ) # NOTE: /sessions endpoint removed - session enumeration is an information disclosure risk diff --git a/src/py_code_mode/integrations/autogen.py b/src/py_code_mode/integrations/autogen.py index 06e0390..a06d488 100644 --- a/src/py_code_mode/integrations/autogen.py +++ b/src/py_code_mode/integrations/autogen.py @@ -34,7 +34,6 @@ def create_run_code_tool( executor: InProcessExecutor | None = None, session_url: str | None = None, timeout: float = 30.0, - session_id: str | None = None, ) -> Callable[[str], Any]: """Create a run_code tool for AutoGen agents. @@ -45,9 +44,6 @@ def create_run_code_tool( executor: CodeExecutor instance for in-process execution session_url: URL of py-code-mode session server for remote execution timeout: Execution timeout in seconds - session_id: Optional session ID for remote execution. If not provided, - a unique session is created. Use this to isolate different - agents using the same session server. Returns: A function that can be registered as an AutoGen tool @@ -63,7 +59,7 @@ def create_run_code_tool( if executor is not None: return _create_local_tool(executor, timeout) else: - return _create_remote_tool(session_url, timeout, session_id) # type: ignore + return _create_remote_tool(session_url, timeout) # type: ignore def _create_local_tool( @@ -109,17 +105,14 @@ async def run_code(code: str) -> str: def _create_remote_tool( session_url: str, timeout: float, - session_id: str | None = None, ) -> Callable[[str], Any]: """Create tool using remote session server.""" # Lazy import to avoid requiring httpx for local-only usage - import uuid - import httpx - # Each tool instance gets its own session - _session_id = session_id or str(uuid.uuid4()) + # Session is created lazily by the server on first execute. + _session_id: str | None = None def run_code(code: str) -> str: """Execute Python code with access to tools.*, workflows.*, and artifacts.*. @@ -137,15 +130,21 @@ def run_code(code: str) -> str: Returns: String representation of the result or error message """ + nonlocal _session_id try: + headers = {} + if _session_id is not None: + headers["X-Session-ID"] = _session_id + with httpx.Client(timeout=timeout + 5) as client: response = client.post( f"{session_url.rstrip('/')}/execute", json={"code": code, "timeout": timeout}, - headers={"X-Session-ID": _session_id}, + headers=headers, ) response.raise_for_status() result = response.json() + _session_id = result["session_id"] if result.get("error"): return f"Error: {result['error']}" diff --git a/tests/container/test_client.py b/tests/container/test_client.py index 4759f39..7dd7826 100644 --- a/tests/container/test_client.py +++ b/tests/container/test_client.py @@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, MagicMock +import httpx import pytest from py_code_mode.execution.container.client import SessionClient @@ -33,6 +34,7 @@ def test_strips_trailing_slash(self) -> None: """Strips trailing slash from base URL.""" client = SessionClient(base_url="http://localhost:8080/") assert client.base_url == "http://localhost:8080" + assert client.session_id is None class TestSessionClientExecute: @@ -49,6 +51,7 @@ async def test_execute_simple_code(self) -> None: "stdout": "", "error": None, "execution_time_ms": 5.0, + "session_id": "server-session-1", } ) @@ -63,10 +66,13 @@ async def test_execute_simple_code(self) -> None: call_args = mock_http_client.post.call_args assert call_args[0][0] == "http://localhost:8080/execute" assert call_args[1]["json"]["code"] == "21 * 2" + assert call_args[1]["headers"] == {} assert result.value == 42 assert result.error is None assert result.stdout == "" + assert result.session_id == "server-session-1" + assert client.session_id == "server-session-1" @pytest.mark.asyncio async def test_execute_with_timeout(self) -> None: @@ -79,6 +85,7 @@ async def test_execute_with_timeout(self) -> None: "stdout": "", "error": None, "execution_time_ms": 100.0, + "session_id": "server-session-1", } ) @@ -102,6 +109,7 @@ async def test_execute_with_error(self) -> None: "stdout": "", "error": "ZeroDivisionError: division by zero", "execution_time_ms": 1.0, + "session_id": "server-session-1", } ) @@ -113,6 +121,41 @@ async def test_execute_with_error(self) -> None: assert result.error is not None assert "ZeroDivisionError" in result.error + assert client.session_id == "server-session-1" + + @pytest.mark.asyncio + async def test_execute_reuses_server_assigned_session_id(self) -> None: + """Second execute sends the server-issued session ID.""" + client = SessionClient() + + first_response = make_mock_response( + { + "value": 42, + "stdout": "", + "error": None, + "execution_time_ms": 5.0, + "session_id": "server-session-1", + } + ) + second_response = make_mock_response( + { + "value": 84, + "stdout": "", + "error": None, + "execution_time_ms": 5.0, + "session_id": "server-session-1", + } + ) + + mock_http_client = AsyncMock() + mock_http_client.post = AsyncMock(side_effect=[first_response, second_response]) + client._client = mock_http_client + + await client.execute("x = 42") + await client.execute("x * 2") + + second_call = mock_http_client.post.call_args_list[1] + assert second_call[1]["headers"]["X-Session-ID"] == "server-session-1" class TestSessionClientHealth: @@ -166,6 +209,8 @@ async def test_info_returns_tools_and_workflows(self) -> None: assert info.tools[0]["name"] == "cli.nmap" assert len(info.workflows) == 1 assert info.workflows[0]["name"] == "scan" + call_args = mock_http_client.get.call_args + assert call_args[1]["headers"] == {} class TestSessionClientReset: @@ -175,6 +220,7 @@ class TestSessionClientReset: async def test_reset_clears_state(self) -> None: """Reset returns status.""" client = SessionClient() + client.session_id = "server-session-1" mock_response = make_mock_response( { @@ -190,7 +236,61 @@ async def test_reset_clears_state(self) -> None: result = await client.reset() assert result.status == "reset" - assert result.session_id == client.session_id + assert result.session_id == "server-session-1" + assert client.session_id is None + + @pytest.mark.asyncio + async def test_reset_without_session_is_local_noop(self) -> None: + """Reset without a server-issued session does not make a request.""" + client = SessionClient() + mock_http_client = AsyncMock() + client._client = mock_http_client + + result = await client.reset() + + assert result.status == "reset" + assert result.session_id is None + mock_http_client.post.assert_not_called() + + +class TestSessionClientDeps: + """Tests for dependency-related HTTP behavior.""" + + @pytest.mark.asyncio + async def test_install_deps_raises_http_status_error(self) -> None: + """install_deps propagates HTTP auth failures via raise_for_status().""" + client = SessionClient(auth_token="wrong-token") + request = httpx.Request("POST", "http://localhost:8080/install_deps") + response = httpx.Response(401, request=request, json={"detail": "Invalid token"}) + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "401 Unauthorized", request=request, response=response + ) + + mock_http_client = AsyncMock() + mock_http_client.post = AsyncMock(return_value=mock_response) + client._client = mock_http_client + + with pytest.raises(httpx.HTTPStatusError): + await client.install_deps(["requests"]) + + @pytest.mark.asyncio + async def test_api_add_dep_raises_http_status_error(self) -> None: + """api_add_dep propagates HTTP auth failures via raise_for_status().""" + client = SessionClient(auth_token="wrong-token") + request = httpx.Request("POST", "http://localhost:8080/api/deps/add") + response = httpx.Response(401, request=request, json={"detail": "Invalid token"}) + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "401 Unauthorized", request=request, response=response + ) + + mock_http_client = AsyncMock() + mock_http_client.post = AsyncMock(return_value=mock_response) + client._client = mock_http_client + + with pytest.raises(httpx.HTTPStatusError): + await client.api_add_dep("requests") class TestSessionClientContextManager: diff --git a/tests/container/test_container_auth.py b/tests/container/test_container_auth.py index 5fb1390..09269a8 100644 --- a/tests/container/test_container_auth.py +++ b/tests/container/test_container_auth.py @@ -250,6 +250,16 @@ def test_protected_endpoint_with_valid_token_succeeds( """Protected endpoints succeed with valid token.""" headers = {"Authorization": f"Bearer {auth_token}"} + if endpoint == "/reset": + create_response = auth_enabled_client.post( + "/execute", + json={"code": "x = 42"}, + headers=headers, + ) + assert create_response.status_code == 200 + session_id = create_response.json()["session_id"] + headers["X-Session-ID"] = session_id + if method == "post": response = auth_enabled_client.post(endpoint, json=body, headers=headers) else: @@ -649,6 +659,7 @@ async def test_session_client_sends_authorization_header(self) -> None: headers = call_args[1].get("headers", {}) assert "Authorization" in headers assert headers["Authorization"] == "Bearer client-auth-token" + assert "X-Session-ID" not in headers await client.close() diff --git a/tests/container/test_executor.py b/tests/container/test_executor.py index 9b0781d..887992e 100644 --- a/tests/container/test_executor.py +++ b/tests/container/test_executor.py @@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest from py_code_mode.execution.container.config import ContainerConfig @@ -191,6 +192,48 @@ async def test_run_with_error(self, config) -> None: assert not result.is_ok assert "NameError" in result.error + @pytest.mark.asyncio + async def test_install_deps_translates_http_errors_to_runtime_error(self, config) -> None: + """install_deps preserves the executor RuntimeError contract.""" + executor = ContainerExecutor(config) + + mock_container = self._make_mock_container() + mock_docker = MagicMock() + mock_docker.containers.run.return_value = mock_container + + request = httpx.Request("POST", "http://localhost:8080/install_deps") + response = httpx.Response(401, request=request, json={"detail": "Invalid token"}) + error = httpx.HTTPStatusError("401 Unauthorized", request=request, response=response) + + with patch("docker.from_env", return_value=mock_docker): + with patch.object(executor, "_wait_for_healthy", new_callable=AsyncMock): + async with executor: + executor._client.install_deps = AsyncMock(side_effect=error) + + with pytest.raises(RuntimeError, match="Invalid token"): + await executor.install_deps(["requests"]) + + @pytest.mark.asyncio + async def test_add_dep_translates_http_errors_to_runtime_error(self, config) -> None: + """add_dep preserves the executor RuntimeError contract.""" + executor = ContainerExecutor(config) + + mock_container = self._make_mock_container() + mock_docker = MagicMock() + mock_docker.containers.run.return_value = mock_container + + request = httpx.Request("POST", "http://localhost:8080/api/deps/add") + response = httpx.Response(401, request=request, json={"detail": "Invalid token"}) + error = httpx.HTTPStatusError("401 Unauthorized", request=request, response=response) + + with patch("docker.from_env", return_value=mock_docker): + with patch.object(executor, "_wait_for_healthy", new_callable=AsyncMock): + async with executor: + executor._client.api_add_dep = AsyncMock(side_effect=error) + + with pytest.raises(RuntimeError, match="Invalid token"): + await executor.add_dep("requests") + def _make_mock_container() -> MagicMock: """Create a mock Docker container with proper port bindings.""" diff --git a/tests/container/test_server.py b/tests/container/test_server.py index 01d6c2b..1c7c929 100644 --- a/tests/container/test_server.py +++ b/tests/container/test_server.py @@ -97,6 +97,7 @@ def test_execute_simple_expression(self, client) -> None: data = response.json() assert data["value"] == 2 assert data["error"] is None + assert data["session_id"] def test_execute_with_stdout(self, client) -> None: """Captures stdout from print statements.""" @@ -118,12 +119,10 @@ def test_execute_with_error(self, client) -> None: def test_execute_state_persists(self, client) -> None: """Variables persist across executions within same session.""" - session_id = "test-persist-session" + create_response = client.post("/execute", json={"code": "x = 42"}) + session_id = create_response.json()["session_id"] headers = {"X-Session-ID": session_id} - # Set variable - client.post("/execute", json={"code": "x = 42"}, headers=headers) - # Access variable (same session) response = client.post("/execute", json={"code": "x * 2"}, headers=headers) @@ -134,21 +133,38 @@ def test_execute_state_persists(self, client) -> None: def test_reset_clears_state(self, client) -> None: """Reset clears session state.""" - session_id = "test-reset-session" + create_response = client.post("/execute", json={"code": "x = 42"}) + session_id = create_response.json()["session_id"] headers = {"X-Session-ID": session_id} - # Set variable - client.post("/execute", json={"code": "x = 42"}, headers=headers) - # Reset this session response = client.post("/reset", headers=headers) assert response.status_code == 200 - # Variable should be gone (new session created with same ID) + # Reusing a reset session ID is invalid. response = client.post("/execute", json={"code": "x"}, headers=headers) - data = response.json() - assert data["error"] is not None - assert "NameError" in data["error"] + assert response.status_code == 400 + + def test_execute_with_unknown_session_id_returns_400(self, client) -> None: + """Unknown session IDs are rejected instead of creating sessions.""" + response = client.post( + "/execute", + json={"code": "1 + 1"}, + headers={"X-Session-ID": "missing-session"}, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Invalid session ID" + + def test_reset_requires_known_session_id(self, client) -> None: + """Reset rejects missing or unknown session IDs.""" + missing_response = client.post("/reset") + assert missing_response.status_code == 400 + assert missing_response.json()["detail"] == "Invalid session ID" + + unknown_response = client.post("/reset", headers={"X-Session-ID": "missing-session"}) + assert unknown_response.status_code == 400 + assert unknown_response.json()["detail"] == "Invalid session ID" def test_execute_returns_execution_time(self, client) -> None: """Execute response includes execution time."""