diff --git a/stream.go b/stream.go index fbe94d5..04ce8c5 100644 --- a/stream.go +++ b/stream.go @@ -97,13 +97,20 @@ func (e *SSEEvent) String() string { } } +func (r *Client) sendError(err error, errChan chan error) { + select { + case errChan <- err: + default: + } +} + func (r *Client) Stream(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (<-chan SSEEvent, <-chan error) { sseChan := make(chan SSEEvent, 64) errChan := make(chan error, 64) id, err := ParseIdentifier(identifier) if err != nil { - errChan <- err + r.sendError(err, errChan) return sseChan, errChan } @@ -115,7 +122,7 @@ func (r *Client) Stream(ctx context.Context, identifier string, input Prediction } if err != nil { - errChan <- err + r.sendError(err, errChan) return sseChan, errChan } @@ -136,16 +143,13 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) ( func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) { url := prediction.URLs["stream"] if url == "" { - errChan <- errors.New("streaming not supported or not enabled for this prediction") + r.sendError(errors.New("streaming not supported or not enabled for this prediction"), errChan) return } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - select { - case errChan <- fmt.Errorf("failed to create request: %w", err): - default: - } + r.sendError(fmt.Errorf("failed to create request: %w", err), errChan) return } req.Header.Set("Accept", "text/event-stream") @@ -163,18 +167,12 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l } else { defer resp.Body.Close() } - select { - case errChan <- fmt.Errorf("failed to send request: %w", err): - default: - } + r.sendError(fmt.Errorf("failed to send request: %w", err), errChan) return } if resp.StatusCode != http.StatusOK { - select { - case errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode): - default: - } + r.sendError(fmt.Errorf("received invalid status code: %d", resp.StatusCode), errChan) return } @@ -229,10 +227,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l event, err := decodeSSEEvent(b) if err != nil { - select { - case errChan <- err: - default: - } + r.sendError(err, errChan) continue } @@ -269,10 +264,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l } if !errors.Is(err, context.Canceled) { - select { - case errChan <- err: - default: - } + r.sendError(err, errChan) } }