Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build2cmake/src/config/v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ pub enum Language {
#[default]
Cuda,
CudaHipify,
Metal,
}

impl Display for Language {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Language::Cuda => f.write_str("cuda"),
Language::CudaHipify => f.write_str("cuda-hipify"),
Language::Metal => f.write_str("metal"),
}
}
}
3 changes: 3 additions & 0 deletions build2cmake/src/config/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ pub struct Kernel {
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
pub enum Backend {
Cuda,
Metal,
Rocm,
}

impl Display for Backend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Backend::Cuda => write!(f, "cuda"),
Backend::Metal => write!(f, "metal"),
Backend::Rocm => write!(f, "rocm"),
}
}
Expand All @@ -105,6 +107,7 @@ impl FromStr for Backend {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"cuda" => Ok(Backend::Cuda),
"metal" => Ok(Backend::Metal),
"rocm" => Ok(Backend::Rocm),
_ => Err(format!("Unknown backend: {}", s)),
}
Expand Down
27 changes: 12 additions & 15 deletions build2cmake/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@ use eyre::{bail, ensure, Context, Result};
use minijinja::Environment;

mod torch;
use torch::write_torch_ext;

mod torch_universal;
use torch::{write_torch_ext, write_torch_ext_metal, write_torch_universal_ext};

mod config;
use config::{Backend, Build, BuildCompat};

mod fileset;
use fileset::FileSet;
use torch_universal::write_torch_universal_ext;

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
Expand Down Expand Up @@ -108,16 +105,15 @@ fn generate_torch(
env.set_trim_blocks(true);
minijinja_embed::load_templates!(&mut env);

match (backend, build.general.universal) {
(None, true) => write_torch_universal_ext(&env, &build, target_dir, force, ops_id)?,
let backend = match (backend, build.general.universal) {
(None, true) => return write_torch_universal_ext(&env, &build, target_dir, force, ops_id),
(Some(backend), true) => bail!("Universal kernel, cannot generate for backend {}", backend),
// TODO: add check if that type of backend has at least one kernel.
(Some(backend), false) => {
if !build.has_kernel_with_backend(&backend) {
bail!("No kernels found for backend {}", backend);
}

write_torch_ext(&env, &build, target_dir, force, ops_id)?
backend
}
(None, false) => {
let mut kernel_backends = build.backends();
Expand All @@ -139,15 +135,16 @@ fn generate_torch(
);
}

match backend {
Backend::Cuda | Backend::Rocm => {
write_torch_ext(&env, &build, target_dir, force, ops_id)?
}
}
backend
}
}
};

Ok(())
match backend {
Backend::Cuda | Backend::Rocm => {
write_torch_ext(&env, backend, &build, target_dir, force, ops_id)
}
Backend::Metal => write_torch_ext_metal(&env, &build, target_dir, force, ops_id),
}
}

fn update_build(build_toml: PathBuf) -> Result<()> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,16 @@ if(GPU_LANG STREQUAL "CUDA")
if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()

add_compile_definitions(CUDA_KERNEL)
elseif(GPU_LANG STREQUAL "HIP")
set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}")
# TODO: remove this once we can set specific archs per source file set.
override_gpu_arches(GPU_ARCHES
${GPU_LANG}
"${${GPU_LANG}_SUPPORTED_ARCHS}")

add_compile_definitions(ROCM_KERNEL)
else()
override_gpu_arches(GPU_ARCHES
${GPU_LANG}
Expand Down
13 changes: 13 additions & 0 deletions build2cmake/src/templates/metal/kernel.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
set({{kernel_name}}_SRC
{{ sources }}
)

{% if includes %}
# TODO: check if CLion support this:
# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
set_source_files_properties(
{{'${' + kernel_name + '_SRC}'}}
PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}")
{% endif %}

list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}})
28 changes: 28 additions & 0 deletions build2cmake/src/templates/metal/preamble.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
cmake_minimum_required(VERSION 3.26)
project({{name}} LANGUAGES CXX)

set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum macOS deployment version")

install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)

include(FetchContent)
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")

include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)

