From b8ed23d7979b262dbf89d8d384c7d15073d6c9a6 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 12:40:59 +0200 Subject: [PATCH 001/535] Add python bindings for common module Signed-off-by: Jan Bielak --- setup.py | 99 ++++++++++++++++++++++- transformer_engine/common/pybind.cpp | 114 +++++++++++++++++++++++++++ 2 files changed, 212 insertions(+), 1 deletion(-) create mode 100644 transformer_engine/common/pybind.cpp diff --git a/setup.py b/setup.py index 4a344191de..c7c1f5d137 100644 --- a/setup.py +++ b/setup.py @@ -464,6 +464,103 @@ def setup_common_extension() -> CMakeExtension: def _all_files_in_dir(path): return list(path.iterdir()) +def setup_common_pybind_extension() -> setuptools.Extension: + """Setup CUDA extension for common library""" + + # Source files + src_dir = root_path / "transformer_engine" / "common" + sources = [ + src_dir / "transformer_engine.cpp", + src_dir / "pybind.cpp", + src_dir / "transpose" / "cast_transpose.cu", + src_dir / "transpose" / "transpose.cu", + src_dir / "transpose" / "cast_transpose_fusion.cu", + src_dir / "transpose" / "transpose_fusion.cu", + src_dir / "transpose" / "multi_cast_transpose.cu", + src_dir / "activation" / "gelu.cu", + src_dir / "fused_attn" / "fused_attn_f16_max512_seqlen.cu", + src_dir / "fused_attn" / "fused_attn_f16_arbitrary_seqlen.cu", + src_dir / "activation" / "relu.cu", + src_dir / "activation" / "swiglu.cu", + src_dir / "fused_attn" / "fused_attn_fp8.cu", + src_dir / "fused_attn" / "fused_attn.cpp", + src_dir / "fused_attn" / "utils.cu", + src_dir / "gemm" / "cublaslt_gemm.cu", + src_dir / "layer_norm" / "ln_api.cpp", + src_dir / "layer_norm" / "ln_bwd_semi_cuda_kernel.cu", + src_dir / "layer_norm" / "ln_fwd_cuda_kernel.cu", + src_dir / "rmsnorm" / "rmsnorm_api.cpp", + src_dir / "rmsnorm" / "rmsnorm_bwd_semi_cuda_kernel.cu", + src_dir / "rmsnorm" / "rmsnorm_fwd_cuda_kernel.cu", + src_dir / "util" / "cast.cu", + src_dir / "util" / "cuda_driver.cpp", + src_dir / "util" / "cuda_runtime.cpp", + src_dir / "util" / "rtc.cpp", + src_dir / "util" / "system.cpp", + src_dir / "fused_softmax" / "scaled_masked_softmax.cu", + src_dir / "fused_softmax" / "scaled_upper_triang_masked_softmax.cu", + src_dir / "fused_softmax" / "scaled_masked_softmax.cu", + src_dir / "fused_softmax" / "scaled_upper_triang_masked_softmax.cu", + ] + + # Header files + include_dirs = [ + src_dir / "include", + root_path / "3rdparty" / "cudnn-frontend" / "include", + ] + + # Compiler flags + cxx_flags = ["-O3"] + nvcc_flags = [ + "-O3", + "-gencode", + "arch=compute_70,code=sm_70", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ] + + # Version-dependent CUDA options + try: + version = cuda_version() + except FileNotFoundError: + print("Could not determine CUDA Toolkit version") + else: + if version >= (11, 2): + nvcc_flags.extend(["--threads", "4"]) + if version >= (11, 0): + nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) + if version >= (11, 8): + nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + + # userbuffers support + if with_userbuffers(): + if os.getenv("MPI_HOME"): + mpi_home = Path(os.getenv("MPI_HOME")) + include_dirs.append(mpi_home / "include") + cxx_flags.append("-DNVTE_WITH_USERBUFFERS") + nvcc_flags.append("-DNVTE_WITH_USERBUFFERS") + + # Construct PyTorch CUDA extension + sources = [str(path) for path in sources] + include_dirs = [str(path) for path in include_dirs] + from torch.utils.cpp_extension import CUDAExtension + return CUDAExtension( + name="transformer_engine_cuda", + sources=sources, + include_dirs=include_dirs, + extra_compile_args={ + "cxx": cxx_flags, + "nvcc": nvcc_flags, + }, + ) + def setup_pytorch_extension() -> setuptools.Extension: """Setup CUDA extension for PyTorch support""" @@ -611,7 +708,7 @@ def main(): setup_requires, install_requires, test_requires = setup_requirements() # Extensions - ext_modules = [setup_common_extension()] + ext_modules = [setup_common_extension(), setup_common_pybind_extension()] if "pytorch" in frameworks(): ext_modules.append(setup_pytorch_extension()) diff --git a/transformer_engine/common/pybind.cpp b/transformer_engine/common/pybind.cpp new file mode 100644 index 0000000000..9c596d4416 --- /dev/null +++ b/transformer_engine/common/pybind.cpp @@ -0,0 +1,114 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("nvte_gelu", &nvte_gelu) + m.def("nvte_dgelu", &nvte_dgelu) + m.def("nvte_geglu", &nvte_geglu) + m.def("nvte_dgeglu", &nvte_dgeglu) + m.def("nvte_relu", &nvte_relu) + m.def("nvte_drelu", &nvte_drelu) + m.def("nvte_swiglu", &nvte_swiglu) + m.def("nvte_dswiglu", &nvte_dswiglu) + m.def("nvte_reglu", &nvte_reglu) + m.def("nvte_dreglu", &nvte_dreglu) + + m.def("nvte_fp8_quantize", &nvte_fp8_quantize) + m.def("nvte_fp8_dequantize", &nvte_fp8_dequantize) + + m.def("nvte_get_fused_attn_backend", &nvte_get_fused_attn_backend) + m.def("nvte_fused_attn_fwd_qkvpacked", &nvte_fused_attn_fwd_qkvpacked) + m.def("nvte_fused_attn_bwd_qkvpacked", &nvte_fused_attn_bwd_qkvpacked) + m.def("nvte_fused_attn_fwd_kvpacked", &nvte_fused_attn_fwd_kvpacked) + m.def("nvte_fused_attn_bwd_kvpacked", &nvte_fused_attn_bwd_kvpacked) + + m.def("nvte_cublas_gemm", &nvte_cublas_gemm) + + m.def("nvte_layernorm_fwd", &nvte_layernorm_fwd) + m.def("nvte_layernorm1p_fwd", &nvte_layernorm1p_fwd) + m.def("nvte_layernorm_bwd", &nvte_layernorm_bwd) + m.def("nvte_layernorm1p_bwd", &nvte_layernorm1p_bwd) + + m.def("nvte_rmsnorm_fwd", &nvte_rmsnorm_fwd) + m.def("nvte_rmsnorm_bwd", &nvte_rmsnorm_bwd) + + m.def("nvte_scaled_softmax_forward", &nvte_scaled_softmax_forward) + m.def("nvte_scaled_softmax_backward", &nvte_scaled_softmax_backward) + m.def("nvte_scaled_masked_softmax_forward", &nvte_scaled_masked_softmax_forward) + m.def("nvte_scaled_masked_softmax_backward", &nvte_scaled_masked_softmax_backward) + m.def("nvte_scaled_upper_triang_masked_softmax_forward", &nvte_scaled_upper_triang_masked_softmax_forward) + m.def("nvte_scaled_upper_triang_masked_softmax_backward", &nvte_scaled_upper_triang_masked_softmax_backward) + + m.def("nvte_create_tensor", &nvte_create_tensor) + m.def("nvte_destroy_tensor", &nvte_destroy_tensor) + m.def("nvte_tensor_type", &nvte_tensor_type) + m.def("nvte_tensor_shape", &nvte_tensor_shape) + m.def("nvte_tensor_data", &nvte_tensor_data) + m.def("nvte_tensor_amax", &nvte_tensor_amax) + m.def("nvte_tensor_scale", &nvte_tensor_scale) + m.def("nvte_tensor_scale_inv", &nvte_tensor_scale_inv) + m.def("nvte_tensor_pack_create", &nvte_tensor_pack_create) + m.def("nvte_tensor_pack_destroy", &nvte_tensor_pack_destroy) + + m.def("nvte_cast_transpose", &nvte_cast_transpose) + m.def("nvte_transpose", &nvte_transpose) + m.def("nvte_cast_transpose_dbias", &nvte_cast_transpose_dbias) + m.def("nvte_fp8_transpose_dbias", &nvte_fp8_transpose_dbias) + m.def("nvte_cast_transpose_dbias_dgelu", &nvte_cast_transpose_dbias_dgelu) + m.def("nvte_multi_cast_transpose", &nvte_multi_cast_transpose) + m.def("nvte_dgeglu_cast_transpose", &nvte_dgeglu_cast_transpose) + + py::enum_(m, "NVTEDType") + .value("kNVTEByte", kNVTEByte) + .value("kNVTEInt32", kNVTEInt32) + .value("kNVTEInt64", kNVTEInt64) + .value("kNVTEFloat32", kNVTEFloat32) + .value("kNVTEFloat16", kNVTEFloat16) + .value("kNVTEBFloat16", kNVTEBFloat16) + .value("kNVTEFloat8E4M3", kNVTEFloat8E4M3) + .value("kNVTEFloat8E5M2", kNVTEFloat8E5M2); + + py::enum_(m, "NVTE_Fused_Attn_Backend") + .value("NVTE_No_Backend", NVTE_No_Backend) + .value("NVTE_F16_max512_seqlen", NVTE_F16_max512_seqlen) + .value("NVTE_F16_arbitrary_seqlen", NVTE_F16_arbitrary_seqlen) + .value("NVTE_FP8", NVTE_FP8); + + py::enum_(m, "NVTE_QKV_Layout") + .value("NVTE_NOT_INTERLEAVED", NVTE_NOT_INTERLEAVED) + .value("NVTE_QKV_INTERLEAVED", NVTE_QKV_INTERLEAVED) + .value("NVTE_KV_INTERLEAVED", NVTE_KV_INTERLEAVED); + + py::enum_(m, "NVTE_Bias_Type") + .value("NVTE_NO_BIAS", NVTE_NO_BIAS) + .value("NVTE_PRE_SCALE_BIAS", NVTE_PRE_SCALE_BIAS) + .value("NVTE_POST_SCALE_BIAS", NVTE_POST_SCALE_BIAS); + + py::enum_(m, "NVTE_Mask_Type") + .value("NVTE_NO_MASK", NVTE_NO_MASK) + .value("NVTE_PADDING_MASK", NVTE_PADDING_MASK) + .value("NVTE_CAUSAL_MASK", NVTE_CAUSAL_MASK); + + py::class_(m, "NVTEShape") + .def(py::init<>()) + .def_readwrite("data", &NVTEShape::data) + .def_readwrite("ndim", &NVTEShape::ndim) + + py::class_(m, "NVTETensorPack") + .def(py::init<>()) + .def_readwrite("tensors", &NVTETensorPack::tensors) + .def_readwrite("size", &NVTETensorPack::size) +} \ No newline at end of file From b11974b8a36ee0a6e8c16735d31141a681e8af9b Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 12:54:05 +0200 Subject: [PATCH 002/535] fix duplicate file name Signed-off-by: Jan Bielak --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index c7c1f5d137..691ea65a5d 100644 --- a/setup.py +++ b/setup.py @@ -499,8 +499,6 @@ def setup_common_pybind_extension() -> setuptools.Extension: src_dir / "util" / "system.cpp", src_dir / "fused_softmax" / "scaled_masked_softmax.cu", src_dir / "fused_softmax" / "scaled_upper_triang_masked_softmax.cu", - src_dir / "fused_softmax" / "scaled_masked_softmax.cu", - src_dir / "fused_softmax" / "scaled_upper_triang_masked_softmax.cu", ] # Header files From 2eca14599bedc2a42bf0b0eb192edd1425a108df Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 13:19:43 +0200 Subject: [PATCH 003/535] omit unnecessary files in build Signed-off-by: Jan Bielak --- setup.py | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/setup.py b/setup.py index 691ea65a5d..66ec079695 100644 --- a/setup.py +++ b/setup.py @@ -470,41 +470,14 @@ def setup_common_pybind_extension() -> setuptools.Extension: # Source files src_dir = root_path / "transformer_engine" / "common" sources = [ - src_dir / "transformer_engine.cpp", src_dir / "pybind.cpp", - src_dir / "transpose" / "cast_transpose.cu", - src_dir / "transpose" / "transpose.cu", - src_dir / "transpose" / "cast_transpose_fusion.cu", - src_dir / "transpose" / "transpose_fusion.cu", - src_dir / "transpose" / "multi_cast_transpose.cu", - src_dir / "activation" / "gelu.cu", - src_dir / "fused_attn" / "fused_attn_f16_max512_seqlen.cu", - src_dir / "fused_attn" / "fused_attn_f16_arbitrary_seqlen.cu", - src_dir / "activation" / "relu.cu", - src_dir / "activation" / "swiglu.cu", - src_dir / "fused_attn" / "fused_attn_fp8.cu", - src_dir / "fused_attn" / "fused_attn.cpp", - src_dir / "fused_attn" / "utils.cu", - src_dir / "gemm" / "cublaslt_gemm.cu", - src_dir / "layer_norm" / "ln_api.cpp", - src_dir / "layer_norm" / "ln_bwd_semi_cuda_kernel.cu", - src_dir / "layer_norm" / "ln_fwd_cuda_kernel.cu", - src_dir / "rmsnorm" / "rmsnorm_api.cpp", - src_dir / "rmsnorm" / "rmsnorm_bwd_semi_cuda_kernel.cu", - src_dir / "rmsnorm" / "rmsnorm_fwd_cuda_kernel.cu", - src_dir / "util" / "cast.cu", - src_dir / "util" / "cuda_driver.cpp", - src_dir / "util" / "cuda_runtime.cpp", - src_dir / "util" / "rtc.cpp", - src_dir / "util" / "system.cpp", - src_dir / "fused_softmax" / "scaled_masked_softmax.cu", - src_dir / "fused_softmax" / "scaled_upper_triang_masked_softmax.cu", ] # Header files include_dirs = [ src_dir / "include", root_path / "3rdparty" / "cudnn-frontend" / "include", + root_path / "transformer_engine" ] # Compiler flags From d63885fc91772211f795859000dd315685626d0f Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 13:24:42 +0200 Subject: [PATCH 004/535] add missing include Signed-off-by: Jan Bielak --- transformer_engine/common/pybind.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/common/pybind.cpp b/transformer_engine/common/pybind.cpp index 9c596d4416..8f151c94d9 100644 --- a/transformer_engine/common/pybind.cpp +++ b/transformer_engine/common/pybind.cpp @@ -14,6 +14,9 @@ #include #include +#include +namespace py = pybind11; + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nvte_gelu", &nvte_gelu) m.def("nvte_dgelu", &nvte_dgelu) From 6a18645ac3e7528e79c85020af93f777b11a661f Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 13:29:55 +0200 Subject: [PATCH 005/535] fix pybind.cpp Signed-off-by: Jan Bielak --- transformer_engine/common/pybind.cpp | 112 +++++++++++++-------------- 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/transformer_engine/common/pybind.cpp b/transformer_engine/common/pybind.cpp index 8f151c94d9..b0afa53170 100644 --- a/transformer_engine/common/pybind.cpp +++ b/transformer_engine/common/pybind.cpp @@ -18,61 +18,61 @@ namespace py = pybind11; PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("nvte_gelu", &nvte_gelu) - m.def("nvte_dgelu", &nvte_dgelu) - m.def("nvte_geglu", &nvte_geglu) - m.def("nvte_dgeglu", &nvte_dgeglu) - m.def("nvte_relu", &nvte_relu) - m.def("nvte_drelu", &nvte_drelu) - m.def("nvte_swiglu", &nvte_swiglu) - m.def("nvte_dswiglu", &nvte_dswiglu) - m.def("nvte_reglu", &nvte_reglu) - m.def("nvte_dreglu", &nvte_dreglu) - - m.def("nvte_fp8_quantize", &nvte_fp8_quantize) - m.def("nvte_fp8_dequantize", &nvte_fp8_dequantize) - - m.def("nvte_get_fused_attn_backend", &nvte_get_fused_attn_backend) - m.def("nvte_fused_attn_fwd_qkvpacked", &nvte_fused_attn_fwd_qkvpacked) - m.def("nvte_fused_attn_bwd_qkvpacked", &nvte_fused_attn_bwd_qkvpacked) - m.def("nvte_fused_attn_fwd_kvpacked", &nvte_fused_attn_fwd_kvpacked) - m.def("nvte_fused_attn_bwd_kvpacked", &nvte_fused_attn_bwd_kvpacked) - - m.def("nvte_cublas_gemm", &nvte_cublas_gemm) - - m.def("nvte_layernorm_fwd", &nvte_layernorm_fwd) - m.def("nvte_layernorm1p_fwd", &nvte_layernorm1p_fwd) - m.def("nvte_layernorm_bwd", &nvte_layernorm_bwd) - m.def("nvte_layernorm1p_bwd", &nvte_layernorm1p_bwd) - - m.def("nvte_rmsnorm_fwd", &nvte_rmsnorm_fwd) - m.def("nvte_rmsnorm_bwd", &nvte_rmsnorm_bwd) - - m.def("nvte_scaled_softmax_forward", &nvte_scaled_softmax_forward) - m.def("nvte_scaled_softmax_backward", &nvte_scaled_softmax_backward) - m.def("nvte_scaled_masked_softmax_forward", &nvte_scaled_masked_softmax_forward) - m.def("nvte_scaled_masked_softmax_backward", &nvte_scaled_masked_softmax_backward) - m.def("nvte_scaled_upper_triang_masked_softmax_forward", &nvte_scaled_upper_triang_masked_softmax_forward) - m.def("nvte_scaled_upper_triang_masked_softmax_backward", &nvte_scaled_upper_triang_masked_softmax_backward) - - m.def("nvte_create_tensor", &nvte_create_tensor) - m.def("nvte_destroy_tensor", &nvte_destroy_tensor) - m.def("nvte_tensor_type", &nvte_tensor_type) - m.def("nvte_tensor_shape", &nvte_tensor_shape) - m.def("nvte_tensor_data", &nvte_tensor_data) - m.def("nvte_tensor_amax", &nvte_tensor_amax) - m.def("nvte_tensor_scale", &nvte_tensor_scale) - m.def("nvte_tensor_scale_inv", &nvte_tensor_scale_inv) - m.def("nvte_tensor_pack_create", &nvte_tensor_pack_create) - m.def("nvte_tensor_pack_destroy", &nvte_tensor_pack_destroy) - - m.def("nvte_cast_transpose", &nvte_cast_transpose) - m.def("nvte_transpose", &nvte_transpose) - m.def("nvte_cast_transpose_dbias", &nvte_cast_transpose_dbias) - m.def("nvte_fp8_transpose_dbias", &nvte_fp8_transpose_dbias) - m.def("nvte_cast_transpose_dbias_dgelu", &nvte_cast_transpose_dbias_dgelu) - m.def("nvte_multi_cast_transpose", &nvte_multi_cast_transpose) - m.def("nvte_dgeglu_cast_transpose", &nvte_dgeglu_cast_transpose) + m.def("nvte_gelu", &nvte_gelu); + m.def("nvte_dgelu", &nvte_dgelu); + m.def("nvte_geglu", &nvte_geglu); + m.def("nvte_dgeglu", &nvte_dgeglu); + m.def("nvte_relu", &nvte_relu); + m.def("nvte_drelu", &nvte_drelu); + m.def("nvte_swiglu", &nvte_swiglu); + m.def("nvte_dswiglu", &nvte_dswiglu); + m.def("nvte_reglu", &nvte_reglu); + m.def("nvte_dreglu", &nvte_dreglu); + + m.def("nvte_fp8_quantize", &nvte_fp8_quantize); + m.def("nvte_fp8_dequantize", &nvte_fp8_dequantize); + + m.def("nvte_get_fused_attn_backend", &nvte_get_fused_attn_backend); + m.def("nvte_fused_attn_fwd_qkvpacked", &nvte_fused_attn_fwd_qkvpacked); + m.def("nvte_fused_attn_bwd_qkvpacked", &nvte_fused_attn_bwd_qkvpacked); + m.def("nvte_fused_attn_fwd_kvpacked", &nvte_fused_attn_fwd_kvpacked); + m.def("nvte_fused_attn_bwd_kvpacked", &nvte_fused_attn_bwd_kvpacked); + + m.def("nvte_cublas_gemm", &nvte_cublas_gemm); + + m.def("nvte_layernorm_fwd", &nvte_layernorm_fwd); + m.def("nvte_layernorm1p_fwd", &nvte_layernorm1p_fwd); + m.def("nvte_layernorm_bwd", &nvte_layernorm_bwd); + m.def("nvte_layernorm1p_bwd", &nvte_layernorm1p_bwd); + + m.def("nvte_rmsnorm_fwd", &nvte_rmsnorm_fwd); + m.def("nvte_rmsnorm_bwd", &nvte_rmsnorm_bwd); + + m.def("nvte_scaled_softmax_forward", &nvte_scaled_softmax_forward); + m.def("nvte_scaled_softmax_backward", &nvte_scaled_softmax_backward); + m.def("nvte_scaled_masked_softmax_forward", &nvte_scaled_masked_softmax_forward); + m.def("nvte_scaled_masked_softmax_backward", &nvte_scaled_masked_softmax_backward); + m.def("nvte_scaled_upper_triang_masked_softmax_forward", &nvte_scaled_upper_triang_masked_softmax_forward); + m.def("nvte_scaled_upper_triang_masked_softmax_backward", &nvte_scaled_upper_triang_masked_softmax_backward); + + m.def("nvte_create_tensor", &nvte_create_tensor); + m.def("nvte_destroy_tensor", &nvte_destroy_tensor); + m.def("nvte_tensor_type", &nvte_tensor_type); + m.def("nvte_tensor_shape", &nvte_tensor_shape); + m.def("nvte_tensor_data", &nvte_tensor_data); + m.def("nvte_tensor_amax", &nvte_tensor_amax); + m.def("nvte_tensor_scale", &nvte_tensor_scale); + m.def("nvte_tensor_scale_inv", &nvte_tensor_scale_inv); + m.def("nvte_tensor_pack_create", &nvte_tensor_pack_create); + m.def("nvte_tensor_pack_destroy", &nvte_tensor_pack_destroy); + + m.def("nvte_cast_transpose", &nvte_cast_transpose); + m.def("nvte_transpose", &nvte_transpose); + m.def("nvte_cast_transpose_dbias", &nvte_cast_transpose_dbias); + m.def("nvte_fp8_transpose_dbias", &nvte_fp8_transpose_dbias); + m.def("nvte_cast_transpose_dbias_dgelu", &nvte_cast_transpose_dbias_dgelu); + m.def("nvte_multi_cast_transpose", &nvte_multi_cast_transpose); + m.def("nvte_dgeglu_cast_transpose", &nvte_dgeglu_cast_transpose); py::enum_(m, "NVTEDType") .value("kNVTEByte", kNVTEByte) @@ -105,7 +105,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("NVTE_PADDING_MASK", NVTE_PADDING_MASK) .value("NVTE_CAUSAL_MASK", NVTE_CAUSAL_MASK); - py::class_(m, "NVTEShape") + py::class_(m, "NVTEShape") .def(py::init<>()) .def_readwrite("data", &NVTEShape::data) .def_readwrite("ndim", &NVTEShape::ndim) From bd4a93d2fdf6bc3effd0158aaf38e9a34208f8f0 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 15:26:09 +0200 Subject: [PATCH 006/535] fix extension Signed-off-by: Jan Bielak --- setup.py | 37 +++--- .../sequential/cpp_extensions}/pybind.cpp | 124 +++++++++++++++--- 2 files changed, 122 insertions(+), 39 deletions(-) rename transformer_engine/{common => pytorch/sequential/cpp_extensions}/pybind.cpp (57%) diff --git a/setup.py b/setup.py index 66ec079695..6e6b65ed35 100644 --- a/setup.py +++ b/setup.py @@ -464,20 +464,23 @@ def setup_common_extension() -> CMakeExtension: def _all_files_in_dir(path): return list(path.iterdir()) -def setup_common_pybind_extension() -> setuptools.Extension: - """Setup CUDA extension for common library""" +def setup_pytorch_extension() -> setuptools.Extension: + """Setup CUDA extension for PyTorch support""" # Source files - src_dir = root_path / "transformer_engine" / "common" + src_dir = root_path / "transformer_engine" / "pytorch" / "csrc" + extensions_dir = src_dir / "extensions" sources = [ - src_dir / "pybind.cpp", - ] + src_dir / "common.cu", + src_dir / "ts_fp8_op.cpp", + ] + \ + _all_files_in_dir(extensions_dir) # Header files include_dirs = [ - src_dir / "include", + root_path / "transformer_engine" / "common" / "include", + root_path / "transformer_engine" / "pytorch" / "csrc", root_path / "3rdparty" / "cudnn-frontend" / "include", - root_path / "transformer_engine" ] # Compiler flags @@ -523,9 +526,10 @@ def setup_common_pybind_extension() -> setuptools.Extension: include_dirs = [str(path) for path in include_dirs] from torch.utils.cpp_extension import CUDAExtension return CUDAExtension( - name="transformer_engine_cuda", + name="transformer_engine_extensions", sources=sources, include_dirs=include_dirs, + # libraries=["transformer_engine"], ### TODO (tmoon) Debug linker errors extra_compile_args={ "cxx": cxx_flags, "nvcc": nvcc_flags, @@ -536,18 +540,15 @@ def setup_pytorch_extension() -> setuptools.Extension: """Setup CUDA extension for PyTorch support""" # Source files - src_dir = root_path / "transformer_engine" / "pytorch" / "csrc" - extensions_dir = src_dir / "extensions" + src_dir = root_path / "transformer_engine" / "pytorch" / "sequential" / "cpp_extensions" sources = [ - src_dir / "common.cu", - src_dir / "ts_fp8_op.cpp", - ] + \ - _all_files_in_dir(extensions_dir) + src_dir / "pybind.cpp" + ] # Header files include_dirs = [ root_path / "transformer_engine" / "common" / "include", - root_path / "transformer_engine" / "pytorch" / "csrc", + root_path / "transformer_engine", root_path / "3rdparty" / "cudnn-frontend" / "include", ] @@ -594,10 +595,9 @@ def setup_pytorch_extension() -> setuptools.Extension: include_dirs = [str(path) for path in include_dirs] from torch.utils.cpp_extension import CUDAExtension return CUDAExtension( - name="transformer_engine_extensions", + name="transformer_engine_cuda", sources=sources, include_dirs=include_dirs, - # libraries=["transformer_engine"], ### TODO (tmoon) Debug linker errors extra_compile_args={ "cxx": cxx_flags, "nvcc": nvcc_flags, @@ -679,9 +679,10 @@ def main(): setup_requires, install_requires, test_requires = setup_requirements() # Extensions - ext_modules = [setup_common_extension(), setup_common_pybind_extension()] + ext_modules = [setup_common_extension()] if "pytorch" in frameworks(): ext_modules.append(setup_pytorch_extension()) + ext_modules.append(setup_sequential_extension()) if "paddle" in frameworks(): ext_modules.append(setup_paddle_extension()) diff --git a/transformer_engine/common/pybind.cpp b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp similarity index 57% rename from transformer_engine/common/pybind.cpp rename to transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp index b0afa53170..a25360843f 100644 --- a/transformer_engine/common/pybind.cpp +++ b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp @@ -1,9 +1,28 @@ /************************************************************************* - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. * * See LICENSE for license information. ************************************************************************/ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include @@ -13,11 +32,78 @@ #include #include #include - -#include +#include namespace py = pybind11; -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +struct Tensor { + NVTETensor impl; + + static void *getDataPtr(at::Tensor t) { + if (t.numel() > 0) { + return t.data_ptr(); + } else { + return nullptr; + } + } + + Tensor(NVTEDType dtype, at::Tensor data, at::Tensor amax, at::Tensor scale, + at::Tensor scale_inv) { + NVTEShape shape{data.sizes().data(), data.sizes().size()}; + impl = nvte_create_tensor(getDataPtr(data), shape, dtype, getDataPtr(amax), + getDataPtr(scale), getDataPtr(scale_inv)); + } + ~Tensor() { nvte_destroy_tensor(impl); } +}; + +struct TensorPack : NVTETensorPack { + TensorPack(const std::vector &tensors_) : tensors{}, size{} { + size = tensors_.size(); + if (size > MAX_SIZE) { + throw std::runtime_error("TensorPack size exceeds MAX_SIZE"); + } + for (size_t i = 0; i < size; ++i) { + tensors[i] = tensors_[i].impl; + } + nvte_tensor_pack_create(this); + } + ~TensorPack() { nvte_tensor_pack_destroy(this); } +}; + +template struct trait { + using type = T; +}; + +template struct wrapped_arg : trait {}; +struct wrapped_arg : trait {}; +struct wrapped_arg : trait> {}; + +template using wrapped_arg_t = typename wrapped_arg::type; + +template decltype(auto) unwrap_arg(T &&arg) { + if constexpr (std::is_same_v < std::decay_t, wrapped_arg_t) { + return arg.impl; + } else if constexpr (std::is_same_v, + wrapped_arg_t>) { + return TensorPack(arg); + } else { + { return arg; } + } + + template + constexpr auto wrap(Ret(func)(Args && ..., LastArg &&)) noexcept { + if constexpr (std::is_same_v, cudaStream_t>) { + return [func](wrapped_arg_t... args) -> Ret { + return func(unwrap_arg(args)..., at::cuda::getCurrentCUDAStream()); + }; + } else { + return [func](wrapped_arg_t... args, + wrapped_arg_t last_arg) -> Ret { + return func(unwrap_arg(args)..., unwrap_arg(last_arg)); + }; + } + } + + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nvte_gelu", &nvte_gelu); m.def("nvte_dgelu", &nvte_dgelu); m.def("nvte_geglu", &nvte_geglu); @@ -28,33 +114,30 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nvte_dswiglu", &nvte_dswiglu); m.def("nvte_reglu", &nvte_reglu); m.def("nvte_dreglu", &nvte_dreglu); - m.def("nvte_fp8_quantize", &nvte_fp8_quantize); m.def("nvte_fp8_dequantize", &nvte_fp8_dequantize); - m.def("nvte_get_fused_attn_backend", &nvte_get_fused_attn_backend); m.def("nvte_fused_attn_fwd_qkvpacked", &nvte_fused_attn_fwd_qkvpacked); m.def("nvte_fused_attn_bwd_qkvpacked", &nvte_fused_attn_bwd_qkvpacked); m.def("nvte_fused_attn_fwd_kvpacked", &nvte_fused_attn_fwd_kvpacked); m.def("nvte_fused_attn_bwd_kvpacked", &nvte_fused_attn_bwd_kvpacked); - m.def("nvte_cublas_gemm", &nvte_cublas_gemm); - m.def("nvte_layernorm_fwd", &nvte_layernorm_fwd); m.def("nvte_layernorm1p_fwd", &nvte_layernorm1p_fwd); m.def("nvte_layernorm_bwd", &nvte_layernorm_bwd); m.def("nvte_layernorm1p_bwd", &nvte_layernorm1p_bwd); - m.def("nvte_rmsnorm_fwd", &nvte_rmsnorm_fwd); m.def("nvte_rmsnorm_bwd", &nvte_rmsnorm_bwd); - m.def("nvte_scaled_softmax_forward", &nvte_scaled_softmax_forward); m.def("nvte_scaled_softmax_backward", &nvte_scaled_softmax_backward); - m.def("nvte_scaled_masked_softmax_forward", &nvte_scaled_masked_softmax_forward); - m.def("nvte_scaled_masked_softmax_backward", &nvte_scaled_masked_softmax_backward); - m.def("nvte_scaled_upper_triang_masked_softmax_forward", &nvte_scaled_upper_triang_masked_softmax_forward); - m.def("nvte_scaled_upper_triang_masked_softmax_backward", &nvte_scaled_upper_triang_masked_softmax_backward); - + m.def("nvte_scaled_masked_softmax_forward", + &nvte_scaled_masked_softmax_forward); + m.def("nvte_scaled_masked_softmax_backward", + &nvte_scaled_masked_softmax_backward); + m.def("nvte_scaled_upper_triang_masked_softmax_forward", + &nvte_scaled_upper_triang_masked_softmax_forward); + m.def("nvte_scaled_upper_triang_masked_softmax_backward", + &nvte_scaled_upper_triang_masked_softmax_backward); m.def("nvte_create_tensor", &nvte_create_tensor); m.def("nvte_destroy_tensor", &nvte_destroy_tensor); m.def("nvte_tensor_type", &nvte_tensor_type); @@ -108,10 +191,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_(m, "NVTEShape") .def(py::init<>()) .def_readwrite("data", &NVTEShape::data) - .def_readwrite("ndim", &NVTEShape::ndim) + .def_readwrite("ndim", &NVTEShape::ndim); - py::class_(m, "NVTETensorPack") - .def(py::init<>()) - .def_readwrite("tensors", &NVTETensorPack::tensors) - .def_readwrite("size", &NVTETensorPack::size) -} \ No newline at end of file + py::class_(m, "NVTETensor") + .def(py::init()) + } From 2e5d965fb09350b0e82fc17abdf529fd913e31a0 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 15:33:22 +0200 Subject: [PATCH 007/535] fix function name Signed-off-by: Jan Bielak --- setup.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 6e6b65ed35..bd4d6ce645 100644 --- a/setup.py +++ b/setup.py @@ -536,9 +536,7 @@ def setup_pytorch_extension() -> setuptools.Extension: }, ) -def setup_pytorch_extension() -> setuptools.Extension: - """Setup CUDA extension for PyTorch support""" - +def setup_sequential_extension() -> setuptools.Extension: # Source files src_dir = root_path / "transformer_engine" / "pytorch" / "sequential" / "cpp_extensions" sources = [ From 832f097c72062e7027d4da717b448578cbd68771 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 15:39:18 +0200 Subject: [PATCH 008/535] fix missing brace Signed-off-by: Jan Bielak --- .../sequential/cpp_extensions/pybind.cpp | 217 +++++++++--------- 1 file changed, 109 insertions(+), 108 deletions(-) diff --git a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp index a25360843f..c7606e3bf3 100644 --- a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp +++ b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp @@ -88,112 +88,113 @@ template decltype(auto) unwrap_arg(T &&arg) { } else { { return arg; } } - - template - constexpr auto wrap(Ret(func)(Args && ..., LastArg &&)) noexcept { - if constexpr (std::is_same_v, cudaStream_t>) { - return [func](wrapped_arg_t... args) -> Ret { - return func(unwrap_arg(args)..., at::cuda::getCurrentCUDAStream()); - }; - } else { - return [func](wrapped_arg_t... args, - wrapped_arg_t last_arg) -> Ret { - return func(unwrap_arg(args)..., unwrap_arg(last_arg)); - }; - } - } - - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("nvte_gelu", &nvte_gelu); - m.def("nvte_dgelu", &nvte_dgelu); - m.def("nvte_geglu", &nvte_geglu); - m.def("nvte_dgeglu", &nvte_dgeglu); - m.def("nvte_relu", &nvte_relu); - m.def("nvte_drelu", &nvte_drelu); - m.def("nvte_swiglu", &nvte_swiglu); - m.def("nvte_dswiglu", &nvte_dswiglu); - m.def("nvte_reglu", &nvte_reglu); - m.def("nvte_dreglu", &nvte_dreglu); - m.def("nvte_fp8_quantize", &nvte_fp8_quantize); - m.def("nvte_fp8_dequantize", &nvte_fp8_dequantize); - m.def("nvte_get_fused_attn_backend", &nvte_get_fused_attn_backend); - m.def("nvte_fused_attn_fwd_qkvpacked", &nvte_fused_attn_fwd_qkvpacked); - m.def("nvte_fused_attn_bwd_qkvpacked", &nvte_fused_attn_bwd_qkvpacked); - m.def("nvte_fused_attn_fwd_kvpacked", &nvte_fused_attn_fwd_kvpacked); - m.def("nvte_fused_attn_bwd_kvpacked", &nvte_fused_attn_bwd_kvpacked); - m.def("nvte_cublas_gemm", &nvte_cublas_gemm); - m.def("nvte_layernorm_fwd", &nvte_layernorm_fwd); - m.def("nvte_layernorm1p_fwd", &nvte_layernorm1p_fwd); - m.def("nvte_layernorm_bwd", &nvte_layernorm_bwd); - m.def("nvte_layernorm1p_bwd", &nvte_layernorm1p_bwd); - m.def("nvte_rmsnorm_fwd", &nvte_rmsnorm_fwd); - m.def("nvte_rmsnorm_bwd", &nvte_rmsnorm_bwd); - m.def("nvte_scaled_softmax_forward", &nvte_scaled_softmax_forward); - m.def("nvte_scaled_softmax_backward", &nvte_scaled_softmax_backward); - m.def("nvte_scaled_masked_softmax_forward", - &nvte_scaled_masked_softmax_forward); - m.def("nvte_scaled_masked_softmax_backward", - &nvte_scaled_masked_softmax_backward); - m.def("nvte_scaled_upper_triang_masked_softmax_forward", - &nvte_scaled_upper_triang_masked_softmax_forward); - m.def("nvte_scaled_upper_triang_masked_softmax_backward", - &nvte_scaled_upper_triang_masked_softmax_backward); - m.def("nvte_create_tensor", &nvte_create_tensor); - m.def("nvte_destroy_tensor", &nvte_destroy_tensor); - m.def("nvte_tensor_type", &nvte_tensor_type); - m.def("nvte_tensor_shape", &nvte_tensor_shape); - m.def("nvte_tensor_data", &nvte_tensor_data); - m.def("nvte_tensor_amax", &nvte_tensor_amax); - m.def("nvte_tensor_scale", &nvte_tensor_scale); - m.def("nvte_tensor_scale_inv", &nvte_tensor_scale_inv); - m.def("nvte_tensor_pack_create", &nvte_tensor_pack_create); - m.def("nvte_tensor_pack_destroy", &nvte_tensor_pack_destroy); - - m.def("nvte_cast_transpose", &nvte_cast_transpose); - m.def("nvte_transpose", &nvte_transpose); - m.def("nvte_cast_transpose_dbias", &nvte_cast_transpose_dbias); - m.def("nvte_fp8_transpose_dbias", &nvte_fp8_transpose_dbias); - m.def("nvte_cast_transpose_dbias_dgelu", &nvte_cast_transpose_dbias_dgelu); - m.def("nvte_multi_cast_transpose", &nvte_multi_cast_transpose); - m.def("nvte_dgeglu_cast_transpose", &nvte_dgeglu_cast_transpose); - - py::enum_(m, "NVTEDType") - .value("kNVTEByte", kNVTEByte) - .value("kNVTEInt32", kNVTEInt32) - .value("kNVTEInt64", kNVTEInt64) - .value("kNVTEFloat32", kNVTEFloat32) - .value("kNVTEFloat16", kNVTEFloat16) - .value("kNVTEBFloat16", kNVTEBFloat16) - .value("kNVTEFloat8E4M3", kNVTEFloat8E4M3) - .value("kNVTEFloat8E5M2", kNVTEFloat8E5M2); - - py::enum_(m, "NVTE_Fused_Attn_Backend") - .value("NVTE_No_Backend", NVTE_No_Backend) - .value("NVTE_F16_max512_seqlen", NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_FP8); - - py::enum_(m, "NVTE_QKV_Layout") - .value("NVTE_NOT_INTERLEAVED", NVTE_NOT_INTERLEAVED) - .value("NVTE_QKV_INTERLEAVED", NVTE_QKV_INTERLEAVED) - .value("NVTE_KV_INTERLEAVED", NVTE_KV_INTERLEAVED); - - py::enum_(m, "NVTE_Bias_Type") - .value("NVTE_NO_BIAS", NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_POST_SCALE_BIAS); - - py::enum_(m, "NVTE_Mask_Type") - .value("NVTE_NO_MASK", NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_CAUSAL_MASK); - - py::class_(m, "NVTEShape") - .def(py::init<>()) - .def_readwrite("data", &NVTEShape::data) - .def_readwrite("ndim", &NVTEShape::ndim); - - py::class_(m, "NVTETensor") - .def(py::init()) +} + +template +constexpr auto wrap(Ret(func)(Args &&..., LastArg &&)) noexcept { + if constexpr (std::is_same_v, cudaStream_t>) { + return [func](wrapped_arg_t... args) -> Ret { + return func(unwrap_arg(args)..., at::cuda::getCurrentCUDAStream()); + }; + } else { + return [func](wrapped_arg_t... args, + wrapped_arg_t last_arg) -> Ret { + return func(unwrap_arg(args)..., unwrap_arg(last_arg)); + }; } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("nvte_gelu", &nvte_gelu); + m.def("nvte_dgelu", &nvte_dgelu); + m.def("nvte_geglu", &nvte_geglu); + m.def("nvte_dgeglu", &nvte_dgeglu); + m.def("nvte_relu", &nvte_relu); + m.def("nvte_drelu", &nvte_drelu); + m.def("nvte_swiglu", &nvte_swiglu); + m.def("nvte_dswiglu", &nvte_dswiglu); + m.def("nvte_reglu", &nvte_reglu); + m.def("nvte_dreglu", &nvte_dreglu); + m.def("nvte_fp8_quantize", &nvte_fp8_quantize); + m.def("nvte_fp8_dequantize", &nvte_fp8_dequantize); + m.def("nvte_get_fused_attn_backend", &nvte_get_fused_attn_backend); + m.def("nvte_fused_attn_fwd_qkvpacked", &nvte_fused_attn_fwd_qkvpacked); + m.def("nvte_fused_attn_bwd_qkvpacked", &nvte_fused_attn_bwd_qkvpacked); + m.def("nvte_fused_attn_fwd_kvpacked", &nvte_fused_attn_fwd_kvpacked); + m.def("nvte_fused_attn_bwd_kvpacked", &nvte_fused_attn_bwd_kvpacked); + m.def("nvte_cublas_gemm", &nvte_cublas_gemm); + m.def("nvte_layernorm_fwd", &nvte_layernorm_fwd); + m.def("nvte_layernorm1p_fwd", &nvte_layernorm1p_fwd); + m.def("nvte_layernorm_bwd", &nvte_layernorm_bwd); + m.def("nvte_layernorm1p_bwd", &nvte_layernorm1p_bwd); + m.def("nvte_rmsnorm_fwd", &nvte_rmsnorm_fwd); + m.def("nvte_rmsnorm_bwd", &nvte_rmsnorm_bwd); + m.def("nvte_scaled_softmax_forward", &nvte_scaled_softmax_forward); + m.def("nvte_scaled_softmax_backward", &nvte_scaled_softmax_backward); + m.def("nvte_scaled_masked_softmax_forward", + &nvte_scaled_masked_softmax_forward); + m.def("nvte_scaled_masked_softmax_backward", + &nvte_scaled_masked_softmax_backward); + m.def("nvte_scaled_upper_triang_masked_softmax_forward", + &nvte_scaled_upper_triang_masked_softmax_forward); + m.def("nvte_scaled_upper_triang_masked_softmax_backward", + &nvte_scaled_upper_triang_masked_softmax_backward); + m.def("nvte_create_tensor", &nvte_create_tensor); + m.def("nvte_destroy_tensor", &nvte_destroy_tensor); + m.def("nvte_tensor_type", &nvte_tensor_type); + m.def("nvte_tensor_shape", &nvte_tensor_shape); + m.def("nvte_tensor_data", &nvte_tensor_data); + m.def("nvte_tensor_amax", &nvte_tensor_amax); + m.def("nvte_tensor_scale", &nvte_tensor_scale); + m.def("nvte_tensor_scale_inv", &nvte_tensor_scale_inv); + m.def("nvte_tensor_pack_create", &nvte_tensor_pack_create); + m.def("nvte_tensor_pack_destroy", &nvte_tensor_pack_destroy); + + m.def("nvte_cast_transpose", &nvte_cast_transpose); + m.def("nvte_transpose", &nvte_transpose); + m.def("nvte_cast_transpose_dbias", &nvte_cast_transpose_dbias); + m.def("nvte_fp8_transpose_dbias", &nvte_fp8_transpose_dbias); + m.def("nvte_cast_transpose_dbias_dgelu", &nvte_cast_transpose_dbias_dgelu); + m.def("nvte_multi_cast_transpose", &nvte_multi_cast_transpose); + m.def("nvte_dgeglu_cast_transpose", &nvte_dgeglu_cast_transpose); + + py::enum_(m, "NVTEDType") + .value("kNVTEByte", kNVTEByte) + .value("kNVTEInt32", kNVTEInt32) + .value("kNVTEInt64", kNVTEInt64) + .value("kNVTEFloat32", kNVTEFloat32) + .value("kNVTEFloat16", kNVTEFloat16) + .value("kNVTEBFloat16", kNVTEBFloat16) + .value("kNVTEFloat8E4M3", kNVTEFloat8E4M3) + .value("kNVTEFloat8E5M2", kNVTEFloat8E5M2); + + py::enum_(m, "NVTE_Fused_Attn_Backend") + .value("NVTE_No_Backend", NVTE_No_Backend) + .value("NVTE_F16_max512_seqlen", NVTE_F16_max512_seqlen) + .value("NVTE_F16_arbitrary_seqlen", NVTE_F16_arbitrary_seqlen) + .value("NVTE_FP8", NVTE_FP8); + + py::enum_(m, "NVTE_QKV_Layout") + .value("NVTE_NOT_INTERLEAVED", NVTE_NOT_INTERLEAVED) + .value("NVTE_QKV_INTERLEAVED", NVTE_QKV_INTERLEAVED) + .value("NVTE_KV_INTERLEAVED", NVTE_KV_INTERLEAVED); + + py::enum_(m, "NVTE_Bias_Type") + .value("NVTE_NO_BIAS", NVTE_NO_BIAS) + .value("NVTE_PRE_SCALE_BIAS", NVTE_PRE_SCALE_BIAS) + .value("NVTE_POST_SCALE_BIAS", NVTE_POST_SCALE_BIAS); + + py::enum_(m, "NVTE_Mask_Type") + .value("NVTE_NO_MASK", NVTE_NO_MASK) + .value("NVTE_PADDING_MASK", NVTE_PADDING_MASK) + .value("NVTE_CAUSAL_MASK", NVTE_CAUSAL_MASK); + + py::class_(m, "NVTEShape") + .def(py::init<>()) + .def_readwrite("data", &NVTEShape::data) + .def_readwrite("ndim", &NVTEShape::ndim); + + py::class_(m, "NVTETensor") + .def( + py::init()) +} From ee2cfa172914ec33e50435f26112d890b3a7c5ed Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 15:45:04 +0200 Subject: [PATCH 009/535] add wrapper Signed-off-by: Jan Bielak --- .../sequential/cpp_extensions/pybind.cpp | 96 +++++++++---------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp index c7606e3bf3..e910f11f48 100644 --- a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp +++ b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp @@ -105,58 +105,58 @@ constexpr auto wrap(Ret(func)(Args &&..., LastArg &&)) noexcept { } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("nvte_gelu", &nvte_gelu); - m.def("nvte_dgelu", &nvte_dgelu); - m.def("nvte_geglu", &nvte_geglu); - m.def("nvte_dgeglu", &nvte_dgeglu); - m.def("nvte_relu", &nvte_relu); - m.def("nvte_drelu", &nvte_drelu); - m.def("nvte_swiglu", &nvte_swiglu); - m.def("nvte_dswiglu", &nvte_dswiglu); - m.def("nvte_reglu", &nvte_reglu); - m.def("nvte_dreglu", &nvte_dreglu); - m.def("nvte_fp8_quantize", &nvte_fp8_quantize); - m.def("nvte_fp8_dequantize", &nvte_fp8_dequantize); - m.def("nvte_get_fused_attn_backend", &nvte_get_fused_attn_backend); - m.def("nvte_fused_attn_fwd_qkvpacked", &nvte_fused_attn_fwd_qkvpacked); - m.def("nvte_fused_attn_bwd_qkvpacked", &nvte_fused_attn_bwd_qkvpacked); - m.def("nvte_fused_attn_fwd_kvpacked", &nvte_fused_attn_fwd_kvpacked); - m.def("nvte_fused_attn_bwd_kvpacked", &nvte_fused_attn_bwd_kvpacked); - m.def("nvte_cublas_gemm", &nvte_cublas_gemm); - m.def("nvte_layernorm_fwd", &nvte_layernorm_fwd); - m.def("nvte_layernorm1p_fwd", &nvte_layernorm1p_fwd); - m.def("nvte_layernorm_bwd", &nvte_layernorm_bwd); - m.def("nvte_layernorm1p_bwd", &nvte_layernorm1p_bwd); - m.def("nvte_rmsnorm_fwd", &nvte_rmsnorm_fwd); - m.def("nvte_rmsnorm_bwd", &nvte_rmsnorm_bwd); - m.def("nvte_scaled_softmax_forward", &nvte_scaled_softmax_forward); - m.def("nvte_scaled_softmax_backward", &nvte_scaled_softmax_backward); + m.def("nvte_gelu", wrap(nvte_gelu)); + m.def("nvte_dgelu", wrap(nvte_dgelu)); + m.def("nvte_geglu", wrap(nvte_geglu)); + m.def("nvte_dgeglu", wrap(nvte_dgeglu)); + m.def("nvte_relu", wrap(nvte_relu)); + m.def("nvte_drelu", wrap(nvte_drelu)); + m.def("nvte_swiglu", wrap(nvte_swiglu)); + m.def("nvte_dswiglu", wrap(nvte_dswiglu)); + m.def("nvte_reglu", wrap(nvte_reglu)); + m.def("nvte_dreglu", wrap(nvte_dreglu)); + m.def("nvte_fp8_quantize", wrap(nvte_fp8_quantize)); + m.def("nvte_fp8_dequantize", wrap(nvte_fp8_dequantize)); + m.def("nvte_get_fused_attn_backend", wrap(nvte_get_fused_attn_backend)); + m.def("nvte_fused_attn_fwd_qkvpacked", wrap(nvte_fused_attn_fwd_qkvpacked)); + m.def("nvte_fused_attn_bwd_qkvpacked", wrap(nvte_fused_attn_bwd_qkvpacked)); + m.def("nvte_fused_attn_fwd_kvpacked", wrap(nvte_fused_attn_fwd_kvpacked)); + m.def("nvte_fused_attn_bwd_kvpacked", wrap(nvte_fused_attn_bwd_kvpacked)); + m.def("nvte_cublas_gemm", wrap(nvte_cublas_gemm)); + m.def("nvte_layernorm_fwd", wrap(nvte_layernorm_fwd)); + m.def("nvte_layernorm1p_fwd", wrap(nvte_layernorm1p_fwd)); + m.def("nvte_layernorm_bwd", wrap(nvte_layernorm_bwd)); + m.def("nvte_layernorm1p_bwd", wrap(nvte_layernorm1p_bwd)); + m.def("nvte_rmsnorm_fwd", wrap(nvte_rmsnorm_fwd)); + m.def("nvte_rmsnorm_bwd", wrap(nvte_rmsnorm_bwd)); + m.def("nvte_scaled_softmax_forward", wrap(nvte_scaled_softmax_forward)); + m.def("nvte_scaled_softmax_backward", wrap(nvte_scaled_softmax_backward)); m.def("nvte_scaled_masked_softmax_forward", - &nvte_scaled_masked_softmax_forward); + wrap(nvte_scaled_masked_softmax_forward)); m.def("nvte_scaled_masked_softmax_backward", - &nvte_scaled_masked_softmax_backward); + wrap(nvte_scaled_masked_softmax_backward)); m.def("nvte_scaled_upper_triang_masked_softmax_forward", - &nvte_scaled_upper_triang_masked_softmax_forward); + wrap(nvte_scaled_upper_triang_masked_softmax_forward)); m.def("nvte_scaled_upper_triang_masked_softmax_backward", - &nvte_scaled_upper_triang_masked_softmax_backward); - m.def("nvte_create_tensor", &nvte_create_tensor); - m.def("nvte_destroy_tensor", &nvte_destroy_tensor); - m.def("nvte_tensor_type", &nvte_tensor_type); - m.def("nvte_tensor_shape", &nvte_tensor_shape); - m.def("nvte_tensor_data", &nvte_tensor_data); - m.def("nvte_tensor_amax", &nvte_tensor_amax); - m.def("nvte_tensor_scale", &nvte_tensor_scale); - m.def("nvte_tensor_scale_inv", &nvte_tensor_scale_inv); - m.def("nvte_tensor_pack_create", &nvte_tensor_pack_create); - m.def("nvte_tensor_pack_destroy", &nvte_tensor_pack_destroy); - - m.def("nvte_cast_transpose", &nvte_cast_transpose); - m.def("nvte_transpose", &nvte_transpose); - m.def("nvte_cast_transpose_dbias", &nvte_cast_transpose_dbias); - m.def("nvte_fp8_transpose_dbias", &nvte_fp8_transpose_dbias); - m.def("nvte_cast_transpose_dbias_dgelu", &nvte_cast_transpose_dbias_dgelu); - m.def("nvte_multi_cast_transpose", &nvte_multi_cast_transpose); - m.def("nvte_dgeglu_cast_transpose", &nvte_dgeglu_cast_transpose); + wrap(nvte_scaled_upper_triang_masked_softmax_backward)); + m.def("nvte_create_tensor", wrap(nvte_create_tensor)); + m.def("nvte_destroy_tensor", wrap(nvte_destroy_tensor)); + m.def("nvte_tensor_type", wrap(nvte_tensor_type)); + m.def("nvte_tensor_shape", wrap(nvte_tensor_shape)); + m.def("nvte_tensor_data", wrap(nvte_tensor_data)); + m.def("nvte_tensor_amax", wrap(nvte_tensor_amax)); + m.def("nvte_tensor_scale", wrap(nvte_tensor_scale)); + m.def("nvte_tensor_scale_inv", wrap(nvte_tensor_scale_inv)); + m.def("nvte_tensor_pack_create", wrap(nvte_tensor_pack_create)); + m.def("nvte_tensor_pack_destroy", wrap(nvte_tensor_pack_destroy)); + m.def("nvte_cast_transpose", wrap(nvte_cast_transpose)); + m.def("nvte_transpose", wrap(nvte_transpose)); + m.def("nvte_cast_transpose_dbias", wrap(nvte_cast_transpose_dbias)); + m.def("nvte_fp8_transpose_dbias", wrap(nvte_fp8_transpose_dbias)); + m.def("nvte_cast_transpose_dbias_dgelu", + wrap(nvte_cast_transpose_dbias_dgelu)); + m.def("nvte_multi_cast_transpose", wrap(nvte_multi_cast_transpose)); + m.def("nvte_dgeglu_cast_transpose", wrap(nvte_dgeglu_cast_transpose)); py::enum_(m, "NVTEDType") .value("kNVTEByte", kNVTEByte) From eedf751f0baa578c8aeedc57d5603a9365c39282 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 15:46:35 +0200 Subject: [PATCH 010/535] add missing conversion Signed-off-by: Jan Bielak --- transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp index e910f11f48..15e2173fc2 100644 --- a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp +++ b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp @@ -66,6 +66,7 @@ struct TensorPack : NVTETensorPack { } nvte_tensor_pack_create(this); } + operator NVTETensorPack *() { return this; } ~TensorPack() { nvte_tensor_pack_destroy(this); } }; From 3b4bb4f405031c18cf57671719f0c0689d77345d Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 17:29:05 +0200 Subject: [PATCH 011/535] fix pybind Signed-off-by: Jan Bielak --- .../sequential/cpp_extensions/pybind.cpp | 47 ++++---- .../sequential/cpp_extensions/type_list.h | 102 ++++++++++++++++++ 2 files changed, 131 insertions(+), 18 deletions(-) create mode 100644 transformer_engine/pytorch/sequential/cpp_extensions/type_list.h diff --git a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp index 15e2173fc2..bebfa01a38 100644 --- a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp +++ b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp @@ -5,6 +5,7 @@ * See LICENSE for license information. ************************************************************************/ +#include "type_list.h" #include #include #include @@ -33,14 +34,15 @@ #include #include #include + namespace py = pybind11; struct Tensor { NVTETensor impl; - static void *getDataPtr(at::Tensor t) { + static float *getDataPtr(at::Tensor t) { if (t.numel() > 0) { - return t.data_ptr(); + return reinterpret_cast(t.data_ptr()); } else { return nullptr; } @@ -48,7 +50,7 @@ struct Tensor { Tensor(NVTEDType dtype, at::Tensor data, at::Tensor amax, at::Tensor scale, at::Tensor scale_inv) { - NVTEShape shape{data.sizes().data(), data.sizes().size()}; + NVTEShape shape{(static_cast(data.sizes().data()), data.sizes().size()}; impl = nvte_create_tensor(getDataPtr(data), shape, dtype, getDataPtr(amax), getDataPtr(scale), getDataPtr(scale_inv)); } @@ -56,7 +58,7 @@ struct Tensor { }; struct TensorPack : NVTETensorPack { - TensorPack(const std::vector &tensors_) : tensors{}, size{} { + TensorPack(const std::vector &tensors_) : NVTETensorPack{} { size = tensors_.size(); if (size > MAX_SIZE) { throw std::runtime_error("TensorPack size exceeds MAX_SIZE"); @@ -75,13 +77,13 @@ template struct trait { }; template struct wrapped_arg : trait {}; -struct wrapped_arg : trait {}; -struct wrapped_arg : trait> {}; +template <> struct wrapped_arg : trait {}; +template <> struct wrapped_arg : trait> {}; template using wrapped_arg_t = typename wrapped_arg::type; template decltype(auto) unwrap_arg(T &&arg) { - if constexpr (std::is_same_v < std::decay_t, wrapped_arg_t) { + if constexpr (std::is_same_v, wrapped_arg_t>) { return arg.impl; } else if constexpr (std::is_same_v, wrapped_arg_t>) { @@ -91,16 +93,25 @@ template decltype(auto) unwrap_arg(T &&arg) { } } -template -constexpr auto wrap(Ret(func)(Args &&..., LastArg &&)) noexcept { - if constexpr (std::is_same_v, cudaStream_t>) { - return [func](wrapped_arg_t... args) -> Ret { - return func(unwrap_arg(args)..., at::cuda::getCurrentCUDAStream()); - }; +template +constexpr auto _wrap_no_last(Ret(func)(Args...), type_list, + LastGetterT last_func) noexcept { + return [func, last_func](wrapped_arg_t... args) -> Ret { + return func(unwrap_arg(args)..., last_func()); + }; +} + +template +constexpr auto wrap(Ret(func)(Args...)) noexcept { + using LastArg = typename type_list::back_t; + if constexpr (std::is_same_v) { + using stripped = typename type_list::template pop_back<>; + return _wrap_no_last<>(func, stripped(), + []() { return at::cuda::getCurrentCUDAStream(); }); } else { - return [func](wrapped_arg_t... args, - wrapped_arg_t last_arg) -> Ret { - return func(unwrap_arg(args)..., unwrap_arg(last_arg)); + return [func](wrapped_arg_t... args) -> Ret { + return func(unwrap_arg(args)...); }; } } @@ -196,6 +207,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def_readwrite("ndim", &NVTEShape::ndim); py::class_(m, "NVTETensor") - .def( - py::init()) + .def(py::init()); } diff --git a/transformer_engine/pytorch/sequential/cpp_extensions/type_list.h b/transformer_engine/pytorch/sequential/cpp_extensions/type_list.h new file mode 100644 index 0000000000..5023f1dae4 --- /dev/null +++ b/transformer_engine/pytorch/sequential/cpp_extensions/type_list.h @@ -0,0 +1,102 @@ +#include +#include + +template struct type_list; + +template struct type_list_front; +template struct type_list_back; +template struct type_list_reverse_list; +template struct type_list_index; +template struct type_list_cat_list; +template struct type_list_pop_front_list; +template struct type_list_pop_back_list; + +template +struct type_list_front> { + using type = First; +}; + +template +struct type_list_pop_front_list, 0> { + using type = type_list; +}; + +template +struct type_list_pop_front_list, N> { + using type = typename type_list_pop_front_list, N - 1>::type; +}; + +template +struct type_list_index, I> { +private: + using stripped = typename type_list_pop_front_list, I>::type; + +public: + using type = typename type_list_front::type; +}; + +template +struct type_list_cat_list, type_list> { + using type = type_list; +}; + +template +struct type_list_reverse_list> { +private: + using ts_reversed = typename type_list_reverse_list>::type; + using back_list = type_list; + +public: + using type = typename type_list_cat_list::type; +}; +template <> struct type_list_reverse_list> { + using type = type_list<>; +}; + +template struct type_list_back> { +private: + using reversed = typename type_list_reverse_list>::type; + +public: + using type = typename type_list_front::type; +}; + +template +struct type_list_pop_back_list, N> { +private: + using reversed = typename type_list_reverse_list>::type; + using stripped = typename type_list_pop_front_list::type; + +public: + using type = typename type_list_reverse_list::type; +}; + +template +using type_list_front_t = typename type_list_front::type; +template +using type_list_back_t = typename type_list_back::type; +template +using type_list_reverse_list_t = typename type_list_reverse_list::type; +template +using type_list_index_t = typename type_list_index::type; +template +using type_list_cat_list_t = typename type_list_cat_list::type; +template +using type_list_pop_front_list_t = + typename type_list_pop_front_list::type; +template +using type_list_pop_back_list_t = typename type_list_pop_back_list::type; + +template struct type_list { + using front = type_list>; + using front_t = type_list_index_t; + + using back = type_list>; + using back_t = type_list_index_t; + + template + using pop_front = type_list_pop_front_list_t; + + template + using pop_back = type_list_pop_back_list_t; +}; From 49b3e1394225e8658e3a369c502ef8b785a3946b Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 17:39:07 +0200 Subject: [PATCH 012/535] fix missing brace Signed-off-by: Jan Bielak --- .../pytorch/sequential/cpp_extensions/pybind.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp index bebfa01a38..469cec755b 100644 --- a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp +++ b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp @@ -5,7 +5,6 @@ * See LICENSE for license information. ************************************************************************/ -#include "type_list.h" #include #include #include @@ -35,6 +34,8 @@ #include #include +#include "type_list.h" + namespace py = pybind11; struct Tensor { @@ -50,7 +51,8 @@ struct Tensor { Tensor(NVTEDType dtype, at::Tensor data, at::Tensor amax, at::Tensor scale, at::Tensor scale_inv) { - NVTEShape shape{(static_cast(data.sizes().data()), data.sizes().size()}; + NVTEShape shape{static_cast(data.sizes().data()), + data.sizes().size()}; impl = nvte_create_tensor(getDataPtr(data), shape, dtype, getDataPtr(amax), getDataPtr(scale), getDataPtr(scale_inv)); } From 11464306fb8db525e5539bfc2990c364385efd98 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 7 Aug 2023 19:45:33 +0200 Subject: [PATCH 013/535] fix templates Signed-off-by: Jan Bielak --- .../sequential/cpp_extensions/pybind.cpp | 31 +++---- .../sequential/cpp_extensions/type_list.h | 80 ++++++++++++++++++- 2 files changed, 96 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp index 469cec755b..743ab75571 100644 --- a/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp +++ b/transformer_engine/pytorch/sequential/cpp_extensions/pybind.cpp @@ -51,8 +51,7 @@ struct Tensor { Tensor(NVTEDType dtype, at::Tensor data, at::Tensor amax, at::Tensor scale, at::Tensor scale_inv) { - NVTEShape shape{static_cast(data.sizes().data()), - data.sizes().size()}; + NVTEShape shape{(size_t *)(data.sizes().data()), data.sizes().size()}; impl = nvte_create_tensor(getDataPtr(data), shape, dtype, getDataPtr(amax), getDataPtr(scale), getDataPtr(scale_inv)); } @@ -95,22 +94,26 @@ template decltype(auto) unwrap_arg(T &&arg) { } } -template -constexpr auto _wrap_no_last(Ret(func)(Args...), type_list, - LastGetterT last_func) noexcept { - return [func, last_func](wrapped_arg_t... args) -> Ret { - return func(unwrap_arg(args)..., last_func()); +template +constexpr auto +remove_cuda_stream_arg_helper(Ret(func)(Args...), type_list, + type_list) noexcept { + return [func](wrapped_arg_t... prefixArgs, + wrapped_arg_t... suffixArgs) -> Ret { + return func(unwrap_arg(prefixArgs)..., at::cuda::getCurrentCUDAStream(), + unwrap_arg(suffixArgs)...); }; } template constexpr auto wrap(Ret(func)(Args...)) noexcept { - using LastArg = typename type_list::back_t; - if constexpr (std::is_same_v) { - using stripped = typename type_list::template pop_back<>; - return _wrap_no_last<>(func, stripped(), - []() { return at::cuda::getCurrentCUDAStream(); }); + using tl = type_list; + if constexpr (tl::template contains) { + constexpr size_t stream_arg_idx = tl::template find; + using prefix = tl::template pop_back; + using suffix = tl::template pop_front; + return remove_cuda_stream_arg_helper(func, prefix(), suffix()); } else { return [func](wrapped_arg_t... args) -> Ret { return func(unwrap_arg(args)...); @@ -138,7 +141,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nvte_fused_attn_bwd_kvpacked", wrap(nvte_fused_attn_bwd_kvpacked)); m.def("nvte_cublas_gemm", wrap(nvte_cublas_gemm)); m.def("nvte_layernorm_fwd", wrap(nvte_layernorm_fwd)); - m.def("nvte_layernorm1p_fwd", wrap(nvte_layernorm1p_fwd)); + m.def("nvte_layernorm1p_fwd", wrap()); m.def("nvte_layernorm_bwd", wrap(nvte_layernorm_bwd)); m.def("nvte_layernorm1p_bwd", wrap(nvte_layernorm1p_bwd)); m.def("nvte_rmsnorm_fwd", wrap(nvte_rmsnorm_fwd)); diff --git a/transformer_engine/pytorch/sequential/cpp_extensions/type_list.h b/transformer_engine/pytorch/sequential/cpp_extensions/type_list.h index 5023f1dae4..7b5459761d 100644 --- a/transformer_engine/pytorch/sequential/cpp_extensions/type_list.h +++ b/transformer_engine/pytorch/sequential/cpp_extensions/type_list.h @@ -1,4 +1,5 @@ #include +#include #include template struct type_list; @@ -10,6 +11,11 @@ template struct type_list_index; template struct type_list_cat_list; template struct type_list_pop_front_list; template struct type_list_pop_back_list; +template struct type_list_contains; +template typename Pred> struct type_list_any; +template struct type_list_find; +template typename Pred> +struct type_list_first; template struct type_list_front> { @@ -20,7 +26,9 @@ template struct type_list_pop_front_list, 0> { using type = type_list; }; - +template <> struct type_list_pop_front_list, 0> { + using type = type_list<>; +}; template struct type_list_pop_front_list, N> { using type = typename type_list_pop_front_list, N - 1>::type; @@ -71,6 +79,49 @@ struct type_list_pop_back_list, N> { using type = typename type_list_reverse_list::type; }; +template typename Pred> +struct type_list_any, Pred> { + static constexpr bool value = (Pred::value || ...); +}; + +template typename Pred> +struct type_list_first, Pred> { +private: + static constexpr bool values[] = {Pred::value...}; + +public: + static constexpr size_t value = []() { + for (size_t i = 0; i < sizeof(values) / sizeof(bool); ++i) { + if (values[i]) { + return i; + } + } + return sizeof(values) / sizeof(bool); + }(); +}; + +template +struct type_list_contains, T> { +private: + template struct pred { + static constexpr bool value = std::is_same_v; + }; + +public: + static constexpr bool value = type_list_any, pred>::value; +}; + +template +struct type_list_find, T> { + template struct pred { + static constexpr bool value = std::is_same_v; + }; + +public: + static constexpr size_t value = + type_list_first, pred>::value; +}; + template using type_list_front_t = typename type_list_front::type; template @@ -86,6 +137,14 @@ using type_list_pop_front_list_t = typename type_list_pop_front_list::type; template using type_list_pop_back_list_t = typename type_list_pop_back_list::type; +template +constexpr bool type_list_contains_v = type_list_contains::value; +template typename Pred> +constexpr bool type_list_any_v = type_list_any::value; +template +constexpr size_t type_list_find_v = type_list_find::value; +template typename Pred> +constexpr size_t type_list_first_v = type_list_first::value; template struct type_list { using front = type_list>; @@ -94,9 +153,28 @@ template struct type_list { using back = type_list>; using back_t = type_list_index_t; + using reverse = type_list_reverse_list_t; + + template using get = type_list_index_t; + template using pop_front = type_list_pop_front_list_t; template using pop_back = type_list_pop_back_list_t; + + template + static constexpr bool contains = type_list_contains_v; + + template