From 5aee694ce7a6c6162a2c537aaa021b804f830921 Mon Sep 17 00:00:00 2001 From: "x.zhou" Date: Wed, 28 Feb 2024 20:43:04 +0800 Subject: [PATCH 1/2] feat: implement encryptedStorage --- cmd/push.go | 2 +- pkg/app/globalstorage.go | 30 ++++- pkg/encryption/aes_encryptor.go | 96 ++++++++++++++ pkg/encryption/common.go | 55 ++++++++ pkg/encryption/encryption.go | 56 ++++++++ pkg/encryption/encryption_test.go | 204 +++++++++++++++++++++++++++++ pkg/encryption/encryptor.go | 17 +++ pkg/storage/encrypted/encrypted.go | 177 +++++++++++++++++++++++++ pkg/util/discard.go | 56 +++++++- 9 files changed, 686 insertions(+), 7 deletions(-) create mode 100644 pkg/encryption/aes_encryptor.go create mode 100644 pkg/encryption/common.go create mode 100644 pkg/encryption/encryption.go create mode 100644 pkg/encryption/encryption_test.go create mode 100644 pkg/encryption/encryptor.go create mode 100644 pkg/storage/encrypted/encrypted.go diff --git a/cmd/push.go b/cmd/push.go index 7b63bd2..d636205 100644 --- a/cmd/push.go +++ b/cmd/push.go @@ -77,7 +77,7 @@ func doPush(opts *pushOptions, cmd *cobra.Command, args []string) { pr, pw := io.Pipe() go func(r io.Reader) { // discard compression header - err := c.Compress(util.DiscardN(pw, compressionHeaderSize), r) + err := c.Compress(util.DiscardNWriter(pw, compressionHeaderSize), r) pw.CloseWithError(err) }(in) in = pr diff --git a/pkg/app/globalstorage.go b/pkg/app/globalstorage.go index 05908a0..4b26f60 100644 --- a/pkg/app/globalstorage.go +++ b/pkg/app/globalstorage.go @@ -8,13 +8,17 @@ import ( "strings" "github.com/apecloud/datasafed/pkg/config" + "github.com/apecloud/datasafed/pkg/encryption" "github.com/apecloud/datasafed/pkg/storage" + "github.com/apecloud/datasafed/pkg/storage/encrypted" "github.com/apecloud/datasafed/pkg/storage/kopia" "github.com/apecloud/datasafed/pkg/storage/rclone" ) const ( backendBasePathEnv = "DATASAFED_BACKEND_BASE_PATH" + encryptionAlgorithm = "DATASAFED_ENCRYPTION_ALGORITHM" + encryptionPassPhrase = "DATASAFED_ENCRYPTION_PASS_PHRASE" kopiaRepoRootEnv = "DATASAFED_KOPIA_REPO_ROOT" kopiaPasswordEnv = "DATASAFED_KOPIA_PASSWORD" kopiaDisableCacheEnv = "DATASAFED_KOPIA_DISABLE_CACHE" @@ -36,15 +40,37 @@ func InitGlobalStorage(ctx context.Context, configFile string) error { storageConf := config.GetGlobal().GetAll(config.StorageSection) if kopiaRoot := strings.TrimSpace(os.Getenv(kopiaRepoRootEnv)); kopiaRoot != "" { - return initKopiaStorage(ctx, storageConf, basePath, kopiaRoot) + err := initKopiaStorage(ctx, storageConf, basePath, kopiaRoot) + if err != nil { + return err + } } else { st, err := createStorage(ctx, storageConf, basePath) if err != nil { return err } globalStorage = st - return nil } + + // wrap with encryptedStorage + encAlgo := os.Getenv(encryptionAlgorithm) + if encAlgo != "" { + encPass := os.Getenv(encryptionPassPhrase) + if encPass == "" { + return fmt.Errorf("encryption pass phrase should not be empty") + } + enc, err := encryption.CreateEncryptor(encAlgo, []byte(encPass)) + if err != nil { + return err + } + encSt, err := encrypted.New(ctx, enc, globalStorage) + if err != nil { + return err + } + globalStorage = encSt + } + + return nil } func GetGlobalStorage() (storage.Storage, error) { diff --git a/pkg/encryption/aes_encryptor.go b/pkg/encryption/aes_encryptor.go new file mode 100644 index 0000000..b3d3918 --- /dev/null +++ b/pkg/encryption/aes_encryptor.go @@ -0,0 +1,96 @@ +package encryption + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" +) + +const bufferSize = ((128 * 1024) / aes.BlockSize) * aes.BlockSize + +type aesEncryptor struct { + key []byte + newEncryptor func(block cipher.Block, iv []byte) cipher.Stream + newDecryptor func(block cipher.Block, iv []byte) cipher.Stream +} + +func (e *aesEncryptor) EncryptStream(plainText io.Reader, output io.Writer) error { + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return fmt.Errorf("rand.Read() error: %w", err) + } + n, err := output.Write(iv) + if err != nil { + return fmt.Errorf("write IV error: %w", err) + } + if n != len(iv) { + return fmt.Errorf("partially write IV, len: %d, written: %d", len(iv), n) + } + block, err := aes.NewCipher(e.key) + if err != nil { + return fmt.Errorf("aes.NewCipher() error: %w", err) + } + enc := e.newEncryptor(block, iv) + buf := make([]byte, bufferSize) + return pipeStream(plainText, "plainText", output, "cipherText", buf, func(b []byte) { + enc.XORKeyStream(b, b) + }) +} + +func (e *aesEncryptor) DecryptStream(cipherText io.Reader, output io.Writer) error { + iv := make([]byte, aes.BlockSize) + _, err := io.ReadFull(cipherText, iv) + if err != nil { + return fmt.Errorf("unable to read iv from cipherText, error: %w", err) + } + block, err := aes.NewCipher(e.key) + if err != nil { + return fmt.Errorf("aes.NewCipher() error: %w", err) + } + dec := e.newDecryptor(block, iv) + buf := make([]byte, bufferSize) + return pipeStream(cipherText, "cipherText", output, "plainText", buf, func(b []byte) { + dec.XORKeyStream(b, b) + }) +} + +func (e *aesEncryptor) Overhead() int { + return aes.BlockSize +} + +func newAESCFB(passPhrase []byte, keyLength int) (StreamEncryptor, error) { + deriveLength := keyLength + if deriveLength < minDerivedKeyLength { + deriveLength = minDerivedKeyLength + } + key, err := deriveKey(passPhrase, []byte(purposeEncryptionKey), deriveLength) + if err != nil { + return nil, fmt.Errorf("deriveKey() error: %w", err) + } + ae := &aesEncryptor{ + key: key[:keyLength], + newEncryptor: cipher.NewCFBEncrypter, + newDecryptor: cipher.NewCFBDecrypter, + } + return ae, nil +} + +func NewAES128CFB(passPhrase []byte) (StreamEncryptor, error) { + return newAESCFB(passPhrase, 16) +} + +func NewAES192CFB(passPhrase []byte) (StreamEncryptor, error) { + return newAESCFB(passPhrase, 24) +} + +func NewAES256CFB(passPhrase []byte) (StreamEncryptor, error) { + return newAESCFB(passPhrase, 32) +} + +func init() { + Register("AES128-CFB", "AES-128 with CFB mode", NewAES128CFB) + Register("AES192-CFB", "AES-192 with CFB mode", NewAES192CFB) + Register("AES256-CFB", "AES-256 with CFB mode", NewAES256CFB) +} diff --git a/pkg/encryption/common.go b/pkg/encryption/common.go new file mode 100644 index 0000000..11c297c --- /dev/null +++ b/pkg/encryption/common.go @@ -0,0 +1,55 @@ +package encryption + +import ( + "crypto/sha256" + "errors" + "fmt" + "io" + + "golang.org/x/crypto/hkdf" +) + +const minDerivedKeyLength = 32 + +func pipeStream(in io.Reader, inName string, + out io.Writer, outName string, + buf []byte, manipulateFunc func([]byte)) error { + for { + n, err := in.Read(buf) + if n > 0 { + data := buf[:n] + manipulateFunc(data) + wn, err := out.Write(data) + if err != nil { + return fmt.Errorf("write to %s error: %w", outName, err) + } + if wn != n { + return fmt.Errorf("partially write to %s, length: %d, actual write: %d", + outName, n, wn) + } + } + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } else { + return fmt.Errorf("read from %s error: %w", inName, err) + } + } + } +} + +// deriveKey uses HKDF to derive a key of a given length and a given purpose from parameters. +func deriveKey(passPhrase []byte, purpose []byte, length int) ([]byte, error) { + if length < minDerivedKeyLength { + return nil, fmt.Errorf("derived key must be at least 32 bytes, was %v", length) + } + + key := make([]byte, length) + k := hkdf.New(sha256.New, passPhrase, purpose, nil) + _, err := io.ReadFull(k, key) + if err != nil { + return nil, err + } + + return key, nil +} diff --git a/pkg/encryption/encryption.go b/pkg/encryption/encryption.go new file mode 100644 index 0000000..873d25f --- /dev/null +++ b/pkg/encryption/encryption.go @@ -0,0 +1,56 @@ +// Package encryption manages content encryption algorithms. +package encryption + +import ( + "sort" + "strings" + + "github.com/pkg/errors" +) + +const ( + purposeEncryptionKey = "encryption" +) + +// CreateEncryptor creates an StreamEncryptor for given parameters. +func CreateEncryptor(algorithm string, passPhrase []byte) (StreamEncryptor, error) { + algorithm = strings.ToUpper(algorithm) + e := encryptors[algorithm] + if e == nil { + return nil, errors.Errorf("unknown encryption algorithm: %v", algorithm) + } + + return e.newEncryptor(passPhrase) +} + +// EncryptorFactory creates new Encryptor for given parameters. +type EncryptorFactory func(passPhrase []byte) (StreamEncryptor, error) + +// DefaultAlgorithm is the name of the default encryption algorithm. +const DefaultAlgorithm = "AES256-CFB" + +// SupportedAlgorithms returns the names of the supported encryption methods. +func SupportedAlgorithms() []string { + var result []string + for k := range encryptors { + result = append(result, k) + } + sort.Strings(result) + return result +} + +// Register registers new encryption algorithm. +func Register(name, description string, newEncryptor EncryptorFactory) { + name = strings.ToUpper(name) + encryptors[name] = &encryptorInfo{ + description, + newEncryptor, + } +} + +type encryptorInfo struct { + description string + newEncryptor EncryptorFactory +} + +var encryptors = map[string]*encryptorInfo{} diff --git a/pkg/encryption/encryption_test.go b/pkg/encryption/encryption_test.go new file mode 100644 index 0000000..411652b --- /dev/null +++ b/pkg/encryption/encryption_test.go @@ -0,0 +1,204 @@ +package encryption_test + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "testing" + + "github.com/apecloud/datasafed/pkg/encryption" + "github.com/stretchr/testify/require" +) + +func TestRoundTrip(t *testing.T) { + data := make([]byte, 100) + rand.Read(data) + + passPhrase := make([]byte, 32) + rand.Read(passPhrase) + + for _, encryptionAlgo := range encryption.SupportedAlgorithms() { + encryptionAlgo := encryptionAlgo + t.Run(encryptionAlgo, func(t *testing.T) { + e, err := encryption.CreateEncryptor(encryptionAlgo, passPhrase) + if err != nil { + t.Fatal(err) + } + + var cipherText1 bytes.Buffer + var cipherText2 bytes.Buffer + + require.NoError(t, e.EncryptStream(bytes.NewBuffer(data), &cipherText1)) + require.NoError(t, e.EncryptStream(bytes.NewBuffer(data), &cipherText2)) + + if v := cipherText1.Bytes(); bytes.Equal(v, cipherText2.Bytes()) { + t.Errorf("multiple EncryptStream returned the same ciphertext: %x", v) + } + + var plainText1 bytes.Buffer + + require.NoError(t, e.DecryptStream(&cipherText1, &plainText1)) + + if v := plainText1.Bytes(); !bytes.Equal(v, data) { + t.Errorf("EncryptStream()/DecryptStream() does not round-trip: %x %x", v, data) + } + + var plainText2 bytes.Buffer + + require.NoError(t, e.DecryptStream(bytes.NewBuffer(cipherText2.Bytes()), &plainText2)) + + if v := plainText2.Bytes(); !bytes.Equal(v, data) { + t.Errorf("EncryptStream()/DecryptStream() does not round-trip: %x %x", v, data) + } + + // TODO: enable the following logic if the algorithm is AEAD + // // flip some bits in the cipherText + // b := cipherText2.Bytes() + // b[mathrand.Intn(len(b))] ^= byte(1 + mathrand.Intn(254)) + + // plainText2.Reset() + // require.Error(t, e.DecryptStream(bytes.NewBuffer(cipherText2.Bytes()), &plainText2)) + }) + } +} + +func TestCiphertextSamples(t *testing.T) { + cases := []struct { + passPhrase []byte + payload []byte + samples map[string]string + }{ + { + passPhrase: []byte("01234567890123456789012345678901"), // 32 bytes + payload: []byte("foo"), + + // samples of base16-encoded ciphertexts of payload encrypted with passPhrase + samples: map[string]string{ + "AES128-CFB": "3f531b215a8b0774edeb5f07f451f811c6ba0b", + "AES192-CFB": "65cce058982cde6dec94ee7965c737bd2e9044", + "AES256-CFB": "cd68a8ba7e886f00326ebd9da560bfca0ad5c4", + }, + }, + { + passPhrase: []byte("abcdefghijklmnopqrstuvwxyzabcdef"), // 32 bytes + payload: []byte("quick brown fox jumps over the lazy dog"), + + // samples of base16-encoded ciphertexts of payload encrypted with passPhrase + samples: map[string]string{ + "AES128-CFB": "a4fd5ed9b98b780f09c3253dadd81e9b96b52f3fbe215ab0b43e88df82457b5eb4209bdabe4d8edf045763d17807ea559f4d1e316edc9b", + "AES192-CFB": "6d9cfd25a3b5f2299534c87cd6f61f16c153178e74496d9b6b67f351d9c2a4a7c1514a5a4b42efe945ac56baea71f1dff51df9dc40a8a4", + "AES256-CFB": "9456421484adb715d7f6b52663908dd1acf16848077df01942847cc0e835a627c8b5c704b465ea86f47afd4e359c097582e81a544fdbd1", + }, + }, + } + + for _, tc := range cases { + verifyCiphertextSamples(t, tc.passPhrase, tc.payload, tc.samples) + } +} + +func verifyCiphertextSamples(t *testing.T, passPhrase, payload []byte, samples map[string]string) { + t.Helper() + + for _, encryptionAlgo := range encryption.SupportedAlgorithms() { + enc, err := encryption.CreateEncryptor(encryptionAlgo, passPhrase) + if err != nil { + t.Fatal(err) + } + + ct := samples[encryptionAlgo] + if ct == "" { + func() { + var v bytes.Buffer + require.NoError(t, enc.EncryptStream(bytes.NewBuffer(payload), &v)) + + t.Errorf("missing ciphertext sample for %q: %q, possible one is: %q", + encryptionAlgo, payload, hex.EncodeToString(v.Bytes())) + }() + } else { + b, err := hex.DecodeString(ct) + if err != nil { + t.Errorf("invalid ciphertext for %v: %v", encryptionAlgo, err) + continue + } + + func() { + var plainText bytes.Buffer + + require.NoError(t, enc.DecryptStream(bytes.NewBuffer(b), &plainText)) + + if v := plainText.Bytes(); !bytes.Equal(v, payload) { + t.Errorf("invalid plaintext after decryption %x, want %x", v, payload) + } + }() + } + } +} + +func benchmarkEncryption(b *testing.B, algorithm string) { + passPhrase := make([]byte, 32) + rand.Read(passPhrase) + + enc, err := encryption.CreateEncryptor(algorithm, passPhrase) + require.NoError(b, err) + + // 8 MiB + plainText := bytes.Repeat([]byte{1, 2, 3, 4, 5, 6, 7, 8}, 1<<20) + + var warmupOut bytes.Buffer + require.NoError(b, enc.EncryptStream(bytes.NewBuffer(plainText), &warmupOut)) + + b.ResetTimer() + + var out bytes.Buffer + out.Grow(len(plainText) + enc.Overhead()) + for i := 0; i < b.N; i++ { + out.Reset() + enc.EncryptStream(bytes.NewBuffer(plainText), &out) + b.SetBytes(int64(len(plainText))) + } +} + +func BenchmarkEncryption(b *testing.B) { + for _, encryptionAlgo := range encryption.SupportedAlgorithms() { + encryptionAlgo := encryptionAlgo + b.Run(encryptionAlgo, func(b *testing.B) { + benchmarkEncryption(b, encryptionAlgo) + }) + } +} + +func benchmarkDecryption(b *testing.B, algorithm string) { + passPhrase := make([]byte, 32) + rand.Read(passPhrase) + + enc, err := encryption.CreateEncryptor(algorithm, passPhrase) + require.NoError(b, err) + + // 8 MiB + plainText := bytes.Repeat([]byte{1, 2, 3, 4, 5, 6, 7, 8}, 1<<20) + + var warmupOut bytes.Buffer + require.NoError(b, enc.EncryptStream(bytes.NewBuffer(plainText), &warmupOut)) + + cipherText := warmupOut.Bytes() + + b.ResetTimer() + + var out bytes.Buffer + out.Grow(len(plainText)) + for i := 0; i < b.N; i++ { + out.Reset() + enc.DecryptStream(bytes.NewBuffer(cipherText), &out) + b.SetBytes(int64(len(cipherText))) + } +} + +func BenchmarkDecryption(b *testing.B) { + for _, encryptionAlgo := range encryption.SupportedAlgorithms() { + encryptionAlgo := encryptionAlgo + b.Run(encryptionAlgo, func(b *testing.B) { + benchmarkDecryption(b, encryptionAlgo) + }) + } +} diff --git a/pkg/encryption/encryptor.go b/pkg/encryption/encryptor.go new file mode 100644 index 0000000..28bbb3f --- /dev/null +++ b/pkg/encryption/encryptor.go @@ -0,0 +1,17 @@ +package encryption + +import ( + "io" +) + +// StreamEncryptor represents an interface for encrypting and decrypting data streams. +type StreamEncryptor interface { + // EncryptStream appends the encrypted bytes corresponding to the given plaintext to a given writer. + EncryptStream(plainText io.Reader, output io.Writer) error + + // DecryptStream appends the unencrypted bytes corresponding to the given ciphertext to a given writer. + DecryptStream(cipherText io.Reader, output io.Writer) error + + // Overhead is the number of bytes of overhead added by EncryptStream() + Overhead() int +} diff --git a/pkg/storage/encrypted/encrypted.go b/pkg/storage/encrypted/encrypted.go new file mode 100644 index 0000000..aa935dd --- /dev/null +++ b/pkg/storage/encrypted/encrypted.go @@ -0,0 +1,177 @@ +package encrypted + +import ( + "context" + "errors" + "io" + "strings" + + "github.com/apecloud/datasafed/pkg/encryption" + "github.com/apecloud/datasafed/pkg/logging" + "github.com/apecloud/datasafed/pkg/storage" + "github.com/apecloud/datasafed/pkg/storage/sanitized" + "github.com/apecloud/datasafed/pkg/util" +) + +const encryptedFileSuffix = ".enc" + +var log = logging.Module("storage/encrypted") + +type encryptedStorage struct { + encryptor encryption.StreamEncryptor + underlying storage.Storage +} + +var _ storage.Storage = (*encryptedStorage)(nil) + +func New(ctx context.Context, + encryptor encryption.StreamEncryptor, + underlying storage.Storage) (storage.Storage, error) { + es := &encryptedStorage{ + encryptor: encryptor, + underlying: underlying, + } + return sanitized.New(ctx, "", es) +} + +func (s *encryptedStorage) Push(ctx context.Context, r io.Reader, rpath string) error { + pr, pw := io.Pipe() + go func() { + err := s.encryptor.EncryptStream(r, pw) + if err != nil { + pw.CloseWithError(err) + } else { + pw.Close() // EOF + } + }() + return s.underlying.Push(ctx, pr, rpath+encryptedFileSuffix) +} + +func (s *encryptedStorage) Pull(ctx context.Context, rpath string, w io.Writer) error { + errCh := make(chan error, 1) + pr, pw := io.Pipe() + go func() { + err := s.encryptor.DecryptStream(pr, w) + if err != nil { + // interrupt underlying.Pull() + pr.CloseWithError(err) + } + errCh <- err + }() + err := s.underlying.Pull(ctx, rpath+encryptedFileSuffix, pw) + if err != nil { + pw.CloseWithError(err) + } else { + pw.Close() + } + decErr := <-errCh // wait until all data are decrypted + if err != nil { + return err + } + return decErr +} + +func (s *encryptedStorage) OpenFile(ctx context.Context, rpath string, offset int64, length int64) (io.ReadCloser, error) { + rc, err := s.underlying.OpenFile(ctx, rpath+encryptedFileSuffix, 0, 0) + if err != nil { + return nil, err + } + pr, pw := io.Pipe() + go func() { + err := s.encryptor.DecryptStream(rc, pw) + if err != nil { + pw.CloseWithError(err) + } else { + pw.Close() + } + }() + var rd io.Reader = pr + if offset > 0 { + log(ctx).Warnf("[ENCRYPTED] OpenFile(): it's not efficient to handle non-zero offset(%d)", offset) + rd = util.DiscardNReader(rd, int(offset)) + } + if length > 0 { + rd = io.LimitReader(rd, length) + } + return &struct { + io.Reader + io.Closer + }{ + Reader: rd, + Closer: pr, + }, nil +} + +func (s *encryptedStorage) Remove(ctx context.Context, rpath string, recursive bool) error { + if !recursive { + return s.underlying.Remove(ctx, rpath+encryptedFileSuffix, false) + } else { + return s.underlying.Remove(ctx, rpath, true) + } +} + +func (s *encryptedStorage) Rmdir(ctx context.Context, rpath string) error { + return s.underlying.Rmdir(ctx, rpath) +} + +func (s *encryptedStorage) Mkdir(ctx context.Context, rpath string) error { + return s.underlying.Mkdir(ctx, rpath) +} + +func (s *encryptedStorage) List(ctx context.Context, rpath string, opt *storage.ListOptions, cb storage.ListCallback) error { + myCb := func(de storage.DirEntry) error { + var err error + if !de.IsDir() { + if strings.HasSuffix(de.Name(), encryptedFileSuffix) { + name := strings.TrimSuffix(de.Name(), encryptedFileSuffix) + path := strings.TrimSuffix(de.Path(), encryptedFileSuffix) + size := de.Size() - int64(s.encryptor.Overhead()) + newEntry := storage.NewStaticDirEntry(de.IsDir(), name, path, size, de.MTime()) + err = cb(newEntry) + } + // ignore files that doesn't end with encryptedFileSuffix + } else { + err = cb(de) + } + return err + } + if opt.PathIsFile { + // list a file + return s.underlying.List(ctx, rpath+encryptedFileSuffix, opt, myCb) + } else if strings.HasSuffix(rpath, "/") { + // rpath is a folder + return s.underlying.List(ctx, rpath, opt, myCb) + } else { + // try list single file first + cloneOpt := *opt + cloneOpt.PathIsFile = true + err := s.underlying.List(ctx, rpath+encryptedFileSuffix, &cloneOpt, myCb) + if err != nil { + // ignore ErrObjectNotFound + if !errors.Is(err, storage.ErrObjectNotFound) { + return err + } + } else { + return nil + } + + // try list a folder + return s.underlying.List(ctx, rpath, opt, myCb) + } +} + +func (s *encryptedStorage) Stat(ctx context.Context, rpath string) (storage.StatResult, error) { + result := storage.StatResult{} + statFunc := func(de storage.DirEntry) error { + if de.IsDir() { + result.Dirs++ + } else { + result.Files++ + result.TotalSize += de.Size() + } + return nil + } + err := s.List(ctx, rpath, &storage.ListOptions{Recursive: true}, statFunc) + result.Entries = result.Dirs + result.Files + return result, err +} diff --git a/pkg/util/discard.go b/pkg/util/discard.go index afcb60b..fd0a9cc 100644 --- a/pkg/util/discard.go +++ b/pkg/util/discard.go @@ -5,14 +5,14 @@ import ( "sync" ) -type discardN struct { +type discardNWriter struct { mu sync.Mutex out io.Writer n int skipped int } -func (d *discardN) Write(p []byte) (int, error) { +func (d *discardNWriter) Write(p []byte) (int, error) { d.mu.Lock() defer d.mu.Unlock() rest := d.n - d.skipped @@ -29,9 +29,57 @@ func (d *discardN) Write(p []byte) (int, error) { return d.out.Write(p) } -func DiscardN(out io.Writer, n int) io.Writer { - return &discardN{ +func DiscardNWriter(out io.Writer, n int) io.Writer { + return &discardNWriter{ out: out, n: n, } } + +type discardNReader struct { + mu sync.Mutex + in io.Reader + n int + skipped int +} + +func (d *discardNReader) discardLocked() error { + buf := make([]byte, 32*1024) + for { + rest := d.n - d.skipped + if rest <= 0 { + break + } + rn := len(buf) + if rn > rest { + rn = rest + } + n, err := d.in.Read(buf[:rn]) + if err != nil { + return err + } + if n > 0 { + d.skipped += n + } + } + return nil +} + +func (d *discardNReader) Read(p []byte) (int, error) { + d.mu.Lock() + defer d.mu.Unlock() + if d.n-d.skipped > 0 { + err := d.discardLocked() + if err != nil { + return 0, err + } + } + return d.in.Read(p) +} + +func DiscardNReader(in io.Reader, n int) io.Reader { + return &discardNReader{ + in: in, + n: n, + } +} From 74a5808bd6c18d584bd68497f45701b1af02df69 Mon Sep 17 00:00:00 2001 From: "x.zhou" Date: Thu, 29 Feb 2024 14:47:51 +0800 Subject: [PATCH 2/2] rename algorithm --- pkg/encryption/aes_encryptor.go | 6 +++--- pkg/encryption/encryption.go | 3 --- pkg/encryption/encryption_test.go | 12 ++++++------ 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/pkg/encryption/aes_encryptor.go b/pkg/encryption/aes_encryptor.go index b3d3918..035bff6 100644 --- a/pkg/encryption/aes_encryptor.go +++ b/pkg/encryption/aes_encryptor.go @@ -90,7 +90,7 @@ func NewAES256CFB(passPhrase []byte) (StreamEncryptor, error) { } func init() { - Register("AES128-CFB", "AES-128 with CFB mode", NewAES128CFB) - Register("AES192-CFB", "AES-192 with CFB mode", NewAES192CFB) - Register("AES256-CFB", "AES-256 with CFB mode", NewAES256CFB) + Register("AES-128-CFB", "AES-128 with CFB mode", NewAES128CFB) + Register("AES-192-CFB", "AES-192 with CFB mode", NewAES192CFB) + Register("AES-256-CFB", "AES-256 with CFB mode", NewAES256CFB) } diff --git a/pkg/encryption/encryption.go b/pkg/encryption/encryption.go index 873d25f..7c6d764 100644 --- a/pkg/encryption/encryption.go +++ b/pkg/encryption/encryption.go @@ -26,9 +26,6 @@ func CreateEncryptor(algorithm string, passPhrase []byte) (StreamEncryptor, erro // EncryptorFactory creates new Encryptor for given parameters. type EncryptorFactory func(passPhrase []byte) (StreamEncryptor, error) -// DefaultAlgorithm is the name of the default encryption algorithm. -const DefaultAlgorithm = "AES256-CFB" - // SupportedAlgorithms returns the names of the supported encryption methods. func SupportedAlgorithms() []string { var result []string diff --git a/pkg/encryption/encryption_test.go b/pkg/encryption/encryption_test.go index 411652b..12490ce 100644 --- a/pkg/encryption/encryption_test.go +++ b/pkg/encryption/encryption_test.go @@ -74,9 +74,9 @@ func TestCiphertextSamples(t *testing.T) { // samples of base16-encoded ciphertexts of payload encrypted with passPhrase samples: map[string]string{ - "AES128-CFB": "3f531b215a8b0774edeb5f07f451f811c6ba0b", - "AES192-CFB": "65cce058982cde6dec94ee7965c737bd2e9044", - "AES256-CFB": "cd68a8ba7e886f00326ebd9da560bfca0ad5c4", + "AES-128-CFB": "3f531b215a8b0774edeb5f07f451f811c6ba0b", + "AES-192-CFB": "65cce058982cde6dec94ee7965c737bd2e9044", + "AES-256-CFB": "cd68a8ba7e886f00326ebd9da560bfca0ad5c4", }, }, { @@ -85,9 +85,9 @@ func TestCiphertextSamples(t *testing.T) { // samples of base16-encoded ciphertexts of payload encrypted with passPhrase samples: map[string]string{ - "AES128-CFB": "a4fd5ed9b98b780f09c3253dadd81e9b96b52f3fbe215ab0b43e88df82457b5eb4209bdabe4d8edf045763d17807ea559f4d1e316edc9b", - "AES192-CFB": "6d9cfd25a3b5f2299534c87cd6f61f16c153178e74496d9b6b67f351d9c2a4a7c1514a5a4b42efe945ac56baea71f1dff51df9dc40a8a4", - "AES256-CFB": "9456421484adb715d7f6b52663908dd1acf16848077df01942847cc0e835a627c8b5c704b465ea86f47afd4e359c097582e81a544fdbd1", + "AES-128-CFB": "a4fd5ed9b98b780f09c3253dadd81e9b96b52f3fbe215ab0b43e88df82457b5eb4209bdabe4d8edf045763d17807ea559f4d1e316edc9b", + "AES-192-CFB": "6d9cfd25a3b5f2299534c87cd6f61f16c153178e74496d9b6b67f351d9c2a4a7c1514a5a4b42efe945ac56baea71f1dff51df9dc40a8a4", + "AES-256-CFB": "9456421484adb715d7f6b52663908dd1acf16848077df01942847cc0e835a627c8b5c704b465ea86f47afd4e359c097582e81a544fdbd1", }, }, }