Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 44 additions & 21 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ type SSEEvent struct {
Data string
}

func (e *SSEEvent) decode(b []byte) error {
// decodeSSEEvent parses the raw SSE event data and returns an SSEEvent pointer and an error.
func decodeSSEEvent(b []byte) (*SSEEvent, error) {
chunks := [][]byte{}
e := &SSEEvent{Type: SSETypeDefault}

for _, line := range bytes.Split(b, []byte("\n")) {
// Parse field and value from line
parts := bytes.SplitN(line, []byte{':'}, 2)
Expand All @@ -56,7 +59,7 @@ func (e *SSEEvent) decode(b []byte) error {
if len(parts) == 2 {
value = parts[1]
// Trim leading space if present
value, _ = bytes.CutPrefix(value, []byte(" "))
value = bytes.TrimPrefix(value, []byte(" "))
}

switch field {
Expand All @@ -73,16 +76,21 @@ func (e *SSEEvent) decode(b []byte) error {

data := bytes.Join(chunks, []byte("\n"))
if !utf8.Valid(data) {
return ErrInvalidUTF8Data
return nil, ErrInvalidUTF8Data
}
e.Data = string(data)

return nil
// Return nil if event data is empty and event type is not "done"
if e.Data == "" && e.Type != SSETypeDone {
return nil, nil
}

return e, nil
}

func (e *SSEEvent) String() string {
switch e.Type {
case "output":
case SSETypeOutput:
return e.Data
default:
return ""
Expand Down Expand Up @@ -126,9 +134,6 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (
}

func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) {
g, ctx := errgroup.WithContext(ctx)
done := make(chan struct{})

url := prediction.URLs["stream"]
if url == "" {
errChan <- errors.New("streaming not supported or not enabled for this prediction")
Expand All @@ -137,7 +142,10 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
errChan <- fmt.Errorf("failed to create request: %w", err)
select {
case errChan <- fmt.Errorf("failed to create request: %w", err):
default:
}
return
}
req.Header.Set("Accept", "text/event-stream")
Expand All @@ -150,22 +158,33 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l

resp, err := r.c.Do(req)
if err != nil || resp == nil {
if resp != nil {
resp.Body.Close()
if resp == nil {
err = errors.New("received nil response")
} else {
defer resp.Body.Close()
}
select {
case errChan <- fmt.Errorf("failed to send request: %w", err):
default:
}
errChan <- fmt.Errorf("failed to send request: %w", err)
return
}

if resp.StatusCode != http.StatusOK {
errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode)
select {
case errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode):
default:
}
return
}

reader := bufio.NewReader(resp.Body)
var buf bytes.Buffer
lineChan := make(chan []byte)

g, ctx := errgroup.WithContext(ctx)
done := make(chan struct{})

g.Go(func() error {
defer close(lineChan)
defer resp.Body.Close()
Expand Down Expand Up @@ -208,18 +227,22 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
b := buf.Bytes()
buf.Reset()

event := SSEEvent{Type: SSETypeDefault}
if err := event.decode(b); err != nil {
event, err := decodeSSEEvent(b)
if err != nil {
select {
case errChan <- err:
default:
}
close(done)
return
continue
}

if event == nil {
// Skip empty events
continue
}

select {
case sseChan <- event:
case sseChan <- *event:
case <-done:
return
case <-ctx.Done():
Expand All @@ -238,9 +261,6 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
go func() {
err := g.Wait()

defer close(sseChan)
defer close(errChan)

if err != nil {
if errors.Is(err, io.EOF) {
// Attempt to reconnect if the connection was closed before the stream was done
Expand All @@ -255,5 +275,8 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
}
}
}

close(sseChan)
close(errChan)
}()
}