From 0286f454d911182629fbdb9abb431a98ecb8c905 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 23 Sep 2024 05:14:08 -0700 Subject: [PATCH 1/2] Non-blocking send to errChan --- stream.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/stream.go b/stream.go index fbe94d5..d44f383 100644 --- a/stream.go +++ b/stream.go @@ -136,7 +136,10 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) ( func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) { url := prediction.URLs["stream"] if url == "" { - errChan <- errors.New("streaming not supported or not enabled for this prediction") + select { + case errChan <- errors.New("streaming not supported or not enabled for this prediction"): + default: + } return } From 48642e46f324127d82f89af9b0632099dfa840d9 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 23 Sep 2024 05:26:46 -0700 Subject: [PATCH 2/2] Create sendError helper method --- stream.go | 41 +++++++++++++++-------------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/stream.go b/stream.go index d44f383..04ce8c5 100644 --- a/stream.go +++ b/stream.go @@ -97,13 +97,20 @@ func (e *SSEEvent) String() string { } } +func (r *Client) sendError(err error, errChan chan error) { + select { + case errChan <- err: + default: + } +} + func (r *Client) Stream(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (<-chan SSEEvent, <-chan error) { sseChan := make(chan SSEEvent, 64) errChan := make(chan error, 64) id, err := ParseIdentifier(identifier) if err != nil { - errChan <- err + r.sendError(err, errChan) return sseChan, errChan } @@ -115,7 +122,7 @@ func (r *Client) Stream(ctx context.Context, identifier string, input Prediction } if err != nil { - errChan <- err + r.sendError(err, errChan) return sseChan, errChan } @@ -136,19 +143,13 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) ( func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) { url := prediction.URLs["stream"] if url == "" { - select { - case errChan <- errors.New("streaming not supported or not enabled for this prediction"): - default: - } + r.sendError(errors.New("streaming not supported or not enabled for this prediction"), errChan) return } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - select { - case errChan <- fmt.Errorf("failed to create request: %w", err): - default: - } + r.sendError(fmt.Errorf("failed to create request: %w", err), errChan) return } req.Header.Set("Accept", "text/event-stream") @@ -166,18 +167,12 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l } else { defer resp.Body.Close() } - select { - case errChan <- fmt.Errorf("failed to send request: %w", err): - default: - } + r.sendError(fmt.Errorf("failed to send request: %w", err), errChan) return } if resp.StatusCode != http.StatusOK { - select { - case errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode): - default: - } + r.sendError(fmt.Errorf("received invalid status code: %d", resp.StatusCode), errChan) return } @@ -232,10 +227,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l event, err := decodeSSEEvent(b) if err != nil { - select { - case errChan <- err: - default: - } + r.sendError(err, errChan) continue } @@ -272,10 +264,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l } if !errors.Is(err, context.Canceled) { - select { - case errChan <- err: - default: - } + r.sendError(err, errChan) } }