diff --git a/internal/guest/network/network.go b/internal/guest/network/network.go index 4cb66f302b..ec4f1fc9bd 100644 --- a/internal/guest/network/network.go +++ b/internal/guest/network/network.go @@ -12,6 +12,8 @@ import ( "strings" "time" + "github.com/Microsoft/hcsshim/internal/guest/storage" + "github.com/Microsoft/hcsshim/internal/guest/storage/pci" "github.com/Microsoft/hcsshim/internal/guest/storage/vmbus" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" @@ -19,6 +21,14 @@ import ( "go.opencensus.io/trace" ) +// mock out calls for testing +var ( + pciFindDeviceFullPath = pci.FindDeviceFullPath + storageWaitForFileMatchingPattern = storage.WaitForFileMatchingPattern + vmbusWaitForDevicePath = vmbus.WaitForDevicePath + ioutilReadDir = ioutil.ReadDir +) + // maxDNSSearches is limited to 6 in `man 5 resolv.conf` const maxDNSSearches = 6 @@ -104,20 +114,33 @@ func MergeValues(first, second []string) []string { // Windows host) to its corresponding interface name (e.g. "eth0"). // // Will retry the operation until `ctx` is exceeded or canceled. -func InstanceIDToName(ctx context.Context, id string) (_ string, err error) { +func InstanceIDToName(ctx context.Context, id string, vpciAssigned bool) (_ string, err error) { ctx, span := trace.StartSpan(ctx, "network::InstanceIDToName") defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - id = strings.ToLower(id) - span.AddAttributes(trace.StringAttribute("adapterInstanceID", id)) + vmBusID := strings.ToLower(id) + span.AddAttributes(trace.StringAttribute("adapterInstanceID", vmBusID)) - vmBusSubPath := filepath.Join(id, "net") - devicePath, err := vmbus.WaitForDevicePath(ctx, vmBusSubPath) + netDevicePath := "" + if vpciAssigned { + pciDevicePath, err := pciFindDeviceFullPath(ctx, vmBusID) + if err != nil { + return "", err + } + pciNetDirPattern := filepath.Join(pciDevicePath, "net") + netDevicePath, err = storageWaitForFileMatchingPattern(ctx, pciNetDirPattern) + } else { + vmBusNetSubPath := filepath.Join(vmBusID, "net") + netDevicePath, err = vmbusWaitForDevicePath(ctx, vmBusNetSubPath) + } + if err != nil { + return "", errors.Wrapf(err, "failed to find adapter %v sysfs path", vmBusID) + } var deviceDirs []os.FileInfo for { - deviceDirs, err = ioutil.ReadDir(devicePath) + deviceDirs, err = ioutilReadDir(netDevicePath) if err != nil { if os.IsNotExist(err) { select { @@ -128,16 +151,16 @@ func InstanceIDToName(ctx context.Context, id string) (_ string, err error) { continue } } else { - return "", errors.Wrapf(err, "failed to read vmbus network device from /sys filesystem for adapter %s", id) + return "", errors.Wrapf(err, "failed to read vmbus network device from /sys filesystem for adapter %s", vmBusID) } } break } if len(deviceDirs) == 0 { - return "", errors.Errorf("no interface name found for adapter %s", id) + return "", errors.Errorf("no interface name found for adapter %s", vmBusID) } if len(deviceDirs) > 1 { - return "", errors.Errorf("multiple interface names found for adapter %s", id) + return "", errors.Errorf("multiple interface names found for adapter %s", vmBusID) } ifname := deviceDirs[0].Name() log.G(ctx).WithField("ifname", ifname).Debug("resolved ifname") diff --git a/internal/guest/network/network_test.go b/internal/guest/network/network_test.go index 08c3781f67..21e3728a8c 100644 --- a/internal/guest/network/network_test.go +++ b/internal/guest/network/network_test.go @@ -4,7 +4,11 @@ package network import ( "context" + "io/fs" + "os" + "path/filepath" "testing" + "time" ) func Test_GenerateResolvConfContent(t *testing.T) { @@ -165,3 +169,94 @@ ff02::2 ip6-allrouters }) } } + +// create a test FileInfo so we can return back a value to ReadDir +type testFileInfo struct { + FileName string + IsDirectory bool +} + +func (t *testFileInfo) Name() string { + return t.FileName +} +func (t *testFileInfo) Size() int64 { + return 0 +} +func (t *testFileInfo) Mode() fs.FileMode { + if t.IsDirectory { + return fs.ModeDir + } + return 0 +} +func (t *testFileInfo) ModTime() time.Time { + return time.Now() +} +func (t *testFileInfo) IsDir() bool { + return t.IsDirectory +} +func (t *testFileInfo) Sys() interface{} { + return nil +} + +var _ = (os.FileInfo)(&testFileInfo{}) + +func Test_InstanceIDToName(t *testing.T) { + ctx, _ := context.WithTimeout(context.Background(), 5*time.Second) + + vmBusGUID := "1111-2222-3333-4444" + testIfName := "test-eth0" + + vmbusWaitForDevicePath = func(_ context.Context, vmBusGUIDPattern string) (string, error) { + vmBusPath := filepath.Join("/sys/bus/vmbus/devices", vmBusGUIDPattern) + return vmBusPath, nil + } + + storageWaitForFileMatchingPattern = func(_ context.Context, pattern string) (string, error) { + return pattern, nil + } + + ioutilReadDir = func(dirname string) ([]os.FileInfo, error) { + info := &testFileInfo{ + FileName: testIfName, + IsDirectory: false, + } + return []fs.FileInfo{info}, nil + } + actualIfName, err := InstanceIDToName(ctx, vmBusGUID, false) + if err != nil { + t.Fatalf("expected no error, instead got %v", err) + } + if actualIfName != testIfName { + t.Fatalf("expected to get %v ifname, instead got %v", testIfName, actualIfName) + } +} + +func Test_InstanceIDToName_VPCI(t *testing.T) { + ctx, _ := context.WithTimeout(context.Background(), 5*time.Second) + + vmBusGUID := "1111-2222-3333-4444" + testIfName := "test-eth0-vpci" + + pciFindDeviceFullPath = func(_ context.Context, vmBusGUID string) (string, error) { + return filepath.Join("/sys/bus/vmbus/devices", vmBusGUID), nil + } + + storageWaitForFileMatchingPattern = func(_ context.Context, pattern string) (string, error) { + return pattern, nil + } + + ioutilReadDir = func(dirname string) ([]os.FileInfo, error) { + info := &testFileInfo{ + FileName: testIfName, + IsDirectory: false, + } + return []os.FileInfo{info}, nil + } + actualIfName, err := InstanceIDToName(ctx, vmBusGUID, true) + if err != nil { + t.Fatalf("expected no error, instead got %v", err) + } + if actualIfName != testIfName { + t.Fatalf("expected to get %v ifname, instead got %v", testIfName, actualIfName) + } +} diff --git a/internal/guest/prot/protocol.go b/internal/guest/prot/protocol.go index acec8a7806..4ae4272345 100644 --- a/internal/guest/prot/protocol.go +++ b/internal/guest/prot/protocol.go @@ -754,6 +754,7 @@ type NetworkAdapterV2 struct { DNSServerList string `json:",omitempty"` EnableLowMetric bool `json:",omitempty"` EncapOverhead uint16 `json:",omitempty"` + VPCIAssigned bool `json:",omitempty"` } // MappedVirtualDisk represents a disk on the host which is mapped into a diff --git a/internal/guest/runtime/hcsv2/network.go b/internal/guest/runtime/hcsv2/network.go index 5d1c1d3395..9ff739c47e 100644 --- a/internal/guest/runtime/hcsv2/network.go +++ b/internal/guest/runtime/hcsv2/network.go @@ -162,7 +162,7 @@ func (n *namespace) AddAdapter(ctx context.Context, adp *prot.NetworkAdapterV2) resolveCtx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - ifname, err := networkInstanceIDToName(resolveCtx, adp.ID) + ifname, err := networkInstanceIDToName(resolveCtx, adp.ID, adp.VPCIAssigned) if err != nil { return err } diff --git a/internal/guest/runtime/hcsv2/network_test.go b/internal/guest/runtime/hcsv2/network_test.go index 90f24aadd2..e73d256c4a 100644 --- a/internal/guest/runtime/hcsv2/network_test.go +++ b/internal/guest/runtime/hcsv2/network_test.go @@ -97,7 +97,7 @@ func Test_removeNetworkNamespace_HasAdapters(t *testing.T) { ns := getOrAddNetworkNamespace(t.Name()) - networkInstanceIDToName = func(ctx context.Context, id string) (string, error) { + networkInstanceIDToName = func(ctx context.Context, id string, _ bool) (string, error) { return "/dev/sdz", nil } err := ns.AddAdapter(context.Background(), &prot.NetworkAdapterV2{ID: "test"}) diff --git a/internal/guest/storage/pci/pci.go b/internal/guest/storage/pci/pci.go index e75b253186..f77222db95 100644 --- a/internal/guest/storage/pci/pci.go +++ b/internal/guest/storage/pci/pci.go @@ -24,16 +24,24 @@ func WaitForPCIDeviceFromVMBusGUID(ctx context.Context, vmBusGUID string) error // FindDeviceBusLocationFromVMBusGUID finds device bus location by // reading /sys/bus/vmbus/devices//... for pci specific directories func FindDeviceBusLocationFromVMBusGUID(ctx context.Context, vmBusGUID string) (string, error) { - pciDir, err := findVMBusPCIDir(ctx, vmBusGUID) + fullPath, err := FindDeviceFullPath(ctx, vmBusGUID) if err != nil { return "", err } - pciDeviceLocation, err := findVMBusPCIDevice(ctx, pciDir) + _, busFile := filepath.Split(fullPath) + return busFile, nil +} + +// FindDeviceFullPath finds the full PCI device path in the form of +// /sys/bus/vmbus/devices//pciXXXX:XX/XXXX:XX* +func FindDeviceFullPath(ctx context.Context, vmBusGUID string) (string, error) { + pciDir, err := findVMBusPCIDir(ctx, vmBusGUID) if err != nil { return "", err } - return pciDeviceLocation, nil + + return findVMBusPCIDevice(ctx, pciDir) } // findVMBusPCIDir waits for the pci bus directory matching pattern @@ -50,7 +58,6 @@ func findVMBusPCIDevice(ctx context.Context, pciDirFullPath string) (string, err // trim /sys/bus/vmbus/devices//pciXXXX:XX to XXXX:XX _, pciDirName := filepath.Split(pciDirFullPath) busPrefix := strings.TrimPrefix(pciDirName, "pci") - // under /sys/bus/vmbus/devices//pciXXXX:XX/ look for directory matching XXXX:XX* pattern busPathPattern := filepath.Join(pciDirFullPath, fmt.Sprintf("%s*", busPrefix)) busFileFullPath, err := storageWaitForFileMatchingPattern(ctx, busPathPattern) @@ -58,7 +65,5 @@ func findVMBusPCIDevice(ctx context.Context, pciDirFullPath string) (string, err return "", err } - // return the resulting XXXX:XX:YY.Y pci bus location - _, busFile := filepath.Split(busFileFullPath) - return busFile, nil + return busFileFullPath, nil } diff --git a/internal/guestrequest/types.go b/internal/guestrequest/types.go index 6d81dc11c8..e3e09f5201 100644 --- a/internal/guestrequest/types.go +++ b/internal/guestrequest/types.go @@ -102,6 +102,7 @@ type LCOWNetworkAdapter struct { DNSServerList string `json:",omitempty"` EnableLowMetric bool `json:",omitempty"` EncapOverhead uint16 `json:",omitempty"` + VPCIAssigned bool `json:",omitempty"` } type LCOWContainerConstraints struct { diff --git a/test/vendor/github.com/Microsoft/hcsshim/internal/guestrequest/types.go b/test/vendor/github.com/Microsoft/hcsshim/internal/guestrequest/types.go index 6d81dc11c8..e3e09f5201 100644 --- a/test/vendor/github.com/Microsoft/hcsshim/internal/guestrequest/types.go +++ b/test/vendor/github.com/Microsoft/hcsshim/internal/guestrequest/types.go @@ -102,6 +102,7 @@ type LCOWNetworkAdapter struct { DNSServerList string `json:",omitempty"` EnableLowMetric bool `json:",omitempty"` EncapOverhead uint16 `json:",omitempty"` + VPCIAssigned bool `json:",omitempty"` } type LCOWContainerConstraints struct {