From 3460108f4fdf84aa39fd05e1900a8bbd25f0a94c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 19 Jan 2026 14:19:54 +0000 Subject: [PATCH] Improve unified_test.go with better testify assertions and edge case coverage --- internal/server/unified_test.go | 296 +++++++++++++++++++++++++++----- 1 file changed, 255 insertions(+), 41 deletions(-) diff --git a/internal/server/unified_test.go b/internal/server/unified_test.go index a6231e92..99ea491a 100644 --- a/internal/server/unified_test.go +++ b/internal/server/unified_test.go @@ -112,38 +112,29 @@ func TestUnifiedServer_GetToolsForBackend(t *testing.T) { // Test filtering for github backend githubTools := us.GetToolsForBackend("github") - if len(githubTools) != 2 { - t.Errorf("Expected 2 GitHub tools, got %d", len(githubTools)) - } + require.Len(t, githubTools, 2, "Expected 2 GitHub tools") + // Verify all tools have correct backend ID and prefix stripped for _, tool := range githubTools { - if tool.BackendID != "github" { - t.Errorf("Expected BackendID 'github', got '%s'", tool.BackendID) - } - // Check that prefix is stripped - if tool.Name == "github___issue_read" || tool.Name == "github___repo_list" { - t.Errorf("Tool name '%s' still has prefix", tool.Name) - } - if tool.Name != "issue_read" && tool.Name != "repo_list" { - t.Errorf("Unexpected tool name after prefix strip: '%s'", tool.Name) - } + assert.Equal(t, "github", tool.BackendID, "Tool should belong to github backend") + assert.NotContains(t, tool.Name, "github___", "Tool name should have prefix stripped") } - // Test filtering for fetch backend - fetchTools := us.GetToolsForBackend("fetch") - if len(fetchTools) != 1 { - t.Errorf("Expected 1 fetch tool, got %d", len(fetchTools)) + // Verify specific tool names are present + toolNames := make([]string, len(githubTools)) + for i, tool := range githubTools { + toolNames[i] = tool.Name } + assert.ElementsMatch(t, []string{"issue_read", "repo_list"}, toolNames, "Should have expected GitHub tool names") - if fetchTools[0].Name != "get" { - t.Errorf("Expected tool name 'get', got '%s'", fetchTools[0].Name) - } + // Test filtering for fetch backend + fetchTools := us.GetToolsForBackend("fetch") + require.Len(t, fetchTools, 1, "Expected 1 fetch tool") + assert.Equal(t, "get", fetchTools[0].Name, "Fetch tool should have name 'get'") // Test filtering for non-existent backend noTools := us.GetToolsForBackend("nonexistent") - if len(noTools) != 0 { - t.Errorf("Expected 0 tools for nonexistent backend, got %d", len(noTools)) - } + assert.Empty(t, noTools, "Expected no tools for nonexistent backend") } func TestGetSessionID_FromContext(t *testing.T) { @@ -193,9 +184,7 @@ func TestRequireSession(t *testing.T) { // Test with invalid session (DIFC enabled) ctxWithInvalidSession := context.WithValue(ctx, SessionIDContextKey, "invalid-session") err = us.requireSession(ctxWithInvalidSession) - if err == nil { - t.Error("requireSession() should fail for invalid session when DIFC is enabled") - } + require.Error(t, err, "requireSession() should fail for invalid session when DIFC is enabled") } func TestRequireSession_DifcDisabled(t *testing.T) { @@ -221,13 +210,9 @@ func TestRequireSession_DifcDisabled(t *testing.T) { session, exists := us.sessions[sessionID] us.sessionMu.RUnlock() - if !exists { - t.Error("Session should have been auto-created when DIFC is disabled") - } - - if session.SessionID != sessionID { - t.Errorf("Expected session ID '%s', got '%s'", sessionID, session.SessionID) - } + require.True(t, exists, "Session should have been auto-created when DIFC is disabled") + require.NotNil(t, session, "Session should not be nil") + assert.Equal(t, sessionID, session.SessionID, "Session ID should match") } func TestRequireSession_DifcDisabled_Concurrent(t *testing.T) { @@ -257,9 +242,8 @@ func TestRequireSession_DifcDisabled_Concurrent(t *testing.T) { // Collect results for i := 0; i < numGoroutines; i++ { - if err := <-errChan; err != nil { - t.Errorf("requireSession() failed in concurrent access: %v", err) - } + err := <-errChan + require.NoError(t, err, "requireSession() should not fail in concurrent access") } // Verify exactly one session was created @@ -268,13 +252,243 @@ func TestRequireSession_DifcDisabled_Concurrent(t *testing.T) { sessionCount := len(us.sessions) us.sessionMu.RUnlock() - if !exists { - t.Error("Session should have been created") + require.True(t, exists, "Session should have been created") + require.NotNil(t, session, "Session should not be nil") + assert.Equal(t, 1, sessionCount, "Expected exactly 1 session") + assert.Equal(t, sessionID, session.SessionID, "Session ID should match") +} + +func TestGetToolsForBackend_EdgeCases(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + require.NoError(t, err, "NewUnified() failed") + defer us.Close() + + tests := []struct { + name string + setupTools map[string]*ToolInfo + backendID string + wantCount int + wantNames []string + description string + }{ + { + name: "empty backend", + setupTools: map[string]*ToolInfo{}, + backendID: "empty", + wantCount: 0, + wantNames: []string{}, + description: "should return empty list for backend with no tools", + }, + { + name: "mixed prefix formats", + setupTools: map[string]*ToolInfo{ + "backend___tool1": { + Name: "backend___tool1", + Description: "Tool 1", + BackendID: "backend", + }, + "backend___tool2": { + Name: "backend___tool2", + Description: "Tool 2", + BackendID: "backend", + }, + }, + backendID: "backend", + wantCount: 2, + wantNames: []string{"tool1", "tool2"}, + description: "should correctly strip backend___ prefix", + }, + { + name: "case sensitive backend", + setupTools: map[string]*ToolInfo{ + "GitHub___read": { + Name: "GitHub___read", + Description: "Read", + BackendID: "GitHub", + }, + }, + backendID: "github", + wantCount: 0, + wantNames: []string{}, + description: "backend ID matching should be case-sensitive", + }, } - assert.Equal(t, 1, sessionCount, "exactly 1 session, got %d") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset tools + us.toolsMu.Lock() + us.tools = make(map[string]*ToolInfo) + for k, v := range tt.setupTools { + us.tools[k] = v + } + us.toolsMu.Unlock() + + // Get tools for backend + result := us.GetToolsForBackend(tt.backendID) + + // Verify count + assert.Len(t, result, tt.wantCount, tt.description) + + // Verify tool names if expected + if tt.wantCount > 0 { + actualNames := make([]string, len(result)) + for i, tool := range result { + actualNames[i] = tool.Name + } + assert.ElementsMatch(t, tt.wantNames, actualNames, "Tool names should match expected") + } + }) + } +} + +func TestGetSessionID_EdgeCases(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + require.NoError(t, err, "NewUnified() failed") + defer us.Close() + + tests := []struct { + name string + ctx context.Context + wantID string + setupFunc func(context.Context) context.Context + description string + }{ + { + name: "nil context value", + ctx: ctx, + wantID: "default", + setupFunc: func(c context.Context) context.Context { return c }, + description: "should return default for context without session ID", + }, + { + name: "empty string session ID", + ctx: ctx, + wantID: "", + setupFunc: func(c context.Context) context.Context { + return context.WithValue(c, SessionIDContextKey, "") + }, + description: "should preserve empty string session ID", + }, + { + name: "whitespace session ID", + ctx: ctx, + wantID: " test ", + setupFunc: func(c context.Context) context.Context { + return context.WithValue(c, SessionIDContextKey, " test ") + }, + description: "should preserve whitespace in session ID", + }, + { + name: "special characters in session ID", + ctx: ctx, + wantID: "session-123_test@example", + setupFunc: func(c context.Context) context.Context { + return context.WithValue(c, SessionIDContextKey, "session-123_test@example") + }, + description: "should handle special characters in session ID", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testCtx := tt.setupFunc(tt.ctx) + result := us.getSessionID(testCtx) + assert.Equal(t, tt.wantID, result, tt.description) + }) + } +} + +func TestRequireSession_EdgeCases(t *testing.T) { + tests := []struct { + name string + enableDIFC bool + sessionID string + preCreate bool + wantErr bool + description string + }{ + { + name: "DIFC enabled with existing session", + enableDIFC: true, + sessionID: "existing", + preCreate: true, + wantErr: false, + description: "should allow access to existing session when DIFC enabled", + }, + { + name: "DIFC enabled without session", + enableDIFC: true, + sessionID: "nonexistent", + preCreate: false, + wantErr: true, + description: "should deny access to nonexistent session when DIFC enabled", + }, + { + name: "DIFC disabled without session", + enableDIFC: false, + sessionID: "autocreate", + preCreate: false, + wantErr: false, + description: "should auto-create session when DIFC disabled", + }, + { + name: "DIFC disabled with existing session", + enableDIFC: false, + sessionID: "existing2", + preCreate: true, + wantErr: false, + description: "should reuse existing session when DIFC disabled", + }, + } - if session.SessionID != sessionID { - t.Errorf("Expected session ID '%s', got '%s'", sessionID, session.SessionID) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + EnableDIFC: tt.enableDIFC, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + require.NoError(t, err, "NewUnified() failed") + defer us.Close() + + // Pre-create session if needed + if tt.preCreate { + us.sessionMu.Lock() + us.sessions[tt.sessionID] = NewSession(tt.sessionID, "token") + us.sessionMu.Unlock() + } + + // Test requireSession + ctxWithSession := context.WithValue(ctx, SessionIDContextKey, tt.sessionID) + err = us.requireSession(ctxWithSession) + + if tt.wantErr { + require.Error(t, err, tt.description) + } else { + require.NoError(t, err, tt.description) + + // Verify session exists after call + us.sessionMu.RLock() + session, exists := us.sessions[tt.sessionID] + us.sessionMu.RUnlock() + + require.True(t, exists, "Session should exist after requireSession") + require.NotNil(t, session, "Session should not be nil") + assert.Equal(t, tt.sessionID, session.SessionID, "Session ID should match") + } + }) } }