From 5159997ba61868277eafe39b7dc2b26bc0ea27dc Mon Sep 17 00:00:00 2001 From: Gjermund Garaba Date: Mon, 6 Apr 2026 12:56:22 +0200 Subject: [PATCH] fix: harden dealer connection lifecycle Prevent panics and races during dealer close/reconnect by: - Initializing stop channels in the constructor instead of on each connect - Adding a permanent closed state that rejects Connect/writes/receivers after Close - Using non-blocking signal sends and clearing stale signals before receiving - Extracting safe read/write/close helpers that hold connMu properly - Resetting pong deadline on reconnect to avoid false timeouts --- dealer/dealer.go | 152 +++++++++++++++++++++++++++++++++--------- dealer/dealer_test.go | 94 ++++++++++++++++++++++++++ dealer/recv.go | 31 ++++++--- 3 files changed, 233 insertions(+), 44 deletions(-) create mode 100644 dealer/dealer_test.go diff --git a/dealer/dealer.go b/dealer/dealer.go index c10cad48..3c61687b 100644 --- a/dealer/dealer.go +++ b/dealer/dealer.go @@ -3,6 +3,7 @@ package dealer import ( "context" "encoding/json" + "errors" "fmt" "math" "net/http" @@ -19,6 +20,8 @@ const ( timeout = 10 * time.Second ) +var ErrDealerClosed = errors.New("dealer closed") + type Dealer struct { log librespot.Logger @@ -29,6 +32,7 @@ type Dealer struct { conn *websocket.Conn + closed bool stop bool pingTickerStop chan struct{} recvLoopStop chan struct{} @@ -58,6 +62,8 @@ func NewDealer(log librespot.Logger, client *http.Client, dealerAddr librespot.G log: log, addr: dealerAddr, accessToken: accessToken, + pingTickerStop: make(chan struct{}, 1), + recvLoopStop: make(chan struct{}, 1), requestReceivers: map[string]requestReceiver{}, } } @@ -66,6 +72,10 @@ func (d *Dealer) Connect(ctx context.Context) error { d.connMu.Lock() defer d.connMu.Unlock() + if d.closed { + return ErrDealerClosed + } + if d.conn != nil && !d.stop { d.log.Debugf("dealer connection already opened") return nil @@ -75,8 +85,6 @@ func (d *Dealer) Connect(ctx context.Context) error { } func (d *Dealer) connect(ctx context.Context) error { - d.recvLoopStop = make(chan struct{}, 1) - d.pingTickerStop = make(chan struct{}, 1) d.stop = false accessToken, err := d.accessToken(ctx, false) @@ -106,27 +114,25 @@ func (d *Dealer) connect(ctx context.Context) error { func (d *Dealer) Close() { d.connMu.Lock() - defer d.connMu.Unlock() - + d.closed = true d.stop = true + conn := d.conn + d.connMu.Unlock() - if d.conn == nil { - return - } + d.signalStop() - d.recvLoopStop <- struct{}{} - d.pingTickerStop <- struct{}{} - _ = d.conn.Close(websocket.StatusGoingAway, "") + if conn != nil { + _ = conn.Close(websocket.StatusGoingAway, "") + } } func (d *Dealer) startReceiving() { d.recvLoopOnce.Do(func() { + d.clearStopSignals() d.log.Tracef("starting dealer recv loop") - go d.recvLoop() - - // set last pong in the future - d.lastPong = time.Now().Add(pingInterval) + d.resetPongDeadline() go d.pingTicker() + go d.recvLoop() }) } @@ -139,27 +145,23 @@ loop: case <-d.pingTickerStop: break loop case <-ticker.C: - d.lastPongLock.Lock() - timePassed := time.Since(d.lastPong) - d.lastPongLock.Unlock() + timePassed := d.timeSinceLastPong() if timePassed > pingInterval+timeout { d.log.Errorf("did not receive last pong from dealer, %.0fs passed", timePassed.Seconds()) // closing the connection should make the read on the "recvLoop" fail, // continue hoping for a new connection - _ = d.conn.Close(websocket.StatusServiceRestart, "") + d.closeConn(websocket.StatusServiceRestart) continue } ctx, cancel := context.WithTimeout(context.Background(), timeout) - d.connMu.RLock() - err := d.conn.Write(ctx, websocket.MessageText, []byte("{\"type\":\"ping\"}")) - d.connMu.RUnlock() + conn, err := d.writeConn(ctx, websocket.MessageText, []byte("{\"type\":\"ping\"}")) cancel() d.log.Tracef("sent dealer ping") if err != nil { - if d.stop { + if d.isStopped() { // break early without logging if we should stop break loop } @@ -168,7 +170,7 @@ loop: // closing the connection should make the read on the "recvLoop" fail, // continue hoping for a new connection - _ = d.conn.Close(websocket.StatusServiceRestart, "") + d.closeConnRef(conn, websocket.StatusServiceRestart) continue } } @@ -185,10 +187,10 @@ loop: break loop default: // no need to hold the connMu since reconnection happens in this routine - msgType, messageBytes, err := d.conn.Read(context.Background()) + msgType, messageBytes, err := d.readConn(context.Background()) // don't log closed error if we're stopping - if d.stop && websocket.CloseStatus(err) == websocket.StatusGoingAway { + if d.isStopped() && websocket.CloseStatus(err) == websocket.StatusGoingAway { d.log.Debugf("dealer connection closed") break loop } else if err != nil { @@ -229,10 +231,10 @@ loop: } // always close as we might end up here because of application errors - _ = d.conn.Close(websocket.StatusInternalError, "") + d.closeConn(websocket.StatusInternalError) // if we shouldn't stop, try to reconnect - if !d.stop { + if !d.isStopped() { d.connMu.Lock() if err := backoff.Retry(d.reconnect, backoff.NewExponentialBackOff()); err != nil { d.log.WithError(err).Errorf("failed reconnecting dealer") @@ -273,9 +275,7 @@ func (d *Dealer) sendReply(key string, success bool) error { } ctx, cancel := context.WithTimeout(context.Background(), timeout) - d.connMu.RLock() - err = d.conn.Write(ctx, websocket.MessageText, replyBytes) - d.connMu.RUnlock() + _, err = d.writeConn(ctx, websocket.MessageText, replyBytes) cancel() if err != nil { return fmt.Errorf("failed sending dealer reply: %w", err) @@ -289,12 +289,98 @@ func (d *Dealer) reconnect() error { return err } - d.lastPongLock.Lock() - d.lastPong = time.Now() - d.lastPongLock.Unlock() + d.resetPongDeadline() // restart the recv loop go d.recvLoop() d.log.Debugf("re-established dealer connection") return nil } + +func (d *Dealer) resetPongDeadline() { + d.lastPongLock.Lock() + d.lastPong = time.Now().Add(pingInterval) + d.lastPongLock.Unlock() +} + +func (d *Dealer) timeSinceLastPong() time.Duration { + d.lastPongLock.Lock() + defer d.lastPongLock.Unlock() + return time.Since(d.lastPong) +} + +func (d *Dealer) closeConn(status websocket.StatusCode) { + d.connMu.RLock() + conn := d.conn + d.connMu.RUnlock() + + d.closeConnRef(conn, status) +} + +func (d *Dealer) closeConnRef(conn *websocket.Conn, status websocket.StatusCode) { + if conn != nil { + _ = conn.Close(status, "") + } +} + +func (d *Dealer) writeConn(ctx context.Context, typ websocket.MessageType, payload []byte) (*websocket.Conn, error) { + d.connMu.RLock() + + if d.closed { + d.connMu.RUnlock() + return nil, ErrDealerClosed + } + + conn := d.conn + + if conn == nil { + d.connMu.RUnlock() + return nil, fmt.Errorf("dealer connection not established") + } + + err := conn.Write(ctx, typ, payload) + d.connMu.RUnlock() + return conn, err +} + +func (d *Dealer) readConn(ctx context.Context) (websocket.MessageType, []byte, error) { + d.connMu.RLock() + conn := d.conn + d.connMu.RUnlock() + + if conn == nil { + return 0, nil, fmt.Errorf("dealer connection not established") + } + + return conn.Read(ctx) +} + +func (d *Dealer) signalStop() { + select { + case d.recvLoopStop <- struct{}{}: + default: + } + + select { + case d.pingTickerStop <- struct{}{}: + default: + } +} + +func (d *Dealer) clearStopSignals() { + select { + case <-d.recvLoopStop: + default: + } + + select { + case <-d.pingTickerStop: + default: + } +} + +func (d *Dealer) isStopped() bool { + d.connMu.RLock() + defer d.connMu.RUnlock() + return d.stop +} diff --git a/dealer/dealer_test.go b/dealer/dealer_test.go new file mode 100644 index 00000000..5a240695 --- /dev/null +++ b/dealer/dealer_test.go @@ -0,0 +1,94 @@ +package dealer + +import ( + "context" + "errors" + "testing" + "testing/synctest" + "time" + + "github.com/coder/websocket" + librespot "github.com/devgianlu/go-librespot" +) + +func TestPingTickerDoesNotPanicWhenConnNil(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + d := &Dealer{ + log: &librespot.NullLogger{}, + pingTickerStop: make(chan struct{}, 1), + } + + panicCh := make(chan any, 1) + go func() { + defer func() { + panicCh <- recover() + }() + d.pingTicker() + }() + + time.Sleep(pingInterval + timeout + time.Nanosecond) + synctest.Wait() + + select { + case p := <-panicCh: + if p != nil { + t.Fatalf("pingTicker panicked when conn was nil: %v", p) + } + default: + } + + d.pingTickerStop <- struct{}{} + synctest.Wait() + + select { + case p := <-panicCh: + if p != nil { + t.Fatalf("pingTicker panicked when conn was nil: %v", p) + } + default: + t.Fatal("pingTicker did not stop") + } + }) +} + +func TestCloseStopsPingTickerWhenConnNil(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + d := &Dealer{ + log: &librespot.NullLogger{}, + pingTickerStop: make(chan struct{}, 1), + } + + done := make(chan struct{}) + go func() { + defer close(done) + d.pingTicker() + }() + + synctest.Wait() + d.Close() + synctest.Wait() + + stopped := false + select { + case <-done: + stopped = true + default: + } + + d.pingTickerStop <- struct{}{} + synctest.Wait() + + if !stopped { + t.Fatal("pingTicker did not stop when closing with nil conn") + } + }) +} + +func TestWriteConnRejectsClosedDealer(t *testing.T) { + d := &Dealer{closed: true} + + _, err := d.writeConn(context.Background(), websocket.MessageText, nil) + if !errors.Is(err, ErrDealerClosed) { + t.Fatalf("expected ErrDealerClosed, got %v", err) + } +} diff --git a/dealer/recv.go b/dealer/recv.go index 5199bca6..58bcb6e1 100644 --- a/dealer/recv.go +++ b/dealer/recv.go @@ -187,15 +187,20 @@ func (d *Dealer) ReceiveMessage(uriPrefixes ...string) <-chan Message { panic("uri prefixes list cannot be empty") } - d.messageReceiversLock.Lock() - defer d.messageReceiversLock.Unlock() + d.connMu.RLock() + if d.closed { + d.connMu.RUnlock() + c := make(chan Message) + close(c) + return c + } - // create new receiver + d.messageReceiversLock.Lock() c := make(chan Message) d.messageReceivers = append(d.messageReceivers, messageReceiver{uriPrefixes, c}) - - // start receiving if necessary d.startReceiving() + d.messageReceiversLock.Unlock() + d.connMu.RUnlock() return c } @@ -242,20 +247,24 @@ func (d *Dealer) handleRequest(rawMsg *RawMessage) { } func (d *Dealer) ReceiveRequest(uri string) <-chan Request { + d.connMu.RLock() + if d.closed { + d.connMu.RUnlock() + c := make(chan Request) + close(c) + return c + } + d.requestReceiversLock.Lock() defer d.requestReceiversLock.Unlock() + defer d.connMu.RUnlock() - // check that there isn't another receiver for this uri - _, ok := d.requestReceivers[uri] - if ok { + if _, ok := d.requestReceivers[uri]; ok { panic(fmt.Sprintf("cannot have more request receivers for %s", uri)) } - // create new receiver c := make(chan Request) d.requestReceivers[uri] = requestReceiver{c} - - // start receiving if necessary d.startReceiving() return c