diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 1bd360ef82..44875acb14 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -355,12 +355,12 @@ func modifyMappedVirtualDisk(ctx context.Context, rt prot.ModifyRequestType, mvd mountCtx, cancel := context.WithTimeout(ctx, time.Second*4) defer cancel() if mvd.MountPath != "" { - return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.MountPath, mvd.ReadOnly, mvd.Options) + return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.MountPath, mvd.ReadOnly, false, mvd.Options) } return nil case prot.MreqtRemove: if mvd.MountPath != "" { - if err := storage.UnmountPath(ctx, mvd.MountPath, true); err != nil { + if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.MountPath, false); err != nil { return err } } diff --git a/internal/guest/storage/crypt/crypt.go b/internal/guest/storage/crypt/crypt.go new file mode 100644 index 0000000000..607ce18f92 --- /dev/null +++ b/internal/guest/storage/crypt/crypt.go @@ -0,0 +1,257 @@ +// +build linux + +package crypt + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + + "github.com/Microsoft/hcsshim/internal/log" + "github.com/pkg/errors" +) + +// Test dependencies +var ( + _copyEmptySparseFilesystem = copyEmptySparseFilesystem + _createSparseEmptyFile = createSparseEmptyFile + _cryptsetupClose = cryptsetupClose + _cryptsetupFormat = cryptsetupFormat + _cryptsetupOpen = cryptsetupOpen + _generateKeyFile = generateKeyFile + _getBlockDeviceSize = getBlockDeviceSize + _ioutilTempDir = ioutil.TempDir + _mkfsExt4Command = mkfsExt4Command + _osRemoveAll = os.RemoveAll +) + +// String used to identify dm-crypt devices. The argument is a unique name based +// on the original block device path. +const cryptDeviceTemplate string = "dm-crypt-%s" + +// cryptsetupCommand runs cryptsetup with the provided arguments +func cryptsetupCommand(args []string) error { + // --debug and -v are used to increase the information printed by + // cryptsetup. By default, it doesn't print much information, which makes it + // hard to debug it when there are problems. + cmd := exec.Command("cryptsetup", append([]string{"--debug", "-v"}, args...)...) + output, err := cmd.CombinedOutput() + if err != nil { + return errors.Wrapf(err, "failed to execute cryptsetup: %s", string(output)) + } + return nil +} + +// cryptsetupFormat runs "cryptsetup luksFormat" with the right arguments to use +// dm-crypt and dm-integrity. +func cryptsetupFormat(source string, keyFilePath string) error { + formatArgs := []string{ + // Mount source using LUKS2 + "luksFormat", source, "--type", "luks2", + // Provide keyfile and prevent showing the confirmation prompt + "--key-file", keyFilePath, "--batch-mode", + // dm-crypt and dm-integrity algorithms. The dm-crypt algorithm is the + // default one used for LUKS. The dm-integrity is the one the + // documentation mentions as one of the combinations they use for + // testing: + // https://gitlab.com/cryptsetup/cryptsetup/-/blob/a0277d3ff6ab7d5c9e0534f25b4b40719e999c8e/docs/v2.0.0-ReleaseNotes#L259-261 + "--cipher", "aes-xts-plain64", "--integrity", "hmac-sha256", + // See EncryptDevice() for the reason of using --integrity-no-wipe + "--integrity-no-wipe", + // Use 4KB sectors, the documentation mentions it can improve + // performance than smaller sizes. + "--sector-size", "4096", + // Force PBKDF2 and a specific number of iterations to skip the + // benchmarking step of luksFormat. Using a KDF is required by + // cryptsetup. The reason why it is mandatory to use a KDF is that + // cryptsetup expects the user to input a passphrase and cryptsetup is + // supposed to derive a strong key from it. In our case, we already pass + // a strong key to cryptsetup, so we don't need a strong KDF. Ideally, + // it would be bypassed completely, but this isn't possible. + "--pbkdf", "pbkdf2", "--pbkdf-force-iterations", "1000"} + + return cryptsetupCommand(formatArgs) +} + +// cryptsetupOpen runs "cryptsetup luksOpen" with the right arguments. +func cryptsetupOpen(source string, deviceName string, keyFilePath string) error { + openArgs := []string{ + // Open device with the key passed to luksFormat + "luksOpen", source, deviceName, "--key-file", keyFilePath, + // Don't use a journal to increase performance + "--integrity-no-journal", "--persistent"} + + return cryptsetupCommand(openArgs) +} + +// cryptsetupClose runs "cryptsetup luksClose" with the right arguments. +func cryptsetupClose(deviceName string) error { + closeArgs := []string{"luksClose", deviceName} + + return cryptsetupCommand(closeArgs) +} + +// mkfsExt4Command runs mkfs.ext4 with the provided arguments +func mkfsExt4Command(args []string) error { + cmd := exec.Command("mkfs.ext4", args...) + output, err := cmd.CombinedOutput() + if err != nil { + return errors.Wrapf(err, "failed to execute mkfs.ext4: %s", string(output)) + } + return nil +} + +// EncryptDevice creates a dm-crypt target for a container scratch vhd. +// +// In order to mount a block device as an encrypted device: +// +// 1. Generate a random key. It doesn't matter which key it is, the aim is to +// protect the contents of the scratch disk from the host OS. It can be +// deleted after mounting the encrypted device. +// +// 2. The original block device has to be formatted with cryptsetup with the +// generated key. This results in that block device becoming an encrypted +// block device that can't be mounted directly. +// +// 3. Open the block device with cryptsetup. It is needed to assign it a device +// name. We are using names that follow `cryptDeviceTemplate`, where "%s" is +// a unique name generated from the path of the original block device. In +// this case, it's just the path of the block device with all +// non-alphanumeric characters replaced by a '-'. +// +// The kernel exposes the unencrypted block device at the path +// /dev/mapper/`cryptDeviceTemplate`. This can be mounted directly, but it +// doesn't have any format yet. +// +// 4. Format the unencrypted block device as ext4: +// +// A normal invocation of luksFormat wipes the target device. This takes +// a really long time, which isn't acceptable in our use-case. Passing the +// option --integrity-no-wipe prevents this from happening so that the +// command ends in an instant. +// +// Because of using --integrity-no-wipe, the resulting device isn't wiped and +// all the integrity tags are incorrect. This means that any attempt to read +// from it will cause an I/O error, which programs aren't prepared to handle. +// For example, mkfs.ext4 tries to read blocks before writing to them, and +// there is no way around it. When it gets an I/O error, it just exits. +// +// The solution is to create a file with the same size as the resulting +// device, format it as ext4, then use dd to copy the format to the device +// (dd won't try to read anything). +// +// However, creating a file that is several GB in size isn't a good solution +// either because doing dd of the whole file would take as long as letting +// luksFormat wipe the disk. +// +// The solution is to create a sparse file and format it. Then, it is +// possible to copy the format to the block device by doing a sparse copy +// (only copy the data parts of the file, not the holes). This makes +// formatting the device almost instantaneous. +// +// 4.1. Get size of scratch disk. +// +// 4.2. Create sparse filesystem image with the same size as the scratch +// device. It can be removed afterwards. +// +// 4.3. Format it as ext4. This way the file is only as big as the few blocks +// of the image that have the filesystem information, the ones modified +// by mkfs.ext4. +// +// 4.4. Do a sparse copy of the filesystem into the unencrypted block device. +// This updates the integrity tags. +func EncryptDevice(ctx context.Context, source string) (path string, err error) { + + uniqueName, err := getUniqueName(source) + if err != nil { + return "", errors.Wrapf(err, "failed to generate unique name: %s", source) + } + + // Create temporary directory to store the keyfile and EXT4 image + tempDir, err := _ioutilTempDir("", "dm-crypt") + if err != nil { + return "", errors.Wrapf(err, "failed to create temporary folder: %s", source) + } + + defer func() { + // Delete it on exit, it won't be needed afterwards + if err := _osRemoveAll(tempDir); err != nil { + log.G(ctx).WithError(err).Debugf("failed to delete temporary folder: %s", tempDir) + } + }() + + // 1. Generate keyfile + keyFilePath := filepath.Join(tempDir, "keyfile") + + if err = _generateKeyFile(keyFilePath, 1024); err != nil { + return "", errors.Wrapf(err, "failed to generate keyfile: %s", keyFilePath) + } + + // 2. Format device + if err = _cryptsetupFormat(source, keyFilePath); err != nil { + return "", errors.Wrapf(err, "luksFormat failed: %s", source) + } + + // 3. Open device + deviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + if err := _cryptsetupOpen(source, deviceName, keyFilePath); err != nil { + return "", errors.Wrapf(err, "luksOpen failed: %s", source) + } + + defer func() { + if err != nil { + if inErr := CleanupCryptDevice(source); inErr != nil { + log.G(ctx).WithError(inErr).Debug("failed to cleanup crypt device") + } + } + }() + + deviceNamePath := "/dev/mapper/" + deviceName + + // 4.1. Get actual size of the scratch device + deviceSize, err := _getBlockDeviceSize(ctx, deviceNamePath) + if err != nil { + return "", errors.Wrapf(err, "error getting size of: %s", deviceNamePath) + } + + if deviceSize == 0 { + return "", fmt.Errorf("invalid size obtained for: %s", deviceNamePath) + } + + // 4.2. Create sparse filesystem image + tempExt4File := filepath.Join(tempDir, "ext4.img") + + if err = _createSparseEmptyFile(ctx, tempExt4File, deviceSize); err != nil { + return "", errors.Wrap(err, "failed to create sparse filesystem file") + } + + // 4.3. Format it as ext4 + if err = _mkfsExt4Command([]string{tempExt4File}); err != nil { + return "", errors.Wrapf(err, "mkfs.ext4 failed to format: %s", tempExt4File) + } + + // 4.4. Sparse copy of the filesystem into the encrypted block device + if err = _copyEmptySparseFilesystem(tempExt4File, deviceNamePath); err != nil { + return "", errors.Wrap(err, "failed to do sparse copy") + } + + return deviceNamePath, nil +} + +// CleanupCryptDevice removes the dm-crypt device created by EncryptDevice +func CleanupCryptDevice(source string) error { + uniqueName, err := getUniqueName(source) + if err != nil { + return errors.Wrapf(err, "failed to generate unique name: %s", source) + } + + // Close dm-crypt device + deviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + if err := _cryptsetupClose(deviceName); err != nil { + return errors.Wrapf(err, "luksClose failed: %s", deviceName) + } + return nil +} diff --git a/internal/guest/storage/crypt/crypt_test.go b/internal/guest/storage/crypt/crypt_test.go new file mode 100644 index 0000000000..c7bab94486 --- /dev/null +++ b/internal/guest/storage/crypt/crypt_test.go @@ -0,0 +1,437 @@ +// +build linux + +package crypt + +import ( + "context" + "fmt" + "testing" + + "github.com/pkg/errors" +) + +const tempDir = "/tmp/dir/" + +func ioutilTempDirTest(dir string, pattern string) (string, error) { + return tempDir, nil +} + +func clearCryptTestDependencies() { + _copyEmptySparseFilesystem = nil + _createSparseEmptyFile = nil + _cryptsetupClose = nil + _cryptsetupFormat = nil + _cryptsetupOpen = nil + _generateKeyFile = nil + _getBlockDeviceSize = nil + _mkfsExt4Command = nil + _ioutilTempDir = ioutilTempDirTest + _osRemoveAll = nil +} + +func Test_Encrypt_Generate_Key_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when key generation fails for any reason. Verify that + // the generated keyfile path has a number that matches the index value. + + source := "/dev/sda" + keyfilePath := tempDir + "keyfile" + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "failed to generate keyfile: %s", keyfilePath) + + _osRemoveAll = func(path string) error { + return nil + } + _generateKeyFile = func(path string, size int64) error { + if keyfilePath != path { + t.Errorf("expected path: %v, got: %v", keyfilePath, path) + } + return customErr + } + + _, err := EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Cryptsetup_Format_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when cryptsetup fails to format the device. Verify that + // the arguments passed to cryptsetup are the right ones. + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + + expectedSource := "/dev/sda" + expectedKeyFilePath := tempDir + "keyfile" + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "luksFormat failed: %s", expectedSource) + + _cryptsetupFormat = func(source string, keyFilePath string) error { + if source != expectedSource { + t.Fatalf("expected source: '%s' got: '%s'", expectedSource, source) + } + if keyFilePath != expectedKeyFilePath { + t.Fatalf("expected keyfile path: '%s' got: '%s'", expectedKeyFilePath, keyFilePath) + } + return customErr + } + + _, err := EncryptDevice(context.Background(), expectedSource) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Cryptsetup_Open_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when cryptsetup fails to open the device. Verify that + // the arguments passed to cryptsetup are the right ones. + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + + expectedSource := "/dev/sda" + uniqueName, _ := getUniqueName(expectedSource) + expectedDeviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + expectedKeyFilePath := tempDir + "keyfile" + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "luksOpen failed: %s", expectedSource) + + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + if source != expectedSource { + t.Fatalf("expected source: '%s' got: '%s'", expectedSource, source) + } + if deviceName != expectedDeviceName { + t.Fatalf("expected device name: '%s' got: '%s'", expectedDeviceName, deviceName) + } + if keyFilePath != expectedKeyFilePath { + t.Fatalf("expected keyfile path: '%s' got: '%s'", expectedKeyFilePath, keyFilePath) + } + return customErr + } + + _, err := EncryptDevice(context.Background(), expectedSource) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Get_Device_Size_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when cryptsetup fails to get the size of the + // unencrypted block device. + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + return nil + } + _cryptsetupClose = func(deviceName string) error { + return nil + } + + source := "/dev/sda" + uniqueName, _ := getUniqueName(source) + deviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + deviceNamePath := "/dev/mapper/" + deviceName + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "error getting size of: %s", deviceNamePath) + + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + return 0, customErr + } + + _, err := EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } + + // Check that it fails when the size of the block device is zero + + expectedErr = fmt.Errorf("invalid size obtained for: %s", deviceNamePath) + + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + return 0, nil + } + + _, err = EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Create_Sparse_File_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when it isn't possible to create a sparse file, and + // make sure that _createSparseEmptyFile receives the right arguments. + + blockDeviceSize := int64(1024 * 1024 * 1024) + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + return nil + } + _cryptsetupClose = func(deviceName string) error { + return nil + } + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + // Return a non-zero size + return blockDeviceSize, nil + } + + source := "/dev/sda" + tempExt4File := tempDir + "ext4.img" + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "failed to create sparse filesystem file") + + _createSparseEmptyFile = func(ctx context.Context, path string, size int64) error { + // Check that the path and the size are the expected ones + if path != tempExt4File { + t.Fatalf("expected path: '%v' got: '%v'", tempExt4File, path) + } + if size != blockDeviceSize { + t.Fatalf("expected size: '%v' got: '%v'", blockDeviceSize, size) + } + + return customErr + } + + _, err := EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Mkfs_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when mkfs fails to format the unencrypted device. + // Verify that the arguments passed to it are the right ones. + + blockDeviceSize := int64(1024 * 1024 * 1024) + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + return nil + } + _cryptsetupClose = func(deviceName string) error { + return nil + } + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + // Return a non-zero size + return blockDeviceSize, nil + } + _createSparseEmptyFile = func(ctx context.Context, path string, size int64) error { + return nil + } + + source := "/dev/sda" + tempExt4File := tempDir + "ext4.img" + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "mkfs.ext4 failed to format: %s", tempExt4File) + + _mkfsExt4Command = func(args []string) error { + if args[0] != tempExt4File { + t.Fatalf("expected args:\n'%v'\ngot:\n'%v'", tempExt4File, args[0]) + } + return customErr + } + + _, err := EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Sparse_Copy_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when the sparse copy fails. Verify that the arguments + // passed to it are the right ones. + + blockDeviceSize := int64(1024 * 1024 * 1024) + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + return nil + } + _cryptsetupClose = func(deviceName string) error { + return nil + } + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + // Return a non-zero size + return blockDeviceSize, nil + } + _createSparseEmptyFile = func(ctx context.Context, path string, size int64) error { + return nil + } + _mkfsExt4Command = func(args []string) error { + return nil + } + + source := "/dev/sda" + tempExt4File := tempDir + "ext4.img" + uniqueName, _ := getUniqueName(source) + deviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + deviceNamePath := "/dev/mapper/" + deviceName + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "failed to do sparse copy") + + _copyEmptySparseFilesystem = func(source string, destination string) error { + if source != tempExt4File { + t.Fatalf("expected source: '%v' got: '%v'", tempExt4File, source) + } + if destination != deviceNamePath { + t.Fatalf("expected destination: '%v' got: '%v'", deviceNamePath, destination) + } + return customErr + } + + _, err := EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Success(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when everything goes right. + + blockDeviceSize := int64(1024 * 1024 * 1024) + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + return nil + } + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + // Return a non-zero size + return blockDeviceSize, nil + } + _createSparseEmptyFile = func(ctx context.Context, path string, size int64) error { + return nil + } + _mkfsExt4Command = func(args []string) error { + return nil + } + _copyEmptySparseFilesystem = func(source string, destination string) error { + return nil + } + + source := "/dev/sda" + uniqueName, _ := getUniqueName(source) + deviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + deviceNamePath := "/dev/mapper/" + deviceName + + encryptedSource, err := EncryptDevice(context.Background(), source) + if err != nil { + t.Fatalf("unexpected err: '%v'", err) + } + if deviceNamePath != encryptedSource { + t.Fatalf("expected path: '%v' got: '%v'", deviceNamePath, encryptedSource) + } +} + +func Test_Cleanup_Dm_Crypt_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when cryptsetup fails to remove an encrypted device. + // Verify that the arguments passed to cryptsetup are the right ones. + + source := "/dev/sda" + uniqueName, _ := getUniqueName(source) + expectedDeviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "luksClose failed: %s", expectedDeviceName) + + _cryptsetupClose = func(deviceName string) error { + if deviceName != expectedDeviceName { + t.Fatalf("expected device name: '%s' got: '%s'", expectedDeviceName, deviceName) + } + return customErr + } + + err := CleanupCryptDevice(source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Cleanup_Dm_Crypt_Success(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when cryptsetup succeedes to close an encrypted device. + + _cryptsetupClose = func(deviceName string) error { + return nil + } + + source := "/dev/sda" + err := CleanupCryptDevice(source) + if err != nil { + t.Fatalf("unexpected err: '%v'", err) + } +} diff --git a/internal/guest/storage/crypt/utilities.go b/internal/guest/storage/crypt/utilities.go new file mode 100644 index 0000000000..7fcecde0ac --- /dev/null +++ b/internal/guest/storage/crypt/utilities.go @@ -0,0 +1,172 @@ +// +build linux + +package crypt + +import ( + "context" + "crypto/rand" + "io" + "io/ioutil" + "os" + "regexp" + + "github.com/Microsoft/hcsshim/internal/log" + "github.com/pkg/errors" +) + +func getUniqueName(path string) (name string, err error) { + // Make a Regex to say we only want letters and numbers + reg, err := regexp.Compile("[^a-zA-Z0-9]+") + if err != nil { + return "", err + } + // Replace all non-alphanumeric characters by dashes + return reg.ReplaceAllString(path, "-"), nil +} + +// getBlockDeviceSize returns the size of the specified block device. +func getBlockDeviceSize(ctx context.Context, path string) (int64, error) { + file, err := os.Open(path) + if err != nil { + return 0, errors.Wrap(err, "error opening: "+path) + } + + defer func() { + if err := file.Close(); err != nil { + log.G(ctx).WithError(err).Debug("error closing: " + path) + } + }() + + pos, err := file.Seek(0, io.SeekEnd) + if err != nil { + return 0, errors.Wrap(err, "error seeking end of: "+path) + } + + return pos, nil +} + +// createSparseEmptyFile creates a sparse file of the specified size. The whole +// file is empty, so the size on disk is zero, only the logical size is the +// specified one. +func createSparseEmptyFile(ctx context.Context, path string, size int64) (err error) { + f, err := os.Create(path) + if err != nil { + return errors.Wrapf(err, "failed to create: %s", path) + } + + defer func() { + if err != nil { + if inErr := os.RemoveAll(path); inErr != nil { + log.G(ctx).WithError(inErr).Debug("failed to delete: " + path) + } + } + }() + + defer func() { + if err := f.Close(); err != nil { + log.G(ctx).WithError(err).Debug("failed to close: " + path) + } + }() + + if err := f.Truncate(size); err != nil { + return errors.Wrapf(err, "failed to truncate: %s", path) + } + + return nil +} + +// The following constants aren't defined in the io or os libraries. +const ( + SEEK_DATA = 3 + SEEK_HOLE = 4 +) + +// copyEmptySparseFilesystem copies data chunks of a sparse source file into a +// destination file. It skips holes. Note that this is intended to copy a +// filesystem that has just been generated, so it only contains metadata blocks. +// Because of that, the source file must end with a hole. If it ends with data, +// the last chunk of data won't be copied. +func copyEmptySparseFilesystem(source string, destination string) error { + fin, err := os.OpenFile(source, os.O_RDONLY, 0) + if err != nil { + return errors.Wrap(err, "failed to open source file") + } + defer fin.Close() + + fout, err := os.OpenFile(destination, os.O_WRONLY, 0) + if err != nil { + return errors.Wrap(err, "failed to open destination file") + } + defer fout.Close() + + finInfo, err := fin.Stat() + if err != nil { + return errors.Wrap(err, "failed to stat source file") + } + + finSize := finInfo.Size() + + var offset int64 = 0 + for { + // Exit when the end of the file is reached + if offset >= finSize { + break + } + + // Calculate bounds of the next data chunk + chunkStart, err := fin.Seek(offset, SEEK_DATA) + if (err != nil) || (chunkStart == -1) { + // No more chunks left + break + } + chunkEnd, err := fin.Seek(chunkStart, SEEK_HOLE) + if (err != nil) || (chunkEnd == -1) { + break + } + chunkSize := chunkEnd - chunkStart + offset = chunkEnd + + // Read contents of this data chunk + _, err = fin.Seek(chunkStart, os.SEEK_SET) + if err != nil { + return errors.Wrap(err, "failed to seek set in source file") + } + + chunkData := make([]byte, chunkSize) + count, err := fin.Read(chunkData) + if err != nil { + return errors.Wrap(err, "failed to read source file") + } + if int64(count) != chunkSize { + return errors.Wrap(err, "not enough data read from source file") + } + + // Write data to destination file + _, err = fout.Seek(chunkStart, os.SEEK_SET) + if err != nil { + return errors.Wrap(err, "failed to seek destination file") + } + _, err = fout.Write(chunkData) + if err != nil { + return errors.Wrap(err, "failed to write destination file") + } + } + + return nil +} + +// generateKeyFile generates a file with random values. +func generateKeyFile(path string, size int64) error { + // The crypto.rand interface generates random numbers using /dev/urandom + keyArray := make([]byte, size) + _, err := rand.Read(keyArray[:]) + if err != nil { + return errors.Wrap(err, "failed to generate key array") + } + + if err := ioutil.WriteFile(path, keyArray[:], 0644); err != nil { + return errors.Wrap(err, "failed to save key to file") + } + + return nil +} diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index 5a5841015c..5ea3f56211 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Microsoft/hcsshim/internal/guest/storage" + "github.com/Microsoft/hcsshim/internal/guest/storage/crypt" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" "github.com/pkg/errors" @@ -33,7 +34,10 @@ var ( // // `target` will be created. On mount failure the created `target` will be // automatically cleaned up. -func Mount(ctx context.Context, controller, lun uint8, target string, readonly bool, options []string) (err error) { +// +// If `encrypted` is set to true, the SCSI device will be encrypted using +// dm-crypt. +func Mount(ctx context.Context, controller, lun uint8, target string, readonly bool, encrypted bool, options []string) (err error) { ctx, span := trace.StartSpan(ctx, "scsi::Mount") defer span.End() defer func() { oc.SetSpanStatus(span, err) }() @@ -62,6 +66,15 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b flags |= unix.MS_RDONLY data = "noload" } + + if encrypted { + encryptedSource, err := crypt.EncryptDevice(ctx, source) + if err != nil { + return errors.Wrapf(err, "failed to mount encrypted device: "+source) + } + source = encryptedSource + } + for { if err := unixMount(source, target, "ext4", flags, data); err != nil { // The `source` found by controllerLunToName can take some time @@ -94,6 +107,33 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b return nil } +// Unmount unmounts a SCSI device mounted at `target`. +// +// If `encrypted` is true, it removes all its associated dm-crypto state. +func Unmount(ctx context.Context, controller, lun uint8, target string, encrypted bool) (err error) { + ctx, span := trace.StartSpan(ctx, "scsi::Unmount") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.Int64Attribute("controller", int64(controller)), + trace.Int64Attribute("lun", int64(lun)), + trace.StringAttribute("target", target)) + + // Unmount unencrypted device + if err := storage.UnmountPath(ctx, target, true); err != nil { + return errors.Wrapf(err, "unmount failed: "+target) + } + + if encrypted { + if err := crypt.CleanupCryptDevice(target); err != nil { + return errors.Wrapf(err, "failed to cleanup dm-crypt state: "+target) + } + } + + return nil +} + // ControllerLunToName finds the `/dev/sd*` path to the SCSI device on // `controller` index `lun`. func ControllerLunToName(ctx context.Context, controller, lun uint8) (_ string, err error) { diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index 3969ec1885..13359f10ee 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -25,7 +25,7 @@ func Test_Mount_Mkdir_Fails_Error(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return expectedErr } - err := Mount(context.Background(), 0, 0, "", false, nil) + err := Mount(context.Background(), 0, 0, "", false, false, nil) if err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } @@ -52,7 +52,7 @@ func Test_Mount_Mkdir_ExpectedPath(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, 0, target, false, nil) + err := Mount(context.Background(), 0, 0, target, false, false, nil) if err != nil { t.Fatalf("expected nil error got: %v", err) } @@ -79,7 +79,7 @@ func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, 0, target, false, nil) + err := Mount(context.Background(), 0, 0, target, false, false, nil) if err != nil { t.Fatalf("expected nil error got: %v", err) } @@ -106,7 +106,7 @@ func Test_Mount_ControllerLunToName_Valid_Controller(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), expectedController, 0, "/fake/path", false, nil) + err := Mount(context.Background(), expectedController, 0, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil error got: %v", err) } @@ -133,7 +133,7 @@ func Test_Mount_ControllerLunToName_Valid_Lun(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, expectedLun, "/fake/path", false, nil) + err := Mount(context.Background(), 0, expectedLun, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil error got: %v", err) } @@ -163,7 +163,7 @@ func Test_Mount_Calls_RemoveAll_OnControllerToLunFailure(t *testing.T) { // NOTE: Do NOT set unixMount because the controller to lun fails. Expect it // not to be called. - err := Mount(context.Background(), 0, 0, target, false, nil) + err := Mount(context.Background(), 0, 0, target, false, false, nil) if err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } @@ -196,7 +196,7 @@ func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) { // Fake the mount failure to test remove is called return expectedErr } - err := Mount(context.Background(), 0, 0, target, false, nil) + err := Mount(context.Background(), 0, 0, target, false, false, nil) if err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } @@ -225,7 +225,7 @@ func Test_Mount_Valid_Source(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -251,7 +251,7 @@ func Test_Mount_Valid_Target(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, expectedTarget, false, nil) + err := Mount(context.Background(), 0, 0, expectedTarget, false, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -277,7 +277,7 @@ func Test_Mount_Valid_FSType(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -303,7 +303,7 @@ func Test_Mount_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -329,7 +329,7 @@ func Test_Mount_Readonly_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", true, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", true, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -354,7 +354,7 @@ func Test_Mount_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -380,7 +380,7 @@ func Test_Mount_Readonly_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", true, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", true, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) }