diff --git a/ext4/dmverity/dmverity.go b/ext4/dmverity/dmverity.go index c3bdef040f..5b3fa8a349 100644 --- a/ext4/dmverity/dmverity.go +++ b/ext4/dmverity/dmverity.go @@ -178,29 +178,35 @@ func ReadDMVerityInfo(vhdPath string, offsetInBytes int64) (*VerityInfo, error) return nil, errors.Errorf("failed to seek dm-verity super block: expected bytes=%d, actual=%d", offsetInBytes, s) } + return ReadDMVerityInfoReader(vhd) +} + +func ReadDMVerityInfoReader(r io.Reader) (*VerityInfo, error) { block := make([]byte, blockSize) - if s, err := vhd.Read(block); err != nil || s != blockSize { + if s, err := r.Read(block); err != nil || s != blockSize { if err != nil { - return nil, errors.Wrapf(err, "%s", ErrSuperBlockReadFailure) + return nil, fmt.Errorf("%s: %w", ErrSuperBlockReadFailure, err) } - return nil, errors.Wrapf(ErrSuperBlockReadFailure, "unexpected bytes read: expected=%d, actual=%d", blockSize, s) + return nil, fmt.Errorf("unexpected bytes read expected=%d actual=%d: %w", blockSize, s, ErrSuperBlockReadFailure) } dmvSB := &dmveritySuperblock{} b := bytes.NewBuffer(block) if err := binary.Read(b, binary.LittleEndian, dmvSB); err != nil { - return nil, errors.Wrapf(err, "%s", ErrSuperBlockParseFailure) + return nil, fmt.Errorf("%s: %w", ErrSuperBlockParseFailure, err) } + if string(bytes.Trim(dmvSB.Signature[:], "\x00")[:]) != VeritySignature { return nil, ErrNotVeritySuperBlock } - // read the merkle tree root - if s, err := vhd.Read(block); err != nil || s != blockSize { + + if s, err := r.Read(block); err != nil || s != blockSize { if err != nil { - return nil, errors.Wrapf(err, "%s", ErrRootHashReadFailure) + return nil, fmt.Errorf("%s: %w", ErrRootHashReadFailure, err) } - return nil, errors.Wrapf(ErrRootHashReadFailure, "unexpected bytes read: expected=%d, actual=%d", blockSize, s) + return nil, fmt.Errorf("unexpected bytes read expected=%d, actual=%d: %w", blockSize, s, ErrRootHashReadFailure) } + rootHash := hash2(dmvSB.Salt[:dmvSB.SaltSize], block) return &VerityInfo{ RootDigest: fmt.Sprintf("%x", rootHash), @@ -215,12 +221,21 @@ func ReadDMVerityInfo(vhdPath string, offsetInBytes int64) (*VerityInfo, error) }, nil } -// ComputeAndWriteHashDevice builds merkle tree from a given io.ReadSeeker and writes the result -// hash device (dm-verity super-block combined with merkle tree) to io.WriteSeeker. -func ComputeAndWriteHashDevice(r io.ReadSeeker, w io.WriteSeeker) error { +// ComputeAndWriteHashDevice builds merkle tree from a given io.ReadSeeker and +// writes the result hash device (dm-verity super-block combined with merkle +// tree) to io.Writer. +func ComputeAndWriteHashDevice(r io.ReadSeeker, w io.Writer) error { + // save current reader position + currBytePos, err := r.Seek(0, io.SeekCurrent) + if err != nil { + return err + } + + // reset to the beginning to find the device size if _, err := r.Seek(0, io.SeekStart); err != nil { return err } + tree, err := MerkleTree(r) if err != nil { return errors.Wrap(err, "failed to build merkle tree") @@ -230,10 +245,13 @@ func ComputeAndWriteHashDevice(r io.ReadSeeker, w io.WriteSeeker) error { if err != nil { return err } - dmVeritySB := NewDMVeritySuperblock(uint64(devSize)) - if _, err := w.Seek(0, io.SeekEnd); err != nil { + + // reset reader to initial position + if _, err := r.Seek(currBytePos, io.SeekStart); err != nil { return err } + + dmVeritySB := NewDMVeritySuperblock(uint64(devSize)) if err := binary.Write(w, binary.LittleEndian, dmVeritySB); err != nil { return errors.Wrap(err, "failed to write dm-verity super-block") } diff --git a/ext4/dmverity/dmverity_test.go b/ext4/dmverity/dmverity_test.go index 25ddea695a..03ee5ce90e 100644 --- a/ext4/dmverity/dmverity_test.go +++ b/ext4/dmverity/dmverity_test.go @@ -57,7 +57,7 @@ func TestInvalidReadEOF(t *testing.T) { if err == nil { t.Fatalf("no error returned") } - if errors.Cause(err) != io.EOF { + if errors.Unwrap(err) != io.EOF { t.Fatalf("unexpected error: %s", err) } } @@ -68,7 +68,7 @@ func TestInvalidReadNotEnoughBytes(t *testing.T) { if err == nil { t.Fatalf("no error returned") } - if errors.Cause(err) != ErrSuperBlockReadFailure || !strings.Contains(err.Error(), "unexpected bytes read") { + if errors.Unwrap(err) != ErrSuperBlockReadFailure || !strings.Contains(err.Error(), "unexpected bytes read") { t.Fatalf("unexpected error: %s", err) } } @@ -94,7 +94,7 @@ func TestNoMerkleTree(t *testing.T) { if err == nil { t.Fatalf("no error returned") } - if errors.Cause(err) != io.EOF || !strings.Contains(err.Error(), "failed to read dm-verity root hash") { + if errors.Unwrap(err) != io.EOF || !strings.Contains(err.Error(), "failed to read dm-verity root hash") { t.Fatalf("expected %q, got %q", io.EOF, err) } } diff --git a/ext4/tar2ext4/tar2ext4.go b/ext4/tar2ext4/tar2ext4.go index 6938683385..319b1517f4 100644 --- a/ext4/tar2ext4/tar2ext4.go +++ b/ext4/tar2ext4/tar2ext4.go @@ -200,7 +200,19 @@ func Convert(r io.Reader, w io.ReadWriteSeeker, options ...Option) error { return nil } -// ReadExt4SuperBlock reads and returns ext4 super block from VHD +// ReadExt4SuperBlock reads and returns ext4 super block from given device. +func ReadExt4SuperBlock(devicePath string) (*format.SuperBlock, error) { + dev, err := os.OpenFile(devicePath, os.O_RDONLY, 0) + if err != nil { + return nil, err + } + defer dev.Close() + + return ReadExt4SuperBlockReadSeeker(dev) +} + +// ReadExt4SuperBlockReadSeeker reads and returns ext4 super block given +// an io.ReadSeeker. // // The layout on disk is as follows: // | Group 0 padding | - 1024 bytes @@ -215,22 +227,26 @@ func Convert(r io.Reader, w io.ReadWriteSeeker, options ...Option) error { // More details can be found here https://ext4.wiki.kernel.org/index.php/Ext4_Disk_Layout // // Our goal is to skip the Group 0 padding, read and return the ext4 SuperBlock -func ReadExt4SuperBlock(devicePath string) (*format.SuperBlock, error) { - dev, err := os.OpenFile(devicePath, os.O_RDONLY, 0) +func ReadExt4SuperBlockReadSeeker(rsc io.ReadSeeker) (*format.SuperBlock, error) { + // save current reader position + currBytePos, err := rsc.Seek(0, io.SeekCurrent) if err != nil { return nil, err } - defer dev.Close() - // Skip padding at the start - if _, err := dev.Seek(1024, io.SeekStart); err != nil { + if _, err := rsc.Seek(1024, io.SeekCurrent); err != nil { return nil, err } var sb format.SuperBlock - if err := binary.Read(dev, binary.LittleEndian, &sb); err != nil { + if err := binary.Read(rsc, binary.LittleEndian, &sb); err != nil { + return nil, err + } + + // reset the reader to initial position + if _, err := rsc.Seek(currBytePos, io.SeekStart); err != nil { return nil, err } - // Make sure the magic bytes are correct. + if sb.Magic != format.SuperBlockMagic { return nil, errors.New("not an ext4 file system") } @@ -246,6 +262,18 @@ func IsDeviceExt4(devicePath string) bool { return err == nil } +// Ext4FileSystemSize reads ext4 superblock and returns the size of the underlying +// ext4 file system and its block size. +func Ext4FileSystemSize(r io.ReadSeeker) (int64, int, error) { + sb, err := ReadExt4SuperBlockReadSeeker(r) + if err != nil { + return 0, 0, fmt.Errorf("failed to read ext4 superblock: %w", err) + } + blockSize := 1024 * (1 << sb.LogBlockSize) + fsSize := int64(blockSize) * int64(sb.BlocksCountLow) + return fsSize, blockSize, nil +} + // ConvertAndComputeRootDigest writes a compact ext4 file system image that contains the files in the // input tar stream, computes the resulting file image's cryptographic hashes (merkle tree) and returns // merkle tree root digest. Convert is called with minimal options: ConvertWhiteout and MaximumDiskSize diff --git a/internal/verity/verity.go b/internal/verity/verity.go index 7aef0ce65e..795a6427e8 100644 --- a/internal/verity/verity.go +++ b/internal/verity/verity.go @@ -2,6 +2,8 @@ package verity import ( "context" + "fmt" + "os" "github.com/Microsoft/hcsshim/ext4/dmverity" "github.com/Microsoft/hcsshim/ext4/tar2ext4" @@ -13,13 +15,13 @@ import ( // fileSystemSize retrieves ext4 fs SuperBlock and returns the file system size and block size func fileSystemSize(vhdPath string) (int64, int, error) { - sb, err := tar2ext4.ReadExt4SuperBlock(vhdPath) + vhd, err := os.Open(vhdPath) if err != nil { - return 0, 0, errors.Wrap(err, "failed to read ext4 super block") + return 0, 0, fmt.Errorf("failed to open VHD file: %w", err) } - blockSize := 1024 * (1 << sb.LogBlockSize) - fsSize := int64(blockSize) * int64(sb.BlocksCountLow) - return fsSize, blockSize, nil + defer vhd.Close() + + return tar2ext4.Ext4FileSystemSize(vhd) } // ReadVeritySuperBlock reads ext4 super block for a given VHD to then further read the dm-verity super block