From 312fd285b62ef7184c580fe137b4c1d6673c847f Mon Sep 17 00:00:00 2001 From: Maxime Visonneau Date: Tue, 16 Apr 2024 12:09:11 +0200 Subject: [PATCH] wip: attempt to support multi-threaded per layer downloads --- remotes/docker/fetcher.go | 168 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 162 insertions(+), 6 deletions(-) diff --git a/remotes/docker/fetcher.go b/remotes/docker/fetcher.go index ecf245933f7a6..3f2002bd339e1 100644 --- a/remotes/docker/fetcher.go +++ b/remotes/docker/fetcher.go @@ -21,16 +21,17 @@ import ( "encoding/json" "errors" "fmt" - "io" - "net/http" - "net/url" - "strings" - "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/images" "github.com/containerd/containerd/log" digest "github.com/opencontainers/go-digest" ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "sync" ) type dockerFetcher struct { @@ -128,7 +129,14 @@ func (r dockerFetcher) Fetch(ctx context.Context, desc ocispec.Descriptor) (io.R return nil, err } - rc, err := r.open(ctx, req, desc.MediaType, offset) + var rc io.ReadCloser + switch desc.MediaType { + case images.MediaTypeDockerSchema2LayerGzip: + rc, err = multiThreadedGet(ctx, *req) + default: + rc, err = r.open(ctx, req, desc.MediaType, offset) + } + if err != nil { // Store the error for referencing later if firstErr == nil { @@ -311,3 +319,151 @@ func (r dockerFetcher) open(ctx context.Context, req *request, mediatype string, return resp.Body, nil } + +// TODO: Integrate with the one above or rearrange +func open_(ctx context.Context, req request, start, end int) (_ io.ReadCloser, retErr error) { + req.header.Set("Accept", "*/*") + + req.header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + + resp, err := req.doWithRetries(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + if retErr != nil { + resp.Body.Close() + } + }() + + if resp.StatusCode > 299 { + // TODO(stevvooe): When doing a offset specific request, we should + // really distinguish between a 206 and a 200. In the case of 200, we + // can discard the bytes, hiding the seek behavior from the + // implementation. + + if resp.StatusCode == http.StatusNotFound { + return nil, fmt.Errorf("content at %v not found: %w", req.String(), errdefs.ErrNotFound) + } + var registryErr Errors + if err := json.NewDecoder(resp.Body).Decode(®istryErr); err != nil || registryErr.Len() < 1 { + return nil, fmt.Errorf("unexpected status code %v: %v", req.String(), resp.Status) + } + return nil, fmt.Errorf("unexpected status code %v: %s - Server message: %s", req.String(), resp.Status, registryErr.Error()) + } + + if start > 0 { + cr := resp.Header.Get("content-range") + if cr != "" { + if !strings.HasPrefix(cr, fmt.Sprintf("bytes %d-%d", start, end)) { + return nil, fmt.Errorf("unhandled content range in response: %v", cr) + + } + } else { + // TODO: Should any cases where use of content range + // without the proper header be considered? + // 206 responses? + + // Discard up to offset + // Could use buffer pool here but this case should be rare + n, err := io.Copy(io.Discard, io.LimitReader(resp.Body, int64(start))) + if err != nil { + return nil, fmt.Errorf("failed to discard to offset: %w", err) + } + if n != int64(start) { + return nil, errors.New("unable to discard to offset") + } + + } + } + + return resp.Body, nil +} + +// chunk of the layer +type chunk struct { + rc io.ReadCloser + err error +} + +// multiReadCloser combines multiple io.ReadClosers into a single io.ReadCloser +type multiReadCloser struct { + reader io.Reader + closers []io.Closer +} + +// Read delegates reading to the embedded io.Reader +func (mrc *multiReadCloser) Read(p []byte) (n int, err error) { + return mrc.reader.Read(p) +} + +// Close goes through all closers and closes them +func (mrc *multiReadCloser) Close() error { + var allErrors error + for _, closer := range mrc.closers { + if err := closer.Close(); err != nil { + allErrors = errors.Join(allErrors, err) + } + } + return allErrors +} + +// multiThreadedGet wraps the get function to download in parallel using 8 threads +func multiThreadedGet(ctx context.Context, req request) (io.ReadCloser, error) { + headReq := req + headReq.method = "HEAD" + + headResp, err := headReq.doWithRetries(ctx, nil) + if err != nil { + return nil, err + } + + // TODO: Also check whether Accept-Ranges is defined? + contentLength, err := strconv.Atoi(headResp.Header.Get("Content-Length")) + if err != nil || contentLength == 0 { + return nil, fmt.Errorf("%w: determining content length", err) + } + + const numWorkers = 4 + var ( + chunkSize = contentLength / numWorkers + chunks = make([]*chunk, numWorkers) + mutex = sync.Mutex{} + wg sync.WaitGroup + ) + + wg.Add(numWorkers) + + for i := 0; i < numWorkers; i++ { + go func(idx int) { + defer wg.Done() + start := idx * chunkSize + end := start + chunkSize - 1 + + rc, err := open_(ctx, req, start, end) + if err != nil { + _ = rc.Close() + } + + mutex.Lock() + chunks[idx] = &chunk{rc: rc, err: err} + mutex.Unlock() + }(i) + } + + // Wait for all goroutines to finish + wg.Wait() + + var readers []io.Reader + var closers []io.Closer + for _, c := range chunks { + if c.err != nil { + return nil, c.err + } + + readers = append(readers, io.Reader(c.rc)) + closers = append(closers, io.Closer(c.rc)) + } + + return &multiReadCloser{reader: io.MultiReader(readers...), closers: closers}, nil +}