From f43b56c0226c715c05a08bee797bc82863856271 Mon Sep 17 00:00:00 2001
From: PabloNA97
Date: Fri, 27 Sep 2024 17:49:28 +0200
Subject: [PATCH] fix: avoid nvidia-smi if it will fail
Avoid GPU query whenever running in a SLURM environment without GPU. If the node doesn't have NVIDIA drivers this query will fail and the process will get killed.
---
openfe/utils/system_probe.py | 55 ++++++++++++++++++++++++++++++++++++
1 file changed, 55 insertions(+)
diff --git a/openfe/utils/system_probe.py b/openfe/utils/system_probe.py
index 483d3f2dd..7b9aa869d 100644
--- a/openfe/utils/system_probe.py
+++ b/openfe/utils/system_probe.py
@@ -273,6 +273,55 @@ def _get_hostname() -> str:
return socket.gethostname()
+def _slurm_environment() -> bool:
+ """
+ Check if the current environment is managed by SLURM.
+ """
+
+ slurm_job_id = os.environ.get("SLURM_JOB_ID")
+
+ if slurm_job_id:
+ return True
+ else:
+ return False
+
+
+def _check_slurm_gpu_info() -> bool:
+ """
+ Check if the GPU information is available in the SLURM environment.
+
+ Returns
+ -------
+ bool
+ True if the GPU information is available in the SLURM environment, False
+ otherwise.
+
+ Notes
+ -----
+ This function checks if the GPU information is available in the SLURM environment by
+ inspecting the environment variables.
+
+ The function returns True if any of the following environment variables are present:
+ - 'SLURM_JOB_GPUS'
+ - 'SLURM_GPUS'
+ - 'CUDA_VISIBLE_DEVICES'
+
+ Otherwise, it returns False.
+ """
+
+ slurm_job_gpus = os.environ.get("SLURM_JOB_GPUS")
+ slurm_gpus = os.environ.get("SLURM_GPUS")
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
+
+ logging.debug(f"SLURM_JOB_GPUS: {slurm_job_gpus}")
+ logging.debug(f"SLURM_GPUS_PER_NODE: {slurm_gpus}")
+ logging.debug(f"CUDA_VISIBLE_DEVICES: {cuda_visible_devices}")
+
+ if slurm_job_gpus or slurm_gpus or cuda_visible_devices:
+ return True
+ else:
+ return False
+
def _get_gpu_info() -> dict[str, dict[str, str]]:
"""
Get GPU information using the 'nvidia-smi' command-line utility.
@@ -336,6 +385,12 @@ def _get_gpu_info() -> dict[str, dict[str, str]]:
"utilization.memory,memory.total,driver_version,"
)
+ if _slurm_environment() and not _check_slurm_gpu_info():
+ logging.debug(
+ "SLURM environment detected, but GPU information is not available."
+ )
+ return {}
+
try:
nvidia_smi_output = subprocess.check_output(
["nvidia-smi", GPU_QUERY, "--format=csv"]