diff --git a/pkg/provisioner/templates/nv-driver.go b/pkg/provisioner/templates/nv-driver.go index be91861fe..48531a48f 100644 --- a/pkg/provisioner/templates/nv-driver.go +++ b/pkg/provisioner/templates/nv-driver.go @@ -113,8 +113,13 @@ holodeck_progress "$COMPONENT" 3 5 "Adding CUDA repository" if [[ ! -f /etc/apt/sources.list.d/cuda*.list ]] || \ [[ ! -f /usr/share/keyrings/cuda-archive-keyring.gpg ]]; then distribution=$(. /etc/os-release; echo "${ID}${VERSION_ID}" | sed -e 's/\.//g') + # Determine CUDA repo architecture (NVIDIA uses "sbsa" for arm64 servers) + CUDA_ARCH="$(uname -m)" + if [[ "$CUDA_ARCH" == "aarch64" ]]; then + CUDA_ARCH="sbsa" + fi holodeck_retry 3 "$COMPONENT" wget -q \ - "https://developer.download.nvidia.com/compute/cuda/repos/$distribution/x86_64/cuda-keyring_1.1-1_all.deb" + "https://developer.download.nvidia.com/compute/cuda/repos/$distribution/${CUDA_ARCH}/cuda-keyring_1.1-1_all.deb" sudo dpkg -i cuda-keyring_1.1-1_all.deb rm -f cuda-keyring_1.1-1_all.deb holodeck_retry 3 "$COMPONENT" sudo apt-get update diff --git a/pkg/provisioner/templates/nv-driver_test.go b/pkg/provisioner/templates/nv-driver_test.go index d03e6d523..bf1a0315d 100644 --- a/pkg/provisioner/templates/nv-driver_test.go +++ b/pkg/provisioner/templates/nv-driver_test.go @@ -178,3 +178,33 @@ func TestNVDriverTemplate(t *testing.T) { }) } } + +func TestNVDriverTemplate_CUDARepoArch(t *testing.T) { + driver := &NvDriver{ + Branch: defaultNVBranch, + } + + var output bytes.Buffer + err := driver.Execute(&output, v1alpha1.Environment{}) + require.NoError(t, err) + + outStr := output.String() + + // Must NOT contain hardcoded x86_64 in the CUDA repo URL + require.NotContains(t, outStr, "cuda/repos/$distribution/x86_64/", + "Template must not hardcode x86_64 in the CUDA repository URL") + + // Must contain runtime architecture detection + require.Contains(t, outStr, `CUDA_ARCH="$(uname -m)"`, + "Template must detect architecture at runtime via uname -m") + + // Must contain aarch64 -> sbsa mapping + require.Contains(t, outStr, `if [[ "$CUDA_ARCH" == "aarch64" ]]; then`, + "Template must check for aarch64 architecture") + require.Contains(t, outStr, `CUDA_ARCH="sbsa"`, + "Template must map aarch64 to sbsa for NVIDIA CUDA repos") + + // Must use CUDA_ARCH variable in the wget URL + require.Contains(t, outStr, "${CUDA_ARCH}/cuda-keyring", + "Template must use CUDA_ARCH variable in the wget URL") +}