Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cuda_bindings/tests/cufile.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
// NOTE : Application can override custom configuration via export CUFILE_ENV_PATH_JSON=<filepath>
// e.g : export CUFILE_ENV_PATH_JSON="/home/<xxx>/cufile.json"


"execution" : {
// max number of workitems in the queue;
"max_io_queue_depth": 128,
// max number of host threads per gpu to spawn for parallel IO
"max_io_threads" : 4,
// enable support for parallel IO
"parallel_io" : true,
// minimum IO threshold before splitting the IO
"min_io_threshold_size_kb" : 8192,
// maximum parallelism for a single request
"max_request_parallelism" : 4
}
}
80 changes: 62 additions & 18 deletions cuda_bindings/tests/test_cufile.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,99 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import ctypes
import errno
import logging
import os
import pathlib
import platform
import tempfile
from contextlib import suppress
from functools import cache

import pytest

import cuda.bindings.driver as cuda

# Configure logging to show INFO level and above
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
force=True, # Override any existing logging configuration
)

try:
from cuda.bindings import cufile
except ImportError:
cufile = None


def platform_is_wsl():
"""Check if running on Windows Subsystem for Linux (WSL)."""
return platform.system() == "Linux" and "microsoft" in pathlib.Path("/proc/version").read_text().lower()


if cufile is None:
pytest.skip("skipping tests on Windows", allow_module_level=True)

if platform_is_wsl():
pytest.skip("skipping cuFile tests on WSL", allow_module_level=True)


@pytest.fixture(scope="module")
def cufile_env_json():
"""Set CUFILE_ENV_PATH_JSON environment variable for async tests."""
original_value = os.environ.get("CUFILE_ENV_PATH_JSON")

# Use /etc/cufile.json if it exists, otherwise fallback to cufile.json in tests directory
if os.path.exists("/etc/cufile.json"):
config_path = "/etc/cufile.json"
else:
# Get absolute path to cufile.json in the same directory as this test file
test_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(test_dir, "cufile.json")

logging.info(f"Using cuFile config: {config_path}")
os.environ["CUFILE_ENV_PATH_JSON"] = config_path
yield
# Restore original value or remove if it wasn't set
if original_value is not None:
os.environ["CUFILE_ENV_PATH_JSON"] = original_value
else:
os.environ.pop("CUFILE_ENV_PATH_JSON", None)


@cache
def cufileLibraryAvailable():
"""Check if cuFile library is available on the system."""
try:
# Try to get cuFile library version - this will fail if library is not available
version = cufile.get_version()
print(f"cuFile library available, version: {version}")
logging.info(f"cuFile library available, version: {version}")
return True
except Exception as e:
print(f"cuFile library not available: {e}")
logging.warning(f"cuFile library not available: {e}")
return False


@cache
def cufileVersionLessThan(target):
"""Check if cuFile library version is less than target version."""
try:
# Get cuFile library version
version = cufile.get_version()
print(f"cuFile library version: {version}")
logging.info(f"cuFile library version: {version}")
# Check if version is less than target
if version < target:
print(f"cuFile library version {version} is less than required {target}")
logging.warning(f"cuFile library version {version} is less than required {target}")
return True
return False
except Exception as e:
print(f"Error checking cuFile version: {e}")
logging.error(f"Error checking cuFile version: {e}")
return True # Assume old version if any error occurs


@cache
def isSupportedFilesystem():
"""Check if the current filesystem is supported (ext4 or xfs)."""
try:
Expand All @@ -65,14 +109,14 @@ def isSupportedFilesystem():
current_dir = os.path.abspath(".")
if current_dir.startswith(mount_point):
fs_type_lower = fs_type.lower()
print(f"Current filesystem type: {fs_type_lower}")
logging.info(f"Current filesystem type: {fs_type_lower}")
return fs_type_lower in ["ext4", "xfs"]

# If we get here, we couldn't determine the filesystem type
print("Could not determine filesystem type from /proc/mounts")
logging.warning("Could not determine filesystem type from /proc/mounts")
return False
except Exception as e:
print(f"Error checking filesystem type: {e}")
logging.error(f"Error checking filesystem type: {e}")
return False


