Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions azurebs/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ type AzBlobstore struct {
storageClient StorageClient
}

// Single blob put threshold is 32MB
const singleBlobPutThreshold = int64(32 * 1024 * 1024)

func getFileSize(source *os.File) (int64, error) {
fileInfo, err := source.Stat()
if err != nil {
return 0, fmt.Errorf("failed to get file stat: %w", err)
}
return fileInfo.Size(), nil
}

func New(storageClient StorageClient) (AzBlobstore, error) {
return AzBlobstore{storageClient: storageClient}, nil
}
Expand All @@ -30,24 +41,36 @@ func (client *AzBlobstore) Put(sourceFilePath string, dest string) error {
return err
}
defer source.Close() //nolint:errcheck

md5, err := client.storageClient.Upload(source, dest)
fileSize, err := getFileSize(source)
if err != nil {
return fmt.Errorf("upload failure: %w", err)
return err
}
if fileSize <= singleBlobPutThreshold {
md5, err := client.storageClient.Upload(source, dest)
if err != nil {
return fmt.Errorf("upload failure: %w", err)
}

if !bytes.Equal(sourceMD5, md5) {
slog.Error("Upload failed due to MD5 mismatch, deleting blob", "blob", dest, "expected_md5", fmt.Sprintf("%x", sourceMD5), "received_md5", fmt.Sprintf("%x", md5))
if !bytes.Equal(sourceMD5, md5) {
slog.Error("Upload failed due to MD5 mismatch, deleting blob", "blob", dest, "expected_md5", fmt.Sprintf("%x", sourceMD5), "received_md5", fmt.Sprintf("%x", md5))

err := client.storageClient.Delete(dest)
if err != nil {
slog.Error("Failed to delete blob after MD5 mismatch", "blob", dest, "error", err)
err := client.storageClient.Delete(dest)
if err != nil {
slog.Error("Failed to delete blob after MD5 mismatch", "blob", dest, "error", err)

}
return fmt.Errorf("MD5 mismatch: expected %x, got %x", sourceMD5, md5)
}

slog.Debug("MD5 verification passed", "blob", dest, "md5", fmt.Sprintf("%x", md5))

} else {
err := client.storageClient.UploadStream(source, dest)
if err != nil {
return fmt.Errorf("upload failure: %w", err)
}
return fmt.Errorf("MD5 mismatch: expected %x, got %x", sourceMD5, md5)
}

slog.Debug("MD5 verification passed", "blob", dest, "md5", fmt.Sprintf("%x", md5))
return nil
}

Expand Down
24 changes: 24 additions & 0 deletions azurebs/client/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client_test

import (
"bytes"
"errors"
"os"
"runtime"
Expand Down Expand Up @@ -32,6 +33,29 @@ var _ = Describe("Client", func() {
Expect(dest).To(Equal("target/blob"))
})

It("uploads a file with UploadStream", func() {
storageClient := clientfakes.FakeStorageClient{}

azBlobstore, err := client.New(&storageClient)
Expect(err).ToNot(HaveOccurred())

file, _ := os.CreateTemp("", "tmpfile-test-upload") //nolint:errcheck
defer os.Remove(file.Name()) //nolint:errcheck

contentSize := 1024 * 1024 * 64 // 64MB

content := bytes.Repeat([]byte("x"), contentSize)
_, _ = file.Write(content) //nolint:errcheck

azBlobstore.Put(file.Name(), "target/blob") //nolint:errcheck

Expect(storageClient.UploadStreamCallCount()).To(Equal(1))
source, dest := storageClient.UploadStreamArgsForCall(0)

Expect(source).To(BeAssignableToTypeOf((*os.File)(nil)))
Expect(dest).To(Equal("target/blob"))
})

It("skips the upload if the md5 cannot be calculated from the file", func() {
storageClient := clientfakes.FakeStorageClient{}

Expand Down
94 changes: 74 additions & 20 deletions azurebs/client/clientfakes/fake_storage_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

96 changes: 78 additions & 18 deletions azurebs/client/storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ type StorageClient interface {
dest string,
) ([]byte, error)

UploadStream(
source io.ReadSeekCloser,
dest string,
) error

Download(
source string,
dest *os.File,
Expand Down Expand Up @@ -68,6 +73,36 @@ type StorageClient interface {
EnsureContainerExists() error
}

// 4 MB of block size
const blockSize = int64(4 * 1024 * 1024)

// number of go routines
const maxConcurrency = 5

func createContext(dsc DefaultStorageClient) (context.Context, context.CancelFunc, error) {
var ctx context.Context
var cancel context.CancelFunc

if dsc.storageConfig.Timeout != "" {
timeoutInt, err := strconv.Atoi(dsc.storageConfig.Timeout)
timeout := time.Duration(timeoutInt) * time.Second
if timeout < 1 && err == nil {
slog.Info("Invalid time, need at least 1 second", "timeout", dsc.storageConfig.Timeout)
return nil, nil, fmt.Errorf("invalid time: %w", err)
}
if err != nil {
slog.Info("Invalid timeout format, need seconds as number e.g. 30s", "timeout", dsc.storageConfig.Timeout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

30s or 30?

return nil, nil, fmt.Errorf("invalid timeout format: %w", err)
}
ctx, cancel = context.WithTimeout(context.Background(), timeout)
} else {
ctx, cancel = context.WithCancel(context.Background())
}

return ctx, cancel, nil

}

type DefaultStorageClient struct {
credential *azblob.SharedKeyCredential
serviceURL string
Expand All @@ -91,33 +126,23 @@ func (dsc DefaultStorageClient) Upload(
) ([]byte, error) {
blobURL := fmt.Sprintf("%s/%s", dsc.serviceURL, dest)

var ctx context.Context
var cancel context.CancelFunc

if dsc.storageConfig.Timeout != "" {
timeoutInt, err := strconv.Atoi(dsc.storageConfig.Timeout)
timeout := time.Duration(timeoutInt) * time.Second
if timeout < 1 && err == nil {
slog.Info("Invalid time, need at least 1 second", "timeout", dsc.storageConfig.Timeout)
return nil, fmt.Errorf("invalid time: %w", err)
}
if err != nil {
slog.Info("Invalid timeout format, need seconds as number e.g. 30s", "timeout", dsc.storageConfig.Timeout)
return nil, fmt.Errorf("invalid timeout format: %w", err)
}
slog.Info("Uploading blob to container", "container", dsc.storageConfig.ContainerName, "blob", dest, "url", blobURL, "timeout", timeout.String())

ctx, cancel = context.WithTimeout(context.Background(), timeout)
slog.Info("Uploading blob to container", "container", dsc.storageConfig.ContainerName, "blob", dest, "url", blobURL, "timeout", dsc.storageConfig.Timeout)
} else {
slog.Info("Uploading blob to container", "container", dsc.storageConfig.ContainerName, "blob", dest, "url", blobURL)
ctx, cancel = context.WithCancel(context.Background())
}

ctx, cancel, err := createContext(dsc)
if err != nil {
return nil, err
}
defer cancel()

client, err := blockblob.NewClientWithSharedKeyCredential(blobURL, dsc.credential, nil)
if err != nil {
return nil, err
}

uploadResponse, err := client.Upload(ctx, source, nil)
if err != nil {
if dsc.storageConfig.Timeout != "" && errors.Is(err, context.DeadlineExceeded) {
Expand All @@ -127,7 +152,42 @@ func (dsc DefaultStorageClient) Upload(
}

slog.Info("Successfully uploaded blob", "container", dsc.storageConfig.ContainerName, "blob", dest)
return uploadResponse.ContentMD5, err
return uploadResponse.ContentMD5, nil
}

func (dsc DefaultStorageClient) UploadStream(
source io.ReadSeekCloser,
dest string,
) error {
blobURL := fmt.Sprintf("%s/%s", dsc.serviceURL, dest)

if dsc.storageConfig.Timeout != "" {
slog.Info("UploadStreaming blob to container", "container", dsc.storageConfig.ContainerName, "blob", dest, "url", blobURL, "timeout", dsc.storageConfig.Timeout)
} else {
slog.Info("UploadStreaming blob to container", "container", dsc.storageConfig.ContainerName, "blob", dest, "url", blobURL)
}

ctx, cancel, err := createContext(dsc)
if err != nil {
return err
}
defer cancel()

client, err := blockblob.NewClientWithSharedKeyCredential(blobURL, dsc.credential, nil)
if err != nil {
return err
}

_, err = client.UploadStream(ctx, source, &azblob.UploadStreamOptions{BlockSize: blockSize, Concurrency: maxConcurrency})
if err != nil {
if dsc.storageConfig.Timeout != "" && errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("upload failed: timeout of %s reached while uploading %s", dsc.storageConfig.Timeout, dest)
}
return fmt.Errorf("upload failure: %w", err)
}

slog.Info("Successfully uploaded blob", "container", dsc.storageConfig.ContainerName, "blob", dest)
return nil
}

func (dsc DefaultStorageClient) Download(
Expand Down
Loading
Loading