diff --git a/go.mod b/go.mod index ae41604..a015840 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/gorilla/mux v1.7.4 github.com/kr/pretty v0.1.0 // indirect + github.com/oklog/ulid v1.3.1 github.com/rs/zerolog v1.19.0 github.com/stretchr/testify v1.6.1 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect diff --git a/go.sum b/go.sum index f5cf236..d3239f6 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/internal/lockclient/cache/lru_cache.go b/internal/lockclient/cache/lru_cache.go index c908ab5..8983a5b 100644 --- a/internal/lockclient/cache/lru_cache.go +++ b/internal/lockclient/cache/lru_cache.go @@ -213,9 +213,13 @@ func (lru *LRUCache) deleteElementFromMap(key interface{}) error { } return nil } + +// printMap prints the LRU map and is concurrency safe. func (lru *LRUCache) printMap() { + lru.mu.Lock() for k, v := range lru.m { fmt.Printf("Key: %v, Value: ", k) fmt.Printf("LN: %v, RN: %v, NodeKey: %v\n", v.Left(), v.Right(), v.Key()) } + lru.mu.Unlock() } diff --git a/internal/lockclient/client.go b/internal/lockclient/client.go index 82575f7..024e686 100644 --- a/internal/lockclient/client.go +++ b/internal/lockclient/client.go @@ -1,6 +1,9 @@ package lockclient -import "github.com/SystemBuilders/LocKey/internal/lockservice" +import ( + "github.com/SystemBuilders/LocKey/internal/lockclient/session" + "github.com/SystemBuilders/LocKey/internal/lockservice" +) // Client describes a client that can be used to interact with // the Lockey lockservice. The client can start the lockservice @@ -17,14 +20,18 @@ type Client interface { // to do so. Starting the service should be a non-blocking call // and return as soon as the server is started and setup. StartService(Config) error + // Connect allows the user process to establish a connection + // with the client. This returns an ID of the session that + // results from the connection. + Connect() session.Session // Acquire can be used to acquire a lock on Lockey. This // implementation interacts with the underlying server and // provides the service. - Acquire(lockservice.Descriptors) error + Acquire(lockservice.Object, session.Session) error // Release can be used to release a lock on Lockey. This // implementation interacts with the underlying server and // provides the service. - Release(lockservice.Descriptors) error + Release(lockservice.Object, session.Session) error } // Config describes the configuration for the lockservice to run on. diff --git a/internal/lockclient/errors.go b/internal/lockclient/errors.go new file mode 100644 index 0000000..025a208 --- /dev/null +++ b/internal/lockclient/errors.go @@ -0,0 +1,13 @@ +package lockclient + +// Error provides constant error strings to the driver functions. +type Error string + +func (e Error) Error() string { return string(e) } + +// Constant errors. +// Rule of thumb, all errors start with a small letter and end with no full stop. +const ( + ErrSessionNonExistent = Error("the session related to this process doesn't exist") + ErrSessionExpired = Error("session expired") +) diff --git a/internal/lockclient/id/id.go b/internal/lockclient/id/id.go new file mode 100644 index 0000000..5f2e952 --- /dev/null +++ b/internal/lockclient/id/id.go @@ -0,0 +1,64 @@ +package id + +import ( + "fmt" + "log" + "math/rand" + "sync" + "time" + + "github.com/oklog/ulid" +) + +// ID describes a general identifier. An ID has to be unique application-wide. +// IDs must not be re-used. +type ID interface { + fmt.Stringer + Bytes() []byte +} + +var _ ID = (*id)(nil) + +type id ulid.ULID + +var ( + lock sync.Mutex + randSource = rand.New(rand.NewSource(time.Now().UnixNano())) + entropy = ulid.Monotonic(randSource, 0) +) + +// Create creates a globally unique ID. This function is safe for concurrent +// use. +func Create() ID { + lock.Lock() + defer lock.Unlock() + + genID, err := ulid.New(ulid.Now(), entropy) + if err != nil { + // For this to happen, the random module would have to fail. Since we + // use Go's pseudo RNG, which just jumps around a few numbers, instead + // of using crypto/rand, and we also made this function safe for + // concurrent use, this is nearly impossible to happen. However, with + // the current version of oklog/ulid v1.3.1, this will also break after + // 2121-04-11 11:53:25.01172576 UTC. + log.Fatal(fmt.Errorf("new ulid: %w", err)) + } + return id(genID) +} + +// Parse parses an ID from a byte slice. +func Parse(idBytes []byte) (ID, error) { + parsed, err := ulid.Parse(string(idBytes)) + if err != nil { + return nil, fmt.Errorf("parse: %w", err) + } + return id(parsed), nil +} + +func (id id) String() string { + return ulid.ULID(id).String() +} + +func (id id) Bytes() []byte { + return []byte(id.String()) +} diff --git a/internal/lockclient/session/session.go b/internal/lockclient/session/session.go new file mode 100644 index 0000000..40ea12f --- /dev/null +++ b/internal/lockclient/session/session.go @@ -0,0 +1,18 @@ +package session + +import "github.com/SystemBuilders/LocKey/internal/lockclient/id" + +// Session captures all necessary parameters necessary to +// describe a session with the lockservice in the lockclient. +type Session interface { + // SessionID is the unique ID that represents this session. + // This will be used in every transaction for validating the user. + SessionID() id.ID + // ClientID is the ID of the client that will be created when + // the client is created. This acts as a second layer check along + // with the sessionID. + ClientID() id.ID + // ProcessID the unique ID assigned for the process by the client. + // This will be the third layer check in the security mechanism. + ProcessID() id.ID +} diff --git a/internal/lockclient/session/simple_session.go b/internal/lockclient/session/simple_session.go new file mode 100644 index 0000000..01ea2c7 --- /dev/null +++ b/internal/lockclient/session/simple_session.go @@ -0,0 +1,38 @@ +package session + +import ( + "github.com/SystemBuilders/LocKey/internal/lockclient/id" +) + +var _ Session = (*SimpleSession)(nil) + +// SimpleSession implements a session. +type SimpleSession struct { + sessionID id.ID + clientID id.ID + processID id.ID +} + +// SessionID returns the sessionID of the SimpleSession. +func (s *SimpleSession) SessionID() id.ID { + return s.sessionID +} + +// ClientID returns the clientID of the SimpleSession. +func (s *SimpleSession) ClientID() id.ID { + return s.clientID +} + +// ProcessID returns the processID of the SimpleSession +func (s *SimpleSession) ProcessID() id.ID { + return s.processID +} + +// NewSession returns a new instance of a session with the given parameters. +func NewSession(sessionID, clientID, processID id.ID) Session { + return &SimpleSession{ + sessionID: sessionID, + clientID: clientID, + processID: processID, + } +} diff --git a/internal/lockclient/simple_client.go b/internal/lockclient/simple_client.go index d48ef04..411e997 100644 --- a/internal/lockclient/simple_client.go +++ b/internal/lockclient/simple_client.go @@ -2,14 +2,20 @@ package lockclient import ( "bytes" + "context" "encoding/json" "errors" "io/ioutil" "net/http" "strings" "sync" + "time" + + "github.com/SystemBuilders/LocKey/internal/lockclient/id" + "github.com/rs/zerolog" "github.com/SystemBuilders/LocKey/internal/lockclient/cache" + "github.com/SystemBuilders/LocKey/internal/lockclient/session" "github.com/SystemBuilders/LocKey/internal/lockservice" ) @@ -20,114 +26,306 @@ type SimpleClient struct { config *lockservice.SimpleConfig cache *cache.LRUCache mu sync.Mutex + id id.ID + log zerolog.Logger + + // sessions holds the mapping of a process to a session. + sessions map[id.ID]session.Session + // sessionTimers maintains the timers for each session, + sessionTimers map[id.ID]chan struct{} + // sessionAcquisitions has a list of all the acquisitions + // from a particular process. This has no knowledge of + // whether the process owning the lock has an active session + // or not, this guarantee has to be ensured by the client. + sessionAcquisitions map[id.ID][]lockservice.Descriptors } // NewSimpleClient returns a new SimpleClient of the given parameters. // This client works with or without the existance of a cache. -func NewSimpleClient(config *lockservice.SimpleConfig, cache *cache.LRUCache) *SimpleClient { +func NewSimpleClient(config *lockservice.SimpleConfig, log zerolog.Logger, cache *cache.LRUCache) *SimpleClient { + clientID := id.Create() + sessions := make(map[id.ID]session.Session) + sessionTimers := make(map[id.ID]chan struct{}) + sessionAcquisitions := make(map[id.ID][]lockservice.Descriptors) return &SimpleClient{ - config: config, - cache: cache, + config: config, + cache: cache, + id: clientID, + log: log, + sessions: sessions, + sessionTimers: sessionTimers, + sessionAcquisitions: sessionAcquisitions, } } var _ Client = (*SimpleClient)(nil) +// Connect lets the user process to establish a connection with the +// client. +func (sc *SimpleClient) Connect() session.Session { + sessionID := id.Create() + processID := id.Create() + session := session.NewSession(sessionID, sc.id, processID) + sc.sessions[processID] = session + sc.startSession(processID) + sc.log. + Debug(). + Str(processID.String(), "connected"). + Msg("session created") + return session +} + // Acquire allows the user process to acquire a lock. -func (sc *SimpleClient) Acquire(d lockservice.Descriptors) error { - return sc.acquire(d) +// This returns a "session expired" error if the session expires when +// the lock is being acquired. +// +// All locks acquired during the session will be revoked if the session +// expires. +func (sc *SimpleClient) Acquire(d lockservice.Object, s session.Session) error { + sc.mu.Lock() + if _, ok := sc.sessions[s.ProcessID()]; !ok { + sc.mu.Unlock() + return ErrSessionNonExistent + } + sc.mu.Unlock() + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + close := make(chan struct{}, 1) + go func() { + for { + sc.mu.Lock() + select { + case <-sc.sessionTimers[s.ProcessID()]: + cancel() + sc.log. + Debug(). + Str(s.ProcessID().String(), "user process"). + Msg("session expired, starting graceful shutdown") + sc.mu.Unlock() + sc.gracefulSessionShutDown(s.ProcessID()) + return + case <-close: + sc.mu.Unlock() + return + default: + sc.mu.Unlock() + } + } + }() + ld := lockservice.NewLockDescriptor(d.ID(), s.ProcessID().String()) + err := sc.acquire(ctx, ld) + if err != nil { + return err + } + // Once the lock is guaranteed to be acquired, append it to the acquisitions list. + sc.mu.Lock() + sc.sessionAcquisitions[s.ProcessID()] = append(sc.sessionAcquisitions[s.ProcessID()], ld) + sc.mu.Unlock() + close <- struct{}{} + return nil } // acquire makes an HTTP call to the lockserver and acquires the lock. // This function makes the acquire call and doesn't care about the contention // on the lock service. -// The errors involved may be due the HTTP errors or the lockservice errors. +// The errors involved may be due the HTTP, cache or the lockservice errors. // -// Currently acquire doesn't order the user processes that request for the lock. -func (sc *SimpleClient) acquire(d lockservice.Descriptors) error { - // Check for existance of a cache and check - // if the element is in the cache. - if sc.cache != nil { - _, err := sc.getFromCache(lockservice.ObjectDescriptor{ObjectID: d.ID()}) - // Since there can be cache errors, we have this double check. - // We need to exit if a cache doesn't exist but proceed if the cache - // failed in persisting this element. - if err != nil && err != lockservice.ErrCheckAcquireFailure { - return err - } - } - - endPoint := sc.config.IP() + ":" + sc.config.Port() + "/acquire" - // Since the cache doesn't have the element, query the server. - testData := lockservice.LockRequest{FileID: d.ID(), UserID: d.Owner()} - requestJSON, err := json.Marshal(testData) +// This function doesn't care about sessions or ordering of the user processes and +// thus can be used for book-keeping purposes using a nil context. +// +// To avoid a race condition by returning errors from the goroutine and the +// acquire functionality, a channel is used to capture the errors. +func (sc *SimpleClient) acquire(ctx context.Context, d lockservice.Descriptors) error { - req, err := http.NewRequest("POST", endPoint, bytes.NewBuffer(requestJSON)) - if err != nil { - return err + errChan := make(chan error, 1) + if ctx != nil { + go func() { + for { + select { + case <-ctx.Done(): + errChan <- ErrSessionExpired + default: + } + } + }() } - req.Header.Set("Content-Type", "application/json") - client := &http.Client{} - resp, err := client.Do(req) + go func() { + // Check for existance of a cache and check + // if the element is in the cache. + if sc.cache != nil { + _, err := sc.getFromCache(lockservice.ObjectDescriptor{ObjectID: d.ID()}) + // Since there can be cache errors, we have this double check. + // We need to exit if a cache doesn't exist but proceed if the cache + // failed in persisting this element. + if err != nil && err != lockservice.ErrCheckAcquireFailure { + errChan <- err + return + } + } - if err != nil { - return err - } - defer resp.Body.Close() + endPoint := sc.config.IP() + ":" + sc.config.Port() + "/acquire" + // Since the cache doesn't have the element, query the server. + testData := lockservice.LockRequest{FileID: d.ID(), UserID: d.Owner()} + requestJSON, err := json.Marshal(testData) + if err != nil { + errChan <- err + return + } - body, _ := ioutil.ReadAll(resp.Body) - if resp.StatusCode != 200 { - return errors.New(strings.TrimSpace(string(body))) - } + req, err := http.NewRequest("POST", endPoint, bytes.NewBuffer(requestJSON)) + if err != nil { + errChan <- err + return + } + req.Header.Set("Content-Type", "application/json") - if sc.cache != nil { - err := sc.addToCache(d) + client := &http.Client{} + resp, err := client.Do(req) if err != nil { - return err + errChan <- err + return } - } - return nil + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + errChan <- err + return + } + if resp.StatusCode != 200 { + errChan <- errors.New(strings.TrimSpace(string(body))) + return + } + + if sc.cache != nil { + err := sc.addToCache(d) + if err != nil { + errChan <- err + return + } + } + errChan <- nil + }() + + return <-errChan } // Release makes an HTTP call to the lockserver and releases the lock. // The errors invloved may be due the HTTP errors or the lockservice errors. -func (sc *SimpleClient) Release(d lockservice.Descriptors) error { - endPoint := sc.config.IPAddr + ":" + sc.config.PortAddr + "/release" - data := lockservice.LockRequest{FileID: d.ID(), UserID: d.Owner()} - requestJSON, err := json.Marshal(data) - if err != nil { - return err +// +// Only if there is an active session by the user process, it can release the locks +// once verified that the locks belong to the user process. +func (sc *SimpleClient) Release(d lockservice.Object, s session.Session) error { + sc.mu.Lock() + if _, ok := sc.sessions[s.ProcessID()]; !ok { + sc.mu.Unlock() + return ErrSessionNonExistent } - - req, err := http.NewRequest("POST", endPoint, bytes.NewBuffer(requestJSON)) + sc.mu.Unlock() + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + close := make(chan struct{}, 1) + go func() { + for { + sc.mu.Lock() + select { + case <-sc.sessionTimers[s.ProcessID()]: + cancel() + sc.log. + Debug(). + Str(s.ProcessID().String(), "user process"). + Msg("session expired, starting graceful shutdown") + sc.mu.Unlock() + sc.gracefulSessionShutDown(s.ProcessID()) + return + case <-close: + sc.mu.Unlock() + return + default: + sc.mu.Unlock() + } + } + }() + ld := lockservice.NewLockDescriptor(d.ID(), s.ProcessID().String()) + err := sc.release(ctx, ld) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") + // Remove the descriptor that was released. + sc.removeFromSlice(s.ProcessID(), ld) + close <- struct{}{} + return nil +} - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return (err) - } - defer resp.Body.Close() +// release makes a HTTP call to the lock service and releases the lock. +// This function makes the release call and doesn't care about the contention +// on the lock service. +// The errors involved maybe the HTTP, cache or the lockservice errors. +// +// This function doesn't care about sessions or ordering of the user processes and +// thus can be used for book-keeping purposes using a nil context. +// TODO: Cache invalidation +func (sc *SimpleClient) release(ctx context.Context, d lockservice.Descriptors) (err error) { - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return err - } - if resp.StatusCode != 200 { - return lockservice.Error(strings.TrimSpace(string(body))) + errChan := make(chan error, 1) + if ctx != nil { + go func() { + for { + select { + case <-ctx.Done(): + errChan <- ErrSessionExpired + return + default: + } + } + }() } - if sc.cache != nil { - err = sc.releaseFromCache(d) + go func() { + endPoint := sc.config.IPAddr + ":" + sc.config.PortAddr + "/release" + data := lockservice.LockRequest{FileID: d.ID(), UserID: d.Owner()} + requestJSON, err := json.Marshal(data) if err != nil { - return err + errChan <- err + return } - } - return nil + + req, err := http.NewRequest("POST", endPoint, bytes.NewBuffer(requestJSON)) + if err != nil { + errChan <- err + return + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + errChan <- err + return + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + errChan <- err + return + } + if resp.StatusCode != 200 { + errChan <- lockservice.Error(strings.TrimSpace(string(body))) + return + } + + if sc.cache != nil { + if err != nil { + errChan <- err + return + } + } + errChan <- nil + }() + + return <-errChan } // StartService starts the lockservice LocKey. @@ -219,3 +417,60 @@ func (sc *SimpleClient) releaseFromCache(d lockservice.Descriptors) error { } return cache.ErrCacheDoesntExist } + +// startSession starts the session by initiating the timer for this user process. +// This is a non blocking function which runs on a different goroutine. It sends +// a signal through the "sessionTimers" map for the respective "processID" when +// the session timer ends. +// +// The function starts with creating a new channel, assigning it to the respective +// object in the map and then ends by closing the channel created. +func (sc *SimpleClient) startSession(processID id.ID) { + go func(id.ID) { + sc.log.Debug(). + Str(processID.String(), "user process"). + Msg("session timer started") + timerChan := make(chan struct{}, 1) + sc.mu.Lock() + sc.sessionTimers[processID] = timerChan + sc.mu.Unlock() + // Sessions last for 200ms. + time.Sleep(200 * time.Millisecond) + + sc.mu.Lock() + sc.sessionTimers[processID] <- struct{}{} + close(sc.sessionTimers[processID]) + delete(sc.sessionTimers, processID) + sc.mu.Unlock() + + sc.log.Debug(). + Str(processID.String(), "session timed out"). + Msg("disconnected") + sc.gracefulSessionShutDown(processID) + }(processID) +} + +// gracefulSessionShutdown releases all the locks in the lockservice once the +// session has ended. +func (sc *SimpleClient) gracefulSessionShutDown(processID id.ID) { + sc.mu.Lock() + var sessionAcquisitons = sc.sessionAcquisitions[processID] + sc.mu.Unlock() + for i := range sessionAcquisitons { + sc.release(nil, sessionAcquisitons[i]) + } + sc.mu.Lock() + delete(sc.sessions, processID) + delete(sc.sessionAcquisitions, processID) + sc.mu.Unlock() +} + +func (sc *SimpleClient) removeFromSlice(processID id.ID, d lockservice.Descriptors) { + sc.mu.Lock() + for i := range sc.sessionAcquisitions[processID] { + if sc.sessionAcquisitions[processID][i] == d { + sc.sessionAcquisitions[processID] = append(sc.sessionAcquisitions[processID][:i], sc.sessionAcquisitions[processID][i+1:]...) + } + } + sc.mu.Unlock() +} diff --git a/internal/lockclient/simple_client_test.go b/internal/lockclient/simple_client_test.go index adeed53..ef5e36d 100644 --- a/internal/lockclient/simple_client_test.go +++ b/internal/lockclient/simple_client_test.go @@ -31,40 +31,43 @@ func TestLockService(t *testing.T) { } }() - // Server takes some time to start + // Server takes some time to start. time.Sleep(100 * time.Millisecond) + + // Flow of creating a client and acquiring a lock: + // 1. Create a cache for the client. + // 2. Create a client and plug in the created cache. + // 3. Connect to the said client and hold on to the session value. + // 4. Use the session as a key for all further transactions. t.Run("acquire test release test", func(t *testing.T) { size := 5 cache := cache.NewLRUCache(size) - sc := NewSimpleClient(scfg, cache) + sc := NewSimpleClient(scfg, log, cache) + + session := sc.Connect() - d := lockservice.NewLockDescriptor("test", "owner") + d := lockservice.NewObjectDescriptor("test") - got := sc.Acquire(d) + got := sc.Acquire(d, session) var want error if got != want { t.Errorf("acquire: got %q want %q", got, want) } - d = lockservice.NewLockDescriptor("test1", "owner") - - got = sc.Acquire(d) + d = lockservice.NewObjectDescriptor("test1") + got = sc.Acquire(d, session) if got != want { t.Errorf("acquire: got %q want %q", got, want) } - d = lockservice.NewLockDescriptor("test", "owner") - - got = sc.Release(d) - + d = lockservice.NewObjectDescriptor("test") + got = sc.Release(d, session) if got != want { t.Errorf("release: got %q want %q", got, want) } - d = lockservice.NewLockDescriptor("test1", "owner") - - got = sc.Release(d) - + d = lockservice.NewObjectDescriptor("test1") + got = sc.Release(d, session) if got != want { t.Errorf("release: got %q want %q", got, want) } @@ -73,25 +76,25 @@ func TestLockService(t *testing.T) { t.Run("acquire test, acquire test, release test", func(t *testing.T) { size := 5 cache := cache.NewLRUCache(size) - sc := NewSimpleClient(scfg, cache) + sc := NewSimpleClient(scfg, log, cache) - d := lockservice.NewLockDescriptor("test", "owner") + session := sc.Connect() + d := lockservice.NewObjectDescriptor("test") - got := sc.Acquire(d) + got := sc.Acquire(d, session) var want error if got != want { t.Errorf("acquire: got %q want %q", got, want) } - got = sc.Acquire(d) + session2 := sc.Connect() + got = sc.Acquire(d, session2) want = lockservice.ErrFileacquired if got.Error() != want.Error() { t.Errorf("acquire: got %q want %q", got, want) } - d = lockservice.NewLockDescriptor("test", "owner") - - got = sc.Release(d) + got = sc.Release(d, session) want = nil if got != want { t.Errorf("release: got %q want %q", got, want) @@ -101,37 +104,65 @@ func TestLockService(t *testing.T) { t.Run("acquire test, trying to release test as another entity should fail", func(t *testing.T) { size := 2 cache := cache.NewLRUCache(size) - sc := NewSimpleClient(scfg, cache) + sc := NewSimpleClient(scfg, log, cache) - d := lockservice.NewLockDescriptor("test", "owner1") - got := sc.Acquire(d) + session := sc.Connect() + d := lockservice.NewObjectDescriptor("test") + got := sc.Acquire(d, session) var want error if got != want { t.Errorf("acquire: got %q want %q", got, want) } - d = lockservice.NewLockDescriptor("test", "owner2") - got = sc.Release(d) + session2 := sc.Connect() + got = sc.Release(d, session2) want = lockservice.ErrUnauthorizedAccess if got != want { t.Errorf("acquire: got %v want %v", got, want) } - d = lockservice.NewLockDescriptor("test2", "owner1") - got = sc.Acquire(d) + d = lockservice.NewObjectDescriptor("test2") + got = sc.Acquire(d, session) want = nil if got != want { t.Errorf("acquire: got %q want %q", got, want) } - d = lockservice.NewLockDescriptor("test", "owner1") - got = sc.Release(d) + d = lockservice.NewObjectDescriptor("test") + got = sc.Release(d, session) + want = nil + if got != want { + t.Errorf("release: got %q want %q", got, want) + } + + d = lockservice.NewObjectDescriptor("test2") + got = sc.Release(d, session) want = nil if got != want { t.Errorf("release: got %q want %q", got, want) } }) + t.Run("acquire test and release after session expiry", func(t *testing.T) { + sc := NewSimpleClient(scfg, log, nil) + session := sc.Connect() + d := lockservice.NewObjectDescriptor("test3") + + got := sc.Acquire(d, session) + var want error + if got != want { + t.Errorf("acquire: got %q want %q", got, want) + } + + // Wait for the session to expire + time.Sleep(500 * time.Millisecond) + got = sc.Release(d, session) + want = ErrSessionNonExistent + if got != want { + t.Errorf("release: got %q want %q", got, want) + } + }) + quit <- true return } @@ -157,16 +188,17 @@ func BenchmarkLocKeyWithoutCache(b *testing.B) { }() time.Sleep(100 * time.Millisecond) - sc := NewSimpleClient(scfg, nil) - d := lockservice.NewLockDescriptor("test", "owner") + sc := NewSimpleClient(scfg, log, nil) + session := sc.Connect() + d := lockservice.NewObjectDescriptor("test") for n := 0; n < b.N; n++ { - got := sc.acquire(d) + got := sc.Acquire(d, session) var want error if got != want { b.Errorf("acquire: got %q want %q", got, want) } - got = sc.Release(d) + got = sc.Release(d, session) if got != want { b.Errorf("release: got %q want %q", got, want) } @@ -196,16 +228,17 @@ func BenchmarkLocKeyWithCache(b *testing.B) { size := 5 cache := cache.NewLRUCache(size) - sc := NewSimpleClient(scfg, cache) - d := lockservice.NewLockDescriptor("test", "owner") + sc := NewSimpleClient(scfg, log, cache) + session := sc.Connect() + d := lockservice.NewObjectDescriptor("test") for n := 0; n < b.N; n++ { - got := sc.acquire(d) + got := sc.Acquire(d, session) var want error if got != want { b.Errorf("acquire: got %q want %q", got, want) } - got = sc.Release(d) + got = sc.Release(d, session) if got != want { b.Errorf("release: got %q want %q", got, want) } diff --git a/internal/lockservice/lockservice.go b/internal/lockservice/lockservice.go index d21959f..76e823a 100644 --- a/internal/lockservice/lockservice.go +++ b/internal/lockservice/lockservice.go @@ -26,3 +26,8 @@ type Descriptors interface { ID() string Owner() string } + +// Object describes any object that can be used with the lockservice. +type Object interface { + ID() string +} diff --git a/internal/lockservice/simpleLockService.go b/internal/lockservice/simpleLockService.go index bb61b5d..dc2fa9f 100644 --- a/internal/lockservice/simpleLockService.go +++ b/internal/lockservice/simpleLockService.go @@ -20,13 +20,13 @@ type SimpleConfig struct { // LockRequest is an instance of a request for a lock. type LockRequest struct { - FileID string `json:"FileID"` - UserID string `json:"UserID"` + FileID string `json:"fileID"` + UserID string `json:"userID"` } // LockCheckRequest is an instance of a lock check request. type LockCheckRequest struct { - FileID string `json:"FileID"` + FileID string `json:"fileID"` } // CheckAcquireRes is the response of a Checkacquire. @@ -55,6 +55,7 @@ type SimpleLockService struct { } var _ Descriptors = (*LockDescriptor)(nil) +var _ Object = (*ObjectDescriptor)(nil) // ObjectDescriptor describes the object that is subjected to // lock operations. @@ -62,6 +63,11 @@ type ObjectDescriptor struct { ObjectID string } +// ID returns the ID related to the object. +func (od *ObjectDescriptor) ID() string { + return od.ObjectID +} + // LockDescriptor implements the Descriptors interface. // Many descriptors can be added to this struct and the ID // can be a combination of all those descriptors.