From ea1711ecc05e3c02c7c3c2708e321f154d762d1e Mon Sep 17 00:00:00 2001 From: axfor Date: Wed, 7 Jan 2026 14:58:18 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9Afix=20distributed=20lock?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .claude/settings.local.json | 176 +-- api/etcd/kv.go | 41 +- api/etcd/lease.go | 21 +- api/etcd/lease_manager.go | 21 + internal/kvstore/store.go | 6 +- internal/kvstore/types.go | 34 + internal/memory/kvstore.go | 26 + internal/memory/store.go | 115 +- internal/rocksdb/kvstore.go | 152 ++- pkg/concurrency/election.go | 115 +- pkg/concurrency/mutex.go | 200 ++- pkg/concurrency/mutex_test.go | 1757 +++++++++++++++++++++++++ test/distributed_lock_memory_test.go | 1757 +++++++++++++++++++++++++ test/distributed_lock_rocksdb_test.go | 967 ++++++++++++++ 14 files changed, 5100 insertions(+), 288 deletions(-) create mode 100644 pkg/concurrency/mutex_test.go create mode 100644 test/distributed_lock_memory_test.go create mode 100644 test/distributed_lock_rocksdb_test.go diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 7aabe00..da676bd 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -1,181 +1,7 @@ { "permissions": { "allow": [ - "Bash(find:*)", - "Bash(CGO_LDFLAGS=\"-Wl,-U,_SecTrustCopyCertificateChain\" go build:*)", - "Bash(tree:*)", - "Bash(grep -E \"(Starting|became leader|elected leader)\" echo \"\" echo \"=== 节点2 ===\" tail -5 /tmp/node2.log)", - "Bash(grep -E \"(Starting|became leader|became follower|elected leader)\" echo \"\" echo \"=== 节点3 ===\" tail -5 /tmp/node3.log)", - "Bash(mkdir:*)", - "Bash(mv:*)", - "Bash(go version:*)", - "Bash(go mod tidy:*)", - "Bash(go mod download:*)", - "Bash(cat:*)", - "Bash(make help:*)", - "Bash(make clean:*)", - "Bash(make build:*)", - "Bash(make cluster-memory:*)", - "Bash(make status:*)", - "Bash(make stop-cluster:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2\" go build:*)", - "Bash(./metaStore --help)", - "Bash(go test:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2\" go test:*)", - "Bash(make test:*)", - "Bash(git add:*)", - "Bash(git commit -m \"$(cat <<''EOF''\nrefactor: migrate to golang-standards/project-layout structure\n\nThis is a major refactoring that reorganizes the codebase from a flat\nstructure in the root directory to a standard Go project layout following\ngolang-standards/project-layout conventions.\n\nKey changes:\n- Move application entry point to cmd/metastore/main.go\n- Reorganize core code into internal/ packages by functional layer:\n - internal/store: Storage interface and implementations (memory/rocksdb)\n - internal/raft: Raft consensus layer (node implementations)\n - internal/http: HTTP API layer\n - internal/storage: Low-level storage engine (RocksDB wrapper)\n - internal/kvstore: KV store interface definitions\n - internal/memory: Memory-based KV store implementation\n - internal/rocksdb: RocksDB-based implementations\n- Remove flat structure files from root directory\n- Update Makefile to use new cmd/metastore path\n- Add PROJECT_LAYOUT.md documenting the new structure\n- Add comprehensive testing documentation in docs/TESTING.md\n- Migrate all tests to corresponding internal packages\n\nBenefits:\n- Better code organization with clear functional separation\n- Follows Go community best practices\n- Better encapsulation using internal packages\n- Improved modularity with interface-based interactions\n- Easier to test with clear package boundaries\n- Enhanced project professionalism and maintainability\n\nAll functionality and APIs remain unchanged - this is purely a structural\nrefactoring with no behavioral changes.\n\n🤖 Generated with [Claude Code](https://claude.com/claude-code)\n\nCo-Authored-By: Claude \nEOF\n)\")", - "Bash(git commit:*)", - "Bash(git restore:*)", - "Bash(export CGO_ENABLED=1:*)", - "Bash(export CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2\":*)", - "Bash(timeout 60 go test:*)", - "Bash(timeout 120 go test:*)", - "Bash(timeout 30 go test:*)", - "Bash(git rm:*)", - "Bash(go get:*)", - "Bash(go build:*)", - "Bash(CGO_ENABLED=0 go build:*)", - "Bash(pkill:*)", - "Bash(CGO_ENABLED=0 go test:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" go test:*)", - "Bash(timeout 120 bash:*)", - "Bash(/dev/null)", - "Bash(go clean:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" go test ./test -run \"TestEtcdRocksDBSingleNodeOperations/PutAndGet\" -v -timeout=30s)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" go run:*)", - "Bash(go doc:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" go build:*)", - "Bash(tee:*)", - "Read(//Users/bast/go/pkg/mod/go.etcd.io/etcd/client/v3@v3.6.4/**)", - "Read(//Users/bast/go/pkg/mod/go.etcd.io/etcd/**)", - "Read(//private/tmp/**)", - "Bash(make test-unit:*)", - "Bash(for file in /Users/bast/code/MetaStore/PROJECT_LAYOUT.md )", - "Bash(/Users/bast/code/MetaStore/docs/PROJECT_SUMMARY.md )", - "Bash(/Users/bast/code/MetaStore/docs/ROCKSDB_3NODE_TEST_REPORT.md )", - "Bash(/Users/bast/code/MetaStore/docs/phase2-design.md )", - "Bash(/Users/bast/code/MetaStore/docs/DIRECTORY_STRUCTURE_CHANGE_REPORT.md )", - "Bash(/Users/bast/code/MetaStore/docs/ROCKSDB_BUILD_MACOS.md )", - "Bash(/Users/bast/code/MetaStore/docs/ROCKSDB_TEST_GUIDE.md )", - "Bash(/Users/bast/code/MetaStore/docs/PHASE2_COMPLETION_REPORT.md )", - "Bash(/Users/bast/code/MetaStore/docs/ROCKSDB_BUILD_MACOS_EN.md )", - "Bash(/Users/bast/code/MetaStore/docs/ROCKSDB_TEST_REPORT.md )", - "Bash(/Users/bast/code/MetaStore/docs/TEST_COVERAGE_REPORT.md )", - "Bash(/Users/bast/code/MetaStore/test_phase2.sh )", - "Bash(/Users/bast/code/MetaStore/test_phase2_cluster.sh:*)", - "Bash(do:*)", - "Bash(done)", - "Bash(./metastore:*)", - "Bash(git checkout:*)", - "Bash(for:*)", - "Bash(/Users/bast/code/MetaStore/test/etcd_memory_integration_test.go )", - "Bash(/Users/bast/code/MetaStore/test/cross_protocol_integration_test.go )", - "Bash(/Users/bast/code/MetaStore/test/http_api_memory_integration_test.go )", - "Bash(/Users/bast/code/MetaStore/test/http_api_memory_consistency_test.go )", - "Bash(/Users/bast/code/MetaStore/internal/memory/kvstore_etcd_raft.go)", - "Bash(/Users/bast/code/MetaStore/test/etcd_compatibility_test.go )", - "Bash(/Users/bast/code/MetaStore/test/etcd_rocksdb_integration_test.go )", - "Bash(/Users/bast/code/MetaStore/test/http_api_rocksdb_consistency_test.go )", - "Bash(/Users/bast/code/MetaStore/internal/rocksdb/kvstore_etcd_raft.go)", - "Bash(git mv:*)", - "Bash(awk:*)", - "Bash(/Users/bast/code/MetaStore/pkg/etcdapi/server.go )", - "Bash(/Users/bast/code/MetaStore/pkg/etcdapi/kv.go )", - "Bash(/Users/bast/code/MetaStore/pkg/etcdapi/lease_manager.go )", - "Bash(/Users/bast/code/MetaStore/pkg/etcdapi/maintenance.go )", - "Bash(/Users/bast/code/MetaStore/pkg/etcdapi/watch_manager.go )", - "Bash(/Users/bast/code/MetaStore/pkg/etcdapi/lease.go )", - "Bash(/Users/bast/code/MetaStore/pkg/etcdapi/watch.go )", - "Bash(/Users/bast/code/MetaStore/pkg/etcdapi/errors.go)", - "Bash(GOPROXY=https://proxy.golang.org,direct go mod tidy:*)", - "Bash(pkg/etcdapi/kv.go)", - "Bash([ -f \"$file\" ])", - "Bash(\"$file\")", - "Bash(pkg/etcdapi/lease_manager.go)", - "Bash(sed:*)", - "Bash(internal/rocksdb/kvstore.go)", - "Bash(internal/memory/store.go)", - "Bash(internal/memory/watch.go)", - "Bash(chmod:*)", - "Bash(xargs kill:*)", - "Bash(lsof:*)", - "Bash(make test-maintenance:*)", - "Bash(timeout 600 make test:*)", - "Bash(while ps aux)", - "Bash(sample:*)", - "Bash(if ! ps aux)", - "Bash(then echo \"测试已完成!\")", - "Bash(exit 0)", - "Bash(fi)", - "Bash(kill:*)", - "Bash(echo:*)", - "Bash(ps:*)", - "Bash(while ps:*)", - "Bash(while true)", - "Bash(if ps -p 58114)", - "Bash(then)", - "Bash(else)", - "Bash(break)", - "Bash(xargs:*)", - "Bash(while read f)", - "Bash(sort:*)", - "Bash(./test_cluster.sh:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" timeout 60 go test:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" timeout 120 go test:*)", - "Bash(go install:*)", - "Bash(brew install:*)", - "Bash(protoc:*)", - "Bash(export PATH=$PATH:$HOME/go/bin)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" timeout 600 go test:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" timeout 180 go test:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCeptificateChain\" go test:*)", - "Bash(make test-perf-memory:*)", - "Bash(make test-perf-rocksdb:*)", - "Bash(protoc-gen-go:*)", - "Bash(timeout 5 ./metastore:*)", - "Bash(timeout:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" go list:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" timeout 300 go test:*)", - "Bash(CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" timeout 90 go test:*)", - "Bash(GOEXPERIMENT=greenteagc CGO_ENABLED=0 go build:*)", - "Bash(GOEXPERIMENT=greenteagc CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" go build:*)", - "Bash(GOEXPERIMENT=greenteagc CGO_ENABLED=0 go test:*)", - "Bash(GOEXPERIMENT=greenteagc CGO_ENABLED=1 CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\" go test:*)", - "Bash(bash -n:*)", - "Bash(./test_mysql_single.sh:*)", - "Bash(mysql:*)", - "Bash(./test_mysql_quick.sh:*)", - "Bash(bash test_mysql_cluster.sh:*)", - "Bash(bash test_mysql_cluster_simple.sh:*)", - "Bash(bash:*)", - "Bash(export CGO_LDFLAGS=\"-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain\":*)", - "Bash(go run:*)", - "Bash(brew search:*)", - "Bash(done echo 'Installation completed' ls -la /usr/local/opt/)", - "Bash(/usr/local/Cellar/mysql-client@8.0/*/bin/mysql)", - "Bash(/usr/local/Cellar/mysql-client@8.0/8.0.44/bin/mysql:*)", - "Bash(curl:*)", - "Bash(tools/etcdctl:*)", - "Bash(unset http_proxy unset https_proxy curl -X PUT http://127.0.0.1:9991/testkey3 -d \"testvalue3\")", - "Bash(env -u http_proxy -u https_proxy curl -X PUT http://127.0.0.1:9991/testkey3 -d \"testvalue3\")", - "Bash(env -u http_proxy -u https_proxy curl http://127.0.0.1:9991/testkey3)", - "Bash(env -u http_proxy -u https_proxy curl -v -X PUT http://127.0.0.1:9991/testkey -d \"testvalue\")", - "Bash(env -u http_proxy -u https_proxy -u all_proxy curl -v -X PUT http://127.0.0.1:9991/testkey -d \"testvalue\")", - "Bash(env -u http_proxy -u https_proxy -u all_proxy curl -v http://127.0.0.1:9991/testkey)", - "Bash(./test_tidb_parser.sh:*)", - "Bash(./test_sql_parser.sh:*)", - "Bash(then echo \"✓ 新格式 (包含 etcd/http/mysql 配置)\" elif grep -q \"mysql:\" \"$f\")", - "Bash(then if grep -q \"enable:\" \"$f\")", - "Bash(then echo \"✗ 旧格式 (包含 enable 字段)\" else echo \"? 部分更新\" fi else echo \"- 无 MySQL 配置\" fi done)", - "Bash(grep:*)", - "Bash(GOEXPERIMENT=greenteagc CGO_ENABLED=1 'CGO_LDFLAGS=-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain' go test -p 1 -run ^Test[^B] ./test/...)", - "Bash(GOEXPERIMENT=greenteagc CGO_ENABLED=1 'CGO_LDFLAGS=-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain' go test -v -p 1 ./test/... -run TestEtcdMemorySingleNode -timeout=120s)", - "Bash(GOEXPERIMENT=greenteagc CGO_ENABLED=1 'CGO_LDFLAGS=-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain' go test -p 1 ./internal/... ./pkg/... ./api/...)", - "Bash(GOEXPERIMENT=greenteagc CGO_ENABLED=1 'CGO_LDFLAGS=-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain' go test -v ./test/... -run TestCrossProtocolMemoryDataInteroperability -timeout=120s)", - "Bash(GOEXPERIMENT=greenteagc CGO_ENABLED=1 'CGO_LDFLAGS=-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain' go test -v ./test/... -run TestEtcdMemoryClusterBasicConsistency -timeout=120s)", - "Bash(GOEXPERIMENT=greenteagc CGO_ENABLED=1 'CGO_LDFLAGS=-lrocksdb -lpthread -lstdc++ -ldl -lm -lzstd -llz4 -lz -lsnappy -lbz2 -Wl,-U,_SecTrustCopyCertificateChain' go test -v ./test/... -run TestEtcdRocksDBSingleNodeOperations -timeout=120s)", - "Bash(make:*)" + ], "deny": [], "ask": [] diff --git a/api/etcd/kv.go b/api/etcd/kv.go index 9f14419..7f7ba7c 100644 --- a/api/etcd/kv.go +++ b/api/etcd/kv.go @@ -32,11 +32,46 @@ type KVServer struct { func (s *KVServer) Range(ctx context.Context, req *pb.RangeRequest) (*pb.RangeResponse, error) { key := string(req.Key) rangeEnd := string(req.RangeEnd) - limit := req.Limit - revision := req.Revision + + // 构建 RangeOptions + opts := kvstore.RangeOptions{ + Limit: req.Limit, + Revision: req.Revision, + MaxCreateRevision: req.MaxCreateRevision, + MinCreateRevision: req.MinCreateRevision, + MaxModRevision: req.MaxModRevision, + MinModRevision: req.MinModRevision, + CountOnly: req.CountOnly, + KeysOnly: req.KeysOnly, + } + + // 转换排序选项 + switch req.SortOrder { + case pb.RangeRequest_ASCEND: + opts.SortOrder = kvstore.SortAscend + case pb.RangeRequest_DESCEND: + opts.SortOrder = kvstore.SortDescend + default: + opts.SortOrder = kvstore.SortNone + } + + switch req.SortTarget { + case pb.RangeRequest_KEY: + opts.SortTarget = kvstore.SortByKey + case pb.RangeRequest_VERSION: + opts.SortTarget = kvstore.SortByVersion + case pb.RangeRequest_CREATE: + opts.SortTarget = kvstore.SortByCreate + case pb.RangeRequest_MOD: + opts.SortTarget = kvstore.SortByMod + case pb.RangeRequest_VALUE: + opts.SortTarget = kvstore.SortByValue + default: + opts.SortTarget = kvstore.SortByKey + } // 从 store 查询 - resp, err := s.server.store.Range(ctx, key, rangeEnd, limit, revision) + resp, err := s.server.store.RangeWithOptions(ctx, key, rangeEnd, opts) if err != nil { return nil, toGRPCError(err) } diff --git a/api/etcd/lease.go b/api/etcd/lease.go index 7805b91..d55b7f8 100644 --- a/api/etcd/lease.go +++ b/api/etcd/lease.go @@ -16,6 +16,7 @@ package etcd import ( "context" + "errors" pb "go.etcd.io/etcd/api/v3/etcdserverpb" ) @@ -31,9 +32,9 @@ func (s *LeaseServer) LeaseGrant(ctx context.Context, req *pb.LeaseGrantRequest) ttl := req.TTL id := req.ID - // 如果没有指定 ID,自动生成 + // 如果没有指定 ID,自动生成唯一 ID if id == 0 { - id = s.server.store.CurrentRevision() + 1 + id = s.server.leaseMgr.GenerateLeaseID() } // 创建 lease @@ -98,13 +99,23 @@ func (s *LeaseServer) LeaseTimeToLive(ctx context.Context, req *pb.LeaseTimeToLi // 获取 lease 信息 lease, err := s.server.leaseMgr.TimeToLive(id) if err != nil { + // 对于不存在的 Lease,etcd 返回 TTL=-1 而不是错误 + // 这符合 etcd 客户端的期望行为 + if errors.Is(err, ErrLeaseNotFound) { + return &pb.LeaseTimeToLiveResponse{ + Header: s.server.getResponseHeader(), + ID: id, + TTL: -1, + GrantedTTL: 0, + }, nil + } return nil, toGRPCError(err) } resp := &pb.LeaseTimeToLiveResponse{ - Header: s.server.getResponseHeader(), - ID: lease.ID, - TTL: lease.Remaining(), + Header: s.server.getResponseHeader(), + ID: lease.ID, + TTL: lease.Remaining(), GrantedTTL: lease.TTL, } diff --git a/api/etcd/lease_manager.go b/api/etcd/lease_manager.go index ee54a91..3aeaa9a 100644 --- a/api/etcd/lease_manager.go +++ b/api/etcd/lease_manager.go @@ -38,11 +38,21 @@ type LeaseManager struct { checkInterval time.Duration // Lease 过期检查间隔 defaultTTL time.Duration // 默认 TTL maxLeaseCount int // 最大 Lease 数量限制(0 表示无限制) + + // Lease ID 生成器 (集群安全) + // ID 格式: 高16位为节点ID,低48位为计数器 + nodeID uint64 + leaseIDCounter atomic.Int64 } // NewLeaseManager 创建新的 Lease 管理器 // 参数: store, leaseConfig (可选), limitsConfig (可选) func NewLeaseManager(store kvstore.Store, leaseCfg *config.LeaseConfig, limitsCfg *config.LimitsConfig) *LeaseManager { + return NewLeaseManagerWithNodeID(store, leaseCfg, limitsCfg, 1) +} + +// NewLeaseManagerWithNodeID 创建新的 Lease 管理器(带节点 ID,用于集群) +func NewLeaseManagerWithNodeID(store kvstore.Store, leaseCfg *config.LeaseConfig, limitsCfg *config.LimitsConfig, nodeID uint64) *LeaseManager { // 使用配置或默认值 if leaseCfg == nil { defaultCfg := config.DefaultConfig(1, 1, ":2379") @@ -61,6 +71,7 @@ func NewLeaseManager(store kvstore.Store, leaseCfg *config.LeaseConfig, limitsCf checkInterval: leaseCfg.CheckInterval, defaultTTL: leaseCfg.DefaultTTL, maxLeaseCount: maxLeases, + nodeID: nodeID, } } @@ -77,6 +88,16 @@ func (lm *LeaseManager) Stop() { close(lm.stopCh) } +// GenerateLeaseID 生成集群唯一的 Lease ID +// ID 格式: 高16位为节点ID,低48位为计数器 +// 这样每个节点可以独立生成不冲突的 ID +func (lm *LeaseManager) GenerateLeaseID() int64 { + counter := lm.leaseIDCounter.Add(1) + // 高16位: nodeID, 低48位: counter + // 支持最多 65535 个节点,每个节点 2^48 个 Lease + return int64(lm.nodeID<<48) | (counter & 0x0000FFFFFFFFFFFF) +} + // Grant 创建一个新的 lease func (lm *LeaseManager) Grant(id int64, ttl int64) (*kvstore.Lease, error) { if lm.stopped.Load() { diff --git a/internal/kvstore/store.go b/internal/kvstore/store.go index 592488c..21293c4 100644 --- a/internal/kvstore/store.go +++ b/internal/kvstore/store.go @@ -26,7 +26,7 @@ type Store interface { // etcd-compatible methods with Context support - // Range executes a range query + // Range executes a range query (legacy, for backward compatibility) // ctx: context for timeout and cancellation // key: start key // rangeEnd: end key (empty for single key, "\x00" for all keys) @@ -34,6 +34,10 @@ type Store interface { // revision: query data at specific revision (0 for latest) Range(ctx context.Context, key, rangeEnd string, limit int64, revision int64) (*RangeResponse, error) + // RangeWithOptions executes a range query with full options support + // This is the preferred method for complex queries (sorting, filtering by revision, etc.) + RangeWithOptions(ctx context.Context, key, rangeEnd string, opts RangeOptions) (*RangeResponse, error) + // PutWithLease stores a key-value pair with optional lease // Returns new revision and previous value (if any) PutWithLease(ctx context.Context, key, value string, leaseID int64) (revision int64, prevKv *KeyValue, err error) diff --git a/internal/kvstore/types.go b/internal/kvstore/types.go index a859caf..55fe23b 100644 --- a/internal/kvstore/types.go +++ b/internal/kvstore/types.go @@ -139,6 +139,40 @@ type OpResponse struct { DeleteResp *DeleteResponse } +// RangeOptions Range 操作选项 +type RangeOptions struct { + Limit int64 // 返回键数量限制 + Revision int64 // 查询指定 revision 的数据 + SortOrder SortOrder // 排序顺序 + SortTarget SortTarget // 排序目标 + MaxCreateRevision int64 // 最大创建 revision 过滤 + MinCreateRevision int64 // 最小创建 revision 过滤 + MaxModRevision int64 // 最大修改 revision 过滤 + MinModRevision int64 // 最小修改 revision 过滤 + CountOnly bool // 只返回数量 + KeysOnly bool // 只返回键 +} + +// SortOrder 排序顺序 +type SortOrder int + +const ( + SortNone SortOrder = 0 + SortAscend SortOrder = 1 + SortDescend SortOrder = 2 +) + +// SortTarget 排序目标 +type SortTarget int + +const ( + SortByKey SortTarget = 0 + SortByVersion SortTarget = 1 + SortByCreate SortTarget = 2 + SortByMod SortTarget = 3 + SortByValue SortTarget = 4 +) + // RangeResponse Range 操作响应 type RangeResponse struct { Kvs []*KeyValue diff --git a/internal/memory/kvstore.go b/internal/memory/kvstore.go index 44509c3..1fa096b 100644 --- a/internal/memory/kvstore.go +++ b/internal/memory/kvstore.go @@ -743,3 +743,29 @@ func (m *Memory) Range(ctx context.Context, key, rangeEnd string, limit int64, r // 未启用 Lease Read 或 RaftNode 不可用,直接读取 return m.MemoryEtcd.Range(ctx, key, rangeEnd, limit, revision) } + +// RangeWithOptions 执行范围查询(支持完整选项,带 Lease Read 优化) +func (m *Memory) RangeWithOptions(ctx context.Context, key, rangeEnd string, opts kvstore.RangeOptions) (*kvstore.RangeResponse, error) { + // 如果启用了 Lease Read 且 RaftNode 可用 + if m.raftNode != nil { + leaseManager := m.raftNode.LeaseManager() + readIndexManager := m.raftNode.ReadIndexManager() + + if leaseManager != nil && readIndexManager != nil { + // Fast Path: Leader 有有效租约 + if leaseManager.IsLeader() && leaseManager.HasValidLease() { + // 记录快速路径读取 + readIndexManager.RecordFastPathRead() + + // 直接读取本地状态(已由租约保证线性一致性) + return m.MemoryEtcd.RangeWithOptions(ctx, key, rangeEnd, opts) + } + + // 当前简化实现:直接读取(在完整实现前保持向后兼容) + return m.MemoryEtcd.RangeWithOptions(ctx, key, rangeEnd, opts) + } + } + + // 未启用 Lease Read 或 RaftNode 不可用,直接读取 + return m.MemoryEtcd.RangeWithOptions(ctx, key, rangeEnd, opts) +} diff --git a/internal/memory/store.go b/internal/memory/store.go index ad23ff0..5b03265 100644 --- a/internal/memory/store.go +++ b/internal/memory/store.go @@ -72,6 +72,15 @@ func (m *MemoryEtcd) CurrentRevision() int64 { // Range 执行范围查询 func (m *MemoryEtcd) Range(ctx context.Context, key, rangeEnd string, limit int64, revision int64) (*kvstore.RangeResponse, error) { + // 转换为 RangeOptions 调用 + return m.RangeWithOptions(ctx, key, rangeEnd, kvstore.RangeOptions{ + Limit: limit, + Revision: revision, + }) +} + +// RangeWithOptions 执行范围查询(支持完整选项) +func (m *MemoryEtcd) RangeWithOptions(ctx context.Context, key, rangeEnd string, opts kvstore.RangeOptions) (*kvstore.RangeResponse, error) { var kvs []*kvstore.KeyValue // 如果 rangeEnd 为空,查询单个键 @@ -81,25 +90,119 @@ func (m *MemoryEtcd) Range(ctx context.Context, key, rangeEnd string, limit int6 } } else { // 范围查询 - ShardedMap 内部会处理锁和排序 - kvs = m.kvData.Range(key, rangeEnd, limit) + // 先获取全部,后面再应用过滤和排序 + kvs = m.kvData.Range(key, rangeEnd, 0) } - // 应用 limit(Range 已经处理了,这里是为了计算 more 和 count) - more := false + // Apply CreateRevision filter + // Note: MaxCreateRevision filtering should be applied when explicitly set + // When the etcd client uses WithMaxCreateRev(myRev-1) and myRev=1, MaxCreateRevision=0 + // In this case, all keys should be filtered out (all keys have CreateRevision >= 1) + if opts.MaxCreateRevision > 0 || opts.MinCreateRevision > 0 { + filtered := make([]*kvstore.KeyValue, 0, len(kvs)) + for _, kv := range kvs { + if opts.MaxCreateRevision > 0 && kv.CreateRevision > opts.MaxCreateRevision { + continue + } + if opts.MinCreateRevision > 0 && kv.CreateRevision < opts.MinCreateRevision { + continue + } + filtered = append(filtered, kv) + } + kvs = filtered + } + + // 应用 ModRevision 过滤 + if opts.MaxModRevision > 0 || opts.MinModRevision > 0 { + filtered := make([]*kvstore.KeyValue, 0, len(kvs)) + for _, kv := range kvs { + if opts.MaxModRevision > 0 && kv.ModRevision > opts.MaxModRevision { + continue + } + if opts.MinModRevision > 0 && kv.ModRevision < opts.MinModRevision { + continue + } + filtered = append(filtered, kv) + } + kvs = filtered + } + + // 应用排序 + if opts.SortOrder != kvstore.SortNone && len(kvs) > 1 { + m.sortKvs(kvs, opts.SortTarget, opts.SortOrder) + } + + // 计算 count(在应用 limit 之前) count := int64(len(kvs)) - if limit > 0 && int64(len(kvs)) > limit { - kvs = kvs[:limit] + + // 如果只需要计数 + if opts.CountOnly { + return &kvstore.RangeResponse{ + Kvs: nil, + More: false, + Count: count, + Revision: m.revision.Load(), + }, nil + } + + // 应用 limit + more := false + if opts.Limit > 0 && int64(len(kvs)) > opts.Limit { + kvs = kvs[:opts.Limit] more = true } + // 如果只需要 keys + if opts.KeysOnly { + for _, kv := range kvs { + kv.Value = nil + } + } + return &kvstore.RangeResponse{ Kvs: kvs, More: more, Count: count, - Revision: m.revision.Load(), // ✅ atomic 操作,无需加锁 + Revision: m.revision.Load(), }, nil } +// sortKvs 对 kvs 进行排序 +func (m *MemoryEtcd) sortKvs(kvs []*kvstore.KeyValue, target kvstore.SortTarget, order kvstore.SortOrder) { + // 使用标准库排序 + less := func(i, j int) bool { + var cmp int + switch target { + case kvstore.SortByKey: + cmp = bytes.Compare(kvs[i].Key, kvs[j].Key) + case kvstore.SortByCreate: + cmp = int(kvs[i].CreateRevision - kvs[j].CreateRevision) + case kvstore.SortByMod: + cmp = int(kvs[i].ModRevision - kvs[j].ModRevision) + case kvstore.SortByVersion: + cmp = int(kvs[i].Version - kvs[j].Version) + case kvstore.SortByValue: + cmp = bytes.Compare(kvs[i].Value, kvs[j].Value) + default: + cmp = bytes.Compare(kvs[i].Key, kvs[j].Key) + } + if order == kvstore.SortDescend { + return cmp > 0 + } + return cmp < 0 + } + + // 简单的冒泡排序(对于分布式锁通常只有少量 key) + n := len(kvs) + for i := 0; i < n-1; i++ { + for j := 0; j < n-i-1; j++ { + if !less(j, j+1) { + kvs[j], kvs[j+1] = kvs[j+1], kvs[j] + } + } + } +} + // PutWithLease 存储键值对,可选关联 lease func (m *MemoryEtcd) PutWithLease(ctx context.Context, key, value string, leaseID int64) (int64, *kvstore.KeyValue, error) { // 验证 lease(如果指定) diff --git a/internal/rocksdb/kvstore.go b/internal/rocksdb/kvstore.go index df2b785..f30c082 100644 --- a/internal/rocksdb/kvstore.go +++ b/internal/rocksdb/kvstore.go @@ -363,12 +363,15 @@ func (r *RocksDB) applyOperationsBatch(ops []*RaftOperation) { } case "LEASE_REVOKE": - if err := r.prepareLeaseRevokeBatch(batch, op.LeaseID); err != nil { + events, err := r.prepareLeaseRevokeBatch(batch, op.LeaseID) + if err != nil { log.Error("Failed to prepare LEASE_REVOKE in batch", zap.Error(err), zap.Int64("leaseID", op.LeaseID), zap.String("component", "storage-rocksdb")) + continue } + watchEvents = append(watchEvents, events...) case "TXN": // Transactions need special handling - apply individually for now @@ -483,6 +486,15 @@ func (r *RocksDB) incrementRevision() (int64, error) { // Range performs range query func (r *RocksDB) Range(ctx context.Context, key, rangeEnd string, limit int64, revision int64) (*kvstore.RangeResponse, error) { + // 转换为 RangeWithOptions 调用 + return r.RangeWithOptions(ctx, key, rangeEnd, kvstore.RangeOptions{ + Limit: limit, + Revision: revision, + }) +} + +// RangeWithOptions performs range query with full options support +func (r *RocksDB) RangeWithOptions(ctx context.Context, key, rangeEnd string, opts kvstore.RangeOptions) (*kvstore.RangeResponse, error) { // Lease Read 优化: 检查是否可以使用快速路径 if r.raftNode != nil { leaseManager := r.raftNode.LeaseManager() @@ -503,8 +515,8 @@ func (r *RocksDB) Range(ctx context.Context, key, rangeEnd string, limit int64, // Pre-allocate slice with estimated capacity estimatedCap := 100 - if limit > 0 && limit < 100 { - estimatedCap = int(limit) + if opts.Limit > 0 && opts.Limit < 100 { + estimatedCap = int(opts.Limit) } kvs := make([]*kvstore.KeyValue, 0, estimatedCap) @@ -532,11 +544,6 @@ func (r *RocksDB) Range(ctx context.Context, key, rangeEnd string, limit int64, if err == nil && kv != nil { kvs = append(kvs, kv) } - - // Early exit if limit reached - if limit > 0 && int64(len(kvs)) >= limit { - break - } } if rangeEnd != "\x00" && k >= rangeEnd { @@ -545,21 +552,75 @@ func (r *RocksDB) Range(ctx context.Context, key, rangeEnd string, limit int64, it.Next() } + } + + // Apply CreateRevision filter + if opts.MaxCreateRevision > 0 || opts.MinCreateRevision > 0 { + filtered := make([]*kvstore.KeyValue, 0, len(kvs)) + for _, kv := range kvs { + if opts.MaxCreateRevision > 0 && kv.CreateRevision > opts.MaxCreateRevision { + continue + } + if opts.MinCreateRevision > 0 && kv.CreateRevision < opts.MinCreateRevision { + continue + } + filtered = append(filtered, kv) + } + kvs = filtered + } + + // Apply ModRevision filter + if opts.MaxModRevision > 0 || opts.MinModRevision > 0 { + filtered := make([]*kvstore.KeyValue, 0, len(kvs)) + for _, kv := range kvs { + if opts.MaxModRevision > 0 && kv.ModRevision > opts.MaxModRevision { + continue + } + if opts.MinModRevision > 0 && kv.ModRevision < opts.MinModRevision { + continue + } + filtered = append(filtered, kv) + } + kvs = filtered + } - // Sort by key + // Apply sorting + if opts.SortOrder != kvstore.SortNone && len(kvs) > 1 { + r.sortKvs(kvs, opts.SortTarget, opts.SortOrder) + } else if len(kvs) > 1 { + // Default sort by key sort.Slice(kvs, func(i, j int) bool { return string(kvs[i].Key) < string(kvs[j].Key) }) } + // Calculate count before applying limit + count := int64(len(kvs)) + + // CountOnly: only return count + if opts.CountOnly { + return &kvstore.RangeResponse{ + Kvs: nil, + More: false, + Count: count, + Revision: r.CurrentRevision(), + }, nil + } + // Apply limit more := false - count := int64(len(kvs)) - if limit > 0 && int64(len(kvs)) > limit { - kvs = kvs[:limit] + if opts.Limit > 0 && int64(len(kvs)) > opts.Limit { + kvs = kvs[:opts.Limit] more = true } + // KeysOnly: clear values + if opts.KeysOnly { + for _, kv := range kvs { + kv.Value = nil + } + } + return &kvstore.RangeResponse{ Kvs: kvs, More: more, @@ -568,6 +629,33 @@ func (r *RocksDB) Range(ctx context.Context, key, rangeEnd string, limit int64, }, nil } +// sortKvs sorts key-value pairs according to target and order +func (r *RocksDB) sortKvs(kvs []*kvstore.KeyValue, target kvstore.SortTarget, order kvstore.SortOrder) { + less := func(i, j int) bool { + var cmp int + switch target { + case kvstore.SortByKey: + cmp = bytes.Compare(kvs[i].Key, kvs[j].Key) + case kvstore.SortByCreate: + cmp = int(kvs[i].CreateRevision - kvs[j].CreateRevision) + case kvstore.SortByMod: + cmp = int(kvs[i].ModRevision - kvs[j].ModRevision) + case kvstore.SortByVersion: + cmp = int(kvs[i].Version - kvs[j].Version) + case kvstore.SortByValue: + cmp = bytes.Compare(kvs[i].Value, kvs[j].Value) + default: + cmp = bytes.Compare(kvs[i].Key, kvs[j].Key) + } + if order == kvstore.SortDescend { + return cmp > 0 + } + return cmp < 0 + } + + sort.Slice(kvs, less) +} + // PutWithLease stores key-value with optional lease func (r *RocksDB) PutWithLease(ctx context.Context, key, value string, leaseID int64) (int64, *kvstore.KeyValue, error) { // Check prevKv before submitting to Raft @@ -809,29 +897,59 @@ func (r *RocksDB) prepareLeaseGrantBatch(batch *grocksdb.WriteBatch, leaseID, tt } // prepareLeaseRevokeBatch prepares a LEASE_REVOKE operation to be added to a WriteBatch -func (r *RocksDB) prepareLeaseRevokeBatch(batch *grocksdb.WriteBatch, leaseID int64) error { +// Returns watch events to be emitted after batch write succeeds +func (r *RocksDB) prepareLeaseRevokeBatch(batch *grocksdb.WriteBatch, leaseID int64) ([]kvstore.WatchEvent, error) { // Get the lease to find associated keys lease, err := r.getLease(leaseID) if err != nil { - return fmt.Errorf("failed to get lease %d: %v", leaseID, err) + return nil, fmt.Errorf("failed to get lease %d: %v", leaseID, err) } if lease == nil { // Lease doesn't exist, nothing to revoke - return nil + return nil, nil } - // Delete all keys associated with this lease + var events []kvstore.WatchEvent + + // Delete all keys associated with this lease and prepare watch events for key := range lease.Keys { + // Get old value first for watch event + prevKv, _ := r.getKeyValue(key) + dbKey := []byte(kvPrefix + key) batch.Delete(dbKey) + + // Prepare watch event if key existed + if prevKv != nil { + // Get revision for watch event + newRevision, err := r.incrementRevision() + if err != nil { + return nil, err + } + + deletedKv := &kvstore.KeyValue{ + Key: prevKv.Key, + Value: nil, + CreateRevision: prevKv.CreateRevision, + ModRevision: newRevision, + Version: 0, + Lease: 0, + } + events = append(events, kvstore.WatchEvent{ + Type: kvstore.EventTypeDelete, + Kv: deletedKv, + PrevKv: prevKv, + Revision: newRevision, + }) + } } // Delete the lease itself leaseKey := []byte(fmt.Sprintf("%s%d", leasePrefix, leaseID)) batch.Delete(leaseKey) - return nil + return events, nil } // putUnlocked applies put operation (called after Raft commit) diff --git a/pkg/concurrency/election.go b/pkg/concurrency/election.go index 743fb32..40a01b1 100644 --- a/pkg/concurrency/election.go +++ b/pkg/concurrency/election.go @@ -22,6 +22,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" pb "go.etcd.io/etcd/api/v3/etcdserverpb" + mvccpb "go.etcd.io/etcd/api/v3/mvccpb" ) var ( @@ -101,50 +102,112 @@ func (e *Election) Campaign(ctx context.Context, val string) error { } // waitLeader 等待成为 Leader(所有更早的 key 被删除) +// Automatically retries if Watch is canceled or network errors occur func (e *Election) waitLeader(ctx context.Context, myKey string, myRev int64) error { client := e.s.client - // 获取所有前缀匹配的 key - getOpts := append(clientv3.WithFirstCreate(), clientv3.WithMaxCreateRev(myRev-1)) for { - // 获取所有 CreateRevision < myRev 的 key - resp, err := client.Get(ctx, e.pfx, getOpts...) + // 检查会话是否还有效 + select { + case <-e.s.Done(): + return errors.New("session expired") + case <-ctx.Done(): + return ctx.Err() + default: + } + + // 获取所有前缀匹配的 key,按 CreateRevision 排序 + resp, err := client.Get(ctx, e.pfx, + clientv3.WithPrefix(), + clientv3.WithSort(clientv3.SortByCreateRevision, clientv3.SortAscend)) if err != nil { return err } + // 手动过滤出 CreateRevision < myRev 的 key + var earlierKeys []*mvccpb.KeyValue + for _, kv := range resp.Kvs { + if kv.CreateRevision < myRev { + earlierKeys = append(earlierKeys, kv) + } + } + // 没有更早的 key,成为 Leader - if len(resp.Kvs) == 0 { + if len(earlierKeys) == 0 { return nil } - // 找到最早的 key - lastKey := string(resp.Kvs[0].Key) + // 找到最早的 key (first one after sorting and filtering) + lastKey := string(earlierKeys[0].Key) - // Watch 该 key,等待其删除 - wch := client.Watch(ctx, lastKey, clientv3.WithRev(myRev)) - for wresp := range wch { - if wresp.Canceled { - return errors.New("watch canceled") + // Watch for deletion with automatic retry on cancellation + err = e.watchKeyDeletion(ctx, lastKey, resp.Header.Revision) + if err != nil { + // If watch was canceled or had network error, retry the loop + // The loop will recheck if the key still exists + if isElectionWatchCanceledOrNetworkError(err) { + continue } - for _, ev := range wresp.Events { - if ev.Type == clientv3.EventTypeDelete { - // key 被删除,继续检查 - goto RETRY - } + return err + } + + // Key was deleted, loop will recheck for more earlier keys + } +} + +// watchKeyDeletion watches a specific key for deletion +// Returns nil when key is deleted, error otherwise +func (e *Election) watchKeyDeletion(ctx context.Context, key string, revision int64) error { + client := e.s.client + + // Create a cancellable context for watch + watchCtx, watchCancel := context.WithCancel(ctx) + defer watchCancel() + + // Watch for deletion starting from the current revision + wch := client.Watch(watchCtx, key, clientv3.WithRev(revision)) + + for wresp := range wch { + if wresp.Canceled { + // Watch was canceled - could be network error or context cancellation + if wresp.Err() != nil { + return wresp.Err() } + return errors.New("watch canceled") } - - RETRY: - // 检查会话是否还有效 - select { - case <-e.s.Done(): - return errors.New("session expired") - case <-ctx.Done(): - return ctx.Err() - default: + for _, ev := range wresp.Events { + if ev.Type == clientv3.EventTypeDelete { + // Key deleted successfully + return nil + } } } + + // Watch channel closed without delete event, check context + select { + case <-ctx.Done(): + return ctx.Err() + case <-e.s.Done(): + return errors.New("session expired") + default: + // Watch channel closed unexpectedly, return error to trigger retry + return errors.New("watch channel closed") + } +} + +// isElectionWatchCanceledOrNetworkError checks if error is due to watch cancellation or network issue +func isElectionWatchCanceledOrNetworkError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + // Check for common watch cancellation and network error patterns + return errStr == "watch canceled" || + errStr == "watch channel closed" || + errStr == "context canceled" || + errStr == "rpc error" || + errStr == "connection" || + errStr == "EOF" } // Resign 主动放弃 Leader 身份 diff --git a/pkg/concurrency/mutex.go b/pkg/concurrency/mutex.go index cc5b601..4df5094 100644 --- a/pkg/concurrency/mutex.go +++ b/pkg/concurrency/mutex.go @@ -45,28 +45,30 @@ func NewMutex(s *Session, pfx string) *Mutex { } } -// Lock 获取锁,阻塞直到成功 +// Lock acquires the lock, blocking until successful func (m *Mutex) Lock(ctx context.Context) error { s := m.s client := m.s.client m.mu.Lock() - // 如果已经持有锁,直接返回 + // Already holding the lock if m.myKey != "" { m.mu.Unlock() return nil } m.mu.Unlock() - // 使用 Lease 创建临时 key - // key 格式: prefix/lease_id + // Create a unique key using lease ID + // Key format: prefix/lease_id myKey := fmt.Sprintf("%s%x", m.pfx, s.Lease()) - - // 使用事务创建 key(仅当不存在时) + + // Step 1: Use transaction to create key (only if not exists) + // Note: We cannot query owner in the same transaction because the query + // would see the state before our Put operation completes cmp := clientv3.Compare(clientv3.CreateRevision(myKey), "=", 0) put := clientv3.OpPut(myKey, "", clientv3.WithLease(s.Lease())) get := clientv3.OpGet(myKey) - + resp, err := client.Txn(ctx).If(cmp).Then(put).Else(get).Commit() if err != nil { return err @@ -76,69 +78,157 @@ func (m *Mutex) Lock(ctx context.Context) error { if resp.Succeeded { myRev = resp.Header.Revision } else { - // key 已存在,获取其 revision + // Key already exists, get its revision myRev = resp.Responses[0].GetResponseRange().Kvs[0].CreateRevision } - // 保存锁信息 + // Save lock info m.mu.Lock() m.myKey = myKey m.myRev = myRev m.hdr = resp.Header m.mu.Unlock() - // 等待获取锁 - return m.waitDeletes(ctx, myKey, myRev) + // Step 2: Query for the current lock owner (after key creation) + // This must be done separately to see the key we just created + ownerResp, err := client.Get(ctx, m.pfx, clientv3.WithFirstCreate()...) + if err != nil { + m.Unlock(ctx) + return err + } + + // Check if we already hold the lock (we are the first key) + if len(ownerResp.Kvs) == 0 || ownerResp.Kvs[0].CreateRevision == myRev { + return nil + } + + // Wait for all earlier keys to be deleted + err = m.waitDeletes(ctx, myKey, myRev) + if err != nil { + // Release the key on error + m.Unlock(ctx) + return err + } + + // Verify we still own the key after waiting + gresp, err := client.Get(ctx, myKey) + if err != nil { + m.Unlock(ctx) + return err + } + if len(gresp.Kvs) == 0 { + return errors.New("session expired") + } + + m.mu.Lock() + m.hdr = gresp.Header + m.mu.Unlock() + + return nil } -// waitDeletes 等待所有更早的 key 被删除 +// waitDeletes waits for all keys with CreateRevision <= maxCreateRev to be deleted +// Automatically retries if Watch is canceled or network errors occur func (m *Mutex) waitDeletes(ctx context.Context, myKey string, myRev int64) error { client := m.s.client - // 获取所有前缀匹配的 key - getOpts := append(clientv3.WithFirstCreate(), clientv3.WithMaxCreateRev(myRev-1)) + // Use WithLastCreate to get the key with the largest CreateRevision <= myRev-1 + // This is the key we need to wait for + getOpts := append(clientv3.WithLastCreate(), clientv3.WithMaxCreateRev(myRev-1)) for { - // 获取所有 CreateRevision < myRev 的 key + // Check if session is still valid + select { + case <-m.s.Done(): + return errors.New("session expired") + case <-ctx.Done(): + return ctx.Err() + default: + } + resp, err := client.Get(ctx, m.pfx, getOpts...) if err != nil { return err } - // 没有更早的 key,获得锁 + // No earlier keys exist, we have the lock if len(resp.Kvs) == 0 { return nil } - // 找到最大的 CreateRevision 小于 myRev 的 key + // Wait for this key to be deleted lastKey := string(resp.Kvs[0].Key) - // Watch 该 key,等待其删除 - wch := client.Watch(ctx, lastKey, clientv3.WithRev(myRev)) - for wresp := range wch { - if wresp.Canceled { - return errors.New("watch canceled") + // Watch for deletion with automatic retry on cancellation + err = m.watchKeyDeletion(ctx, lastKey, resp.Header.Revision) + if err != nil { + // If watch was canceled or had network error, retry the loop + // The loop will recheck if the key still exists + if isWatchCanceledOrNetworkError(err) { + continue } - for _, ev := range wresp.Events { - if ev.Type == clientv3.EventTypeDelete { - // key 被删除,继续检查 - goto RETRY - } + return err + } + + // Key was deleted, loop will recheck for more earlier keys + } +} + +// watchKeyDeletion watches a specific key for deletion +// Returns nil when key is deleted, error otherwise +func (m *Mutex) watchKeyDeletion(ctx context.Context, key string, revision int64) error { + client := m.s.client + + // Create a cancellable context for watch + watchCtx, watchCancel := context.WithCancel(ctx) + defer watchCancel() + + // Watch for deletion starting from the current revision + wch := client.Watch(watchCtx, key, clientv3.WithRev(revision)) + + for wresp := range wch { + if wresp.Canceled { + // Watch was canceled - could be network error or context cancellation + if wresp.Err() != nil { + return wresp.Err() } + return errors.New("watch canceled") } - - RETRY: - // 检查会话是否还有效 - select { - case <-m.s.Done(): - return errors.New("session expired") - case <-ctx.Done(): - return ctx.Err() - default: + for _, ev := range wresp.Events { + if ev.Type == clientv3.EventTypeDelete { + // Key deleted successfully + return nil + } } } + + // Watch channel closed without delete event, check context + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.s.Done(): + return errors.New("session expired") + default: + // Watch channel closed unexpectedly, return error to trigger retry + return errors.New("watch channel closed") + } } -// TryLock 尝试获取锁,不阻塞 +// isWatchCanceledOrNetworkError checks if error is due to watch cancellation or network issue +func isWatchCanceledOrNetworkError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + // Check for common watch cancellation and network error patterns + return errStr == "watch canceled" || + errStr == "watch channel closed" || + errStr == "context canceled" || + errStr == "rpc error" || + errStr == "connection" || + errStr == "EOF" +} + +// TryLock attempts to acquire the lock without blocking func (m *Mutex) TryLock(ctx context.Context) error { s := m.s client := m.s.client @@ -151,12 +241,12 @@ func (m *Mutex) TryLock(ctx context.Context) error { m.mu.Unlock() myKey := fmt.Sprintf("%s%x", m.pfx, s.Lease()) - - // 创建 key + + // Step 1: Create key using transaction cmp := clientv3.Compare(clientv3.CreateRevision(myKey), "=", 0) put := clientv3.OpPut(myKey, "", clientv3.WithLease(s.Lease())) get := clientv3.OpGet(myKey) - + resp, err := client.Txn(ctx).If(cmp).Then(put).Else(get).Commit() if err != nil { return err @@ -169,27 +259,27 @@ func (m *Mutex) TryLock(ctx context.Context) error { myRev = resp.Responses[0].GetResponseRange().Kvs[0].CreateRevision } - // 检查是否有更早的 key - getOpts := append(clientv3.WithFirstCreate(), clientv3.WithMaxCreateRev(myRev-1)) - gresp, err := client.Get(ctx, m.pfx, getOpts...) + // Step 2: Query for the current lock owner + ownerResp, err := client.Get(ctx, m.pfx, clientv3.WithFirstCreate()...) if err != nil { + _, _ = client.Delete(ctx, myKey) return err } - if len(gresp.Kvs) > 0 { - // 有更早的 key,删除自己的 key - _, _ = client.Delete(ctx, myKey) - return concurrency.ErrLocked + // Check if we are the owner (first key by creation revision) + if len(ownerResp.Kvs) == 0 || ownerResp.Kvs[0].CreateRevision == myRev { + // We are the owner + m.mu.Lock() + m.myKey = myKey + m.myRev = myRev + m.hdr = resp.Header + m.mu.Unlock() + return nil } - // 获得锁 - m.mu.Lock() - m.myKey = myKey - m.myRev = myRev - m.hdr = resp.Header - m.mu.Unlock() - - return nil + // Not the owner, delete our key + _, _ = client.Delete(ctx, myKey) + return concurrency.ErrLocked } // Unlock 释放锁 diff --git a/pkg/concurrency/mutex_test.go b/pkg/concurrency/mutex_test.go new file mode 100644 index 0000000..611e2ff --- /dev/null +++ b/pkg/concurrency/mutex_test.go @@ -0,0 +1,1757 @@ +// Copyright 2025 The axfor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package concurrency + +import ( + "context" + "fmt" + "sort" + "sync" + "sync/atomic" + "testing" + "time" + + etcdapi "metaStore/api/etcd" + "metaStore/internal/memory" + + clientv3 "go.etcd.io/etcd/client/v3" + etcdconcurrency "go.etcd.io/etcd/client/v3/concurrency" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================================ +// Test Helper Functions +// ============================================================================ + +// startLockTestServer 启动用于锁测试的服务器 +func startLockTestServer(t *testing.T) (*etcdapi.Server, *clientv3.Client) { + store := memory.NewMemoryEtcd() + server, err := etcdapi.NewServer(etcdapi.ServerConfig{ + Store: store, + Address: "127.0.0.1:0", + ClusterID: 1, + MemberID: 1, + }) + require.NoError(t, err) + + go func() { + if err := server.Start(); err != nil { + t.Logf("Server error: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + cli, err := clientv3.New(clientv3.Config{ + Endpoints: []string{server.Address()}, + DialTimeout: 5 * time.Second, + }) + require.NoError(t, err) + + t.Cleanup(func() { + cli.Close() + server.Stop() + }) + + return server, cli +} + +// ============================================================================ +// Session Tests +// ============================================================================ + +// TestSessionCreate 测试 Session 创建 +func TestSessionCreate(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 创建会话 + session, err := NewSession(cli, WithTTL(10)) + require.NoError(t, err) + require.NotNil(t, session) + + // 验证 Lease ID + leaseID := session.Lease() + assert.NotEqual(t, clientv3.NoLease, leaseID) + + // 验证 Lease 有效 + ttlResp, err := cli.TimeToLive(ctx, leaseID) + require.NoError(t, err) + assert.Greater(t, ttlResp.TTL, int64(0)) + assert.LessOrEqual(t, ttlResp.TTL, int64(10)) + + // 关闭会话 + err = session.Close() + require.NoError(t, err) + + // 验证 Lease 已被撤销 + ttlResp, err = cli.TimeToLive(ctx, leaseID) + require.NoError(t, err) + assert.Equal(t, int64(-1), ttlResp.TTL) +} + +// TestSessionWithExistingLease 测试使用现有 Lease 创建会话 +func TestSessionWithExistingLease(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 先创建 Lease + leaseResp, err := cli.Grant(ctx, 30) + require.NoError(t, err) + + // 使用现有 Lease 创建会话 + session, err := NewSession(cli, WithLease(leaseResp.ID)) + require.NoError(t, err) + require.NotNil(t, session) + + // 验证使用的是同一个 Lease + assert.Equal(t, leaseResp.ID, session.Lease()) + + session.Close() +} + +// TestSessionOrphan 测试 Orphan 功能 +func TestSessionOrphan(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + leaseID := session.Lease() + + // 使用 Orphan 结束会话但保留 Lease + session.Orphan() + + // 验证 Lease 仍然有效 + ttlResp, err := cli.TimeToLive(ctx, leaseID) + require.NoError(t, err) + assert.Greater(t, ttlResp.TTL, int64(0)) + + // 手动撤销 Lease + _, err = cli.Revoke(ctx, leaseID) + require.NoError(t, err) +} + +// TestSessionExpiry 测试 Session 过期 +func TestSessionExpiry(t *testing.T) { + _, cli := startLockTestServer(t) + + // 创建短期会话(2秒) + session, err := NewSession(cli, WithTTL(2)) + require.NoError(t, err) + leaseID := session.Lease() + + // 关闭会话(停止 keepalive) + session.Close() + + // 等待 Lease 过期 + time.Sleep(3 * time.Second) + + // 验证 Lease 已过期 + ctx := context.Background() + ttlResp, err := cli.TimeToLive(ctx, leaseID) + require.NoError(t, err) + assert.Equal(t, int64(-1), ttlResp.TTL) +} + +// ============================================================================ +// Basic Mutex Tests +// ============================================================================ + +// TestMutexLockUnlock 测试基本的 Lock 和 Unlock +func TestMutexLockUnlock(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/lock") + + // 验证初始状态 + assert.False(t, mutex.IsOwner()) + assert.Empty(t, mutex.Key()) + + // 获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 验证锁状态 + assert.True(t, mutex.IsOwner()) + assert.NotEmpty(t, mutex.Key()) + assert.NotNil(t, mutex.Header()) + + // 释放锁 + err = mutex.Unlock(ctx) + require.NoError(t, err) + + // 验证锁已释放 + assert.False(t, mutex.IsOwner()) + assert.Empty(t, mutex.Key()) +} + +// TestMutexReentrantLock 测试重入锁(同一个 Mutex 多次 Lock) +func TestMutexReentrantLock(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/reentrant") + + // 第一次获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + firstKey := mutex.Key() + + // 第二次获取锁(应该立即返回) + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 验证 key 没有变化 + assert.Equal(t, firstKey, mutex.Key()) + + // 释放锁 + err = mutex.Unlock(ctx) + require.NoError(t, err) +} + +// TestMutexUnlockWithoutLock 测试未持有锁时 Unlock +func TestMutexUnlockWithoutLock(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/unlock-without-lock") + + // 未持有锁时 Unlock 应该是安全的 + err = mutex.Unlock(ctx) + require.NoError(t, err) +} + +// ============================================================================ +// TryLock Tests +// ============================================================================ + +// TestTryLockSuccess 测试 TryLock 成功场景 +func TestTryLockSuccess(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/trylock") + + // TryLock 应该立即成功 + err = mutex.TryLock(ctx) + require.NoError(t, err) + assert.True(t, mutex.IsOwner()) + + mutex.Unlock(ctx) +} + +// TestTryLockFail 测试 TryLock 失败场景 +func TestTryLockFail(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 第一个会话获取锁 + session1, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session1.Close() + + mutex1 := NewMutex(session1, "/test/trylock-fail") + err = mutex1.Lock(ctx) + require.NoError(t, err) + + // 第二个会话尝试 TryLock + session2, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := NewMutex(session2, "/test/trylock-fail") + err = mutex2.TryLock(ctx) + + // 应该返回 ErrLocked + assert.Error(t, err) + assert.Equal(t, etcdconcurrency.ErrLocked, err) + assert.False(t, mutex2.IsOwner()) + + mutex1.Unlock(ctx) +} + +// TestTryLockAfterUnlock 测试解锁后 TryLock +func TestTryLockAfterUnlock(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session1, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session1.Close() + + session2, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + mutex1 := NewMutex(session1, "/test/trylock-after-unlock") + mutex2 := NewMutex(session2, "/test/trylock-after-unlock") + + // session1 获取锁 + err = mutex1.Lock(ctx) + require.NoError(t, err) + + // session2 TryLock 失败 + err = mutex2.TryLock(ctx) + assert.Equal(t, etcdconcurrency.ErrLocked, err) + + // session1 释放锁 + err = mutex1.Unlock(ctx) + require.NoError(t, err) + + // session2 TryLock 成功 + err = mutex2.TryLock(ctx) + require.NoError(t, err) + assert.True(t, mutex2.IsOwner()) + + mutex2.Unlock(ctx) +} + +// ============================================================================ +// Concurrent Lock Tests +// ============================================================================ + +// TestMutexContention 测试锁竞争 +func TestMutexContention(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numClients = 5 + var wg sync.WaitGroup + acquired := make(chan int, numClients) + released := make(chan int, numClients) + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/contention") + + // 获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + + acquired <- id + t.Logf("Client %d acquired lock", id) + + // 持有锁一小段时间 + time.Sleep(50 * time.Millisecond) + + // 释放锁 + err = mutex.Unlock(ctx) + require.NoError(t, err) + + released <- id + t.Logf("Client %d released lock", id) + }(i) + } + + // 等待所有 goroutine 完成 + wg.Wait() + close(acquired) + close(released) + + // 验证每个客户端都获取并释放了锁 + acquiredClients := make(map[int]bool) + for id := range acquired { + acquiredClients[id] = true + } + assert.Len(t, acquiredClients, numClients) + + releasedClients := make(map[int]bool) + for id := range released { + releasedClients[id] = true + } + assert.Len(t, releasedClients, numClients) +} + +// TestMutexFIFOOrder 测试锁的 FIFO 顺序 +func TestMutexFIFOOrder(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numClients = 5 + var orderMu sync.Mutex + acquireOrder := make([]int, 0, numClients) + + // 创建一个信号通道来控制启动顺序 + startSignals := make([]chan struct{}, numClients) + for i := range startSignals { + startSignals[i] = make(chan struct{}) + } + + var wg sync.WaitGroup + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // 等待启动信号 + <-startSignals[id] + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/fifo") + + // 获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 记录获取顺序 + orderMu.Lock() + acquireOrder = append(acquireOrder, id) + orderMu.Unlock() + + t.Logf("Client %d acquired lock at position %d", id, len(acquireOrder)) + + // 持有锁一小段时间 + time.Sleep(20 * time.Millisecond) + + mutex.Unlock(ctx) + }(i) + } + + // 按顺序发送启动信号 + for i := 0; i < numClients; i++ { + close(startSignals[i]) + time.Sleep(30 * time.Millisecond) // 确保按顺序注册到锁队列 + } + + wg.Wait() + + // 验证获取顺序 + t.Logf("Acquire order: %v", acquireOrder) + assert.Len(t, acquireOrder, numClients) + + // 验证是 FIFO 顺序 + expectedOrder := make([]int, numClients) + for i := range expectedOrder { + expectedOrder[i] = i + } + assert.Equal(t, expectedOrder, acquireOrder, "Lock acquisition should follow FIFO order") +} + +// TestMutexCriticalSection 测试临界区保护 +func TestMutexCriticalSection(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numClients = 10 + const iterations = 5 + var counter int64 + var violations int64 + + var wg sync.WaitGroup + for i := 0; i < numClients; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/critical-section") + + for j := 0; j < iterations; j++ { + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 临界区操作 + oldVal := atomic.LoadInt64(&counter) + time.Sleep(time.Millisecond) // 模拟工作 + newVal := atomic.AddInt64(&counter, 1) + + // 检查是否有竞态条件 + if newVal != oldVal+1 { + atomic.AddInt64(&violations, 1) + } + + mutex.Unlock(ctx) + } + }() + } + + wg.Wait() + + assert.Equal(t, int64(numClients*iterations), atomic.LoadInt64(&counter)) + assert.Equal(t, int64(0), atomic.LoadInt64(&violations), "No race conditions should occur") +} + +// ============================================================================ +// Lock Timeout and Cancellation Tests +// ============================================================================ + +// TestMutexLockWithTimeout 测试带超时的锁获取 +func TestMutexLockWithTimeout(t *testing.T) { + _, cli := startLockTestServer(t) + bgCtx := context.Background() + + // 第一个会话持有锁 + session1, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session1.Close() + + mutex1 := NewMutex(session1, "/test/timeout") + err = mutex1.Lock(bgCtx) + require.NoError(t, err) + + // 第二个会话尝试获取锁,带超时 + session2, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := NewMutex(session2, "/test/timeout") + + ctx, cancel := context.WithTimeout(bgCtx, 500*time.Millisecond) + defer cancel() + + start := time.Now() + err = mutex2.Lock(ctx) + + elapsed := time.Since(start) + assert.Error(t, err) + assert.True(t, elapsed >= 400*time.Millisecond && elapsed < 1*time.Second, + "Lock should timeout around 500ms, got %v", elapsed) + + mutex1.Unlock(bgCtx) +} + +// TestMutexLockCancellation 测试锁获取取消 +func TestMutexLockCancellation(t *testing.T) { + _, cli := startLockTestServer(t) + + // 第一个会话持有锁 + session1, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session1.Close() + + mutex1 := NewMutex(session1, "/test/cancel") + err = mutex1.Lock(context.Background()) + require.NoError(t, err) + + // 第二个会话尝试获取锁 + session2, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := NewMutex(session2, "/test/cancel") + + ctx, cancel := context.WithCancel(context.Background()) + + // 启动 goroutine 获取锁 + done := make(chan error, 1) + go func() { + done <- mutex2.Lock(ctx) + }() + + // 等待一会儿然后取消 + time.Sleep(200 * time.Millisecond) + cancel() + + // 验证锁获取被取消 + select { + case err := <-done: + assert.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + case <-time.After(2 * time.Second): + t.Fatal("Lock should be canceled") + } + + mutex1.Unlock(context.Background()) +} + +// ============================================================================ +// Session Failure Tests +// ============================================================================ + +// TestMutexReleaseOnSessionClose 测试 Session 关闭时锁自动释放 +func TestMutexReleaseOnSessionClose(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 第一个会话获取锁 + session1, err := NewSession(cli, WithTTL(5)) + require.NoError(t, err) + + mutex1 := NewMutex(session1, "/test/session-close") + err = mutex1.Lock(ctx) + require.NoError(t, err) + + // 第二个会话准备获取锁 + session2, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := NewMutex(session2, "/test/session-close") + + // 启动 goroutine 等待锁 + acquired := make(chan struct{}) + go func() { + err := mutex2.Lock(ctx) + if err == nil { + close(acquired) + } + }() + + // 关闭第一个会话 + time.Sleep(100 * time.Millisecond) + session1.Close() + + // 验证第二个会话能获取锁 + select { + case <-acquired: + t.Log("Second session acquired lock after first session closed") + assert.True(t, mutex2.IsOwner()) + case <-time.After(5 * time.Second): + t.Fatal("Second session should acquire lock after first session closes") + } + + mutex2.Unlock(ctx) +} + +// ============================================================================ +// Election Tests +// ============================================================================ + +// TestElectionCampaign 测试 Leader 选举 +func TestElectionCampaign(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + election := NewElection(session, "/test/election") + + // 初始状态 + assert.False(t, election.IsLeader()) + + // 竞选 Leader + err = election.Campaign(ctx, "leader-value") + require.NoError(t, err) + + // 验证成为 Leader + assert.True(t, election.IsLeader()) + assert.NotEmpty(t, election.Key()) + assert.Greater(t, election.Rev(), int64(0)) + + // 查询当前 Leader + _, val, err := election.Leader(ctx) + require.NoError(t, err) + assert.Equal(t, "leader-value", val) + + // 放弃 Leader + err = election.Resign(ctx) + require.NoError(t, err) + + assert.False(t, election.IsLeader()) +} + +// TestElectionMultipleCandidates 测试多候选人选举 +func TestElectionMultipleCandidates(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numCandidates = 3 + var wg sync.WaitGroup + leaderChan := make(chan int, numCandidates) + + for i := 0; i < numCandidates; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + election := NewElection(session, "/test/multi-election") + + // 竞选 + value := fmt.Sprintf("candidate-%d", id) + err = election.Campaign(ctx, value) + require.NoError(t, err) + + leaderChan <- id + t.Logf("Candidate %d became leader", id) + + // 持有一段时间 + time.Sleep(100 * time.Millisecond) + + // 放弃 + election.Resign(ctx) + t.Logf("Candidate %d resigned", id) + }(i) + } + + // 等待所有候选人完成 + wg.Wait() + close(leaderChan) + + // 验证所有候选人都成为过 Leader + leaders := make(map[int]bool) + for id := range leaderChan { + leaders[id] = true + } + assert.Len(t, leaders, numCandidates) +} + +// TestElectionObserve 测试 Leader 变化观察 +func TestElectionObserve(t *testing.T) { + _, cli := startLockTestServer(t) + + session1, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session1.Close() + + session2, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + election1 := NewElection(session1, "/test/observe") + election2 := NewElection(session2, "/test/observe") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // election1 成为 Leader + err = election1.Campaign(ctx, "leader-1") + require.NoError(t, err) + + // 启动观察者 + observeCh := election2.Observe(ctx) + + // 收集观察到的 Leader + var observedLeaders []string + done := make(chan struct{}) + + go func() { + defer close(done) + for i := 0; i < 3; i++ { + select { + case leader, ok := <-observeCh: + if !ok { + return + } + observedLeaders = append(observedLeaders, leader) + t.Logf("Observed leader: %s", leader) + case <-ctx.Done(): + return + } + } + }() + + // 等待第一次观察 + time.Sleep(200 * time.Millisecond) + + // election1 放弃 + election1.Resign(ctx) + time.Sleep(100 * time.Millisecond) + + // election2 成为 Leader + err = election2.Campaign(ctx, "leader-2") + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + + cancel() + <-done + + // 验证观察到了 Leader 变化 + t.Logf("Observed leaders: %v", observedLeaders) + assert.GreaterOrEqual(t, len(observedLeaders), 1) +} + +// TestElectionResignNotLeader 测试非 Leader 放弃 +func TestElectionResignNotLeader(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + election := NewElection(session, "/test/resign-not-leader") + + // 未成为 Leader 就放弃 + err = election.Resign(ctx) + assert.Error(t, err) + assert.Equal(t, ErrElectionNotLeader, err) +} + +// ============================================================================ +// Stress Tests +// ============================================================================ + +// TestMutexHighConcurrency 高并发锁测试 +func TestMutexHighConcurrency(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numGoroutines = 20 + const iterations = 10 + var wg sync.WaitGroup + var successCount int64 + var failCount int64 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := NewSession(cli, WithTTL(60)) + if err != nil { + atomic.AddInt64(&failCount, int64(iterations)) + return + } + defer session.Close() + + mutex := NewMutex(session, "/test/high-concurrency") + + for j := 0; j < iterations; j++ { + lockCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + err := mutex.Lock(lockCtx) + cancel() + + if err != nil { + atomic.AddInt64(&failCount, 1) + continue + } + + atomic.AddInt64(&successCount, 1) + + // 短暂持有锁 + time.Sleep(5 * time.Millisecond) + + mutex.Unlock(ctx) + } + }(i) + } + + wg.Wait() + + t.Logf("Success: %d, Fail: %d", successCount, failCount) + assert.Equal(t, int64(numGoroutines*iterations), successCount) + assert.Equal(t, int64(0), failCount) +} + +// TestMutexRapidLockUnlock 快速加解锁测试 +func TestMutexRapidLockUnlock(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/rapid") + + const iterations = 100 + for i := 0; i < iterations; i++ { + err := mutex.Lock(ctx) + require.NoError(t, err, "Lock failed at iteration %d", i) + + assert.True(t, mutex.IsOwner()) + + err = mutex.Unlock(ctx) + require.NoError(t, err, "Unlock failed at iteration %d", i) + + assert.False(t, mutex.IsOwner()) + } +} + +// ============================================================================ +// Edge Case Tests +// ============================================================================ + +// TestMutexDifferentPrefixes 测试不同前缀的锁互不影响 +func TestMutexDifferentPrefixes(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex1 := NewMutex(session, "/test/prefix1") + mutex2 := NewMutex(session, "/test/prefix2") + + // 同时获取两个不同前缀的锁 + err = mutex1.Lock(ctx) + require.NoError(t, err) + + err = mutex2.Lock(ctx) + require.NoError(t, err) + + // 两个都应该成功 + assert.True(t, mutex1.IsOwner()) + assert.True(t, mutex2.IsOwner()) + + mutex1.Unlock(ctx) + mutex2.Unlock(ctx) +} + +// TestMutexSameSessionDifferentMutex 测试同一会话的不同 Mutex 实例 +func TestMutexSameSessionDifferentMutex(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + // 同一会话创建两个 Mutex 实例(相同前缀) + mutex1 := NewMutex(session, "/test/same-prefix") + mutex2 := NewMutex(session, "/test/same-prefix") + + // mutex1 获取锁 + err = mutex1.Lock(ctx) + require.NoError(t, err) + + // mutex2 也能获取锁(因为使用相同的 Lease,key 相同) + err = mutex2.Lock(ctx) + require.NoError(t, err) + + // 两个都认为自己是 owner + assert.True(t, mutex1.IsOwner()) + assert.True(t, mutex2.IsOwner()) + + // 但实际上是同一个 key + assert.Equal(t, mutex1.Key(), mutex2.Key()) + + mutex1.Unlock(ctx) +} + +// TestMutexEmptyPrefix 测试空前缀 +func TestMutexEmptyPrefix(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "") + + err = mutex.Lock(ctx) + require.NoError(t, err) + assert.True(t, mutex.IsOwner()) + + mutex.Unlock(ctx) +} + +// TestMutexSpecialCharacterPrefix 测试特殊字符前缀 +func TestMutexSpecialCharacterPrefix(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + prefixes := []string{ + "/test/special/chars", + "/test/with spaces", + "/test/with-dashes", + "/test/with_underscores", + "/test/with.dots", + } + + for _, prefix := range prefixes { + t.Run(prefix, func(t *testing.T) { + mutex := NewMutex(session, prefix) + err := mutex.Lock(ctx) + require.NoError(t, err) + assert.True(t, mutex.IsOwner()) + mutex.Unlock(ctx) + }) + } +} + +// ============================================================================ +// Benchmark Tests +// ============================================================================ + +// BenchmarkMutexLockUnlock 基准测试锁性能 +func BenchmarkMutexLockUnlock(b *testing.B) { + store := memory.NewMemoryEtcd() + server, err := etcdapi.NewServer(etcdapi.ServerConfig{ + Store: store, + Address: "127.0.0.1:0", + ClusterID: 1, + MemberID: 1, + }) + if err != nil { + b.Fatal(err) + } + + go server.Start() + time.Sleep(100 * time.Millisecond) + defer server.Stop() + + cli, err := clientv3.New(clientv3.Config{ + Endpoints: []string{server.Address()}, + DialTimeout: 5 * time.Second, + }) + if err != nil { + b.Fatal(err) + } + defer cli.Close() + + session, err := NewSession(cli, WithTTL(60)) + if err != nil { + b.Fatal(err) + } + defer session.Close() + + mutex := NewMutex(session, "/bench/lock") + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + mutex.Lock(ctx) + mutex.Unlock(ctx) + } +} + +// BenchmarkTryLock 基准测试 TryLock 性能 +func BenchmarkTryLock(b *testing.B) { + store := memory.NewMemoryEtcd() + server, err := etcdapi.NewServer(etcdapi.ServerConfig{ + Store: store, + Address: "127.0.0.1:0", + ClusterID: 1, + MemberID: 1, + }) + if err != nil { + b.Fatal(err) + } + + go server.Start() + time.Sleep(100 * time.Millisecond) + defer server.Stop() + + cli, err := clientv3.New(clientv3.Config{ + Endpoints: []string{server.Address()}, + DialTimeout: 5 * time.Second, + }) + if err != nil { + b.Fatal(err) + } + defer cli.Close() + + session, err := NewSession(cli, WithTTL(60)) + if err != nil { + b.Fatal(err) + } + defer session.Close() + + mutex := NewMutex(session, "/bench/trylock") + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := mutex.TryLock(ctx) + if err == nil { + mutex.Unlock(ctx) + } + } +} + +// BenchmarkSessionCreate 基准测试会话创建性能 +func BenchmarkSessionCreate(b *testing.B) { + store := memory.NewMemoryEtcd() + server, err := etcdapi.NewServer(etcdapi.ServerConfig{ + Store: store, + Address: "127.0.0.1:0", + ClusterID: 1, + MemberID: 1, + }) + if err != nil { + b.Fatal(err) + } + + go server.Start() + time.Sleep(100 * time.Millisecond) + defer server.Stop() + + cli, err := clientv3.New(clientv3.Config{ + Endpoints: []string{server.Address()}, + DialTimeout: 5 * time.Second, + }) + if err != nil { + b.Fatal(err) + } + defer cli.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + session, err := NewSession(cli, WithTTL(10)) + if err != nil { + b.Fatal(err) + } + session.Close() + } +} + +// ============================================================================ +// Integration Tests with etcd concurrency package +// ============================================================================ + +// TestCompatibilityWithEtcdConcurrency 测试与 etcd 官方 concurrency 包的兼容性 +func TestCompatibilityWithEtcdConcurrency(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 使用 etcd 官方的 concurrency 包创建会话和锁 + etcdSession, err := etcdconcurrency.NewSession(cli, etcdconcurrency.WithTTL(30)) + require.NoError(t, err) + defer etcdSession.Close() + + etcdMutex := etcdconcurrency.NewMutex(etcdSession, "/test/etcd-compat") + + // 获取锁 + err = etcdMutex.Lock(ctx) + require.NoError(t, err) + + // 验证锁状态 + assert.NotEmpty(t, etcdMutex.Key()) + + // 释放锁 + err = etcdMutex.Unlock(ctx) + require.NoError(t, err) +} + +// TestMixedLockUsage 测试混合使用自定义和 etcd 官方的锁 +func TestMixedLockUsage(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 使用自定义的 concurrency 包 + customSession, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer customSession.Close() + + customMutex := NewMutex(customSession, "/test/mixed") + + // 使用 etcd 官方的 concurrency 包 + etcdSession, err := etcdconcurrency.NewSession(cli, etcdconcurrency.WithTTL(30)) + require.NoError(t, err) + defer etcdSession.Close() + + etcdMutex := etcdconcurrency.NewMutex(etcdSession, "/test/mixed") + + // 自定义锁获取 + err = customMutex.Lock(ctx) + require.NoError(t, err) + + // etcd 锁尝试获取应该失败 + tryCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + err = etcdMutex.Lock(tryCtx) + cancel() + assert.Error(t, err, "etcd mutex should not be able to acquire lock") + + // 释放自定义锁 + customMutex.Unlock(ctx) + + // 现在 etcd 锁应该能获取 + err = etcdMutex.Lock(ctx) + require.NoError(t, err) + + etcdMutex.Unlock(ctx) +} + +// ============================================================================ +// Verify Lock Key Format +// ============================================================================ + +// TestMutexKeyFormat 验证锁 key 格式 +func TestMutexKeyFormat(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + prefix := "/test/key-format" + mutex := NewMutex(session, prefix) + + err = mutex.Lock(ctx) + require.NoError(t, err) + + key := mutex.Key() + t.Logf("Lock key: %s", key) + + // 验证 key 格式: prefix/ + lease_id(十六进制) + assert.Contains(t, key, prefix+"/") + assert.Contains(t, key, fmt.Sprintf("%x", session.Lease())) + + mutex.Unlock(ctx) +} + +// ============================================================================ +// Ordering Verification Tests +// ============================================================================ + +// TestLockAcquisitionOrderWithTimestamp 测试锁获取顺序(带时间戳验证) +func TestLockAcquisitionOrderWithTimestamp(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numClients = 5 + type lockEvent struct { + id int + timestamp time.Time + } + + var mu sync.Mutex + events := make([]lockEvent, 0, numClients) + + var wg sync.WaitGroup + startCh := make(chan struct{}) + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/order-timestamp") + + // 等待启动信号 + <-startCh + + // 获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 记录获取时间 + mu.Lock() + events = append(events, lockEvent{id: id, timestamp: time.Now()}) + mu.Unlock() + + time.Sleep(10 * time.Millisecond) + mutex.Unlock(ctx) + }(i) + } + + // 同时启动所有 goroutine + close(startCh) + wg.Wait() + + // 验证事件顺序 + assert.Len(t, events, numClients) + + // 验证时间戳是递增的 + for i := 1; i < len(events); i++ { + assert.True(t, events[i].timestamp.After(events[i-1].timestamp) || + events[i].timestamp.Equal(events[i-1].timestamp), + "Lock acquisition timestamps should be ordered") + } + + // 打印顺序 + var order []int + for _, e := range events { + order = append(order, e.id) + } + t.Logf("Acquisition order: %v", order) +} + +// ============================================================================ +// Recovery Tests +// ============================================================================ + +// TestMutexRecoveryAfterSessionClose 测试会话关闭后的锁恢复 +func TestMutexRecoveryAfterSessionClose(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + prefix := "/test/recovery" + + // 第一个会话获取锁 + session1, err := NewSession(cli, WithTTL(5)) + require.NoError(t, err) + + mutex1 := NewMutex(session1, prefix) + err = mutex1.Lock(ctx) + require.NoError(t, err) + t.Log("Session 1 acquired lock") + + // 关闭第一个会话 + session1.Close() + t.Log("Session 1 closed") + + // 第二个会话应该能获取锁 + session2, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := NewMutex(session2, prefix) + + // 应该能够获取锁 + lockCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + err = mutex2.Lock(lockCtx) + cancel() + + require.NoError(t, err, "Session 2 should acquire lock after session 1 closes") + assert.True(t, mutex2.IsOwner()) + t.Log("Session 2 acquired lock") + + mutex2.Unlock(ctx) +} + +// ============================================================================ +// Additional Concurrency Tests +// ============================================================================ + +// TestMultipleLocksSequential 测试顺序获取多个锁 +func TestMultipleLocksSequential(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + locks := make([]*Mutex, 5) + for i := range locks { + locks[i] = NewMutex(session, fmt.Sprintf("/test/multi/%d", i)) + } + + // 顺序获取所有锁 + for i, lock := range locks { + err := lock.Lock(ctx) + require.NoError(t, err, "Failed to acquire lock %d", i) + } + + // 验证所有锁都被持有 + for i, lock := range locks { + assert.True(t, lock.IsOwner(), "Lock %d should be owned", i) + } + + // 顺序释放所有锁 + for i, lock := range locks { + err := lock.Unlock(ctx) + require.NoError(t, err, "Failed to release lock %d", i) + } +} + +// TestConcurrentDifferentLocks 测试并发获取不同的锁 +func TestConcurrentDifferentLocks(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numLocks = 10 + var wg sync.WaitGroup + errors := make(chan error, numLocks) + + for i := 0; i < numLocks; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := NewSession(cli, WithTTL(30)) + if err != nil { + errors <- err + return + } + defer session.Close() + + mutex := NewMutex(session, fmt.Sprintf("/test/concurrent/%d", id)) + + if err := mutex.Lock(ctx); err != nil { + errors <- err + return + } + + time.Sleep(10 * time.Millisecond) + + if err := mutex.Unlock(ctx); err != nil { + errors <- err + return + } + }(i) + } + + wg.Wait() + close(errors) + + // 检查是否有错误 + for err := range errors { + t.Errorf("Unexpected error: %v", err) + } +} + +// TestLockFairness 测试锁公平性 +func TestLockFairness(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numRounds = 5 + const numClients = 3 + + var mu sync.Mutex + acquisitions := make(map[int]int) + + for round := 0; round < numRounds; round++ { + var wg sync.WaitGroup + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/fairness") + + err = mutex.Lock(ctx) + require.NoError(t, err) + + mu.Lock() + acquisitions[id]++ + mu.Unlock() + + time.Sleep(5 * time.Millisecond) + mutex.Unlock(ctx) + }(i) + } + + wg.Wait() + } + + // 验证每个客户端都获取了锁 + t.Logf("Acquisitions: %v", acquisitions) + for i := 0; i < numClients; i++ { + assert.Greater(t, acquisitions[i], 0, "Client %d should have acquired lock at least once", i) + } + + // 验证分布相对均匀(每个客户端应该获取约 numRounds 次) + total := 0 + for _, count := range acquisitions { + total += count + } + assert.Equal(t, numRounds*numClients, total) +} + +// TestLockWithContextDeadline 测试带截止时间的锁 +func TestLockWithContextDeadline(t *testing.T) { + _, cli := startLockTestServer(t) + + session1, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session1.Close() + + session2, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + + mutex1 := NewMutex(session1, "/test/deadline") + mutex2 := NewMutex(session2, "/test/deadline") + + // session1 获取锁 + ctx1 := context.Background() + err = mutex1.Lock(ctx1) + require.NoError(t, err) + + // session2 尝试获取锁,带截止时间 + deadline := time.Now().Add(500 * time.Millisecond) + ctx2, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + start := time.Now() + err = mutex2.Lock(ctx2) + elapsed := time.Since(start) + + assert.Error(t, err) + assert.True(t, elapsed >= 400*time.Millisecond, "Should wait until deadline") + assert.True(t, elapsed < 1*time.Second, "Should not wait too long after deadline") + + mutex1.Unlock(ctx1) +} + +// ============================================================================ +// Data Race Detection Tests +// ============================================================================ + +// TestMutexNoDataRace 测试无数据竞争 +func TestMutexNoDataRace(t *testing.T) { + _, cli := startLockTestServer(t) + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/race") + + var wg sync.WaitGroup + + // 并发调用各种方法 + for i := 0; i < 10; i++ { + wg.Add(3) + + go func() { + defer wg.Done() + _ = mutex.IsOwner() + }() + + go func() { + defer wg.Done() + _ = mutex.Key() + }() + + go func() { + defer wg.Done() + _ = mutex.Header() + }() + } + + wg.Wait() +} + +// TestSessionNoDataRace 测试 Session 无数据竞争 +func TestSessionNoDataRace(t *testing.T) { + _, cli := startLockTestServer(t) + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + + var wg sync.WaitGroup + + // 并发调用各种方法 + for i := 0; i < 10; i++ { + wg.Add(2) + + go func() { + defer wg.Done() + _ = session.Lease() + }() + + go func() { + defer wg.Done() + _ = session.Done() + }() + } + + wg.Wait() + session.Close() +} + +// ============================================================================ +// Edge Cases for Watch-based Waiting +// ============================================================================ + +// TestMutexWaitingQueue 测试锁等待队列 +func TestMutexWaitingQueue(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numWaiters = 5 + var orderMu sync.Mutex + order := make([]int, 0, numWaiters) + + // 信号通道用于同步 + ready := make([]chan struct{}, numWaiters) + for i := range ready { + ready[i] = make(chan struct{}) + } + + var wg sync.WaitGroup + + for i := 0; i < numWaiters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/queue") + + // 通知已准备好 + close(ready[id]) + + // 获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 记录顺序 + orderMu.Lock() + order = append(order, id) + orderMu.Unlock() + + time.Sleep(20 * time.Millisecond) + mutex.Unlock(ctx) + }(i) + + // 等待 goroutine 准备好后再启动下一个 + <-ready[i] + time.Sleep(30 * time.Millisecond) + } + + wg.Wait() + + t.Logf("Acquisition order: %v", order) + assert.Len(t, order, numWaiters) + + // 验证顺序 + expected := make([]int, numWaiters) + for i := range expected { + expected[i] = i + } + assert.Equal(t, expected, order) +} + +// TestMutexWatchEventHandling 测试 Watch 事件处理 +func TestMutexWatchEventHandling(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 创建多个会话和锁 + const numSessions = 3 + sessions := make([]*Session, numSessions) + mutexes := make([]*Mutex, numSessions) + + for i := range sessions { + var err error + sessions[i], err = NewSession(cli, WithTTL(60)) + require.NoError(t, err) + mutexes[i] = NewMutex(sessions[i], "/test/watch-events") + } + + defer func() { + for _, s := range sessions { + s.Close() + } + }() + + // 第一个会话获取锁 + err := mutexes[0].Lock(ctx) + require.NoError(t, err) + + // 其他会话尝试获取锁(会等待) + done := make([]chan error, numSessions-1) + for i := 1; i < numSessions; i++ { + done[i-1] = make(chan error, 1) + go func(idx int) { + done[idx-1] <- mutexes[idx].Lock(ctx) + }(i) + } + + // 等待其他会话进入等待状态 + time.Sleep(200 * time.Millisecond) + + // 释放第一个锁 + err = mutexes[0].Unlock(ctx) + require.NoError(t, err) + + // 验证等待的会话依次获取锁 + for i := 1; i < numSessions; i++ { + select { + case err := <-done[i-1]: + require.NoError(t, err) + t.Logf("Session %d acquired lock", i) + mutexes[i].Unlock(ctx) + case <-time.After(5 * time.Second): + t.Fatalf("Session %d failed to acquire lock", i) + } + } +} + +// ============================================================================ +// Performance Characterization Tests +// ============================================================================ + +// TestLockLatencyDistribution 测试锁延迟分布 +func TestLockLatencyDistribution(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := NewSession(cli, WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := NewMutex(session, "/test/latency") + + const iterations = 50 + latencies := make([]time.Duration, iterations) + + for i := 0; i < iterations; i++ { + start := time.Now() + err := mutex.Lock(ctx) + latencies[i] = time.Since(start) + require.NoError(t, err) + mutex.Unlock(ctx) + } + + // 计算统计信息 + sort.Slice(latencies, func(i, j int) bool { + return latencies[i] < latencies[j] + }) + + var total time.Duration + for _, l := range latencies { + total += l + } + + avg := total / time.Duration(iterations) + p50 := latencies[iterations/2] + p95 := latencies[iterations*95/100] + p99 := latencies[iterations*99/100] + + t.Logf("Lock latency distribution (n=%d):", iterations) + t.Logf(" Average: %v", avg) + t.Logf(" P50: %v", p50) + t.Logf(" P95: %v", p95) + t.Logf(" P99: %v", p99) + t.Logf(" Min: %v", latencies[0]) + t.Logf(" Max: %v", latencies[iterations-1]) + + // 验证延迟合理 + assert.Less(t, avg, 100*time.Millisecond, "Average latency should be reasonable") +} diff --git a/test/distributed_lock_memory_test.go b/test/distributed_lock_memory_test.go new file mode 100644 index 0000000..8a468cb --- /dev/null +++ b/test/distributed_lock_memory_test.go @@ -0,0 +1,1757 @@ +// Copyright 2025 The axfor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "fmt" + "sort" + "sync" + "sync/atomic" + "testing" + "time" + + etcdapi "metaStore/api/etcd" + "metaStore/internal/memory" + "metaStore/pkg/concurrency" + + clientv3 "go.etcd.io/etcd/client/v3" + etcdconcurrency "go.etcd.io/etcd/client/v3/concurrency" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================================ +// Test Helper Functions +// ============================================================================ + +// startLockTestServer 启动用于锁测试的服务器 +func startLockTestServer(t *testing.T) (*etcdapi.Server, *clientv3.Client) { + store := memory.NewMemoryEtcd() + server, err := etcdapi.NewServer(etcdapi.ServerConfig{ + Store: store, + Address: "127.0.0.1:0", + ClusterID: 1, + MemberID: 1, + }) + require.NoError(t, err) + + go func() { + if err := server.Start(); err != nil { + t.Logf("Server error: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + cli, err := clientv3.New(clientv3.Config{ + Endpoints: []string{server.Address()}, + DialTimeout: 5 * time.Second, + }) + require.NoError(t, err) + + t.Cleanup(func() { + cli.Close() + server.Stop() + }) + + return server, cli +} + +// ============================================================================ +// Session Tests +// ============================================================================ + +// TestSessionCreate 测试 Session 创建 +func TestSessionCreate(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 创建会话 + session, err := concurrency.NewSession(cli, concurrency.WithTTL(10)) + require.NoError(t, err) + require.NotNil(t, session) + + // 验证 Lease ID + leaseID := session.Lease() + assert.NotEqual(t, clientv3.NoLease, leaseID) + + // 验证 Lease 有效 + ttlResp, err := cli.TimeToLive(ctx, leaseID) + require.NoError(t, err) + assert.Greater(t, ttlResp.TTL, int64(0)) + assert.LessOrEqual(t, ttlResp.TTL, int64(10)) + + // 关闭会话 + err = session.Close() + require.NoError(t, err) + + // 验证 Lease 已被撤销 + ttlResp, err = cli.TimeToLive(ctx, leaseID) + require.NoError(t, err) + assert.Equal(t, int64(-1), ttlResp.TTL) +} + +// TestSessionWithExistingLease 测试使用现有 Lease 创建会话 +func TestSessionWithExistingLease(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 先创建 Lease + leaseResp, err := cli.Grant(ctx, 30) + require.NoError(t, err) + + // 使用现有 Lease 创建会话 + session, err := concurrency.NewSession(cli, concurrency.WithLease(leaseResp.ID)) + require.NoError(t, err) + require.NotNil(t, session) + + // 验证使用的是同一个 Lease + assert.Equal(t, leaseResp.ID, session.Lease()) + + session.Close() +} + +// TestSessionOrphan 测试 Orphan 功能 +func TestSessionOrphan(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + leaseID := session.Lease() + + // 使用 Orphan 结束会话但保留 Lease + session.Orphan() + + // 验证 Lease 仍然有效 + ttlResp, err := cli.TimeToLive(ctx, leaseID) + require.NoError(t, err) + assert.Greater(t, ttlResp.TTL, int64(0)) + + // 手动撤销 Lease + _, err = cli.Revoke(ctx, leaseID) + require.NoError(t, err) +} + +// TestSessionExpiry 测试 Session 过期 +func TestSessionExpiry(t *testing.T) { + _, cli := startLockTestServer(t) + + // 创建短期会话(2秒) + session, err := concurrency.NewSession(cli, concurrency.WithTTL(2)) + require.NoError(t, err) + leaseID := session.Lease() + + // 关闭会话(停止 keepalive) + session.Close() + + // 等待 Lease 过期 + time.Sleep(3 * time.Second) + + // 验证 Lease 已过期 + ctx := context.Background() + ttlResp, err := cli.TimeToLive(ctx, leaseID) + require.NoError(t, err) + assert.Equal(t, int64(-1), ttlResp.TTL) +} + +// ============================================================================ +// Basic Mutex Tests +// ============================================================================ + +// TestMutexLockUnlock 测试基本的 Lock 和 Unlock +func TestMutexLockUnlock(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/lock") + + // 验证初始状态 + assert.False(t, mutex.IsOwner()) + assert.Empty(t, mutex.Key()) + + // 获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 验证锁状态 + assert.True(t, mutex.IsOwner()) + assert.NotEmpty(t, mutex.Key()) + assert.NotNil(t, mutex.Header()) + + // 释放锁 + err = mutex.Unlock(ctx) + require.NoError(t, err) + + // 验证锁已释放 + assert.False(t, mutex.IsOwner()) + assert.Empty(t, mutex.Key()) +} + +// TestMutexReentrantLock 测试重入锁(同一个 Mutex 多次 Lock) +func TestMutexReentrantLock(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/reentrant") + + // 第一次获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + firstKey := mutex.Key() + + // 第二次获取锁(应该立即返回) + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 验证 key 没有变化 + assert.Equal(t, firstKey, mutex.Key()) + + // 释放锁 + err = mutex.Unlock(ctx) + require.NoError(t, err) +} + +// TestMutexUnlockWithoutLock 测试未持有锁时 Unlock +func TestMutexUnlockWithoutLock(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/unlock-without-lock") + + // 未持有锁时 Unlock 应该是安全的 + err = mutex.Unlock(ctx) + require.NoError(t, err) +} + +// ============================================================================ +// TryLock Tests +// ============================================================================ + +// TestTryLockSuccess 测试 TryLock 成功场景 +func TestTryLockSuccess(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/trylock") + + // TryLock 应该立即成功 + err = mutex.TryLock(ctx) + require.NoError(t, err) + assert.True(t, mutex.IsOwner()) + + mutex.Unlock(ctx) +} + +// TestTryLockFail 测试 TryLock 失败场景 +func TestTryLockFail(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 第一个会话获取锁 + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session1.Close() + + mutex1 := concurrency.NewMutex(session1, "/test/trylock-fail") + err = mutex1.Lock(ctx) + require.NoError(t, err) + + // 第二个会话尝试 TryLock + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := concurrency.NewMutex(session2, "/test/trylock-fail") + err = mutex2.TryLock(ctx) + + // 应该返回 ErrLocked + assert.Error(t, err) + assert.Equal(t, etcdconcurrency.ErrLocked, err) + assert.False(t, mutex2.IsOwner()) + + mutex1.Unlock(ctx) +} + +// TestTryLockAfterUnlock 测试解锁后 TryLock +func TestTryLockAfterUnlock(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session1.Close() + + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + mutex1 := concurrency.NewMutex(session1, "/test/trylock-after-unlock") + mutex2 := concurrency.NewMutex(session2, "/test/trylock-after-unlock") + + // session1 获取锁 + err = mutex1.Lock(ctx) + require.NoError(t, err) + + // session2 TryLock 失败 + err = mutex2.TryLock(ctx) + assert.Equal(t, etcdconcurrency.ErrLocked, err) + + // session1 释放锁 + err = mutex1.Unlock(ctx) + require.NoError(t, err) + + // session2 TryLock 成功 + err = mutex2.TryLock(ctx) + require.NoError(t, err) + assert.True(t, mutex2.IsOwner()) + + mutex2.Unlock(ctx) +} + +// ============================================================================ +// Concurrent Lock Tests +// ============================================================================ + +// TestMutexContention 测试锁竞争 +func TestMutexContention(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numClients = 5 + var wg sync.WaitGroup + acquired := make(chan int, numClients) + released := make(chan int, numClients) + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/contention") + + // 获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + + acquired <- id + t.Logf("Client %d acquired lock", id) + + // 持有锁一小段时间 + time.Sleep(50 * time.Millisecond) + + // 释放锁 + err = mutex.Unlock(ctx) + require.NoError(t, err) + + released <- id + t.Logf("Client %d released lock", id) + }(i) + } + + // 等待所有 goroutine 完成 + wg.Wait() + close(acquired) + close(released) + + // 验证每个客户端都获取并释放了锁 + acquiredClients := make(map[int]bool) + for id := range acquired { + acquiredClients[id] = true + } + assert.Len(t, acquiredClients, numClients) + + releasedClients := make(map[int]bool) + for id := range released { + releasedClients[id] = true + } + assert.Len(t, releasedClients, numClients) +} + +// TestMutexFIFOOrder 测试锁的 FIFO 顺序 +func TestMutexFIFOOrder(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numClients = 5 + var orderMu sync.Mutex + acquireOrder := make([]int, 0, numClients) + + // 创建一个信号通道来控制启动顺序 + startSignals := make([]chan struct{}, numClients) + for i := range startSignals { + startSignals[i] = make(chan struct{}) + } + + var wg sync.WaitGroup + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // 等待启动信号 + <-startSignals[id] + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/fifo") + + // 获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 记录获取顺序 + orderMu.Lock() + acquireOrder = append(acquireOrder, id) + orderMu.Unlock() + + t.Logf("Client %d acquired lock at position %d", id, len(acquireOrder)) + + // 持有锁一小段时间 + time.Sleep(20 * time.Millisecond) + + mutex.Unlock(ctx) + }(i) + } + + // 按顺序发送启动信号 + for i := 0; i < numClients; i++ { + close(startSignals[i]) + time.Sleep(30 * time.Millisecond) // 确保按顺序注册到锁队列 + } + + wg.Wait() + + // 验证获取顺序 + t.Logf("Acquire order: %v", acquireOrder) + assert.Len(t, acquireOrder, numClients) + + // 验证是 FIFO 顺序 + expectedOrder := make([]int, numClients) + for i := range expectedOrder { + expectedOrder[i] = i + } + assert.Equal(t, expectedOrder, acquireOrder, "Lock acquisition should follow FIFO order") +} + +// TestMutexCriticalSection 测试临界区保护 +func TestMutexCriticalSection(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numClients = 10 + const iterations = 5 + var counter int64 + var violations int64 + + var wg sync.WaitGroup + for i := 0; i < numClients; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/critical-section") + + for j := 0; j < iterations; j++ { + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 临界区操作 + oldVal := atomic.LoadInt64(&counter) + time.Sleep(time.Millisecond) // 模拟工作 + newVal := atomic.AddInt64(&counter, 1) + + // 检查是否有竞态条件 + if newVal != oldVal+1 { + atomic.AddInt64(&violations, 1) + } + + mutex.Unlock(ctx) + } + }() + } + + wg.Wait() + + assert.Equal(t, int64(numClients*iterations), atomic.LoadInt64(&counter)) + assert.Equal(t, int64(0), atomic.LoadInt64(&violations), "No race conditions should occur") +} + +// ============================================================================ +// Lock Timeout and Cancellation Tests +// ============================================================================ + +// TestMutexLockWithTimeout 测试带超时的锁获取 +func TestMutexLockWithTimeout(t *testing.T) { + _, cli := startLockTestServer(t) + + // 第一个会话持有锁 + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session1.Close() + + mutex1 := concurrency.NewMutex(session1, "/test/timeout") + err = mutex1.Lock(context.Background()) + require.NoError(t, err) + + // 第二个会话尝试获取锁,带超时 + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := concurrency.NewMutex(session2, "/test/timeout") + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + start := time.Now() + err = mutex2.Lock(ctx) + + elapsed := time.Since(start) + assert.Error(t, err) + assert.True(t, elapsed >= 400*time.Millisecond && elapsed < 1*time.Second, + "Lock should timeout around 500ms, got %v", elapsed) + + mutex1.Unlock(context.Background()) +} + +// TestMutexLockCancellation 测试锁获取取消 +func TestMutexLockCancellation(t *testing.T) { + _, cli := startLockTestServer(t) + + // 第一个会话持有锁 + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session1.Close() + + mutex1 := concurrency.NewMutex(session1, "/test/cancel") + err = mutex1.Lock(context.Background()) + require.NoError(t, err) + + // 第二个会话尝试获取锁 + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := concurrency.NewMutex(session2, "/test/cancel") + + ctx, cancel := context.WithCancel(context.Background()) + + // 启动 goroutine 获取锁 + done := make(chan error, 1) + go func() { + done <- mutex2.Lock(ctx) + }() + + // 等待一会儿然后取消 + time.Sleep(200 * time.Millisecond) + cancel() + + // 验证锁获取被取消 + select { + case err := <-done: + assert.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + case <-time.After(2 * time.Second): + t.Fatal("Lock should be canceled") + } + + mutex1.Unlock(context.Background()) +} + +// ============================================================================ +// Session Failure Tests +// ============================================================================ + +// TestMutexReleaseOnSessionClose 测试 Session 关闭时锁自动释放 +func TestMutexReleaseOnSessionClose(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 第一个会话获取锁 + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(5)) + require.NoError(t, err) + + mutex1 := concurrency.NewMutex(session1, "/test/session-close") + err = mutex1.Lock(ctx) + require.NoError(t, err) + + // 第二个会话准备获取锁 + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := concurrency.NewMutex(session2, "/test/session-close") + + // 启动 goroutine 等待锁 + acquired := make(chan struct{}) + go func() { + err := mutex2.Lock(ctx) + if err == nil { + close(acquired) + } + }() + + // 关闭第一个会话 + time.Sleep(100 * time.Millisecond) + session1.Close() + + // 验证第二个会话能获取锁 + select { + case <-acquired: + t.Log("Second session acquired lock after first session closed") + assert.True(t, mutex2.IsOwner()) + case <-time.After(5 * time.Second): + t.Fatal("Second session should acquire lock after first session closes") + } + + mutex2.Unlock(ctx) +} + +// ============================================================================ +// Election Tests +// ============================================================================ + +// TestElectionCampaign 测试 Leader 选举 +func TestElectionCampaign(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + election := concurrency.NewElection(session, "/test/election") + + // 初始状态 + assert.False(t, election.IsLeader()) + + // 竞选 Leader + err = election.Campaign(ctx, "leader-value") + require.NoError(t, err) + + // 验证成为 Leader + assert.True(t, election.IsLeader()) + assert.NotEmpty(t, election.Key()) + assert.Greater(t, election.Rev(), int64(0)) + + // 查询当前 Leader + _, val, err := election.Leader(ctx) + require.NoError(t, err) + assert.Equal(t, "leader-value", val) + + // 放弃 Leader + err = election.Resign(ctx) + require.NoError(t, err) + + assert.False(t, election.IsLeader()) +} + +// TestElectionMultipleCandidates 测试多候选人选举 +func TestElectionMultipleCandidates(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numCandidates = 3 + var wg sync.WaitGroup + leaderChan := make(chan int, numCandidates) + + for i := 0; i < numCandidates; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + election := concurrency.NewElection(session, "/test/multi-election") + + // 竞选 + value := fmt.Sprintf("candidate-%d", id) + err = election.Campaign(ctx, value) + require.NoError(t, err) + + leaderChan <- id + t.Logf("Candidate %d became leader", id) + + // 持有一段时间 + time.Sleep(100 * time.Millisecond) + + // 放弃 + election.Resign(ctx) + t.Logf("Candidate %d resigned", id) + }(i) + } + + // 等待所有候选人完成 + wg.Wait() + close(leaderChan) + + // 验证所有候选人都成为过 Leader + leaders := make(map[int]bool) + for id := range leaderChan { + leaders[id] = true + } + assert.Len(t, leaders, numCandidates) +} + +// TestElectionObserve 测试 Leader 变化观察 +func TestElectionObserve(t *testing.T) { + _, cli := startLockTestServer(t) + + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session1.Close() + + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + election1 := concurrency.NewElection(session1, "/test/observe") + election2 := concurrency.NewElection(session2, "/test/observe") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // election1 成为 Leader + err = election1.Campaign(ctx, "leader-1") + require.NoError(t, err) + + // 启动观察者 + observeCh := election2.Observe(ctx) + + // 收集观察到的 Leader + var observedLeaders []string + done := make(chan struct{}) + + go func() { + defer close(done) + for i := 0; i < 3; i++ { + select { + case leader, ok := <-observeCh: + if !ok { + return + } + observedLeaders = append(observedLeaders, leader) + t.Logf("Observed leader: %s", leader) + case <-ctx.Done(): + return + } + } + }() + + // 等待第一次观察 + time.Sleep(200 * time.Millisecond) + + // election1 放弃 + election1.Resign(ctx) + time.Sleep(100 * time.Millisecond) + + // election2 成为 Leader + err = election2.Campaign(ctx, "leader-2") + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + + cancel() + <-done + + // 验证观察到了 Leader 变化 + t.Logf("Observed leaders: %v", observedLeaders) + assert.GreaterOrEqual(t, len(observedLeaders), 1) +} + +// TestElectionResignNotLeader 测试非 Leader 放弃 +func TestElectionResignNotLeader(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + election := concurrency.NewElection(session, "/test/resign-not-leader") + + // 未成为 Leader 就放弃 + err = election.Resign(ctx) + assert.Error(t, err) + assert.Equal(t, concurrency.ErrElectionNotLeader, err) +} + +// ============================================================================ +// Stress Tests +// ============================================================================ + +// TestMutexHighConcurrency 高并发锁测试 +func TestMutexHighConcurrency(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numGoroutines = 20 + const iterations = 10 + var wg sync.WaitGroup + var successCount int64 + var failCount int64 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + if err != nil { + atomic.AddInt64(&failCount, int64(iterations)) + return + } + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/high-concurrency") + + for j := 0; j < iterations; j++ { + lockCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + err := mutex.Lock(lockCtx) + cancel() + + if err != nil { + atomic.AddInt64(&failCount, 1) + continue + } + + atomic.AddInt64(&successCount, 1) + + // 短暂持有锁 + time.Sleep(5 * time.Millisecond) + + mutex.Unlock(ctx) + } + }(i) + } + + wg.Wait() + + t.Logf("Success: %d, Fail: %d", successCount, failCount) + assert.Equal(t, int64(numGoroutines*iterations), successCount) + assert.Equal(t, int64(0), failCount) +} + +// TestMutexRapidLockUnlock 快速加解锁测试 +func TestMutexRapidLockUnlock(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/rapid") + + const iterations = 100 + for i := 0; i < iterations; i++ { + err := mutex.Lock(ctx) + require.NoError(t, err, "Lock failed at iteration %d", i) + + assert.True(t, mutex.IsOwner()) + + err = mutex.Unlock(ctx) + require.NoError(t, err, "Unlock failed at iteration %d", i) + + assert.False(t, mutex.IsOwner()) + } +} + +// ============================================================================ +// Edge Case Tests +// ============================================================================ + +// TestMutexDifferentPrefixes 测试不同前缀的锁互不影响 +func TestMutexDifferentPrefixes(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex1 := concurrency.NewMutex(session, "/test/prefix1") + mutex2 := concurrency.NewMutex(session, "/test/prefix2") + + // 同时获取两个不同前缀的锁 + err = mutex1.Lock(ctx) + require.NoError(t, err) + + err = mutex2.Lock(ctx) + require.NoError(t, err) + + // 两个都应该成功 + assert.True(t, mutex1.IsOwner()) + assert.True(t, mutex2.IsOwner()) + + mutex1.Unlock(ctx) + mutex2.Unlock(ctx) +} + +// TestMutexSameSessionDifferentMutex 测试同一会话的不同 Mutex 实例 +func TestMutexSameSessionDifferentMutex(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + // 同一会话创建两个 Mutex 实例(相同前缀) + mutex1 := concurrency.NewMutex(session, "/test/same-prefix") + mutex2 := concurrency.NewMutex(session, "/test/same-prefix") + + // mutex1 获取锁 + err = mutex1.Lock(ctx) + require.NoError(t, err) + + // mutex2 也能获取锁(因为使用相同的 Lease,key 相同) + err = mutex2.Lock(ctx) + require.NoError(t, err) + + // 两个都认为自己是 owner + assert.True(t, mutex1.IsOwner()) + assert.True(t, mutex2.IsOwner()) + + // 但实际上是同一个 key + assert.Equal(t, mutex1.Key(), mutex2.Key()) + + mutex1.Unlock(ctx) +} + +// TestMutexEmptyPrefix 测试空前缀 +func TestMutexEmptyPrefix(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "") + + err = mutex.Lock(ctx) + require.NoError(t, err) + assert.True(t, mutex.IsOwner()) + + mutex.Unlock(ctx) +} + +// TestMutexSpecialCharacterPrefix 测试特殊字符前缀 +func TestMutexSpecialCharacterPrefix(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + prefixes := []string{ + "/test/special/chars", + "/test/with spaces", + "/test/with-dashes", + "/test/with_underscores", + "/test/with.dots", + } + + for _, prefix := range prefixes { + t.Run(prefix, func(t *testing.T) { + mutex := concurrency.NewMutex(session, prefix) + err := mutex.Lock(ctx) + require.NoError(t, err) + assert.True(t, mutex.IsOwner()) + mutex.Unlock(ctx) + }) + } +} + +// ============================================================================ +// Benchmark Tests +// ============================================================================ + +// BenchmarkMutexLockUnlock 基准测试锁性能 +func BenchmarkMutexLockUnlock(b *testing.B) { + store := memory.NewMemoryEtcd() + server, err := etcdapi.NewServer(etcdapi.ServerConfig{ + Store: store, + Address: "127.0.0.1:0", + ClusterID: 1, + MemberID: 1, + }) + if err != nil { + b.Fatal(err) + } + + go server.Start() + time.Sleep(100 * time.Millisecond) + defer server.Stop() + + cli, err := clientv3.New(clientv3.Config{ + Endpoints: []string{server.Address()}, + DialTimeout: 5 * time.Second, + }) + if err != nil { + b.Fatal(err) + } + defer cli.Close() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + if err != nil { + b.Fatal(err) + } + defer session.Close() + + mutex := concurrency.NewMutex(session, "/bench/lock") + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + mutex.Lock(ctx) + mutex.Unlock(ctx) + } +} + +// BenchmarkTryLock 基准测试 TryLock 性能 +func BenchmarkTryLock(b *testing.B) { + store := memory.NewMemoryEtcd() + server, err := etcdapi.NewServer(etcdapi.ServerConfig{ + Store: store, + Address: "127.0.0.1:0", + ClusterID: 1, + MemberID: 1, + }) + if err != nil { + b.Fatal(err) + } + + go server.Start() + time.Sleep(100 * time.Millisecond) + defer server.Stop() + + cli, err := clientv3.New(clientv3.Config{ + Endpoints: []string{server.Address()}, + DialTimeout: 5 * time.Second, + }) + if err != nil { + b.Fatal(err) + } + defer cli.Close() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + if err != nil { + b.Fatal(err) + } + defer session.Close() + + mutex := concurrency.NewMutex(session, "/bench/trylock") + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := mutex.TryLock(ctx) + if err == nil { + mutex.Unlock(ctx) + } + } +} + +// BenchmarkSessionCreate 基准测试会话创建性能 +func BenchmarkSessionCreate(b *testing.B) { + store := memory.NewMemoryEtcd() + server, err := etcdapi.NewServer(etcdapi.ServerConfig{ + Store: store, + Address: "127.0.0.1:0", + ClusterID: 1, + MemberID: 1, + }) + if err != nil { + b.Fatal(err) + } + + go server.Start() + time.Sleep(100 * time.Millisecond) + defer server.Stop() + + cli, err := clientv3.New(clientv3.Config{ + Endpoints: []string{server.Address()}, + DialTimeout: 5 * time.Second, + }) + if err != nil { + b.Fatal(err) + } + defer cli.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + session, err := concurrency.NewSession(cli, concurrency.WithTTL(10)) + if err != nil { + b.Fatal(err) + } + session.Close() + } +} + +// ============================================================================ +// Integration Tests with etcd concurrency package +// ============================================================================ + +// TestCompatibilityWithEtcdConcurrency 测试与 etcd 官方 concurrency 包的兼容性 +func TestCompatibilityWithEtcdConcurrency(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 使用 etcd 官方的 concurrency 包创建会话和锁 + etcdSession, err := etcdconcurrency.NewSession(cli, etcdconcurrency.WithTTL(30)) + require.NoError(t, err) + defer etcdSession.Close() + + etcdMutex := etcdconcurrency.NewMutex(etcdSession, "/test/etcd-compat") + + // 获取锁 + err = etcdMutex.Lock(ctx) + require.NoError(t, err) + + // 验证锁状态 + assert.NotEmpty(t, etcdMutex.Key()) + + // 释放锁 + err = etcdMutex.Unlock(ctx) + require.NoError(t, err) +} + +// TestMixedLockUsage 测试混合使用自定义和 etcd 官方的锁 +func TestMixedLockUsage(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 使用自定义的 concurrency 包 + customSession, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer customSession.Close() + + customMutex := concurrency.NewMutex(customSession, "/test/mixed") + + // 使用 etcd 官方的 concurrency 包 + etcdSession, err := etcdconcurrency.NewSession(cli, etcdconcurrency.WithTTL(30)) + require.NoError(t, err) + defer etcdSession.Close() + + etcdMutex := etcdconcurrency.NewMutex(etcdSession, "/test/mixed") + + // 自定义锁获取 + err = customMutex.Lock(ctx) + require.NoError(t, err) + + // etcd 锁尝试获取应该失败 + tryCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + err = etcdMutex.Lock(tryCtx) + cancel() + assert.Error(t, err, "etcd mutex should not be able to acquire lock") + + // 释放自定义锁 + customMutex.Unlock(ctx) + + // 现在 etcd 锁应该能获取 + err = etcdMutex.Lock(ctx) + require.NoError(t, err) + + etcdMutex.Unlock(ctx) +} + +// ============================================================================ +// Verify Lock Key Format +// ============================================================================ + +// TestMutexKeyFormat 验证锁 key 格式 +func TestMutexKeyFormat(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + prefix := "/test/key-format" + mutex := concurrency.NewMutex(session, prefix) + + err = mutex.Lock(ctx) + require.NoError(t, err) + + key := mutex.Key() + t.Logf("Lock key: %s", key) + + // 验证 key 格式: prefix/ + lease_id(十六进制) + assert.Contains(t, key, prefix+"/") + assert.Contains(t, key, fmt.Sprintf("%x", session.Lease())) + + mutex.Unlock(ctx) +} + +// ============================================================================ +// Ordering Verification Tests +// ============================================================================ + +// TestLockAcquisitionOrderWithTimestamp 测试锁获取顺序(带时间戳验证) +func TestLockAcquisitionOrderWithTimestamp(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numClients = 5 + type lockEvent struct { + id int + timestamp time.Time + } + + var mu sync.Mutex + events := make([]lockEvent, 0, numClients) + + var wg sync.WaitGroup + startCh := make(chan struct{}) + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/order-timestamp") + + // 等待启动信号 + <-startCh + + // 获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 记录获取时间 + mu.Lock() + events = append(events, lockEvent{id: id, timestamp: time.Now()}) + mu.Unlock() + + time.Sleep(10 * time.Millisecond) + mutex.Unlock(ctx) + }(i) + } + + // 同时启动所有 goroutine + close(startCh) + wg.Wait() + + // 验证事件顺序 + assert.Len(t, events, numClients) + + // 验证时间戳是递增的 + for i := 1; i < len(events); i++ { + assert.True(t, events[i].timestamp.After(events[i-1].timestamp) || + events[i].timestamp.Equal(events[i-1].timestamp), + "Lock acquisition timestamps should be ordered") + } + + // 打印顺序 + var order []int + for _, e := range events { + order = append(order, e.id) + } + t.Logf("Acquisition order: %v", order) +} + +// ============================================================================ +// Recovery Tests +// ============================================================================ + +// TestMutexRecoveryAfterSessionClose 测试会话关闭后的锁恢复 +func TestMutexRecoveryAfterSessionClose(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + prefix := "/test/recovery" + + // 第一个会话获取锁 + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(5)) + require.NoError(t, err) + + mutex1 := concurrency.NewMutex(session1, prefix) + err = mutex1.Lock(ctx) + require.NoError(t, err) + t.Log("Session 1 acquired lock") + + // 关闭第一个会话 + session1.Close() + t.Log("Session 1 closed") + + // 第二个会话应该能获取锁 + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := concurrency.NewMutex(session2, prefix) + + // 应该能够获取锁 + lockCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + err = mutex2.Lock(lockCtx) + cancel() + + require.NoError(t, err, "Session 2 should acquire lock after session 1 closes") + assert.True(t, mutex2.IsOwner()) + t.Log("Session 2 acquired lock") + + mutex2.Unlock(ctx) +} + +// ============================================================================ +// Additional Concurrency Tests +// ============================================================================ + +// TestMultipleLocksSequential 测试顺序获取多个锁 +func TestMultipleLocksSequential(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + locks := make([]*concurrency.Mutex, 5) + for i := range locks { + locks[i] = concurrency.NewMutex(session, fmt.Sprintf("/test/multi/%d", i)) + } + + // 顺序获取所有锁 + for i, lock := range locks { + err := lock.Lock(ctx) + require.NoError(t, err, "Failed to acquire lock %d", i) + } + + // 验证所有锁都被持有 + for i, lock := range locks { + assert.True(t, lock.IsOwner(), "Lock %d should be owned", i) + } + + // 顺序释放所有锁 + for i, lock := range locks { + err := lock.Unlock(ctx) + require.NoError(t, err, "Failed to release lock %d", i) + } +} + +// TestConcurrentDifferentLocks 测试并发获取不同的锁 +func TestConcurrentDifferentLocks(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numLocks = 10 + var wg sync.WaitGroup + errors := make(chan error, numLocks) + + for i := 0; i < numLocks; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + if err != nil { + errors <- err + return + } + defer session.Close() + + mutex := concurrency.NewMutex(session, fmt.Sprintf("/test/concurrent/%d", id)) + + if err := mutex.Lock(ctx); err != nil { + errors <- err + return + } + + time.Sleep(10 * time.Millisecond) + + if err := mutex.Unlock(ctx); err != nil { + errors <- err + return + } + }(i) + } + + wg.Wait() + close(errors) + + // 检查是否有错误 + for err := range errors { + t.Errorf("Unexpected error: %v", err) + } +} + +// TestLockFairness 测试锁公平性 +func TestLockFairness(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numRounds = 5 + const numClients = 3 + + var mu sync.Mutex + acquisitions := make(map[int]int) + + for round := 0; round < numRounds; round++ { + var wg sync.WaitGroup + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/fairness") + + err = mutex.Lock(ctx) + require.NoError(t, err) + + mu.Lock() + acquisitions[id]++ + mu.Unlock() + + time.Sleep(5 * time.Millisecond) + mutex.Unlock(ctx) + }(i) + } + + wg.Wait() + } + + // 验证每个客户端都获取了锁 + t.Logf("Acquisitions: %v", acquisitions) + for i := 0; i < numClients; i++ { + assert.Greater(t, acquisitions[i], 0, "Client %d should have acquired lock at least once", i) + } + + // 验证分布相对均匀(每个客户端应该获取约 numRounds 次) + total := 0 + for _, count := range acquisitions { + total += count + } + assert.Equal(t, numRounds*numClients, total) +} + +// TestLockWithContextDeadline 测试带截止时间的锁 +func TestLockWithContextDeadline(t *testing.T) { + _, cli := startLockTestServer(t) + + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session1.Close() + + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + + mutex1 := concurrency.NewMutex(session1, "/test/deadline") + mutex2 := concurrency.NewMutex(session2, "/test/deadline") + + // session1 获取锁 + ctx1 := context.Background() + err = mutex1.Lock(ctx1) + require.NoError(t, err) + + // session2 尝试获取锁,带截止时间 + deadline := time.Now().Add(500 * time.Millisecond) + ctx2, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + start := time.Now() + err = mutex2.Lock(ctx2) + elapsed := time.Since(start) + + assert.Error(t, err) + assert.True(t, elapsed >= 400*time.Millisecond, "Should wait until deadline") + assert.True(t, elapsed < 1*time.Second, "Should not wait too long after deadline") + + mutex1.Unlock(ctx1) +} + +// ============================================================================ +// Data Race Detection Tests +// ============================================================================ + +// TestMutexNoDataRace 测试无数据竞争 +func TestMutexNoDataRace(t *testing.T) { + _, cli := startLockTestServer(t) + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/race") + + var wg sync.WaitGroup + + // 并发调用各种方法 + for i := 0; i < 10; i++ { + wg.Add(3) + + go func() { + defer wg.Done() + _ = mutex.IsOwner() + }() + + go func() { + defer wg.Done() + _ = mutex.Key() + }() + + go func() { + defer wg.Done() + _ = mutex.Header() + }() + } + + wg.Wait() +} + +// TestSessionNoDataRace 测试 Session 无数据竞争 +func TestSessionNoDataRace(t *testing.T) { + _, cli := startLockTestServer(t) + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + + var wg sync.WaitGroup + + // 并发调用各种方法 + for i := 0; i < 10; i++ { + wg.Add(2) + + go func() { + defer wg.Done() + _ = session.Lease() + }() + + go func() { + defer wg.Done() + _ = session.Done() + }() + } + + wg.Wait() + session.Close() +} + +// ============================================================================ +// Edge Cases for Watch-based Waiting +// ============================================================================ + +// TestMutexWaitingQueue 测试锁等待队列 +func TestMutexWaitingQueue(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + const numWaiters = 5 + var orderMu sync.Mutex + order := make([]int, 0, numWaiters) + + // 信号通道用于同步 + ready := make([]chan struct{}, numWaiters) + for i := range ready { + ready[i] = make(chan struct{}) + } + + var wg sync.WaitGroup + + for i := 0; i < numWaiters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/queue") + + // 通知已准备好 + close(ready[id]) + + // 获取锁 + err = mutex.Lock(ctx) + require.NoError(t, err) + + // 记录顺序 + orderMu.Lock() + order = append(order, id) + orderMu.Unlock() + + time.Sleep(20 * time.Millisecond) + mutex.Unlock(ctx) + }(i) + + // 等待 goroutine 准备好后再启动下一个 + <-ready[i] + time.Sleep(30 * time.Millisecond) + } + + wg.Wait() + + t.Logf("Acquisition order: %v", order) + assert.Len(t, order, numWaiters) + + // 验证顺序 + expected := make([]int, numWaiters) + for i := range expected { + expected[i] = i + } + assert.Equal(t, expected, order) +} + +// TestMutexWatchEventHandling 测试 Watch 事件处理 +func TestMutexWatchEventHandling(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + // 创建多个会话和锁 + const numSessions = 3 + sessions := make([]*concurrency.Session, numSessions) + mutexes := make([]*concurrency.Mutex, numSessions) + + for i := range sessions { + var err error + sessions[i], err = concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + mutexes[i] = concurrency.NewMutex(sessions[i], "/test/watch-events") + } + + defer func() { + for _, s := range sessions { + s.Close() + } + }() + + // 第一个会话获取锁 + err := mutexes[0].Lock(ctx) + require.NoError(t, err) + + // 其他会话尝试获取锁(会等待) + done := make([]chan error, numSessions-1) + for i := 1; i < numSessions; i++ { + done[i-1] = make(chan error, 1) + go func(idx int) { + done[idx-1] <- mutexes[idx].Lock(ctx) + }(i) + } + + // 等待其他会话进入等待状态 + time.Sleep(200 * time.Millisecond) + + // 释放第一个锁 + err = mutexes[0].Unlock(ctx) + require.NoError(t, err) + + // 验证等待的会话依次获取锁 + for i := 1; i < numSessions; i++ { + select { + case err := <-done[i-1]: + require.NoError(t, err) + t.Logf("Session %d acquired lock", i) + mutexes[i].Unlock(ctx) + case <-time.After(5 * time.Second): + t.Fatalf("Session %d failed to acquire lock", i) + } + } +} + +// ============================================================================ +// Performance Characterization Tests +// ============================================================================ + +// TestLockLatencyDistribution 测试锁延迟分布 +func TestLockLatencyDistribution(t *testing.T) { + _, cli := startLockTestServer(t) + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/test/latency") + + const iterations = 50 + latencies := make([]time.Duration, iterations) + + for i := 0; i < iterations; i++ { + start := time.Now() + err := mutex.Lock(ctx) + latencies[i] = time.Since(start) + require.NoError(t, err) + mutex.Unlock(ctx) + } + + // 计算统计信息 + sort.Slice(latencies, func(i, j int) bool { + return latencies[i] < latencies[j] + }) + + var total time.Duration + for _, l := range latencies { + total += l + } + + avg := total / time.Duration(iterations) + p50 := latencies[iterations/2] + p95 := latencies[iterations*95/100] + p99 := latencies[iterations*99/100] + + t.Logf("Lock latency distribution (n=%d):", iterations) + t.Logf(" Average: %v", avg) + t.Logf(" P50: %v", p50) + t.Logf(" P95: %v", p95) + t.Logf(" P99: %v", p99) + t.Logf(" Min: %v", latencies[0]) + t.Logf(" Max: %v", latencies[iterations-1]) + + // 验证延迟合理 + assert.Less(t, avg, 100*time.Millisecond, "Average latency should be reasonable") +} diff --git a/test/distributed_lock_rocksdb_test.go b/test/distributed_lock_rocksdb_test.go new file mode 100644 index 0000000..92fbfac --- /dev/null +++ b/test/distributed_lock_rocksdb_test.go @@ -0,0 +1,967 @@ +// Copyright 2025 The axfor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build cgo +// +build cgo + +package test + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "metaStore/pkg/concurrency" + + clientv3 "go.etcd.io/etcd/client/v3" + etcdconcurrency "go.etcd.io/etcd/client/v3/concurrency" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================================ +// RocksDB Distributed Lock Test Helper +// ============================================================================ + +// startRocksDBLockTestServer starts a RocksDB-backed server for lock testing +func startRocksDBLockTestServer(t *testing.T) (*clientv3.Client, func()) { + node, cleanup := startRocksDBNode(t, 1) + + cli, err := clientv3.New(clientv3.Config{ + Endpoints: []string{node.clientAddr}, + DialTimeout: 5 * time.Second, + }) + require.NoError(t, err) + + return cli, func() { + // 关键:给 Session 清理留出时间,避免 "connection is closing" 错误 + time.Sleep(500 * time.Millisecond) + cleanup() // 先关闭服务器 + cli.Close() // 再关闭客户端 + } +} + +// ============================================================================ +// RocksDB Session Tests +// ============================================================================ + +func TestRocksDB_SessionCreate(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(10)) + require.NoError(t, err) + require.NotNil(t, session) + + leaseID := session.Lease() + assert.NotEqual(t, clientv3.NoLease, leaseID) + + ttlResp, err := cli.TimeToLive(ctx, leaseID) + require.NoError(t, err) + assert.Greater(t, ttlResp.TTL, int64(0)) + assert.LessOrEqual(t, ttlResp.TTL, int64(10)) + + err = session.Close() + require.NoError(t, err) + + ttlResp, err = cli.TimeToLive(ctx, leaseID) + require.NoError(t, err) + assert.Equal(t, int64(-1), ttlResp.TTL) +} + +func TestRocksDB_SessionWithExistingLease(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + leaseResp, err := cli.Grant(ctx, 30) + require.NoError(t, err) + + session, err := concurrency.NewSession(cli, concurrency.WithLease(leaseResp.ID)) + require.NoError(t, err) + require.NotNil(t, session) + + assert.Equal(t, leaseResp.ID, session.Lease()) + session.Close() +} + +// ============================================================================ +// RocksDB Basic Mutex Tests +// ============================================================================ + +func TestRocksDB_MutexLockUnlock(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/rocksdb/test/lock") + + assert.False(t, mutex.IsOwner()) + assert.Empty(t, mutex.Key()) + + err = mutex.Lock(ctx) + require.NoError(t, err) + + assert.True(t, mutex.IsOwner()) + assert.NotEmpty(t, mutex.Key()) + assert.NotNil(t, mutex.Header()) + + err = mutex.Unlock(ctx) + require.NoError(t, err) + + assert.False(t, mutex.IsOwner()) + assert.Empty(t, mutex.Key()) +} + +func TestRocksDB_MutexReentrantLock(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/rocksdb/test/reentrant") + + err = mutex.Lock(ctx) + require.NoError(t, err) + firstKey := mutex.Key() + + err = mutex.Lock(ctx) + require.NoError(t, err) + + assert.Equal(t, firstKey, mutex.Key()) + + err = mutex.Unlock(ctx) + require.NoError(t, err) +} + +// ============================================================================ +// RocksDB TryLock Tests +// ============================================================================ + +func TestRocksDB_TryLockSuccess(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/rocksdb/test/trylock") + + err = mutex.TryLock(ctx) + require.NoError(t, err) + assert.True(t, mutex.IsOwner()) + + mutex.Unlock(ctx) +} + +func TestRocksDB_TryLockFail(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session1.Close() + + mutex1 := concurrency.NewMutex(session1, "/rocksdb/test/trylock-fail") + err = mutex1.Lock(ctx) + require.NoError(t, err) + + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := concurrency.NewMutex(session2, "/rocksdb/test/trylock-fail") + err = mutex2.TryLock(ctx) + + assert.Error(t, err) + assert.Equal(t, etcdconcurrency.ErrLocked, err) + assert.False(t, mutex2.IsOwner()) + + mutex1.Unlock(ctx) +} + +func TestRocksDB_TryLockAfterUnlock(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session1.Close() + + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + mutex1 := concurrency.NewMutex(session1, "/rocksdb/test/trylock-after-unlock") + mutex2 := concurrency.NewMutex(session2, "/rocksdb/test/trylock-after-unlock") + + err = mutex1.Lock(ctx) + require.NoError(t, err) + + err = mutex2.TryLock(ctx) + assert.Equal(t, etcdconcurrency.ErrLocked, err) + + err = mutex1.Unlock(ctx) + require.NoError(t, err) + + err = mutex2.TryLock(ctx) + require.NoError(t, err) + assert.True(t, mutex2.IsOwner()) + + mutex2.Unlock(ctx) +} + +// ============================================================================ +// RocksDB Concurrent Lock Tests +// ============================================================================ + +func TestRocksDB_MutexContention(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + const numClients = 5 + var wg sync.WaitGroup + acquired := make(chan int, numClients) + released := make(chan int, numClients) + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/rocksdb/test/contention") + + err = mutex.Lock(ctx) + require.NoError(t, err) + + acquired <- id + t.Logf("RocksDB Client %d acquired lock", id) + + time.Sleep(50 * time.Millisecond) + + err = mutex.Unlock(ctx) + require.NoError(t, err) + + released <- id + t.Logf("RocksDB Client %d released lock", id) + }(i) + } + + wg.Wait() + close(acquired) + close(released) + + acquiredClients := make(map[int]bool) + for id := range acquired { + acquiredClients[id] = true + } + assert.Len(t, acquiredClients, numClients) + + releasedClients := make(map[int]bool) + for id := range released { + releasedClients[id] = true + } + assert.Len(t, releasedClients, numClients) +} + +func TestRocksDB_MutexFIFOOrder(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + const numClients = 5 + var orderMu sync.Mutex + acquireOrder := make([]int, 0, numClients) + + // 关键:使用两阶段信号 + // 1. startSignals: 通知 goroutine 开始创建 Session + // 2. sessionReady: Session 创建完成,可以启动下一个 + startSignals := make([]chan struct{}, numClients) + sessionReady := make([]chan struct{}, numClients) + for i := range startSignals { + startSignals[i] = make(chan struct{}) + sessionReady[i] = make(chan struct{}) + } + + var wg sync.WaitGroup + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + <-startSignals[id] + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + // Session 创建完成,通知主线程 + close(sessionReady[id]) + + mutex := concurrency.NewMutex(session, "/rocksdb/test/fifo") + + err = mutex.Lock(ctx) + require.NoError(t, err) + + orderMu.Lock() + acquireOrder = append(acquireOrder, id) + orderMu.Unlock() + + t.Logf("RocksDB Client %d acquired lock at position %d", id, len(acquireOrder)) + + time.Sleep(20 * time.Millisecond) + + mutex.Unlock(ctx) + }(i) + } + + // 按顺序启动客户端,并等待每个 Session 创建完成再启动下一个 + for i := 0; i < numClients; i++ { + close(startSignals[i]) + <-sessionReady[i] // 等待 Session 创建完成(此时 Lease 已经通过 Raft 共识) + time.Sleep(10 * time.Millisecond) // 小延迟确保顺序 + } + + wg.Wait() + + t.Logf("RocksDB Acquire order: %v", acquireOrder) + assert.Len(t, acquireOrder, numClients) + + expectedOrder := make([]int, numClients) + for i := range expectedOrder { + expectedOrder[i] = i + } + assert.Equal(t, expectedOrder, acquireOrder, "Lock acquisition should follow FIFO order") +} + +func TestRocksDB_MutexCriticalSection(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + const numClients = 10 + const iterations = 5 + var counter int64 + var violations int64 + + var wg sync.WaitGroup + for i := 0; i < numClients; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/rocksdb/test/critical-section") + + for j := 0; j < iterations; j++ { + err = mutex.Lock(ctx) + require.NoError(t, err) + + oldVal := atomic.LoadInt64(&counter) + time.Sleep(time.Millisecond) + newVal := atomic.AddInt64(&counter, 1) + + if newVal != oldVal+1 { + atomic.AddInt64(&violations, 1) + } + + mutex.Unlock(ctx) + } + }() + } + + wg.Wait() + + assert.Equal(t, int64(numClients*iterations), atomic.LoadInt64(&counter)) + assert.Equal(t, int64(0), atomic.LoadInt64(&violations), "No race conditions should occur") +} + +// ============================================================================ +// RocksDB Lock Timeout and Cancellation Tests +// ============================================================================ + +func TestRocksDB_MutexLockWithTimeout(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session1.Close() + + mutex1 := concurrency.NewMutex(session1, "/rocksdb/test/timeout") + err = mutex1.Lock(context.Background()) + require.NoError(t, err) + + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := concurrency.NewMutex(session2, "/rocksdb/test/timeout") + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + start := time.Now() + err = mutex2.Lock(ctx) + + elapsed := time.Since(start) + assert.Error(t, err) + assert.True(t, elapsed >= 400*time.Millisecond && elapsed < 1*time.Second, + "Lock should timeout around 500ms, got %v", elapsed) + + mutex1.Unlock(context.Background()) +} + +func TestRocksDB_MutexLockCancellation(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session1.Close() + + mutex1 := concurrency.NewMutex(session1, "/rocksdb/test/cancel") + err = mutex1.Lock(context.Background()) + require.NoError(t, err) + + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := concurrency.NewMutex(session2, "/rocksdb/test/cancel") + + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan error, 1) + go func() { + done <- mutex2.Lock(ctx) + }() + + time.Sleep(200 * time.Millisecond) + cancel() + + select { + case err := <-done: + assert.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + case <-time.After(2 * time.Second): + t.Fatal("Lock should be canceled") + } + + mutex1.Unlock(context.Background()) +} + +// ============================================================================ +// RocksDB Session Failure Tests +// ============================================================================ + +func TestRocksDB_MutexReleaseOnSessionClose(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(3)) // 缩短 TTL 到 3 秒 + require.NoError(t, err) + lease1ID := session1.Lease() + t.Logf("Session1 created with lease ID: %x", lease1ID) + + mutex1 := concurrency.NewMutex(session1, "/rocksdb/test/session-close") + err = mutex1.Lock(ctx) + require.NoError(t, err) + lockKey1 := mutex1.Key() + t.Logf("Session1 acquired lock with key: %s", lockKey1) + + // Verify lease is alive + ttl1, err := cli.TimeToLive(ctx, lease1ID) + require.NoError(t, err) + t.Logf("Session1 lease TTL before close: %d seconds", ttl1.TTL) + + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session2.Close() + t.Logf("Session2 created with lease ID: %x", session2.Lease()) + + mutex2 := concurrency.NewMutex(session2, "/rocksdb/test/session-close") + + acquired := make(chan struct{}) + go func() { + t.Log("Session2 trying to acquire lock...") + err := mutex2.Lock(ctx) + if err == nil { + t.Log("Session2 acquired lock!") + close(acquired) + } else { + t.Logf("Session2 failed to acquire lock: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + t.Log("Closing Session1...") + err = session1.Close() + t.Logf("Session1 closed, error: %v", err) + + // Check if lease was revoked + ttl2, err := cli.TimeToLive(ctx, lease1ID) + require.NoError(t, err) + t.Logf("Session1 lease TTL after close: %d seconds (should be -1)", ttl2.TTL) + + // Check if lock key still exists + getResp, err := cli.Get(ctx, lockKey1) + require.NoError(t, err) + t.Logf("Session1 lock key after close: exists=%v, count=%d", len(getResp.Kvs) > 0, len(getResp.Kvs)) + + select { + case <-acquired: + t.Log("RocksDB Second session acquired lock after first session closed") + assert.True(t, mutex2.IsOwner()) + case <-time.After(10 * time.Second): // 增加超时到 10 秒,给 Raft 共识和 Lease 撤销足够时间 + // Check state before failing + ttl3, _ := cli.TimeToLive(ctx, lease1ID) + t.Logf("Final Session1 lease TTL: %d", ttl3.TTL) + getResp2, _ := cli.Get(ctx, lockKey1) + t.Logf("Final Session1 lock key: exists=%v", len(getResp2.Kvs) > 0) + + t.Fatal("Second session should acquire lock after first session closes") + } + + mutex2.Unlock(ctx) +} + +// ============================================================================ +// RocksDB Election Tests +// ============================================================================ + +func TestRocksDB_ElectionCampaign(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + election := concurrency.NewElection(session, "/rocksdb/test/election") + + assert.False(t, election.IsLeader()) + + err = election.Campaign(ctx, "leader-value") + require.NoError(t, err) + + assert.True(t, election.IsLeader()) + assert.NotEmpty(t, election.Key()) + assert.Greater(t, election.Rev(), int64(0)) + + _, val, err := election.Leader(ctx) + require.NoError(t, err) + assert.Equal(t, "leader-value", val) + + err = election.Resign(ctx) + require.NoError(t, err) + + assert.False(t, election.IsLeader()) +} + +func TestRocksDB_ElectionMultipleCandidates(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + const numCandidates = 3 + var wg sync.WaitGroup + leaderChan := make(chan int, numCandidates) + + for i := 0; i < numCandidates; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + election := concurrency.NewElection(session, "/rocksdb/test/multi-election") + + value := fmt.Sprintf("candidate-%d", id) + err = election.Campaign(ctx, value) + require.NoError(t, err) + + leaderChan <- id + t.Logf("RocksDB Candidate %d became leader", id) + + time.Sleep(100 * time.Millisecond) + + election.Resign(ctx) + t.Logf("RocksDB Candidate %d resigned", id) + }(i) + } + + wg.Wait() + close(leaderChan) + + leaders := make(map[int]bool) + for id := range leaderChan { + leaders[id] = true + } + assert.Len(t, leaders, numCandidates) +} + +// ============================================================================ +// RocksDB Stress Tests +// ============================================================================ + +func TestRocksDB_MutexHighConcurrency(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + const numGoroutines = 20 + const iterations = 10 + var wg sync.WaitGroup + var successCount int64 + var failCount int64 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + if err != nil { + atomic.AddInt64(&failCount, int64(iterations)) + return + } + defer session.Close() + + mutex := concurrency.NewMutex(session, "/rocksdb/test/high-concurrency") + + for j := 0; j < iterations; j++ { + lockCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + err := mutex.Lock(lockCtx) + cancel() + + if err != nil { + atomic.AddInt64(&failCount, 1) + continue + } + + atomic.AddInt64(&successCount, 1) + + time.Sleep(5 * time.Millisecond) + + mutex.Unlock(ctx) + } + }(i) + } + + wg.Wait() + + t.Logf("RocksDB Success: %d, Fail: %d", successCount, failCount) + assert.Equal(t, int64(numGoroutines*iterations), successCount) + assert.Equal(t, int64(0), failCount) +} + +func TestRocksDB_MutexRapidLockUnlock(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/rocksdb/test/rapid") + + const iterations = 100 + for i := 0; i < iterations; i++ { + err := mutex.Lock(ctx) + require.NoError(t, err, "Lock failed at iteration %d", i) + + assert.True(t, mutex.IsOwner()) + + err = mutex.Unlock(ctx) + require.NoError(t, err, "Unlock failed at iteration %d", i) + + assert.False(t, mutex.IsOwner()) + } +} + +// ============================================================================ +// RocksDB Compatibility Tests +// ============================================================================ + +func TestRocksDB_CompatibilityWithEtcdConcurrency(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + etcdSession, err := etcdconcurrency.NewSession(cli, etcdconcurrency.WithTTL(30)) + require.NoError(t, err) + defer etcdSession.Close() + + etcdMutex := etcdconcurrency.NewMutex(etcdSession, "/rocksdb/test/etcd-compat") + + err = etcdMutex.Lock(ctx) + require.NoError(t, err) + + assert.NotEmpty(t, etcdMutex.Key()) + + err = etcdMutex.Unlock(ctx) + require.NoError(t, err) +} + +func TestRocksDB_MixedLockUsage(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + customSession, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer customSession.Close() + + customMutex := concurrency.NewMutex(customSession, "/rocksdb/test/mixed") + + etcdSession, err := etcdconcurrency.NewSession(cli, etcdconcurrency.WithTTL(30)) + require.NoError(t, err) + defer etcdSession.Close() + + etcdMutex := etcdconcurrency.NewMutex(etcdSession, "/rocksdb/test/mixed") + + err = customMutex.Lock(ctx) + require.NoError(t, err) + + tryCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + err = etcdMutex.Lock(tryCtx) + cancel() + assert.Error(t, err, "etcd mutex should not be able to acquire lock") + + customMutex.Unlock(ctx) + + err = etcdMutex.Lock(ctx) + require.NoError(t, err) + + etcdMutex.Unlock(ctx) +} + +// ============================================================================ +// RocksDB Key Format Verification +// ============================================================================ + +func TestRocksDB_MutexKeyFormat(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + prefix := "/rocksdb/test/key-format" + mutex := concurrency.NewMutex(session, prefix) + + err = mutex.Lock(ctx) + require.NoError(t, err) + + key := mutex.Key() + t.Logf("RocksDB Lock key: %s", key) + + assert.Contains(t, key, prefix+"/") + assert.Contains(t, key, fmt.Sprintf("%x", session.Lease())) + + mutex.Unlock(ctx) +} + +// ============================================================================ +// RocksDB Recovery Tests +// ============================================================================ + +func TestRocksDB_MutexRecoveryAfterSessionClose(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + prefix := "/rocksdb/test/recovery" + + session1, err := concurrency.NewSession(cli, concurrency.WithTTL(5)) + require.NoError(t, err) + + mutex1 := concurrency.NewMutex(session1, prefix) + err = mutex1.Lock(ctx) + require.NoError(t, err) + t.Log("RocksDB Session 1 acquired lock") + + session1.Close() + t.Log("RocksDB Session 1 closed") + + session2, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session2.Close() + + mutex2 := concurrency.NewMutex(session2, prefix) + + lockCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + err = mutex2.Lock(lockCtx) + cancel() + + require.NoError(t, err, "Session 2 should acquire lock after session 1 closes") + assert.True(t, mutex2.IsOwner()) + t.Log("RocksDB Session 2 acquired lock") + + mutex2.Unlock(ctx) +} + +// ============================================================================ +// RocksDB Edge Case Tests +// ============================================================================ + +func TestRocksDB_MutexDifferentPrefixes(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(30)) + require.NoError(t, err) + defer session.Close() + + mutex1 := concurrency.NewMutex(session, "/rocksdb/test/prefix1") + mutex2 := concurrency.NewMutex(session, "/rocksdb/test/prefix2") + + err = mutex1.Lock(ctx) + require.NoError(t, err) + + err = mutex2.Lock(ctx) + require.NoError(t, err) + + assert.True(t, mutex1.IsOwner()) + assert.True(t, mutex2.IsOwner()) + + mutex1.Unlock(ctx) + mutex2.Unlock(ctx) +} + +func TestRocksDB_MutexWaitingQueue(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + ctx := context.Background() + + const numWaiters = 5 + var orderMu sync.Mutex + order := make([]int, 0, numWaiters) + + ready := make([]chan struct{}, numWaiters) + for i := range ready { + ready[i] = make(chan struct{}) + } + + var wg sync.WaitGroup + + for i := 0; i < numWaiters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/rocksdb/test/queue") + + close(ready[id]) + + err = mutex.Lock(ctx) + require.NoError(t, err) + + orderMu.Lock() + order = append(order, id) + orderMu.Unlock() + + time.Sleep(20 * time.Millisecond) + mutex.Unlock(ctx) + }(i) + + <-ready[i] + time.Sleep(30 * time.Millisecond) + } + + wg.Wait() + + t.Logf("RocksDB Acquisition order: %v", order) + assert.Len(t, order, numWaiters) + + expected := make([]int, numWaiters) + for i := range expected { + expected[i] = i + } + assert.Equal(t, expected, order) +} + +// ============================================================================ +// RocksDB Data Race Detection Tests +// ============================================================================ + +func TestRocksDB_MutexNoDataRace(t *testing.T) { + cli, cleanup := startRocksDBLockTestServer(t) + defer cleanup() + + session, err := concurrency.NewSession(cli, concurrency.WithTTL(60)) + require.NoError(t, err) + defer session.Close() + + mutex := concurrency.NewMutex(session, "/rocksdb/test/race") + + var wg sync.WaitGroup + + for i := 0; i < 10; i++ { + wg.Add(3) + + go func() { + defer wg.Done() + _ = mutex.IsOwner() + }() + + go func() { + defer wg.Done() + _ = mutex.Key() + }() + + go func() { + defer wg.Done() + _ = mutex.Header() + }() + } + + wg.Wait() +}