From 912bf945de079f5ab3423b790eda9f7313948372 Mon Sep 17 00:00:00 2001 From: Maksim An Date: Mon, 20 Sep 2021 14:07:40 -0700 Subject: [PATCH 1/4] extend integrity protection of LCOW layers to SCSI devices LCOW layers can be added both as VPMem and as SCSI devices. Previous work focused on enabling integrity protection for read only VPMem layers, this change enables it for read-only SCSI devices as well. Just like in a VPMem scenario, create dm-verity target when verity information is presented to the guest during SCSI device mounting step. Additionally remove unnecessary unit test, since the guest logic has changed. Signed-off-by: Maksim An --- internal/guest/storage/scsi/scsi.go | 40 +++++++++++++++++-- internal/guest/storage/scsi/scsi_test.go | 38 +++--------------- internal/uvm/scsi.go | 2 +- .../Microsoft/hcsshim/internal/uvm/scsi.go | 2 +- 4 files changed, 43 insertions(+), 39 deletions(-) diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index c95ab9d555..2591e1121c 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -5,6 +5,7 @@ package scsi import ( "context" "fmt" + dm "github.com/Microsoft/hcsshim/internal/guest/storage/devicemapper" "io/ioutil" "os" "path/filepath" @@ -33,6 +34,7 @@ var ( const ( scsiDevicesPath = "/sys/bus/scsi/devices" + verityDeviceFmt = "verity-scsi-contr%d-lun%d-%s" ) // Mount creates a mount from the SCSI device on `controller` index `lun` to @@ -52,16 +54,43 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b trace.Int64Attribute("controller", int64(controller)), trace.Int64Attribute("lun", int64(lun))) + source, err := controllerLunToName(spnCtx, controller, lun) + if err != nil { + return err + } + if readonly { // containers only have read-only layers so only enforce for them var deviceHash string + verityHandler := func() error { + return nil + } if verityInfo != nil { deviceHash = verityInfo.RootDigest + verityHandler = func() error { + dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash) + if source, err = dm.CreateVerityTarget(ctx, source, dmVerityName, verityInfo); err != nil { + return err + } + defer func() { + if err != nil { + if err := dm.RemoveDevice(dmVerityName); err != nil { + log.G(spnCtx).WithError(err).WithField("verityTarget", dmVerityName).Debug("failed to cleanup verity target") + } + } + }() + return nil + } } + err = securityPolicy.EnforceDeviceMountPolicy(target, deviceHash) if err != nil { return errors.Wrapf(err, "won't mount scsi controller %d lun %d onto %s", controller, lun, target) } + + if err := verityHandler(); err != nil { + return err + } } if err := osMkdirAll(target, 0700); err != nil { @@ -72,10 +101,6 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b osRemoveAll(target) } }() - source, err := controllerLunToName(spnCtx, controller, lun) - if err != nil { - return err - } // we only care about readonly mount option when mounting the device var flags uintptr @@ -142,6 +167,13 @@ func Unmount(ctx context.Context, controller, lun uint8, target string, encrypte return errors.Wrapf(err, "unmounting scsi controller %d lun %d from %s denied by policy", controller, lun, target) } + if verityInfo != nil { + dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, verityInfo.RootDigest) + if err := dm.RemoveDevice(dmVerityName); err != nil { + return errors.Wrapf(err, "failed to remove dm verity target: %s", dmVerityName) + } + } + // Unmount unencrypted device if err := storage.UnmountPath(ctx, target, true); err != nil { return errors.Wrapf(err, "unmount failed: "+target) diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index 99cc77533d..23d3182aed 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -27,6 +27,11 @@ func Test_Mount_Mkdir_Fails_Error(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return expectedErr } + + controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + return "", nil + } + err := Mount(context.Background(), 0, 0, "", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) if err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) @@ -141,39 +146,6 @@ func Test_Mount_ControllerLunToName_Valid_Lun(t *testing.T) { } } -func Test_Mount_Calls_RemoveAll_OnControllerToLunFailure(t *testing.T) { - clearTestDependencies() - - osMkdirAll = func(path string, perm os.FileMode) error { - return nil - } - expectedErr := errors.New("expected controller to lun failure") - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { - return "", expectedErr - } - target := "/fake/path" - removeAllCalled := false - osRemoveAll = func(path string) error { - removeAllCalled = true - if path != target { - t.Errorf("expected path: %v, got: %v", target, path) - return errors.New("unexpected path") - } - return nil - } - - // NOTE: Do NOT set unixMount because the controller to lun fails. Expect it - // not to be called. - - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != expectedErr { - t.Fatalf("expected err: %v, got: %v", expectedErr, err) - } - if !removeAllCalled { - t.Fatal("expected os.RemoveAll to be called on mount failure") - } -} - func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) { clearTestDependencies() diff --git a/internal/uvm/scsi.go b/internal/uvm/scsi.go index 652cc5b964..f9587e908a 100644 --- a/internal/uvm/scsi.go +++ b/internal/uvm/scsi.go @@ -443,7 +443,7 @@ func (uvm *UtilityVM) addSCSIActual(ctx context.Context, addReq *addSCSIRequest) log.G(ctx).WithFields(logrus.Fields{ "hostPath": sm.HostPath, "rootDigest": v.RootDigest, - }).Debug("adding VPMem with dm-verity") + }).Debug("adding SCSI with dm-verity") } verity = v } diff --git a/test/vendor/github.com/Microsoft/hcsshim/internal/uvm/scsi.go b/test/vendor/github.com/Microsoft/hcsshim/internal/uvm/scsi.go index 652cc5b964..f9587e908a 100644 --- a/test/vendor/github.com/Microsoft/hcsshim/internal/uvm/scsi.go +++ b/test/vendor/github.com/Microsoft/hcsshim/internal/uvm/scsi.go @@ -443,7 +443,7 @@ func (uvm *UtilityVM) addSCSIActual(ctx context.Context, addReq *addSCSIRequest) log.G(ctx).WithFields(logrus.Fields{ "hostPath": sm.HostPath, "rootDigest": v.RootDigest, - }).Debug("adding VPMem with dm-verity") + }).Debug("adding SCSI with dm-verity") } verity = v } From b229f22a73c0292ae42f5e6d80e29c3e019e8ba0 Mon Sep 17 00:00:00 2001 From: Maksim An Date: Tue, 21 Sep 2021 17:56:48 -0700 Subject: [PATCH 2/4] tests: add pmem and scsi unit tests for linear/verity targets Signed-off-by: Maksim An --- internal/guest/storage/pmem/pmem.go | 22 +- internal/guest/storage/pmem/pmem_test.go | 293 +++++++++++++++++++++++ internal/guest/storage/scsi/scsi.go | 37 ++- internal/guest/storage/scsi/scsi_test.go | 98 ++++++++ 4 files changed, 421 insertions(+), 29 deletions(-) diff --git a/internal/guest/storage/pmem/pmem.go b/internal/guest/storage/pmem/pmem.go index 8af7207bad..9291db42f3 100644 --- a/internal/guest/storage/pmem/pmem.go +++ b/internal/guest/storage/pmem/pmem.go @@ -21,9 +21,13 @@ import ( // Test dependencies var ( - osMkdirAll = os.MkdirAll - osRemoveAll = os.RemoveAll - unixMount = unix.Mount + osMkdirAll = os.MkdirAll + osRemoveAll = os.RemoveAll + unixMount = unix.Mount + mountInternal = mount + createLinearTarget = dm.CreateZeroSectorLinearTarget + veritySetup = dm.CreateVerityTarget + removeDevice = dm.RemoveDevice ) const ( @@ -32,8 +36,8 @@ const ( verityDeviceFmt = "dm-verity-pmem%d-%s" ) -// mountInternal mounts source to target via unix.Mount -func mountInternal(ctx context.Context, source, target string) (err error) { +// mount mounts source to target via unix.Mount +func mount(ctx context.Context, source, target string) (err error) { if err := osMkdirAll(target, 0700); err != nil { return err } @@ -89,12 +93,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. // device instead of the original VPMem. if mappingInfo != nil { dmLinearName := fmt.Sprintf(linearDeviceFmt, device, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) - if devicePath, err = dm.CreateZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil { + if devicePath, err = createLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil { return err } defer func() { if err != nil { - if err := dm.RemoveDevice(dmLinearName); err != nil { + if err := removeDevice(dmLinearName); err != nil { log.G(mCtx).WithError(err).Debugf("failed to cleanup linear target: %s", dmLinearName) } } @@ -103,12 +107,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. if verityInfo != nil { dmVerityName := fmt.Sprintf(verityDeviceFmt, device, verityInfo.RootDigest) - if devicePath, err = dm.CreateVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil { + if devicePath, err = veritySetup(mCtx, devicePath, dmVerityName, verityInfo); err != nil { return err } defer func() { if err != nil { - if err := dm.RemoveDevice(dmVerityName); err != nil { + if err := removeDevice(dmVerityName); err != nil { log.G(mCtx).WithError(err).Debugf("failed to cleanup verity target: %s", dmVerityName) } } diff --git a/internal/guest/storage/pmem/pmem_test.go b/internal/guest/storage/pmem/pmem_test.go index d147ede5f2..2a23eaebff 100644 --- a/internal/guest/storage/pmem/pmem_test.go +++ b/internal/guest/storage/pmem/pmem_test.go @@ -5,6 +5,7 @@ package pmem import ( "context" "fmt" + "github.com/Microsoft/hcsshim/internal/guest/prot" "os" "testing" @@ -18,6 +19,10 @@ func clearTestDependencies() { osMkdirAll = nil osRemoveAll = nil unixMount = nil + createLinearTarget = nil + veritySetup = nil + removeDevice = nil + mountInternal = mount } func Test_Mount_Mkdir_Fails_Error(t *testing.T) { @@ -305,3 +310,291 @@ func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer { return &policy.MountMonitoringSecurityPolicyEnforcer{} } + +// device mapper tests +func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + mappingInfo := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedLinearName := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) + expectedSource := "/dev/pmem0" + expectedTarget := "/foo" + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName) + createLTCalled := false + + osMkdirAll = func(_ string, _ os.FileMode) error { + return nil + } + + mountInternal = func(_ context.Context, source, target string) error { + if source != mapperPath { + t.Errorf("expected mountInternal source %s, got %s", mapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected mountInternal target %s, got %s", expectedTarget, source) + } + return nil + } + + createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + createLTCalled = true + if source != expectedSource { + t.Errorf("expected createLinearTarget source %s, got %s", expectedSource, source) + } + if name != expectedLinearName { + t.Errorf("expected createLinearTarget name %s, got %s", expectedLinearName, name) + } + return mapperPath, nil + } + + if err := Mount( + context.Background(), 0, expectedTarget, mappingInfo, nil, openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount: %s", err) + } + if !createLTCalled { + t.Fatalf("createLinearTarget not called") + } +} + +func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest) + expectedSource := "/dev/pmem0" + expectedTarget := "/foo" + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) + veritySetupCalled := false + + mountInternal = func(_ context.Context, source, target string) error { + if source != mapperPath { + t.Errorf("expected mountInternal source %s, got %s", mapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected mountInternal target %s, got %s", expectedTarget, target) + } + return nil + } + veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + veritySetupCalled = true + if source != expectedSource { + t.Errorf("expected veritySetup source %s, got %s", expectedSource, source) + } + if name != expectedVerityName { + t.Errorf("expected veritySetup name %s, got %s", expectedVerityName, name) + } + return mapperPath, nil + } + + if err := Mount( + context.Background(), 0, expectedTarget, nil, verityInfo, openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected Mount failure: %s", err) + } + if !veritySetupCalled { + t.Fatal("veritySetup not called") + } +} + +func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) { + clearTestDependencies() + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + mapping := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes) + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest) + expectedPMemDevice := "/dev/pmem0" + mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget) + mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + dmLinearCalled := false + dmVerityCalled := false + mountCalled := false + + createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + dmLinearCalled = true + if source != expectedPMemDevice { + t.Errorf("expected createLinearTarget source %s, got %s", expectedPMemDevice, source) + } + if name != expectedLinearTarget { + t.Errorf("expected createLineartarget name %s, got %s", expectedLinearTarget, name) + } + return mapperLinearPath, nil + } + veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + dmVerityCalled = true + if source != mapperLinearPath { + t.Errorf("expected veritySetup source %s, got %s", mapperLinearPath, source) + } + if name != expectedVerityTarget { + t.Errorf("expected veritySetup target name %s, got %s", expectedVerityTarget, name) + } + return mapperVerityPath, nil + } + mountInternal = func(_ context.Context, source, target string) error { + mountCalled = true + if source != mapperVerityPath { + t.Errorf("expected Mount source %s, got %s", mapperVerityPath, source) + } + return nil + } + + if err := Mount( + context.Background(), 0, "/foo", mapping, verityInfo, openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount call: %s", err) + } + if !dmLinearCalled { + t.Fatal("expected createLinearTarget call") + } + if !dmVerityCalled { + t.Fatal("expected veritySetup call") + } + if !mountCalled { + t.Fatal("expected mountInternal call") + } +} + +func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + mappingInfo := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedError := errors.New("mountInternal error") + expectedTarget := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget) + removeDeviceCalled := false + + createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + return mapperPath, nil + } + mountInternal = func(_ context.Context, source, target string) error { + return expectedError + } + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedTarget { + t.Errorf("expected removeDevice linear target %s, got %s", expectedTarget, name) + } + return nil + } + + if err := Mount( + context.Background(), 0, "/foo", mappingInfo, nil, openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be callled") + } +} + +func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + verity := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest) + expectedError := errors.New("mountInternal error") + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + removeDeviceCalled := false + + veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + return mapperPath, nil + } + mountInternal = func(_ context.Context, _, _ string) error { + return expectedError + } + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedVerityTarget { + t.Errorf("expected removeDevice verity target %s, got %s", expectedVerityTarget, name) + } + return nil + } + + if err := Mount( + context.Background(), 0, "/foo", nil, verity, openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be called") + } +} + +func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + mapping := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + verity := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedError := errors.New("mountInternal error") + expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes) + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest) + expectedPMemDevice := "/dev/pmem0" + mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget) + mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + rmLinearCalled := false + rmVerityCalled := false + + createLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) { + if source != expectedPMemDevice { + t.Errorf("expected createLinearTarget source %s, got %s", expectedPMemDevice, source) + } + return mapperLinearPath, nil + } + veritySetup = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) { + if source != mapperLinearPath { + t.Errorf("expected veritySetup to be called with %s, got %s", mapperLinearPath, source) + } + if name != expectedVerityTarget { + t.Errorf("expected veritySetup target %s, got %s", expectedVerityTarget, name) + } + return mapperVerityPath, nil + } + removeDevice = func(name string) error { + if name != expectedLinearTarget && name != expectedVerityTarget { + t.Errorf("unexpected removeDevice target name %s", name) + } + if name == expectedLinearTarget { + rmLinearCalled = true + } + if name == expectedVerityTarget { + rmVerityCalled = true + } + return nil + } + mountInternal = func(_ context.Context, _, _ string) error { + return expectedError + } + + if err := Mount( + context.Background(), 0, "/foo", mapping, verity, openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !rmLinearCalled { + t.Fatal("expected removeDevice for linear target to be called") + } + if !rmVerityCalled { + t.Fatal("expected removeDevice for verity target to be called") + } +} diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index 2591e1121c..a1ac0db085 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -30,6 +30,10 @@ var ( // controllerLunToName is stubbed to make testing `Mount` easier. controllerLunToName = ControllerLunToName + // veritySetup is stubbed for unit testing `Mount` + veritySetup = dm.CreateVerityTarget + // removeDevice is stubbed for unit testing `Mount` + removeDevice = dm.RemoveDevice ) const ( @@ -62,25 +66,8 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b if readonly { // containers only have read-only layers so only enforce for them var deviceHash string - verityHandler := func() error { - return nil - } if verityInfo != nil { deviceHash = verityInfo.RootDigest - verityHandler = func() error { - dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash) - if source, err = dm.CreateVerityTarget(ctx, source, dmVerityName, verityInfo); err != nil { - return err - } - defer func() { - if err != nil { - if err := dm.RemoveDevice(dmVerityName); err != nil { - log.G(spnCtx).WithError(err).WithField("verityTarget", dmVerityName).Debug("failed to cleanup verity target") - } - } - }() - return nil - } } err = securityPolicy.EnforceDeviceMountPolicy(target, deviceHash) @@ -88,8 +75,18 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b return errors.Wrapf(err, "won't mount scsi controller %d lun %d onto %s", controller, lun, target) } - if err := verityHandler(); err != nil { - return err + if verityInfo != nil { + dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash) + if source, err = veritySetup(ctx, source, dmVerityName, verityInfo); err != nil { + return err + } + defer func() { + if err != nil { + if err := removeDevice(dmVerityName); err != nil { + log.G(spnCtx).WithError(err).WithField("verityTarget", dmVerityName).Debug("failed to cleanup verity target") + } + } + }() } } @@ -169,7 +166,7 @@ func Unmount(ctx context.Context, controller, lun uint8, target string, encrypte if verityInfo != nil { dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, verityInfo.RootDigest) - if err := dm.RemoveDevice(dmVerityName); err != nil { + if err := removeDevice(dmVerityName); err != nil { return errors.Wrapf(err, "failed to remove dm verity target: %s", dmVerityName) } } diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index 23d3182aed..4b53b62896 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -5,6 +5,8 @@ package scsi import ( "context" "errors" + "fmt" + "github.com/Microsoft/hcsshim/internal/guest/prot" "os" "testing" @@ -18,6 +20,7 @@ func clearTestDependencies() { osRemoveAll = nil unixMount = nil controllerLunToName = nil + veritySetup = nil } func Test_Mount_Mkdir_Fails_Error(t *testing.T) { @@ -501,3 +504,98 @@ func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer { return &policy.MountMonitoringSecurityPolicyEnforcer{} } + +// dm-verity tests +func Test_CreateVerityTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, "hash") + expectedSource := "/dev/sdb" + expectedMapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) + expectedTarget := "/foo" + veritySetupCalled := false + + controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { + return expectedSource, nil + } + + osMkdirAll = func(_ string, _ os.FileMode) error { + return nil + } + + vInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + veritySetup = func(_ context.Context, source, name string, verityInfo *prot.DeviceVerityInfo) (string, error) { + veritySetupCalled = true + if source != expectedSource { + t.Errorf("expected source %s, got %s", expectedSource, source) + } + if name != expectedVerityName { + t.Errorf("expected verity target name %s, got %s", expectedVerityName, name) + } + return expectedMapperPath, nil + } + + unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { + if source != expectedMapperPath { + t.Errorf("expected unixMount source %s, got %s", expectedMapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected unixMount target %s, got %s", expectedTarget, target) + } + return nil + } + + if err := Mount( + context.Background(), 0, 0, expectedTarget, true, false, nil, vInfo, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount: %s", err) + } + if !veritySetupCalled { + t.Fatalf("expected veritySetup to be called") + } +} + +func Test_osMkdirAllFails_And_RemoveDevice_Called(t *testing.T) { + clearTestDependencies() + + expectedError := errors.New("osMkdirAll error") + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, "hash") + removeDeviceCalled := false + + controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { + return "/dev/sdb", nil + } + + osMkdirAll = func(_ string, _ os.FileMode) error { + return expectedError + } + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + + veritySetup = func(_ context.Context, _, _ string, _ *prot.DeviceVerityInfo) (string, error) { + return fmt.Sprintf("/dev/mapper/%s", expectedVerityName), nil + } + + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedVerityName { + t.Errorf("expected RemoveDevice name %s, got %s", expectedVerityName, name) + } + return nil + } + + if err := Mount( + context.Background(), 0, 0, "/foo", true, false, nil, verityInfo, + openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be called") + } +} From 49e4ed6153b391a0bff837142585f96ff2f922bf Mon Sep 17 00:00:00 2001 From: Maksim An Date: Thu, 7 Oct 2021 11:31:15 -0700 Subject: [PATCH 3/4] pr feedback #1: update function name mocks and Mount calls in tests Signed-off-by: Maksim An --- internal/guest/storage/pmem/pmem.go | 18 +- internal/guest/storage/pmem/pmem_test.go | 108 +++++++----- internal/guest/storage/scsi/scsi.go | 6 +- internal/guest/storage/scsi/scsi_test.go | 199 +++++++++++++++++++---- 4 files changed, 247 insertions(+), 84 deletions(-) diff --git a/internal/guest/storage/pmem/pmem.go b/internal/guest/storage/pmem/pmem.go index 9291db42f3..34ea570474 100644 --- a/internal/guest/storage/pmem/pmem.go +++ b/internal/guest/storage/pmem/pmem.go @@ -21,13 +21,13 @@ import ( // Test dependencies var ( - osMkdirAll = os.MkdirAll - osRemoveAll = os.RemoveAll - unixMount = unix.Mount - mountInternal = mount - createLinearTarget = dm.CreateZeroSectorLinearTarget - veritySetup = dm.CreateVerityTarget - removeDevice = dm.RemoveDevice + osMkdirAll = os.MkdirAll + osRemoveAll = os.RemoveAll + unixMount = unix.Mount + mountInternal = mount + createZeroSectorLinearTarget = dm.CreateZeroSectorLinearTarget + createVerityTargetCalled = dm.CreateVerityTarget + removeDevice = dm.RemoveDevice ) const ( @@ -93,7 +93,7 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. // device instead of the original VPMem. if mappingInfo != nil { dmLinearName := fmt.Sprintf(linearDeviceFmt, device, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) - if devicePath, err = createLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil { + if devicePath, err = createZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil { return err } defer func() { @@ -107,7 +107,7 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. if verityInfo != nil { dmVerityName := fmt.Sprintf(verityDeviceFmt, device, verityInfo.RootDigest) - if devicePath, err = veritySetup(mCtx, devicePath, dmVerityName, verityInfo); err != nil { + if devicePath, err = createVerityTargetCalled(mCtx, devicePath, dmVerityName, verityInfo); err != nil { return err } defer func() { diff --git a/internal/guest/storage/pmem/pmem_test.go b/internal/guest/storage/pmem/pmem_test.go index 2a23eaebff..1d04c73ef6 100644 --- a/internal/guest/storage/pmem/pmem_test.go +++ b/internal/guest/storage/pmem/pmem_test.go @@ -19,8 +19,8 @@ func clearTestDependencies() { osMkdirAll = nil osRemoveAll = nil unixMount = nil - createLinearTarget = nil - veritySetup = nil + createZeroSectorLinearTarget = nil + createVerityTargetCalled = nil removeDevice = nil mountInternal = mount } @@ -323,7 +323,7 @@ func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing expectedSource := "/dev/pmem0" expectedTarget := "/foo" mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName) - createLTCalled := false + createZSLTCalled := false osMkdirAll = func(_ string, _ os.FileMode) error { return nil @@ -339,28 +339,33 @@ func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing return nil } - createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { - createLTCalled = true + createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + createZSLTCalled = true if source != expectedSource { - t.Errorf("expected createLinearTarget source %s, got %s", expectedSource, source) + t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedSource, source) } if name != expectedLinearName { - t.Errorf("expected createLinearTarget name %s, got %s", expectedLinearName, name) + t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearName, name) } return mapperPath, nil } if err := Mount( - context.Background(), 0, expectedTarget, mappingInfo, nil, openDoorSecurityPolicyEnforcer(), + context.Background(), + 0, + expectedTarget, + mappingInfo, + nil, + openDoorSecurityPolicyEnforcer(), ); err != nil { t.Fatalf("unexpected error during Mount: %s", err) } - if !createLTCalled { - t.Fatalf("createLinearTarget not called") + if !createZSLTCalled { + t.Fatalf("createZeroSectorLinearTarget not called") } } -func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) { +func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *testing.T) { clearTestDependencies() verityInfo := &prot.DeviceVerityInfo{ @@ -370,7 +375,7 @@ func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) { expectedSource := "/dev/pmem0" expectedTarget := "/foo" mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) - veritySetupCalled := false + createVerityTargetCalledCalled := false mountInternal = func(_ context.Context, source, target string) error { if source != mapperPath { @@ -381,28 +386,33 @@ func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) { } return nil } - veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { - veritySetupCalled = true + createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + createVerityTargetCalledCalled = true if source != expectedSource { - t.Errorf("expected veritySetup source %s, got %s", expectedSource, source) + t.Errorf("expected createVerityTargetCalled source %s, got %s", expectedSource, source) } if name != expectedVerityName { - t.Errorf("expected veritySetup name %s, got %s", expectedVerityName, name) + t.Errorf("expected createVerityTargetCalled name %s, got %s", expectedVerityName, name) } return mapperPath, nil } if err := Mount( - context.Background(), 0, expectedTarget, nil, verityInfo, openDoorSecurityPolicyEnforcer(), + context.Background(), + 0, + expectedTarget, + nil, + verityInfo, + openDoorSecurityPolicyEnforcer(), ); err != nil { t.Fatalf("unexpected Mount failure: %s", err) } - if !veritySetupCalled { - t.Fatal("veritySetup not called") + if !createVerityTargetCalledCalled { + t.Fatal("createVerityTargetCalled not called") } } -func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) { +func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *testing.T) { clearTestDependencies() verityInfo := &prot.DeviceVerityInfo{ @@ -421,23 +431,23 @@ func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) { dmVerityCalled := false mountCalled := false - createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { dmLinearCalled = true if source != expectedPMemDevice { - t.Errorf("expected createLinearTarget source %s, got %s", expectedPMemDevice, source) + t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source) } if name != expectedLinearTarget { - t.Errorf("expected createLineartarget name %s, got %s", expectedLinearTarget, name) + t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearTarget, name) } return mapperLinearPath, nil } - veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { dmVerityCalled = true if source != mapperLinearPath { - t.Errorf("expected veritySetup source %s, got %s", mapperLinearPath, source) + t.Errorf("expected createVerityTargetCalled source %s, got %s", mapperLinearPath, source) } if name != expectedVerityTarget { - t.Errorf("expected veritySetup target name %s, got %s", expectedVerityTarget, name) + t.Errorf("expected createVerityTargetCalled target name %s, got %s", expectedVerityTarget, name) } return mapperVerityPath, nil } @@ -450,15 +460,20 @@ func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) { } if err := Mount( - context.Background(), 0, "/foo", mapping, verityInfo, openDoorSecurityPolicyEnforcer(), + context.Background(), + 0, + "/foo", + mapping, + verityInfo, + openDoorSecurityPolicyEnforcer(), ); err != nil { t.Fatalf("unexpected error during Mount call: %s", err) } if !dmLinearCalled { - t.Fatal("expected createLinearTarget call") + t.Fatal("expected createZeroSectorLinearTarget call") } if !dmVerityCalled { - t.Fatal("expected veritySetup call") + t.Fatal("expected createVerityTargetCalled call") } if !mountCalled { t.Fatal("expected mountInternal call") @@ -477,7 +492,7 @@ func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testin mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget) removeDeviceCalled := false - createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { return mapperPath, nil } mountInternal = func(_ context.Context, source, target string) error { @@ -492,7 +507,12 @@ func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testin } if err := Mount( - context.Background(), 0, "/foo", mappingInfo, nil, openDoorSecurityPolicyEnforcer(), + context.Background(), + 0, + "/foo", + mappingInfo, + nil, + openDoorSecurityPolicyEnforcer(), ); err != expectedError { t.Fatalf("expected Mount error %s, got %s", expectedError, err) } @@ -512,7 +532,7 @@ func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testin mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) removeDeviceCalled := false - veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { return mapperPath, nil } mountInternal = func(_ context.Context, _, _ string) error { @@ -527,7 +547,12 @@ func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testin } if err := Mount( - context.Background(), 0, "/foo", nil, verity, openDoorSecurityPolicyEnforcer(), + context.Background(), + 0, + "/foo", + nil, + verity, + openDoorSecurityPolicyEnforcer(), ); err != expectedError { t.Fatalf("expected Mount error %s, got %s", expectedError, err) } @@ -555,18 +580,18 @@ func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testin rmLinearCalled := false rmVerityCalled := false - createLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) { + createZeroSectorLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) { if source != expectedPMemDevice { - t.Errorf("expected createLinearTarget source %s, got %s", expectedPMemDevice, source) + t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source) } return mapperLinearPath, nil } - veritySetup = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) { + createVerityTargetCalled = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) { if source != mapperLinearPath { - t.Errorf("expected veritySetup to be called with %s, got %s", mapperLinearPath, source) + t.Errorf("expected createVerityTargetCalled to be called with %s, got %s", mapperLinearPath, source) } if name != expectedVerityTarget { - t.Errorf("expected veritySetup target %s, got %s", expectedVerityTarget, name) + t.Errorf("expected createVerityTargetCalled target %s, got %s", expectedVerityTarget, name) } return mapperVerityPath, nil } @@ -587,7 +612,12 @@ func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testin } if err := Mount( - context.Background(), 0, "/foo", mapping, verity, openDoorSecurityPolicyEnforcer(), + context.Background(), + 0, + "/foo", + mapping, + verity, + openDoorSecurityPolicyEnforcer(), ); err != expectedError { t.Fatalf("expected Mount error %s, got %s", expectedError, err) } diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index a1ac0db085..1333dac43e 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -30,8 +30,8 @@ var ( // controllerLunToName is stubbed to make testing `Mount` easier. controllerLunToName = ControllerLunToName - // veritySetup is stubbed for unit testing `Mount` - veritySetup = dm.CreateVerityTarget + // createVerityTarget is stubbed for unit testing `Mount` + createVerityTarget = dm.CreateVerityTarget // removeDevice is stubbed for unit testing `Mount` removeDevice = dm.RemoveDevice ) @@ -77,7 +77,7 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b if verityInfo != nil { dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash) - if source, err = veritySetup(ctx, source, dmVerityName, verityInfo); err != nil { + if source, err = createVerityTarget(spnCtx, source, dmVerityName, verityInfo); err != nil { return err } defer func() { diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index 4b53b62896..853cb83df7 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -20,7 +20,7 @@ func clearTestDependencies() { osRemoveAll = nil unixMount = nil controllerLunToName = nil - veritySetup = nil + createVerityTarget = nil } func Test_Mount_Mkdir_Fails_Error(t *testing.T) { @@ -35,8 +35,17 @@ func Test_Mount_Mkdir_Fails_Error(t *testing.T) { return "", nil } - err := Mount(context.Background(), 0, 0, "", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != expectedErr { + if err := Mount( + context.Background(), + 0, + 0, + "", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } } @@ -62,8 +71,18 @@ func Test_Mount_Mkdir_ExpectedPath(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + target, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil error got: %v", err) } } @@ -89,8 +108,18 @@ func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + target, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil error got: %v", err) } } @@ -116,8 +145,18 @@ func Test_Mount_ControllerLunToName_Valid_Controller(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), expectedController, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + expectedController, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil error got: %v", err) } } @@ -143,8 +182,18 @@ func Test_Mount_ControllerLunToName_Valid_Lun(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, expectedLun, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + expectedLun, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil error got: %v", err) } } @@ -173,8 +222,18 @@ func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) { // Fake the mount failure to test remove is called return expectedErr } - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != expectedErr { + + if err := Mount( + context.Background(), + 0, + 0, + target, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } if !removeAllCalled { @@ -228,8 +287,18 @@ func Test_Mount_Valid_Target(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, expectedTarget, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + expectedTarget, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -254,8 +323,18 @@ func Test_Mount_Valid_FSType(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -280,8 +359,18 @@ func Test_Mount_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -306,8 +395,18 @@ func Test_Mount_Readonly_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", true, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + true, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -331,8 +430,18 @@ func Test_Mount_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -357,8 +466,18 @@ func Test_Mount_Readonly_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", true, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + true, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -513,7 +632,7 @@ func Test_CreateVerityTarget_And_Mount_Called_With_Correct_Parameters(t *testing expectedSource := "/dev/sdb" expectedMapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) expectedTarget := "/foo" - veritySetupCalled := false + createVerityTargetCalled := false controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { return expectedSource, nil @@ -526,8 +645,8 @@ func Test_CreateVerityTarget_And_Mount_Called_With_Correct_Parameters(t *testing vInfo := &prot.DeviceVerityInfo{ RootDigest: "hash", } - veritySetup = func(_ context.Context, source, name string, verityInfo *prot.DeviceVerityInfo) (string, error) { - veritySetupCalled = true + createVerityTarget = func(_ context.Context, source, name string, verityInfo *prot.DeviceVerityInfo) (string, error) { + createVerityTargetCalled = true if source != expectedSource { t.Errorf("expected source %s, got %s", expectedSource, source) } @@ -548,13 +667,20 @@ func Test_CreateVerityTarget_And_Mount_Called_With_Correct_Parameters(t *testing } if err := Mount( - context.Background(), 0, 0, expectedTarget, true, false, nil, vInfo, + context.Background(), + 0, + 0, + expectedTarget, + true, + false, + nil, + vInfo, openDoorSecurityPolicyEnforcer(), ); err != nil { t.Fatalf("unexpected error during Mount: %s", err) } - if !veritySetupCalled { - t.Fatalf("expected veritySetup to be called") + if !createVerityTargetCalled { + t.Fatalf("expected createVerityTargetCalled to be called") } } @@ -577,7 +703,7 @@ func Test_osMkdirAllFails_And_RemoveDevice_Called(t *testing.T) { RootDigest: "hash", } - veritySetup = func(_ context.Context, _, _ string, _ *prot.DeviceVerityInfo) (string, error) { + createVerityTarget = func(_ context.Context, _, _ string, _ *prot.DeviceVerityInfo) (string, error) { return fmt.Sprintf("/dev/mapper/%s", expectedVerityName), nil } @@ -590,7 +716,14 @@ func Test_osMkdirAllFails_And_RemoveDevice_Called(t *testing.T) { } if err := Mount( - context.Background(), 0, 0, "/foo", true, false, nil, verityInfo, + context.Background(), + 0, + 0, + "/foo", + true, + false, + nil, + verityInfo, openDoorSecurityPolicyEnforcer(), ); err != expectedError { t.Fatalf("expected Mount error %s, got %s", expectedError, err) From adfcb488c11bbb7beb0eb69c0d7494e645f9b232 Mon Sep 17 00:00:00 2001 From: Maksim An Date: Thu, 7 Oct 2021 23:36:33 -0700 Subject: [PATCH 4/4] fix find/replace side effect Signed-off-by: Maksim An --- internal/guest/storage/pmem/pmem.go | 4 +-- internal/guest/storage/pmem/pmem_test.go | 32 ++++++++++++------------ internal/guest/storage/scsi/scsi.go | 13 +++++----- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/internal/guest/storage/pmem/pmem.go b/internal/guest/storage/pmem/pmem.go index 34ea570474..681659f061 100644 --- a/internal/guest/storage/pmem/pmem.go +++ b/internal/guest/storage/pmem/pmem.go @@ -26,7 +26,7 @@ var ( unixMount = unix.Mount mountInternal = mount createZeroSectorLinearTarget = dm.CreateZeroSectorLinearTarget - createVerityTargetCalled = dm.CreateVerityTarget + createVerityTarget = dm.CreateVerityTarget removeDevice = dm.RemoveDevice ) @@ -107,7 +107,7 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. if verityInfo != nil { dmVerityName := fmt.Sprintf(verityDeviceFmt, device, verityInfo.RootDigest) - if devicePath, err = createVerityTargetCalled(mCtx, devicePath, dmVerityName, verityInfo); err != nil { + if devicePath, err = createVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil { return err } defer func() { diff --git a/internal/guest/storage/pmem/pmem_test.go b/internal/guest/storage/pmem/pmem_test.go index 1d04c73ef6..be61d70b11 100644 --- a/internal/guest/storage/pmem/pmem_test.go +++ b/internal/guest/storage/pmem/pmem_test.go @@ -20,7 +20,7 @@ func clearTestDependencies() { osRemoveAll = nil unixMount = nil createZeroSectorLinearTarget = nil - createVerityTargetCalled = nil + createVerityTarget = nil removeDevice = nil mountInternal = mount } @@ -375,7 +375,7 @@ func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *t expectedSource := "/dev/pmem0" expectedTarget := "/foo" mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) - createVerityTargetCalledCalled := false + createVerityTargetCalled := false mountInternal = func(_ context.Context, source, target string) error { if source != mapperPath { @@ -386,13 +386,13 @@ func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *t } return nil } - createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { - createVerityTargetCalledCalled = true + createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + createVerityTargetCalled = true if source != expectedSource { - t.Errorf("expected createVerityTargetCalled source %s, got %s", expectedSource, source) + t.Errorf("expected createVerityTarget source %s, got %s", expectedSource, source) } if name != expectedVerityName { - t.Errorf("expected createVerityTargetCalled name %s, got %s", expectedVerityName, name) + t.Errorf("expected createVerityTarget name %s, got %s", expectedVerityName, name) } return mapperPath, nil } @@ -407,8 +407,8 @@ func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *t ); err != nil { t.Fatalf("unexpected Mount failure: %s", err) } - if !createVerityTargetCalledCalled { - t.Fatal("createVerityTargetCalled not called") + if !createVerityTargetCalled { + t.Fatal("createVerityTarget not called") } } @@ -441,13 +441,13 @@ func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *te } return mapperLinearPath, nil } - createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { dmVerityCalled = true if source != mapperLinearPath { - t.Errorf("expected createVerityTargetCalled source %s, got %s", mapperLinearPath, source) + t.Errorf("expected createVerityTarget source %s, got %s", mapperLinearPath, source) } if name != expectedVerityTarget { - t.Errorf("expected createVerityTargetCalled target name %s, got %s", expectedVerityTarget, name) + t.Errorf("expected createVerityTarget target name %s, got %s", expectedVerityTarget, name) } return mapperVerityPath, nil } @@ -473,7 +473,7 @@ func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *te t.Fatal("expected createZeroSectorLinearTarget call") } if !dmVerityCalled { - t.Fatal("expected createVerityTargetCalled call") + t.Fatal("expected createVerityTarget call") } if !mountCalled { t.Fatal("expected mountInternal call") @@ -532,7 +532,7 @@ func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testin mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) removeDeviceCalled := false - createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { return mapperPath, nil } mountInternal = func(_ context.Context, _, _ string) error { @@ -586,12 +586,12 @@ func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testin } return mapperLinearPath, nil } - createVerityTargetCalled = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) { + createVerityTarget = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) { if source != mapperLinearPath { - t.Errorf("expected createVerityTargetCalled to be called with %s, got %s", mapperLinearPath, source) + t.Errorf("expected createVerityTarget to be called with %s, got %s", mapperLinearPath, source) } if name != expectedVerityTarget { - t.Errorf("expected createVerityTargetCalled target %s, got %s", expectedVerityTarget, name) + t.Errorf("expected createVerityTarget target %s, got %s", expectedVerityTarget, name) } return mapperVerityPath, nil } diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index 1333dac43e..fbcebf3754 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -164,18 +164,19 @@ func Unmount(ctx context.Context, controller, lun uint8, target string, encrypte return errors.Wrapf(err, "unmounting scsi controller %d lun %d from %s denied by policy", controller, lun, target) } + // Unmount unencrypted device + if err := storage.UnmountPath(ctx, target, true); err != nil { + return errors.Wrapf(err, "unmount failed: "+target) + } + if verityInfo != nil { dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, verityInfo.RootDigest) if err := removeDevice(dmVerityName); err != nil { - return errors.Wrapf(err, "failed to remove dm verity target: %s", dmVerityName) + // Ignore failures, since the path has been unmounted at this point. + log.G(ctx).WithError(err).Debugf("failed to remove dm verity target: %s", dmVerityName) } } - // Unmount unencrypted device - if err := storage.UnmountPath(ctx, target, true); err != nil { - return errors.Wrapf(err, "unmount failed: "+target) - } - if encrypted { if err := crypt.CleanupCryptDevice(target); err != nil { return errors.Wrapf(err, "failed to cleanup dm-crypt state: "+target)