diff --git a/stream.go b/stream.go index ef19170..fbe94d5 100644 --- a/stream.go +++ b/stream.go @@ -41,8 +41,11 @@ type SSEEvent struct { Data string } -func (e *SSEEvent) decode(b []byte) error { +// decodeSSEEvent parses the raw SSE event data and returns an SSEEvent pointer and an error. +func decodeSSEEvent(b []byte) (*SSEEvent, error) { chunks := [][]byte{} + e := &SSEEvent{Type: SSETypeDefault} + for _, line := range bytes.Split(b, []byte("\n")) { // Parse field and value from line parts := bytes.SplitN(line, []byte{':'}, 2) @@ -56,7 +59,7 @@ func (e *SSEEvent) decode(b []byte) error { if len(parts) == 2 { value = parts[1] // Trim leading space if present - value, _ = bytes.CutPrefix(value, []byte(" ")) + value = bytes.TrimPrefix(value, []byte(" ")) } switch field { @@ -73,16 +76,21 @@ func (e *SSEEvent) decode(b []byte) error { data := bytes.Join(chunks, []byte("\n")) if !utf8.Valid(data) { - return ErrInvalidUTF8Data + return nil, ErrInvalidUTF8Data } e.Data = string(data) - return nil + // Return nil if event data is empty and event type is not "done" + if e.Data == "" && e.Type != SSETypeDone { + return nil, nil + } + + return e, nil } func (e *SSEEvent) String() string { switch e.Type { - case "output": + case SSETypeOutput: return e.Data default: return "" @@ -126,9 +134,6 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) ( } func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) { - g, ctx := errgroup.WithContext(ctx) - done := make(chan struct{}) - url := prediction.URLs["stream"] if url == "" { errChan <- errors.New("streaming not supported or not enabled for this prediction") @@ -137,7 +142,10 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - errChan <- fmt.Errorf("failed to create request: %w", err) + select { + case errChan <- fmt.Errorf("failed to create request: %w", err): + default: + } return } req.Header.Set("Accept", "text/event-stream") @@ -150,15 +158,23 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l resp, err := r.c.Do(req) if err != nil || resp == nil { - if resp != nil { - resp.Body.Close() + if resp == nil { + err = errors.New("received nil response") + } else { + defer resp.Body.Close() + } + select { + case errChan <- fmt.Errorf("failed to send request: %w", err): + default: } - errChan <- fmt.Errorf("failed to send request: %w", err) return } if resp.StatusCode != http.StatusOK { - errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode) + select { + case errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode): + default: + } return } @@ -166,6 +182,9 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l var buf bytes.Buffer lineChan := make(chan []byte) + g, ctx := errgroup.WithContext(ctx) + done := make(chan struct{}) + g.Go(func() error { defer close(lineChan) defer resp.Body.Close() @@ -208,18 +227,22 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l b := buf.Bytes() buf.Reset() - event := SSEEvent{Type: SSETypeDefault} - if err := event.decode(b); err != nil { + event, err := decodeSSEEvent(b) + if err != nil { select { case errChan <- err: default: } - close(done) - return + continue + } + + if event == nil { + // Skip empty events + continue } select { - case sseChan <- event: + case sseChan <- *event: case <-done: return case <-ctx.Done(): @@ -238,9 +261,6 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l go func() { err := g.Wait() - defer close(sseChan) - defer close(errChan) - if err != nil { if errors.Is(err, io.EOF) { // Attempt to reconnect if the connection was closed before the stream was done @@ -255,5 +275,8 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l } } } + + close(sseChan) + close(errChan) }() }