From 4dd00fa5dd8c6eef9a9926f1b484a837319d3b42 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 4 Feb 2026 16:41:40 +0100 Subject: [PATCH 01/15] use types to encode and decode jsonrpc queries --- go/client.go | 603 +++++++------------------------ go/client_test.go | 23 +- go/internal/jsonrpc2/jsonrpc2.go | 150 +++++--- go/session.go | 230 +++--------- go/types.go | 295 +++++++++++---- 5 files changed, 505 insertions(+), 796 deletions(-) diff --git a/go/client.go b/go/client.go index d45d3447..a6d3e950 100644 --- a/go/client.go +++ b/go/client.go @@ -396,36 +396,6 @@ func (c *Client) ForceStop() { } } -// buildProviderParams converts a ProviderConfig to a map for JSON-RPC params. -func buildProviderParams(p *ProviderConfig) map[string]any { - params := make(map[string]any) - if p.Type != "" { - params["type"] = p.Type - } - if p.WireApi != "" { - params["wireApi"] = p.WireApi - } - if p.BaseURL != "" { - params["baseUrl"] = p.BaseURL - } - if p.APIKey != "" { - params["apiKey"] = p.APIKey - } - if p.BearerToken != "" { - params["bearerToken"] = p.BearerToken - } - if p.Azure != nil { - azure := make(map[string]any) - if p.Azure.APIVersion != "" { - azure["apiVersion"] = p.Azure.APIVersion - } - if len(azure) > 0 { - params["azure"] = azure - } - } - return params -} - func (c *Client) ensureConnected() error { if c.client != nil { return nil @@ -467,166 +437,54 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses return nil, err } - params := make(map[string]any) + req := createSessionRequest{} if config != nil { - if config.Model != "" { - params["model"] = config.Model - } - if config.SessionID != "" { - params["sessionId"] = config.SessionID - } - if config.ReasoningEffort != "" { - params["reasoningEffort"] = config.ReasoningEffort - } - if len(config.Tools) > 0 { - toolDefs := make([]map[string]any, 0, len(config.Tools)) - for _, tool := range config.Tools { - if tool.Name == "" { - continue - } - definition := map[string]any{ - "name": tool.Name, - "description": tool.Description, - } - if tool.Parameters != nil { - definition["parameters"] = tool.Parameters - } - toolDefs = append(toolDefs, definition) - } - if len(toolDefs) > 0 { - params["tools"] = toolDefs - } - } - // Add system message configuration if provided - if config.SystemMessage != nil { - systemMessage := make(map[string]any) - - if config.SystemMessage.Mode != "" { - systemMessage["mode"] = config.SystemMessage.Mode - } + req.Model = config.Model + req.SessionID = config.SessionID + req.ReasoningEffort = config.ReasoningEffort + req.ConfigDir = config.ConfigDir + req.Tools = config.Tools + req.SystemMessage = config.SystemMessage + req.AvailableTools = config.AvailableTools + req.ExcludedTools = config.ExcludedTools + req.Provider = config.Provider + req.WorkingDirectory = config.WorkingDirectory + req.MCPServers = config.MCPServers + req.CustomAgents = config.CustomAgents + req.SkillDirectories = config.SkillDirectories + req.DisabledSkills = config.DisabledSkills + req.InfiniteSessions = config.InfiniteSessions - if config.SystemMessage.Mode == "replace" { - if config.SystemMessage.Content != "" { - systemMessage["content"] = config.SystemMessage.Content - } - } else { - if config.SystemMessage.Content != "" { - systemMessage["content"] = config.SystemMessage.Content - } - } - - if len(systemMessage) > 0 { - params["systemMessage"] = systemMessage - } - } - // Add tool filtering options - if len(config.AvailableTools) > 0 { - params["availableTools"] = config.AvailableTools - } - if len(config.ExcludedTools) > 0 { - params["excludedTools"] = config.ExcludedTools - } - // Add streaming option if config.Streaming { - params["streaming"] = config.Streaming - } - // Add provider configuration - if config.Provider != nil { - params["provider"] = buildProviderParams(config.Provider) + req.Streaming = Bool(true) } - // Add permission request flag if config.OnPermissionRequest != nil { - params["requestPermission"] = true + req.RequestPermission = Bool(true) } - // Add user input request flag if config.OnUserInputRequest != nil { - params["requestUserInput"] = true + req.RequestUserInput = Bool(true) } - // Add hooks flag if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || config.Hooks.OnPostToolUse != nil || config.Hooks.OnUserPromptSubmitted != nil || config.Hooks.OnSessionStart != nil || config.Hooks.OnSessionEnd != nil || config.Hooks.OnErrorOccurred != nil) { - params["hooks"] = true - } - // Add working directory - if config.WorkingDirectory != "" { - params["workingDirectory"] = config.WorkingDirectory - } - // Add MCP servers configuration - if len(config.MCPServers) > 0 { - params["mcpServers"] = config.MCPServers - } - // Add custom agents configuration - if len(config.CustomAgents) > 0 { - customAgents := make([]map[string]any, 0, len(config.CustomAgents)) - for _, agent := range config.CustomAgents { - agentMap := map[string]any{ - "name": agent.Name, - "prompt": agent.Prompt, - } - if agent.DisplayName != "" { - agentMap["displayName"] = agent.DisplayName - } - if agent.Description != "" { - agentMap["description"] = agent.Description - } - if len(agent.Tools) > 0 { - agentMap["tools"] = agent.Tools - } - if len(agent.MCPServers) > 0 { - agentMap["mcpServers"] = agent.MCPServers - } - if agent.Infer != nil { - agentMap["infer"] = *agent.Infer - } - customAgents = append(customAgents, agentMap) - } - params["customAgents"] = customAgents - } - // Add config directory override - if config.ConfigDir != "" { - params["configDir"] = config.ConfigDir - } - // Add skill directories configuration - if len(config.SkillDirectories) > 0 { - params["skillDirectories"] = config.SkillDirectories - } - // Add disabled skills configuration - if len(config.DisabledSkills) > 0 { - params["disabledSkills"] = config.DisabledSkills - } - // Add infinite sessions configuration - if config.InfiniteSessions != nil { - infiniteSessions := make(map[string]any) - if config.InfiniteSessions.Enabled != nil { - infiniteSessions["enabled"] = *config.InfiniteSessions.Enabled - } - if config.InfiniteSessions.BackgroundCompactionThreshold != nil { - infiniteSessions["backgroundCompactionThreshold"] = *config.InfiniteSessions.BackgroundCompactionThreshold - } - if config.InfiniteSessions.BufferExhaustionThreshold != nil { - infiniteSessions["bufferExhaustionThreshold"] = *config.InfiniteSessions.BufferExhaustionThreshold - } - params["infiniteSessions"] = infiniteSessions + req.Hooks = Bool(true) } } - result, err := c.client.Request("session.create", params) + result, err := c.client.Request("session.create", req) if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } - sessionID, ok := result["sessionId"].(string) - if !ok { - return nil, fmt.Errorf("invalid response: missing sessionId") + var response createSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - workspacePath, _ := result["workspacePath"].(string) - - session := newSession(sessionID, c.client, workspacePath) + session := newSession(response.SessionID, c.client, response.WorkspacePath) if config != nil { session.registerTools(config.Tools) @@ -644,7 +502,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses } c.sessionsMux.Lock() - c.sessions[sessionID] = session + c.sessions[response.SessionID] = session c.sessionsMux.Unlock() return session, nil @@ -676,119 +534,52 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, return nil, err } - params := map[string]any{ - "sessionId": sessionID, - } - + var req resumeSessionRequest + req.SessionID = sessionID if config != nil { - if config.ReasoningEffort != "" { - params["reasoningEffort"] = config.ReasoningEffort - } - if len(config.Tools) > 0 { - toolDefs := make([]map[string]any, 0, len(config.Tools)) - for _, tool := range config.Tools { - if tool.Name == "" { - continue - } - definition := map[string]any{ - "name": tool.Name, - "description": tool.Description, - } - if tool.Parameters != nil { - definition["parameters"] = tool.Parameters - } - toolDefs = append(toolDefs, definition) - } - if len(toolDefs) > 0 { - params["tools"] = toolDefs - } - } - if config.Provider != nil { - params["provider"] = buildProviderParams(config.Provider) - } - // Add streaming option + req.ReasoningEffort = config.ReasoningEffort + req.Tools = config.Tools + req.Provider = config.Provider if config.Streaming { - params["streaming"] = config.Streaming + req.Streaming = Bool(true) } - // Add permission request flag if config.OnPermissionRequest != nil { - params["requestPermission"] = true + req.RequestPermission = Bool(true) } - // Add user input request flag - if config.OnUserInputRequest != nil { - params["requestUserInput"] = true + if config.OnPermissionRequest != nil { + req.RequestUserInput = Bool(true) } - // Add hooks flag if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || config.Hooks.OnPostToolUse != nil || config.Hooks.OnUserPromptSubmitted != nil || config.Hooks.OnSessionStart != nil || config.Hooks.OnSessionEnd != nil || config.Hooks.OnErrorOccurred != nil) { - params["hooks"] = true - } - // Add working directory - if config.WorkingDirectory != "" { - params["workingDirectory"] = config.WorkingDirectory + req.Hooks = Bool(true) } - // Add disable resume flag + req.WorkingDirectory = config.WorkingDirectory if config.DisableResume { - params["disableResume"] = true + req.DisableResume = Bool(true) } - // Add MCP servers configuration if len(config.MCPServers) > 0 { - params["mcpServers"] = config.MCPServers - } - // Add custom agents configuration - if len(config.CustomAgents) > 0 { - customAgents := make([]map[string]any, 0, len(config.CustomAgents)) - for _, agent := range config.CustomAgents { - agentMap := map[string]any{ - "name": agent.Name, - "prompt": agent.Prompt, - } - if agent.DisplayName != "" { - agentMap["displayName"] = agent.DisplayName - } - if agent.Description != "" { - agentMap["description"] = agent.Description - } - if len(agent.Tools) > 0 { - agentMap["tools"] = agent.Tools - } - if len(agent.MCPServers) > 0 { - agentMap["mcpServers"] = agent.MCPServers - } - if agent.Infer != nil { - agentMap["infer"] = *agent.Infer - } - customAgents = append(customAgents, agentMap) - } - params["customAgents"] = customAgents - } - // Add skill directories configuration - if len(config.SkillDirectories) > 0 { - params["skillDirectories"] = config.SkillDirectories - } - // Add disabled skills configuration - if len(config.DisabledSkills) > 0 { - params["disabledSkills"] = config.DisabledSkills + req.MCPServers = config.MCPServers } + req.CustomAgents = config.CustomAgents + req.SkillDirectories = config.SkillDirectories + req.DisabledSkills = config.DisabledSkills } - result, err := c.client.Request("session.resume", params) + result, err := c.client.Request("session.resume", req) if err != nil { return nil, fmt.Errorf("failed to resume session: %w", err) } - resumedSessionID, ok := result["sessionId"].(string) - if !ok { - return nil, fmt.Errorf("invalid response: missing sessionId") + var response resumeSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - workspacePath, _ := result["workspacePath"].(string) - - session := newSession(resumedSessionID, c.client, workspacePath) + session := newSession(response.SessionID, c.client, response.WorkspacePath) if config != nil { session.registerTools(config.Tools) if config.OnPermissionRequest != nil { @@ -805,7 +596,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, } c.sessionsMux.Lock() - c.sessions[resumedSessionID] = session + c.sessions[response.SessionID] = session c.sessionsMux.Unlock() return session, nil @@ -830,19 +621,13 @@ func (c *Client) ListSessions(ctx context.Context) ([]SessionMetadata, error) { return nil, err } - result, err := c.client.Request("session.list", map[string]any{}) + result, err := c.client.Request("session.list", listSessionsRequest{}) if err != nil { return nil, err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal sessions response: %w", err) - } - - var response ListSessionsResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response listSessionsResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal sessions response: %w", err) } @@ -864,23 +649,13 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error { return err } - params := map[string]any{ - "sessionId": sessionID, - } - - result, err := c.client.Request("session.delete", params) + result, err := c.client.Request("session.delete", deleteSessionRequest{SessionID: sessionID}) if err != nil { return err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return fmt.Errorf("failed to marshal delete response: %w", err) - } - - var response DeleteSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response deleteSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return fmt.Errorf("failed to unmarshal delete response: %w", err) } @@ -925,18 +700,13 @@ func (c *Client) GetForegroundSessionID(ctx context.Context) (*string, error) { } } - result, err := c.client.Request("session.getForeground", map[string]any{}) + result, err := c.client.Request("session.getForeground", getForegroundSessionRequest{}) if err != nil { return nil, err } - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal getForeground response: %w", err) - } - - var response GetForegroundSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response getForegroundSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal getForeground response: %w", err) } @@ -964,22 +734,13 @@ func (c *Client) SetForegroundSessionID(ctx context.Context, sessionID string) e } } - params := map[string]any{ - "sessionId": sessionID, - } - - result, err := c.client.Request("session.setForeground", params) + result, err := c.client.Request("session.setForeground", setForegroundSessionRequest{SessionID: sessionID}) if err != nil { return err } - jsonBytes, err := json.Marshal(result) - if err != nil { - return fmt.Errorf("failed to marshal setForeground response: %w", err) - } - - var response SetForegroundSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response setForegroundSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return fmt.Errorf("failed to unmarshal setForeground response: %w", err) } @@ -1057,7 +818,7 @@ func (c *Client) OnEventType(eventType SessionLifecycleEventType, handler Sessio } // dispatchLifecycleEvent dispatches a lifecycle event to all registered handlers -func (c *Client) dispatchLifecycleEvent(event SessionLifecycleEvent) { +func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) { c.lifecycleHandlersMux.Lock() // Copy handlers to avoid holding lock during callbacks typedHandlers := make([]SessionLifecycleHandler, 0) @@ -1111,87 +872,57 @@ func (c *Client) State() ConnectionState { // } else { // log.Printf("Server responded at %d", resp.Timestamp) // } -func (c *Client) Ping(ctx context.Context, message string) (*PingResponse, error) { +func (c *Client) Ping(ctx context.Context, message string) (*pingResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } - params := map[string]any{} - if message != "" { - params["message"] = message - } - - result, err := c.client.Request("ping", params) + result, err := c.client.Request("ping", pingRequest{Message: message}) if err != nil { return nil, err } - response := &PingResponse{} - if msg, ok := result["message"].(string); ok { - response.Message = msg - } - if ts, ok := result["timestamp"].(float64); ok { - response.Timestamp = int64(ts) - } - if pv, ok := result["protocolVersion"].(float64); ok { - v := int(pv) - response.ProtocolVersion = &v + var response pingResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // GetStatus returns CLI status including version and protocol information -func (c *Client) GetStatus(ctx context.Context) (*GetStatusResponse, error) { +func (c *Client) GetStatus(ctx context.Context) (*getStatusResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } - result, err := c.client.Request("status.get", map[string]any{}) + result, err := c.client.Request("status.get", getStatusRequest{}) if err != nil { return nil, err } - response := &GetStatusResponse{} - if v, ok := result["version"].(string); ok { - response.Version = v - } - if pv, ok := result["protocolVersion"].(float64); ok { - response.ProtocolVersion = int(pv) + var response getStatusResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // GetAuthStatus returns current authentication status -func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, error) { +func (c *Client) GetAuthStatus(ctx context.Context) (*getAuthStatusResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } - result, err := c.client.Request("auth.getStatus", map[string]any{}) + result, err := c.client.Request("auth.getStatus", getAuthStatusRequest{}) if err != nil { return nil, err } - response := &GetAuthStatusResponse{} - if v, ok := result["isAuthenticated"].(bool); ok { - response.IsAuthenticated = v - } - if v, ok := result["authType"].(string); ok { - response.AuthType = &v - } - if v, ok := result["host"].(string); ok { - response.Host = &v - } - if v, ok := result["login"].(string); ok { - response.Login = &v - } - if v, ok := result["statusMessage"].(string); ok { - response.StatusMessage = &v + var response getAuthStatusResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // ListModels returns available models with their metadata. @@ -1216,19 +947,13 @@ func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) { } // Cache miss - fetch from backend while holding lock - result, err := c.client.Request("models.list", map[string]any{}) + result, err := c.client.Request("models.list", listModelsRequest{}) if err != nil { return nil, err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal models response: %w", err) - } - - var response GetModelsResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response listModelsResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal models response: %w", err) } @@ -1422,82 +1147,48 @@ func (c *Client) connectViaTcp(ctx context.Context) error { // setupNotificationHandler configures handlers for session events, tool calls, and permission requests. func (c *Client) setupNotificationHandler() { - c.client.SetNotificationHandler(func(method string, params map[string]any) { - switch method { - case "session.event": - // Extract sessionId and event - sessionID, ok := params["sessionId"].(string) - if !ok { - return - } - - // Marshal the event back to JSON and unmarshal into typed struct - eventJSON, err := json.Marshal(params["event"]) - if err != nil { - return - } - - event, err := UnmarshalSessionEvent(eventJSON) - if err != nil { - return - } - - // Dispatch to session - c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] - c.sessionsMux.Unlock() - - if ok { - session.dispatchEvent(event) - } - case "session.lifecycle": - // Handle session lifecycle events - eventJSON, err := json.Marshal(params) - if err != nil { - return - } - - var event SessionLifecycleEvent - if err := json.Unmarshal(eventJSON, &event); err != nil { - return - } + c.client.SetRequestHandler("session.event", jsonrpc2.NotificationHandlerFor(c.handleSessionEvent)) + c.client.SetRequestHandler("session.lifecycle", jsonrpc2.NotificationHandlerFor(c.handleLifecycleEvent)) + c.client.SetRequestHandler("tool.call", jsonrpc2.RequestHandlerFor(c.handleToolCallRequest)) + c.client.SetRequestHandler("permission.request", jsonrpc2.RequestHandlerFor(c.handlePermissionRequest)) + c.client.SetRequestHandler("userInput.request", jsonrpc2.RequestHandlerFor(c.handleUserInputRequest)) + c.client.SetRequestHandler("hooks.invoke", jsonrpc2.RequestHandlerFor(c.handleHooksInvoke)) +} - c.dispatchLifecycleEvent(event) - } - }) +func (c *Client) handleSessionEvent(req sessionEventRequest) { + if req.SessionID == "" { + return + } + // Dispatch to session + c.sessionsMux.Lock() + session, ok := c.sessions[req.SessionID] + c.sessionsMux.Unlock() - c.client.SetRequestHandler("tool.call", c.handleToolCallRequest) - c.client.SetRequestHandler("permission.request", c.handlePermissionRequest) - c.client.SetRequestHandler("userInput.request", c.handleUserInputRequest) - c.client.SetRequestHandler("hooks.invoke", c.handleHooksInvoke) + if ok { + session.dispatchEvent(req.Event) + } } // handleToolCallRequest handles a tool call request from the CLI server. -func (c *Client) handleToolCallRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - toolCallID, _ := params["toolCallId"].(string) - toolName, _ := params["toolName"].(string) - - if sessionID == "" || toolCallID == "" || toolName == "" { +func (c *Client) handleToolCallRequest(req toolCallRequest) (*toolCallResponse, *jsonrpc2.Error) { + if req.SessionID == "" || req.ToolCallID == "" || req.ToolName == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid tool call payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - handler, ok := session.getToolHandler(toolName) + handler, ok := session.getToolHandler(req.ToolName) if !ok { - return map[string]any{"result": buildUnsupportedToolResult(toolName)}, nil + return &toolCallResponse{Result: buildUnsupportedToolResult(req.ToolName)}, nil } - arguments := params["arguments"] - result := c.executeToolCall(sessionID, toolCallID, toolName, arguments, handler) - - return map[string]any{"result": result}, nil + result := c.executeToolCall(req.SessionID, req.ToolCallID, req.ToolName, req.Arguments, handler) + return &toolCallResponse{Result: result}, nil } // executeToolCall executes a tool handler and returns the result. @@ -1531,100 +1222,70 @@ func (c *Client) executeToolCall( } // handlePermissionRequest handles a permission request from the CLI server. -func (c *Client) handlePermissionRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - permissionRequest, _ := params["permissionRequest"].(map[string]any) - - if sessionID == "" { +func (c *Client) handlePermissionRequest(req permissionRequestRequest) (*permissionRequestResponse, *jsonrpc2.Error) { + if req.SessionID == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid permission request payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - result, err := session.handlePermissionRequest(permissionRequest) + result, err := session.handlePermissionRequest(req.Request) if err != nil { // Return denial on error - return map[string]any{ - "result": map[string]any{ - "kind": "denied-no-approval-rule-and-could-not-request-from-user", + return &permissionRequestResponse{ + Result: PermissionRequestResult{ + Kind: "denied-no-approval-rule-and-could-not-request-from-user", }, }, nil } - return map[string]any{"result": result}, nil + return &permissionRequestResponse{Result: result}, nil } // handleUserInputRequest handles a user input request from the CLI server. -func (c *Client) handleUserInputRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - question, _ := params["question"].(string) - - if sessionID == "" || question == "" { +func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputResponse, *jsonrpc2.Error) { + if req.SessionID == "" || req.Question == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - // Parse choices - var choices []string - if choicesRaw, ok := params["choices"].([]any); ok { - for _, choice := range choicesRaw { - if s, ok := choice.(string); ok { - choices = append(choices, s) - } - } - } - - var allowFreeform *bool - if af, ok := params["allowFreeform"].(bool); ok { - allowFreeform = &af - } - - request := UserInputRequest{ - Question: question, - Choices: choices, - AllowFreeform: allowFreeform, - } - - response, err := session.handleUserInputRequest(request) + response, err := session.handleUserInputRequest(UserInputRequest{ + Question: req.Question, + Choices: req.Choices, + AllowFreeform: req.AllowFreeform, + }) if err != nil { return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} } - return map[string]any{ - "answer": response.Answer, - "wasFreeform": response.WasFreeform, - }, nil + return &userInputResponse{Answer: response.Answer, WasFreeform: response.WasFreeform}, nil } // handleHooksInvoke handles a hooks invocation from the CLI server. -func (c *Client) handleHooksInvoke(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - hookType, _ := params["hookType"].(string) - input, _ := params["input"].(map[string]any) - - if sessionID == "" || hookType == "" { +func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jsonrpc2.Error) { + if req.SessionID == "" || req.Type == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - output, err := session.handleHooksInvoke(hookType, input) + output, err := session.handleHooksInvoke(req.Type, req.Input) if err != nil { return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} } diff --git a/go/client_test.go b/go/client_test.go index 185bb4cb..176dad8c 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -25,25 +25,20 @@ func TestClient_HandleToolCallRequest(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - params := map[string]any{ - "sessionId": session.SessionID, - "toolCallId": "123", - "toolName": "missing_tool", - "arguments": map[string]any{}, + params := toolCallRequest{ + SessionID: session.SessionID, + ToolCallID: "123", + ToolName: "missing_tool", + Arguments: map[string]any{}, } response, _ := client.handleToolCallRequest(params) - result, ok := response["result"].(ToolResult) - if !ok { - t.Fatalf("Expected result to be ToolResult, got %T", response["result"]) + if response.Result.ResultType != "failure" { + t.Errorf("Expected resultType to be 'failure', got %q", response.Result.ResultType) } - if result.ResultType != "failure" { - t.Errorf("Expected resultType to be 'failure', got %q", result.ResultType) - } - - if result.Error != "tool 'missing_tool' not supported" { - t.Errorf("Expected error to be \"tool 'missing_tool' not supported\", got %q", result.Error) + if response.Result.Error != "tool 'missing_tool' not supported" { + t.Errorf("Expected error to be \"tool 'missing_tool' not supported\", got %q", response.Result.Error) } }) } diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index 8e4a0f6a..1a6e17d1 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "reflect" "sync" ) @@ -23,43 +24,39 @@ func (e *Error) Error() string { // Request represents a JSON-RPC 2.0 request type Request struct { JSONRPC string `json:"jsonrpc"` - ID json.RawMessage `json:"id"` + ID json.RawMessage `json:"id"` // nil for notifications Method string `json:"method"` - Params map[string]any `json:"params"` + Params json.RawMessage `json:"params"` +} + +func (r *Request) IsCall() bool { + return len(r.ID) > 0 } // Response represents a JSON-RPC 2.0 response type Response struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id,omitempty"` - Result map[string]any `json:"result,omitempty"` + Result json.RawMessage `json:"result,omitempty"` Error *Error `json:"error,omitempty"` } -// Notification represents a JSON-RPC 2.0 notification -type Notification struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params map[string]any `json:"params"` -} - // NotificationHandler handles incoming notifications -type NotificationHandler func(method string, params map[string]any) +type NotificationHandler func(method string, params json.RawMessage) // RequestHandler handles incoming server requests and returns a result or error -type RequestHandler func(params map[string]any) (map[string]any, *Error) +type RequestHandler func(params json.RawMessage) (json.RawMessage, *Error) // Client is a minimal JSON-RPC 2.0 client for stdio transport type Client struct { - stdin io.WriteCloser - stdout io.ReadCloser - mu sync.Mutex - pendingRequests map[string]chan *Response - notificationHandler NotificationHandler - requestHandlers map[string]RequestHandler - running bool - stopChan chan struct{} - wg sync.WaitGroup + stdin io.WriteCloser + stdout io.ReadCloser + mu sync.Mutex + pendingRequests map[string]chan *Response + requestHandlers map[string]RequestHandler + running bool + stopChan chan struct{} + wg sync.WaitGroup } // NewClient creates a new JSON-RPC client @@ -96,11 +93,55 @@ func (c *Client) Stop() { c.wg.Wait() } -// SetNotificationHandler sets the handler for incoming notifications -func (c *Client) SetNotificationHandler(handler NotificationHandler) { - c.mu.Lock() - defer c.mu.Unlock() - c.notificationHandler = handler +func NotificationHandlerFor[In any](handler func(params In)) RequestHandler { + return func(params json.RawMessage) (json.RawMessage, *Error) { + var in In + // If In is a pointer type, allocate the underlying value and unmarshal into it directly + var target any = &in + if t := reflect.TypeOf(in); t != nil && t.Kind() == reflect.Pointer { + in = reflect.New(t.Elem()).Interface().(In) + target = in + } + if err := json.Unmarshal(params, target); err != nil { + return nil, &Error{ + Code: -32602, + Message: fmt.Sprintf("Invalid params: %v", err), + } + } + handler(in) + return nil, nil + } +} + +// RequestHandlerFor creates a RequestHandler from a typed function +func RequestHandlerFor[In, Out any](handler func(params In) (Out, *Error)) RequestHandler { + return func(params json.RawMessage) (json.RawMessage, *Error) { + var in In + // If In is a pointer type, allocate the underlying value and unmarshal into it directly + var target any = &in + if t := reflect.TypeOf(in); t != nil && t.Kind() == reflect.Pointer { + in = reflect.New(t.Elem()).Interface().(In) + target = in + } + if err := json.Unmarshal(params, target); err != nil { + return nil, &Error{ + Code: -32602, + Message: fmt.Sprintf("Invalid params: %v", err), + } + } + out, errj := handler(in) + if errj != nil { + return nil, errj + } + outData, err := json.Marshal(out) + if err != nil { + return nil, &Error{ + Code: -32603, + Message: fmt.Sprintf("Failed to marshal response: %v", err), + } + } + return outData, nil + } } // SetRequestHandler registers a handler for incoming requests from the server @@ -115,7 +156,7 @@ func (c *Client) SetRequestHandler(method string, handler RequestHandler) { } // Request sends a JSON-RPC request and waits for the response -func (c *Client) Request(method string, params map[string]any) (map[string]any, error) { +func (c *Client) Request(method string, params any) (json.RawMessage, error) { requestID := generateUUID() // Create response channel @@ -131,12 +172,17 @@ func (c *Client) Request(method string, params map[string]any) (map[string]any, c.mu.Unlock() }() + paramsData, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + // Send request request := Request{ JSONRPC: "2.0", ID: json.RawMessage(`"` + requestID + `"`), Method: method, - Params: params, + Params: json.RawMessage(paramsData), } if err := c.sendMessage(request); err != nil { @@ -156,11 +202,16 @@ func (c *Client) Request(method string, params map[string]any) (map[string]any, } // Notify sends a JSON-RPC notification (no response expected) -func (c *Client) Notify(method string, params map[string]any) error { - notification := Notification{ +func (c *Client) Notify(method string, params any) error { + paramsData, err := json.Marshal(params) + if err != nil { + return fmt.Errorf("failed to marshal params: %w", err) + } + + notification := Request{ JSONRPC: "2.0", Method: method, - Params: params, + Params: json.RawMessage(paramsData), } return c.sendMessage(notification) } @@ -231,7 +282,7 @@ func (c *Client) readLoop() { // Try to parse as request first (has both ID and Method) var request Request - if err := json.Unmarshal(body, &request); err == nil && request.Method != "" && len(request.ID) > 0 { + if err := json.Unmarshal(body, &request); err == nil && request.Method != "" { c.handleRequest(&request) continue } @@ -242,13 +293,6 @@ func (c *Client) readLoop() { c.handleResponse(&response) continue } - - // Try to parse as notification (has Method but no ID) - var notification Notification - if err := json.Unmarshal(body, ¬ification); err == nil && notification.Method != "" { - c.handleNotification(¬ification) - continue - } } } @@ -270,47 +314,41 @@ func (c *Client) handleResponse(response *Response) { } } -// handleNotification dispatches a notification to the handler -func (c *Client) handleNotification(notification *Notification) { - c.mu.Lock() - handler := c.notificationHandler - c.mu.Unlock() - - if handler != nil { - handler(notification.Method, notification.Params) - } -} - func (c *Client) handleRequest(request *Request) { c.mu.Lock() handler := c.requestHandlers[request.Method] c.mu.Unlock() if handler == nil { - c.sendErrorResponse(request.ID, -32601, fmt.Sprintf("Method not found: %s", request.Method), nil) + if request.IsCall() { + c.sendErrorResponse(request.ID, -32601, fmt.Sprintf("Method not found: %s", request.Method), nil) + } return } go func() { defer func() { if r := recover(); r != nil { - c.sendErrorResponse(request.ID, -32603, fmt.Sprintf("request handler panic: %v", r), nil) + if request.IsCall() { + c.sendErrorResponse(request.ID, -32603, fmt.Sprintf("request handler panic: %v", r), nil) + } } }() result, err := handler(request.Params) + if !request.IsCall() { + // Only send a response if this is a call + return + } if err != nil { c.sendErrorResponse(request.ID, err.Code, err.Message, err.Data) return } - if result == nil { - result = make(map[string]any) - } c.sendResponse(request.ID, result) }() } -func (c *Client) sendResponse(id json.RawMessage, result map[string]any) { +func (c *Client) sendResponse(id json.RawMessage, result json.RawMessage) { response := Response{ JSONRPC: "2.0", ID: id, diff --git a/go/session.go b/go/session.go index e4f1473d..5d494710 100644 --- a/go/session.go +++ b/go/session.go @@ -106,29 +106,23 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) // log.Printf("Failed to send message: %v", err) // } func (s *Session) Send(ctx context.Context, options MessageOptions) (string, error) { - params := map[string]any{ - "sessionId": s.SessionID, - "prompt": options.Prompt, + req := sessionSendRequest{ + SessionID: s.SessionID, + Prompt: options.Prompt, + Attachments: options.Attachments, + Mode: options.Mode, } - if options.Attachments != nil { - params["attachments"] = options.Attachments - } - if options.Mode != "" { - params["mode"] = options.Mode - } - - result, err := s.client.Request("session.send", params) + result, err := s.client.Request("session.send", req) if err != nil { return "", fmt.Errorf("failed to send message: %w", err) } - messageID, ok := result["messageId"].(string) - if !ok { - return "", fmt.Errorf("invalid response: missing messageId") + var response sessionSendResponse + if err := json.Unmarshal(result, &response); err != nil { + return "", fmt.Errorf("failed to unmarshal send response: %w", err) } - - return messageID, nil + return response.MessageID, nil } // SendAndWait sends a message to this session and waits until the session becomes idle. @@ -306,7 +300,7 @@ func (s *Session) getPermissionHandler() PermissionHandler { // handlePermissionRequest handles a permission request from the Copilot CLI. // This is an internal method called by the SDK when the CLI requests permission. -func (s *Session) handlePermissionRequest(requestData map[string]any) (PermissionRequestResult, error) { +func (s *Session) handlePermissionRequest(request PermissionRequest) (PermissionRequestResult, error) { handler := s.getPermissionHandler() if handler == nil { @@ -315,16 +309,6 @@ func (s *Session) handlePermissionRequest(requestData map[string]any) (Permissio }, nil } - // Convert map to PermissionRequest struct - kind, _ := requestData["kind"].(string) - toolCallID, _ := requestData["toolCallId"].(string) - - request := PermissionRequest{ - Kind: kind, - ToolCallID: toolCallID, - Extra: requestData, - } - invocation := PermissionInvocation{ SessionID: s.SessionID, } @@ -388,7 +372,7 @@ func (s *Session) getHooks() *SessionHooks { // handleHooksInvoke handles a hook invocation from the Copilot CLI. // This is an internal method called by the SDK when the CLI invokes a hook. -func (s *Session) handleHooksInvoke(hookType string, input map[string]any) (any, error) { +func (s *Session) handleHooksInvoke(hookType string, rawInput json.RawMessage) (any, error) { hooks := s.getHooks() if hooks == nil { @@ -404,153 +388,66 @@ func (s *Session) handleHooksInvoke(hookType string, input map[string]any) (any, if hooks.OnPreToolUse == nil { return nil, nil } - hookInput := parsePreToolUseInput(input) - return hooks.OnPreToolUse(hookInput, invocation) + var input PreToolUseHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnPreToolUse(input, invocation) case "postToolUse": if hooks.OnPostToolUse == nil { return nil, nil } - hookInput := parsePostToolUseInput(input) - return hooks.OnPostToolUse(hookInput, invocation) + var input PostToolUseHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnPostToolUse(input, invocation) case "userPromptSubmitted": if hooks.OnUserPromptSubmitted == nil { return nil, nil } - hookInput := parseUserPromptSubmittedInput(input) - return hooks.OnUserPromptSubmitted(hookInput, invocation) + var input UserPromptSubmittedHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnUserPromptSubmitted(input, invocation) case "sessionStart": if hooks.OnSessionStart == nil { return nil, nil } - hookInput := parseSessionStartInput(input) - return hooks.OnSessionStart(hookInput, invocation) + var input SessionStartHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnSessionStart(input, invocation) case "sessionEnd": if hooks.OnSessionEnd == nil { return nil, nil } - hookInput := parseSessionEndInput(input) - return hooks.OnSessionEnd(hookInput, invocation) + var input SessionEndHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnSessionEnd(input, invocation) case "errorOccurred": if hooks.OnErrorOccurred == nil { return nil, nil } - hookInput := parseErrorOccurredInput(input) - return hooks.OnErrorOccurred(hookInput, invocation) - + var input ErrorOccurredHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnErrorOccurred(input, invocation) default: return nil, fmt.Errorf("unknown hook type: %s", hookType) } } -// Helper functions to parse hook inputs - -func parsePreToolUseInput(input map[string]any) PreToolUseHookInput { - result := PreToolUseHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if name, ok := input["toolName"].(string); ok { - result.ToolName = name - } - result.ToolArgs = input["toolArgs"] - return result -} - -func parsePostToolUseInput(input map[string]any) PostToolUseHookInput { - result := PostToolUseHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if name, ok := input["toolName"].(string); ok { - result.ToolName = name - } - result.ToolArgs = input["toolArgs"] - result.ToolResult = input["toolResult"] - return result -} - -func parseUserPromptSubmittedInput(input map[string]any) UserPromptSubmittedHookInput { - result := UserPromptSubmittedHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if prompt, ok := input["prompt"].(string); ok { - result.Prompt = prompt - } - return result -} - -func parseSessionStartInput(input map[string]any) SessionStartHookInput { - result := SessionStartHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if source, ok := input["source"].(string); ok { - result.Source = source - } - if prompt, ok := input["initialPrompt"].(string); ok { - result.InitialPrompt = prompt - } - return result -} - -func parseSessionEndInput(input map[string]any) SessionEndHookInput { - result := SessionEndHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if reason, ok := input["reason"].(string); ok { - result.Reason = reason - } - if msg, ok := input["finalMessage"].(string); ok { - result.FinalMessage = msg - } - if errStr, ok := input["error"].(string); ok { - result.Error = errStr - } - return result -} - -func parseErrorOccurredInput(input map[string]any) ErrorOccurredHookInput { - result := ErrorOccurredHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if errMsg, ok := input["error"].(string); ok { - result.Error = errMsg - } - if ctx, ok := input["errorContext"].(string); ok { - result.ErrorContext = ctx - } - if rec, ok := input["recoverable"].(bool); ok { - result.Recoverable = rec - } - return result -} - // dispatchEvent dispatches an event to all registered handlers. // This is an internal method; handlers are called synchronously and any panics // are recovered to prevent crashing the event dispatcher. @@ -596,38 +493,17 @@ func (s *Session) dispatchEvent(event SessionEvent) { // } // } func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { - params := map[string]any{ - "sessionId": s.SessionID, - } - result, err := s.client.Request("session.getMessages", params) + result, err := s.client.Request("session.getMessages", sessionGetMessagesRequest{SessionID: s.SessionID}) if err != nil { return nil, fmt.Errorf("failed to get messages: %w", err) } - eventsRaw, ok := result["events"].([]any) - if !ok { - return nil, fmt.Errorf("invalid response: missing events") + var response sessionGetMessagesResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal get messages response: %w", err) } - - // Convert to SessionEvent structs - events := make([]SessionEvent, 0, len(eventsRaw)) - for _, eventRaw := range eventsRaw { - // Marshal back to JSON and unmarshal into typed struct - eventJSON, err := json.Marshal(eventRaw) - if err != nil { - continue - } - - event, err := UnmarshalSessionEvent(eventJSON) - if err != nil { - continue - } - - events = append(events, event) - } - - return events, nil + return response.Events, nil } // Destroy destroys this session and releases all associated resources. @@ -645,11 +521,7 @@ func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { // log.Printf("Failed to destroy session: %v", err) // } func (s *Session) Destroy() error { - params := map[string]any{ - "sessionId": s.SessionID, - } - - _, err := s.client.Request("session.destroy", params) + _, err := s.client.Request("session.destroy", sessionDestroyRequest{SessionID: s.SessionID}) if err != nil { return fmt.Errorf("failed to destroy session: %w", err) } @@ -692,11 +564,11 @@ func (s *Session) Destroy() error { // log.Printf("Failed to abort: %v", err) // } func (s *Session) Abort(ctx context.Context) error { - params := map[string]any{ - "sessionId": s.SessionID, + req := sessionAbortRequest{ + SessionID: s.SessionID, } - _, err := s.client.Request("session.abort", params) + _, err := s.client.Request("session.abort", req) if err != nil { return fmt.Errorf("failed to abort session: %w", err) } diff --git a/go/types.go b/go/types.go index 7a1917f0..9ca57dc7 100644 --- a/go/types.go +++ b/go/types.go @@ -1,5 +1,7 @@ package copilot +import "encoding/json" + // ConnectionState represents the client connection state type ConnectionState string @@ -113,15 +115,15 @@ type PermissionInvocation struct { // UserInputRequest represents a request for user input from the agent type UserInputRequest struct { - Question string `json:"question"` - Choices []string `json:"choices,omitempty"` - AllowFreeform *bool `json:"allowFreeform,omitempty"` + Question string + Choices []string + AllowFreeform *bool } // UserInputResponse represents the user's response to an input request type UserInputResponse struct { - Answer string `json:"answer"` - WasFreeform bool `json:"wasFreeform"` + Answer string + WasFreeform bool } // UserInputHandler handles user input requests from the agent @@ -307,13 +309,13 @@ type CustomAgentConfig struct { // limits through background compaction and persist state to a workspace directory. type InfiniteSessionConfig struct { // Enabled controls whether infinite sessions are enabled (default: true) - Enabled *bool + Enabled *bool `json:"enabled,omitempty"` // BackgroundCompactionThreshold is the context utilization (0.0-1.0) at which // background compaction starts. Default: 0.80 - BackgroundCompactionThreshold *float64 + BackgroundCompactionThreshold *float64 `json:"backgroundCompactionThreshold,omitempty"` // BufferExhaustionThreshold is the context utilization (0.0-1.0) at which // the session blocks until compaction completes. Default: 0.95 - BufferExhaustionThreshold *float64 + BufferExhaustionThreshold *float64 `json:"bufferExhaustionThreshold,omitempty"` } // SessionConfig configures a new session @@ -369,10 +371,10 @@ type SessionConfig struct { // Tool describes a caller-implemented tool that can be invoked by Copilot type Tool struct { - Name string - Description string // optional - Parameters map[string]any - Handler ToolHandler + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` + Handler ToolHandler `json:"-"` } // ToolInvocation describes a tool call initiated by Copilot @@ -477,43 +479,6 @@ type MessageOptions struct { // SessionEventHandler is a callback for session events type SessionEventHandler func(event SessionEvent) -// PingResponse is the response from a ping request -type PingResponse struct { - Message string `json:"message"` - Timestamp int64 `json:"timestamp"` - ProtocolVersion *int `json:"protocolVersion,omitempty"` -} - -// SessionCreateResponse is the response from session.create -type SessionCreateResponse struct { - SessionID string `json:"sessionId"` -} - -// SessionSendResponse is the response from session.send -type SessionSendResponse struct { - MessageID string `json:"messageId"` -} - -// SessionGetMessagesResponse is the response from session.getMessages -type SessionGetMessagesResponse struct { - Events []SessionEvent `json:"events"` -} - -// GetStatusResponse is the response from status.get -type GetStatusResponse struct { - Version string `json:"version"` - ProtocolVersion int `json:"protocolVersion"` -} - -// GetAuthStatusResponse is the response from auth.getStatus -type GetAuthStatusResponse struct { - IsAuthenticated bool `json:"isAuthenticated"` - AuthType *string `json:"authType,omitempty"` - Host *string `json:"host,omitempty"` - Login *string `json:"login,omitempty"` - StatusMessage *string `json:"statusMessage,omitempty"` -} - // ModelVisionLimits contains vision-specific limits type ModelVisionLimits struct { SupportedMediaTypes []string `json:"supported_media_types"` @@ -562,11 +527,6 @@ type ModelInfo struct { DefaultReasoningEffort string `json:"defaultReasoningEffort,omitempty"` } -// GetModelsResponse is the response from models.list -type GetModelsResponse struct { - Models []ModelInfo `json:"models"` -} - // SessionMetadata contains metadata about a session type SessionMetadata struct { SessionID string `json:"sessionId"` @@ -576,22 +536,6 @@ type SessionMetadata struct { IsRemote bool `json:"isRemote"` } -// ListSessionsResponse is the response from session.list -type ListSessionsResponse struct { - Sessions []SessionMetadata `json:"sessions"` -} - -// DeleteSessionRequest is the request for session.delete -type DeleteSessionRequest struct { - SessionID string `json:"sessionId"` -} - -// DeleteSessionResponse is the response from session.delete -type DeleteSessionResponse struct { - Success bool `json:"success"` - Error *string `json:"error,omitempty"` -} - // SessionLifecycleEventType represents the type of session lifecycle event type SessionLifecycleEventType string @@ -620,19 +564,218 @@ type SessionLifecycleEventMetadata struct { // SessionLifecycleHandler is a callback for session lifecycle events type SessionLifecycleHandler func(event SessionLifecycleEvent) -// GetForegroundSessionResponse is the response from session.getForeground -type GetForegroundSessionResponse struct { +// permissionRequestRequest represents the request data for a permission request +type permissionRequestRequest struct { + SessionID string `json:"sessionId"` + Request PermissionRequest `json:"permissionRequest"` +} + +// permissionRequestResponse represents the response to a permission request +type permissionRequestResponse struct { + Result PermissionRequestResult `json:"result"` +} + +// createSessionRequest is the request for session.create +type createSessionRequest struct { + Model string `json:"model,omitempty"` + SessionID string `json:"sessionId,omitempty"` + ReasoningEffort string `json:"reasoningEffort,omitempty"` + Tools []Tool `json:"tools,omitempty"` + SystemMessage *SystemMessageConfig `json:"systemMessage,omitempty"` + AvailableTools []string `json:"availableTools,omitempty"` + ExcludedTools []string `json:"excludedTools,omitempty"` + Provider *ProviderConfig `json:"provider,omitempty"` + RequestPermission *bool `json:"requestPermission,omitempty"` + RequestUserInput *bool `json:"requestUserInput,omitempty"` + Hooks *bool `json:"hooks,omitempty"` + WorkingDirectory string `json:"workingDirectory,omitempty"` + Streaming *bool `json:"streaming,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` + CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` + ConfigDir string `json:"configDir,omitempty"` + SkillDirectories []string `json:"skillDirectories,omitempty"` + DisabledSkills []string `json:"disabledSkills,omitempty"` + InfiniteSessions *InfiniteSessionConfig `json:"infiniteSessions,omitempty"` +} + +// createSessionResponse is the response from session.create +type createSessionResponse struct { + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` +} + +// resumeSessionRequest is the request for session.resume +type resumeSessionRequest struct { + SessionID string `json:"sessionId"` + ReasoningEffort string `json:"reasoningEffort,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Provider *ProviderConfig `json:"provider,omitempty"` + RequestPermission *bool `json:"requestPermission,omitempty"` + RequestUserInput *bool `json:"requestUserInput,omitempty"` + Hooks *bool `json:"hooks,omitempty"` + WorkingDirectory string `json:"workingDirectory,omitempty"` + DisableResume *bool `json:"disableResume,omitempty"` + Streaming *bool `json:"streaming,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` + CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` + SkillDirectories []string `json:"skillDirectories,omitempty"` + DisabledSkills []string `json:"disabledSkills,omitempty"` +} + +// resumeSessionResponse is the response from session.resume +type resumeSessionResponse struct { + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` +} + +type hooksInvokeRequest struct { + SessionID string `json:"sessionId"` + Type string `json:"hookType"` + Input json.RawMessage `json:"input"` +} + +// listSessionsRequest is the request for session.list +type listSessionsRequest struct{} + +// listSessionsResponse is the response from session.list +type listSessionsResponse struct { + Sessions []SessionMetadata `json:"sessions"` +} + +// deleteSessionRequest is the request for session.delete +type deleteSessionRequest struct { + SessionID string `json:"sessionId"` +} + +// deleteSessionResponse is the response from session.delete +type deleteSessionResponse struct { + Success bool `json:"success"` + Error *string `json:"error,omitempty"` +} + +// getForegroundSessionRequest is the request for session.getForeground +type getForegroundSessionRequest struct{} + +// getForegroundSessionResponse is the response from session.getForeground +type getForegroundSessionResponse struct { SessionID *string `json:"sessionId,omitempty"` WorkspacePath *string `json:"workspacePath,omitempty"` } -// SetForegroundSessionRequest is the request for session.setForeground -type SetForegroundSessionRequest struct { +// setForegroundSessionRequest is the request for session.setForeground +type setForegroundSessionRequest struct { SessionID string `json:"sessionId"` } -// SetForegroundSessionResponse is the response from session.setForeground -type SetForegroundSessionResponse struct { +// setForegroundSessionResponse is the response from session.setForeground +type setForegroundSessionResponse struct { Success bool `json:"success"` Error *string `json:"error,omitempty"` } + +type pingRequest struct { + Message string `json:"message,omitempty"` +} + +// pingResponse is the response from a ping request +type pingResponse struct { + Message string `json:"message"` + Timestamp int64 `json:"timestamp"` + ProtocolVersion *int `json:"protocolVersion,omitempty"` +} + +// getStatusRequest is the request for status.get +type getStatusRequest struct{} + +// getStatusResponse is the response from status.get +type getStatusResponse struct { + Version string `json:"version"` + ProtocolVersion int `json:"protocolVersion"` +} + +// getAuthStatusRequest is the request for auth.getStatus +type getAuthStatusRequest struct{} + +// getAuthStatusResponse is the response from auth.getStatus +type getAuthStatusResponse struct { + IsAuthenticated bool `json:"isAuthenticated"` + AuthType *string `json:"authType,omitempty"` + Host *string `json:"host,omitempty"` + Login *string `json:"login,omitempty"` + StatusMessage *string `json:"statusMessage,omitempty"` +} + +// listModelsRequest is the request for models.list +type listModelsRequest struct{} + +// listModelsResponse is the response from models.list +type listModelsResponse struct { + Models []ModelInfo `json:"models"` +} + +// sessionGetMessagesRequest is the request for session.getMessages +type sessionGetMessagesRequest struct { + SessionID string `json:"sessionId"` +} + +// sessionGetMessagesResponse is the response from session.getMessages +type sessionGetMessagesResponse struct { + Events []SessionEvent `json:"events"` +} + +// sessionDestroyRequest is the request for session.destroy +type sessionDestroyRequest struct { + SessionID string `json:"sessionId"` +} + +// sessionAbortRequest is the request for session.abort +type sessionAbortRequest struct { + SessionID string `json:"sessionId"` +} + +type sessionSendRequest struct { + SessionID string `json:"sessionId"` + Prompt string `json:"prompt"` + Attachments []Attachment `json:"attachments,omitempty"` + Mode string `json:"mode,omitempty"` +} + +// sessionSendResponse is the response from session.send +type sessionSendResponse struct { + MessageID string `json:"messageId"` +} + +// sessionEventRequest is the request for session event notifications +type sessionEventRequest struct { + SessionID string `json:"sessionId"` + Event SessionEvent `json:"event"` +} + +// toolCallRequest represents a tool call request from the server +// to the client for execution. +type toolCallRequest struct { + SessionID string `json:"sessionId"` + ToolCallID string `json:"toolCallId"` + ToolName string `json:"toolName"` + Arguments any `json:"arguments"` +} + +// toolCallResponse represents the response to a tool call request +// from the client back to the server. +type toolCallResponse struct { + Result ToolResult `json:"result"` +} + +// userInputRequest represents a request for user input from the agent +type userInputRequest struct { + SessionID string `json:"sessionId"` + Question string `json:"question"` + Choices []string `json:"choices,omitempty"` + AllowFreeform *bool `json:"allowFreeform,omitempty"` +} + +// userInputResponse represents the user's response to an input request +type userInputResponse struct { + Answer string `json:"answer"` + WasFreeform bool `json:"wasFreeform"` +} From 94c5f50545b1a5729acddc814d64cebce0893338 Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Wed, 4 Feb 2026 16:51:24 +0100 Subject: [PATCH 02/15] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- go/client.go | 4 ++-- go/internal/jsonrpc2/jsonrpc2.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/client.go b/go/client.go index a6d3e950..53a38b23 100644 --- a/go/client.go +++ b/go/client.go @@ -546,7 +546,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, if config.OnPermissionRequest != nil { req.RequestPermission = Bool(true) } - if config.OnPermissionRequest != nil { + if config.OnUserInputRequest != nil { req.RequestUserInput = Bool(true) } if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || @@ -817,7 +817,7 @@ func (c *Client) OnEventType(eventType SessionLifecycleEventType, handler Sessio } } -// dispatchLifecycleEvent dispatches a lifecycle event to all registered handlers +// handleLifecycleEvent dispatches a lifecycle event to all registered handlers func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) { c.lifecycleHandlersMux.Lock() // Copy handlers to avoid holding lock during callbacks diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index 1a6e17d1..a226f11f 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -98,7 +98,7 @@ func NotificationHandlerFor[In any](handler func(params In)) RequestHandler { var in In // If In is a pointer type, allocate the underlying value and unmarshal into it directly var target any = &in - if t := reflect.TypeOf(in); t != nil && t.Kind() == reflect.Pointer { + if t := reflect.TypeFor[In](); t.Kind() == reflect.Pointer { in = reflect.New(t.Elem()).Interface().(In) target = in } From d00a6147e472b2aaf8eefa12b30afebfbcd60a0f Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 4 Feb 2026 16:54:02 +0100 Subject: [PATCH 03/15] reexport some types --- go/client.go | 8 ++++---- go/types.go | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/go/client.go b/go/client.go index 53a38b23..b89f2f36 100644 --- a/go/client.go +++ b/go/client.go @@ -890,7 +890,7 @@ func (c *Client) Ping(ctx context.Context, message string) (*pingResponse, error } // GetStatus returns CLI status including version and protocol information -func (c *Client) GetStatus(ctx context.Context) (*getStatusResponse, error) { +func (c *Client) GetStatus(ctx context.Context) (*GetStatusResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } @@ -900,7 +900,7 @@ func (c *Client) GetStatus(ctx context.Context) (*getStatusResponse, error) { return nil, err } - var response getStatusResponse + var response GetStatusResponse if err := json.Unmarshal(result, &response); err != nil { return nil, err } @@ -908,7 +908,7 @@ func (c *Client) GetStatus(ctx context.Context) (*getStatusResponse, error) { } // GetAuthStatus returns current authentication status -func (c *Client) GetAuthStatus(ctx context.Context) (*getAuthStatusResponse, error) { +func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } @@ -918,7 +918,7 @@ func (c *Client) GetAuthStatus(ctx context.Context) (*getAuthStatusResponse, err return nil, err } - var response getAuthStatusResponse + var response GetAuthStatusResponse if err := json.Unmarshal(result, &response); err != nil { return nil, err } diff --git a/go/types.go b/go/types.go index 9ca57dc7..1dee8a43 100644 --- a/go/types.go +++ b/go/types.go @@ -687,8 +687,8 @@ type pingResponse struct { // getStatusRequest is the request for status.get type getStatusRequest struct{} -// getStatusResponse is the response from status.get -type getStatusResponse struct { +// GetStatusResponse is the response from status.get +type GetStatusResponse struct { Version string `json:"version"` ProtocolVersion int `json:"protocolVersion"` } @@ -696,8 +696,8 @@ type getStatusResponse struct { // getAuthStatusRequest is the request for auth.getStatus type getAuthStatusRequest struct{} -// getAuthStatusResponse is the response from auth.getStatus -type getAuthStatusResponse struct { +// GetAuthStatusResponse is the response from auth.getStatus +type GetAuthStatusResponse struct { IsAuthenticated bool `json:"isAuthenticated"` AuthType *string `json:"authType,omitempty"` Host *string `json:"host,omitempty"` From 107d5455b88d985681abbb2e587cf93c08f4c7c8 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 4 Feb 2026 17:24:46 +0100 Subject: [PATCH 04/15] fix race --- go/internal/e2e/mcp_and_agents_test.go | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/go/internal/e2e/mcp_and_agents_test.go b/go/internal/e2e/mcp_and_agents_test.go index 1d21651b..244589a1 100644 --- a/go/internal/e2e/mcp_and_agents_test.go +++ b/go/internal/e2e/mcp_and_agents_test.go @@ -37,18 +37,13 @@ func TestMCPServers(t *testing.T) { } // Simple interaction to verify session works - _, err = session.Send(t.Context(), copilot.MessageOptions{ + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is 2+2?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) - } - if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "4") { t.Errorf("Expected message to contain '4', got: %v", message.Data.Content) } @@ -168,18 +163,13 @@ func TestCustomAgents(t *testing.T) { } // Simple interaction to verify session works - _, err = session.Send(t.Context(), copilot.MessageOptions{ + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is 5+5?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) - } - if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "10") { t.Errorf("Expected message to contain '10', got: %v", message.Data.Content) } @@ -373,18 +363,13 @@ func TestCombinedConfiguration(t *testing.T) { t.Error("Expected non-empty session ID") } - _, err = session.Send(t.Context(), copilot.MessageOptions{ + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is 7+7?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) - } - if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "14") { t.Errorf("Expected message to contain '14', got: %v", message.Data.Content) } From 55759d750b628bff1c5f657324f674fefd279444 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 4 Feb 2026 17:40:12 +0100 Subject: [PATCH 05/15] fix test --- go/internal/e2e/session_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/go/internal/e2e/session_test.go b/go/internal/e2e/session_test.go index 62183286..5d225b35 100644 --- a/go/internal/e2e/session_test.go +++ b/go/internal/e2e/session_test.go @@ -68,6 +68,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") + } + if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { t.Errorf("Expected assistant message to contain '2', got %v", assistantMessage.Data.Content) } @@ -77,6 +81,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send second message: %v", err) } + if secondMessage == nil { + t.Fatal("Expected second assistant message, got nil") + } + if secondMessage.Data.Content == nil || !strings.Contains(*secondMessage.Data.Content, "4") { t.Errorf("Expected second message to contain '4', got %v", secondMessage.Data.Content) } From 76a805bb4c5f931035061d1867fdd58da40cd097 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 5 Feb 2026 10:22:55 +0100 Subject: [PATCH 06/15] export PingResponse --- go/client.go | 4 ++-- go/types.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go/client.go b/go/client.go index b89f2f36..6692e93d 100644 --- a/go/client.go +++ b/go/client.go @@ -872,7 +872,7 @@ func (c *Client) State() ConnectionState { // } else { // log.Printf("Server responded at %d", resp.Timestamp) // } -func (c *Client) Ping(ctx context.Context, message string) (*pingResponse, error) { +func (c *Client) Ping(ctx context.Context, message string) (*PingResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } @@ -882,7 +882,7 @@ func (c *Client) Ping(ctx context.Context, message string) (*pingResponse, error return nil, err } - var response pingResponse + var response PingResponse if err := json.Unmarshal(result, &response); err != nil { return nil, err } diff --git a/go/types.go b/go/types.go index 1dee8a43..b421c85c 100644 --- a/go/types.go +++ b/go/types.go @@ -677,8 +677,8 @@ type pingRequest struct { Message string `json:"message,omitempty"` } -// pingResponse is the response from a ping request -type pingResponse struct { +// PingResponse is the response from a ping request +type PingResponse struct { Message string `json:"message"` Timestamp int64 `json:"timestamp"` ProtocolVersion *int `json:"protocolVersion,omitempty"` From 35eb0929c282125d78db7c160459ab5438d0b800 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 5 Feb 2026 17:21:44 +0100 Subject: [PATCH 07/15] use SendAndWait instead of GetFinalAssistantMessage --- go/internal/e2e/permissions_test.go | 14 +-- go/internal/e2e/session_test.go | 117 ++++++-------------------- go/internal/e2e/testharness/helper.go | 102 ---------------------- go/internal/e2e/tools_test.go | 28 +----- go/session.go | 6 +- 5 files changed, 33 insertions(+), 234 deletions(-) diff --git a/go/internal/e2e/permissions_test.go b/go/internal/e2e/permissions_test.go index a891548c..ad8485e3 100644 --- a/go/internal/e2e/permissions_test.go +++ b/go/internal/e2e/permissions_test.go @@ -134,18 +134,13 @@ func TestPermissions(t *testing.T) { t.Fatalf("Failed to write test file: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{ + _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "Edit protected.txt and replace 'protected' with 'hacked'.", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - _, err = testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) - } - // Verify the file was NOT modified content, err := os.ReadFile(testFile) if err != nil { @@ -165,16 +160,11 @@ func TestPermissions(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) - } - if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "4") { t.Errorf("Expected message to contain '4', got: %v", message.Data.Content) } diff --git a/go/internal/e2e/session_test.go b/go/internal/e2e/session_test.go index 5d225b35..846cf1a4 100644 --- a/go/internal/e2e/session_test.go +++ b/go/internal/e2e/session_test.go @@ -2,6 +2,7 @@ package e2e import ( "regexp" + "slices" "strings" "testing" "time" @@ -152,16 +153,11 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is your full name?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is your full name?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - content := "" if assistantMessage.Data.Content != nil { content = *assistantMessage.Data.Content @@ -198,16 +194,11 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - _, err = testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - // Validate that only the specified tools are present traffic, err := ctx.GetExchanges() if err != nil { @@ -236,16 +227,11 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - _, err = testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - // Validate that excluded tool is not present but others are traffic, err := ctx.GetExchanges() if err != nil { @@ -303,16 +289,11 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is the secret number for key ALPHA?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is the secret number for key ALPHA?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - content := "" if assistantMessage.Data.Content != nil { content = *assistantMessage.Data.Content @@ -337,16 +318,11 @@ func TestSession(t *testing.T) { } sessionID := session1.SessionID - _, err = session1.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + answer, err := session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session1) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "2") { t.Errorf("Expected answer to contain '2', got %v", answer.Data.Content) } @@ -361,13 +337,21 @@ func TestSession(t *testing.T) { t.Errorf("Expected resumed session ID to match, got %q vs %q", session2.SessionID, sessionID) } - answer2, err := testharness.GetFinalAssistantMessage(t.Context(), session2) + messages, err := session2.GetMessages(t.Context()) if err != nil { - t.Fatalf("Failed to get assistant message from resumed session: %v", err) + t.Fatalf("Failed to get messages: %v", err) + } + + answer2Idx := slices.IndexFunc(messages, func(m copilot.SessionEvent) bool { + return m.Type == "assistant.message" + }) + + if answer2Idx == -1 { + t.Fatalf("Expected to find an assistant.message in resumed session messages, got %v", messages) } - if answer2.Data.Content == nil || !strings.Contains(*answer2.Data.Content, "2") { - t.Errorf("Expected resumed session answer to contain '2', got %v", answer2.Data.Content) + if messages[answer2Idx].Data.Content == nil || !strings.Contains(*messages[answer2Idx].Data.Content, "2") { + t.Errorf("Expected resumed session answer to contain '2', got %v", messages[answer2Idx].Data.Content) } }) @@ -381,16 +365,11 @@ func TestSession(t *testing.T) { } sessionID := session1.SessionID - _, err = session1.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + answer, err := session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session1) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "2") { t.Errorf("Expected answer to contain '2', got %v", answer.Data.Content) } @@ -570,42 +549,28 @@ func TestSession(t *testing.T) { } var deltaContents []string - done := make(chan bool) - session.On(func(event copilot.SessionEvent) { + unsubscribe := session.On(func(event copilot.SessionEvent) { switch event.Type { case "assistant.message_delta": if event.Data.DeltaContent != nil { deltaContents = append(deltaContents, *event.Data.DeltaContent) } - case "session.idle": - close(done) + case "assistant.message": } }) - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) + unsubscribe() if err != nil { t.Fatalf("Failed to send message: %v", err) } - // Wait for completion - select { - case <-done: - case <-time.After(60 * time.Second): - t.Fatal("Timed out waiting for session.idle") - } - // Should have received delta events if len(deltaContents) == 0 { t.Error("Expected to receive delta events, got none") } - // Get the final message to compare - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - // Accumulated deltas should equal the final message accumulated := strings.Join(deltaContents, "") if assistantMessage.Data.Content != nil && accumulated != *assistantMessage.Data.Content { @@ -635,16 +600,11 @@ func TestSession(t *testing.T) { } // Session should still work normally - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { t.Errorf("Expected assistant message to contain '2', got %v", assistantMessage.Data.Content) } @@ -659,31 +619,16 @@ func TestSession(t *testing.T) { } var receivedEvents []copilot.SessionEvent - idle := make(chan bool) - session.On(func(event copilot.SessionEvent) { receivedEvents = append(receivedEvents, event) - if event.Type == "session.idle" { - select { - case idle <- true: - default: - } - } }) // Send a message to trigger events - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 100+200?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 100+200?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - // Wait for session to become idle - select { - case <-idle: - case <-time.After(60 * time.Second): - t.Fatal("Timed out waiting for session.idle") - } - // Should have received multiple events if len(receivedEvents) == 0 { t.Error("Expected to receive events, got none") @@ -713,11 +658,6 @@ func TestSession(t *testing.T) { t.Error("Expected to receive session.idle event") } - // Verify the assistant response contains the expected answer - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "300") { t.Errorf("Expected assistant message to contain '300', got %v", assistantMessage.Data.Content) } @@ -740,16 +680,11 @@ func TestSession(t *testing.T) { } // Session should work normally with custom config dir - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { t.Errorf("Expected assistant message to contain '2', got %v", assistantMessage.Data.Content) } diff --git a/go/internal/e2e/testharness/helper.go b/go/internal/e2e/testharness/helper.go index 05947c80..c523b6db 100644 --- a/go/internal/e2e/testharness/helper.go +++ b/go/internal/e2e/testharness/helper.go @@ -1,60 +1,12 @@ package testharness import ( - "context" "errors" "time" copilot "github.com/github/copilot-sdk/go" ) -// GetFinalAssistantMessage waits for and returns the final assistant message from a session turn. -func GetFinalAssistantMessage(ctx context.Context, session *copilot.Session) (*copilot.SessionEvent, error) { - result := make(chan *copilot.SessionEvent, 1) - errCh := make(chan error, 1) - - // Subscribe to future events - var finalAssistantMessage *copilot.SessionEvent - unsubscribe := session.On(func(event copilot.SessionEvent) { - switch event.Type { - case "assistant.message": - finalAssistantMessage = &event - case "session.idle": - if finalAssistantMessage != nil { - result <- finalAssistantMessage - } - case "session.error": - msg := "session error" - if event.Data.Message != nil { - msg = *event.Data.Message - } - errCh <- errors.New(msg) - } - }) - defer unsubscribe() - - // Also check existing messages in case the response already arrived - go func() { - existing, err := getExistingFinalResponse(ctx, session) - if err != nil { - errCh <- err - return - } - if existing != nil { - result <- existing - } - }() - - select { - case msg := <-result: - return msg, nil - case err := <-errCh: - return nil, err - case <-ctx.Done(): - return nil, errors.New("timeout waiting for assistant message") - } -} - // GetNextEventOfType waits for and returns the next event of the specified type from a session. func GetNextEventOfType(session *copilot.Session, eventType copilot.SessionEventType, timeout time.Duration) (*copilot.SessionEvent, error) { result := make(chan *copilot.SessionEvent, 1) @@ -89,57 +41,3 @@ func GetNextEventOfType(session *copilot.Session, eventType copilot.SessionEvent return nil, errors.New("timeout waiting for event: " + string(eventType)) } } - -func getExistingFinalResponse(ctx context.Context, session *copilot.Session) (*copilot.SessionEvent, error) { - messages, err := session.GetMessages(ctx) - if err != nil { - return nil, err - } - - // Find last user message - finalUserMessageIndex := -1 - for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Type == "user.message" { - finalUserMessageIndex = i - break - } - } - - var currentTurnMessages []copilot.SessionEvent - if finalUserMessageIndex < 0 { - currentTurnMessages = messages - } else { - currentTurnMessages = messages[finalUserMessageIndex:] - } - - // Check for errors - for _, msg := range currentTurnMessages { - if msg.Type == "session.error" { - errMsg := "session error" - if msg.Data.Message != nil { - errMsg = *msg.Data.Message - } - return nil, errors.New(errMsg) - } - } - - // Find session.idle and get last assistant message before it - sessionIdleIndex := -1 - for i, msg := range currentTurnMessages { - if msg.Type == "session.idle" { - sessionIdleIndex = i - break - } - } - - if sessionIdleIndex != -1 { - // Find last assistant.message before session.idle - for i := sessionIdleIndex - 1; i >= 0; i-- { - if currentTurnMessages[i].Type == "assistant.message" { - return ¤tTurnMessages[i], nil - } - } - } - - return nil, nil -} diff --git a/go/internal/e2e/tools_test.go b/go/internal/e2e/tools_test.go index 5af9079c..f9ddcbc3 100644 --- a/go/internal/e2e/tools_test.go +++ b/go/internal/e2e/tools_test.go @@ -30,16 +30,11 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What's the first line of README.md in this directory?"}) + answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What's the first line of README.md in this directory?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "ELIZA") { t.Errorf("Expected answer to contain 'ELIZA', got %v", answer.Data.Content) } @@ -64,16 +59,11 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "Use encrypt_string to encrypt this string: Hello"}) + answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Use encrypt_string to encrypt this string: Hello"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "HELLO") { t.Errorf("Expected answer to contain 'HELLO', got %v", answer.Data.Content) } @@ -96,18 +86,13 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{ + answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is my location? If you can't find out, just say 'unknown'.", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - // Check the underlying traffic traffic, err := ctx.GetExchanges() if err != nil { @@ -213,7 +198,7 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{ + answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "Perform a DB query for the 'cities' table using IDs 12 and 19, sorting ascending. " + "Reply only with lines of the form: [cityname] [population]", }) @@ -221,11 +206,6 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - if answer == nil || answer.Data.Content == nil { t.Fatalf("Expected assistant message with content") } diff --git a/go/session.go b/go/session.go index 5d494710..37cfe52f 100644 --- a/go/session.go +++ b/go/session.go @@ -564,11 +564,7 @@ func (s *Session) Destroy() error { // log.Printf("Failed to abort: %v", err) // } func (s *Session) Abort(ctx context.Context) error { - req := sessionAbortRequest{ - SessionID: s.SessionID, - } - - _, err := s.client.Request("session.abort", req) + _, err := s.client.Request("session.abort", sessionAbortRequest{SessionID: s.SessionID}) if err != nil { return fmt.Errorf("failed to abort session: %w", err) } From cf2ce0732acfec68ab04c956b901fdd246cd5906 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 5 Feb 2026 17:25:45 +0100 Subject: [PATCH 08/15] fix nil access --- go/internal/e2e/mcp_and_agents_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/go/internal/e2e/mcp_and_agents_test.go b/go/internal/e2e/mcp_and_agents_test.go index 244589a1..e19bd0a8 100644 --- a/go/internal/e2e/mcp_and_agents_test.go +++ b/go/internal/e2e/mcp_and_agents_test.go @@ -370,6 +370,10 @@ func TestCombinedConfiguration(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatalf("Expected a message, got nil") + } + if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "14") { t.Errorf("Expected message to contain '14', got: %v", message.Data.Content) } From 90c8e9dff28ad159e02597db76fe7a8a0bb92092 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 5 Feb 2026 17:30:48 +0100 Subject: [PATCH 09/15] fix nil access --- go/internal/e2e/compaction_test.go | 3 +++ go/internal/e2e/mcp_and_agents_test.go | 8 ++++++++ go/internal/e2e/session_test.go | 23 +++++++++++++++++++++-- go/internal/e2e/skills_test.go | 8 ++++++++ go/internal/e2e/tools_test.go | 12 ++++++++++++ 5 files changed, 52 insertions(+), 2 deletions(-) diff --git a/go/internal/e2e/compaction_test.go b/go/internal/e2e/compaction_test.go index da9ea240..5fae9393 100644 --- a/go/internal/e2e/compaction_test.go +++ b/go/internal/e2e/compaction_test.go @@ -83,6 +83,9 @@ func TestCompaction(t *testing.T) { if err != nil { t.Fatalf("Failed to send verification message: %v", err) } + if answer == nil { + t.Fatalf("Expected an answer, got nil") + } if answer.Data.Content == nil || !strings.Contains(strings.ToLower(*answer.Data.Content), "dragon") { t.Errorf("Expected answer to contain 'dragon', got %v", answer.Data.Content) } diff --git a/go/internal/e2e/mcp_and_agents_test.go b/go/internal/e2e/mcp_and_agents_test.go index e19bd0a8..40f5ba3f 100644 --- a/go/internal/e2e/mcp_and_agents_test.go +++ b/go/internal/e2e/mcp_and_agents_test.go @@ -44,6 +44,10 @@ func TestMCPServers(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatal("Expected a message, got nil") + } + if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "4") { t.Errorf("Expected message to contain '4', got: %v", message.Data.Content) } @@ -170,6 +174,10 @@ func TestCustomAgents(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatal("Expected a message, got nil") + } + if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "10") { t.Errorf("Expected message to contain '10', got: %v", message.Data.Content) } diff --git a/go/internal/e2e/session_test.go b/go/internal/e2e/session_test.go index 846cf1a4..2a16e18e 100644 --- a/go/internal/e2e/session_test.go +++ b/go/internal/e2e/session_test.go @@ -533,6 +533,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message after abort: %v", err) } + if answer == nil { + t.Fatalf("Expected an answer after abort, got nil") + } + if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "4") { t.Errorf("Expected answer to contain '4', got %v", answer.Data.Content) } @@ -550,7 +554,7 @@ func TestSession(t *testing.T) { var deltaContents []string - unsubscribe := session.On(func(event copilot.SessionEvent) { + session.On(func(event copilot.SessionEvent) { switch event.Type { case "assistant.message_delta": if event.Data.DeltaContent != nil { @@ -561,11 +565,14 @@ func TestSession(t *testing.T) { }) assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) - unsubscribe() if err != nil { t.Fatalf("Failed to send message: %v", err) } + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") + } + // Should have received delta events if len(deltaContents) == 0 { t.Error("Expected to receive delta events, got none") @@ -605,6 +612,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") + } + if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { t.Errorf("Expected assistant message to contain '2', got %v", assistantMessage.Data.Content) } @@ -629,6 +640,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") + } + // Should have received multiple events if len(receivedEvents) == 0 { t.Error("Expected to receive events, got none") @@ -685,6 +700,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") + } + if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { t.Errorf("Expected assistant message to contain '2', got %v", assistantMessage.Data.Content) } diff --git a/go/internal/e2e/skills_test.go b/go/internal/e2e/skills_test.go index ed3578ab..f49f6688 100644 --- a/go/internal/e2e/skills_test.go +++ b/go/internal/e2e/skills_test.go @@ -71,6 +71,10 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatalf("Expected a message, got nil") + } + if message.Data.Content == nil || !strings.Contains(*message.Data.Content, skillMarker) { t.Errorf("Expected message to contain skill marker '%s', got: %v", skillMarker, message.Data.Content) } @@ -99,6 +103,10 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatalf("Expected a message, got nil") + } + if message.Data.Content != nil && strings.Contains(*message.Data.Content, skillMarker) { t.Errorf("Expected message to NOT contain skill marker '%s' when disabled, got: %v", skillMarker, *message.Data.Content) } diff --git a/go/internal/e2e/tools_test.go b/go/internal/e2e/tools_test.go index f9ddcbc3..b6af6ef0 100644 --- a/go/internal/e2e/tools_test.go +++ b/go/internal/e2e/tools_test.go @@ -35,6 +35,10 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if answer == nil { + t.Fatalf("Expected an answer, got nil") + } + if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "ELIZA") { t.Errorf("Expected answer to contain 'ELIZA', got %v", answer.Data.Content) } @@ -64,6 +68,10 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if answer == nil { + t.Fatalf("Expected an answer, got nil") + } + if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "HELLO") { t.Errorf("Expected answer to contain 'HELLO', got %v", answer.Data.Content) } @@ -93,6 +101,10 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if answer == nil { + t.Fatalf("Expected an answer, got nil") + } + // Check the underlying traffic traffic, err := ctx.GetExchanges() if err != nil { From 668b22f153959e01c3466078b6c438a9e4056b91 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 5 Feb 2026 17:34:14 +0100 Subject: [PATCH 10/15] fix nil access --- go/internal/e2e/session_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/go/internal/e2e/session_test.go b/go/internal/e2e/session_test.go index 2a16e18e..6fb05051 100644 --- a/go/internal/e2e/session_test.go +++ b/go/internal/e2e/session_test.go @@ -158,6 +158,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") + } + content := "" if assistantMessage.Data.Content != nil { content = *assistantMessage.Data.Content @@ -294,6 +298,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") + } + content := "" if assistantMessage.Data.Content != nil { content = *assistantMessage.Data.Content @@ -323,6 +331,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if answer == nil { + t.Fatalf("Expected an answer, got nil") + } + if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "2") { t.Errorf("Expected answer to contain '2', got %v", answer.Data.Content) } @@ -370,6 +382,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if answer == nil { + t.Fatalf("Expected an answer, got nil") + } + if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "2") { t.Errorf("Expected answer to contain '2', got %v", answer.Data.Content) } From d2d5dec772266c4a1423bd849a5975916f274468 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 5 Feb 2026 17:38:22 +0100 Subject: [PATCH 11/15] fix nil access --- go/internal/e2e/mcp_and_agents_test.go | 8 ++++++++ go/internal/e2e/permissions_test.go | 4 ++++ go/internal/e2e/skills_test.go | 8 ++++++++ 3 files changed, 20 insertions(+) diff --git a/go/internal/e2e/mcp_and_agents_test.go b/go/internal/e2e/mcp_and_agents_test.go index 40f5ba3f..33ad8479 100644 --- a/go/internal/e2e/mcp_and_agents_test.go +++ b/go/internal/e2e/mcp_and_agents_test.go @@ -96,6 +96,10 @@ func TestMCPServers(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatalf("Expected a message, got nil") + } + if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "6") { t.Errorf("Expected message to contain '6', got: %v", message.Data.Content) } @@ -226,6 +230,10 @@ func TestCustomAgents(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatalf("Expected a message, got nil") + } + if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "12") { t.Errorf("Expected message to contain '12', got: %v", message.Data.Content) } diff --git a/go/internal/e2e/permissions_test.go b/go/internal/e2e/permissions_test.go index ad8485e3..cde53b1d 100644 --- a/go/internal/e2e/permissions_test.go +++ b/go/internal/e2e/permissions_test.go @@ -165,6 +165,10 @@ func TestPermissions(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatal("Expected a message, got nil") + } + if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "4") { t.Errorf("Expected message to contain '4', got: %v", message.Data.Content) } diff --git a/go/internal/e2e/skills_test.go b/go/internal/e2e/skills_test.go index f49f6688..52367422 100644 --- a/go/internal/e2e/skills_test.go +++ b/go/internal/e2e/skills_test.go @@ -133,6 +133,10 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message1 == nil { + t.Fatalf("Expected a message, got nil") + } + if message1.Data.Content != nil && strings.Contains(*message1.Data.Content, skillMarker) { t.Errorf("Expected message to NOT contain skill marker before skill was added, got: %v", *message1.Data.Content) } @@ -155,6 +159,10 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message2 == nil { + t.Fatalf("Expected a message, got nil") + } + if message2.Data.Content == nil || !strings.Contains(*message2.Data.Content, skillMarker) { t.Errorf("Expected message to contain skill marker '%s' after resume, got: %v", skillMarker, message2.Data.Content) } From 101df6f78f19c3b61f68d28e3b783de2b5302baf Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 5 Feb 2026 17:44:25 +0100 Subject: [PATCH 12/15] fix race --- go/internal/jsonrpc2/jsonrpc2.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index a226f11f..e44e1231 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -326,20 +326,20 @@ func (c *Client) handleRequest(request *Request) { return } + // Notifications run synchronously, calls run in a goroutine to avoid blocking + if !request.IsCall() { + handler(request.Params) + return + } + go func() { defer func() { if r := recover(); r != nil { - if request.IsCall() { - c.sendErrorResponse(request.ID, -32603, fmt.Sprintf("request handler panic: %v", r), nil) - } + c.sendErrorResponse(request.ID, -32603, fmt.Sprintf("request handler panic: %v", r), nil) } }() result, err := handler(request.Params) - if !request.IsCall() { - // Only send a response if this is a call - return - } if err != nil { c.sendErrorResponse(request.ID, err.Code, err.Message, err.Data) return From d46feaaf9bd859b109ed607395d60be0dc87dc17 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 5 Feb 2026 20:27:57 +0100 Subject: [PATCH 13/15] remove if --- go/client.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/go/client.go b/go/client.go index 35f164b4..c66801a8 100644 --- a/go/client.go +++ b/go/client.go @@ -566,9 +566,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, if config.DisableResume { req.DisableResume = Bool(true) } - if len(config.MCPServers) > 0 { - req.MCPServers = config.MCPServers - } + req.MCPServers = config.MCPServers req.CustomAgents = config.CustomAgents req.SkillDirectories = config.SkillDirectories req.DisabledSkills = config.DisabledSkills From 79596295956a4952e71c7742c0bdc34f2722a1ec Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 6 Feb 2026 13:18:32 +0100 Subject: [PATCH 14/15] remove getFinalAssistantMessage from nodejs test harness --- nodejs/test/e2e/harness/sdkTestHelper.ts | 73 +----------------------- nodejs/test/e2e/session.test.ts | 34 ++++++----- 2 files changed, 20 insertions(+), 87 deletions(-) diff --git a/nodejs/test/e2e/harness/sdkTestHelper.ts b/nodejs/test/e2e/harness/sdkTestHelper.ts index 4e8ff203..f94f606e 100644 --- a/nodejs/test/e2e/harness/sdkTestHelper.ts +++ b/nodejs/test/e2e/harness/sdkTestHelper.ts @@ -2,78 +2,7 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { AssistantMessageEvent, CopilotSession, SessionEvent } from "../../../src"; - -export async function getFinalAssistantMessage( - session: CopilotSession -): Promise { - // We don't know whether the answer has already arrived or not, so race both possibilities - return new Promise(async (resolve, reject) => { - getFutureFinalResponse(session).then(resolve).catch(reject); - getExistingFinalResponse(session) - .then((msg) => { - if (msg) { - resolve(msg); - } - }) - .catch(reject); - }); -} - -function getExistingFinalResponse( - session: CopilotSession -): Promise { - return new Promise(async (resolve, reject) => { - const messages = await session.getMessages(); - const finalUserMessageIndex = messages.findLastIndex((m) => m.type === "user.message"); - const currentTurnMessages = - finalUserMessageIndex < 0 ? messages : messages.slice(finalUserMessageIndex); - - const currentTurnError = currentTurnMessages.find((m) => m.type === "session.error"); - if (currentTurnError) { - const error = new Error(currentTurnError.data.message); - error.stack = currentTurnError.data.stack; - reject(error); - return; - } - - const sessionIdleMessageIndex = currentTurnMessages.findIndex( - (m) => m.type === "session.idle" - ); - if (sessionIdleMessageIndex !== -1) { - const lastAssistantMessage = currentTurnMessages - .slice(0, sessionIdleMessageIndex) - .findLast((m) => m.type === "assistant.message"); - resolve(lastAssistantMessage as AssistantMessageEvent | undefined); - return; - } - - resolve(undefined); - }); -} - -function getFutureFinalResponse(session: CopilotSession): Promise { - return new Promise((resolve, reject) => { - let finalAssistantMessage: AssistantMessageEvent | undefined; - session.on((event) => { - if (event.type === "assistant.message") { - finalAssistantMessage = event; - } else if (event.type === "session.idle") { - if (!finalAssistantMessage) { - reject( - new Error("Received session.idle without a preceding assistant.message") - ); - } else { - resolve(finalAssistantMessage); - } - } else if (event.type === "session.error") { - const error = new Error(event.data.message); - error.stack = event.data.stack; - reject(error); - } - }); - }); -} +import { CopilotSession, SessionEvent } from "../../../src"; export async function retry( message: string, diff --git a/nodejs/test/e2e/session.test.ts b/nodejs/test/e2e/session.test.ts index b3fba475..78f041f7 100644 --- a/nodejs/test/e2e/session.test.ts +++ b/nodejs/test/e2e/session.test.ts @@ -1,8 +1,8 @@ import { describe, expect, it, onTestFinished } from "vitest"; import { ParsedHttpExchange } from "../../../test/harness/replayingCapiProxy.js"; -import { CopilotClient } from "../../src/index.js"; +import { CopilotClient, SessionEvent } from "../../src/index.js"; import { CLI_PATH, createSdkTestContext } from "./harness/sdkTestContext.js"; -import { getFinalAssistantMessage, getNextEventOfType } from "./harness/sdkTestHelper.js"; +import { getNextEventOfType } from "./harness/sdkTestHelper.js"; describe("Sessions", async () => { const { copilotClient: client, openAiEndpoint, homeDir, env } = await createSdkTestContext(); @@ -167,7 +167,7 @@ describe("Sessions", async () => { expect(session2.sessionId).toBe(sessionId); // TODO: There's an inconsistency here. When resuming with a new client, we don't see - // the session.idle message in the history, which means we can't use getFinalAssistantMessage. + // the session.idle message in the history, so we can't easily identify when a turn completed. const messages = await session2.getMessages(); expect(messages).toContainEqual(expect.objectContaining({ type: "user.message" })); @@ -328,9 +328,8 @@ describe("Sessions", async () => { expect(session.sessionId).toMatch(/^[a-f0-9-]+$/); // Session should work normally with custom config dir - await session.send({ prompt: "What is 1+1?" }); - const assistantMessage = await getFinalAssistantMessage(session); - expect(assistantMessage.data.content).toContain("2"); + const assistantMessage = await session.sendAndWait({ prompt: "What is 1+1?" }); + expect(assistantMessage?.data.content).toContain("2"); }); }); @@ -348,23 +347,28 @@ describe("Send Blocking Behavior", async () => { it("send returns immediately while events stream in background", async () => { const session = await client.createSession(); - const events: string[] = []; - session.on((event) => { - events.push(event.type); - }); + const events: Array = []; + session.on((event) => events.push(event)); + + // Set up promise to wait for idle BEFORE sending + const idlePromise = getNextEventOfType(session, "session.idle"); // Use a slow command so we can verify send() returns before completion await session.send({ prompt: "Run 'sleep 2 && echo done'" }); // send() should return before turn completes (no session.idle yet) - expect(events).not.toContain("session.idle"); + expect(events.some((e) => e.type === "session.idle")).toBe(false); // Wait for turn to complete - const message = await getFinalAssistantMessage(session); + await idlePromise; - expect(message.data.content).toContain("done"); - expect(events).toContain("session.idle"); - expect(events).toContain("assistant.message"); + // Find the last assistant message from collected events + const assistantMessages = events.filter((e) => e.type === "assistant.message"); + const message = assistantMessages[assistantMessages.length - 1]; + + expect(message.data?.content).toContain("done"); + expect(events.some((e) => e.type === "session.idle")).toBe(true); + expect(events.some((e) => e.type === "assistant.message")).toBe(true); }); it("sendAndWait blocks until session.idle and returns final assistant message", async () => { From 31f5f81637d877bb96beb75bb09835eff9123f80 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 6 Feb 2026 13:28:30 +0100 Subject: [PATCH 15/15] revert test changes --- go/internal/e2e/compaction_test.go | 3 - go/internal/e2e/mcp_and_agents_test.go | 29 +++--- go/internal/e2e/permissions_test.go | 14 ++- go/internal/e2e/session_test.go | 126 +++++++++++++---------- go/internal/e2e/skills_test.go | 16 --- go/internal/e2e/testharness/helper.go | 102 ++++++++++++++++++ go/internal/e2e/tools_test.go | 28 +++-- nodejs/test/e2e/harness/sdkTestHelper.ts | 73 ++++++++++++- nodejs/test/e2e/session.test.ts | 34 +++--- 9 files changed, 303 insertions(+), 122 deletions(-) diff --git a/go/internal/e2e/compaction_test.go b/go/internal/e2e/compaction_test.go index 5fae9393..da9ea240 100644 --- a/go/internal/e2e/compaction_test.go +++ b/go/internal/e2e/compaction_test.go @@ -83,9 +83,6 @@ func TestCompaction(t *testing.T) { if err != nil { t.Fatalf("Failed to send verification message: %v", err) } - if answer == nil { - t.Fatalf("Expected an answer, got nil") - } if answer.Data.Content == nil || !strings.Contains(strings.ToLower(*answer.Data.Content), "dragon") { t.Errorf("Expected answer to contain 'dragon', got %v", answer.Data.Content) } diff --git a/go/internal/e2e/mcp_and_agents_test.go b/go/internal/e2e/mcp_and_agents_test.go index 33ad8479..1d21651b 100644 --- a/go/internal/e2e/mcp_and_agents_test.go +++ b/go/internal/e2e/mcp_and_agents_test.go @@ -37,15 +37,16 @@ func TestMCPServers(t *testing.T) { } // Simple interaction to verify session works - message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ + _, err = session.Send(t.Context(), copilot.MessageOptions{ Prompt: "What is 2+2?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if message == nil { - t.Fatal("Expected a message, got nil") + message, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get final message: %v", err) } if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "4") { @@ -96,10 +97,6 @@ func TestMCPServers(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } - if message == nil { - t.Fatalf("Expected a message, got nil") - } - if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "6") { t.Errorf("Expected message to contain '6', got: %v", message.Data.Content) } @@ -171,15 +168,16 @@ func TestCustomAgents(t *testing.T) { } // Simple interaction to verify session works - message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ + _, err = session.Send(t.Context(), copilot.MessageOptions{ Prompt: "What is 5+5?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if message == nil { - t.Fatal("Expected a message, got nil") + message, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get final message: %v", err) } if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "10") { @@ -230,10 +228,6 @@ func TestCustomAgents(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } - if message == nil { - t.Fatalf("Expected a message, got nil") - } - if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "12") { t.Errorf("Expected message to contain '12', got: %v", message.Data.Content) } @@ -379,15 +373,16 @@ func TestCombinedConfiguration(t *testing.T) { t.Error("Expected non-empty session ID") } - message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ + _, err = session.Send(t.Context(), copilot.MessageOptions{ Prompt: "What is 7+7?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if message == nil { - t.Fatalf("Expected a message, got nil") + message, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get final message: %v", err) } if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "14") { diff --git a/go/internal/e2e/permissions_test.go b/go/internal/e2e/permissions_test.go index cde53b1d..a891548c 100644 --- a/go/internal/e2e/permissions_test.go +++ b/go/internal/e2e/permissions_test.go @@ -134,13 +134,18 @@ func TestPermissions(t *testing.T) { t.Fatalf("Failed to write test file: %v", err) } - _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{ + _, err = session.Send(t.Context(), copilot.MessageOptions{ Prompt: "Edit protected.txt and replace 'protected' with 'hacked'.", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } + _, err = testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get final message: %v", err) + } + // Verify the file was NOT modified content, err := os.ReadFile(testFile) if err != nil { @@ -160,13 +165,14 @@ func TestPermissions(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if message == nil { - t.Fatal("Expected a message, got nil") + message, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get final message: %v", err) } if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "4") { diff --git a/go/internal/e2e/session_test.go b/go/internal/e2e/session_test.go index 6fb05051..62183286 100644 --- a/go/internal/e2e/session_test.go +++ b/go/internal/e2e/session_test.go @@ -2,7 +2,6 @@ package e2e import ( "regexp" - "slices" "strings" "testing" "time" @@ -69,10 +68,6 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } - if assistantMessage == nil { - t.Fatal("Expected assistant message, got nil") - } - if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { t.Errorf("Expected assistant message to contain '2', got %v", assistantMessage.Data.Content) } @@ -82,10 +77,6 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send second message: %v", err) } - if secondMessage == nil { - t.Fatal("Expected second assistant message, got nil") - } - if secondMessage.Data.Content == nil || !strings.Contains(*secondMessage.Data.Content, "4") { t.Errorf("Expected second message to contain '4', got %v", secondMessage.Data.Content) } @@ -153,13 +144,14 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is your full name?"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is your full name?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if assistantMessage == nil { - t.Fatal("Expected assistant message, got nil") + assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) } content := "" @@ -198,11 +190,16 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } + _, err = testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + // Validate that only the specified tools are present traffic, err := ctx.GetExchanges() if err != nil { @@ -231,11 +228,16 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } + _, err = testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + // Validate that excluded tool is not present but others are traffic, err := ctx.GetExchanges() if err != nil { @@ -293,13 +295,14 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is the secret number for key ALPHA?"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is the secret number for key ALPHA?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if assistantMessage == nil { - t.Fatal("Expected assistant message, got nil") + assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) } content := "" @@ -326,13 +329,14 @@ func TestSession(t *testing.T) { } sessionID := session1.SessionID - answer, err := session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + _, err = session1.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if answer == nil { - t.Fatalf("Expected an answer, got nil") + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session1) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) } if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "2") { @@ -349,21 +353,13 @@ func TestSession(t *testing.T) { t.Errorf("Expected resumed session ID to match, got %q vs %q", session2.SessionID, sessionID) } - messages, err := session2.GetMessages(t.Context()) + answer2, err := testharness.GetFinalAssistantMessage(t.Context(), session2) if err != nil { - t.Fatalf("Failed to get messages: %v", err) - } - - answer2Idx := slices.IndexFunc(messages, func(m copilot.SessionEvent) bool { - return m.Type == "assistant.message" - }) - - if answer2Idx == -1 { - t.Fatalf("Expected to find an assistant.message in resumed session messages, got %v", messages) + t.Fatalf("Failed to get assistant message from resumed session: %v", err) } - if messages[answer2Idx].Data.Content == nil || !strings.Contains(*messages[answer2Idx].Data.Content, "2") { - t.Errorf("Expected resumed session answer to contain '2', got %v", messages[answer2Idx].Data.Content) + if answer2.Data.Content == nil || !strings.Contains(*answer2.Data.Content, "2") { + t.Errorf("Expected resumed session answer to contain '2', got %v", answer2.Data.Content) } }) @@ -377,13 +373,14 @@ func TestSession(t *testing.T) { } sessionID := session1.SessionID - answer, err := session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + _, err = session1.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if answer == nil { - t.Fatalf("Expected an answer, got nil") + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session1) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) } if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "2") { @@ -549,10 +546,6 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message after abort: %v", err) } - if answer == nil { - t.Fatalf("Expected an answer after abort, got nil") - } - if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "4") { t.Errorf("Expected answer to contain '4', got %v", answer.Data.Content) } @@ -569,6 +562,7 @@ func TestSession(t *testing.T) { } var deltaContents []string + done := make(chan bool) session.On(func(event copilot.SessionEvent) { switch event.Type { @@ -576,17 +570,21 @@ func TestSession(t *testing.T) { if event.Data.DeltaContent != nil { deltaContents = append(deltaContents, *event.Data.DeltaContent) } - case "assistant.message": + case "session.idle": + close(done) } }) - assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if assistantMessage == nil { - t.Fatal("Expected assistant message, got nil") + // Wait for completion + select { + case <-done: + case <-time.After(60 * time.Second): + t.Fatal("Timed out waiting for session.idle") } // Should have received delta events @@ -594,6 +592,12 @@ func TestSession(t *testing.T) { t.Error("Expected to receive delta events, got none") } + // Get the final message to compare + assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + // Accumulated deltas should equal the final message accumulated := strings.Join(deltaContents, "") if assistantMessage.Data.Content != nil && accumulated != *assistantMessage.Data.Content { @@ -623,13 +627,14 @@ func TestSession(t *testing.T) { } // Session should still work normally - assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if assistantMessage == nil { - t.Fatal("Expected assistant message, got nil") + assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) } if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { @@ -646,18 +651,29 @@ func TestSession(t *testing.T) { } var receivedEvents []copilot.SessionEvent + idle := make(chan bool) + session.On(func(event copilot.SessionEvent) { receivedEvents = append(receivedEvents, event) + if event.Type == "session.idle" { + select { + case idle <- true: + default: + } + } }) // Send a message to trigger events - assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 100+200?"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 100+200?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if assistantMessage == nil { - t.Fatal("Expected assistant message, got nil") + // Wait for session to become idle + select { + case <-idle: + case <-time.After(60 * time.Second): + t.Fatal("Timed out waiting for session.idle") } // Should have received multiple events @@ -689,6 +705,11 @@ func TestSession(t *testing.T) { t.Error("Expected to receive session.idle event") } + // Verify the assistant response contains the expected answer + assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "300") { t.Errorf("Expected assistant message to contain '300', got %v", assistantMessage.Data.Content) } @@ -711,13 +732,14 @@ func TestSession(t *testing.T) { } // Session should work normally with custom config dir - assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if assistantMessage == nil { - t.Fatal("Expected assistant message, got nil") + assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) } if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { diff --git a/go/internal/e2e/skills_test.go b/go/internal/e2e/skills_test.go index 52367422..ed3578ab 100644 --- a/go/internal/e2e/skills_test.go +++ b/go/internal/e2e/skills_test.go @@ -71,10 +71,6 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } - if message == nil { - t.Fatalf("Expected a message, got nil") - } - if message.Data.Content == nil || !strings.Contains(*message.Data.Content, skillMarker) { t.Errorf("Expected message to contain skill marker '%s', got: %v", skillMarker, message.Data.Content) } @@ -103,10 +99,6 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } - if message == nil { - t.Fatalf("Expected a message, got nil") - } - if message.Data.Content != nil && strings.Contains(*message.Data.Content, skillMarker) { t.Errorf("Expected message to NOT contain skill marker '%s' when disabled, got: %v", skillMarker, *message.Data.Content) } @@ -133,10 +125,6 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } - if message1 == nil { - t.Fatalf("Expected a message, got nil") - } - if message1.Data.Content != nil && strings.Contains(*message1.Data.Content, skillMarker) { t.Errorf("Expected message to NOT contain skill marker before skill was added, got: %v", *message1.Data.Content) } @@ -159,10 +147,6 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } - if message2 == nil { - t.Fatalf("Expected a message, got nil") - } - if message2.Data.Content == nil || !strings.Contains(*message2.Data.Content, skillMarker) { t.Errorf("Expected message to contain skill marker '%s' after resume, got: %v", skillMarker, message2.Data.Content) } diff --git a/go/internal/e2e/testharness/helper.go b/go/internal/e2e/testharness/helper.go index c523b6db..05947c80 100644 --- a/go/internal/e2e/testharness/helper.go +++ b/go/internal/e2e/testharness/helper.go @@ -1,12 +1,60 @@ package testharness import ( + "context" "errors" "time" copilot "github.com/github/copilot-sdk/go" ) +// GetFinalAssistantMessage waits for and returns the final assistant message from a session turn. +func GetFinalAssistantMessage(ctx context.Context, session *copilot.Session) (*copilot.SessionEvent, error) { + result := make(chan *copilot.SessionEvent, 1) + errCh := make(chan error, 1) + + // Subscribe to future events + var finalAssistantMessage *copilot.SessionEvent + unsubscribe := session.On(func(event copilot.SessionEvent) { + switch event.Type { + case "assistant.message": + finalAssistantMessage = &event + case "session.idle": + if finalAssistantMessage != nil { + result <- finalAssistantMessage + } + case "session.error": + msg := "session error" + if event.Data.Message != nil { + msg = *event.Data.Message + } + errCh <- errors.New(msg) + } + }) + defer unsubscribe() + + // Also check existing messages in case the response already arrived + go func() { + existing, err := getExistingFinalResponse(ctx, session) + if err != nil { + errCh <- err + return + } + if existing != nil { + result <- existing + } + }() + + select { + case msg := <-result: + return msg, nil + case err := <-errCh: + return nil, err + case <-ctx.Done(): + return nil, errors.New("timeout waiting for assistant message") + } +} + // GetNextEventOfType waits for and returns the next event of the specified type from a session. func GetNextEventOfType(session *copilot.Session, eventType copilot.SessionEventType, timeout time.Duration) (*copilot.SessionEvent, error) { result := make(chan *copilot.SessionEvent, 1) @@ -41,3 +89,57 @@ func GetNextEventOfType(session *copilot.Session, eventType copilot.SessionEvent return nil, errors.New("timeout waiting for event: " + string(eventType)) } } + +func getExistingFinalResponse(ctx context.Context, session *copilot.Session) (*copilot.SessionEvent, error) { + messages, err := session.GetMessages(ctx) + if err != nil { + return nil, err + } + + // Find last user message + finalUserMessageIndex := -1 + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Type == "user.message" { + finalUserMessageIndex = i + break + } + } + + var currentTurnMessages []copilot.SessionEvent + if finalUserMessageIndex < 0 { + currentTurnMessages = messages + } else { + currentTurnMessages = messages[finalUserMessageIndex:] + } + + // Check for errors + for _, msg := range currentTurnMessages { + if msg.Type == "session.error" { + errMsg := "session error" + if msg.Data.Message != nil { + errMsg = *msg.Data.Message + } + return nil, errors.New(errMsg) + } + } + + // Find session.idle and get last assistant message before it + sessionIdleIndex := -1 + for i, msg := range currentTurnMessages { + if msg.Type == "session.idle" { + sessionIdleIndex = i + break + } + } + + if sessionIdleIndex != -1 { + // Find last assistant.message before session.idle + for i := sessionIdleIndex - 1; i >= 0; i-- { + if currentTurnMessages[i].Type == "assistant.message" { + return ¤tTurnMessages[i], nil + } + } + } + + return nil, nil +} diff --git a/go/internal/e2e/tools_test.go b/go/internal/e2e/tools_test.go index b6af6ef0..5af9079c 100644 --- a/go/internal/e2e/tools_test.go +++ b/go/internal/e2e/tools_test.go @@ -30,13 +30,14 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What's the first line of README.md in this directory?"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What's the first line of README.md in this directory?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if answer == nil { - t.Fatalf("Expected an answer, got nil") + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) } if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "ELIZA") { @@ -63,13 +64,14 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Use encrypt_string to encrypt this string: Hello"}) + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "Use encrypt_string to encrypt this string: Hello"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if answer == nil { - t.Fatalf("Expected an answer, got nil") + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) } if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "HELLO") { @@ -94,15 +96,16 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ + _, err = session.Send(t.Context(), copilot.MessageOptions{ Prompt: "What is my location? If you can't find out, just say 'unknown'.", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if answer == nil { - t.Fatalf("Expected an answer, got nil") + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) } // Check the underlying traffic @@ -210,7 +213,7 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ + _, err = session.Send(t.Context(), copilot.MessageOptions{ Prompt: "Perform a DB query for the 'cities' table using IDs 12 and 19, sorting ascending. " + "Reply only with lines of the form: [cityname] [population]", }) @@ -218,6 +221,11 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + if answer == nil || answer.Data.Content == nil { t.Fatalf("Expected assistant message with content") } diff --git a/nodejs/test/e2e/harness/sdkTestHelper.ts b/nodejs/test/e2e/harness/sdkTestHelper.ts index f94f606e..4e8ff203 100644 --- a/nodejs/test/e2e/harness/sdkTestHelper.ts +++ b/nodejs/test/e2e/harness/sdkTestHelper.ts @@ -2,7 +2,78 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { CopilotSession, SessionEvent } from "../../../src"; +import { AssistantMessageEvent, CopilotSession, SessionEvent } from "../../../src"; + +export async function getFinalAssistantMessage( + session: CopilotSession +): Promise { + // We don't know whether the answer has already arrived or not, so race both possibilities + return new Promise(async (resolve, reject) => { + getFutureFinalResponse(session).then(resolve).catch(reject); + getExistingFinalResponse(session) + .then((msg) => { + if (msg) { + resolve(msg); + } + }) + .catch(reject); + }); +} + +function getExistingFinalResponse( + session: CopilotSession +): Promise { + return new Promise(async (resolve, reject) => { + const messages = await session.getMessages(); + const finalUserMessageIndex = messages.findLastIndex((m) => m.type === "user.message"); + const currentTurnMessages = + finalUserMessageIndex < 0 ? messages : messages.slice(finalUserMessageIndex); + + const currentTurnError = currentTurnMessages.find((m) => m.type === "session.error"); + if (currentTurnError) { + const error = new Error(currentTurnError.data.message); + error.stack = currentTurnError.data.stack; + reject(error); + return; + } + + const sessionIdleMessageIndex = currentTurnMessages.findIndex( + (m) => m.type === "session.idle" + ); + if (sessionIdleMessageIndex !== -1) { + const lastAssistantMessage = currentTurnMessages + .slice(0, sessionIdleMessageIndex) + .findLast((m) => m.type === "assistant.message"); + resolve(lastAssistantMessage as AssistantMessageEvent | undefined); + return; + } + + resolve(undefined); + }); +} + +function getFutureFinalResponse(session: CopilotSession): Promise { + return new Promise((resolve, reject) => { + let finalAssistantMessage: AssistantMessageEvent | undefined; + session.on((event) => { + if (event.type === "assistant.message") { + finalAssistantMessage = event; + } else if (event.type === "session.idle") { + if (!finalAssistantMessage) { + reject( + new Error("Received session.idle without a preceding assistant.message") + ); + } else { + resolve(finalAssistantMessage); + } + } else if (event.type === "session.error") { + const error = new Error(event.data.message); + error.stack = event.data.stack; + reject(error); + } + }); + }); +} export async function retry( message: string, diff --git a/nodejs/test/e2e/session.test.ts b/nodejs/test/e2e/session.test.ts index 78f041f7..b3fba475 100644 --- a/nodejs/test/e2e/session.test.ts +++ b/nodejs/test/e2e/session.test.ts @@ -1,8 +1,8 @@ import { describe, expect, it, onTestFinished } from "vitest"; import { ParsedHttpExchange } from "../../../test/harness/replayingCapiProxy.js"; -import { CopilotClient, SessionEvent } from "../../src/index.js"; +import { CopilotClient } from "../../src/index.js"; import { CLI_PATH, createSdkTestContext } from "./harness/sdkTestContext.js"; -import { getNextEventOfType } from "./harness/sdkTestHelper.js"; +import { getFinalAssistantMessage, getNextEventOfType } from "./harness/sdkTestHelper.js"; describe("Sessions", async () => { const { copilotClient: client, openAiEndpoint, homeDir, env } = await createSdkTestContext(); @@ -167,7 +167,7 @@ describe("Sessions", async () => { expect(session2.sessionId).toBe(sessionId); // TODO: There's an inconsistency here. When resuming with a new client, we don't see - // the session.idle message in the history, so we can't easily identify when a turn completed. + // the session.idle message in the history, which means we can't use getFinalAssistantMessage. const messages = await session2.getMessages(); expect(messages).toContainEqual(expect.objectContaining({ type: "user.message" })); @@ -328,8 +328,9 @@ describe("Sessions", async () => { expect(session.sessionId).toMatch(/^[a-f0-9-]+$/); // Session should work normally with custom config dir - const assistantMessage = await session.sendAndWait({ prompt: "What is 1+1?" }); - expect(assistantMessage?.data.content).toContain("2"); + await session.send({ prompt: "What is 1+1?" }); + const assistantMessage = await getFinalAssistantMessage(session); + expect(assistantMessage.data.content).toContain("2"); }); }); @@ -347,28 +348,23 @@ describe("Send Blocking Behavior", async () => { it("send returns immediately while events stream in background", async () => { const session = await client.createSession(); - const events: Array = []; - session.on((event) => events.push(event)); - - // Set up promise to wait for idle BEFORE sending - const idlePromise = getNextEventOfType(session, "session.idle"); + const events: string[] = []; + session.on((event) => { + events.push(event.type); + }); // Use a slow command so we can verify send() returns before completion await session.send({ prompt: "Run 'sleep 2 && echo done'" }); // send() should return before turn completes (no session.idle yet) - expect(events.some((e) => e.type === "session.idle")).toBe(false); + expect(events).not.toContain("session.idle"); // Wait for turn to complete - await idlePromise; - - // Find the last assistant message from collected events - const assistantMessages = events.filter((e) => e.type === "assistant.message"); - const message = assistantMessages[assistantMessages.length - 1]; + const message = await getFinalAssistantMessage(session); - expect(message.data?.content).toContain("done"); - expect(events.some((e) => e.type === "session.idle")).toBe(true); - expect(events.some((e) => e.type === "assistant.message")).toBe(true); + expect(message.data.content).toContain("done"); + expect(events).toContain("session.idle"); + expect(events).toContain("assistant.message"); }); it("sendAndWait blocks until session.idle and returns final assistant message", async () => {