diff --git a/build2cmake/src/config/v1.rs b/build2cmake/src/config/v1.rs index d3daf6dc..60ea2da3 100644 --- a/build2cmake/src/config/v1.rs +++ b/build2cmake/src/config/v1.rs @@ -51,6 +51,7 @@ pub enum Language { #[default] Cuda, CudaHipify, + Metal, } impl Display for Language { @@ -58,6 +59,7 @@ impl Display for Language { match self { Language::Cuda => f.write_str("cuda"), Language::CudaHipify => f.write_str("cuda-hipify"), + Language::Metal => f.write_str("metal"), } } } diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index 5ccc1127..8c70fbca 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -87,6 +87,7 @@ pub struct Kernel { #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub enum Backend { Cuda, + Metal, Rocm, } @@ -94,6 +95,7 @@ impl Display for Backend { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Backend::Cuda => write!(f, "cuda"), + Backend::Metal => write!(f, "metal"), Backend::Rocm => write!(f, "rocm"), } } @@ -105,6 +107,7 @@ impl FromStr for Backend { fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { "cuda" => Ok(Backend::Cuda), + "metal" => Ok(Backend::Metal), "rocm" => Ok(Backend::Rocm), _ => Err(format!("Unknown backend: {}", s)), } diff --git a/build2cmake/src/main.rs b/build2cmake/src/main.rs index 6a624e99..e18ee88c 100644 --- a/build2cmake/src/main.rs +++ b/build2cmake/src/main.rs @@ -9,16 +9,13 @@ use eyre::{bail, ensure, Context, Result}; use minijinja::Environment; mod torch; -use torch::write_torch_ext; - -mod torch_universal; +use torch::{write_torch_ext, write_torch_ext_metal, write_torch_universal_ext}; mod config; use config::{Backend, Build, BuildCompat}; mod fileset; use fileset::FileSet; -use torch_universal::write_torch_universal_ext; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] @@ -108,16 +105,15 @@ fn generate_torch( env.set_trim_blocks(true); minijinja_embed::load_templates!(&mut env); - match (backend, build.general.universal) { - (None, true) => write_torch_universal_ext(&env, &build, target_dir, force, ops_id)?, + let backend = match (backend, build.general.universal) { + (None, true) => return write_torch_universal_ext(&env, &build, target_dir, force, ops_id), (Some(backend), true) => bail!("Universal kernel, cannot generate for backend {}", backend), - // TODO: add check if that type of backend has at least one kernel. (Some(backend), false) => { if !build.has_kernel_with_backend(&backend) { bail!("No kernels found for backend {}", backend); } - write_torch_ext(&env, &build, target_dir, force, ops_id)? + backend } (None, false) => { let mut kernel_backends = build.backends(); @@ -139,15 +135,16 @@ fn generate_torch( ); } - match backend { - Backend::Cuda | Backend::Rocm => { - write_torch_ext(&env, &build, target_dir, force, ops_id)? - } - } + backend } - } + }; - Ok(()) + match backend { + Backend::Cuda | Backend::Rocm => { + write_torch_ext(&env, backend, &build, target_dir, force, ops_id) + } + Backend::Metal => write_torch_ext_metal(&env, &build, target_dir, force, ops_id), + } } fn update_build(build_toml: PathBuf) -> Result<()> { diff --git a/build2cmake/src/templates/dep-cutlass.cmake b/build2cmake/src/templates/cuda/dep-cutlass.cmake similarity index 100% rename from build2cmake/src/templates/dep-cutlass.cmake rename to build2cmake/src/templates/cuda/dep-cutlass.cmake diff --git a/build2cmake/src/cmake/hipify.py b/build2cmake/src/templates/cuda/hipify.py similarity index 100% rename from build2cmake/src/cmake/hipify.py rename to build2cmake/src/templates/cuda/hipify.py diff --git a/build2cmake/src/templates/kernel.cmake b/build2cmake/src/templates/cuda/kernel.cmake similarity index 100% rename from build2cmake/src/templates/kernel.cmake rename to build2cmake/src/templates/cuda/kernel.cmake diff --git a/build2cmake/src/templates/preamble.cmake b/build2cmake/src/templates/cuda/preamble.cmake similarity index 96% rename from build2cmake/src/templates/preamble.cmake rename to build2cmake/src/templates/cuda/preamble.cmake index 26f881b8..dbb92ad1 100644 --- a/build2cmake/src/templates/preamble.cmake +++ b/build2cmake/src/templates/cuda/preamble.cmake @@ -60,12 +60,16 @@ if(GPU_LANG STREQUAL "CUDA") if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA") list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}") endif() + + add_compile_definitions(CUDA_KERNEL) elseif(GPU_LANG STREQUAL "HIP") set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}") # TODO: remove this once we can set specific archs per source file set. override_gpu_arches(GPU_ARCHES ${GPU_LANG} "${${GPU_LANG}_SUPPORTED_ARCHS}") + + add_compile_definitions(ROCM_KERNEL) else() override_gpu_arches(GPU_ARCHES ${GPU_LANG} diff --git a/build2cmake/src/templates/setup.py b/build2cmake/src/templates/cuda/setup.py similarity index 100% rename from build2cmake/src/templates/setup.py rename to build2cmake/src/templates/cuda/setup.py diff --git a/build2cmake/src/templates/torch-binding.cmake b/build2cmake/src/templates/cuda/torch-binding.cmake similarity index 100% rename from build2cmake/src/templates/torch-binding.cmake rename to build2cmake/src/templates/cuda/torch-binding.cmake diff --git a/build2cmake/src/templates/torch-extension.cmake b/build2cmake/src/templates/cuda/torch-extension.cmake similarity index 100% rename from build2cmake/src/templates/torch-extension.cmake rename to build2cmake/src/templates/cuda/torch-extension.cmake diff --git a/build2cmake/src/templates/metal/kernel.cmake b/build2cmake/src/templates/metal/kernel.cmake new file mode 100644 index 00000000..59423d04 --- /dev/null +++ b/build2cmake/src/templates/metal/kernel.cmake @@ -0,0 +1,13 @@ +set({{kernel_name}}_SRC + {{ sources }} +) + +{% if includes %} +# TODO: check if CLion support this: +# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories +set_source_files_properties( + {{'${' + kernel_name + '_SRC}'}} + PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}") +{% endif %} + +list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}}) \ No newline at end of file diff --git a/build2cmake/src/templates/metal/preamble.cmake b/build2cmake/src/templates/metal/preamble.cmake new file mode 100644 index 00000000..5087d087 --- /dev/null +++ b/build2cmake/src/templates/metal/preamble.cmake @@ -0,0 +1,28 @@ +cmake_minimum_required(VERSION 3.26) +project({{name}} LANGUAGES CXX) + +set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum macOS deployment version") + +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) + +include(FetchContent) +file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists +message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") + +include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) + +if(DEFINED Python_EXECUTABLE) + # Allow passing through the interpreter (e.g. from setup.py). + find_package(Python COMPONENTS Development Development.SABIModule Interpreter) + if (NOT Python_FOUND) + message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") + endif() +else() + find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter) +endif() + +append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") + +find_package(Torch REQUIRED) + +add_compile_definitions(METAL_KERNEL) diff --git a/build2cmake/src/templates/metal/setup.py b/build2cmake/src/templates/metal/setup.py new file mode 100644 index 00000000..4d078fb5 --- /dev/null +++ b/build2cmake/src/templates/metal/setup.py @@ -0,0 +1,121 @@ +import logging +import os +from shutil import which, move +import subprocess +import sys +from pathlib import Path + +from setuptools import Extension, find_packages, setup +from setuptools.command.build_ext import build_ext + +logger = logging.getLogger(__name__) + + +def is_sccache_available() -> bool: + return which("sccache") is not None + + +def is_ccache_available() -> bool: + return which("ccache") is not None + + +def is_ninja_available() -> bool: + return which("ninja") is not None + + +class CMakeExtension(Extension): + def __init__(self, name: str, sourcedir: str = "") -> None: + super().__init__(name, sources=[], py_limited_api=True) + self.sourcedir = os.fspath(Path(sourcedir).resolve()) + + +class CMakeBuild(build_ext): + def build_extension(self, ext: CMakeExtension) -> None: + ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) + extdir = ext_fullpath.parent.resolve() + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + cmake_args = [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", + f"-DPython_EXECUTABLE={sys.executable}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + ] + build_args = [] + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + if not cmake_generator or cmake_generator == "Ninja": + try: + import ninja + + ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" + cmake_args += [ + "-GNinja", + f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", + ] + except ImportError: + pass + + if is_sccache_available(): + cmake_args += [ + "-DCMAKE_C_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", + "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache", + "-DCMAKE_OBJC_COMPILER_LAUNCHER=sccache", + "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=sccache", + ] + elif is_ccache_available(): + cmake_args += [ + "-DCMAKE_C_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache", + "-DCMAKE_OBJC_COMPILER_LAUNCHER=ccache", + "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=ccache", + ] + + num_jobs = os.getenv("MAX_JOBS", None) + if num_jobs is not None: + num_jobs = int(num_jobs) + logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) + else: + try: + # os.sched_getaffinity() isn't universally available, so fall + # back to os.cpu_count() if we get an error here. + num_jobs = len(os.sched_getaffinity(0)) + except AttributeError: + num_jobs = os.cpu_count() + + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + subprocess.run( + ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True + ) + subprocess.run( + ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True + ) + + +setup( + name="{{ name }}", + # The version is just a stub, it's not used by the final build artefact. + version="0.1.0", + ext_modules=[CMakeExtension("{{ name }}.{{ ops_name }}")], + cmdclass={"build_ext": CMakeBuild}, + packages=find_packages(where="torch-ext", include=["{{ name }}*"]), + package_dir={"": "torch-ext"}, +{% if data_globs %} + package_data={"{{ name }}": [ {{ data_globs }} ]}, +{% endif %} + zip_safe=False, + install_requires=["torch"], + python_requires=">=3.9", +) diff --git a/build2cmake/src/templates/metal/torch-binding.cmake b/build2cmake/src/templates/metal/torch-binding.cmake new file mode 100644 index 00000000..79872f04 --- /dev/null +++ b/build2cmake/src/templates/metal/torch-binding.cmake @@ -0,0 +1,16 @@ +#get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG}) +#list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS}) + +set(TORCH_{{name}}_SRC + {{ src|join(' ') }} +) + +{% if includes %} +# TODO: check if CLion support this: +# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories +set_source_files_properties( + {{'${TORCH_' + name + '_SRC}'}} + PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}") +{% endif %} + +list(APPEND SRC {{'"${TORCH_' + name + '_SRC}"'}}) diff --git a/build2cmake/src/templates/metal/torch-extension.cmake b/build2cmake/src/templates/metal/torch-extension.cmake new file mode 100644 index 00000000..55824350 --- /dev/null +++ b/build2cmake/src/templates/metal/torch-extension.cmake @@ -0,0 +1,9 @@ +define_gpu_extension_target( + {{ ops_name }} + DESTINATION {{ ops_name }} + LANGUAGE ${GPU_LANG} + SOURCES ${SRC} + COMPILE_FLAGS ${GPU_FLAGS} + ARCHITECTURES ${GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) \ No newline at end of file diff --git a/build2cmake/src/templates/metal/utils.cmake b/build2cmake/src/templates/metal/utils.cmake new file mode 100644 index 00000000..c6355d5f --- /dev/null +++ b/build2cmake/src/templates/metal/utils.cmake @@ -0,0 +1,7 @@ +# Run `EXPR` in python after importing `PKG`. Use the result of this to extend +# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported. +macro (append_cmake_prefix_path PKG EXPR) + run_python(_PREFIX_PATH + "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path") + list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH}) +endmacro() \ No newline at end of file diff --git a/build2cmake/src/templates/_ops-universal.py b/build2cmake/src/templates/universal/_ops.py similarity index 100% rename from build2cmake/src/templates/_ops-universal.py rename to build2cmake/src/templates/universal/_ops.py diff --git a/build2cmake/src/templates/pyproject_universal.toml b/build2cmake/src/templates/universal/pyproject.toml similarity index 100% rename from build2cmake/src/templates/pyproject_universal.toml rename to build2cmake/src/templates/universal/pyproject.toml diff --git a/build2cmake/src/cmake/utils.cmake b/build2cmake/src/templates/utils.cmake similarity index 100% rename from build2cmake/src/cmake/utils.cmake rename to build2cmake/src/templates/utils.cmake diff --git a/build2cmake/src/torch.rs b/build2cmake/src/torch/cuda.rs similarity index 82% rename from build2cmake/src/torch.rs rename to build2cmake/src/torch/cuda.rs index e71f6b38..34b8ed19 100644 --- a/build2cmake/src/torch.rs +++ b/build2cmake/src/torch/cuda.rs @@ -1,52 +1,19 @@ use std::collections::HashSet; use std::io::Write; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use eyre::{bail, Context, Result}; -use git2::Repository; use itertools::Itertools; use minijinja::{context, Environment}; -use rand::Rng; +use super::kernel_ops_identifier; use crate::config::{Backend, Build, Dependencies, Kernel, Torch}; use crate::FileSet; -static CMAKE_UTILS: &str = include_str!("cmake/utils.cmake"); -static REGISTRATION_H: &str = include_str!("templates/registration.h"); -static HIPIFY: &str = include_str!("cmake/hipify.py"); -static CUDA_SUPPORTED_ARCHS_JSON: &str = include_str!("cuda_supported_archs.json"); - -fn random_identifier() -> String { - // Generate a random string when no ops_id is provided - let mut rng = rand::thread_rng(); - let build_id: u64 = rng.gen(); - base32::encode( - base32::Alphabet::Rfc4648Lower { padding: false }, - &build_id.to_le_bytes(), - ) -} - -fn git_identifier(target_dir: impl AsRef) -> Result { - let repo = Repository::discover(target_dir.as_ref()).context("Cannot open git repository")?; - let head = repo.head()?; - let commit = head.peel_to_commit()?; - let rev = commit.tree_id().to_string().chars().take(7).collect(); - let dirty = !repo.statuses(None)?.is_empty(); - Ok(if dirty { format!("{rev}_dirty") } else { rev }) -} - -pub fn kernel_ops_identifier( - target_dir: impl AsRef, - name: &str, - ops_id: Option, -) -> String { - let identifier = ops_id.unwrap_or_else(|| match git_identifier(target_dir.as_ref()) { - Ok(rev) => rev, - Err(_) => random_identifier(), - }); - - format!("_{name}_{identifier}") -} +static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); +static REGISTRATION_H: &str = include_str!("../templates/registration.h"); +static HIPIFY: &str = include_str!("../templates/cuda/hipify.py"); +static CUDA_SUPPORTED_ARCHS_JSON: &str = include_str!("../cuda_supported_archs.json"); fn cuda_supported_archs() -> String { let supported_archs: Vec = serde_json::from_str(CUDA_SUPPORTED_ARCHS_JSON) @@ -56,6 +23,7 @@ fn cuda_supported_archs() -> String { pub fn write_torch_ext( env: &Environment, + backend: Backend, build: &Build, target_dir: PathBuf, force: bool, @@ -72,6 +40,7 @@ pub fn write_torch_ext( write_cmake( env, + backend, build, torch_ext, &build.general.name, @@ -131,7 +100,7 @@ fn write_setup_py( let data_globs = torch.data_globs().map(|globs| globs.join(", ")); - env.get_template("setup.py") + env.get_template("cuda/setup.py") .wrap_err("Cannot get setup.py template")? .render_to_write( context! { @@ -174,6 +143,7 @@ fn write_ops_py( fn write_cmake( env: &Environment, + backend: Backend, build: &Build, torch: &Torch, name: &str, @@ -202,7 +172,11 @@ fn write_cmake( render_binding(env, torch, name, cmake_writer)?; - for (kernel_name, kernel) in &build.kernels { + for (kernel_name, kernel) in build + .kernels + .iter() + .filter(|(_, kernel)| kernel.backend == backend) + { render_kernel(env, kernel_name, kernel, cmake_writer)?; } @@ -217,7 +191,7 @@ pub fn render_binding( name: &str, write: &mut impl Write, ) -> Result<()> { - env.get_template("torch-binding.cmake") + env.get_template("cuda/torch-binding.cmake") .wrap_err("Cannot get Torch binding template")? .render_to_write( context! { @@ -243,7 +217,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu for dep in deps { match dep { Dependencies::Cutlass2_10 => { - env.get_template("dep-cutlass.cmake") + env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( context! { @@ -254,7 +228,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu .wrap_err("Cannot render CUTLASS dependency template")?; } Dependencies::Cutlass3_5 => { - env.get_template("dep-cutlass.cmake") + env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( context! { @@ -265,7 +239,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu .wrap_err("Cannot render CUTLASS dependency template")?; } Dependencies::Cutlass3_6 => { - env.get_template("dep-cutlass.cmake") + env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( context! { @@ -276,7 +250,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu .wrap_err("Cannot render CUTLASS dependency template")?; } Dependencies::Cutlass3_8 => { - env.get_template("dep-cutlass.cmake") + env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( context! { @@ -308,7 +282,7 @@ pub fn render_kernel( .collect_vec() .join("\n"); - env.get_template("kernel.cmake") + env.get_template("cuda/kernel.cmake") .wrap_err("Cannot get kernel template")? .render_to_write( context! { @@ -329,7 +303,7 @@ pub fn render_kernel( } pub fn render_extension(env: &Environment, ops_name: &str, write: &mut impl Write) -> Result<()> { - env.get_template("torch-extension.cmake") + env.get_template("cuda/torch-extension.cmake") .wrap_err("Cannot get Torch extension template")? .render_to_write( context! { @@ -345,7 +319,7 @@ pub fn render_extension(env: &Environment, ops_name: &str, write: &mut impl Writ } pub fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> { - env.get_template("preamble.cmake") + env.get_template("cuda/preamble.cmake") .wrap_err("Cannot get CMake prelude template")? .render_to_write( context! { diff --git a/build2cmake/src/torch/metal.rs b/build2cmake/src/torch/metal.rs new file mode 100644 index 00000000..23876735 --- /dev/null +++ b/build2cmake/src/torch/metal.rs @@ -0,0 +1,263 @@ +use std::{io::Write, path::PathBuf}; + +use eyre::{bail, Context, Result}; +use itertools::Itertools; +use minijinja::{context, Environment}; + +use super::kernel_ops_identifier; +use crate::{ + config::{Build, Kernel, Torch}, + fileset::FileSet, +}; + +static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); +static REGISTRATION_H: &str = include_str!("../templates/registration.h"); + +pub fn write_torch_ext_metal( + env: &Environment, + build: &Build, + target_dir: PathBuf, + force: bool, + ops_id: Option, +) -> Result<()> { + let torch_ext = match build.torch.as_ref() { + Some(torch_ext) => torch_ext, + None => bail!("Build configuration does not have `torch` section"), + }; + + let mut file_set = FileSet::default(); + + let ops_name = kernel_ops_identifier(&target_dir, &build.general.name, ops_id); + + write_cmake( + env, + build, + torch_ext, + &build.general.name, + &ops_name, + &mut file_set, + )?; + + write_setup_py( + env, + torch_ext, + &build.general.name, + &ops_name, + &mut file_set, + )?; + + write_ops_py(env, &build.general.name, &ops_name, &mut file_set)?; + + write_pyproject_toml(env, &mut file_set)?; + + write_torch_registration_macros(&mut file_set)?; + + file_set.write(&target_dir, force)?; + + Ok(()) +} + +fn write_cmake( + env: &Environment, + build: &Build, + torch: &Torch, + name: &str, + ops_name: &str, + file_set: &mut FileSet, +) -> Result<()> { + let mut utils_path = PathBuf::new(); + utils_path.push("cmake"); + utils_path.push("utils.cmake"); + file_set + .entry(utils_path.clone()) + .extend_from_slice(CMAKE_UTILS.as_bytes()); + + let cmake_writer = file_set.entry("CMakeLists.txt"); + + render_preamble(env, name, cmake_writer)?; + + // Add deps once we have any non-CUDA deps. + // render_deps(env, build, cmake_writer)?; + + render_binding(env, torch, name, cmake_writer)?; + + for (kernel_name, kernel) in &build.kernels { + render_kernel(env, kernel_name, kernel, cmake_writer)?; + } + + render_extension(env, ops_name, cmake_writer)?; + + Ok(()) +} + +fn render_binding( + env: &Environment, + torch: &Torch, + name: &str, + write: &mut impl Write, +) -> Result<()> { + env.get_template("metal/torch-binding.cmake") + .wrap_err("Cannot get Torch binding template")? + .render_to_write( + context! { + includes => torch.include.as_ref().map(prefix_and_join_includes), + name => name, + src => torch.src + }, + &mut *write, + ) + .wrap_err("Cannot render Torch binding template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +pub fn render_extension(env: &Environment, ops_name: &str, write: &mut impl Write) -> Result<()> { + env.get_template("metal/torch-extension.cmake") + .wrap_err("Cannot get Torch extension template")? + .render_to_write( + context! { + ops_name => ops_name, + }, + &mut *write, + ) + .wrap_err("Cannot render Torch extension template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +pub fn render_kernel( + env: &Environment, + kernel_name: &str, + kernel: &Kernel, + write: &mut impl Write, +) -> Result<()> { + // Easier to do in Rust than Jinja. + let sources = kernel + .src + .iter() + .map(|src| format!("\"{src}\"")) + .collect_vec() + .join("\n"); + + env.get_template("metal/kernel.cmake") + .wrap_err("Cannot get kernel template")? + .render_to_write( + context! { + includes => kernel.include.as_ref().map(prefix_and_join_includes), + kernel_name => kernel_name, + sources => sources, + }, + &mut *write, + ) + .wrap_err("Cannot render kernel template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> { + env.get_template("metal/preamble.cmake") + .wrap_err("Cannot get CMake prelude template")? + .render_to_write( + context! { + name => name, + }, + &mut *write, + ) + .wrap_err("Cannot render CMake prelude template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +fn write_ops_py( + env: &Environment, + name: &str, + ops_name: &str, + file_set: &mut FileSet, +) -> Result<()> { + let mut path = PathBuf::new(); + path.push("torch-ext"); + path.push(name); + path.push("_ops.py"); + let writer = file_set.entry(path); + + env.get_template("_ops.py") + .wrap_err("Cannot get _ops.py template")? + .render_to_write( + context! { + ops_name => ops_name, + }, + writer, + ) + .wrap_err("Cannot render kernel template")?; + + Ok(()) +} + +fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> { + let writer = file_set.entry("pyproject.toml"); + + env.get_template("pyproject.toml") + .wrap_err("Cannot get pyproject.toml template")? + .render_to_write(context! {}, writer) + .wrap_err("Cannot render kernel template")?; + + Ok(()) +} + +fn write_setup_py( + env: &Environment, + torch: &Torch, + name: &str, + ops_name: &str, + file_set: &mut FileSet, +) -> Result<()> { + let writer = file_set.entry("setup.py"); + + let data_globs = torch.data_globs().map(|globs| globs.join(", ")); + + env.get_template("metal/setup.py") + .wrap_err("Cannot get setup.py template")? + .render_to_write( + context! { + data_globs => data_globs, + ops_name => ops_name, + name => name, + version => "0.1.0", + }, + writer, + ) + .wrap_err("Cannot render kernel template")?; + + Ok(()) +} + +fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { + let mut path = PathBuf::new(); + path.push("torch-ext"); + path.push("registration.h"); + file_set + .entry(path) + .extend_from_slice(REGISTRATION_H.as_bytes()); + + Ok(()) +} + +fn prefix_and_join_includes(includes: impl AsRef<[S]>) -> String +where + S: AsRef, +{ + includes + .as_ref() + .iter() + .map(|include| format!("${{CMAKE_SOURCE_DIR}}/{}", include.as_ref())) + .collect_vec() + .join(";") +} diff --git a/build2cmake/src/torch/mod.rs b/build2cmake/src/torch/mod.rs new file mode 100644 index 00000000..2eaed180 --- /dev/null +++ b/build2cmake/src/torch/mod.rs @@ -0,0 +1,11 @@ +mod cuda; +pub use cuda::write_torch_ext; + +mod metal; +pub use metal::write_torch_ext_metal; + +mod ops_identifier; +pub(crate) use ops_identifier::kernel_ops_identifier; + +mod universal; +pub use universal::write_torch_universal_ext; diff --git a/build2cmake/src/torch/ops_identifier.rs b/build2cmake/src/torch/ops_identifier.rs new file mode 100644 index 00000000..3fe49e45 --- /dev/null +++ b/build2cmake/src/torch/ops_identifier.rs @@ -0,0 +1,37 @@ +use std::path::Path; + +use eyre::{Result, WrapErr}; +use git2::Repository; +use rand::Rng; + +fn random_identifier() -> String { + // Generate a random string when no ops_id is provided + let mut rng = rand::thread_rng(); + let build_id: u64 = rng.gen(); + base32::encode( + base32::Alphabet::Rfc4648Lower { padding: false }, + &build_id.to_le_bytes(), + ) +} + +fn git_identifier(target_dir: impl AsRef) -> Result { + let repo = Repository::discover(target_dir.as_ref()).context("Cannot open git repository")?; + let head = repo.head()?; + let commit = head.peel_to_commit()?; + let rev = commit.tree_id().to_string().chars().take(7).collect(); + let dirty = !repo.statuses(None)?.is_empty(); + Ok(if dirty { format!("{rev}_dirty") } else { rev }) +} + +pub fn kernel_ops_identifier( + target_dir: impl AsRef, + name: &str, + ops_id: Option, +) -> String { + let identifier = ops_id.unwrap_or_else(|| match git_identifier(target_dir.as_ref()) { + Ok(rev) => rev, + Err(_) => random_identifier(), + }); + + format!("_{name}_{identifier}") +} diff --git a/build2cmake/src/torch_universal.rs b/build2cmake/src/torch/universal.rs similarity index 95% rename from build2cmake/src/torch_universal.rs rename to build2cmake/src/torch/universal.rs index 55511508..f0361d73 100644 --- a/build2cmake/src/torch_universal.rs +++ b/build2cmake/src/torch/universal.rs @@ -45,7 +45,7 @@ fn write_ops_py( path.push("_ops.py"); let writer = file_set.entry(path); - env.get_template("_ops-universal.py") + env.get_template("universal/_ops.py") .wrap_err("Cannot get _ops-universal.py template")? .render_to_write( context! { @@ -68,7 +68,7 @@ fn write_pyproject_toml( let data_globs = torch.and_then(|torch| torch.data_globs().map(|globs| globs.join(", "))); - env.get_template("pyproject_universal.toml") + env.get_template("universal/pyproject.toml") .wrap_err("Cannot get universal pyproject.toml template")? .render_to_write( context! { diff --git a/examples/relu/build.toml b/examples/relu/build.toml index bfb3b8f7..efed5a32 100644 --- a/examples/relu/build.toml +++ b/examples/relu/build.toml @@ -11,7 +11,14 @@ src = [ [kernel.activation] backend = "cuda" depends = ["torch"] -src = ["relu_kernel/relu.cu"] +src = ["relu_cuda/relu.cu"] + +[kernel.activation_metal] +backend = "metal" +src = [ + "relu_metal/relu.mm", +] +depends = [ "torch" ] [kernel.activation_rocm] backend = "rocm" @@ -27,4 +34,4 @@ rocm-archs = [ "gfx1101", ] depends = ["torch"] -src = ["relu_kernel/relu.cu"] +src = ["relu_cuda/relu.cu"] diff --git a/examples/relu/relu_kernel/relu.cu b/examples/relu/relu_cuda/relu.cu similarity index 52% rename from examples/relu/relu_kernel/relu.cu rename to examples/relu/relu_cuda/relu.cu index af82ff39..6bbe3160 100644 --- a/examples/relu/relu_kernel/relu.cu +++ b/examples/relu/relu_cuda/relu.cu @@ -5,8 +5,7 @@ #include __global__ void relu_kernel(float *__restrict__ out, - float const *__restrict__ input, - const int d) { + float const *__restrict__ input, const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { auto x = input[token_idx * d + idx]; @@ -14,13 +13,25 @@ __global__ void relu_kernel(float *__restrict__ out, } } -void relu(torch::Tensor &out, - torch::Tensor const &input) -{ +void relu(torch::Tensor &out, torch::Tensor const &input) { + TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); TORCH_CHECK(input.scalar_type() == at::ScalarType::Float && input.scalar_type() == at::ScalarType::Float, "relu_kernel only supports float32"); + TORCH_CHECK(input.sizes() == out.sizes(), + "Tensors must have the same shape. Got input shape: ", + input.sizes(), " and output shape: ", out.sizes()); + + TORCH_CHECK(input.scalar_type() == out.scalar_type(), + "Tensors must have the same data type. Got input dtype: ", + input.scalar_type(), " and output dtype: ", out.scalar_type()); + + TORCH_CHECK(input.device() == out.device(), + "Tensors must be on the same device. Got input device: ", + input.device(), " and output device: ", out.device()); + int d = input.size(-1); int64_t num_tokens = input.numel() / d; dim3 grid(num_tokens); diff --git a/examples/relu/relu_metal/relu.mm b/examples/relu/relu_metal/relu.mm new file mode 100644 index 00000000..ac156da8 --- /dev/null +++ b/examples/relu/relu_metal/relu.mm @@ -0,0 +1,119 @@ +#include + +#import +#import +#include + +char const *CUSTOM_KERNEL = R"( + #include + using namespace metal; + + kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]], + device float *outC [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + // Explicitly write to output + outC[index] = max(0.0f, inA[index]); + } + + kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]], + device half *outC [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + // Explicitly write to output + outC[index] = max(static_cast(0.0), inA[index]); + } +)"; + +static inline id getMTLBufferStorage(const torch::Tensor &tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +torch::Tensor &dispatchReluKernel(torch::Tensor const &input, + torch::Tensor &output) { + @autoreleasepool { + id device = MTLCreateSystemDefaultDevice(); + NSError *error = nil; + + int numThreads = input.numel(); + + id customKernelLibrary = [device + newLibraryWithSource:[NSString stringWithUTF8String:CUSTOM_KERNEL] + options:nil + error:&error]; + TORCH_CHECK(customKernelLibrary, + "Failed to to create custom kernel library, error: ", + error.localizedDescription.UTF8String); + + std::string kernel_name = + std::string("relu_forward_kernel_") + + (input.scalar_type() == torch::kFloat ? "float" : "half"); + id customReluFunction = [customKernelLibrary + newFunctionWithName:[NSString + stringWithUTF8String:kernel_name.c_str()]]; + TORCH_CHECK(customReluFunction, + "Failed to create function state object for ", + kernel_name.c_str()); + + id reluPSO = + [device newComputePipelineStateWithFunction:customReluFunction + error:&error]; + TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String); + + id commandBuffer = torch::mps::get_command_buffer(); + TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference"); + + dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue(); + + dispatch_sync(serialQueue, ^() { + id computeEncoder = + [commandBuffer computeCommandEncoder]; + TORCH_CHECK(computeEncoder, "Failed to create compute command encoder"); + + [computeEncoder setComputePipelineState:reluPSO]; + [computeEncoder setBuffer:getMTLBufferStorage(input) + offset:input.storage_offset() * input.element_size() + atIndex:0]; + [computeEncoder setBuffer:getMTLBufferStorage(output) + offset:output.storage_offset() * output.element_size() + atIndex:1]; + + MTLSize gridSize = MTLSizeMake(numThreads, 1, 1); + + NSUInteger threadGroupSize = reluPSO.maxTotalThreadsPerThreadgroup; + if (threadGroupSize > numThreads) { + threadGroupSize = numThreads; + } + MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1); + + [computeEncoder dispatchThreads:gridSize + threadsPerThreadgroup:threadgroupSize]; + + [computeEncoder endEncoding]; + + torch::mps::commit(); + }); + } + + return output; +} + +void relu(torch::Tensor &out, const torch::Tensor &input) { + TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.scalar_type() == torch::kFloat || + input.scalar_type() == torch::kHalf, + "Unsupported data type: ", input.scalar_type()); + + TORCH_CHECK(input.sizes() == out.sizes(), + "Tensors must have the same shape. Got input shape: ", + input.sizes(), " and output shape: ", out.sizes()); + + TORCH_CHECK(input.scalar_type() == out.scalar_type(), + "Tensors must have the same data type. Got input dtype: ", + input.scalar_type(), " and output dtype: ", out.scalar_type()); + + TORCH_CHECK(input.device() == out.device(), + "Tensors must be on the same device. Got input device: ", + input.device(), " and output device: ", out.device()); + + dispatchReluKernel(input, out); +} diff --git a/examples/relu/torch-ext/torch_binding.cpp b/examples/relu/torch-ext/torch_binding.cpp index eb15d63f..4f75d886 100644 --- a/examples/relu/torch-ext/torch_binding.cpp +++ b/examples/relu/torch-ext/torch_binding.cpp @@ -5,7 +5,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("relu(Tensor! out, Tensor input) -> ()"); +#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) ops.impl("relu", torch::kCUDA, &relu); +#elif defined(METAL_KERNEL) + ops.impl("relu", torch::kMPS, relu); +#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)