Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*.dll
*.so
*.dylib
flowguard-go

# Test binary, built with `go test -c`
*.test
Expand Down
8 changes: 8 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@ require (
github.com/modelcontextprotocol/go-sdk v1.1.0
github.com/spf13/cobra v1.8.0
)

require (
github.com/google/jsonschema-go v0.3.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
)
75 changes: 0 additions & 75 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
Expand All @@ -19,81 +16,9 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457/go.mod h1:pRgIJT+bRLFKnoM1ldnzKoxTIn14Yxz928LQRYYgIN0=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
10 changes: 10 additions & 0 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var (
routedMode bool
unifiedMode bool
envFile string
enableDIFC bool
)

var rootCmd = &cobra.Command{
Expand All @@ -40,6 +41,7 @@ func init() {
rootCmd.Flags().BoolVar(&routedMode, "routed", false, "Run in routed mode (each backend at /mcp/<server>)")
rootCmd.Flags().BoolVar(&unifiedMode, "unified", false, "Run in unified mode (all backends at /mcp)")
rootCmd.Flags().StringVar(&envFile, "env", "", "Path to .env file to load environment variables")
rootCmd.Flags().BoolVar(&enableDIFC, "enable-difc", false, "Enable DIFC enforcement and session requirement (requires sys___init call before tool access)")
}

func run(cmd *cobra.Command, args []string) error {
Expand Down Expand Up @@ -71,6 +73,14 @@ func run(cmd *cobra.Command, args []string) error {

log.Printf("Loaded %d MCP server(s)", len(cfg.Servers))

// Apply command-line flags to config
cfg.EnableDIFC = enableDIFC
if enableDIFC {
log.Println("DIFC enforcement and session requirement enabled")
} else {
log.Println("DIFC enforcement disabled (sessions auto-created for standard MCP client compatibility)")
}

// Determine mode (default to unified if neither flag is set)
mode := "unified"
if routedMode && unifiedMode {
Expand Down
3 changes: 2 additions & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (

// Config represents the FlowGuard configuration
type Config struct {
Servers map[string]*ServerConfig `toml:"servers"`
Servers map[string]*ServerConfig `toml:"servers"`
EnableDIFC bool // When true, enables DIFC enforcement and requires sys___init call before tool access. Default is false for standard MCP client compatibility.
}

// ServerConfig represents a single MCP server configuration
Expand Down
37 changes: 33 additions & 4 deletions internal/server/unified.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ type Session struct {
SessionID string
}

// NewSession creates a new Session with the given session ID and optional token
func NewSession(sessionID, token string) *Session {
return &Session{
Token: token,
SessionID: sessionID,
}
}

// ContextKey for session ID (exported so transport can use it)
type ContextKey string

Expand Down Expand Up @@ -52,6 +60,7 @@ type UnifiedServer struct {
agentRegistry *difc.AgentRegistry
capabilities *difc.Capabilities
evaluator *difc.Evaluator
enableDIFC bool // When true, DIFC enforcement and session requirement are enabled
}

// NewUnified creates a new unified MCP server
Expand All @@ -70,6 +79,7 @@ func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error)
agentRegistry: difc.NewAgentRegistry(),
capabilities: difc.NewCapabilities(),
evaluator: difc.NewEvaluator(),
enableDIFC: cfg.EnableDIFC,
}

// Create MCP server
Expand Down Expand Up @@ -214,10 +224,7 @@ func (us *UnifiedServer) registerSysTools() error {

// Create session
us.sessionMu.Lock()
us.sessions[sessionID] = &Session{
Token: token,
SessionID: sessionID,
}
us.sessions[sessionID] = NewSession(sessionID, token)
us.sessionMu.Unlock()

log.Printf("Initialized session: %s", sessionID)
Expand Down Expand Up @@ -507,10 +514,32 @@ func (us *UnifiedServer) getSessionID(ctx context.Context) string {
}

// requireSession checks that a session has been initialized for this request
// When DIFC is disabled (default), automatically creates a session if one doesn't exist
func (us *UnifiedServer) requireSession(ctx context.Context) error {
sessionID := us.getSessionID(ctx)
log.Printf("Checking session for ID: %s", sessionID)

// If DIFC is disabled (default), use double-checked locking to auto-create session
if !us.enableDIFC {
us.sessionMu.RLock()
session := us.sessions[sessionID]
us.sessionMu.RUnlock()

if session == nil {
// Need to create session - acquire write lock
us.sessionMu.Lock()
// Double-check after acquiring write lock to avoid race condition
if us.sessions[sessionID] == nil {
log.Printf("DIFC disabled: auto-creating session for ID: %s", sessionID)
us.sessions[sessionID] = NewSession(sessionID, "")
log.Printf("Session auto-created for ID: %s", sessionID)
}
us.sessionMu.Unlock()
}
return nil
}

// DIFC is enabled - require explicit session initialization
us.sessionMu.RLock()
session := us.sessions[sessionID]
us.sessionMu.RUnlock()
Expand Down
105 changes: 96 additions & 9 deletions internal/server/unified_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ func TestUnifiedServer_SessionManagement(t *testing.T) {
token := "test-token"

us.sessionMu.Lock()
us.sessions[sessionID] = &Session{
Token: token,
SessionID: sessionID,
}
us.sessions[sessionID] = NewSession(sessionID, token)
us.sessionMu.Unlock()

// Test session retrieval
Expand Down Expand Up @@ -92,7 +89,7 @@ func TestUnifiedServer_GetSessionKeys(t *testing.T) {
sessions := []string{"session-1", "session-2", "session-3"}
for _, sid := range sessions {
us.sessionMu.Lock()
us.sessions[sid] = &Session{SessionID: sid, Token: "token"}
us.sessions[sid] = NewSession(sid, "token")
us.sessionMu.Unlock()
}

Expand Down Expand Up @@ -210,7 +207,8 @@ func TestGetSessionID_FromContext(t *testing.T) {

func TestRequireSession(t *testing.T) {
cfg := &config.Config{
Servers: map[string]*config.ServerConfig{},
Servers: map[string]*config.ServerConfig{},
EnableDIFC: true, // Enable DIFC for this test
}

ctx := context.Background()
Expand All @@ -223,7 +221,7 @@ func TestRequireSession(t *testing.T) {
// Create a session
sessionID := "valid-session"
us.sessionMu.Lock()
us.sessions[sessionID] = &Session{SessionID: sessionID, Token: "token"}
us.sessions[sessionID] = NewSession(sessionID, "token")
us.sessionMu.Unlock()

// Test with valid session
Expand All @@ -233,10 +231,99 @@ func TestRequireSession(t *testing.T) {
t.Errorf("requireSession() failed for valid session: %v", err)
}

// Test with invalid session
// Test with invalid session (DIFC enabled)
ctxWithInvalidSession := context.WithValue(ctx, SessionIDContextKey, "invalid-session")
err = us.requireSession(ctxWithInvalidSession)
if err == nil {
t.Error("requireSession() should fail for invalid session")
t.Error("requireSession() should fail for invalid session when DIFC is enabled")
}
}

func TestRequireSession_DifcDisabled(t *testing.T) {
cfg := &config.Config{
Servers: map[string]*config.ServerConfig{},
EnableDIFC: false, // DIFC disabled (default)
}

ctx := context.Background()
us, err := NewUnified(ctx, cfg)
if err != nil {
t.Fatalf("NewUnified() failed: %v", err)
}
defer us.Close()

// Test with non-existent session when DIFC is disabled
// Should auto-create a session
sessionID := "new-session"
ctxWithNewSession := context.WithValue(ctx, SessionIDContextKey, sessionID)
err = us.requireSession(ctxWithNewSession)
if err != nil {
t.Errorf("requireSession() should auto-create session when DIFC is disabled: %v", err)
}

// Verify session was created
us.sessionMu.RLock()
session, exists := us.sessions[sessionID]
us.sessionMu.RUnlock()

if !exists {
t.Error("Session should have been auto-created when DIFC is disabled")
}

if session.SessionID != sessionID {
t.Errorf("Expected session ID '%s', got '%s'", sessionID, session.SessionID)
}
}

func TestRequireSession_DifcDisabled_Concurrent(t *testing.T) {
cfg := &config.Config{
Servers: map[string]*config.ServerConfig{},
EnableDIFC: false, // DIFC disabled (default)
}

ctx := context.Background()
us, err := NewUnified(ctx, cfg)
if err != nil {
t.Fatalf("NewUnified() failed: %v", err)
}
defer us.Close()

// Test concurrent session creation to verify no race condition
sessionID := "concurrent-session"
ctxWithSession := context.WithValue(ctx, SessionIDContextKey, sessionID)

// Run 10 goroutines trying to create the same session simultaneously
const numGoroutines = 10
errChan := make(chan error, numGoroutines)

for i := 0; i < numGoroutines; i++ {
go func() {
errChan <- us.requireSession(ctxWithSession)
}()
}

// Collect results
for i := 0; i < numGoroutines; i++ {
if err := <-errChan; err != nil {
t.Errorf("requireSession() failed in concurrent access: %v", err)
}
}

// Verify exactly one session was created
us.sessionMu.RLock()
session, exists := us.sessions[sessionID]
sessionCount := len(us.sessions)
us.sessionMu.RUnlock()

if !exists {
t.Error("Session should have been created")
}

if sessionCount != 1 {
t.Errorf("Expected exactly 1 session, got %d", sessionCount)
}

if session.SessionID != sessionID {
t.Errorf("Expected session ID '%s', got '%s'", sessionID, session.SessionID)
}
}