diff --git a/contrib/cdisetup/nvidia/nvidia.go b/contrib/cdisetup/nvidia/nvidia.go index 9a4d32bab84c..88b70b4196cb 100644 --- a/contrib/cdisetup/nvidia/nvidia.go +++ b/contrib/cdisetup/nvidia/nvidia.go @@ -25,10 +25,13 @@ import ( // This is example of experimental on-demand setup of a CDI devices. // This code is not currently shipping with BuildKit and will probably change. -const ( - cdiKind = "nvidia.com/gpu" - defaultVersion = "570.0" -) +const cdiKind = "nvidia.com/gpu" + +// https://github.com/ollama/ollama/blob/b816ff86c923e0290f58f2275e831fc17c29ba37/discover/gpu_linux.go#L33-L43 +var libcudaGlobs = []string{ + "/usr/lib/*-linux-gnu/libcuda.so*", + "/usr/lib/wsl/drivers/*/libcuda.so*", +} func init() { cdidevices.Register(cdiKind, &setup{}) @@ -39,8 +42,7 @@ type setup struct{} var _ cdidevices.Setup = &setup{} func (s *setup) Validate() error { - _, err := readVersion() - if err == nil { + if _, err := readVersion(); err == nil { return nil } b, err := hasNvidiaDevices() @@ -93,55 +95,94 @@ func (s *setup) Run(ctx context.Context) (err error) { return errors.Errorf("NVIDIA setup is currently only supported on Debian/Ubuntu") } - var needsDriver bool - - if _, err := os.Stat("/proc/driver/nvidia"); err != nil { - needsDriver = true + needsDriver := true + if _, err := os.Stat("/proc/driver/nvidia"); err == nil { + needsDriver = false + } else if nvidiaSmi, err := exec.LookPath("nvidia-smi"); err == nil && nvidiaSmi != "" { + if err := run(ctx, []string{nvidiaSmi, "-L"}, pw, dgst); err == nil { + needsDriver = false + } + } + if needsDriver { + if hasWSLGPU() { + return errors.Errorf("NVIDIA drivers are required for WSL with non PCI-based GPUs") + } + return errors.Errorf("NVIDIA drivers are required. Try loading NVIDIA kernel module with \"modprobe nvidia\" command") } - var arch string - switch runtime.GOARCH { - case "amd64": - arch = "x86_64" - case "arm64": - arch = "sbsa" - // for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb + var dv string + if !hasLibsInstalled() && !hasWSLGPU() { + version, err := readVersion() + if err != nil { + return errors.Wrapf(err, "failed to read NVIDIA driver version") + } + var ok bool + dv, _, ok = strings.Cut(version, ".") + if !ok { + return errors.Errorf("failed to parse NVIDIA driver version %q", version) + } } - if arch == "" { - return errors.Errorf("unsupported architecture: %s", runtime.GOARCH) + if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil { + return err } - if needsDriver { - pw.Write(identity.NewID(), client.VertexWarning{ - Vertex: dgst, - Short: []byte("NVIDIA Drivers not found. Installing prebuilt drivers is not recommended"), - }) + if err := run(ctx, []string{"apt-get", "install", "-y", "gpg"}, pw, dgst); err != nil { + return err } - version, err := readVersion() - if err != nil && !needsDriver { - return errors.Wrapf(err, "failed to read NVIDIA driver version") + if err := installPackages(ctx, dv, pw, dgst); err != nil { + return err } - if version == "" { - version = defaultVersion + + if err := os.MkdirAll("/etc/cdi", 0700); err != nil { + return errors.Wrapf(err, "failed to create /etc/cdi") } - v1, _, ok := strings.Cut(version, ".") - if !ok { - return errors.Errorf("failed to parse NVIDIA driver version %q", version) + + buf := &bytes.Buffer{} + + cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate") + cmd.Stdout = buf + cmd.Stderr = newStream(pw, 2, dgst) + if err := cmd.Run(); err != nil { + return errors.Wrapf(err, "failed to generate CDI spec") } - if err := run(ctx, []string{"apt-get", "update"}, pw, dgst); err != nil { - return err + if len(buf.Bytes()) == 0 { + return errors.Errorf("nvidia-ctk output is empty") } - if err := run(ctx, []string{"apt-get", "install", "-y", "gpg"}, pw, dgst); err != nil { - return err + if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil { + return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml") } + return nil +} + +func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error { + fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " ")) + cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec + cmd.Stderr = newStream(pw, 2, dgst) + cmd.Stdout = newStream(pw, 1, dgst) + return cmd.Run() +} + +func installPackages(ctx context.Context, dv string, pw progress.Writer, dgst digest.Digest) error { const aptDistro = "ubuntu2404" - aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/" + var arch string + switch runtime.GOARCH { + case "amd64": + arch = "x86_64" + case "arm64": + arch = "sbsa" + // for non-sbsa could use https://nvidia.github.io/libnvidia-container/stable/deb + } + if arch == "" { + return errors.Errorf("unsupported architecture: %s", runtime.GOARCH) + } + + aptURL := "https://developer.download.nvidia.com/compute/cuda/repos/" + aptDistro + "/" + arch + "/" keyTarget := "/usr/share/keyrings/nvidia-cuda-keyring.gpg" if _, err := os.Stat(keyTarget); err != nil { @@ -174,59 +215,17 @@ func (s *setup) Run(ctx context.Context) (err error) { return err } - if needsDriver { - // this pretty much never works, is it even worth having? - // better approach could be to try to create another chroot/container that is built with same kernel packages as the host - // could nvidia-headless-no-dkms- be reusable - if err := run(ctx, []string{"apt-get", "install", "-y", "nvidia-driver-" + v1}, pw, dgst); err != nil { - return err - } - _, err := os.Stat("/proc/driver/nvidia") - if err != nil { - return errors.Wrapf(err, "failed to install NVIDIA kernel module. Please install NVIDIA drivers manually") - } - } - - if err := run(ctx, []string{"apt-get", "install", "-y", "--no-install-recommends", - "libnvidia-compute-" + v1, - "libnvidia-extra-" + v1, - "libnvidia-gl-" + v1, - "nvidia-utils-" + v1, - "nvidia-container-toolkit-base", - }, pw, dgst); err != nil { - return err - } - - if err := os.MkdirAll("/etc/cdi", 0700); err != nil { - return errors.Wrapf(err, "failed to create /etc/cdi") + pkgs := []string{"nvidia-container-toolkit-base"} + if dv != "" { + pkgs = append(pkgs, []string{ + "libnvidia-compute-" + dv, + "libnvidia-extra-" + dv, + "libnvidia-gl-" + dv, + "nvidia-utils-" + dv, + }...) } - buf := &bytes.Buffer{} - - cmd := exec.CommandContext(ctx, "nvidia-ctk", "cdi", "generate") - cmd.Stdout = buf - cmd.Stderr = newStream(pw, 2, dgst) - if err := cmd.Run(); err != nil { - return errors.Wrapf(err, "failed to generate CDI spec") - } - - if len(buf.Bytes()) == 0 { - return errors.Errorf("nvidia-ctk output is empty") - } - - if err := os.WriteFile("/etc/cdi/nvidia.yaml", buf.Bytes(), 0644); err != nil { - return errors.Wrapf(err, "failed to write /etc/cdi/nvidia.yaml") - } - - return nil -} - -func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Digest) error { - fmt.Fprintf(newStream(pw, 2, dgst), "> %s\n", strings.Join(args, " ")) - cmd := exec.CommandContext(ctx, args[0], args[1:]...) //nolint:gosec - cmd.Stderr = newStream(pw, 2, dgst) - cmd.Stdout = newStream(pw, 1, dgst) - return cmd.Run() + return run(ctx, append([]string{"apt-get", "install", "-y", "--no-install-recommends"}, pkgs...), pw, dgst) } func readVersion() (string, error) { @@ -268,6 +267,10 @@ func hasNvidiaDevices() (bool, error) { } } + if !found { + found = hasWSLGPU() + } + return found, nil } @@ -302,3 +305,19 @@ func isDebianOrUbuntu() (bool, error) { return id == "debian" || id == "ubuntu", nil } + +func hasWSLGPU() bool { + // WSL-specific GPU mapping that doesn't expose PCI info. + _, err := os.Stat("/dev/dxg") + return err == nil +} + +func hasLibsInstalled() bool { + // Check for libcuda in the standard locations to confirm NVIDIA GPU drivers + for _, p := range libcudaGlobs { + if matches, err := filepath.Glob(p); err == nil && len(matches) > 0 { + return true + } + } + return false +}