From 8f8a4e0942d9cac471a5058860cbc431461deffb Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Sep 2024 06:35:05 -0700 Subject: [PATCH 1/6] Refactor SSE creation to stop propagating empty events --- stream.go | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/stream.go b/stream.go index ef19170..230b140 100644 --- a/stream.go +++ b/stream.go @@ -41,8 +41,11 @@ type SSEEvent struct { Data string } -func (e *SSEEvent) decode(b []byte) error { +// decodeSSEEvent parses the raw SSE event data and returns an SSEEvent pointer and an error. +func decodeSSEEvent(b []byte) (*SSEEvent, error) { chunks := [][]byte{} + e := &SSEEvent{Type: SSETypeDefault} + for _, line := range bytes.Split(b, []byte("\n")) { // Parse field and value from line parts := bytes.SplitN(line, []byte{':'}, 2) @@ -56,7 +59,7 @@ func (e *SSEEvent) decode(b []byte) error { if len(parts) == 2 { value = parts[1] // Trim leading space if present - value, _ = bytes.CutPrefix(value, []byte(" ")) + value = bytes.TrimPrefix(value, []byte(" ")) } switch field { @@ -73,11 +76,16 @@ func (e *SSEEvent) decode(b []byte) error { data := bytes.Join(chunks, []byte("\n")) if !utf8.Valid(data) { - return ErrInvalidUTF8Data + return nil, ErrInvalidUTF8Data } e.Data = string(data) - return nil + // Return nil if event data is empty and event type is not "done" + if e.Data == "" && e.Type != SSETypeDone { + return nil, nil + } + + return e, nil } func (e *SSEEvent) String() string { @@ -208,18 +216,22 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l b := buf.Bytes() buf.Reset() - event := SSEEvent{Type: SSETypeDefault} - if err := event.decode(b); err != nil { + event, err := decodeSSEEvent(b) + if err != nil { select { case errChan <- err: default: } - close(done) - return + continue + } + + if event == nil { + // Skip empty events + continue } select { - case sseChan <- event: + case sseChan <- *event: case <-done: return case <-ctx.Done(): From 981436b23ca1d1a9a70ec077621fda26c63c5f43 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Sep 2024 06:59:13 -0700 Subject: [PATCH 2/6] This works --- stream.go | 45 ++++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/stream.go b/stream.go index 230b140..086f2a8 100644 --- a/stream.go +++ b/stream.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "sync" "unicode/utf8" "golang.org/x/sync/errgroup" @@ -90,7 +91,13 @@ func decodeSSEEvent(b []byte) (*SSEEvent, error) { func (e *SSEEvent) String() string { switch e.Type { - case "output": + case SSETypeDone: + return "" + case SSETypeError: + return e.Data + case SSETypeLogs: + return e.Data + case SSETypeOutput: return e.Data default: return "" @@ -134,18 +141,18 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) ( } func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) { - g, ctx := errgroup.WithContext(ctx) - done := make(chan struct{}) - url := prediction.URLs["stream"] if url == "" { errChan <- errors.New("streaming not supported or not enabled for this prediction") return } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, url, nil) if err != nil { - errChan <- fmt.Errorf("failed to create request: %w", err) + select { + case errChan <- fmt.Errorf("failed to create request: %w", err): + default: + } return } req.Header.Set("Accept", "text/event-stream") @@ -157,16 +164,22 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l } resp, err := r.c.Do(req) - if err != nil || resp == nil { + if err != nil { if resp != nil { - resp.Body.Close() + defer resp.Body.Close() + } + select { + case errChan <- fmt.Errorf("failed to send request: %w", err): + default: } - errChan <- fmt.Errorf("failed to send request: %w", err) return } if resp.StatusCode != http.StatusOK { - errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode) + select { + case errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode): + default: + } return } @@ -174,6 +187,10 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l var buf bytes.Buffer lineChan := make(chan []byte) + g, ctx := errgroup.WithContext(ctx) + done := make(chan struct{}) + closeOnce := sync.Once{} + g.Go(func() error { defer close(lineChan) defer resp.Body.Close() @@ -250,9 +267,6 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l go func() { err := g.Wait() - defer close(sseChan) - defer close(errChan) - if err != nil { if errors.Is(err, io.EOF) { // Attempt to reconnect if the connection was closed before the stream was done @@ -267,5 +281,10 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l } } } + + closeOnce.Do(func() { + close(sseChan) + close(errChan) + }) }() } From 9b52ee2ab3da932c31537c0fd3e5e402e2ca58d3 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Sep 2024 07:54:23 -0700 Subject: [PATCH 3/6] Remove unnecessary once for closing channels --- stream.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/stream.go b/stream.go index 086f2a8..e7f02be 100644 --- a/stream.go +++ b/stream.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "net/http" - "sync" "unicode/utf8" "golang.org/x/sync/errgroup" @@ -189,7 +188,6 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l g, ctx := errgroup.WithContext(ctx) done := make(chan struct{}) - closeOnce := sync.Once{} g.Go(func() error { defer close(lineChan) @@ -282,9 +280,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l } } - closeOnce.Do(func() { - close(sseChan) - close(errChan) - }) + close(sseChan) + close(errChan) }() } From 290d2e4d788e2b5dee838e8781a9b888a03a4d7f Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Sep 2024 07:55:54 -0700 Subject: [PATCH 4/6] Simplify String() method --- stream.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/stream.go b/stream.go index e7f02be..89029da 100644 --- a/stream.go +++ b/stream.go @@ -90,12 +90,6 @@ func decodeSSEEvent(b []byte) (*SSEEvent, error) { func (e *SSEEvent) String() string { switch e.Type { - case SSETypeDone: - return "" - case SSETypeError: - return e.Data - case SSETypeLogs: - return e.Data case SSETypeOutput: return e.Data default: From 74a41810ee42732a96771fa46b1fa8eab3947394 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Sep 2024 07:57:11 -0700 Subject: [PATCH 5/6] Use existing context instead of TODO --- stream.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stream.go b/stream.go index 89029da..401cd27 100644 --- a/stream.go +++ b/stream.go @@ -140,7 +140,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l return } - req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { select { case errChan <- fmt.Errorf("failed to create request: %w", err): From 91173cbcedf5ff6532e8ffd8d7bac8e2ca340e37 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Sep 2024 08:03:07 -0700 Subject: [PATCH 6/6] Fix linting error --- stream.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/stream.go b/stream.go index 401cd27..fbe94d5 100644 --- a/stream.go +++ b/stream.go @@ -157,8 +157,10 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l } resp, err := r.c.Do(req) - if err != nil { - if resp != nil { + if err != nil || resp == nil { + if resp == nil { + err = errors.New("received nil response") + } else { defer resp.Body.Close() } select {