diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 57cebc4b..5ed6c0e5 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -322,6 +322,7 @@ private async Task CleanupConnectionAsync(List? errors) /// A that can be used to cancel the operation. /// A task that resolves to provide the . /// Thrown when the client is not connected and AutoStart is disabled, or when a session with the same ID already exists. + /// Thrown when contains an invalid model name. /// /// Sessions maintain conversation state, handle events, and manage tool execution. /// If the client is not connected and is enabled (default), @@ -344,6 +345,19 @@ public async Task CreateSessionAsync(SessionConfig? config = nul { var connection = await EnsureConnectedAsync(cancellationToken); + if (!string.IsNullOrEmpty(config?.Model)) + { + // ListModelsAsync caches results after the first call, so this validation has minimal overhead + var availableModels = await ListModelsAsync(cancellationToken).ConfigureAwait(false); + + if (!availableModels.Any(m => string.Equals(m.Id, config.Model, StringComparison.OrdinalIgnoreCase))) + { + throw new ArgumentException( + $"Invalid model '{config.Model}'. Available models: {string.Join(", ", availableModels.Select(m => m.Id))}", + nameof(config)); + } + } + var hasHooks = config?.Hooks != null && ( config.Hooks.OnPreToolUse != null || config.Hooks.OnPostToolUse != null || diff --git a/dotnet/test/SessionTests.cs b/dotnet/test/SessionTests.cs index 13b23522..1e716111 100644 --- a/dotnet/test/SessionTests.cs +++ b/dotnet/test/SessionTests.cs @@ -15,7 +15,7 @@ public class SessionTests(E2ETestFixture fixture, ITestOutputHelper output) : E2 [Fact] public async Task ShouldCreateAndDestroySessions() { - var session = await Client.CreateSessionAsync(new SessionConfig { Model = "fake-test-model" }); + var session = await Client.CreateSessionAsync(new SessionConfig { Model = "claude-sonnet-4.5" }); Assert.Matches(@"^[a-f0-9-]+$", session.SessionId); @@ -395,4 +395,19 @@ public async Task Should_Create_Session_With_Custom_Config_Dir() Assert.NotNull(assistantMessage); Assert.Contains("2", assistantMessage!.Data.Content); } + + [Fact] + public async Task CreateSessionAsync_WithInvalidModel_ThrowsArgumentException() + { + var exception = await Assert.ThrowsAsync(async () => + { + await Client.CreateSessionAsync(new SessionConfig + { + Model = "INVALID_MODEL_THAT_DOES_NOT_EXIST" + }); + }); + + Assert.Contains("Invalid model", exception.Message); + Assert.Contains("INVALID_MODEL_THAT_DOES_NOT_EXIST", exception.Message); + } } diff --git a/go/client.go b/go/client.go index d45d3447..3e055f63 100644 --- a/go/client.go +++ b/go/client.go @@ -470,6 +470,30 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses params := make(map[string]any) if config != nil { if config.Model != "" { + // Validate model if specified + // ListModels caches results after the first call, so this validation has minimal overhead + availableModels, err := c.ListModels(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list models: %w", err) + } + + modelLower := strings.ToLower(config.Model) + modelFound := false + for _, model := range availableModels { + if strings.ToLower(model.ID) == modelLower { + modelFound = true + break + } + } + + if !modelFound { + validIDs := make([]string, len(availableModels)) + for i, model := range availableModels { + validIDs[i] = model.ID + } + return nil, fmt.Errorf("invalid model '%s'. Available models: %s", config.Model, strings.Join(validIDs, ", ")) + } + params["model"] = config.Model } if config.SessionID != "" { diff --git a/go/client_test.go b/go/client_test.go index 185bb4cb..d4047f4b 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -5,9 +5,15 @@ import ( "path/filepath" "reflect" "regexp" + "strings" "testing" ) +// containsIgnoreCase checks if a string contains a substring (case-insensitive) +func containsIgnoreCase(s, substr string) bool { + return strings.Contains(strings.ToLower(s), strings.ToLower(substr)) +} + // This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.test.go instead func TestClient_HandleToolCallRequest(t *testing.T) { @@ -48,6 +54,38 @@ func TestClient_HandleToolCallRequest(t *testing.T) { }) } +func TestClient_CreateSession_WithInvalidModel(t *testing.T) { + cliPath := findCLIPathForTest() + if cliPath == "" { + t.Skip("CLI not found") + } + + client := NewClient(&ClientOptions{CLIPath: cliPath}) + t.Cleanup(func() { client.ForceStop() }) + + err := client.Start(t.Context()) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + _, err = client.CreateSession(t.Context(), &SessionConfig{ + Model: "INVALID_MODEL_THAT_DOES_NOT_EXIST", + }) + + if err == nil { + t.Fatal("Expected error when creating session with invalid model, got nil") + } + + errorMsg := err.Error() + if !containsIgnoreCase(errorMsg, "invalid model") { + t.Errorf("Expected error message to contain 'invalid model', got: %s", errorMsg) + } + + if !containsIgnoreCase(errorMsg, "INVALID_MODEL_THAT_DOES_NOT_EXIST") { + t.Errorf("Expected error message to contain 'INVALID_MODEL_THAT_DOES_NOT_EXIST', got: %s", errorMsg) + } +} + func TestClient_URLParsing(t *testing.T) { t.Run("should parse port-only URL format", func(t *testing.T) { client := NewClient(&ClientOptions{ diff --git a/go/internal/e2e/session_test.go b/go/internal/e2e/session_test.go index 62183286..730703ce 100644 --- a/go/internal/e2e/session_test.go +++ b/go/internal/e2e/session_test.go @@ -18,7 +18,7 @@ func TestSession(t *testing.T) { t.Run("should create and destroy sessions", func(t *testing.T) { ctx.ConfigureForTest(t) - session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{Model: "fake-test-model"}) + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{Model: "claude-sonnet-4.5"}) if err != nil { t.Fatalf("Failed to create session: %v", err) } @@ -41,8 +41,8 @@ func TestSession(t *testing.T) { t.Errorf("Expected session.start sessionId to match") } - if messages[0].Data.SelectedModel == nil || *messages[0].Data.SelectedModel != "fake-test-model" { - t.Errorf("Expected selectedModel to be 'fake-test-model', got %v", messages[0].Data.SelectedModel) + if messages[0].Data.SelectedModel == nil || *messages[0].Data.SelectedModel != "claude-sonnet-4.5" { + t.Errorf("Expected selectedModel to be 'claude-sonnet-4.5', got %v", messages[0].Data.SelectedModel) } if err := session.Destroy(); err != nil { diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 20dc17f8..20fcc690 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -467,6 +467,20 @@ export class CopilotClient { } } + // Validate model if specified + if (config.model) { + // listModels() caches results, so this has minimal overhead + const availableModels = await this.listModels(); + const modelLower = config.model.toLowerCase(); + + if (!availableModels.some((m) => m.id.toLowerCase() === modelLower)) { + const validIds = availableModels.map((m) => m.id); + throw new Error( + `Invalid model '${config.model}'. Available models: ${validIds.join(", ")}` + ); + } + } + const response = await this.connection!.sendRequest("session.create", { model: config.model, sessionId: config.sessionId, diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index 364ff382..674de28e 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -28,6 +28,20 @@ describe("CopilotClient", () => { }); }); + it("throws error when creating session with invalid model", async () => { + const client = new CopilotClient({ cliPath: CLI_PATH }); + await client.start(); + onTestFinished(() => client.forceStop()); + + const error = await client + .createSession({ model: "INVALID_MODEL_THAT_DOES_NOT_EXIST" }) + .catch((e) => e); + + expect(error).toBeInstanceOf(Error); + expect(error.message).toContain("Invalid model"); + expect(error.message).toContain("INVALID_MODEL_THAT_DOES_NOT_EXIST"); + }); + describe("URL parsing", () => { it("should parse port-only URL format", () => { const client = new CopilotClient({ diff --git a/nodejs/test/e2e/session.test.ts b/nodejs/test/e2e/session.test.ts index b3fba475..28cb0122 100644 --- a/nodejs/test/e2e/session.test.ts +++ b/nodejs/test/e2e/session.test.ts @@ -8,13 +8,13 @@ describe("Sessions", async () => { const { copilotClient: client, openAiEndpoint, homeDir, env } = await createSdkTestContext(); it("should create and destroy sessions", async () => { - const session = await client.createSession({ model: "fake-test-model" }); + const session = await client.createSession({ model: "claude-sonnet-4.5" }); expect(session.sessionId).toMatch(/^[a-f0-9-]+$/); expect(await session.getMessages()).toMatchObject([ { type: "session.start", - data: { sessionId: session.sessionId, selectedModel: "fake-test-model" }, + data: { sessionId: session.sessionId, selectedModel: "claude-sonnet-4.5" }, }, ]); diff --git a/python/copilot/client.py b/python/copilot/client.py index da1ba8c0..6313bba9 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -389,6 +389,19 @@ async def create_session(self, config: Optional[SessionConfig] = None) -> Copilo cfg = config or {} + # Validate model if specified + model = cfg.get("model") + if model: + # list_models() caches results, so this has minimal overhead + available_models = await self.list_models() + model_lower = model.lower() + + if not any(m.id.lower() == model_lower for m in available_models): + valid_ids = [m.id for m in available_models] + raise ValueError( + f"Invalid model '{model}'. Available models: {', '.join(valid_ids)}" + ) + tool_defs = [] tools = cfg.get("tools") if tools: diff --git a/python/e2e/test_session.py b/python/e2e/test_session.py index f2e545ed..cb2fdcd0 100644 --- a/python/e2e/test_session.py +++ b/python/e2e/test_session.py @@ -14,14 +14,14 @@ class TestSessions: async def test_should_create_and_destroy_sessions(self, ctx: E2ETestContext): - session = await ctx.client.create_session({"model": "fake-test-model"}) + session = await ctx.client.create_session({"model": "claude-sonnet-4.5"}) assert session.session_id messages = await session.get_messages() assert len(messages) > 0 assert messages[0].type.value == "session.start" assert messages[0].data.session_id == session.session_id - assert messages[0].data.selected_model == "fake-test-model" + assert messages[0].data.selected_model == "claude-sonnet-4.5" await session.destroy() @@ -456,6 +456,13 @@ def on_event(event): assistant_message = await get_final_assistant_message(session) assert "300" in assistant_message.data.content + async def test_create_session_with_invalid_model_raises_value_error(self, ctx: E2ETestContext): + with pytest.raises(ValueError) as exc_info: + await ctx.client.create_session({"model": "INVALID_MODEL_THAT_DOES_NOT_EXIST"}) + + assert "Invalid model" in str(exc_info.value) + assert "INVALID_MODEL_THAT_DOES_NOT_EXIST" in str(exc_info.value) + async def test_should_create_session_with_custom_config_dir(self, ctx: E2ETestContext): import os