Expand Down Expand Up @@ -730,7 +774,7 @@ def test_cufile_read_write_large():


@pytest.mark.skipif(not isSupportedFilesystem(), reason="cuFile handle_register requires ext4 or xfs filesystem")
def test_cufile_write_async():
def test_cufile_write_async(cufile_env_json):
"""Test cuFile asynchronous write operations."""
# Initialize CUDA
(err,) = cuda.cuInit(0)
Expand Down Expand Up @@ -823,7 +867,7 @@ def test_cufile_write_async():


@pytest.mark.skipif(not isSupportedFilesystem(), reason="cuFile handle_register requires ext4 or xfs filesystem")
def test_cufile_read_async():
def test_cufile_read_async(cufile_env_json):
"""Test cuFile asynchronous read operations."""
# Initialize CUDA
(err,) = cuda.cuInit(0)
Expand Down Expand Up @@ -929,7 +973,7 @@ def test_cufile_read_async():


@pytest.mark.skipif(not isSupportedFilesystem(), reason="cuFile handle_register requires ext4 or xfs filesystem")
def test_cufile_async_read_write():
def test_cufile_async_read_write(cufile_env_json):
"""Test cuFile asynchronous read and write operations in sequence."""
# Initialize CUDA
(err,) = cuda.cuInit(0)
Expand Down Expand Up @@ -1788,13 +1832,13 @@ def test_set_get_parameter_string():
retrieved_value_raw = cufile.get_parameter_string(cufile.StringConfigParameter.LOGGING_LEVEL, 256)
# Use safe_decode_string to handle null terminators and padding
retrieved_value = safe_decode_string(retrieved_value_raw.encode("utf-8"))
print(f"Logging level test: set {logging_level}, got {retrieved_value}")
logging.info(f"Logging level test: set {logging_level}, got {retrieved_value}")
# The retrieved value should be a string, so we can compare directly
assert retrieved_value == logging_level, (
f"Logging level mismatch: set {logging_level}, got {retrieved_value}"
)
except Exception as e:
print(f"Logging level test failed: {e}")
logging.error(f"Logging level test failed: {e}")
# Re-raise the exception to make the test fail
raise

Expand All @@ -1810,11 +1854,11 @@ def test_set_get_parameter_string():
retrieved_value_raw = cufile.get_parameter_string(cufile.StringConfigParameter.ENV_LOGFILE_PATH, 256)
# Use safe_decode_string to handle null terminators and padding
retrieved_value = safe_decode_string(retrieved_value_raw.encode("utf-8"))
print(f"Log file path test: set {logfile_path}, got {retrieved_value}")
logging.info(f"Log file path test: set {logfile_path}, got {retrieved_value}")
# The retrieved value should be a string, so we can compare directly
assert retrieved_value == logfile_path, f"Log file path mismatch: set {logfile_path}, got {retrieved_value}"
except Exception as e:
print(f"Log file path test failed: {e}")
logging.error(f"Log file path test failed: {e}")
# Re-raise the exception to make the test fail
raise

Expand All @@ -1828,11 +1872,11 @@ def test_set_get_parameter_string():
retrieved_value_raw = cufile.get_parameter_string(cufile.StringConfigParameter.LOG_DIR, 256)
# Use safe_decode_string to handle null terminators and padding
retrieved_value = safe_decode_string(retrieved_value_raw.encode("utf-8"))
print(f"Log directory test: set {log_dir}, got {retrieved_value}")
logging.info(f"Log directory test: set {log_dir}, got {retrieved_value}")
# The retrieved value should be a string, so we can compare directly
assert retrieved_value == log_dir, f"Log directory mismatch: set {log_dir}, got {retrieved_value}"
except Exception as e:
print(f"Log directory test failed: {e}")
logging.error(f"Log directory test failed: {e}")
# Re-raise the exception to make the test fail
raise

Expand Down
Loading