Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 255 additions & 41 deletions internal/server/unified_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,38 +112,29 @@ func TestUnifiedServer_GetToolsForBackend(t *testing.T) {

// Test filtering for github backend
githubTools := us.GetToolsForBackend("github")
if len(githubTools) != 2 {
t.Errorf("Expected 2 GitHub tools, got %d", len(githubTools))
}
require.Len(t, githubTools, 2, "Expected 2 GitHub tools")

// Verify all tools have correct backend ID and prefix stripped
for _, tool := range githubTools {
if tool.BackendID != "github" {
t.Errorf("Expected BackendID 'github', got '%s'", tool.BackendID)
}
// Check that prefix is stripped
if tool.Name == "github___issue_read" || tool.Name == "github___repo_list" {
t.Errorf("Tool name '%s' still has prefix", tool.Name)
}
if tool.Name != "issue_read" && tool.Name != "repo_list" {
t.Errorf("Unexpected tool name after prefix strip: '%s'", tool.Name)
}
assert.Equal(t, "github", tool.BackendID, "Tool should belong to github backend")
assert.NotContains(t, tool.Name, "github___", "Tool name should have prefix stripped")
}

// Test filtering for fetch backend
fetchTools := us.GetToolsForBackend("fetch")
if len(fetchTools) != 1 {
t.Errorf("Expected 1 fetch tool, got %d", len(fetchTools))
// Verify specific tool names are present
toolNames := make([]string, len(githubTools))
for i, tool := range githubTools {
toolNames[i] = tool.Name
}
assert.ElementsMatch(t, []string{"issue_read", "repo_list"}, toolNames, "Should have expected GitHub tool names")

if fetchTools[0].Name != "get" {
t.Errorf("Expected tool name 'get', got '%s'", fetchTools[0].Name)
}
// Test filtering for fetch backend
fetchTools := us.GetToolsForBackend("fetch")
require.Len(t, fetchTools, 1, "Expected 1 fetch tool")
assert.Equal(t, "get", fetchTools[0].Name, "Fetch tool should have name 'get'")

// Test filtering for non-existent backend
noTools := us.GetToolsForBackend("nonexistent")
if len(noTools) != 0 {
t.Errorf("Expected 0 tools for nonexistent backend, got %d", len(noTools))
}
assert.Empty(t, noTools, "Expected no tools for nonexistent backend")
}

func TestGetSessionID_FromContext(t *testing.T) {
Expand Down Expand Up @@ -193,9 +184,7 @@ func TestRequireSession(t *testing.T) {
// Test with invalid session (DIFC enabled)
ctxWithInvalidSession := context.WithValue(ctx, SessionIDContextKey, "invalid-session")
err = us.requireSession(ctxWithInvalidSession)
if err == nil {
t.Error("requireSession() should fail for invalid session when DIFC is enabled")
}
require.Error(t, err, "requireSession() should fail for invalid session when DIFC is enabled")
}

func TestRequireSession_DifcDisabled(t *testing.T) {
Expand All @@ -221,13 +210,9 @@ func TestRequireSession_DifcDisabled(t *testing.T) {
session, exists := us.sessions[sessionID]
us.sessionMu.RUnlock()

if !exists {
t.Error("Session should have been auto-created when DIFC is disabled")
}

if session.SessionID != sessionID {
t.Errorf("Expected session ID '%s', got '%s'", sessionID, session.SessionID)
}
require.True(t, exists, "Session should have been auto-created when DIFC is disabled")
require.NotNil(t, session, "Session should not be nil")
assert.Equal(t, sessionID, session.SessionID, "Session ID should match")
}

func TestRequireSession_DifcDisabled_Concurrent(t *testing.T) {
Expand Down Expand Up @@ -257,9 +242,8 @@ func TestRequireSession_DifcDisabled_Concurrent(t *testing.T) {

// Collect results
for i := 0; i < numGoroutines; i++ {
if err := <-errChan; err != nil {
t.Errorf("requireSession() failed in concurrent access: %v", err)
}
err := <-errChan
require.NoError(t, err, "requireSession() should not fail in concurrent access")
}

// Verify exactly one session was created
Expand All @@ -268,13 +252,243 @@ func TestRequireSession_DifcDisabled_Concurrent(t *testing.T) {
sessionCount := len(us.sessions)
us.sessionMu.RUnlock()

if !exists {
t.Error("Session should have been created")
require.True(t, exists, "Session should have been created")
require.NotNil(t, session, "Session should not be nil")
assert.Equal(t, 1, sessionCount, "Expected exactly 1 session")
assert.Equal(t, sessionID, session.SessionID, "Session ID should match")
}

func TestGetToolsForBackend_EdgeCases(t *testing.T) {
cfg := &config.Config{
Servers: map[string]*config.ServerConfig{},
}

ctx := context.Background()
us, err := NewUnified(ctx, cfg)
require.NoError(t, err, "NewUnified() failed")
defer us.Close()

tests := []struct {
name string
setupTools map[string]*ToolInfo
backendID string
wantCount int
wantNames []string
description string
}{
{
name: "empty backend",
setupTools: map[string]*ToolInfo{},
backendID: "empty",
wantCount: 0,
wantNames: []string{},
description: "should return empty list for backend with no tools",
},
{
name: "mixed prefix formats",
setupTools: map[string]*ToolInfo{
"backend___tool1": {
Name: "backend___tool1",
Description: "Tool 1",
BackendID: "backend",
},
"backend___tool2": {
Name: "backend___tool2",
Description: "Tool 2",
BackendID: "backend",
},
},
backendID: "backend",
wantCount: 2,
wantNames: []string{"tool1", "tool2"},
description: "should correctly strip backend___ prefix",
},
{
name: "case sensitive backend",
setupTools: map[string]*ToolInfo{
"GitHub___read": {
Name: "GitHub___read",
Description: "Read",
BackendID: "GitHub",
},
},
backendID: "github",
wantCount: 0,
wantNames: []string{},
description: "backend ID matching should be case-sensitive",
},
}

assert.Equal(t, 1, sessionCount, "exactly 1 session, got %d")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset tools
us.toolsMu.Lock()
us.tools = make(map[string]*ToolInfo)
for k, v := range tt.setupTools {
us.tools[k] = v
}
us.toolsMu.Unlock()

// Get tools for backend
result := us.GetToolsForBackend(tt.backendID)

// Verify count
assert.Len(t, result, tt.wantCount, tt.description)

// Verify tool names if expected
if tt.wantCount > 0 {
actualNames := make([]string, len(result))
for i, tool := range result {
actualNames[i] = tool.Name
}
assert.ElementsMatch(t, tt.wantNames, actualNames, "Tool names should match expected")
}
})
}
}

func TestGetSessionID_EdgeCases(t *testing.T) {
cfg := &config.Config{
Servers: map[string]*config.ServerConfig{},
}

ctx := context.Background()
us, err := NewUnified(ctx, cfg)
require.NoError(t, err, "NewUnified() failed")
defer us.Close()

tests := []struct {
name string
ctx context.Context
wantID string
setupFunc func(context.Context) context.Context
description string
}{
{
name: "nil context value",
ctx: ctx,
wantID: "default",
setupFunc: func(c context.Context) context.Context { return c },
description: "should return default for context without session ID",
},
{
name: "empty string session ID",
ctx: ctx,
wantID: "",
setupFunc: func(c context.Context) context.Context {
return context.WithValue(c, SessionIDContextKey, "")
},
description: "should preserve empty string session ID",
},
{
name: "whitespace session ID",
ctx: ctx,
wantID: " test ",
setupFunc: func(c context.Context) context.Context {
return context.WithValue(c, SessionIDContextKey, " test ")
},
description: "should preserve whitespace in session ID",
},
{
name: "special characters in session ID",
ctx: ctx,
wantID: "session-123_test@example",
setupFunc: func(c context.Context) context.Context {
return context.WithValue(c, SessionIDContextKey, "session-123_test@example")
},
description: "should handle special characters in session ID",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testCtx := tt.setupFunc(tt.ctx)
result := us.getSessionID(testCtx)
assert.Equal(t, tt.wantID, result, tt.description)
})
}
}

func TestRequireSession_EdgeCases(t *testing.T) {
tests := []struct {
name string
enableDIFC bool
sessionID string
preCreate bool
wantErr bool
description string
}{
{
name: "DIFC enabled with existing session",
enableDIFC: true,
sessionID: "existing",
preCreate: true,
wantErr: false,
description: "should allow access to existing session when DIFC enabled",
},
{
name: "DIFC enabled without session",
enableDIFC: true,
sessionID: "nonexistent",
preCreate: false,
wantErr: true,
description: "should deny access to nonexistent session when DIFC enabled",
},
{
name: "DIFC disabled without session",
enableDIFC: false,
sessionID: "autocreate",
preCreate: false,
wantErr: false,
description: "should auto-create session when DIFC disabled",
},
{
name: "DIFC disabled with existing session",
enableDIFC: false,
sessionID: "existing2",
preCreate: true,
wantErr: false,
description: "should reuse existing session when DIFC disabled",
},
}

if session.SessionID != sessionID {
t.Errorf("Expected session ID '%s', got '%s'", sessionID, session.SessionID)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.Config{
Servers: map[string]*config.ServerConfig{},
EnableDIFC: tt.enableDIFC,
}

ctx := context.Background()
us, err := NewUnified(ctx, cfg)
require.NoError(t, err, "NewUnified() failed")
defer us.Close()

// Pre-create session if needed
if tt.preCreate {
us.sessionMu.Lock()
us.sessions[tt.sessionID] = NewSession(tt.sessionID, "token")
us.sessionMu.Unlock()
}

// Test requireSession
ctxWithSession := context.WithValue(ctx, SessionIDContextKey, tt.sessionID)
err = us.requireSession(ctxWithSession)

if tt.wantErr {
require.Error(t, err, tt.description)
} else {
require.NoError(t, err, tt.description)

// Verify session exists after call
us.sessionMu.RLock()
session, exists := us.sessions[tt.sessionID]
us.sessionMu.RUnlock()

require.True(t, exists, "Session should exist after requireSession")
require.NotNil(t, session, "Session should not be nil")
assert.Equal(t, tt.sessionID, session.SessionID, "Session ID should match")
}
})
}
}