From f43b56c0226c715c05a08bee797bc82863856271 Mon Sep 17 00:00:00 2001 From: PabloNA97 Date: Fri, 27 Sep 2024 17:49:28 +0200 Subject: [PATCH] fix: avoid nvidia-smi if it will fail Avoid GPU query whenever running in a SLURM environment without GPU. If the node doesn't have NVIDIA drivers this query will fail and the process will get killed. --- openfe/utils/system_probe.py | 55 ++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/openfe/utils/system_probe.py b/openfe/utils/system_probe.py index 483d3f2dd..7b9aa869d 100644 --- a/openfe/utils/system_probe.py +++ b/openfe/utils/system_probe.py @@ -273,6 +273,55 @@ def _get_hostname() -> str: return socket.gethostname() +def _slurm_environment() -> bool: + """ + Check if the current environment is managed by SLURM. + """ + + slurm_job_id = os.environ.get("SLURM_JOB_ID") + + if slurm_job_id: + return True + else: + return False + + +def _check_slurm_gpu_info() -> bool: + """ + Check if the GPU information is available in the SLURM environment. + + Returns + ------- + bool + True if the GPU information is available in the SLURM environment, False + otherwise. + + Notes + ----- + This function checks if the GPU information is available in the SLURM environment by + inspecting the environment variables. + + The function returns True if any of the following environment variables are present: + - 'SLURM_JOB_GPUS' + - 'SLURM_GPUS' + - 'CUDA_VISIBLE_DEVICES' + + Otherwise, it returns False. + """ + + slurm_job_gpus = os.environ.get("SLURM_JOB_GPUS") + slurm_gpus = os.environ.get("SLURM_GPUS") + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + + logging.debug(f"SLURM_JOB_GPUS: {slurm_job_gpus}") + logging.debug(f"SLURM_GPUS_PER_NODE: {slurm_gpus}") + logging.debug(f"CUDA_VISIBLE_DEVICES: {cuda_visible_devices}") + + if slurm_job_gpus or slurm_gpus or cuda_visible_devices: + return True + else: + return False + def _get_gpu_info() -> dict[str, dict[str, str]]: """ Get GPU information using the 'nvidia-smi' command-line utility. @@ -336,6 +385,12 @@ def _get_gpu_info() -> dict[str, dict[str, str]]: "utilization.memory,memory.total,driver_version," ) + if _slurm_environment() and not _check_slurm_gpu_info(): + logging.debug( + "SLURM environment detected, but GPU information is not available." + ) + return {} + try: nvidia_smi_output = subprocess.check_output( ["nvidia-smi", GPU_QUERY, "--format=csv"]