diff --git a/stream.go b/stream.go index 0639fd5..ef19170 100644 --- a/stream.go +++ b/stream.go @@ -127,6 +127,7 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) ( func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) { g, ctx := errgroup.WithContext(ctx) + done := make(chan struct{}) url := prediction.URLs["stream"] if url == "" { @@ -161,8 +162,6 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l return } - done := make(chan struct{}) - reader := bufio.NewReader(resp.Body) var buf bytes.Buffer lineChan := make(chan []byte) @@ -211,12 +210,22 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l event := SSEEvent{Type: SSETypeDefault} if err := event.decode(b); err != nil { - errChan <- err + select { + case errChan <- err: + default: + } close(done) return } - sseChan <- event + select { + case sseChan <- event: + case <-done: + return + case <-ctx.Done(): + return + } + if event.Type == SSETypeDone { close(done) return @@ -239,15 +248,11 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l return } - if errors.Is(err, context.Canceled) { - // Context was canceled, simply return - return - } - - select { - case errChan <- err: - default: - // errChan is full or closed + if !errors.Is(err, context.Canceled) { + select { + case errChan <- err: + default: + } } } }()