diff --git a/connection.go b/connection.go index e183eff..086c490 100644 --- a/connection.go +++ b/connection.go @@ -50,6 +50,10 @@ type Connection struct { pingIdLock sync.Mutex pingId uint32 pingChans map[uint32]chan error + + shutdownLock sync.Mutex + shutdownChan chan error + hasShutdown bool } // NewConnection creates a new spdy connection from an existing @@ -91,6 +95,8 @@ func NewConnection(conn net.Conn, server bool) (*Connection, error) { pingId: pid, pingChans: make(map[uint32]chan error), + + shutdownChan: make(chan error), } return session, nil @@ -451,12 +457,7 @@ func (s *Connection) handleGoAwayFrame(frame *spdy.GoAwayFrame) error { } // Do not block frame handler waiting for closure - go func() { - err := s.waitClose(s.goAwayTimeout) - if err != nil { - fmt.Errorf("close error: %s", err) - } - }() + go s.shutdown(s.goAwayTimeout) return nil } @@ -511,7 +512,16 @@ func (s *Connection) CreateStream(headers http.Header, parent *Stream, fin bool) return stream, s.sendStream(stream, fin) } -func (s *Connection) waitClose(closeTimeout time.Duration) (err error) { +func (s *Connection) shutdown(closeTimeout time.Duration) { + // TODO Ensure this isn't called multiple times + s.shutdownLock.Lock() + if s.hasShutdown { + s.shutdownLock.Unlock() + return + } + s.hasShutdown = true + s.shutdownLock.Unlock() + var timeout <-chan time.Time if closeTimeout > time.Duration(0) { timeout = time.After(closeTimeout) @@ -528,6 +538,7 @@ func (s *Connection) waitClose(closeTimeout time.Duration) (err error) { close(streamsClosed) }() + var err error select { case <-streamsClosed: // No active streams, close should be safe @@ -539,11 +550,25 @@ func (s *Connection) waitClose(closeTimeout time.Duration) (err error) { <-streamsClosed } + if err != nil { + duration := 10 * time.Minute + time.AfterFunc(duration, func() { + select { + case err, ok := <-s.shutdownChan: + if ok { + fmt.Errorf("Unhandled close error after %s: %s", duration, err) + } + default: + } + }) + s.shutdownChan <- err + } + close(s.shutdownChan) + return } -// Closes spdy connection by sending GoAway frame and waiting for -// streams to finish. +// Closes spdy connection by sending GoAway frame and initiating shutdown func (s *Connection) Close() error { s.receiveIdLock.Lock() if s.goneAway { @@ -570,7 +595,48 @@ func (s *Connection) Close() error { return err } - return s.waitClose(s.closeTimeout) + go s.shutdown(s.closeTimeout) + + return nil +} + +// CloseWait closes the connection and waits for shutdown +// to finish. Note the underlying network Connection +// is not closed until the end of shutdown. +func (s *Connection) CloseWait() error { + closeErr := s.Close() + if closeErr != nil { + return closeErr + } + shutdownErr, ok := <-s.shutdownChan + if ok { + return shutdownErr + } + return nil +} + +// Wait waits for the connection to finish shutdown or for +// the wait timeout duration to expire. This needs to be +// called either after Close has been called or the GOAWAYFRAME +// has been received. If the wait timeout is 0, this function +// will block until shutdown finishes. If wait is never called +// and a shutdown error occurs, that error will be logged as an +// unhandled error. +func (s *Connection) Wait(waitTimeout time.Duration) error { + var timeout <-chan time.Time + if waitTimeout > time.Duration(0) { + timeout = time.After(waitTimeout) + } + + select { + case err, ok := <-s.shutdownChan: + if ok { + return err + } + case <-timeout: + return ErrTimeout + } + return nil } // NotifyClose registers a channel to be called when the remote