diff --git a/stream.go b/stream.go index fd13415..9883c23 100644 --- a/stream.go +++ b/stream.go @@ -166,20 +166,24 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l g.Go(func() error { defer close(lineChan) + defer resp.Body.Close() for { select { case <-ctx.Done(): - return nil + return ctx.Err() case <-done: return nil default: line, err := reader.ReadBytes('\n') if err != nil { - defer resp.Body.Close() return err } - lineChan <- line + select { + case lineChan <- line: + case <-ctx.Done(): + return ctx.Err() + } } } }) @@ -218,30 +222,27 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l }() go func() { + err := g.Wait() + defer close(sseChan) defer close(errChan) - for { - select { - case <-ctx.Done(): - return - case <-done: + if err != nil { + if errors.Is(err, io.EOF) { + // Attempt to reconnect if the connection was closed before the stream was done + r.streamPrediction(ctx, prediction, lastEvent, sseChan, errChan) return - default: - err := g.Wait() - if err != nil { - if err == io.EOF { - // Attempt to reconnect if the connection was closed before the stream was done - r.streamPrediction(ctx, prediction, lastEvent, sseChan, errChan) - continue - } + } - if errors.Is(err, context.Canceled) { - return - } + if errors.Is(err, context.Canceled) { + // Context was canceled, simply return + return + } - errChan <- err - } + select { + case errChan <- err: + default: + // errChan is full or closed } } }()