From 427afb51715a81aa20a9ef86bce8f01914f608e5 Mon Sep 17 00:00:00 2001 From: Gjermund Garaba Date: Fri, 10 Apr 2026 13:15:29 +0200 Subject: [PATCH 1/2] fix(audio): unblock chunked reader close during chunk prefetch --- audio/chunked-reader.go | 250 ++++++++++++++-------- audio/chunked-reader_integration_test.go | 119 ++++++++--- audio/chunked-reader_internal_test.go | 251 +++++++++++++++++++++++ 3 files changed, 502 insertions(+), 118 deletions(-) diff --git a/audio/chunked-reader.go b/audio/chunked-reader.go index dff60274..4f561cd3 100644 --- a/audio/chunked-reader.go +++ b/audio/chunked-reader.go @@ -1,6 +1,8 @@ package audio import ( + "context" + "errors" "fmt" "io" "net" @@ -45,10 +47,12 @@ func parseContentRange(resp *http.Response) (start int64, end int64, size int64, type chunkItem struct { *sync.Cond - data []byte fetching bool - err error +} + +func newChunkItem() *chunkItem { + return &chunkItem{Cond: sync.NewCond(&sync.Mutex{})} } type HttpChunkedReader struct { @@ -63,20 +67,35 @@ type HttpChunkedReader struct { len int64 pos int64 + prefetchMu sync.Mutex prefetchWg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc - initialLatency time.Duration - latencies []time.Duration + latMu sync.Mutex + latencies []time.Duration } func NewHttpChunkedReader(log librespot.Logger, client *http.Client, audioUrl string) (_ *HttpChunkedReader, err error) { - r := &HttpChunkedReader{log: log, client: client} + ctx, cancel := context.WithCancel(context.Background()) + r := &HttpChunkedReader{ + log: log, + client: client, + ctx: ctx, + cancel: cancel, + } r.url, err = url.Parse(audioUrl) if err != nil { return nil, fmt.Errorf("failed parsing resource url: %w", err) } + defer func() { + if err != nil { + r.cancel() + } + }() + // request the first chunk, needed for the complete content length resp, err := r.downloadChunk(0) if err != nil { @@ -92,19 +111,14 @@ func NewHttpChunkedReader(log librespot.Logger, client *http.Client, audioUrl st } // create the necessary amount of chunks - var totalChunks int64 - if r.len%DefaultChunkSize == 0 { - totalChunks = r.len / DefaultChunkSize - } else { - totalChunks = r.len/DefaultChunkSize + 1 - } + totalChunks := (r.len + DefaultChunkSize - 1) / DefaultChunkSize r.chunks = make([]*chunkItem, totalChunks) for i := int64(0); i < totalChunks; i++ { - r.chunks[i] = &chunkItem{Cond: sync.NewCond(&sync.Mutex{}), data: nil, err: nil} + r.chunks[i] = newChunkItem() } - r.chunks[0].data, err = io.ReadAll(r.measureLatency(true, resp.Body)) + r.chunks[0].data, err = io.ReadAll(r.measureLatency(resp.Body)) if err != nil { return nil, fmt.Errorf("failed reading first chunk: %w", err) } @@ -113,9 +127,26 @@ func NewHttpChunkedReader(log librespot.Logger, client *http.Client, audioUrl st return r, nil } +func (r *HttpChunkedReader) closeErr(err error) error { + if err != nil && r.isClosed() { + return net.ErrClosed + } + + return err +} + +func (r *HttpChunkedReader) isClosed() bool { + return r.ctx.Err() != nil +} + func (r *HttpChunkedReader) downloadChunk(idx int) (*http.Response, error) { + retryBackoff := backoff.WithContext( + backoff.WithMaxRetries(backoff.NewConstantBackOff(1*time.Second), 3), + r.ctx, + ) + return backoff.RetryWithData(func() (*http.Response, error) { - resp, err := r.client.Do(&http.Request{ + resp, err := r.client.Do((&http.Request{ Method: "GET", URL: r.url, Header: http.Header{ @@ -125,8 +156,13 @@ func (r *HttpChunkedReader) downloadChunk(idx int) (*http.Response, error) { min(max(r.len, DefaultChunkSize), int64((idx+1)*DefaultChunkSize))-1, )}, }, - }) + }).WithContext(r.ctx)) if err != nil { + err = r.closeErr(err) + if errors.Is(err, net.ErrClosed) { + return nil, backoff.Permanent(err) + } + return nil, err } @@ -136,67 +172,70 @@ func (r *HttpChunkedReader) downloadChunk(idx int) (*http.Response, error) { } return resp, nil - }, backoff.WithMaxRetries(backoff.NewConstantBackOff(1*time.Second), 3)) + }, retryBackoff) +} + +func (r *HttpChunkedReader) downloadAndRead(idx int) ([]byte, error) { + resp, err := r.downloadChunk(idx) + if err != nil { + return nil, fmt.Errorf("failed downloading chunk %d: %w", idx, r.closeErr(err)) + } + + defer func() { _ = resp.Body.Close() }() + + data, err := io.ReadAll(r.measureLatency(resp.Body)) + if err != nil { + return nil, fmt.Errorf("failed reading chunk %d: %w", idx, r.closeErr(err)) + } + + return data, nil } func (r *HttpChunkedReader) fetchChunk(idx int) ([]byte, error) { chunk := r.chunks[idx] - chunk.L.Lock() - // if the chunk is already being fetched, wait until it is done - for chunk.fetching { - for !(chunk.data != nil || chunk.err != nil) { - chunk.Wait() + chunk.L.Lock() + for { + if r.isClosed() { + chunk.L.Unlock() + return nil, net.ErrClosed } - } - - // chunk fetched, just return its data - if chunk.data != nil { - chunk.L.Unlock() - return chunk.data, nil - } - chunk.fetching = true - chunk.err = nil - chunk.L.Unlock() + if chunk.data != nil { + data := chunk.data + chunk.L.Unlock() + return data, nil + } - // download chunk - resp, err := r.downloadChunk(idx) - if err != nil { - // update chunk and signal not fetching - chunk.L.Lock() - chunk.err = err - chunk.fetching = false - chunk.Broadcast() - chunk.L.Unlock() + if !chunk.fetching { + chunk.fetching = true + chunk.L.Unlock() + break + } - return nil, fmt.Errorf("failed downloading chunk %d: %w", idx, chunk.err) + chunk.Wait() } - // ensure body gets closed - defer func() { _ = resp.Body.Close() }() - - // read the chunk data - data, err := io.ReadAll(r.measureLatency(false, resp.Body)) - if chunk.err != nil { - // update chunk and signal not fetching + data, err := r.downloadAndRead(idx) + if err != nil { chunk.L.Lock() - chunk.err = err chunk.fetching = false chunk.Broadcast() chunk.L.Unlock() - - return nil, fmt.Errorf("failed reading chunk %d: %w", idx, chunk.err) + return nil, err } - // update chunk and signal not fetching chunk.L.Lock() chunk.data = data chunk.fetching = false chunk.Broadcast() chunk.L.Unlock() - r.log.Debugf("fetched chunk %d/%d, size: %d", idx, len(r.chunks)-1, len(chunk.data)) + r.log.Debugf("fetched chunk %d/%d, size: %d", idx, len(r.chunks)-1, len(data)) + if r.isClosed() { + return nil, net.ErrClosed + } + return data, nil } @@ -206,12 +245,27 @@ func (r *HttpChunkedReader) prefetchChunks(curr int) { break } - r.prefetchWg.Add(1) - go func(i int) { - defer r.prefetchWg.Done() - _, _ = r.fetchChunk(i) - }(i) + if !r.startPrefetch(i) { + return + } + } +} + +func (r *HttpChunkedReader) startPrefetch(idx int) bool { + r.prefetchMu.Lock() + defer r.prefetchMu.Unlock() + + if r.isClosed() { + return false } + + r.prefetchWg.Add(1) + go func() { + defer r.prefetchWg.Done() + _, _ = r.fetchChunk(idx) + }() + + return true } func (r *HttpChunkedReader) Read(p []byte) (n int, err error) { @@ -221,6 +275,10 @@ func (r *HttpChunkedReader) Read(p []byte) (n int, err error) { } func (r *HttpChunkedReader) ReadAt(p []byte, pos int64) (n int, _ error) { + if r.isClosed() { + return 0, net.ErrClosed + } + chunkIdx, off := int(pos/DefaultChunkSize), int(pos%DefaultChunkSize) if chunkIdx >= len(r.chunks) { return 0, io.EOF @@ -289,19 +347,29 @@ func (r *HttpChunkedReader) Seek(offset int64, whence int) (int64, error) { } } -func (r *HttpChunkedReader) measureLatency(initial bool, rr io.Reader) io.Reader { +func (r *HttpChunkedReader) measureLatency(rr io.Reader) io.Reader { return &LatencyReader{ - Reader: rr, - Callback: func(latency time.Duration) { - if initial { - r.initialLatency = latency - } - - r.latencies = append(r.latencies, latency) - }, + Reader: rr, + Callback: r.recordLatency, } } +func (r *HttpChunkedReader) recordLatency(latency time.Duration) { + r.latMu.Lock() + defer r.latMu.Unlock() + + r.latencies = append(r.latencies, latency) +} + +func (r *HttpChunkedReader) latencySnapshot() []time.Duration { + r.latMu.Lock() + defer r.latMu.Unlock() + + latencies := make([]time.Duration, len(r.latencies)) + copy(latencies, r.latencies) + return latencies +} + func (r *HttpChunkedReader) Size() int64 { return r.len } @@ -311,16 +379,22 @@ func (r *HttpChunkedReader) Url() *url.URL { } func (r *HttpChunkedReader) InitialLatency() time.Duration { - return r.initialLatency + latencies := r.latencySnapshot() + if len(latencies) == 0 { + return 0 + } + + return latencies[0] } func (r *HttpChunkedReader) MaxLatency() time.Duration { - if len(r.latencies) == 0 { + latencies := r.latencySnapshot() + if len(latencies) == 0 { return 0 } - maxLatency := r.latencies[0] - for _, latency := range r.latencies { + maxLatency := latencies[0] + for _, latency := range latencies { if latency > maxLatency { maxLatency = latency } @@ -330,12 +404,13 @@ func (r *HttpChunkedReader) MaxLatency() time.Duration { } func (r *HttpChunkedReader) MinLatency() time.Duration { - if len(r.latencies) == 0 { + latencies := r.latencySnapshot() + if len(latencies) == 0 { return 0 } - minLatency := r.latencies[0] - for _, latency := range r.latencies { + minLatency := latencies[0] + for _, latency := range latencies { if latency < minLatency { minLatency = latency } @@ -345,26 +420,25 @@ func (r *HttpChunkedReader) MinLatency() time.Duration { } func (r *HttpChunkedReader) AvgLatencyMs() float64 { - if len(r.latencies) == 0 { + latencies := r.latencySnapshot() + if len(latencies) == 0 { return 0 } var sum time.Duration - for _, latency := range r.latencies { + for _, latency := range latencies { sum += latency } - return float64(sum.Milliseconds()) / float64(len(r.latencies)) + return float64(sum.Milliseconds()) / float64(len(latencies)) } func (r *HttpChunkedReader) MedianLatency() time.Duration { - if len(r.latencies) == 0 { + latencies := r.latencySnapshot() + if len(latencies) == 0 { return 0 } - latencies := make([]time.Duration, len(r.latencies)) - copy(latencies, r.latencies) - sort.Slice(latencies, func(i, j int) bool { return latencies[i] < latencies[j] }) @@ -378,21 +452,23 @@ func (r *HttpChunkedReader) MedianLatency() time.Duration { } func (r *HttpChunkedReader) TotalTime() time.Duration { + latencies := r.latencySnapshot() + var sum time.Duration - for _, latency := range r.latencies { + for _, latency := range latencies { sum += latency } return sum } func (r *HttpChunkedReader) Close() error { + r.prefetchMu.Lock() + r.cancel() + r.prefetchMu.Unlock() + for _, chunk := range r.chunks { chunk.L.Lock() - if chunk.fetching { - chunk.err = net.ErrClosed - chunk.fetching = false - chunk.Broadcast() - } + chunk.Broadcast() chunk.L.Unlock() } diff --git a/audio/chunked-reader_integration_test.go b/audio/chunked-reader_integration_test.go index 92d6da41..0acaa1e9 100644 --- a/audio/chunked-reader_integration_test.go +++ b/audio/chunked-reader_integration_test.go @@ -5,9 +5,11 @@ package audio_test import ( "fmt" "io" + "net" "net/http" "net/http/httptest" "sync" + "sync/atomic" "testing" "time" @@ -25,6 +27,19 @@ type HttpChunkedReaderIntegrationSuite struct { server *httptest.Server } +func (suite *HttpChunkedReaderIntegrationSuite) newReader(server *httptest.Server) *audio.HttpChunkedReader { + suite.T().Helper() + + if server == nil { + server = suite.server + } + + reader, err := audio.NewHttpChunkedReader(suite.logger, server.Client(), server.URL) + suite.Require().NoError(err) + suite.T().Cleanup(func() { _ = reader.Close() }) + return reader +} + func (suite *HttpChunkedReaderIntegrationSuite) SetupTest() { suite.logger = &librespot.NullLogger{} @@ -67,9 +82,7 @@ func (suite *HttpChunkedReaderIntegrationSuite) handleHTTPRequest(w http.Respons } func (suite *HttpChunkedReaderIntegrationSuite) TestBasicReadOperations() { - reader, err := audio.NewHttpChunkedReader(suite.logger, suite.server.Client(), suite.server.URL) - suite.Require().NoError(err) - suite.T().Cleanup(func() { _ = reader.Close() }) + reader := suite.newReader(nil) // Test basic read buf := make([]byte, 1000) @@ -87,9 +100,7 @@ func (suite *HttpChunkedReaderIntegrationSuite) TestBasicReadOperations() { } func (suite *HttpChunkedReaderIntegrationSuite) TestLargeSequentialRead() { - reader, err := audio.NewHttpChunkedReader(suite.logger, suite.server.Client(), suite.server.URL) - suite.Require().NoError(err) - suite.T().Cleanup(func() { _ = reader.Close() }) + reader := suite.newReader(nil) // Read the entire file result, err := io.ReadAll(reader) @@ -99,9 +110,7 @@ func (suite *HttpChunkedReaderIntegrationSuite) TestLargeSequentialRead() { } func (suite *HttpChunkedReaderIntegrationSuite) TestRandomAccessPattern() { - reader, err := audio.NewHttpChunkedReader(suite.logger, suite.server.Client(), suite.server.URL) - suite.Require().NoError(err) - suite.T().Cleanup(func() { _ = reader.Close() }) + reader := suite.newReader(nil) positions := []int64{ 0, @@ -128,9 +137,7 @@ func (suite *HttpChunkedReaderIntegrationSuite) TestRandomAccessPattern() { } func (suite *HttpChunkedReaderIntegrationSuite) TestConcurrentReads() { - reader, err := audio.NewHttpChunkedReader(suite.logger, suite.server.Client(), suite.server.URL) - suite.Require().NoError(err) - suite.T().Cleanup(func() { _ = reader.Close() }) + reader := suite.newReader(nil) const numGoroutines = 10 const readSize = 1024 @@ -177,9 +184,7 @@ func (suite *HttpChunkedReaderIntegrationSuite) TestConcurrentReads() { } func (suite *HttpChunkedReaderIntegrationSuite) TestSeekOperations() { - reader, err := audio.NewHttpChunkedReader(suite.logger, suite.server.Client(), suite.server.URL) - suite.Require().NoError(err) - suite.T().Cleanup(func() { _ = reader.Close() }) + reader := suite.newReader(nil) // Test various seek operations testCases := []struct { @@ -213,12 +218,11 @@ func (suite *HttpChunkedReaderIntegrationSuite) TestSeekOperations() { } func (suite *HttpChunkedReaderIntegrationSuite) TestPrefetching() { - reader, err := audio.NewHttpChunkedReader(suite.logger, suite.server.Client(), suite.server.URL) - suite.Require().NoError(err) - suite.T().Cleanup(func() { _ = reader.Close() }) + reader := suite.newReader(nil) // Read from the beginning to trigger prefetching buf := make([]byte, 1000) + var err error _, err = reader.ReadAt(buf, 0) suite.Require().NoError(err) @@ -242,9 +246,7 @@ func (suite *HttpChunkedReaderIntegrationSuite) TestPrefetching() { } func (suite *HttpChunkedReaderIntegrationSuite) TestLatencyMetrics() { - reader, err := audio.NewHttpChunkedReader(suite.logger, suite.server.Client(), suite.server.URL) - suite.Require().NoError(err) - suite.T().Cleanup(func() { _ = reader.Close() }) + reader := suite.newReader(nil) // Trigger multiple reads to generate latency data buf := make([]byte, 1000) @@ -269,10 +271,9 @@ func (suite *HttpChunkedReaderIntegrationSuite) TestLatencyMetrics() { func (suite *HttpChunkedReaderIntegrationSuite) TestErrorRecovery() { // Create a server that fails occasionally - failCount := 0 + var failCount atomic.Int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - failCount++ - if failCount%3 == 0 { // Fail every 3rd request + if failCount.Add(1)%3 == 0 { // Fail every 3rd request w.WriteHeader(http.StatusInternalServerError) return } @@ -280,9 +281,7 @@ func (suite *HttpChunkedReaderIntegrationSuite) TestErrorRecovery() { })) defer server.Close() - reader, err := audio.NewHttpChunkedReader(suite.logger, server.Client(), server.URL) - suite.Require().NoError(err) - suite.T().Cleanup(func() { _ = reader.Close() }) + reader := suite.newReader(server) // Try to read multiple chunks - some will fail and retry buf := make([]byte, 1000) @@ -295,14 +294,72 @@ func (suite *HttpChunkedReaderIntegrationSuite) TestErrorRecovery() { } } +func (suite *HttpChunkedReaderIntegrationSuite) TestCloseCancelsInFlightRead() { + started := make(chan struct{}) + var startedOnce sync.Once + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Range") == fmt.Sprintf("bytes=%d-%d", audio.DefaultChunkSize, audio.DefaultChunkSize*2-1) { + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", audio.DefaultChunkSize, audio.DefaultChunkSize*2-1, len(suite.testData))) + w.WriteHeader(http.StatusPartialContent) + _, _ = w.Write(suite.testData[audio.DefaultChunkSize : audio.DefaultChunkSize+1]) + w.(http.Flusher).Flush() + startedOnce.Do(func() { + close(started) + }) + <-r.Context().Done() + return + } + + suite.handleHTTPRequest(w, r) + })) + defer server.Close() + + reader := suite.newReader(server) + + errCh := make(chan error, 1) + go func() { + buf := make([]byte, 1024) + _, err := reader.ReadAt(buf, audio.DefaultChunkSize) + errCh <- err + }() + + select { + case <-started: + case <-time.After(time.Second): + suite.T().Fatal("second chunk did not start") + } + + closeErrCh := make(chan error, 1) + go func() { + closeErrCh <- reader.Close() + }() + + select { + case err := <-closeErrCh: + suite.Require().NoError(err) + case <-time.After(time.Second): + suite.T().Fatal("Close did not return") + } + + select { + case err := <-errCh: + suite.Require().ErrorIs(err, net.ErrClosed) + case <-time.After(time.Second): + suite.T().Fatal("ReadAt did not return") + } + + buf := make([]byte, 1) + _, err := reader.ReadAt(buf, 0) + suite.Require().ErrorIs(err, net.ErrClosed) +} + func (suite *HttpChunkedReaderIntegrationSuite) TestBoundaryConditions() { - reader, err := audio.NewHttpChunkedReader(suite.logger, suite.server.Client(), suite.server.URL) - suite.Require().NoError(err) - suite.T().Cleanup(func() { _ = reader.Close() }) + reader := suite.newReader(nil) // Test reading at exact chunk boundary buf := make([]byte, 100) - _, err = reader.ReadAt(buf, audio.DefaultChunkSize) + _, err := reader.ReadAt(buf, audio.DefaultChunkSize) suite.Require().NoError(err) // Test reading across chunk boundary diff --git a/audio/chunked-reader_internal_test.go b/audio/chunked-reader_internal_test.go index d3e89b49..7f7450ad 100644 --- a/audio/chunked-reader_internal_test.go +++ b/audio/chunked-reader_internal_test.go @@ -3,9 +3,18 @@ package audio import ( + "context" + "errors" + "io" + "net" "net/http" + "net/url" + "strings" + "sync" "testing" + "time" + librespot "github.com/devgianlu/go-librespot" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -91,3 +100,245 @@ func TestParseContentRange(t *testing.T) { }) } } + +func TestCloseCancelsConcurrentFetchChunkCallers(t *testing.T) { + transport, started := newBlockingRoundTripper() + reader := newFetchTestReader(t, transport) + + errCh := make(chan error, 2) + go func() { + _, err := reader.fetchChunk(0) + errCh <- err + }() + + select { + case <-started: + case <-time.After(time.Second): + t.Fatal("first fetchChunk did not start downloading") + } + + go func() { + _, err := reader.fetchChunk(0) + errCh <- err + }() + + closeErrCh := make(chan error, 1) + go func() { + closeErrCh <- reader.Close() + }() + + select { + case err := <-closeErrCh: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("Close did not return") + } + + for i := 0; i < 2; i++ { + select { + case err := <-errCh: + require.ErrorIs(t, err, net.ErrClosed) + case <-time.After(time.Second): + t.Fatal("fetchChunk did not return") + } + } +} + +func TestCloseCancelsInFlightFetchChunk(t *testing.T) { + transport, started := newBlockingRoundTripper() + testCloseCancelsFetchChunk(t, transport, started, time.Second) +} + +func TestCloseCancelsRetryBackoffSleep(t *testing.T) { + transport, started := newRetryThenBlockRoundTripper() + testCloseCancelsFetchChunk(t, transport, started, 250*time.Millisecond) +} + +func TestCloseBeforeChunkPublishesReturnsErrClosed(t *testing.T) { + chunkReader := strings.NewReader("chunk") + beforeEOF := make(chan struct{}) + release := make(chan struct{}) + + body := io.NopCloser(readerFunc(func(p []byte) (int, error) { + n, err := chunkReader.Read(p) + if err == io.EOF { + close(beforeEOF) + <-release + } + + return n, err + })) + + reader := newFetchTestReader(t, roundTripperFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusPartialContent, + Status: "206 Partial Content", + Body: body, + Header: http.Header{}, + }, nil + })) + + errCh := make(chan error, 1) + go func() { + _, err := reader.fetchChunk(0) + errCh <- err + }() + + select { + case <-beforeEOF: + case <-time.After(time.Second): + t.Fatal("fetchChunk did not finish reading chunk data") + } + + require.NoError(t, reader.Close()) + close(release) + + select { + case err := <-errCh: + require.ErrorIs(t, err, net.ErrClosed) + case <-time.After(time.Second): + t.Fatal("fetchChunk did not return") + } +} + +func TestCloseNormalizesBodyReadTransportErrors(t *testing.T) { + reader := newFetchTestReader(t, nil) + + started := make(chan struct{}) + var startedOnce sync.Once + body := io.NopCloser(readerFunc(func([]byte) (int, error) { + startedOnce.Do(func() { + close(started) + }) + + <-reader.ctx.Done() + return 0, errors.New("use of closed network connection") + })) + + reader.client.Transport = roundTripperFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusPartialContent, + Status: "206 Partial Content", + Body: body, + Header: http.Header{}, + }, nil + }) + + errCh := make(chan error, 1) + go func() { + _, err := reader.fetchChunk(0) + errCh <- err + }() + + select { + case <-started: + case <-time.After(time.Second): + t.Fatal("fetchChunk did not start reading response body") + } + + require.NoError(t, reader.Close()) + + select { + case err := <-errCh: + require.ErrorIs(t, err, net.ErrClosed) + case <-time.After(time.Second): + t.Fatal("fetchChunk did not return") + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +type readerFunc func([]byte) (int, error) + +func (fn readerFunc) Read(p []byte) (int, error) { + return fn(p) +} + +func newFetchTestReader(t *testing.T, transport http.RoundTripper) *HttpChunkedReader { + t.Helper() + + chunkURL, err := url.Parse("https://example.com/audio") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + return &HttpChunkedReader{ + log: &librespot.NullLogger{}, + client: &http.Client{Transport: transport}, + url: chunkURL, + len: DefaultChunkSize, + chunks: []*chunkItem{newChunkItem()}, + ctx: ctx, + cancel: cancel, + } +} + +func newBlockingRoundTripper() (http.RoundTripper, <-chan struct{}) { + started := make(chan struct{}) + var startedOnce sync.Once + + return roundTripperFunc(func(req *http.Request) (*http.Response, error) { + startedOnce.Do(func() { + close(started) + }) + + <-req.Context().Done() + return nil, req.Context().Err() + }), started +} + +func newRetryThenBlockRoundTripper() (http.RoundTripper, <-chan struct{}) { + started := make(chan struct{}) + attempts := 0 + + return roundTripperFunc(func(req *http.Request) (*http.Response, error) { + attempts++ + if attempts == 1 { + close(started) + return nil, errors.New("transient") + } + + <-req.Context().Done() + return nil, req.Context().Err() + }), started +} + +func testCloseCancelsFetchChunk(t *testing.T, transport http.RoundTripper, started <-chan struct{}, closeTimeout time.Duration) { + t.Helper() + + reader := newFetchTestReader(t, transport) + + errCh := make(chan error, 1) + go func() { + _, err := reader.fetchChunk(0) + errCh <- err + }() + + select { + case <-started: + case <-time.After(time.Second): + t.Fatal("fetchChunk did not start downloading") + } + + closeErrCh := make(chan error, 1) + go func() { + closeErrCh <- reader.Close() + }() + + select { + case err := <-closeErrCh: + require.NoError(t, err) + case <-time.After(closeTimeout): + t.Fatal("Close did not return") + } + + select { + case err := <-errCh: + require.ErrorIs(t, err, net.ErrClosed) + case <-time.After(2 * time.Second): + t.Fatal("fetchChunk did not return") + } +} From 9d2685cc38ef3d42a91200fc311b4fc5e4a5090d Mon Sep 17 00:00:00 2001 From: Gjermund Garaba Date: Wed, 15 Apr 2026 16:40:09 +0200 Subject: [PATCH 2/2] cr fixes --- audio/chunked-reader.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/audio/chunked-reader.go b/audio/chunked-reader.go index 4f561cd3..b9013348 100644 --- a/audio/chunked-reader.go +++ b/audio/chunked-reader.go @@ -168,7 +168,7 @@ func (r *HttpChunkedReader) downloadChunk(idx int) (*http.Response, error) { if resp.StatusCode != http.StatusPartialContent { _ = resp.Body.Close() - return nil, fmt.Errorf("invalid first chunk response status: %s", resp.Status) + return nil, fmt.Errorf("unexpected chunk response status: got %s, expected %d", resp.Status, http.StatusPartialContent) } return resp, nil