diff --git a/.gitignore b/.gitignore index aaadf736..eac1b08a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ *.dll *.so *.dylib +flowguard-go # Test binary, built with `go test -c` *.test diff --git a/go.mod b/go.mod index 4d0a28c6..3d3c55f9 100644 --- a/go.mod +++ b/go.mod @@ -7,3 +7,11 @@ require ( github.com/modelcontextprotocol/go-sdk v1.1.0 github.com/spf13/cobra v1.8.0 ) + +require ( + github.com/google/jsonschema-go v0.3.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.30.0 // indirect +) diff --git a/go.sum b/go.sum index 4ecfb5d0..13b3d49b 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,6 @@ -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= @@ -19,81 +16,9 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= -golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= -golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= -golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= -golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457/go.mod h1:pRgIJT+bRLFKnoM1ldnzKoxTIn14Yxz928LQRYYgIN0= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= -golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cmd/root.go b/internal/cmd/root.go index f3c19fa6..5262cf20 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -23,6 +23,7 @@ var ( routedMode bool unifiedMode bool envFile string + enableDIFC bool ) var rootCmd = &cobra.Command{ @@ -40,6 +41,7 @@ func init() { rootCmd.Flags().BoolVar(&routedMode, "routed", false, "Run in routed mode (each backend at /mcp/)") rootCmd.Flags().BoolVar(&unifiedMode, "unified", false, "Run in unified mode (all backends at /mcp)") rootCmd.Flags().StringVar(&envFile, "env", "", "Path to .env file to load environment variables") + rootCmd.Flags().BoolVar(&enableDIFC, "enable-difc", false, "Enable DIFC enforcement and session requirement (requires sys___init call before tool access)") } func run(cmd *cobra.Command, args []string) error { @@ -71,6 +73,14 @@ func run(cmd *cobra.Command, args []string) error { log.Printf("Loaded %d MCP server(s)", len(cfg.Servers)) + // Apply command-line flags to config + cfg.EnableDIFC = enableDIFC + if enableDIFC { + log.Println("DIFC enforcement and session requirement enabled") + } else { + log.Println("DIFC enforcement disabled (sessions auto-created for standard MCP client compatibility)") + } + // Determine mode (default to unified if neither flag is set) mode := "unified" if routedMode && unifiedMode { diff --git a/internal/config/config.go b/internal/config/config.go index 0df3bc5f..a56a7795 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,7 +12,8 @@ import ( // Config represents the FlowGuard configuration type Config struct { - Servers map[string]*ServerConfig `toml:"servers"` + Servers map[string]*ServerConfig `toml:"servers"` + EnableDIFC bool // When true, enables DIFC enforcement and requires sys___init call before tool access. Default is false for standard MCP client compatibility. } // ServerConfig represents a single MCP server configuration diff --git a/internal/server/unified.go b/internal/server/unified.go index 28663ebb..04b3888c 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -21,6 +21,14 @@ type Session struct { SessionID string } +// NewSession creates a new Session with the given session ID and optional token +func NewSession(sessionID, token string) *Session { + return &Session{ + Token: token, + SessionID: sessionID, + } +} + // ContextKey for session ID (exported so transport can use it) type ContextKey string @@ -52,6 +60,7 @@ type UnifiedServer struct { agentRegistry *difc.AgentRegistry capabilities *difc.Capabilities evaluator *difc.Evaluator + enableDIFC bool // When true, DIFC enforcement and session requirement are enabled } // NewUnified creates a new unified MCP server @@ -70,6 +79,7 @@ func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error) agentRegistry: difc.NewAgentRegistry(), capabilities: difc.NewCapabilities(), evaluator: difc.NewEvaluator(), + enableDIFC: cfg.EnableDIFC, } // Create MCP server @@ -214,10 +224,7 @@ func (us *UnifiedServer) registerSysTools() error { // Create session us.sessionMu.Lock() - us.sessions[sessionID] = &Session{ - Token: token, - SessionID: sessionID, - } + us.sessions[sessionID] = NewSession(sessionID, token) us.sessionMu.Unlock() log.Printf("Initialized session: %s", sessionID) @@ -507,10 +514,32 @@ func (us *UnifiedServer) getSessionID(ctx context.Context) string { } // requireSession checks that a session has been initialized for this request +// When DIFC is disabled (default), automatically creates a session if one doesn't exist func (us *UnifiedServer) requireSession(ctx context.Context) error { sessionID := us.getSessionID(ctx) log.Printf("Checking session for ID: %s", sessionID) + // If DIFC is disabled (default), use double-checked locking to auto-create session + if !us.enableDIFC { + us.sessionMu.RLock() + session := us.sessions[sessionID] + us.sessionMu.RUnlock() + + if session == nil { + // Need to create session - acquire write lock + us.sessionMu.Lock() + // Double-check after acquiring write lock to avoid race condition + if us.sessions[sessionID] == nil { + log.Printf("DIFC disabled: auto-creating session for ID: %s", sessionID) + us.sessions[sessionID] = NewSession(sessionID, "") + log.Printf("Session auto-created for ID: %s", sessionID) + } + us.sessionMu.Unlock() + } + return nil + } + + // DIFC is enabled - require explicit session initialization us.sessionMu.RLock() session := us.sessions[sessionID] us.sessionMu.RUnlock() diff --git a/internal/server/unified_test.go b/internal/server/unified_test.go index 3719b4f0..1ccdfb4b 100644 --- a/internal/server/unified_test.go +++ b/internal/server/unified_test.go @@ -52,10 +52,7 @@ func TestUnifiedServer_SessionManagement(t *testing.T) { token := "test-token" us.sessionMu.Lock() - us.sessions[sessionID] = &Session{ - Token: token, - SessionID: sessionID, - } + us.sessions[sessionID] = NewSession(sessionID, token) us.sessionMu.Unlock() // Test session retrieval @@ -92,7 +89,7 @@ func TestUnifiedServer_GetSessionKeys(t *testing.T) { sessions := []string{"session-1", "session-2", "session-3"} for _, sid := range sessions { us.sessionMu.Lock() - us.sessions[sid] = &Session{SessionID: sid, Token: "token"} + us.sessions[sid] = NewSession(sid, "token") us.sessionMu.Unlock() } @@ -210,7 +207,8 @@ func TestGetSessionID_FromContext(t *testing.T) { func TestRequireSession(t *testing.T) { cfg := &config.Config{ - Servers: map[string]*config.ServerConfig{}, + Servers: map[string]*config.ServerConfig{}, + EnableDIFC: true, // Enable DIFC for this test } ctx := context.Background() @@ -223,7 +221,7 @@ func TestRequireSession(t *testing.T) { // Create a session sessionID := "valid-session" us.sessionMu.Lock() - us.sessions[sessionID] = &Session{SessionID: sessionID, Token: "token"} + us.sessions[sessionID] = NewSession(sessionID, "token") us.sessionMu.Unlock() // Test with valid session @@ -233,10 +231,99 @@ func TestRequireSession(t *testing.T) { t.Errorf("requireSession() failed for valid session: %v", err) } - // Test with invalid session + // Test with invalid session (DIFC enabled) ctxWithInvalidSession := context.WithValue(ctx, SessionIDContextKey, "invalid-session") err = us.requireSession(ctxWithInvalidSession) if err == nil { - t.Error("requireSession() should fail for invalid session") + t.Error("requireSession() should fail for invalid session when DIFC is enabled") + } +} + +func TestRequireSession_DifcDisabled(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + EnableDIFC: false, // DIFC disabled (default) + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Test with non-existent session when DIFC is disabled + // Should auto-create a session + sessionID := "new-session" + ctxWithNewSession := context.WithValue(ctx, SessionIDContextKey, sessionID) + err = us.requireSession(ctxWithNewSession) + if err != nil { + t.Errorf("requireSession() should auto-create session when DIFC is disabled: %v", err) + } + + // Verify session was created + us.sessionMu.RLock() + session, exists := us.sessions[sessionID] + us.sessionMu.RUnlock() + + if !exists { + t.Error("Session should have been auto-created when DIFC is disabled") + } + + if session.SessionID != sessionID { + t.Errorf("Expected session ID '%s', got '%s'", sessionID, session.SessionID) + } +} + +func TestRequireSession_DifcDisabled_Concurrent(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + EnableDIFC: false, // DIFC disabled (default) + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Test concurrent session creation to verify no race condition + sessionID := "concurrent-session" + ctxWithSession := context.WithValue(ctx, SessionIDContextKey, sessionID) + + // Run 10 goroutines trying to create the same session simultaneously + const numGoroutines = 10 + errChan := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + errChan <- us.requireSession(ctxWithSession) + }() + } + + // Collect results + for i := 0; i < numGoroutines; i++ { + if err := <-errChan; err != nil { + t.Errorf("requireSession() failed in concurrent access: %v", err) + } + } + + // Verify exactly one session was created + us.sessionMu.RLock() + session, exists := us.sessions[sessionID] + sessionCount := len(us.sessions) + us.sessionMu.RUnlock() + + if !exists { + t.Error("Session should have been created") + } + + if sessionCount != 1 { + t.Errorf("Expected exactly 1 session, got %d", sessionCount) + } + + if session.SessionID != sessionID { + t.Errorf("Expected session ID '%s', got '%s'", sessionID, session.SessionID) } }