if(DEFINED Python_EXECUTABLE)
# Allow passing through the interpreter (e.g. from setup.py).
find_package(Python COMPONENTS Development Development.SABIModule Interpreter)
if (NOT Python_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif()
else()
find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
endif()

append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")

find_package(Torch REQUIRED)

add_compile_definitions(METAL_KERNEL)
121 changes: 121 additions & 0 deletions build2cmake/src/templates/metal/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import logging
import os
from shutil import which, move
import subprocess
import sys
from pathlib import Path

from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext

logger = logging.getLogger(__name__)


def is_sccache_available() -> bool:
return which("sccache") is not None


def is_ccache_available() -> bool:
return which("ccache") is not None


def is_ninja_available() -> bool:
return which("ninja") is not None


class CMakeExtension(Extension):
def __init__(self, name: str, sourcedir: str = "") -> None:
super().__init__(name, sources=[], py_limited_api=True)
self.sourcedir = os.fspath(Path(sourcedir).resolve())


class CMakeBuild(build_ext):
def build_extension(self, ext: CMakeExtension) -> None:
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)
extdir = ext_fullpath.parent.resolve()

debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
cfg = "Debug" if debug else "Release"

cmake_generator = os.environ.get("CMAKE_GENERATOR", "")

# Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
# from Python.
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
f"-DPython_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
]
build_args = []
if "CMAKE_ARGS" in os.environ:
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]

if not cmake_generator or cmake_generator == "Ninja":
try:
import ninja

ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
cmake_args += [
"-GNinja",
f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
]
except ImportError:
pass

if is_sccache_available():
cmake_args += [
"-DCMAKE_C_COMPILER_LAUNCHER=sccache",
"-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
"-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
"-DCMAKE_OBJC_COMPILER_LAUNCHER=sccache",
"-DCMAKE_OBJCXX_COMPILER_LAUNCHER=sccache",
]
elif is_ccache_available():
cmake_args += [
"-DCMAKE_C_COMPILER_LAUNCHER=ccache",
"-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
"-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
"-DCMAKE_OBJC_COMPILER_LAUNCHER=ccache",
"-DCMAKE_OBJCXX_COMPILER_LAUNCHER=ccache",
]

num_jobs = os.getenv("MAX_JOBS", None)
if num_jobs is not None:
num_jobs = int(num_jobs)
logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
else:
try:
# os.sched_getaffinity() isn't universally available, so fall
# back to os.cpu_count() if we get an error here.
num_jobs = len(os.sched_getaffinity(0))
except AttributeError:
num_jobs = os.cpu_count()

build_temp = Path(self.build_temp) / ext.name
if not build_temp.exists():
build_temp.mkdir(parents=True)

subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
)
subprocess.run(
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
)


setup(
name="{{ name }}",
# The version is just a stub, it's not used by the final build artefact.
version="0.1.0",
ext_modules=[CMakeExtension("{{ name }}.{{ ops_name }}")],
cmdclass={"build_ext": CMakeBuild},
packages=find_packages(where="torch-ext", include=["{{ name }}*"]),
package_dir={"": "torch-ext"},
{% if data_globs %}
package_data={"{{ name }}": [ {{ data_globs }} ]},
{% endif %}
zip_safe=False,
install_requires=["torch"],
python_requires=">=3.9",
)
16 changes: 16 additions & 0 deletions build2cmake/src/templates/metal/torch-binding.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
#list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})

set(TORCH_{{name}}_SRC
{{ src|join(' ') }}
)

{% if includes %}
# TODO: check if CLion support this:
# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
set_source_files_properties(
{{'${TORCH_' + name + '_SRC}'}}
PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}")
{% endif %}

list(APPEND SRC {{'"${TORCH_' + name + '_SRC}"'}})
9 changes: 9 additions & 0 deletions build2cmake/src/templates/metal/torch-extension.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
define_gpu_extension_target(
{{ ops_name }}
DESTINATION {{ ops_name }}
LANGUAGE ${GPU_LANG}
SOURCES ${SRC}
COMPILE_FLAGS ${GPU_FLAGS}
ARCHITECTURES ${GPU_ARCHES}
USE_SABI 3
WITH_SOABI)
7 changes: 7 additions & 0 deletions build2cmake/src/templates/metal/utils.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
macro (append_cmake_prefix_path PKG EXPR)
run_python(_PREFIX_PATH
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
endmacro()
Loading