Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 50 additions & 38 deletions engibench/utils/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Sequence
import os
import subprocess
import tempfile


def pull(image: str) -> None:
Expand Down Expand Up @@ -217,40 +218,59 @@ def is_available(cls) -> bool:
return False


class Singularity(ContainerRuntime):
"""Singularity / Apptainer."""
DOCKER_PREFIX = "docker://"

name = "singularity"
executable = "singularity"

class Apptainer(ContainerRuntime):
"""Apptainer."""

name = "apptainer"
executable = "apptainer"

@classmethod
def pull(cls, image: str) -> None:
"""Pull an image.
def _set_apptainer_env(cls) -> None:
"""Set Apptainer environment variables."""
# See https://scicomp.ethz.ch/wiki/Apptainer#Settings
# Set cache directory to SCRATCH if available, otherwise use default
scratch_dir = os.environ.get("SCRATCH")
if scratch_dir:
# stores apptainer images in your $SCRATCH directory
os.environ["APPTAINER_CACHEDIR"] = f"{scratch_dir}/.apptainer"

# uses the local temporary directory to store temporary data when building images
os.environ["APPTAINER_TMPDIR"] = os.environ.get("TMPDIR", tempfile.gettempdir())

Args:
image: Container image to pull.
"""
# Convert to docker URI if needed
if "://" not in image:
docker_uri = "docker://" + image
else:
docker_uri = image
# Extract just the image part if it's already a docker URI
if docker_uri.startswith("docker://"):
image = docker_uri[len("docker://") :]
@classmethod
def sif_filename(cls, image: str) -> str:
"""Construct the sif filename from an image specifier."""
# Extract just the image part if it's a docker URI
image = image.removeprefix(DOCKER_PREFIX)

# Parse the image name to match Singularity's naming convention
# For "mdolab/public:u22-gcc-ompi-stable", Singularity creates "public_u22-gcc-ompi-stable.sif"
image_name = image.split("/")[-1] if "/" in image else image
image_name = image.rsplit("/", 1)[-1] if "/" in image else image

# Replace ":" with "_" in the image name
sif_filename = image_name.replace(":", "_") + ".sif"
return image_name.replace(":", "_") + ".sif"

@classmethod
def pull(cls, image: str) -> None:
"""Pull an image.

Args:
image: Container image to pull.
"""
# Set Apptainer environment variables
cls._set_apptainer_env()
# Get sif filename
sif_filename = cls.sif_filename(image)

# Check if the image already exists
if os.path.exists(sif_filename):
print(f"Image file already exists: {sif_filename} - skipping pull")
return

# Convert to docker URI if needed
docker_uri = DOCKER_PREFIX + image if "://" not in image else image
# Image doesn't exist, proceed with pull
subprocess.run([cls.executable, "pull", docker_uri], check=True)

Expand All @@ -272,32 +292,24 @@ def run(
env: Mapping of environment variable names and values to set inside the container.
name: Optional name for the container (not supported by all runtimes).
"""
# Create a mutable working copy to add required system mounts
working_mounts = list(mounts)

# HPC/Singularity containers require explicit /tmp mounting to prevent memory issues
# and ensure application compatibility. This is container configuration, not insecure temp file creation.
if working_mounts: # Only add /tmp mount if we have existing mounts
# Use the first mount's host path for /tmp (existing logic)
tmp_host_path = working_mounts[0][0]
working_mounts.append((tmp_host_path, "/tmp")) # noqa: S108
else:
# Handle the empty mounts case - perhaps use a default temp directory
# or skip the /tmp mount altogether
pass

mount_args = (["--mount", f"type=bind,src={src},target={target}"] for src, target in working_mounts)
# Set Apptainer environment variables
cls._set_apptainer_env()

# Get sif filename
sif_image = cls.sif_filename(image)

# Reconstruct mount and env args
mount_args = (["--mount", f"type=bind,src={src},target={target}"] for src, target in mounts)
env_args = (["--env", f"{var}={value}"] for var, value in (env or {}).items())
if "://" not in image:
image = "docker://" + image

return subprocess.run(
[
cls.executable,
"run",
"--compat",
*(arg for args in mount_args for arg in args),
*(arg for args in env_args for arg in args),
image,
sif_image,
*command,
],
check=False,
Expand Down
Loading