diff --git a/pkg/inventory/db.go b/pkg/inventory/db.go index 467005c9..70320ec3 100644 --- a/pkg/inventory/db.go +++ b/pkg/inventory/db.go @@ -295,7 +295,7 @@ func (db *DB) netdevToDRAdev(link netlink.Link) (*resourceapi.Device, error) { func addPCIAttributes(device *resourceapi.Device, ifName string, path string) { device.Attributes["dra.net/virtual"] = resourceapi.DeviceAttribute{BoolValue: ptr.To(false)} - address, err := bdfAddress(ifName, path) + address, err := pciAddressForNetInterface(ifName) if err != nil { klog.Infof("Could not get bdf address : %v", err) } else { diff --git a/pkg/inventory/sysfs.go b/pkg/inventory/sysfs.go index 1c287b1e..02ab2867 100644 --- a/pkg/inventory/sysfs.go +++ b/pkg/inventory/sysfs.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "path/filepath" + "regexp" "strconv" "strings" @@ -40,6 +41,10 @@ const ( sysdevPath = "/sys/devices" ) +// pciAddressRegex is used to identify a PCI address within a string. +// It matches patterns like "0000:00:04.0" or "00:04.0". +var pciAddressRegex = regexp.MustCompile(`^(?:([0-9a-fA-F]{4}):)?([0-9a-fA-F]{2}):([0-9a-fA-F]{2})\.([0-9a-fA-F])$`) + func realpath(ifName string, syspath string) string { linkPath := filepath.Join(syspath, ifName) dst, err := os.Readlink(linkPath) @@ -141,42 +146,58 @@ type pciRoot struct { bus string } -func bdfAddress(ifName string, path string) (*pciAddress, error) { +// parsePCIAddress takes a string and attempts to extract and parse a PCI address from it. +func parsePCIAddress(s string) (*pciAddress, error) { + matches := pciAddressRegex.FindStringSubmatch(s) + if matches == nil { + return nil, fmt.Errorf("could not find PCI address in string: %s", s) + } address := &pciAddress{} - // https://docs.kernel.org/PCI/sysfs-pci.html - // realpath /sys/class/net/ens4/device - // /sys/devices/pci0000:00/0000:00:04.0/virtio1 - // The topmost element describes the PCI domain and bus number. - // PCI domain: 0000 Bus: 00 Device: 04 Function: 0 - sysfsPath := realpath(ifName, path) - bfd := strings.Split(sysfsPath, "/") - if len(bfd) < 5 { - return nil, fmt.Errorf("could not find corresponding PCI address: %v", bfd) - } - - klog.V(4).Infof("pci address: %s", bfd[4]) - pci := strings.Split(bfd[4], ":") - // Simple BDF notation - switch len(pci) { - case 2: - address.bus = pci[0] - f := strings.Split(pci[1], ".") - if len(f) != 2 { - return nil, fmt.Errorf("could not find corresponding PCI device and function: %v", pci) - } - address.device = f[0] - address.function = f[1] - case 3: - address.domain = pci[0] - address.bus = pci[1] - f := strings.Split(pci[2], ".") - if len(f) != 2 { - return nil, fmt.Errorf("could not find corresponding PCI device and function: %v", pci) + + // When pciAddressRegex matches, it is expected to return 5 elements. (First + // is the complete matched string itself, and the next 4 are the submatches + // corresponding to Domain:Bus:Device.Function). Examples: + // - "0000:00:04.0" -> ["0000:00:04.0" "0000" "00" "04" "0"] + // - "00:05.0" -> ["0000:00:05.0" "" "00" "05" "0"] + if len(matches) == 5 { + address.domain = matches[1] + address.bus = matches[2] + address.device = matches[3] + address.function = matches[4] + } else { + return nil, fmt.Errorf("invalid PCI address format: %s", s) + } + + return address, nil +} + +// pciAddressFromPath takes a full sysfs path and traverses it upwards to find +// the first component that contains a valid PCI address. +func pciAddressFromPath(path string) (*pciAddress, error) { + parts := strings.Split(path, "/") + for len(parts) > 0 { + current := parts[len(parts)-1] + addr, err := parsePCIAddress(current) + if err == nil { + return addr, nil } - address.device = f[0] - address.function = f[1] - default: - return nil, fmt.Errorf("could not find corresponding PCI address: %v", pci) + parts = parts[:len(parts)-1] + } + return nil, fmt.Errorf("could not find PCI address in path: %s", path) +} + +// pciAddressForNetInterface finds the PCI address for a given network interface name. +func pciAddressForNetInterface(ifName string) (*pciAddress, error) { + // First, find the absolute path of the device in the sysfs, which typically + // looks like: + // /sys/devices/pci0000:8c/0000:8c:00.0/0000:8d:00.0/0000:8e:02.0/0000:91:00.0/net/eth0 + // Then, use pciAddressFromPath() to traverse the path upwards, checking + // each component to find the first one that matches the format of a PCI + // address. + sysfsPath := realpath(ifName, sysnetPath) + address, err := pciAddressFromPath(sysfsPath) + if err != nil { + return nil, fmt.Errorf("could not find PCI address for interface %q: %w", ifName, err) } return address, nil } diff --git a/pkg/inventory/sysfs_test.go b/pkg/inventory/sysfs_test.go new file mode 100644 index 00000000..d0d6ae76 --- /dev/null +++ b/pkg/inventory/sysfs_test.go @@ -0,0 +1,138 @@ +/* +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package inventory + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestParsePCIAddress(t *testing.T) { + testCases := []struct { + name string + input string + want *pciAddress + wantErr bool + }{ + { + name: "valid with domain", + input: "0000:00:04.0", + want: &pciAddress{ + domain: "0000", + bus: "00", + device: "04", + function: "0", + }, + wantErr: false, + }, + { + name: "valid without domain", + input: "00:04.0", + want: &pciAddress{ + domain: "", + bus: "00", + device: "04", + function: "0", + }, + wantErr: false, + }, + { + name: "invalid format", + input: "not-a-pci-address", + wantErr: true, + }, + { + name: "empty string", + input: "", + wantErr: true, + }, + { + name: "embedded in string", + input: "pci-0000:8c:00.0-device", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := parsePCIAddress(tc.input) + if (err != nil) != tc.wantErr { + t.Fatalf("pciAddressFromString() error = %v, wantErr %v", err, tc.wantErr) + return + } + if diff := cmp.Diff(tc.want, got, cmp.AllowUnexported(pciAddress{})); diff != "" { + t.Errorf("pciAddressFromString() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestPCIAddressFromPath(t *testing.T) { + testCases := []struct { + name string + input string + want *pciAddress + wantErr bool + }{ + { + name: "simple path", + input: "/sys/devices/pci0000:00/0000:00:04.0/virtio1/net/eth0", + want: &pciAddress{ + domain: "0000", + bus: "00", + device: "04", + function: "0", + }, + wantErr: false, + }, + { + name: "hierarchical path", + input: "/sys/devices/pci0000:8c/0000:8c:00.0/0000:8d:00.0/0000:8e:02.0/0000:91:00.0/net/eth3", + want: &pciAddress{ + domain: "0000", + bus: "91", + device: "00", + function: "0", + }, + wantErr: false, + }, + { + name: "no pci address in path", + input: "/sys/devices/virtual/net/lo", + wantErr: true, + }, + { + name: "empty path", + input: "", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := pciAddressFromPath(tc.input) + if (err != nil) != tc.wantErr { + t.Errorf("pciAddressFromPath() error = %v, wantErr %v", err, tc.wantErr) + return + } + if diff := cmp.Diff(tc.want, got, cmp.AllowUnexported(pciAddress{})); diff != "" { + t.Errorf("pciAddressFromPath() mismatch (-want +got):\n%s", diff) + } + }) + } +}