diff --git a/internal/pkg/provisioner/storageos.go b/internal/pkg/provisioner/storageos.go index 8d1fda6b..8b0fd0c1 100644 --- a/internal/pkg/provisioner/storageos.go +++ b/internal/pkg/provisioner/storageos.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" storagev1 "k8s.io/api/storage/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -111,12 +112,6 @@ func IsProvisionedStorageClass(sc *storagev1.StorageClass, provisioners ...strin // IsStorageOSVolume returns true if the volume's PVC was provisioned by // StorageOS. The namespace of the Pod/PVC must be provided. func IsStorageOSVolume(k8s client.Client, volume corev1.Volume, namespace string) (bool, error) { - return IsProvisionedVolume(k8s, volume, namespace, DriverName) -} - -// IsProvisionedVolume returns true if the volume's PVC was provided by one of -// the given provisioners. -func IsProvisionedVolume(k8s client.Client, volume corev1.Volume, namespace string, provisioners ...string) (bool, error) { // Ensure that the volume has a claim. if volume.PersistentVolumeClaim == nil { return false, nil @@ -129,8 +124,11 @@ func IsProvisionedVolume(k8s client.Client, volume corev1.Volume, namespace stri Namespace: namespace, } if err := k8s.Get(context.Background(), key, pvc); err != nil { - return false, errors.Wrap(err, "failed to get PVC") + if apierrors.IsNotFound(err) { + return false, nil + } + return false, errors.Wrap(err, "failed to get PVCxxxx") } - return IsProvisionedPVC(k8s, *pvc, namespace, provisioners...) + return IsProvisionedPVC(k8s, *pvc, namespace, DriverName) } diff --git a/internal/pkg/provisioner/storageos_test.go b/internal/pkg/provisioner/storageos_test.go index 729fa3be..5e768ef9 100644 --- a/internal/pkg/provisioner/storageos_test.go +++ b/internal/pkg/provisioner/storageos_test.go @@ -161,18 +161,85 @@ func TestHasStorageOSAnnotation(t *testing.T) { } } -func TestIsProvisionedVolume(t *testing.T) { +func Test_IsStorageOSVolume(t *testing.T) { t.Parallel() - provisioned, err := IsProvisionedVolume(nil, corev1.Volume{}, "") + stosSC := storagev1.StorageClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "fast", + }, + Provisioner: DriverName, + } + nonStosSC := storagev1.StorageClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "slow", + }, + Provisioner: "foo-provisioner", + } + + stosPVC := createPVC("storageos", "default", stosSC.Name, false) + nonStosPVC := createPVC("non-storageos", "default", nonStosSC.Name, false) - if provisioned { - t.Errorf("IsProvisionedVolume() = %t", provisioned) - return + tests := []struct { + name string + volume corev1.Volume + want bool + wantErr bool + }{ + { + name: "no PVC", + volume: corev1.Volume{}, + want: false, + }, + { + name: "PVC not exists", + volume: corev1.Volume{ + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: "non-exists", + }, + }, + }, + want: false, + }, + { + name: "non StorageOS PVC", + volume: corev1.Volume{ + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: nonStosPVC.Name, + }, + }, + }, + want: false, + }, + { + name: "StorageOS PVC", + volume: corev1.Volume{ + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: stosPVC.Name, + }, + }, + }, + want: true, + }, } - if err != nil { - t.Errorf("IsProvisionedVolume() error = %v, not allowed", err) - return + for _, tt := range tests { + var tt = tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + k8s := fake.NewClientBuilder().WithObjects(&stosSC, &nonStosSC, &stosPVC, &nonStosPVC).Build() + + got, gotErr := IsStorageOSVolume(k8s, tt.volume, "default") + if (gotErr != nil) != tt.wantErr { + t.Errorf("IsStorageOSVolume() error = %v, wantErr %t", gotErr, tt.wantErr) + } + if got != tt.want { + t.Errorf("IsStorageOSVolume() = %v, want %v", got, tt.want) + } + }) } }