A CUDA API interception library that simulates GPU devices in non-GPU environments, enabling basic operations for PyTorch and other deep learning frameworks.
- CUDA Driver API - Device management, memory allocation, kernel launch
- CUDA Runtime API - cudaMalloc/Free, cudaMemcpy, Stream, Event
- cuBLAS/cuBLASLt - Matrix operations (GEMM, PyTorch 2.x compatible)
- NVML API - GPU information queries
- Python API Wrapper -
import fakegpu; fakegpu.init()enables FakeGPU from inside Python - PyTorch Support - Basic tensor ops, linear layers, neural networks
- GPU Tool Compatibility - Compatible with existing GPU status monitoring tools (nvidia-smi, gpustat, etc.)
- Detailed Reporting - More comprehensive documentation and analysis reports
- Multi-Node GPU Communication - Simulate cross-node GPU communication (NCCL, etc.)
- Enhanced Testing - Optimize test suite with more languages and runtime environments
- Preset GPU Info - Add more preset GPU hardware configurations
- Multi-Architecture & Data Types - Support different GPU architectures and various data storage/memory types
cmake -S . -B build
cmake --build buildGenerated libraries:
build/libcuda.so.1- CUDA Driver APIbuild/libcudart.so.12- CUDA Runtime APIbuild/libcublas.so.12- cuBLAS/cuBLASLt APIbuild/libnvidia-ml.so.1- NVML API
Comparison test (recommended):
./test/run_comparison.shRuns identical tests on both real GPU and FakeGPU to verify correctness.
PyTorch test:
LD_LIBRARY_PATH=./build:$LD_LIBRARY_PATH \
LD_PRELOAD=./build/libcublas.so.12:./build/libcudart.so.12:./build/libcuda.so.1:./build/libnvidia-ml.so.1 \
python3 test/test_comparison.py --mode fakeimport torch
# All PyTorch CUDA operations are intercepted by FakeGPU
device = torch.device('cuda:0')
x = torch.randn(100, 100, device=device)
y = torch.randn(100, 100, device=device)
z = x @ y # Matrix multiplication
# Simple neural network
model = torch.nn.Linear(100, 50).to(device)
output = model(x)Runtime requires preloading all libraries:
LD_LIBRARY_PATH=./build:$LD_LIBRARY_PATH \
LD_PRELOAD=./build/libcublas.so.12:./build/libcudart.so.12:./build/libcuda.so.1:./build/libnvidia-ml.so.1 \
python your_script.pyPython wrapper (no need to start Python with LD_PRELOAD):
import fakegpu
# Call early (before importing torch / CUDA-using libraries)
fakegpu.init()
import torchShortcut runner:
./fgpu python your_script.py
# Optional: FAKEGPU_BUILD_DIR=/path/to/build ./fgpu python your_script.pyPython runner (installs fakegpu console script):
fakegpu python your_script.py
# or: python -m fakegpu python your_script.pyGPU tools (nvidia-smi)
# FakeGPU-simulated devices via NVML stubs
./fgpu nvidia-smi
# Temperatures may show N/A because the TemperatureV struct is not fully emulated yet.| Test | Status | Description |
|---|---|---|
| Tensor creation | ✓ | Basic memory allocation |
| Element-wise ops | ✓ | Add, multiply, trigonometric |
| Matrix multiplication | ✓ | cuBLAS/cuBLASLt GEMM |
| Linear layer | ✓ | PyTorch nn.Linear |
| Neural network | ✓ | Multi-layer forward pass |
| Memory transfer | ✓ | CPU ↔ GPU data copy |
FakeGPU
├── src/
│ ├── core/ # Global state and device management
│ ├── cuda/ # CUDA Driver/Runtime API stubs
│ ├── cublas/ # cuBLAS/cuBLASLt API stubs
│ ├── nvml/ # NVML API stubs
│ └── monitor/ # Resource monitoring and reporting
└── test/ # Test scripts
Core Design:
- Uses
LD_PRELOADto intercept CUDA API calls - Device memory backed by system RAM (malloc/free)
- Matrix operations return random values (no actual computation)
- Kernel launches are no-ops (logging only)
- Default build exposes eight
Fake NVIDIA A100-SXM4-80GBdevices to mirror common server nodes. - GPU parameters are edited in YAML under
profiles/*.yaml; CMake embeds these files at build time so no runtime file lookup is needed. Add or tweak a file, reruncmake -S . -B build, and the new profiles are compiled in. - Presets cover multiple compute capabilities (Maxwell→Blackwell) and feed the existing helpers (
GpuProfile::GTX980/P100/V100/T4/A40/A100/H100/L40S/B100/B200), which now prefer the YAML data and fall back to code defaults if parsing fails.
- ❌ No real GPU computation (kernels are no-ops)
- ❌ Complex models (Transformers) may require additional APIs
- ❌ No multi-GPU synchronization
⚠️ For testing and development environments only
- ✅ Running GPU code tests in CI/CD environments
- ✅ Debugging deep learning code on machines without GPUs
- ✅ Validating CUDA API call logic
- ✅ Prototyping and unit testing
- CMake 3.14+
- C++17 compiler
- Python 3.8+ (for testing)
- PyTorch 2.x (optional, for testing)
MIT License
- Test Guide - Detailed testing instructions
- cuBLASLt Implementation - cuBLASLt support details