Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,20 @@ func (e *SSEEvent) String() string {
}
}

func (r *Client) sendError(err error, errChan chan error) {
select {
case errChan <- err:
default:
}
}

func (r *Client) Stream(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (<-chan SSEEvent, <-chan error) {
sseChan := make(chan SSEEvent, 64)
errChan := make(chan error, 64)

id, err := ParseIdentifier(identifier)
if err != nil {
errChan <- err
r.sendError(err, errChan)
return sseChan, errChan
}

Expand All @@ -115,7 +122,7 @@ func (r *Client) Stream(ctx context.Context, identifier string, input Prediction
}

if err != nil {
errChan <- err
r.sendError(err, errChan)
return sseChan, errChan
}

Expand All @@ -136,16 +143,13 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (
func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) {
url := prediction.URLs["stream"]
if url == "" {
errChan <- errors.New("streaming not supported or not enabled for this prediction")
r.sendError(errors.New("streaming not supported or not enabled for this prediction"), errChan)
return
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
select {
case errChan <- fmt.Errorf("failed to create request: %w", err):
default:
}
r.sendError(fmt.Errorf("failed to create request: %w", err), errChan)
return
}
req.Header.Set("Accept", "text/event-stream")
Expand All @@ -163,18 +167,12 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
} else {
defer resp.Body.Close()
}
select {
case errChan <- fmt.Errorf("failed to send request: %w", err):
default:
}
r.sendError(fmt.Errorf("failed to send request: %w", err), errChan)
return
}

if resp.StatusCode != http.StatusOK {
select {
case errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode):
default:
}
r.sendError(fmt.Errorf("received invalid status code: %d", resp.StatusCode), errChan)
return
}

Expand Down Expand Up @@ -229,10 +227,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l

event, err := decodeSSEEvent(b)
if err != nil {
select {
case errChan <- err:
default:
}
r.sendError(err, errChan)
continue
}

Expand Down Expand Up @@ -269,10 +264,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
}

if !errors.Is(err, context.Canceled) {
select {
case errChan <- err:
default:
}
r.sendError(err, errChan)
}
}

Expand Down