From 69994fcda2b604aea51737de887c205cdfc48306 Mon Sep 17 00:00:00 2001 From: Antonio Nino Diaz Date: Mon, 28 Jun 2021 12:00:02 +0100 Subject: [PATCH] Add support to encrypt SCSI scratch disks with dm-crypt This protects the files generated by the guest from the host OS, as they are encrypted by a key that the host doesn't know. This commit adds a new argument to the scsi.Mount() function, `encrypted`, that makes the SCSI drive be mounted using dm-crypt. It also uses dm-integrity for integrity checking. This makes the boot process a couple of seconds slower. Also, it adds scsi.Unmount(), which also has the `encrypted` argument, and it does the necessary cleanup for a drive that has been mounted as an encrypted drive. All the pre-existing SCSI tests have been fixed to work with the new scsi.Mount() function prototype. New tests have been added for the new code. This is all disabled for now, it has to be enabled in a future patch. Important note: This depends on cryptsetup and mkfs.ext4. Also, the kernel must be compiled with dm-crypt and dm-integrity support. --- internal/guest/runtime/hcsv2/uvm.go | 4 +- internal/guest/storage/crypt/crypt.go | 257 ++++++++++++ internal/guest/storage/crypt/crypt_test.go | 437 +++++++++++++++++++++ internal/guest/storage/crypt/utilities.go | 172 ++++++++ internal/guest/storage/scsi/scsi.go | 42 +- internal/guest/storage/scsi/scsi_test.go | 28 +- 6 files changed, 923 insertions(+), 17 deletions(-) create mode 100644 internal/guest/storage/crypt/crypt.go create mode 100644 internal/guest/storage/crypt/crypt_test.go create mode 100644 internal/guest/storage/crypt/utilities.go diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 1bd360ef82..44875acb14 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -355,12 +355,12 @@ func modifyMappedVirtualDisk(ctx context.Context, rt prot.ModifyRequestType, mvd mountCtx, cancel := context.WithTimeout(ctx, time.Second*4) defer cancel() if mvd.MountPath != "" { - return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.MountPath, mvd.ReadOnly, mvd.Options) + return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.MountPath, mvd.ReadOnly, false, mvd.Options) } return nil case prot.MreqtRemove: if mvd.MountPath != "" { - if err := storage.UnmountPath(ctx, mvd.MountPath, true); err != nil { + if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.MountPath, false); err != nil { return err } } diff --git a/internal/guest/storage/crypt/crypt.go b/internal/guest/storage/crypt/crypt.go new file mode 100644 index 0000000000..607ce18f92 --- /dev/null +++ b/internal/guest/storage/crypt/crypt.go @@ -0,0 +1,257 @@ +// +build linux + +package crypt + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + + "github.com/Microsoft/hcsshim/internal/log" + "github.com/pkg/errors" +) + +// Test dependencies +var ( + _copyEmptySparseFilesystem = copyEmptySparseFilesystem + _createSparseEmptyFile = createSparseEmptyFile + _cryptsetupClose = cryptsetupClose + _cryptsetupFormat = cryptsetupFormat + _cryptsetupOpen = cryptsetupOpen + _generateKeyFile = generateKeyFile + _getBlockDeviceSize = getBlockDeviceSize + _ioutilTempDir = ioutil.TempDir + _mkfsExt4Command = mkfsExt4Command + _osRemoveAll = os.RemoveAll +) + +// String used to identify dm-crypt devices. The argument is a unique name based +// on the original block device path. +const cryptDeviceTemplate string = "dm-crypt-%s" + +// cryptsetupCommand runs cryptsetup with the provided arguments +func cryptsetupCommand(args []string) error { + // --debug and -v are used to increase the information printed by + // cryptsetup. By default, it doesn't print much information, which makes it + // hard to debug it when there are problems. + cmd := exec.Command("cryptsetup", append([]string{"--debug", "-v"}, args...)...) + output, err := cmd.CombinedOutput() + if err != nil { + return errors.Wrapf(err, "failed to execute cryptsetup: %s", string(output)) + } + return nil +} + +// cryptsetupFormat runs "cryptsetup luksFormat" with the right arguments to use +// dm-crypt and dm-integrity. +func cryptsetupFormat(source string, keyFilePath string) error { + formatArgs := []string{ + // Mount source using LUKS2 + "luksFormat", source, "--type", "luks2", + // Provide keyfile and prevent showing the confirmation prompt + "--key-file", keyFilePath, "--batch-mode", + // dm-crypt and dm-integrity algorithms. The dm-crypt algorithm is the + // default one used for LUKS. The dm-integrity is the one the + // documentation mentions as one of the combinations they use for + // testing: + // https://gitlab.com/cryptsetup/cryptsetup/-/blob/a0277d3ff6ab7d5c9e0534f25b4b40719e999c8e/docs/v2.0.0-ReleaseNotes#L259-261 + "--cipher", "aes-xts-plain64", "--integrity", "hmac-sha256", + // See EncryptDevice() for the reason of using --integrity-no-wipe + "--integrity-no-wipe", + // Use 4KB sectors, the documentation mentions it can improve + // performance than smaller sizes. + "--sector-size", "4096", + // Force PBKDF2 and a specific number of iterations to skip the + // benchmarking step of luksFormat. Using a KDF is required by + // cryptsetup. The reason why it is mandatory to use a KDF is that + // cryptsetup expects the user to input a passphrase and cryptsetup is + // supposed to derive a strong key from it. In our case, we already pass + // a strong key to cryptsetup, so we don't need a strong KDF. Ideally, + // it would be bypassed completely, but this isn't possible. + "--pbkdf", "pbkdf2", "--pbkdf-force-iterations", "1000"} + + return cryptsetupCommand(formatArgs) +} + +// cryptsetupOpen runs "cryptsetup luksOpen" with the right arguments. +func cryptsetupOpen(source string, deviceName string, keyFilePath string) error { + openArgs := []string{ + // Open device with the key passed to luksFormat + "luksOpen", source, deviceName, "--key-file", keyFilePath, + // Don't use a journal to increase performance + "--integrity-no-journal", "--persistent"} + + return cryptsetupCommand(openArgs) +} + +// cryptsetupClose runs "cryptsetup luksClose" with the right arguments. +func cryptsetupClose(deviceName string) error { + closeArgs := []string{"luksClose", deviceName} + + return cryptsetupCommand(closeArgs) +} + +// mkfsExt4Command runs mkfs.ext4 with the provided arguments +func mkfsExt4Command(args []string) error { + cmd := exec.Command("mkfs.ext4", args...) + output, err := cmd.CombinedOutput() + if err != nil { + return errors.Wrapf(err, "failed to execute mkfs.ext4: %s", string(output)) + } + return nil +} + +// EncryptDevice creates a dm-crypt target for a container scratch vhd. +// +// In order to mount a block device as an encrypted device: +// +// 1. Generate a random key. It doesn't matter which key it is, the aim is to +// protect the contents of the scratch disk from the host OS. It can be +// deleted after mounting the encrypted device. +// +// 2. The original block device has to be formatted with cryptsetup with the +// generated key. This results in that block device becoming an encrypted +// block device that can't be mounted directly. +// +// 3. Open the block device with cryptsetup. It is needed to assign it a device +// name. We are using names that follow `cryptDeviceTemplate`, where "%s" is +// a unique name generated from the path of the original block device. In +// this case, it's just the path of the block device with all +// non-alphanumeric characters replaced by a '-'. +// +// The kernel exposes the unencrypted block device at the path +// /dev/mapper/`cryptDeviceTemplate`. This can be mounted directly, but it +// doesn't have any format yet. +// +// 4. Format the unencrypted block device as ext4: +// +// A normal invocation of luksFormat wipes the target device. This takes +// a really long time, which isn't acceptable in our use-case. Passing the +// option --integrity-no-wipe prevents this from happening so that the +// command ends in an instant. +// +// Because of using --integrity-no-wipe, the resulting device isn't wiped and +// all the integrity tags are incorrect. This means that any attempt to read +// from it will cause an I/O error, which programs aren't prepared to handle. +// For example, mkfs.ext4 tries to read blocks before writing to them, and +// there is no way around it. When it gets an I/O error, it just exits. +// +// The solution is to create a file with the same size as the resulting +// device, format it as ext4, then use dd to copy the format to the device +// (dd won't try to read anything). +// +// However, creating a file that is several GB in size isn't a good solution +// either because doing dd of the whole file would take as long as letting +// luksFormat wipe the disk. +// +// The solution is to create a sparse file and format it. Then, it is +// possible to copy the format to the block device by doing a sparse copy +// (only copy the data parts of the file, not the holes). This makes +// formatting the device almost instantaneous. +// +// 4.1. Get size of scratch disk. +// +// 4.2. Create sparse filesystem image with the same size as the scratch +// device. It can be removed afterwards. +// +// 4.3. Format it as ext4. This way the file is only as big as the few blocks +// of the image that have the filesystem information, the ones modified +// by mkfs.ext4. +// +// 4.4. Do a sparse copy of the filesystem into the unencrypted block device. +// This updates the integrity tags. +func EncryptDevice(ctx context.Context, source string) (path string, err error) { + + uniqueName, err := getUniqueName(source) + if err != nil { + return "", errors.Wrapf(err, "failed to generate unique name: %s", source) + } + + // Create temporary directory to store the keyfile and EXT4 image + tempDir, err := _ioutilTempDir("", "dm-crypt") + if err != nil { + return "", errors.Wrapf(err, "failed to create temporary folder: %s", source) + } + + defer func() { + // Delete it on exit, it won't be needed afterwards + if err := _osRemoveAll(tempDir); err != nil { + log.G(ctx).WithError(err).Debugf("failed to delete temporary folder: %s", tempDir) + } + }() + + // 1. Generate keyfile + keyFilePath := filepath.Join(tempDir, "keyfile") + + if err = _generateKeyFile(keyFilePath, 1024); err != nil { + return "", errors.Wrapf(err, "failed to generate keyfile: %s", keyFilePath) + } + + // 2. Format device + if err = _cryptsetupFormat(source, keyFilePath); err != nil { + return "", errors.Wrapf(err, "luksFormat failed: %s", source) + } + + // 3. Open device + deviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + if err := _cryptsetupOpen(source, deviceName, keyFilePath); err != nil { + return "", errors.Wrapf(err, "luksOpen failed: %s", source) + } + + defer func() { + if err != nil { + if inErr := CleanupCryptDevice(source); inErr != nil { + log.G(ctx).WithError(inErr).Debug("failed to cleanup crypt device") + } + } + }() + + deviceNamePath := "/dev/mapper/" + deviceName + + // 4.1. Get actual size of the scratch device + deviceSize, err := _getBlockDeviceSize(ctx, deviceNamePath) + if err != nil { + return "", errors.Wrapf(err, "error getting size of: %s", deviceNamePath) + } + + if deviceSize == 0 { + return "", fmt.Errorf("invalid size obtained for: %s", deviceNamePath) + } + + // 4.2. Create sparse filesystem image + tempExt4File := filepath.Join(tempDir, "ext4.img") + + if err = _createSparseEmptyFile(ctx, tempExt4File, deviceSize); err != nil { + return "", errors.Wrap(err, "failed to create sparse filesystem file") + } + + // 4.3. Format it as ext4 + if err = _mkfsExt4Command([]string{tempExt4File}); err != nil { + return "", errors.Wrapf(err, "mkfs.ext4 failed to format: %s", tempExt4File) + } + + // 4.4. Sparse copy of the filesystem into the encrypted block device + if err = _copyEmptySparseFilesystem(tempExt4File, deviceNamePath); err != nil { + return "", errors.Wrap(err, "failed to do sparse copy") + } + + return deviceNamePath, nil +} + +// CleanupCryptDevice removes the dm-crypt device created by EncryptDevice +func CleanupCryptDevice(source string) error { + uniqueName, err := getUniqueName(source) + if err != nil { + return errors.Wrapf(err, "failed to generate unique name: %s", source) + } + + // Close dm-crypt device + deviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + if err := _cryptsetupClose(deviceName); err != nil { + return errors.Wrapf(err, "luksClose failed: %s", deviceName) + } + return nil +} diff --git a/internal/guest/storage/crypt/crypt_test.go b/internal/guest/storage/crypt/crypt_test.go new file mode 100644 index 0000000000..c7bab94486 --- /dev/null +++ b/internal/guest/storage/crypt/crypt_test.go @@ -0,0 +1,437 @@ +// +build linux + +package crypt + +import ( + "context" + "fmt" + "testing" + + "github.com/pkg/errors" +) + +const tempDir = "/tmp/dir/" + +func ioutilTempDirTest(dir string, pattern string) (string, error) { + return tempDir, nil +} + +func clearCryptTestDependencies() { + _copyEmptySparseFilesystem = nil + _createSparseEmptyFile = nil + _cryptsetupClose = nil + _cryptsetupFormat = nil + _cryptsetupOpen = nil + _generateKeyFile = nil + _getBlockDeviceSize = nil + _mkfsExt4Command = nil + _ioutilTempDir = ioutilTempDirTest + _osRemoveAll = nil +} + +func Test_Encrypt_Generate_Key_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when key generation fails for any reason. Verify that + // the generated keyfile path has a number that matches the index value. + + source := "/dev/sda" + keyfilePath := tempDir + "keyfile" + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "failed to generate keyfile: %s", keyfilePath) + + _osRemoveAll = func(path string) error { + return nil + } + _generateKeyFile = func(path string, size int64) error { + if keyfilePath != path { + t.Errorf("expected path: %v, got: %v", keyfilePath, path) + } + return customErr + } + + _, err := EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Cryptsetup_Format_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when cryptsetup fails to format the device. Verify that + // the arguments passed to cryptsetup are the right ones. + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + + expectedSource := "/dev/sda" + expectedKeyFilePath := tempDir + "keyfile" + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "luksFormat failed: %s", expectedSource) + + _cryptsetupFormat = func(source string, keyFilePath string) error { + if source != expectedSource { + t.Fatalf("expected source: '%s' got: '%s'", expectedSource, source) + } + if keyFilePath != expectedKeyFilePath { + t.Fatalf("expected keyfile path: '%s' got: '%s'", expectedKeyFilePath, keyFilePath) + } + return customErr + } + + _, err := EncryptDevice(context.Background(), expectedSource) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Cryptsetup_Open_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when cryptsetup fails to open the device. Verify that + // the arguments passed to cryptsetup are the right ones. + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + + expectedSource := "/dev/sda" + uniqueName, _ := getUniqueName(expectedSource) + expectedDeviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + expectedKeyFilePath := tempDir + "keyfile" + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "luksOpen failed: %s", expectedSource) + + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + if source != expectedSource { + t.Fatalf("expected source: '%s' got: '%s'", expectedSource, source) + } + if deviceName != expectedDeviceName { + t.Fatalf("expected device name: '%s' got: '%s'", expectedDeviceName, deviceName) + } + if keyFilePath != expectedKeyFilePath { + t.Fatalf("expected keyfile path: '%s' got: '%s'", expectedKeyFilePath, keyFilePath) + } + return customErr + } + + _, err := EncryptDevice(context.Background(), expectedSource) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Get_Device_Size_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when cryptsetup fails to get the size of the + // unencrypted block device. + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + return nil + } + _cryptsetupClose = func(deviceName string) error { + return nil + } + + source := "/dev/sda" + uniqueName, _ := getUniqueName(source) + deviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + deviceNamePath := "/dev/mapper/" + deviceName + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "error getting size of: %s", deviceNamePath) + + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + return 0, customErr + } + + _, err := EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } + + // Check that it fails when the size of the block device is zero + + expectedErr = fmt.Errorf("invalid size obtained for: %s", deviceNamePath) + + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + return 0, nil + } + + _, err = EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Create_Sparse_File_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when it isn't possible to create a sparse file, and + // make sure that _createSparseEmptyFile receives the right arguments. + + blockDeviceSize := int64(1024 * 1024 * 1024) + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + return nil + } + _cryptsetupClose = func(deviceName string) error { + return nil + } + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + // Return a non-zero size + return blockDeviceSize, nil + } + + source := "/dev/sda" + tempExt4File := tempDir + "ext4.img" + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "failed to create sparse filesystem file") + + _createSparseEmptyFile = func(ctx context.Context, path string, size int64) error { + // Check that the path and the size are the expected ones + if path != tempExt4File { + t.Fatalf("expected path: '%v' got: '%v'", tempExt4File, path) + } + if size != blockDeviceSize { + t.Fatalf("expected size: '%v' got: '%v'", blockDeviceSize, size) + } + + return customErr + } + + _, err := EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Mkfs_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when mkfs fails to format the unencrypted device. + // Verify that the arguments passed to it are the right ones. + + blockDeviceSize := int64(1024 * 1024 * 1024) + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + return nil + } + _cryptsetupClose = func(deviceName string) error { + return nil + } + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + // Return a non-zero size + return blockDeviceSize, nil + } + _createSparseEmptyFile = func(ctx context.Context, path string, size int64) error { + return nil + } + + source := "/dev/sda" + tempExt4File := tempDir + "ext4.img" + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "mkfs.ext4 failed to format: %s", tempExt4File) + + _mkfsExt4Command = func(args []string) error { + if args[0] != tempExt4File { + t.Fatalf("expected args:\n'%v'\ngot:\n'%v'", tempExt4File, args[0]) + } + return customErr + } + + _, err := EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Sparse_Copy_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when the sparse copy fails. Verify that the arguments + // passed to it are the right ones. + + blockDeviceSize := int64(1024 * 1024 * 1024) + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + return nil + } + _cryptsetupClose = func(deviceName string) error { + return nil + } + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + // Return a non-zero size + return blockDeviceSize, nil + } + _createSparseEmptyFile = func(ctx context.Context, path string, size int64) error { + return nil + } + _mkfsExt4Command = func(args []string) error { + return nil + } + + source := "/dev/sda" + tempExt4File := tempDir + "ext4.img" + uniqueName, _ := getUniqueName(source) + deviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + deviceNamePath := "/dev/mapper/" + deviceName + + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "failed to do sparse copy") + + _copyEmptySparseFilesystem = func(source string, destination string) error { + if source != tempExt4File { + t.Fatalf("expected source: '%v' got: '%v'", tempExt4File, source) + } + if destination != deviceNamePath { + t.Fatalf("expected destination: '%v' got: '%v'", deviceNamePath, destination) + } + return customErr + } + + _, err := EncryptDevice(context.Background(), source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Encrypt_Success(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when everything goes right. + + blockDeviceSize := int64(1024 * 1024 * 1024) + + _generateKeyFile = func(path string, size int64) error { + return nil + } + _osRemoveAll = func(path string) error { + return nil + } + _cryptsetupFormat = func(source string, keyFilePath string) error { + return nil + } + _cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error { + return nil + } + _getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) { + // Return a non-zero size + return blockDeviceSize, nil + } + _createSparseEmptyFile = func(ctx context.Context, path string, size int64) error { + return nil + } + _mkfsExt4Command = func(args []string) error { + return nil + } + _copyEmptySparseFilesystem = func(source string, destination string) error { + return nil + } + + source := "/dev/sda" + uniqueName, _ := getUniqueName(source) + deviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + deviceNamePath := "/dev/mapper/" + deviceName + + encryptedSource, err := EncryptDevice(context.Background(), source) + if err != nil { + t.Fatalf("unexpected err: '%v'", err) + } + if deviceNamePath != encryptedSource { + t.Fatalf("expected path: '%v' got: '%v'", deviceNamePath, encryptedSource) + } +} + +func Test_Cleanup_Dm_Crypt_Error(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when cryptsetup fails to remove an encrypted device. + // Verify that the arguments passed to cryptsetup are the right ones. + + source := "/dev/sda" + uniqueName, _ := getUniqueName(source) + expectedDeviceName := fmt.Sprintf(cryptDeviceTemplate, uniqueName) + customErr := errors.New("expected error message") + expectedErr := errors.Wrapf(customErr, "luksClose failed: %s", expectedDeviceName) + + _cryptsetupClose = func(deviceName string) error { + if deviceName != expectedDeviceName { + t.Fatalf("expected device name: '%s' got: '%s'", expectedDeviceName, deviceName) + } + return customErr + } + + err := CleanupCryptDevice(source) + if err.Error() != expectedErr.Error() { + t.Fatalf("expected err:\n'%v'\ngot:\n'%v'", expectedErr, err) + } +} + +func Test_Cleanup_Dm_Crypt_Success(t *testing.T) { + clearCryptTestDependencies() + + // Test what happens when cryptsetup succeedes to close an encrypted device. + + _cryptsetupClose = func(deviceName string) error { + return nil + } + + source := "/dev/sda" + err := CleanupCryptDevice(source) + if err != nil { + t.Fatalf("unexpected err: '%v'", err) + } +} diff --git a/internal/guest/storage/crypt/utilities.go b/internal/guest/storage/crypt/utilities.go new file mode 100644 index 0000000000..7fcecde0ac --- /dev/null +++ b/internal/guest/storage/crypt/utilities.go @@ -0,0 +1,172 @@ +// +build linux + +package crypt + +import ( + "context" + "crypto/rand" + "io" + "io/ioutil" + "os" + "regexp" + + "github.com/Microsoft/hcsshim/internal/log" + "github.com/pkg/errors" +) + +func getUniqueName(path string) (name string, err error) { + // Make a Regex to say we only want letters and numbers + reg, err := regexp.Compile("[^a-zA-Z0-9]+") + if err != nil { + return "", err + } + // Replace all non-alphanumeric characters by dashes + return reg.ReplaceAllString(path, "-"), nil +} + +// getBlockDeviceSize returns the size of the specified block device. +func getBlockDeviceSize(ctx context.Context, path string) (int64, error) { + file, err := os.Open(path) + if err != nil { + return 0, errors.Wrap(err, "error opening: "+path) + } + + defer func() { + if err := file.Close(); err != nil { + log.G(ctx).WithError(err).Debug("error closing: " + path) + } + }() + + pos, err := file.Seek(0, io.SeekEnd) + if err != nil { + return 0, errors.Wrap(err, "error seeking end of: "+path) + } + + return pos, nil +} + +// createSparseEmptyFile creates a sparse file of the specified size. The whole +// file is empty, so the size on disk is zero, only the logical size is the +// specified one. +func createSparseEmptyFile(ctx context.Context, path string, size int64) (err error) { + f, err := os.Create(path) + if err != nil { + return errors.Wrapf(err, "failed to create: %s", path) + } + + defer func() { + if err != nil { + if inErr := os.RemoveAll(path); inErr != nil { + log.G(ctx).WithError(inErr).Debug("failed to delete: " + path) + } + } + }() + + defer func() { + if err := f.Close(); err != nil { + log.G(ctx).WithError(err).Debug("failed to close: " + path) + } + }() + + if err := f.Truncate(size); err != nil { + return errors.Wrapf(err, "failed to truncate: %s", path) + } + + return nil +} + +// The following constants aren't defined in the io or os libraries. +const ( + SEEK_DATA = 3 + SEEK_HOLE = 4 +) + +// copyEmptySparseFilesystem copies data chunks of a sparse source file into a +// destination file. It skips holes. Note that this is intended to copy a +// filesystem that has just been generated, so it only contains metadata blocks. +// Because of that, the source file must end with a hole. If it ends with data, +// the last chunk of data won't be copied. +func copyEmptySparseFilesystem(source string, destination string) error { + fin, err := os.OpenFile(source, os.O_RDONLY, 0) + if err != nil { + return errors.Wrap(err, "failed to open source file") + } + defer fin.Close() + + fout, err := os.OpenFile(destination, os.O_WRONLY, 0) + if err != nil { + return errors.Wrap(err, "failed to open destination file") + } + defer fout.Close() + + finInfo, err := fin.Stat() + if err != nil { + return errors.Wrap(err, "failed to stat source file") + } + + finSize := finInfo.Size() + + var offset int64 = 0 + for { + // Exit when the end of the file is reached + if offset >= finSize { + break + } + + // Calculate bounds of the next data chunk + chunkStart, err := fin.Seek(offset, SEEK_DATA) + if (err != nil) || (chunkStart == -1) { + // No more chunks left + break + } + chunkEnd, err := fin.Seek(chunkStart, SEEK_HOLE) + if (err != nil) || (chunkEnd == -1) { + break + } + chunkSize := chunkEnd - chunkStart + offset = chunkEnd + + // Read contents of this data chunk + _, err = fin.Seek(chunkStart, os.SEEK_SET) + if err != nil { + return errors.Wrap(err, "failed to seek set in source file") + } + + chunkData := make([]byte, chunkSize) + count, err := fin.Read(chunkData) + if err != nil { + return errors.Wrap(err, "failed to read source file") + } + if int64(count) != chunkSize { + return errors.Wrap(err, "not enough data read from source file") + } + + // Write data to destination file + _, err = fout.Seek(chunkStart, os.SEEK_SET) + if err != nil { + return errors.Wrap(err, "failed to seek destination file") + } + _, err = fout.Write(chunkData) + if err != nil { + return errors.Wrap(err, "failed to write destination file") + } + } + + return nil +} + +// generateKeyFile generates a file with random values. +func generateKeyFile(path string, size int64) error { + // The crypto.rand interface generates random numbers using /dev/urandom + keyArray := make([]byte, size) + _, err := rand.Read(keyArray[:]) + if err != nil { + return errors.Wrap(err, "failed to generate key array") + } + + if err := ioutil.WriteFile(path, keyArray[:], 0644); err != nil { + return errors.Wrap(err, "failed to save key to file") + } + + return nil +} diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index 5a5841015c..5ea3f56211 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Microsoft/hcsshim/internal/guest/storage" + "github.com/Microsoft/hcsshim/internal/guest/storage/crypt" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" "github.com/pkg/errors" @@ -33,7 +34,10 @@ var ( // // `target` will be created. On mount failure the created `target` will be // automatically cleaned up. -func Mount(ctx context.Context, controller, lun uint8, target string, readonly bool, options []string) (err error) { +// +// If `encrypted` is set to true, the SCSI device will be encrypted using +// dm-crypt. +func Mount(ctx context.Context, controller, lun uint8, target string, readonly bool, encrypted bool, options []string) (err error) { ctx, span := trace.StartSpan(ctx, "scsi::Mount") defer span.End() defer func() { oc.SetSpanStatus(span, err) }() @@ -62,6 +66,15 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b flags |= unix.MS_RDONLY data = "noload" } + + if encrypted { + encryptedSource, err := crypt.EncryptDevice(ctx, source) + if err != nil { + return errors.Wrapf(err, "failed to mount encrypted device: "+source) + } + source = encryptedSource + } + for { if err := unixMount(source, target, "ext4", flags, data); err != nil { // The `source` found by controllerLunToName can take some time @@ -94,6 +107,33 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b return nil } +// Unmount unmounts a SCSI device mounted at `target`. +// +// If `encrypted` is true, it removes all its associated dm-crypto state. +func Unmount(ctx context.Context, controller, lun uint8, target string, encrypted bool) (err error) { + ctx, span := trace.StartSpan(ctx, "scsi::Unmount") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + span.AddAttributes( + trace.Int64Attribute("controller", int64(controller)), + trace.Int64Attribute("lun", int64(lun)), + trace.StringAttribute("target", target)) + + // Unmount unencrypted device + if err := storage.UnmountPath(ctx, target, true); err != nil { + return errors.Wrapf(err, "unmount failed: "+target) + } + + if encrypted { + if err := crypt.CleanupCryptDevice(target); err != nil { + return errors.Wrapf(err, "failed to cleanup dm-crypt state: "+target) + } + } + + return nil +} + // ControllerLunToName finds the `/dev/sd*` path to the SCSI device on // `controller` index `lun`. func ControllerLunToName(ctx context.Context, controller, lun uint8) (_ string, err error) { diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index 3969ec1885..13359f10ee 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -25,7 +25,7 @@ func Test_Mount_Mkdir_Fails_Error(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return expectedErr } - err := Mount(context.Background(), 0, 0, "", false, nil) + err := Mount(context.Background(), 0, 0, "", false, false, nil) if err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } @@ -52,7 +52,7 @@ func Test_Mount_Mkdir_ExpectedPath(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, 0, target, false, nil) + err := Mount(context.Background(), 0, 0, target, false, false, nil) if err != nil { t.Fatalf("expected nil error got: %v", err) } @@ -79,7 +79,7 @@ func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, 0, target, false, nil) + err := Mount(context.Background(), 0, 0, target, false, false, nil) if err != nil { t.Fatalf("expected nil error got: %v", err) } @@ -106,7 +106,7 @@ func Test_Mount_ControllerLunToName_Valid_Controller(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), expectedController, 0, "/fake/path", false, nil) + err := Mount(context.Background(), expectedController, 0, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil error got: %v", err) } @@ -133,7 +133,7 @@ func Test_Mount_ControllerLunToName_Valid_Lun(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, expectedLun, "/fake/path", false, nil) + err := Mount(context.Background(), 0, expectedLun, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil error got: %v", err) } @@ -163,7 +163,7 @@ func Test_Mount_Calls_RemoveAll_OnControllerToLunFailure(t *testing.T) { // NOTE: Do NOT set unixMount because the controller to lun fails. Expect it // not to be called. - err := Mount(context.Background(), 0, 0, target, false, nil) + err := Mount(context.Background(), 0, 0, target, false, false, nil) if err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } @@ -196,7 +196,7 @@ func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) { // Fake the mount failure to test remove is called return expectedErr } - err := Mount(context.Background(), 0, 0, target, false, nil) + err := Mount(context.Background(), 0, 0, target, false, false, nil) if err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } @@ -225,7 +225,7 @@ func Test_Mount_Valid_Source(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -251,7 +251,7 @@ func Test_Mount_Valid_Target(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, expectedTarget, false, nil) + err := Mount(context.Background(), 0, 0, expectedTarget, false, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -277,7 +277,7 @@ func Test_Mount_Valid_FSType(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -303,7 +303,7 @@ func Test_Mount_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -329,7 +329,7 @@ func Test_Mount_Readonly_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", true, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", true, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -354,7 +354,7 @@ func Test_Mount_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -380,7 +380,7 @@ func Test_Mount_Readonly_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", true, nil) + err := Mount(context.Background(), 0, 0, "/fake/path", true, false, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) }