From 16f064e4cdcc4aaca494b20482f96a4fc74da68a Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 19 Mar 2023 12:06:55 -0400 Subject: [PATCH 001/178] Builds and can start debugging again. --- CMakeLists.txt | 287 +++++++++++++++++++------------------ csrc/compute_at_map.cpp | 4 +- csrc/disjoint_set.h | 191 +++++++++++++++++------- csrc/ir_utils.cpp | 30 ++-- csrc/ir_utils.h | 5 + csrc/lower2device.cpp | 54 ++++--- csrc/transform_iter.cpp | 282 ++++++++++++++++++++---------------- csrc/transform_iter.h | 206 +++++++++++++++++++++----- csrc/type.cpp | 8 +- csrc/type.h | 4 +- csrc/utils.cpp | 1 + csrc/utils.h | 1 + test/test_gpu3.cpp | 35 ----- test/test_gpu_indexing.cpp | 270 +++++++++++++++++++++++++++++++--- 14 files changed, 936 insertions(+), 442 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6325ac765c6..2001d097f6f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,142 +15,144 @@ else() endif() # --- project - file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/nvfuser") set(NVFUSER_ROOT ${PROJECT_SOURCE_DIR}) set(NVFUSER_SRCS_DIR "${NVFUSER_ROOT}/csrc") + # TODO: have TORCH_ROOT setup as a variable instead # currently we are expecting nvfuser to be added from the pytorch root cmake file. set(TORCH_ROOT "${CMAKE_SOURCE_DIR}") set(TORCH_INSTALL_LIB_DIR ${TORCH_ROOT}/torch/lib) # --- build nvfuser_codegen library - set(NVFUSER_SRCS) set(NVFUSER_CODEGEN ${PROJECT_NAME}_codegen) list(APPEND NVFUSER_SRCS - ${NVFUSER_SRCS_DIR}/compute_at.cpp - ${NVFUSER_SRCS_DIR}/inlining.cpp - ${NVFUSER_SRCS_DIR}/compute_at_map.cpp - ${NVFUSER_SRCS_DIR}/codegen.cpp - ${NVFUSER_SRCS_DIR}/contiguity.cpp - ${NVFUSER_SRCS_DIR}/dispatch.cpp - ${NVFUSER_SRCS_DIR}/expr_evaluator.cpp - ${NVFUSER_SRCS_DIR}/expr_simplifier.cpp - ${NVFUSER_SRCS_DIR}/executor.cpp - ${NVFUSER_SRCS_DIR}/executor_kernel_arg.cpp - ${NVFUSER_SRCS_DIR}/executor_params.cpp - ${NVFUSER_SRCS_DIR}/evaluator_common.cpp - ${NVFUSER_SRCS_DIR}/executor_utils.cpp - ${NVFUSER_SRCS_DIR}/fusion.cpp - ${NVFUSER_SRCS_DIR}/graph_fuser.cpp - ${NVFUSER_SRCS_DIR}/grouped_reduction.cpp - ${NVFUSER_SRCS_DIR}/index_compute.cpp - ${NVFUSER_SRCS_DIR}/lower_index_compute.cpp - ${NVFUSER_SRCS_DIR}/instrumentation.cpp - ${NVFUSER_SRCS_DIR}/ir_base_nodes.cpp - ${NVFUSER_SRCS_DIR}/ir_builder.cpp - ${NVFUSER_SRCS_DIR}/ir_cloner.cpp - ${NVFUSER_SRCS_DIR}/ir_container.cpp - ${NVFUSER_SRCS_DIR}/ir_graphviz.cpp - ${NVFUSER_SRCS_DIR}/ir_nodes.cpp - ${NVFUSER_SRCS_DIR}/ir_iostream.cpp - ${NVFUSER_SRCS_DIR}/ir_utils.cpp - ${NVFUSER_SRCS_DIR}/iter_visitor.cpp - ${NVFUSER_SRCS_DIR}/kernel.cpp - ${NVFUSER_SRCS_DIR}/kernel_cache.cpp - ${NVFUSER_SRCS_DIR}/kernel_db/kernel_db.cpp - ${NVFUSER_SRCS_DIR}/kernel_db/utils.cpp - ${NVFUSER_SRCS_DIR}/kernel_ir.cpp - ${NVFUSER_SRCS_DIR}/kernel_ir_dispatch.cpp - ${NVFUSER_SRCS_DIR}/lower_alias_memory.cpp - ${NVFUSER_SRCS_DIR}/lower_allocation.cpp - ${NVFUSER_SRCS_DIR}/lower_double_buffer.cpp - ${NVFUSER_SRCS_DIR}/lower_divisible_split.cpp - ${NVFUSER_SRCS_DIR}/lower_expr_sort.cpp - ${NVFUSER_SRCS_DIR}/lower_fused_reduction.cpp - ${NVFUSER_SRCS_DIR}/lower_fusion_simplifier.cpp - ${NVFUSER_SRCS_DIR}/lower_index.cpp - ${NVFUSER_SRCS_DIR}/lower_scalar_hoist.cpp - ${NVFUSER_SRCS_DIR}/lower_insert_syncs.cpp - ${NVFUSER_SRCS_DIR}/lower_instrument.cpp - ${NVFUSER_SRCS_DIR}/lower_loop_rotation.cpp - ${NVFUSER_SRCS_DIR}/lower_loops.cpp - ${NVFUSER_SRCS_DIR}/lower_magic_zero.cpp - ${NVFUSER_SRCS_DIR}/lower_misaligned_vectorization.cpp - ${NVFUSER_SRCS_DIR}/lower_predicate.cpp - ${NVFUSER_SRCS_DIR}/lower_predicate_elimination.cpp - ${NVFUSER_SRCS_DIR}/lower_replace_size.cpp - ${NVFUSER_SRCS_DIR}/lower_shift.cpp - ${NVFUSER_SRCS_DIR}/lower_sync_information.cpp - ${NVFUSER_SRCS_DIR}/lower_thread_predicate.cpp - ${NVFUSER_SRCS_DIR}/lower_trivial_broadcast.cpp - ${NVFUSER_SRCS_DIR}/lower_unroll.cpp - ${NVFUSER_SRCS_DIR}/lower_utils.cpp - ${NVFUSER_SRCS_DIR}/lower_validation.cpp - ${NVFUSER_SRCS_DIR}/lower_vectorize_welford.cpp - ${NVFUSER_SRCS_DIR}/lower_warp_reduce.cpp - ${NVFUSER_SRCS_DIR}/lower2device.cpp - ${NVFUSER_SRCS_DIR}/lower_bank_conflict.cpp - ${NVFUSER_SRCS_DIR}/manager.cpp - ${NVFUSER_SRCS_DIR}/maxinfo_propagator.cpp - ${NVFUSER_SRCS_DIR}/multidevice/aggregate_dag.cpp - ${NVFUSER_SRCS_DIR}/multidevice/multidevice_runtime.cpp - ${NVFUSER_SRCS_DIR}/multidevice/multicluster_fusion.cpp - ${NVFUSER_SRCS_DIR}/multidevice/ProcessGroupBuilder.cpp - ${NVFUSER_SRCS_DIR}/mutator.cpp - ${NVFUSER_SRCS_DIR}/non_divisible_split.cpp - ${NVFUSER_SRCS_DIR}/ops/alias.cpp - ${NVFUSER_SRCS_DIR}/ops/arith.cpp - ${NVFUSER_SRCS_DIR}/ops/composite.cpp - ${NVFUSER_SRCS_DIR}/ops/normalization.cpp - ${NVFUSER_SRCS_DIR}/ops/utils.cpp - ${NVFUSER_SRCS_DIR}/parallel_dimension_map.cpp - ${NVFUSER_SRCS_DIR}/parallel_type_bitmap.cpp - ${NVFUSER_SRCS_DIR}/parser.cpp - ${NVFUSER_SRCS_DIR}/partial_split_map.cpp - ${NVFUSER_SRCS_DIR}/partition.cpp - ${NVFUSER_SRCS_DIR}/predicate_compute.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/fusion_cache.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/fusion_state.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/fusion_definition.cpp - ${NVFUSER_SRCS_DIR}/register_interface.cpp - ${NVFUSER_SRCS_DIR}/root_domain_map.cpp - ${NVFUSER_SRCS_DIR}/scheduler/pointwise.cpp - ${NVFUSER_SRCS_DIR}/scheduler/pointwise_utils.cpp - ${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp - ${NVFUSER_SRCS_DIR}/scheduler/normalization.cpp - ${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp - ${NVFUSER_SRCS_DIR}/scheduler/reduction.cpp - ${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp - ${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp - ${NVFUSER_SRCS_DIR}/scheduler/registry.cpp - ${NVFUSER_SRCS_DIR}/scheduler/utils.cpp - ${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp - ${NVFUSER_SRCS_DIR}/swizzle.cpp - ${NVFUSER_SRCS_DIR}/sys_utils.cpp - ${NVFUSER_SRCS_DIR}/type_inference.cpp - ${NVFUSER_SRCS_DIR}/type_promotion.cpp - ${NVFUSER_SRCS_DIR}/fusion_segmenter.cpp - ${NVFUSER_SRCS_DIR}/tensor_view.cpp - ${NVFUSER_SRCS_DIR}/transform_iter.cpp - ${NVFUSER_SRCS_DIR}/transform_replay.cpp - ${NVFUSER_SRCS_DIR}/transform_rfactor.cpp - ${NVFUSER_SRCS_DIR}/transform_view.cpp - ${NVFUSER_SRCS_DIR}/type.cpp - ${NVFUSER_SRCS_DIR}/utils.cpp - ${NVFUSER_SRCS_DIR}/mma_type.cpp - ${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp + ${NVFUSER_SRCS_DIR}/compute_at.cpp + ${NVFUSER_SRCS_DIR}/inlining.cpp + ${NVFUSER_SRCS_DIR}/compute_at_map.cpp + ${NVFUSER_SRCS_DIR}/codegen.cpp + ${NVFUSER_SRCS_DIR}/contiguity.cpp + ${NVFUSER_SRCS_DIR}/dispatch.cpp + ${NVFUSER_SRCS_DIR}/expr_evaluator.cpp + ${NVFUSER_SRCS_DIR}/expr_simplifier.cpp + ${NVFUSER_SRCS_DIR}/executor.cpp + ${NVFUSER_SRCS_DIR}/executor_kernel_arg.cpp + ${NVFUSER_SRCS_DIR}/executor_params.cpp + ${NVFUSER_SRCS_DIR}/evaluator_common.cpp + ${NVFUSER_SRCS_DIR}/executor_utils.cpp + ${NVFUSER_SRCS_DIR}/fusion.cpp + ${NVFUSER_SRCS_DIR}/graph_fuser.cpp + ${NVFUSER_SRCS_DIR}/grouped_reduction.cpp + ${NVFUSER_SRCS_DIR}/index_compute.cpp + ${NVFUSER_SRCS_DIR}/lower_index_compute.cpp + ${NVFUSER_SRCS_DIR}/instrumentation.cpp + ${NVFUSER_SRCS_DIR}/id_graphs.cpp + ${NVFUSER_SRCS_DIR}/ir_base_nodes.cpp + ${NVFUSER_SRCS_DIR}/ir_builder.cpp + ${NVFUSER_SRCS_DIR}/ir_cloner.cpp + ${NVFUSER_SRCS_DIR}/ir_container.cpp + ${NVFUSER_SRCS_DIR}/ir_graphviz.cpp + ${NVFUSER_SRCS_DIR}/ir_nodes.cpp + ${NVFUSER_SRCS_DIR}/ir_iostream.cpp + ${NVFUSER_SRCS_DIR}/ir_utils.cpp + ${NVFUSER_SRCS_DIR}/iter_visitor.cpp + ${NVFUSER_SRCS_DIR}/kernel.cpp + ${NVFUSER_SRCS_DIR}/kernel_cache.cpp + ${NVFUSER_SRCS_DIR}/kernel_db/kernel_db.cpp + ${NVFUSER_SRCS_DIR}/kernel_db/utils.cpp + ${NVFUSER_SRCS_DIR}/kernel_ir.cpp + ${NVFUSER_SRCS_DIR}/kernel_ir_dispatch.cpp + ${NVFUSER_SRCS_DIR}/lower_alias_memory.cpp + ${NVFUSER_SRCS_DIR}/lower_allocation.cpp + ${NVFUSER_SRCS_DIR}/lower_double_buffer.cpp + ${NVFUSER_SRCS_DIR}/lower_divisible_split.cpp + ${NVFUSER_SRCS_DIR}/lower_expr_sort.cpp + ${NVFUSER_SRCS_DIR}/lower_fused_reduction.cpp + ${NVFUSER_SRCS_DIR}/lower_fusion_simplifier.cpp + ${NVFUSER_SRCS_DIR}/lower_index.cpp + ${NVFUSER_SRCS_DIR}/lower_scalar_hoist.cpp + ${NVFUSER_SRCS_DIR}/lower_insert_syncs.cpp + ${NVFUSER_SRCS_DIR}/lower_instrument.cpp + ${NVFUSER_SRCS_DIR}/lower_loop_rotation.cpp + ${NVFUSER_SRCS_DIR}/lower_loops.cpp + ${NVFUSER_SRCS_DIR}/lower_magic_zero.cpp + ${NVFUSER_SRCS_DIR}/lower_misaligned_vectorization.cpp + ${NVFUSER_SRCS_DIR}/lower_predicate.cpp + ${NVFUSER_SRCS_DIR}/lower_predicate_elimination.cpp + ${NVFUSER_SRCS_DIR}/lower_replace_size.cpp + ${NVFUSER_SRCS_DIR}/lower_shift.cpp + ${NVFUSER_SRCS_DIR}/lower_sync_information.cpp + ${NVFUSER_SRCS_DIR}/lower_thread_predicate.cpp + ${NVFUSER_SRCS_DIR}/lower_trivial_broadcast.cpp + ${NVFUSER_SRCS_DIR}/lower_unroll.cpp + ${NVFUSER_SRCS_DIR}/lower_utils.cpp + ${NVFUSER_SRCS_DIR}/lower_validation.cpp + ${NVFUSER_SRCS_DIR}/lower_vectorize_welford.cpp + ${NVFUSER_SRCS_DIR}/lower_warp_reduce.cpp + ${NVFUSER_SRCS_DIR}/lower2device.cpp + ${NVFUSER_SRCS_DIR}/lower_bank_conflict.cpp + ${NVFUSER_SRCS_DIR}/manager.cpp + ${NVFUSER_SRCS_DIR}/maxinfo_propagator.cpp + ${NVFUSER_SRCS_DIR}/multidevice/aggregate_dag.cpp + ${NVFUSER_SRCS_DIR}/multidevice/multidevice_runtime.cpp + ${NVFUSER_SRCS_DIR}/multidevice/multicluster_fusion.cpp + ${NVFUSER_SRCS_DIR}/multidevice/ProcessGroupBuilder.cpp + ${NVFUSER_SRCS_DIR}/mutator.cpp + ${NVFUSER_SRCS_DIR}/non_divisible_split.cpp + ${NVFUSER_SRCS_DIR}/ops/alias.cpp + ${NVFUSER_SRCS_DIR}/ops/arith.cpp + ${NVFUSER_SRCS_DIR}/ops/composite.cpp + ${NVFUSER_SRCS_DIR}/ops/normalization.cpp + ${NVFUSER_SRCS_DIR}/ops/utils.cpp + ${NVFUSER_SRCS_DIR}/parallel_dimension_map.cpp + ${NVFUSER_SRCS_DIR}/parallel_type_bitmap.cpp + ${NVFUSER_SRCS_DIR}/parser.cpp + ${NVFUSER_SRCS_DIR}/partial_split_map.cpp + ${NVFUSER_SRCS_DIR}/partition.cpp + ${NVFUSER_SRCS_DIR}/predicate_compute.cpp + ${NVFUSER_SRCS_DIR}/python_frontend/fusion_cache.cpp + ${NVFUSER_SRCS_DIR}/python_frontend/fusion_state.cpp + ${NVFUSER_SRCS_DIR}/python_frontend/fusion_definition.cpp + ${NVFUSER_SRCS_DIR}/register_interface.cpp + ${NVFUSER_SRCS_DIR}/root_domain_map.cpp + ${NVFUSER_SRCS_DIR}/scheduler/pointwise.cpp + ${NVFUSER_SRCS_DIR}/scheduler/pointwise_utils.cpp + ${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp + ${NVFUSER_SRCS_DIR}/scheduler/normalization.cpp + ${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp + ${NVFUSER_SRCS_DIR}/scheduler/reduction.cpp + ${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp + ${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp + ${NVFUSER_SRCS_DIR}/scheduler/registry.cpp + ${NVFUSER_SRCS_DIR}/scheduler/utils.cpp + ${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp + ${NVFUSER_SRCS_DIR}/swizzle.cpp + ${NVFUSER_SRCS_DIR}/sys_utils.cpp + ${NVFUSER_SRCS_DIR}/type_inference.cpp + ${NVFUSER_SRCS_DIR}/type_promotion.cpp + ${NVFUSER_SRCS_DIR}/fusion_segmenter.cpp + ${NVFUSER_SRCS_DIR}/tensor_view.cpp + ${NVFUSER_SRCS_DIR}/transform_iter.cpp + ${NVFUSER_SRCS_DIR}/transform_replay.cpp + ${NVFUSER_SRCS_DIR}/transform_rfactor.cpp + ${NVFUSER_SRCS_DIR}/transform_view.cpp + ${NVFUSER_SRCS_DIR}/type.cpp + ${NVFUSER_SRCS_DIR}/utils.cpp + ${NVFUSER_SRCS_DIR}/mma_type.cpp + ${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp ) add_library(${NVFUSER_CODEGEN} SHARED ${NVFUSER_SRCS}) + # Note: for some reason, torch_compile_options gives us segfaults when we run test_jit -#torch_compile_options(${NVFUSER_CODEGEN}) +# torch_compile_options(${NVFUSER_CODEGEN}) target_compile_options(${NVFUSER_CODEGEN} PRIVATE -Wall -Wno-unused-function) if(NOT USE_ROCM) target_compile_options(${NVFUSER_CODEGEN} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") + # NB: This must be target_compile_definitions, not target_compile_options, # as the latter is not respected by nvcc target_compile_definitions(${NVFUSER_CODEGEN} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") @@ -160,12 +162,14 @@ else() target_compile_definitions(${NVFUSER_CODEGEN} PRIVATE USE_ROCM __HIP_PLATFORM_HCC__ - ) + ) endif() target_link_libraries(${NVFUSER_CODEGEN} PRIVATE torch ${TORCHLIB_FLAVOR}) + # For kernel_db, linking STL Filesystem Library for backward compatability with C++14 target_link_libraries(${NVFUSER_CODEGEN} PRIVATE stdc++fs) + if(NOT USE_ROCM) target_link_libraries(${NVFUSER_CODEGEN} PRIVATE ${CUDA_NVRTC_LIB} torch::nvtoolsext) target_include_directories(${NVFUSER_CODEGEN} PRIVATE ${CUDA_INCLUDE_DIRS}) @@ -173,39 +177,44 @@ else() target_link_libraries(${NVFUSER_CODEGEN} PRIVATE ${ROCM_HIPRTC_LIB}) target_include_directories(${NVFUSER_CODEGEN} PRIVATE ${Caffe2_HIP_INCLUDE}) endif() + if(NOT MSVC) target_compile_options(${NVFUSER_CODEGEN} PRIVATE -Werror) endif() + target_include_directories(${NVFUSER_CODEGEN} PUBLIC - "$" - "$" - ) + "$" + "$" +) set_property(TARGET ${NVFUSER_CODEGEN} PROPERTY CXX_STANDARD 17) install(TARGETS ${NVFUSER_CODEGEN} EXPORT NvfuserTargets DESTINATION "${TORCH_INSTALL_LIB_DIR}") + # installing nvfuser headers install(DIRECTORY "${NVFUSER_SRCS_DIR}/" - DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/nvfuser" - FILES_MATCHING PATTERN "*.h" ) + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/nvfuser" + FILES_MATCHING PATTERN "*.h") + # installing nvfuser python tests install(DIRECTORY "${NVFUSER_ROOT}/python_tests/" - DESTINATION "${TORCH_ROOT}/test/_nvfuser" - FILES_MATCHING PATTERN "*.py" ) + DESTINATION "${TORCH_ROOT}/test/_nvfuser" + FILES_MATCHING PATTERN "*.py") # --- build nvfuser_python library - if(BUILD_PYTHON) set(NVFUSER "${PROJECT_NAME}") set(NVFUSER_PYTHON_SRCS) list(APPEND NVFUSER_PYTHON_SRCS - ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp + ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp + ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp ) add_library(${NVFUSER} MODULE ${NVFUSER_PYTHON_SRCS}) torch_compile_options(${NVFUSER}) + if(NOT USE_ROCM) target_compile_options(${NVFUSER} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") + # NB: This must be target_compile_definitions, not target_compile_options, # as the latter is not respected by nvcc target_compile_definitions(${NVFUSER} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") @@ -216,7 +225,7 @@ if(BUILD_PYTHON) target_compile_definitions(${NVFUSER} PRIVATE USE_ROCM __HIP_PLATFORM_HCC__ - ) + ) target_include_directories(${NVFUSER_CODEGEN} PRIVATE ${Caffe2_HIP_INCLUDE}) endif() @@ -229,6 +238,7 @@ if(BUILD_PYTHON) # avoid using Python3_add_library, copied from functorch set_target_properties(${NVFUSER} PROPERTIES PREFIX "" DEBUG_POSTFIX "") + if(NOT MSVC) target_compile_options(${NVFUSER} PRIVATE -Werror) set_target_properties(${NVFUSER} PROPERTIES SUFFIX ".so") @@ -237,22 +247,23 @@ if(BUILD_PYTHON) endif() set_target_properties(${NVFUSER} PROPERTIES LIBRARY_OUTPUT_DIRECTORY - ${CMAKE_BINARY_DIR}/nvfuser) + ${CMAKE_BINARY_DIR}/nvfuser) set_target_properties(${NVFUSER} PROPERTIES INSTALL_RPATH "${_rpath_portable_origin}/../torch/lib") if(TORCH_PYTHON_LINK_FLAGS AND NOT TORCH_PYTHON_LINK_FLAGS STREQUAL "") message(STATUS "somehow this is happening") set_target_properties(${NVFUSER} PROPERTIES LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS}) endif() + install(TARGETS ${NVFUSER} EXPORT NvfuserTargets DESTINATION ${TORCH_ROOT}/nvfuser/) # setup python API version add_custom_command( OUTPUT ${NVFUSER_ROOT}/python/version.py COMMAND - "${PYTHON_EXECUTABLE}" -c \"from pathlib import Path\; Path('${NVFUSER_ROOT}/tools/gen_nvfuser_version.py').touch()\" + "${PYTHON_EXECUTABLE}" -c \"from pathlib import Path\; Path('${NVFUSER_ROOT}/tools/gen_nvfuser_version.py') .touch() \" COMMAND - "${PYTHON_EXECUTABLE}" ${NVFUSER_ROOT}/tools/gen_nvfuser_version.py + "${PYTHON_EXECUTABLE}" ${NVFUSER_ROOT}/tools/gen_nvfuser_version.py DEPENDS ${NVFUSER_ROOT}/tools/gen_nvfuser_version.py WORKING_DIRECTORY ${NVFUSER_ROOT}/tools/ ) @@ -264,8 +275,8 @@ if(BUILD_PYTHON) # install nvfuser python files install(DIRECTORY "${NVFUSER_ROOT}/python/" - DESTINATION "${TORCH_ROOT}/nvfuser" - FILES_MATCHING PATTERN "*.py" ) + DESTINATION "${TORCH_ROOT}/nvfuser" + FILES_MATCHING PATTERN "*.py") file(WRITE "${TORCH_ROOT}/nvfuser/.gitignore" "*") endif() @@ -310,6 +321,7 @@ file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/include/nvfuser_resources") # "stringify" NVFUSER runtime sources # (generate C++ header files embedding the original input as a string literal) set(NVFUSER_STRINGIFY_TOOL "${NVFUSER_ROOT}/tools/stringify_file.py") + foreach(src ${NVFUSER_RUNTIME_FILES}) get_filename_component(filename ${src} NAME_WE) set(dst "${CMAKE_BINARY_DIR}/include/nvfuser_resources/${filename}.h") @@ -369,15 +381,16 @@ if(BUILD_TEST) list(APPEND JIT_TEST_CU_SRCS ${NVFUSER_ROOT}/test/test_gpu_rng.cu) add_executable(${NVFUSER_TESTS} - ${TORCH_ROOT}/test/cpp/common/main.cpp - ${TORCH_ROOT}/test/cpp/jit/test_utils.cpp - ${JIT_TEST_SRCS} - ${JIT_TEST_CU_SRCS}) + ${TORCH_ROOT}/test/cpp/common/main.cpp + ${TORCH_ROOT}/test/cpp/jit/test_utils.cpp + ${JIT_TEST_SRCS} + ${JIT_TEST_CU_SRCS}) torch_compile_options(${NVFUSER_TESTS}) target_compile_definitions(${NVFUSER_TESTS} PRIVATE USE_GTEST) target_include_directories(${NVFUSER_TESTS} PRIVATE "${NVFUSER_ROOT}" "${TORCH_ROOT}/torch/csrc/api/include/") target_link_libraries(${NVFUSER_TESTS} PRIVATE ${NVFUSER_CODEGEN} torch ${TORCHLIB_FLAVOR} gtest_main gmock_main) + if(NOT MSVC) set_property(SOURCE ${JIT_TEST_SRCS} APPEND PROPERTY COMPILE_OPTIONS "-Werror") endif() @@ -421,11 +434,11 @@ if(BUILD_NVFUSER_BENCHMARK) install(TARGETS ${NVFUSER_BENCHMARK} DESTINATION bin) target_link_libraries(${NVFUSER_BENCHMARK} PRIVATE torch_library benchmark ${NVFUSER_CODEGEN}) target_include_directories(${NVFUSER_BENCHMARK} PRIVATE ${NVFUSER_ROOT}) + if(NOT MSVC) target_compile_options_if_supported(nvfuser_bench -Werror) target_compile_options_if_supported(nvfuser_bench -Wno-deprecated-copy) endif() - endif() # -- install nvfuser cmake config files and symlink to build binaries diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index 854c8d1d602..9addae93b84 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -1470,8 +1470,10 @@ const DisjointSets& ComputeAtMap::getIdSets( return id_graph_.permissiveNodes(); case IdMappingMode::PERMISSIVE_RESIZE: return id_graph_.permissiveResizeNodes(); + default: + TORCH_INTERNAL_ASSERT( + false, "Error with mapping mode, didn't expect mode: ", mode); } - TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided."); } bool ComputeAtMap::idExistsInMap(IterDomain* id) const { diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index f01aeb3b61a..d9da203ada9 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -1,10 +1,3 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on #pragma once #include @@ -36,13 +29,38 @@ std::string abstractToString(T ref) { // Vector like class that will prevent adding duplicate entries by also // maintaing a set +// +// TODO: Can we support std::back_inserter with this class? template > class VectorOfUniqueEntries { public: VectorOfUniqueEntries() = default; - VectorOfUniqueEntries(const std::initializer_list& x) - : vector_(x), set_(x) {} + VectorOfUniqueEntries(const std::initializer_list& initializer) { + for (auto entry : initializer) { + pushBack(entry); + } + } + + VectorOfUniqueEntries(const VectorOfUniqueEntries& other) { + vector_ = other.vector(); + set_ = other.set(); + } + + VectorOfUniqueEntries& operator=(const VectorOfUniqueEntries& other) { + if (this != &other) { + vector_ = other.vector(); + set_ = other.set(); + } + return *this; + } + + template + VectorOfUniqueEntries(InputIt first, InputIt last) { + while (first != last) { + pushBack(*first++); + } + } // Returns if a node was actually added bool pushBack(T entry) { @@ -53,6 +71,15 @@ class VectorOfUniqueEntries { return false; } + // Returns if a node was actually added + bool pushFront(T entry) { + if (set_.emplace(entry).second) { + vector_.insert(vector_.begin(), entry); + return true; + } + return false; + } + // Returns if any node was added bool pushBack(const VectorOfUniqueEntries& other) { bool any_added = false; @@ -62,11 +89,53 @@ class VectorOfUniqueEntries { return any_added; } + // Returns a new VectorOfUniqueEntries with entries that are in both this and + // other, order is preserved as this. + VectorOfUniqueEntries intersect( + const VectorOfUniqueEntries& other) { + VectorOfUniqueEntries intersection; + for (auto entry : vector()) { + if (other.has(entry)) { + intersection.pushBack(entry); + } + } + return intersection; + } + + // Returns a new VectorOfUniqueEntries with entries that are in this but not + // in other. + VectorOfUniqueEntries subtract( + const VectorOfUniqueEntries& other) const { + VectorOfUniqueEntries subtraction; + for (auto entry : vector()) { + if (!other.has(entry)) { + subtraction.pushBack(entry); + } + } + return subtraction; + } + + // Returns a new VectorOfUniqueEntries with entries that are either in this or + // other. + VectorOfUniqueEntries computeUnion( + const VectorOfUniqueEntries& other) const { + const VectorOfUniqueEntries& this_ref = *this; + VectorOfUniqueEntries union_(this_ref); + for (auto entry : other.vector()) { + union_.pushBack(entry); + } + return union_; + } + // Returns a const vector useful for iterating on const std::vector& vector() const { return vector_; } + const std::unordered_set& set() const { + return set_; + } + // Returns first element in vector T front() const { return vector_.front(); @@ -85,6 +154,14 @@ class VectorOfUniqueEntries { return v; } + // Remove and returns the last element in vector + T popFront() { + T v = vector_.front(); + set_.erase(v); + vector_.erase(vector_.begin()); + return v; + } + // Returns if this container is empty bool empty() const { return vector_.empty(); @@ -141,7 +218,7 @@ class VectorOfUniqueEntries { return vector_.end(); } - std::string toString() { + std::string toString() const { std::stringstream ss; ss << "{ "; for (auto entry : vector()) { @@ -210,64 +287,78 @@ class DisjointSets { } // Initializes a new set for provided entry - // - // TODO: Return iterator - void initializeSet(T entry) { - if (disjoint_set_maps_.find(entry) != disjoint_set_maps_.end()) { - return; + std::pair< + typename std::unordered_map< + T, + std::shared_ptr>, + Hash>::iterator, + bool> + initializeSet(T entry) { + auto disjoint_set_maps_it = disjoint_set_maps_.find(entry); + if (disjoint_set_maps_it != disjoint_set_maps_.end()) { + return std::make_pair(disjoint_set_maps_it, false); } disjoint_sets_.push_back( std::make_shared>()); disjoint_sets_.back()->pushBack(entry); - disjoint_set_maps_.emplace(std::make_pair(entry, disjoint_sets_.back())); + return disjoint_set_maps_.emplace( + std::make_pair(entry, disjoint_sets_.back())); } // Adds all of the disjoint set belonging to entry1 to the disjoint set // belonging to entry0, maps all entries of disjoint set belonging to entry1 // to entry0, removes original disjoint set belonging to entry1. void mapEntries(T entry0, T entry1) { + if (entry0 == entry1) { + return; + } + auto set_it_0 = disjoint_set_maps_.find(entry0); auto set_it_1 = disjoint_set_maps_.find(entry1); - // Track if we need to reset iterators, optimize for case where both entries - // exist - bool invalid_iterators = false; - if (set_it_0 == disjoint_set_maps_.end()) { - initializeSet(entry0); - invalid_iterators = true; - } + auto set_0_found = set_it_0 != disjoint_set_maps_.end(); + auto set_1_found = set_it_1 != disjoint_set_maps_.end(); - if (set_it_1 == disjoint_set_maps_.end()) { - initializeSet(entry1); - invalid_iterators = true; + // Sets already joined + if (set_0_found && set_1_found && set_it_0->second == set_it_1->second) { + return; } - // TODO: We can avoid refinding one iterator if initialize set returns an - // iterator, though if we insert entry1 we'd have to refind entry0 as it - // could invalidate all iterators - if (invalid_iterators) { - set_it_0 = disjoint_set_maps_.find(entry0); + // Make and map new set + disjoint_sets_.push_back( + std::make_shared>()); + auto new_set = disjoint_sets_.back(); + + if (set_0_found) { + auto set_0 = set_it_0->second; + for (auto set_0_entry : *set_0) { + TORCH_INTERNAL_ASSERT(set_0_entry != entry1); + new_set->pushBack(set_0_entry); + disjoint_set_maps_[set_0_entry] = new_set; + } + disjoint_sets_.erase( + std::find(disjoint_sets_.begin(), disjoint_sets_.end(), set_0)); + // Erase invalidates iterators, regrab. set_it_1 = disjoint_set_maps_.find(entry1); + set_1_found = set_it_1 != disjoint_set_maps_.end(); + } else { + new_set->pushBack(entry0); + disjoint_set_maps_[entry0] = new_set; } - auto set0_shared_ptr = set_it_0->second; - auto set1_shared_ptr = set_it_1->second; - - // If the sets are already the same, do nothing - if (set0_shared_ptr == set1_shared_ptr) { - return; - } - - // Place everything in set1 into set0 and remap all entries in set1 to set0 - for (auto entry : set1_shared_ptr->vector()) { - set0_shared_ptr->pushBack(entry); - disjoint_set_maps_[entry] = set0_shared_ptr; + if (set_1_found) { + auto set_1 = set_it_1->second; + for (auto set_1_entry : *set_1) { + new_set->pushBack(set_1_entry); + disjoint_set_maps_[set_1_entry] = new_set; + } + disjoint_sets_.erase( + std::find(disjoint_sets_.begin(), disjoint_sets_.end(), set_1)); + } else { + new_set->pushBack(entry1); + disjoint_set_maps_[entry1] = new_set; } - - // set1 no longer needed as its entries are copied into set0 - disjoint_sets_.erase(std::find( - disjoint_sets_.begin(), disjoint_sets_.end(), set1_shared_ptr)); } // Will assert if provided entry0 is not in any disjoint set, otherwise @@ -323,11 +414,7 @@ class DisjointSets { const std::string sep(" "); for (auto s_ptr : disjoint_sets_) { auto& set = *s_ptr; - ss << sep << "{\n"; - for (auto entry : set.vector()) { - ss << sep << sep << abstractToString(entry) << "\n"; - } - ss << sep << "}\n"; + ss << sep << abstractToString(set) << "\n"; } ss << "}"; return ss.str(); diff --git a/csrc/ir_utils.cpp b/csrc/ir_utils.cpp index 83ed969d47a..be7fc54bc76 100644 --- a/csrc/ir_utils.cpp +++ b/csrc/ir_utils.cpp @@ -223,15 +223,9 @@ TensorView* rfactorHelper( namespace { template -std::vector uniqueEntries(const std::vector& tv_deuqe) { - std::vector unique_entries; - std::unordered_set inserted; - for (auto tv_entry : tv_deuqe) { - if (inserted.emplace(tv_entry).second) { - unique_entries.emplace_back(tv_entry); - } - } - return unique_entries; +std::vector uniqueEntries(const std::vector& tv_vector) { + VectorOfUniqueEntries unique_vector(tv_vector.begin(), tv_vector.end()); + return unique_vector.vector(); } } // namespace @@ -376,6 +370,24 @@ std::vector allTvs(Fusion* fusion) { return uniqueEntries(all_tvs); } +std::vector allTvsOfExprs(const std::vector& exprs) { + std::vector all_tvs; + std::unordered_set added; + for (auto expr : exprs) { + auto input_tvs = ir_utils::filterByType(expr->inputs()); + auto output_tvs = ir_utils::filterByType(expr->outputs()); + for (bool input : {true, false}) { + auto& tvs = input ? input_tvs : output_tvs; + for (auto tv : tvs) { + if (added.emplace(tv).second) { + all_tvs.push_back(tv); + } + } + } + } + return all_tvs; +} + std::vector allTvsExcept( Fusion* fusion, const std::unordered_set& except) { diff --git a/csrc/ir_utils.h b/csrc/ir_utils.h index 70e3052de92..9eb5131e2e4 100644 --- a/csrc/ir_utils.h +++ b/csrc/ir_utils.h @@ -303,8 +303,13 @@ TORCH_CUDA_CU_API std::vector outputTvsOf( std::vector tvs); // returns all tensor views in fusion that are used between outputs and inputs. +// List is topologically sorted. TORCH_CUDA_CU_API std::vector allTvs(Fusion* fusion); +// returns all tensor views used in the provided expressions +TORCH_CUDA_CU_API std::vector allTvsOfExprs( + const std::vector& exprs); + // returns all tensor views in fusion that are used between outputs and inputs // except the specified set. TORCH_CUDA_CU_API std::vector allTvsExcept( diff --git a/csrc/lower2device.cpp b/csrc/lower2device.cpp index 140daf8c3a8..68c322b40e9 100644 --- a/csrc/lower2device.cpp +++ b/csrc/lower2device.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -252,6 +253,7 @@ void assignRNGOffset(Fusion* fusion) { void dumpExprsIfEnabled( const std::vector& exprs, std::string pass_name, + bool force_expr_disable = true, bool force_enable = false) { auto enabled_by_env = [&pass_name]() { if (!isDebugDumpEnabled(DebugDumpOption::LowerVerbose)) { @@ -262,8 +264,12 @@ void dumpExprsIfEnabled( args.empty() || std::find(args.begin(), args.end(), pass_name) != args.end()); }; - if (force_enable || enabled_by_env()) { + bool name_only = isDebugDumpEnabled(DebugDumpOption::LowerNameOnly); + if (name_only || force_enable || enabled_by_env()) { std::cout << "After " << pass_name << ":" << std::endl; + if (name_only || force_expr_disable) { + return; + } for (auto exp : exprs) { std::cout << exp->toString() << std::endl; } @@ -308,17 +314,19 @@ void GpuLower::lower(Fusion* fusion) { // prepare for lowering validateIr(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateIr"); + dumpExprsIfEnabled(fusion_->exprs(), "validateIr", true); // Checks if any TIDx dim is marked as padded to a warp. Also checks if we can // determine the padding is explicitly a single warp. collectPaddedParallelDims(); - dumpExprsIfEnabled(fusion_->exprs(), "collectPaddedParallelDims"); + dumpExprsIfEnabled(fusion_->exprs(), "collectPaddedParallelDims", true); // Replaces integers that are tensor sizes by named scalars as "T0.size[0]" replaceSymbolicSizes(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "replaceSymbolicSizes"); + IterDomainGraphs test(fusion_); + // Build what's refered to as the compute at map. This map contains the // mappings of all iteration domains across the fusion. There are three types // of mappings Permissive, Exact, and Loop, see compute_at_map.h/cpp for more @@ -326,7 +334,7 @@ void GpuLower::lower(Fusion* fusion) { compute_at_map_ = std::make_shared(fusion_); resolveComputeWith(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "resolveComputeWith"); + dumpExprsIfEnabled(fusion_->exprs(), "resolveComputeWith", true); if (isDebugDumpEnabled(DebugDumpOption::ComputeAtMap)) { std::cout << compute_at_map_->toString() << std::endl; @@ -336,34 +344,35 @@ void GpuLower::lower(Fusion* fusion) { // Uses compute_at_map, find all splits that are enforced to be divisible divisible_splits_ = getAllDivisibleSplits(fusion_, compute_at_map_.get()); - dumpExprsIfEnabled(fusion_->exprs(), "getAllDivisibleSplits"); + dumpExprsIfEnabled(fusion_->exprs(), "getAllDivisibleSplits", true); // Used in parallel dimension map concretized_broadcast_domains_ = std::make_shared(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build ConcretizedBroadcastDomains"); + dumpExprsIfEnabled( + fusion_->exprs(), "build ConcretizedBroadcastDomains", true); parallelDimensionMap().build(fusion_); if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) { std::cout << "Parallel dimension map:" << std::endl; std::cout << parallel_dimension_map_.toString() << std::endl; } - dumpExprsIfEnabled(fusion_->exprs(), "build parallelDimensionMap"); + dumpExprsIfEnabled(fusion_->exprs(), "build parallelDimensionMap", true); // Validate mma data format and compatibility if any on the fusion. validateMma(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateMma"); + dumpExprsIfEnabled(fusion_->exprs(), "validateMma", true); // Validate swizzle usage on the fusion schedule. validateSwizzle(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateSwizzle"); + dumpExprsIfEnabled(fusion_->exprs(), "validateSwizzle", true); validateResize(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "validateResize"); // Compute thread predicates. Depends on parallel_dimension_map_ thread_pred_map_.build(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build thread_pred_map_"); + dumpExprsIfEnabled(fusion_->exprs(), "build thread_pred_map_", true); // Fuse cetain patterns of reductions, such as a grid reduction // followed by a grid broadcast. Only depends on parallelization and @@ -374,26 +383,27 @@ void GpuLower::lower(Fusion* fusion) { // Scan the whole fusion and build mappings about halo extensions of // all IterDomains halo_info_ = std::make_shared(fusion_, compute_at_map_); - dumpExprsIfEnabled(fusion_->exprs(), "build HaloInfo"); + dumpExprsIfEnabled(fusion_->exprs(), "build HaloInfo", true); // Want to run this after parallel map and halo info map are // created. vectorized_accesses_ and vectorized_set_info_ are filled. validateAndCollectVectorizeInfo(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateAndCollectVectorizeInfo"); + dumpExprsIfEnabled(fusion_->exprs(), "validateAndCollectVectorizeInfo", true); // Depends on ComputeAtMap and HaloInfo. validateAndConvertIterDomainGrouping(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateAndConvertIterDomainGrouping"); + dumpExprsIfEnabled( + fusion_->exprs(), "validateAndConvertIterDomainGrouping", true); // Assumes all grouped reductions are convered to // GroupedReductionOp, which is done by // validateAndConvertIterDomainGrouping validateGroupedReductions(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateGroupedReductions"); + dumpExprsIfEnabled(fusion_->exprs(), "validateGroupedReductions", true); // all of the lookup TVs are fusion inputs validateLookupTV(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateLookupTV"); + dumpExprsIfEnabled(fusion_->exprs(), "validateLookupTV", true); // Depends on thread_pred_map_, validates parallelization collects which // tensor views need WAR or RAW syncs @@ -401,27 +411,27 @@ void GpuLower::lower(Fusion* fusion) { if (isDebugDumpEnabled(DebugDumpOption::SyncMap)) { std::cout << sync_map_->toString() << std::endl; } - dumpExprsIfEnabled(fusion_->exprs(), "SyncMap"); + dumpExprsIfEnabled(fusion_->exprs(), "SyncMap", true); partialSplitMap().build(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build partialSplitMap"); + dumpExprsIfEnabled(fusion_->exprs(), "build partialSplitMap", true); validatePartialSplit(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validatePartialSplit"); + dumpExprsIfEnabled(fusion_->exprs(), "validatePartialSplit", true); nonDivisibleSplitInfo().build(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build nonDivisibleSplitInfo"); + dumpExprsIfEnabled(fusion_->exprs(), "build nonDivisibleSplitInfo", true); // Detects all exprssions that don't need predicates. Depends on // nonDivisibleSplitInfo. pred_elimination_ = std::make_unique(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination"); + dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination", true); doubleBufferInfo().build(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build doubleBufferInfo"); + dumpExprsIfEnabled(fusion_->exprs(), "build doubleBufferInfo", true); compute_at_map_->allocateIndexVariables(); - dumpExprsIfEnabled(fusion_->exprs(), "allocateIndexVariables"); + dumpExprsIfEnabled(fusion_->exprs(), "allocateIndexVariables", true); // Run our passes keeping the lowered expressions and forwarding // them diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index ce708f76edc..3b03ae31895 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -14,6 +14,64 @@ namespace nvfuser { +Expr* ReplayTransform::replayAs( + const std::vector& ordered_inputs, + const Expr* expression_to_match) { + ReplayTransform replay(ordered_inputs, expression_to_match); + return replay.replayed_expr_; +} + +ReplayTransform::ReplayTransform( + const std::vector& ordered_inputs, + const Expr* expression_to_match) + : input_ids_(ordered_inputs) { + OptOutConstDispatch::handle(expression_to_match); +} + +// We're going to replay this split operation on the corresponding ID +void ReplayTransform::handle(const Split* split) { + TORCH_INTERNAL_ASSERT( + input_ids_.size() == 1, + "Expected one input to match split: ", + split->toString()); + replayed_expr_ = IterDomain::split( + input_ids_[0], + split->factor(), + split->innerSplit(), + split->startOffset(), + split->stopOffset()) + .first->definition(); +} + +// We're going to replay this merge operation on the corresponding IDs +void ReplayTransform::handle(const Merge* merge) { + TORCH_INTERNAL_ASSERT( + input_ids_.size() == 2, + "Expected two inputs to match merge: ", + merge->toString()); + replayed_expr_ = + IterDomain::merge(input_ids_[0], input_ids_[1])->definition(); +} + +// We're going to replay this swizzle operation on the corresponding IDs +// if replaying swizzle is enabled. +void ReplayTransform::handle(const Swizzle2D* swizzle_2d) { + TORCH_INTERNAL_ASSERT( + input_ids_.size() == 2, + "Expected two inputs to match swizzle: ", + swizzle_2d->toString()); + replayed_expr_ = IterDomain::swizzle( + swizzle_2d->swizzleType(), + input_ids_[0], + input_ids_[1], + swizzle_2d->swizzleMode()) + .first->definition(); +} + + +void ReplayTransform::handle(const Resize* resize) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} // Transform dispatch void ReplayTransformations::handle(Expr* e) { auto is_supported_expr = e->isOneOf(); @@ -666,142 +724,112 @@ int BestEffortReplay::findFirstMismatchedID( return std::min(td1->nDims(), td2->nDims()); } -namespace { +ForwardingInfo::ForwardingInfo( + const TensorView* producer, + const TensorView* consumer) { + // Active indicates the TV that has axes the other TV does not. For + // broadcast this is the consumer squeeze the producer. + // + // Either producer or consumer maps depending on operation + std::unordered_map* active_forwarding_map = nullptr; + std::unordered_map>* + active_compliment_map = nullptr; + + // Either squeeze or broadcast dimension flags depending on operation + const std::vector* active_dim_flags = nullptr; + + // Either producer or consumer depending on operation + std::vector active_root_dom; + const TensorView* active_tv = nullptr; + + if (auto bop = dynamic_cast(consumer->definition())) { + active_forwarding_map = &consumer_forwarding_map; + active_compliment_map = &consumer_compliment_map; + active_dim_flags = &bop->getBroadcastDimFlags(); + active_root_dom = consumer->getRootDomain(); + active_tv = consumer; + } else if (auto sop = dynamic_cast(consumer->definition())) { + active_forwarding_map = &producer_forwarding_map; + active_compliment_map = &producer_compliment_map; + active_dim_flags = &sop->getSqueezeDimFlags(); + active_root_dom = + TensorDomain::noReductions(producer->getMaybeRFactorDomain()); + active_tv = producer; + } else { + return; + } -// Maps that track information relevant to best effort replay about newly added -// or squeezed broadcast axes -// -// For example if we have consumer: T0[i0, b1, b2, i3] and producer: -// T1[i0, i3] -// -// If consumer transformations are: -// -> T[i0, b1o, b1i, b2o, b2i, i3] -// -> T[i0*b1i, b1o, b2o, b2i, i3] -// -> T[i0*b1i*b2o, b1o, b2i, i3] -// -> T[i0*b1i*b2o*i3, b1o, b2i] -// -// forwarding_map would forward i0->i0*b1i and i0*b1i->i0*b1i*b2o -// compliment_map would have the entry i0->b1i and i0*b1i->b2o -// -// The first is to fast forward transformations in consumer involving broadcast -// axes not in producer. The compliment map is to use later to compute what leaf -// nodes we may have after the forwarding process is finished. Leaf nodes are -// only important for replayCasP, so look there to see how this is done. Forward -// map is used for replayCasP and replayPasC. -struct ForwardingInfo { - public: - // Map IterDomain* axes that can safely be forwarded to their output. - std::unordered_map producer_forwarding_map; - std::unordered_map consumer_forwarding_map; - - // Given a forward id map id_input -> id_forwarded - // Track the other inputs in the expr that id_input is an input to. These will - // be used to adjust the replay's leaf tracking. Don't need to track one to - // many as currently transformations on IterDomains can only have maximum 2 - // inputs, but maybe in the future we'll have more. - std::unordered_map> - producer_compliment_map; - std::unordered_map> - consumer_compliment_map; - - ForwardingInfo(const TensorView* producer, const TensorView* consumer) { - // Either producer or consumer maps depending on operation - std::unordered_map* active_forwarding_map = - nullptr; - std::unordered_map>* - active_compliment_map = nullptr; - - // Either squeeze or broadcast dimension flags depending on operation - const std::vector* active_dim_flags = nullptr; - - // Either producer or consumer depending on operation - std::vector active_root_dom; - const TensorView* active_tv = nullptr; - - if (auto bop = dynamic_cast(consumer->definition())) { - active_forwarding_map = &consumer_forwarding_map; - active_compliment_map = &consumer_compliment_map; - active_dim_flags = &bop->getBroadcastDimFlags(); - active_root_dom = consumer->getRootDomain(); - active_tv = consumer; - } else if (auto sop = dynamic_cast(consumer->definition())) { - active_forwarding_map = &producer_forwarding_map; - active_compliment_map = &producer_compliment_map; - active_dim_flags = &sop->getSqueezeDimFlags(); - active_root_dom = - TensorDomain::noReductions(producer->getMaybeRFactorDomain()); - active_tv = producer; - } else { - return; - } + TORCH_INTERNAL_ASSERT(active_root_dom.size() == active_dim_flags->size()); - // Collect which root ids are only in active_tv but not in the inactive - // tensor. - std::unordered_set forwarded_ids; - TORCH_INTERNAL_ASSERT(active_root_dom.size() == active_dim_flags->size()); - for (auto i : c10::irange(active_dim_flags->size())) { - if (active_dim_flags->at(i)) { - forwarded_ids.emplace(active_root_dom.at(i)); - } + // Collect which root ids are only in active_tv but not in the inactive + // tensor. + // + // Initialize which id's should beforwarded. + std::unordered_set forwarded_ids; + for (auto i : c10::irange(active_dim_flags->size())) { + if (active_dim_flags->at(i)) { + forwarded_ids.emplace(active_root_dom.at(i)); } + } - // We have root axes in active_tv that don't exist in the inactive tensor, - // now forward those to include all id's in active_tv comprised of only axes - // not in the inactive tensor. - std::vector active_tv_history = StmtSort::getExprs( - FusionGuard::getCurFusion(), - std::vector( - active_tv->domain()->domain().begin(), - active_tv->domain()->domain().end())); - - auto isIdOnlyInActiveTv = [&forwarded_ids](IterDomain* input_id) { - return forwarded_ids.count(input_id) > 0; - }; - - for (auto expr : active_tv_history) { - auto input_ids = ir_utils::filterByType(expr->inputs()); - // If expr inputs are all in forwarded_ids, then so are all outputs - if (std::all_of(input_ids.begin(), input_ids.end(), isIdOnlyInActiveTv)) { - for (auto output_ids : - ir_utils::filterByType(expr->outputs())) { - forwarded_ids.emplace(output_ids); - } - } else if ( - expr->isA() && - std::any_of(input_ids.begin(), input_ids.end(), isIdOnlyInActiveTv)) { - auto merge_expr = expr->as(); - // If - // - one of the inputs is made of id's in active_tv that don't map to - // the inactive tensor, - // - && the other input maps to an id in both the active and inactive - // tensor - // - && this is a merge - // - // For the sake of BestEffortReplay we can forward the input mapping - // to both the active and inactive tensor to the output of the - // expression - std::vector forwarded_ids; - std::vector compliment_ids; - - for (auto input_id : input_ids) { - if (!isIdOnlyInActiveTv(input_id)) { - forwarded_ids.emplace_back(input_id); - active_forwarding_map->emplace( - std::make_pair(input_id, merge_expr->out())); - } else { - compliment_ids.push_back(input_id); - } - } + // We have root axes in active_tv that don't exist in the inactive tensor, + // now forward those to include all id's in active_tv comprised of only axes + // not in the inactive tensor. + std::vector active_tv_history = StmtSort::getExprs( + FusionGuard::getCurFusion(), + std::vector( + active_tv->domain()->domain().begin(), + active_tv->domain()->domain().end())); - // Set up compliment map - for (auto forwarded_id : forwarded_ids) { - active_compliment_map->emplace( - std::make_pair(forwarded_id, compliment_ids)); + auto isInForwardIdSet = [&forwarded_ids](IterDomain* input_id) { + return forwarded_ids.count(input_id) > 0; + }; + + for (auto expr : active_tv_history) { + auto input_ids = ir_utils::filterByType(expr->inputs()); + // If expr inputs are all in forwarded_ids, then so are all outputs + if (std::all_of(input_ids.begin(), input_ids.end(), isInForwardIdSet)) { + for (auto output_ids : + ir_utils::filterByType(expr->outputs())) { + forwarded_ids.emplace(output_ids); + } + } else if ( + expr->isA() && + std::any_of(input_ids.begin(), input_ids.end(), isInForwardIdSet)) { + auto merge_expr = expr->as(); + // If + // - one of the inputs is made of id's in active_tv that don't map to + // the inactive tensor, + // - && the other input maps to an id in both the active and inactive + // tensor + // - && this is a merge + // + // For the sake of BestEffortReplay we can forward the input mapping + // to both the active and inactive tensor to the output of the + // expression + std::vector forwarded_ids; + std::vector compliment_ids; + + for (auto input_id : input_ids) { + if (!isInForwardIdSet(input_id)) { + forwarded_ids.emplace_back(input_id); + active_forwarding_map->emplace( + std::make_pair(input_id, merge_expr->out())); + } else { + compliment_ids.push_back(input_id); } } + + // Set up compliment map + for (auto forwarded_id : forwarded_ids) { + active_compliment_map->emplace( + std::make_pair(forwarded_id, compliment_ids)); + } } } -}; +} + +namespace { // Trace chain of swizzles until reaching // an IterDomain that's either a leaf or diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index 076ece21f95..ade428f9542 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -32,6 +32,43 @@ struct id_int_lt { } // namespace +class ReplayTransform : OptOutConstDispatch { + public: + // Replays expression_to_match with the provided ordered_inputs. Inputs should + // be ordered as they would be used in provided expression. Returns new + // replayed expression. + static Expr* replayAs( + const std::vector& ordered_inputs, + const Expr* expression_to_match); + + private: + ReplayTransform() = delete; + + ReplayTransform( + const std::vector& ordered_inputs, + const Expr* expression_to_match); + + using OptOutConstDispatch::handle; + + // We're going to replay this split operation on the corresponding ID + void handle(const Split* split) override; + + // We're going to replay this merge operation on the corresponding IDs + void handle(const Merge* merge) override; + + // We're going to replay this swizzle operation on the corresponding IDs + // if replaying swizzle is enabled. + void handle(const Swizzle2D* swizzle_2d) override; + + + // We're going to replay this resize operation on the corresponding IDs + // if replaying resize is enabled. + void handle(const Resize* resize) override; + + Expr* replayed_expr_ = nullptr; + const std::vector& input_ids_; +}; + // Uses the history of _target_domain, and replays that history using the // provided map. // @@ -150,7 +187,72 @@ class TORCH_CUDA_CU_API ReplayTransformations : public IterVisitor { bool ran_replay_ = false; // Mark if replay has been run }; +// Maps that track information relevant to best effort replay about newly added +// or squeezed broadcast axes +// +// For example if we have consumer: T0[i0, b1, b2, i3] and producer: +// T1[i0, i3] +// +// If consumer transformations are: +// -> T[i0, b1o, b1i, b2o, b2i, i3] +// -> T[i0*b1i, b1o, b2o, b2i, i3] +// -> T[i0*b1i*b2o, b1o, b2i, i3] +// -> T[i0*b1i*b2o*i3, b1o, b2i] +// +// forwarding_map would forward i0->i0*b1i and i0*b1i->i0*b1i*b2o +// compliment_map would have the entry i0->b1i and i0*b1i->b2o +// +// The first is to fast forward transformations in consumer involving broadcast +// axes not in producer. The compliment map is to use later to compute what leaf +// nodes we may have after the forwarding process is finished. Leaf nodes are +// only important for replayCasP, so look there to see how this is done. Forward +// map is used for replayCasP and replayPasC. +class ForwardingInfo { + public: + // Map IterDomain* axes that can safely be forwarded to their output. + std::unordered_map producer_forwarding_map; + std::unordered_map consumer_forwarding_map; + + // Given a forward id map id_input -> id_forwarded + // Track the other inputs in the expr that id_input is an input to. These will + // be used to adjust the replay's leaf tracking. Don't need to track one to + // many as currently transformations on IterDomains can only have maximum 2 + // inputs, but maybe in the future we'll have more. + std::unordered_map> + producer_compliment_map; + std::unordered_map> + consumer_compliment_map; + + ForwardingInfo(const TensorView* producer, const TensorView* consumer); + + ForwardingInfo() = delete; +}; + /* + * Short Description: + * + * Given an Expr in target_domain, check if its inputs are in replay_map. If so, + * check if the mapped domain in replay_map are recorded to be transformed by an + * "equivelent" operation in replay_domain's history. If so, "forward" the + * operation and update replay_map to map the outputs of the expressions across + * target_domain and reference_domain. + * + * replay_map maps root IDs in the history of target_domain to root IDs in the + * history replay_domain. PasC and CasP is just a convenient mechanism to have + * BestEffortReplay make this base root mapping. + * + * Note: See ForwardingInfo in transform_iter.cpp for more information on + * forwarding. + * + * Side note potentially for the future: In theory we could actually disconnect + * T4's view from it's rfactor domain. This would allow rfactor domains to be + * "reversible". The way this would have to be implemented is that there just + * needs to be a path of transformations from a tensors leaf domains, to its + * root domains, and its rfactor domain. It shouldn't really matter if those + * connections are forward or backward through transformations. The only thing + * that really matters is they're connected. This is left for future work as it + * could have significant impact on other parts of the system like how loops are + * generated and expressions are sorted. * Motivation: * * Consider the following program: @@ -165,44 +267,73 @@ class TORCH_CUDA_CU_API ReplayTransformations : public IterVisitor { * T1[I0, R1i] = T4[I0, R1orf, I1irf] * T2[I0] = T1[I0, R1i] * - * There's an issue when we call replayCasP on - * T4[I0, R1o, I1i] = T0[I0, I1] + * There's an issue when we want to replay T4 to have transformations similar to + * those on T0. Primarily T0's "rfactor" domain has a strict match requirement + * on T4's root domain. If transformations on top of T0 don't match T4's + * transformations (from T4's root domain to T4's rfactor domain), T4 cannot be + * replayed like T0 on those domains as they would generate incorrect code in + * the system today. * - * This would try to replay T4 as T0, and it could include the rfactor domains. - * For example we compute T0 inline with T4. The way computeAt is setup this - * would call replayPasC(T0, T4, -1) then repalyCasP(T4, T0, -1) + * T0 doesn't have this constraint if we want to replay T0 as T4, so this is + * directional based on rfactor. Therefore to replay T0 transformations onto T4 + * we want to make sure those transformations are consistent with T4 (between + * T4's root and rfactor domain). Best Effort Replay does not actually add any + * transformations to the tensors provided. However, it will provide information + * to determine producers's transformations are consistent consumers + * transformations (or the other way around). Best Effort Replay will return + * discovered mappings between tensors that it detects to be matching based on + * provided initial information (or just through p2c/c2p root domain mappings). * - * We might assume that the only way we will hit this is if we call - * T4->computeAt(T0...) so it might be safe to assume that the right - * transformations would be replayed. However, we want to preserve the rfactor - * domain, so since it would replay T4 at root, it would produce iterdomains - * that wouldn't corresopnd to those in rfactor. Also, I don't know if this - * assumption is correct. + * Transformations have a concept of "permissiveness" used for broadcast and + * squeeze. For example: * - * Therefore, we will assume it is not correct, and we will validate here that - * if we replay a domain that it would transform it in a way consistent with - * any defined RFactor domains, then we will update the replay map so that - * RFactor roots are mapped to intermediate IterDomains in the target and start - * replay from there. + * T1[I0, B1] = T0[I0] + * T2[I0, I1] = T1[I0, B1] * + * We may want to replay T1 and T0 based on transformations on T2. These + * transformations may involve B1. We could even have: * - * SHORT DESCRIPTION: + * T2->merge(0, 1)->split(0, 128) * - * This class will validate/do the above. It will also run through - * transformations in target according to replay_map. If equal transformations - * already exist in replay_domain history, we will not redo those - * transformations, but instead update replay_map to reflect forwarding the - * existing transformations. This later part is the "best effort" replay. Though - * we include rfactor replay and validation here. + * resulting in: * - * Given an Expr in target_domain, check if its inputs are in replay_map. If so, - * check if the mapped domain in replay_map are recorded to be transformed by an - * equivelent operation in replay_domain's history. If so, "forward" the - * operation and update replay_map to the outputs of target_domain's output(s), - * to the output of the equivlent expr's outputs in relpay_domain's history. + * T2[(I0*I1)/128, 128] * - * replay_map maps root IDs in the history of target_domain to root IDs in the - * history replay_domain + * T0 doesn't have I1 so it can't technicaly be transformed in an exactly + * consistent way. However, it may still be desired to "inline" T0 into T1 and + * in result T1 into T2. It may further be desired to bind BIDx and TIDx to the + * two dimensions in the problem. This example doesn't "technically" result in + * thread to thread communication, but since our scope in mind is a shared + * global memory it results in duplicate reads. These duplicate reads are + * automatically cached in our memory hierarchy. So in a way there is implicit + * communication in that a memory location is read by multiple threads. + * + * This is where forwarding and permissiveness come into play. When we transform + * T1 with the first merge, we will mark the result I0*B1 of T1 to be + * "permissively" mapped to I0 of T0, so when we perform the split, we split + * T0's I0 dimension to I0/128 and 128. This is to help us mark inlining and + * paralellization across these dimensions so we can effectively reason about + * the "not full" dimension in T0. This is where the concept of forward map in + * BestEffortReplay comes in. + * + * Permissiveness can also be considered "symmetric" across broadcast and + * squeeze as they are similar operations, however broadcast and squeeze do have + * different implications since squeeze doesn't result in the implicit + * communication described in the previous paragraph. However, as far as + * forwarding is concerned they're symmetric. Indexing/parallelization has + * significant logic dedicated to broadcast resolutions (unlike squeeze). + * + * This class provides a mechanism to annalyze all of the above concepts. It + * can also run through transformations in target according to a manually + * specified IterDomain to IterDomain replay_map. If equal transformations + * already exist in replay_domain history, we will not redo those + * transformations, but instead update replay_map to reflect forwarding the + * existing transformations based on a notion of expresions being "equal" (input + * IterDomains mapped and transformation expression parameters matching, or the + * iter domain that doesn't match is in a forwarding map). The replay map is the + * "best effort" part of BestEffortReplay, it doesn't actually perform new + * transformations to enforce matching, it just detects existing matching + * transforms. However, we still include rfactor validation within. */ class TORCH_CUDA_CU_API BestEffortReplay { @@ -213,17 +344,20 @@ class TORCH_CUDA_CU_API BestEffortReplay { std::unordered_map leaf_ids_; std::vector forwarded_ids_; - // Need to track which id's have been forwarded. Later need to make sure leaf - // nodes to produce compliment axes are properly tracked. i.e. + // Need to track which id's have been forwarded. Later will need to make sure + // leaf nodes to produce "compliment" axes are properly tracked. i.e. // T[i0, b1, b2, i3] // -> T[i0, b1o, b1i, b2o, b2i, i3] // -> T[i0*b1i*b2o, b1o, b2i, i3] // -> T[i0*b1i*b2o*i3, b1o, b2i] // If we forwarded i0 -> i0*b1i*b2o*i3, we need to know that b1o and b2i - // are leaf nodes even though their split wasn't part of targets replay. + // are leaf nodes even though their split wasn't part of targets replay. These + // are important IterDomains to track for transformation replays as otherwise + // we could easily drop axes we need by accident // Counter to make sure best effort replay leaf_ids can be grabbed - // deterministicly + // deterministicly, important to make sure replays are run to run + // deterministic. size_t counter = 0; // Determine if current replay will ignore swizzle ops. @@ -261,6 +395,10 @@ class TORCH_CUDA_CU_API BestEffortReplay { // I02->I12 // } // + // TODO: Reevaluate swizzle and transform replays. We have some concepts on + // iter domain mapping we should formalize. It would be good to have these + // options accessible while specified in a consistent manner. + // https://github.com/ftxj/pytorch/pull/1#pullrequestreview-1210168522 bool skip_replay_swizzle_ = true; bool skip_target_swizzle_ = true; diff --git a/csrc/type.cpp b/csrc/type.cpp index 64e370ac0d7..d28058e7332 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -675,10 +675,14 @@ static const char* id_map_mode_type2string(IdMappingMode t) { return "exact"; case IdMappingMode::ALMOSTEXACT: return "almost_exact"; - case IdMappingMode::PERMISSIVE: - return "permissive"; + case IdMappingMode::INDEX: + return "index"; case IdMappingMode::LOOP: return "loop"; + case IdMappingMode::PERMISSIVE: + return "permissive"; + case IdMappingMode::PERMISSIVE_RESIZE: + return "permissive_resize"; default: // Don't try to print t as it would recursively call this function TORCH_INTERNAL_ASSERT(false, "Unexpected IdMappingMode Type."); diff --git a/csrc/type.h b/csrc/type.h index 8b71e32980a..4aa04f8d72f 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -374,13 +374,15 @@ enum class IdMappingMode { EXACT, ALMOSTEXACT, LOOP, + INDEX, PERMISSIVE, PERMISSIVE_RESIZE }; -static constexpr std::array kIdMappingModes = { +static constexpr std::array kIdMappingModes = { IdMappingMode::EXACT, IdMappingMode::ALMOSTEXACT, + IdMappingMode::INDEX, IdMappingMode::LOOP, IdMappingMode::PERMISSIVE, IdMappingMode::PERMISSIVE_RESIZE}; diff --git a/csrc/utils.cpp b/csrc/utils.cpp index aeb67dde0d8..3df6cfc2fd4 100644 --- a/csrc/utils.cpp +++ b/csrc/utils.cpp @@ -135,6 +135,7 @@ auto parseDebugDumpOptions() { {"bank_conflict", DebugDumpOption::BankConflictInfo}, {"sync_map", DebugDumpOption::SyncMap}, {"lower_verbose", DebugDumpOption::LowerVerbose}, + {"lower_name_only", DebugDumpOption::LowerNameOnly}, {"expr_simplify", DebugDumpOption::ExprSimplification}, {"expr_sort", DebugDumpOption::ExprSort}, {"loop_rotation", DebugDumpOption::LoopRotation}}; diff --git a/csrc/utils.h b/csrc/utils.h index 7b69aa2498c..aa354e38cb8 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -76,6 +76,7 @@ enum class DebugDumpOption { BankConflictInfo, //! Dump bank confliction info SyncMap, //! RAW dependency info LowerVerbose, //! Print all passes' transform in GpuLower::lower + LowerNameOnly, //! Print pass names as they're finished ExprSimplification, //! Print all passes' transform in simplifyExpr ExprSort, //! Print merging decisions on expression sorting LoopRotation, //! Print loop rotation log diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 3213f1d3348..0b39caae178 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -5287,41 +5287,6 @@ TEST_F(NVFuserTest, FusionScheduleTransposeRepro1_CUDA) { &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__); } -// Repro for issue #1873 -TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - auto tv1 = makeContigTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - auto tv2 = set(tv0); - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = add(tv3, tv1); - fusion.addOutput(tv4); - - tv4->merge(0); - tv4->split(0, 32); - - tv0->computeAt(tv4, 1); - - tv2->split(-1, 8); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({123}, options); - at::Tensor t1 = at::randn({3, 123}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - - auto outputs = fe.runFusion({t0, t1}); - - auto tv_ref = t0 + t1; - - testValidate(&fusion, outputs, {t0, t1}, {tv_ref}, __LINE__, __FILE__); -} - TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) { // https://github.com/csarofeen/pytorch/issues/1926 std::unique_ptr fusion_ptr = std::make_unique(); diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index ccc9c22cd91..b3e7ec5c786 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -77,6 +78,7 @@ TEST_F(NVFuserTest, FusionIndexing1_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } +// Same as 1 but merge starting from inner most dimension TEST_F(NVFuserTest, FusionIndexing2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -131,6 +133,7 @@ TEST_F(NVFuserTest, FusionIndexing2_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } +// Same compute as 1 and 2 but use a scheduler. TEST_F(NVFuserTest, FusionIndexing3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -165,6 +168,7 @@ TEST_F(NVFuserTest, FusionIndexing3_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } +// Same as 3 but use 3 dimensions and concrete sizes TEST_F(NVFuserTest, FusionIndexing4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -372,8 +376,8 @@ TEST_F(NVFuserTest, FusionIndexing8_CUDA) { &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); } +// Same as 5 but using implicit broadcast TEST_F(NVFuserTest, FusionIndexing9_CUDA) { - // Same as 7 but with outer splits instead of inner Fusion fusion; FusionGuard fg(&fusion); @@ -791,44 +795,266 @@ TEST_F(NVFuserTest, FusionIndexing17_CUDA) { &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -// Repro of issue #2560 +// TODO: Finish and enable test TEST_F(NVFuserTest, FusionIndexing18_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - auto tv0 = makeSymbolicTensor(1); + TensorView* tv0 = makeConcreteTensor({5, 7, 11, 13}); fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - auto tv2 = broadcast(tv0, {false, true}); - auto tv3 = add(tv2, tv1); - auto tv4 = sum(tv3, {0, 1}); + auto tv1 = set(tv0); + + auto tv2 = makeConcreteTensor({5, 11}); + fusion.addInput(tv2); + + auto tv3 = broadcast(tv2, {false, true, false, true}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + // // tv4[5, 7, 11, 13] = tv3[5, b1, 11, b3] + tv1[5, 7, 11, 13] + tv4->merge(0, 3); + // tv4[5*13, 7, 11] + tv4->split(0, 3); + // tv4[5*13//3, 3, 7, 11] + tv4->merge(2, 3)->split(2, 2); + // tv4[5*13//3, 3, 7*11//2, 2] + // tv4->merge(0, 2); + // // tv4[(5*13//3)*(7*11//2), 3, 2] + + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + inlineAllAt(tv4, 1, false); + fusion.printKernel(); + // std::cout<definition()->toString()<merge(0)->merge(0); + // tv10[7*11*13] + tv10->split(0, 5)->split(0, 3); + // tv10[7*11*13//5//3, 3, 5] + + TransformPropagatorWithCheck propagator(tv10); + MaxRootDomainInfoSpanningTree(tv10).traverse(&propagator); + + std::vector tensors_to_inline{tv1, tv2, tv4, tv6, tv8}; + for (auto tensor : tensors_to_inline) { + tensor->inlineAt(1); + } + + fusion.print(); + fusion.printKernel(); +} + +// TODO: Finish and enable test +// +// Progressive loop promotion. producer gets promoted in consumer, consumer is +// promoted in a different way to its consumer. +TEST_F(NVFuserTest, FusionIndexing20_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({5}); + fusion.addInput(tv0); + + // [5] + auto tv1 = set(tv0); + auto tv2 = broadcast(tv1, {true, false}); + // [1, 5] + auto tv3 = makeConcreteTensor({3, 5}); + fusion.addInput(tv3); + auto tv4 = add(tv3, tv2); + // [3, 5] + + auto tv5 = broadcast(tv4, {false, false, true}); + // [3, 5, 1] + auto tv6 = makeConcreteTensor({3, 5, 7}); + fusion.addInput(tv6); + auto tv7 = add(tv5, tv6); + // [3, 5, 7] + fusion.addOutput(tv7); + + tv4->merge(0)->split(0, 3, false); + // [3, 5] + // [3, 3*5/3] + + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + // tv0->tv1->tv2(b)->tv4->tv5(b)->tv7 + + tv1->inlineAt(1); + tv2->inlineAt(1); + tv4->inlineAt(1); + + tv5->merge(1)->split(1, 5, false); + // [3, 3*5/3, 7] + tv7->merge(1)->split(1, 5, false); + // [3, 5, (3*5/3)*7/5] + tv5->inlineAt(2); + + fusion.printKernel(); +} + +// Repro for issue #1873 +TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + auto tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = set(tv0); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv3, tv1); fusion.addOutput(tv4); tv4->merge(0); - tv4->split(0, 4); - auto tv5 = tv4->rFactor({1}); + tv4->split(0, 32); - MaxRootDomainInfoSpanningTree tree(tv5); - TransformPropagator tp(tv5); - tree.traverse(&tp); + tv0->computeAt(tv4, 1); - inlineAllAt(tv4, 1, true); + tv2->split(-1, 8); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(1); - at::Tensor t0 = at::randn({5}, options); - at::Tensor t1 = at::randn({5, 3}, options); - std::vector inputs = {t0, t1}; + at::Tensor t0 = at::randn({123}, options); + at::Tensor t1 = at::randn({3, 123}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + + auto outputs = fe.runFusion({t0, t1}); + + auto tv_ref = t0 + t1; + + testValidate(&fusion, outputs, {t0, t1}, {tv_ref}, __LINE__, __FILE__); +} + +// Broadcast inline 3 times and merge all domains +TEST_F(NVFuserTest, FusionMultiPromotion_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int w = 3, x = 4, y = 7, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // [y] + auto tv0 = makeSymbolicTensor(1); + // [w, x, y, z] + auto tv1 = makeSymbolicTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // y + auto tv2 = broadcast(tv0, {true, false}); + // w, y, z + auto tv3 = broadcast(tv2, {false, false, true}); + // w, y, z + auto tv4 = broadcast(tv3, {false, true, false, false}); + // w, x, y, z + auto tv5 = add(tv4, tv1); + + fusion.addOutput(tv5); + + tv5->merge(1)->merge(1)->merge(0)->split(0, 11); + + tv0->computeAt(tv5, 1); + tv1->computeAt(tv5, 1); FusionExecutor fe; - fe.compileFusion(&fusion, inputs); - auto cg_outputs = fe.runFusion(inputs); - auto ref = (t0.unsqueeze(-1) + t1).sum(); + at::Tensor t0 = at::randn({y}, options); + at::Tensor t1 = at::randn({w, x, y, z}, options); + + auto t4 = t0.unsqueeze(0).unsqueeze(0).unsqueeze(-1); + auto aten_output = t4.add(t1); + + std::vector aten_inputs = {t0, t1}; + + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +// Broadcast and concretize same domain in two different ways and try to merge +// their loops remains unsupported. +TEST_F(NVFuserTest, FusionMultiPromotion2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + // [w] + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + // [w, x] + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + // [w, y] + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = set(tv0); + // [w] + auto tv4 = broadcast(tv3, {false, true}); + // [w, 1] + auto tv5 = add(tv4, tv2); + // [w, x] + fusion.addOutput(tv5); + + // [w] + auto tv6 = broadcast(tv3, {false, true}); + // [w, 1] + auto tv7 = add(tv6, tv2); + // [y] + + for (auto tv : std::vector{tv4, tv5, tv6, tv7}) { + tv->merge(0); + } + + for (auto tv : std::vector{tv3, tv4, tv6}) { + tv->inlineAt(1); + } - testValidate(fe.kernel(), cg_outputs, inputs, {ref}, __LINE__, __FILE__); + ASSERT_ANY_THROW(fusion.printKernel()); } } // namespace nvfuser From 45959ef7ad2d82a6ef867376ab9cb50ac1c9cf58 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 19 Mar 2023 17:18:21 -0400 Subject: [PATCH 002/178] Helps if I include the right files. --- csrc/id_graphs.cpp | 3164 ++++++++++++++++++++++++++++++++++++++++++++ csrc/id_graphs.h | 465 +++++++ 2 files changed, 3629 insertions(+) create mode 100644 csrc/id_graphs.cpp create mode 100644 csrc/id_graphs.h diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp new file mode 100644 index 00000000000..67ce0d6de81 --- /dev/null +++ b/csrc/id_graphs.cpp @@ -0,0 +1,3164 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace nvfuser { + +namespace debug_print { +// A few compressed printing utilities to show critical uniqueness information. +// i.e. being able to tell slight differences between groups we're working with. + +// Sometimes it can be helpful to directly check the pointer addresses of the +// groups. As one group might look exactly like another group but are in +// different disjoint sets. Leaving commented out by default. + +// template +// std::string ptrStringShort(const T* ptr) { +// std::stringstream ss; +// ss << ptr; +// return "0x." + ss.str().substr(9); +// } + +std::string idsStringShort(const VectorOfUniqueEntries& id_group) { + std::vector names; + for (auto id : id_group) { + names.push_back(id->name()); + } + std::sort(names.begin(), names.end()); + + std::stringstream ss; + ss << "{" << names << "}"; + return ss.str(); +} + +std::string idGroupStringShort(const IdGroup& id_group) { + std::stringstream ss; + ss << /* ptrStringShort(id_group.get()) << */ "(idg)" + << idsStringShort(*id_group); + return ss.str(); +} + +std::string idGroupsStringShortInline(const IdGroups& id_groups) { + // Track position in id_groups and its min iter domain name in the set + std::vector> group_name_info; + + unsigned int pos = 0; + + for (auto id_group : id_groups) { + unsigned int min_id_name = std::numeric_limits::max(); + for (auto id : *id_group) { + if (id->name() < min_id_name) { + min_id_name = id->name(); + } + } + group_name_info.push_back(std::make_pair(min_id_name, pos++)); + } + + // Sort based on minimum id in the group + std::sort(group_name_info.begin(), group_name_info.end()); + + std::stringstream ss; + ss << /* ptrStringShort(&id_groups) <<*/ "(idgs){"; + bool first = true; + for (auto i : c10::irange(group_name_info.size())) { + if (first) { + first = false; + } else { + ss << ", "; + } + auto pos = group_name_info[i].second; + ss << idGroupStringShort(id_groups.vector()[pos]); + } + + ss << "}"; + return ss.str(); +} + +std::string idGroupsStringShort(const IdGroups& id_groups) { + std::stringstream ss; + + // Track position in id_groups and its min iter domain name in the set + std::vector> group_name_info; + + unsigned int pos = 0; + + for (auto id_group : id_groups) { + unsigned int min_id_name = std::numeric_limits::max(); + for (auto id : *id_group) { + if (id->name() < min_id_name) { + min_id_name = id->name(); + } + } + group_name_info.push_back(std::make_pair(min_id_name, pos++)); + } + + ss << /* ptrStringShort(&id_groups) <<*/ "(idgs){\n"; + + // Sort based on minimum id in the group + std::sort(group_name_info.begin(), group_name_info.end()); + + for (auto i : c10::irange(group_name_info.size())) { + auto pos = group_name_info[i].second; + ss << " " << idGroupStringShort(id_groups.vector()[pos]) << "\n"; + } + + ss << "}"; + return ss.str(); +} + +std::string exprGroupStringShort(ExprGroup expr_group) { + std::vector names; + for (auto expr : *expr_group) { + names.push_back(expr->name()); + } + std::sort(names.begin(), names.end()); + + std::stringstream ss; + ss << /* ptrStringShort(&expr_group) <<*/ "(exprg){" << names << "}"; + return ss.str(); +} + +std::string exprGroupStringShort( + const IdGraph& id_graph, + ExprGroup expr_group) { + std::stringstream ss; + auto inputs = id_graph.inputGroups(expr_group); + auto outputs = id_graph.outputGroups(expr_group); + ss << idGroupsStringShortInline(inputs) << " -" + << exprGroupStringShort(expr_group) << "-> " + << idGroupsStringShortInline(outputs); + return ss.str(); +} + +std::string exprGroupsStringShort( + const IdGraph& id_graph, + ExprGroups expr_groups) { + // Track position in expr_groups and its min iter domain name in the set + std::vector> group_name_info; + + unsigned int pos = 0; + + for (auto expr_group : expr_groups) { + unsigned int min_expr_name = std::numeric_limits::max(); + for (auto expr : *expr_group) { + if (expr->name() < min_expr_name) { + min_expr_name = expr->name(); + } + } + group_name_info.push_back(std::make_pair(min_expr_name, pos++)); + } + + // Sort based on minimum id in the group + std::sort(group_name_info.begin(), group_name_info.end()); + + std::stringstream ss; + ss << /* ptrStringShort(&expr_groups) <<*/ "(exprs) {"; + bool first = true; + for (auto i : c10::irange(group_name_info.size())) { + if (first) { + first = false; + } else { + ss << ", "; + } + auto pos = group_name_info[i].second; + ss << exprGroupStringShort(id_graph, expr_groups.vector()[pos]); + } + + ss << "}"; + return ss.str(); +} + +std::string definitionsToString(const IdGraph& id_graph) { + std::stringstream ss; + ExprGroups defs; + for (auto id_group : id_graph.disjointIdSets().disjointSets()) { + auto definition_pair = id_graph.iterDomainGroupDefinitions(id_group); + if (definition_pair.second) { + for (auto expr_group : definition_pair.first) { + defs.pushBack(expr_group); + } + } + } + for (auto expr : defs) { + ss << exprGroupStringShort(id_graph, expr) << std::endl; + } + return ss.str(); +} + +std::string usesToString(const IdGraph& id_graph) { + std::stringstream ss; + + for (auto id_group : id_graph.disjointIdSets().disjointSets()) { + auto uses_pair = id_graph.iterDomainGroupUses(id_group); + ss << idGroupStringShort(id_group) << std::endl; + if (uses_pair.second) { + for (auto expr_group : uses_pair.first) { + ss << " " << exprGroupStringShort(id_graph, expr_group) << std::endl; + } + } + } + return ss.str(); +} + +} // namespace debug_print + +namespace { + +bool transformAtributesMatch(Expr* first, Expr* second) { + if (first == nullptr || second == nullptr) { + return false; + } + + TORCH_INTERNAL_ASSERT( + first->isA() || first->isA() || first->isA(), + "Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n", + first->toString()); + + if (typeid(*first) != typeid(*second)) { + return false; + } + + if (first->isA()) { + auto first_split = first->as(); + auto second_split = second->as(); + if (!first_split->factor()->sameAs(second_split->factor()) || + first_split->innerSplit() != second_split->innerSplit() || + !first_split->startOffset()->sameAs(second_split->startOffset()) || + !first_split->stopOffset()->sameAs(second_split->stopOffset())) { + return false; + } + } + + if (first->isA()) { + auto first_swizzle = first->as(); + auto second_swizzle = second->as(); + if (first_swizzle->swizzleMode() != second_swizzle->swizzleMode() || + first_swizzle->swizzleType() != second_swizzle->swizzleType()) { + return false; + } + } + + return true; +} +} // namespace + +void IdGraphVisitor::traverse() { + IdGroups all_ids; + ExprGroups all_exprs; + { + if (sub_selection_.empty()) { + all_ids = IdGroups( + graph().disjointIdSets().disjointSets().begin(), + graph().disjointIdSets().disjointSets().end()); + } else { + for (auto id : sub_selection_) { + auto disjoint_pair = graph().disjointIdSet(id); + if (disjoint_pair.second) { + all_ids.pushBack(disjoint_pair.first); + } + } + } + + if (sub_selection_.empty()) { + all_exprs = ExprGroups( + graph().disjointExprSets().disjointSets().begin(), + graph().disjointExprSets().disjointSets().end()); + } else { + for (auto id_group : all_ids) { + for (auto def : graph().uniqueDefinitions(id_group)) { + if (all_exprs.has(def)) { + continue; + } + auto inp_groups = graph().inputGroups(def); + auto out_groups = graph().outputGroups(def); + if (inp_groups.subtract(all_ids).empty() && + out_groups.subtract(all_ids).empty()) { + all_exprs.pushBack(def); + } + } + } + } + } + // There could be IterDomains in from or to that are between other from and + // to nodes. Make sure to clear those out. + IdGroups terminating_inputs; + IdGroups terminating_outputs; + + { + IdGroups not_inputs; + IdGroups not_outputs; + for (auto expr_group : all_exprs) { + auto inp_groups = graph().inputGroups(expr_group); + auto out_groups = graph().outputGroups(expr_group); + + if (inp_groups.intersect(out_groups).size() > 0) { + // Expression is just a loop to its current group, ignore + continue; + } + + not_inputs.pushBack(out_groups); + not_outputs.pushBack(inp_groups); + } + + terminating_inputs = + IdGroups(all_ids.begin(), all_ids.end()).subtract(not_inputs); + + terminating_outputs = + IdGroups(all_ids.begin(), all_ids.end()).subtract(not_outputs); + } + + IdGroups to_visit_ids = terminating_inputs; + IdGroups visited_ids; + + ExprGroups to_visit_exprs; + ExprGroups visited_exprs; + + auto is_expr_ready = [&](ExprGroup expr_group) { + auto inp_groups = graph().inputGroups(expr_group); + return std::all_of( + inp_groups.begin(), inp_groups.end(), [&](IdGroup id_group) { + return visited_ids.has(id_group) || id_group->empty(); + }); + }; + + auto is_id_ready = [&](IdGroup id_group) { + auto unique_defs = graph().uniqueDefinitions(id_group); + return std::all_of( + unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { + return expr_group->empty() || visited_exprs.has(expr_group) || + IdGraph::isTrivialExpr(expr_group->front()).size(); + }); + }; + + while (to_visit_ids.size() > 0 || to_visit_exprs.size() > 0) { + // Process expressions first as all definitions of iter domains have to be + // processed before we can process that iter domain. + + // Detect if nothing has been processed which would put us in an infinite + // loop + bool something_was_processed = false; + ExprGroups still_to_visit_exprs; + + while (to_visit_exprs.size() > 0) { + auto current_expr_group = to_visit_exprs.popFront(); + if (visited_exprs.has(current_expr_group)) { + continue; + } + + if (is_expr_ready(current_expr_group)) { + handle(current_expr_group); + + something_was_processed = true; + visited_exprs.pushBack(current_expr_group); + + auto out_groups = graph().outputGroups(current_expr_group); + for (auto out_group : out_groups) { + to_visit_ids.pushBack(out_group); + } + } else { + still_to_visit_exprs.pushBack(current_expr_group); + } + } + + std::swap(to_visit_exprs, still_to_visit_exprs); + + IdGroups still_to_visit_ids; + while (to_visit_ids.size() > 0) { + auto current_id_group = to_visit_ids.popFront(); + if (visited_ids.has(current_id_group)) { + continue; + } + + if (is_id_ready(current_id_group)) { + handle(current_id_group); + + something_was_processed = true; + visited_ids.pushBack(current_id_group); + + if (!terminating_outputs.has(current_id_group)) { + auto uses_pair = graph().iterDomainGroupUses(current_id_group); + if (uses_pair.second) { + to_visit_exprs.pushBack(uses_pair.first); + } + } + + } else { + still_to_visit_ids.pushBack(current_id_group); + } + } + + TORCH_INTERNAL_ASSERT( + something_was_processed || + (to_visit_ids.size() == 0 && to_visit_exprs.size() == 0), + "Infinite loop entered."); + } +} + +IdGraph::IdGraph(const IdGraph& other) { + disjoint_ids_ = other.disjoint_ids_; + disjoint_exprs_ = other.disjoint_exprs_; + id_uses_ = other.id_uses_; + id_definitions_ = other.id_definitions_; + view_rfactor_ids_ = other.view_rfactor_ids_; + + for (auto orig_unique_def_pair : other.unique_definitions_) { + auto orig_id_group = orig_unique_def_pair.first; + auto orig_expr_groups = orig_unique_def_pair.second; + + auto new_id_group_pair = disjointIdSet(orig_id_group->front()); + TORCH_INTERNAL_ASSERT(new_id_group_pair.second); + auto new_id_group = new_id_group_pair.first; + + ExprGroups new_expr_groups; + for (auto orig_expr_group : orig_expr_groups) { + auto new_expr_group_pair = disjointExprSet(orig_expr_group->front()); + TORCH_INTERNAL_ASSERT(new_expr_group_pair.second); + new_expr_groups.pushBack(new_expr_group_pair.first); + } + + unique_definitions_[new_id_group] = new_expr_groups; + } + + for (auto orig_unique_use_pair : other.unique_uses_) { + auto orig_id_group = orig_unique_use_pair.first; + auto orig_expr_groups = orig_unique_use_pair.second; + + auto new_id_group_pair = disjointIdSet(orig_id_group->front()); + TORCH_INTERNAL_ASSERT(new_id_group_pair.second); + auto new_id_group = new_id_group_pair.first; + + ExprGroups new_expr_groups; + for (auto orig_expr_group : orig_expr_groups) { + auto new_expr_group_pair = disjointExprSet(orig_expr_group->front()); + TORCH_INTERNAL_ASSERT(new_expr_group_pair.second); + new_expr_groups.pushBack(new_expr_group_pair.first); + } + + unique_uses_[new_id_group] = new_expr_groups; + } +} + +IdGraph& IdGraph::operator=(const IdGraph& other) { + disjoint_ids_.clear(); + disjoint_exprs_.clear(); + unique_definitions_.clear(); + unique_uses_.clear(); + id_uses_.clear(); + id_definitions_.clear(); + view_rfactor_ids_.clear(); + IdGraph copy(other); + std::swap(*this, copy); + return *this; +} + +const DisjointSets& IdGraph::disjointIdSets() const { + return disjoint_ids_; +} + +DisjointSets& IdGraph::disjointIdSets() { + return disjoint_ids_; +} + +std::pair IdGraph::disjointIdSet(IterDomain* id) const { + auto disjoint_set_it = disjoint_ids_.disjointSetMap().find(id); + if (disjoint_set_it == disjoint_ids_.disjointSetMap().end()) { + return std::make_pair(IdGroup(nullptr), false); + } + return std::make_pair(disjoint_set_it->second, true); +} + +const DisjointSets& IdGraph::disjointExprSets() const { + return disjoint_exprs_; +} + +DisjointSets& IdGraph::disjointExprSets() { + return disjoint_exprs_; +} + +std::pair IdGraph::disjointExprSet(Expr* expr) const { + auto disjoint_set_it = disjoint_exprs_.disjointSetMap().find(expr); + if (disjoint_set_it == disjoint_exprs_.disjointSetMap().end()) { + return std::make_pair(ExprGroup(nullptr), false); + } + return std::make_pair(disjoint_set_it->second, true); +} + +ExprGroups IdGraph::toGroups(const VectorOfUniqueEntries& exprs) const { + ExprGroups expr_groups; + for (auto expr : exprs) { + auto disjoint_set_pair = disjointExprSet(expr); + if (disjoint_set_pair.second) { + expr_groups.pushBack(disjoint_set_pair.first); + } + } + return expr_groups; +} + +IdGroups IdGraph::toGroups( + const VectorOfUniqueEntries& ids) const { + IdGroups id_groups; + for (auto id : ids) { + auto disjoint_set_pair = disjointIdSet(id); + if (disjoint_set_pair.second) { + id_groups.pushBack(disjoint_set_pair.first); + } + } + return id_groups; +} + +IdGroups IdGraph::outputGroups(ExprGroup expr) const { + VectorOfUniqueEntries id_outputs; + for (auto id_output : + ir_utils::filterByType(expr->front()->outputs())) { + id_outputs.pushBack(id_output); + } + return toGroups(id_outputs); +} + +IdGroups IdGraph::inputGroups(ExprGroup expr) const { + VectorOfUniqueEntries id_inputs; + for (auto id_input : + ir_utils::filterByType(expr->front()->inputs())) { + id_inputs.pushBack(id_input); + } + return toGroups(id_inputs); +} + +ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { + ExprGroups to_visit; + for (auto of_id_group : of) { + auto group_uses_pair = iterDomainGroupUses(of_id_group); + if (group_uses_pair.second) { + to_visit.pushBack(group_uses_pair.first); + } + } + + ExprGroups visited; + while (to_visit.size() > 0) { + auto current_expr = to_visit.popFront(); + visited.pushBack(current_expr); + auto output_ids = outputGroups(current_expr); + for (auto output_id : output_ids) { + auto group_uses_pair = iterDomainGroupUses(output_id); + if (!group_uses_pair.second) { + continue; + } + for (auto group_use : group_uses_pair.first) { + if (visited.has(group_use)) { + continue; + } + to_visit.pushBack(group_use); + } + } + } + + return visited; +} + +ExprGroups IdGraph::allDefinitionsOf(const IdGroups& of) const { + ExprGroups to_visit; + for (auto of_id_group : of) { + auto group_defs_pair = iterDomainGroupDefinitions(of_id_group); + if (group_defs_pair.second) { + to_visit.pushBack(group_defs_pair.first); + } + } + + ExprGroups visited; + while (to_visit.size() > 0) { + auto current_expr = to_visit.popFront(); + visited.pushBack(current_expr); + auto input_ids = inputGroups(current_expr); + for (auto input_id : input_ids) { + auto group_defs_pair = iterDomainGroupDefinitions(input_id); + if (!group_defs_pair.second) { + continue; + } + for (auto group_def : group_defs_pair.first) { + if (visited.has(group_def)) { + continue; + } + to_visit.pushBack(group_def); + } + } + } + + return visited; +} + +ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) + const { + auto all_uses_of_from = allUsesOf(from); + auto all_definitions_of_to = allDefinitionsOf(to); + + // All of the expressions between from and to. Not all will be used as we + // just want to define each iter domain group once. + auto all_exprs = all_uses_of_from.intersect(all_definitions_of_to); + + // There could be IterDomains in from or to that are between other from and + // to nodes. Make sure to clear those out. + IdGroups terminating_inputs; + IdGroups terminating_outputs; + { + IdGroups not_inputs; + IdGroups not_outputs; + IdGroups all_id_groups; + + for (auto expr_group : all_exprs) { + auto inp_groups = inputGroups(expr_group); + auto out_groups = outputGroups(expr_group); + if (inp_groups.intersect(out_groups).size() > 0) { + // Expression is just a loop to its current group, ignore + continue; + } + + all_id_groups.pushBack(inp_groups); + + if (inp_groups.empty()) { + not_outputs.pushBack(inp_groups); + } + + all_id_groups.pushBack(out_groups); + + if (out_groups.empty()) { + not_inputs.pushBack(out_groups); + } + } + terminating_inputs = all_id_groups.subtract(not_inputs); + terminating_outputs = all_id_groups.subtract(not_outputs); + } + + // Track all expressions to get from outputs to this IterDomain. We + // traverse backwards as that's the direction of indexing expressions. An + // index is assigned to each leaf of a domain and as we traverse backwards + // we're effectively accumulating indexing math. We'll only keep the fewest + // expression lists to get to the iter domain. + std::unordered_map required_ind_exprs_ids; + std::unordered_map required_ind_exprs_exprs; + + // Return if all output IterDomain groups of an expression group have + // already been visited + auto outputsVisited = [&](ExprGroup expr) { + for (auto id_group : outputGroups(expr)) { + if (required_ind_exprs_ids.find(id_group) == + required_ind_exprs_ids.end()) { + return false; + } + } + return true; + }; + + auto allIdUsesVisisted = [&](IdGroup id) { + auto uses_pair = iterDomainGroupUses(id); + if (!uses_pair.second) { + return true; + } + for (auto use_group : uses_pair.first) { + if (all_exprs.has(use_group)) { + if (required_ind_exprs_exprs.find(use_group) == + required_ind_exprs_exprs.end()) { + return false; + } + } + } + return true; + }; + + // Returns all expression groups in required_ind_exprs_ids of outputs + auto requiredExprsOutputs = [&](ExprGroup expr) { + ExprGroups all_output_required_exprs; + for (auto id_group : outputGroups(expr)) { + auto id_group_exprs_it = required_ind_exprs_ids.find(id_group); + TORCH_INTERNAL_ASSERT( + id_group_exprs_it != required_ind_exprs_ids.end(), + "Failure in Iter Domain Graph index resolution, count expected for group: ", + id_group->toString()); + all_output_required_exprs.pushBack(id_group_exprs_it->second); + } + return all_output_required_exprs; + }; + + auto processExpr = [&](ExprGroup expr) { + if (!outputsVisited(expr)) { + return false; + } + // Accumulate expressions from all outputs add this expression and set it + // as current expressions required indexing expressions. + required_ind_exprs_exprs[expr] = requiredExprsOutputs(expr); + return true; + }; + + auto processId = [&](IdGroup id) { + // Track if we've grabed any of the uses required indexing expressions. + bool initialized = false; + // Expression group of all indexing expressions required for this iter + // domain coming back from any of its uses. + ExprGroups min_groups; + + auto uses_pair = iterDomainGroupUses(id); + if (!uses_pair.second) { + // No expressions required for this iter domain, it must be a + // terminating output. + required_ind_exprs_ids[id] = min_groups; + return true; + } + + // Only worry about expressions between inputs and outputs we're + // looking at. + for (auto use_group : uses_pair.first.intersect(all_exprs)) { + auto use_required_ind_exprs_it = required_ind_exprs_exprs.find(use_group); + if (use_required_ind_exprs_it == required_ind_exprs_exprs.end()) { + // If there isn't an entry for the use expression it wasn't + // processed, so don't try to process this iter domain yet. + return false; + } + if (!initialized) { + // If first use found initialize the minimum expression group + min_groups = + use_required_ind_exprs_it->second.computeUnion({use_group}); + initialized = true; + } else if ( + use_required_ind_exprs_it->second.size() + 1 < min_groups.size()) { + // If current use has fewer expressions use that, make sure to add the + // use expression. + min_groups = + use_required_ind_exprs_it->second.computeUnion({use_group}); + } + } + required_ind_exprs_ids[id] = min_groups; + return true; + }; + + IdGroups to_visit_ids = terminating_outputs; + ExprGroups to_visit_exprs; + + while (to_visit_ids.size() > 0 || to_visit_exprs.size() > 0) { + // Process expressions first as all uses of iter domains have to be + // processed before we can process that iter domain. + + // Try to detect when nothing has been processed which would put us in an + // infinite loop + bool something_was_processed = false; + ExprGroups still_to_visit_exprs; + while (to_visit_exprs.size() > 0) { + auto currently_visiting = to_visit_exprs.popFront(); + if (required_ind_exprs_exprs.find(currently_visiting) != + required_ind_exprs_exprs.end()) { + continue; + } + if (processExpr(currently_visiting)) { + something_was_processed = true; + auto inp_groups = inputGroups(currently_visiting); + for (auto inp_group : inp_groups) { + to_visit_ids.pushBack(inp_group); + } + } else { + still_to_visit_exprs.pushBack(currently_visiting); + } + } + + std::swap(to_visit_exprs, still_to_visit_exprs); + + IdGroups still_to_visit_ids; + while (to_visit_ids.size() > 0) { + auto currently_visiting = to_visit_ids.popFront(); + if (required_ind_exprs_ids.find(currently_visiting) != + required_ind_exprs_ids.end()) { + continue; + } + + if (processId(currently_visiting)) { + something_was_processed = true; + auto definitions_pair = iterDomainGroupDefinitions(currently_visiting); + if (definitions_pair.second) { + for (auto def : definitions_pair.first) { + if (!all_exprs.has(def)) { + continue; + } + if (required_ind_exprs_exprs.find(def) == + required_ind_exprs_exprs.end()) { + to_visit_exprs.pushBack(def); + } + } + } + } else { + still_to_visit_ids.pushBack(currently_visiting); + } + } + + TORCH_INTERNAL_ASSERT( + something_was_processed || + (to_visit_ids.size() == 0 && to_visit_exprs.size() == 0), + "Infinite loop entered."); + } + + // We want to traverse the expressions registered in required_ind_exprs_ids, + // let's create a strict "uses path" + std::unordered_map uses_path; + for (auto entry : required_ind_exprs_ids) { + auto id = entry.first; + auto traverse_exprs = entry.second; + auto all_uses = iterDomainGroupUses(id); + if (all_uses.second) { + uses_path[id] = traverse_exprs.intersect(all_uses.first); + } else { + uses_path[id] = {}; + continue; + } + } + + // Topologically sort the uses_path. + ExprGroups sorted_exprs; + ExprGroups to_visit; + + for (auto inp : terminating_inputs) { + auto use_it = uses_path.find(inp); + TORCH_INTERNAL_ASSERT( + use_it != uses_path.end(), + "Invalid calculation of exprs between, no use found of a provided terminating input: ", + inp->toString(), + " expressions cannot be computed."); + auto uses = use_it->second; + for (auto use : uses) { + to_visit.pushBack(use); + } + } + + IdGroups visited = terminating_inputs; + + while (to_visit.size() > 0) { + bool something_processed = false; + ExprGroups still_to_visit; + while (to_visit.size() > 0) { + auto currently_visiting = to_visit.popFront(); + auto inputs = inputGroups(currently_visiting); + if (std::all_of(inputs.begin(), inputs.end(), [&](IdGroup inp_id) { + return visited.has(inp_id); + })) { + something_processed = true; + sorted_exprs.pushBack(currently_visiting); + auto outputs = outputGroups(currently_visiting); + for (auto out_id : outputs) { + visited.pushBack(out_id); + auto use_pair = iterDomainGroupUses(out_id); + if (!use_pair.second) { + continue; + } + still_to_visit.pushBack(use_pair.first.intersect(all_exprs)); + } + } else { + still_to_visit.pushBack(currently_visiting); + } + } + std::swap(to_visit, still_to_visit); + TORCH_INTERNAL_ASSERT(something_processed, "Infinite loop entered."); + } + + return sorted_exprs; +} + +std::unordered_map> IdGraph:: + buildMapBetween( + const std::vector& from, + const std::vector& to) const { + std::unordered_map from_ids2set; + + for (auto from_id : from) { + auto from_disjoint_set_pair = disjointIdSet(from_id); + if (!from_disjoint_set_pair.second) { + continue; + } + from_ids2set[from_id] = from_disjoint_set_pair.first; + } + + // Map from the sets associated with the IterDomains in to, to those iter + // domains + std::unordered_map> set2to_ids; + + for (auto to_id : to) { + auto to_disjoint_set_pair = disjointIdSet(to_id); + if (!to_disjoint_set_pair.second) { + continue; + } + auto to_set = to_disjoint_set_pair.first; + auto set2to_ids_it = set2to_ids.find(to_set); + + if (set2to_ids_it == set2to_ids.end()) { + set2to_ids[to_set] = {to_id}; + } else { + set2to_ids[to_set].pushBack(to_id); + } + } + + std::unordered_map> + from_ids2to_ids; + for (auto from_id : from) { + from_ids2to_ids[from_id] = VectorOfUniqueEntries(); + + auto from_it = from_ids2set.find(from_id); + TORCH_INTERNAL_ASSERT(from_it != from_ids2set.end()); + + auto from_set = from_it->second; + auto to_entry_it = set2to_ids.find(from_set); + if (to_entry_it == set2to_ids.end()) { + continue; + } + from_ids2to_ids[from_id] = to_entry_it->second; + } + return from_ids2to_ids; +} + +std::unordered_map> IdGraph:: + buildMapBetween( + const VectorOfUniqueEntries& from, + const VectorOfUniqueEntries& to) const { + return buildMapBetween(from.vector(), to.vector()); +} + +std::pair IdGraph::iterDomainGroupDefinitions( + IdGroup id_group) const { + auto null_return = std::make_pair(ExprGroups(), false); + + if (id_group == nullptr) { + return null_return; + } + + auto definitions_it = unique_definitions_.find(id_group); + if (definitions_it == unique_definitions_.end()) { + return null_return; + } + + return std::make_pair(definitions_it->second, true); +} + +std::pair IdGraph::iterDomainGroupUses( + IdGroup id_group) const { + auto null_return = std::make_pair(ExprGroups(), false); + + if (id_group == nullptr) { + return null_return; + } + + auto uses_it = unique_uses_.find(id_group); + if (uses_it == unique_uses_.end()) { + return null_return; + } + + return std::make_pair(uses_it->second, true); +} + +// TODO: Improve and extend to include other information. +std::string IdGraph::toString() const { + std::stringstream ss; + ss << "IdGraph { \n"; + ss << "Disjoint Id Set " << disjoint_ids_.toString() << std::endl; + ss << " } IdGraph\n" << std::endl; + return ss.str(); +} + +std::vector> IdGraph::isTrivialExpr(Expr* expr) { + std::vector> mapped_ids; + if (auto merge = dynamic_cast(expr)) { + if (merge->inner()->extent()->isOneInt()) { + mapped_ids.push_back({merge->outer(), merge->out()}); + } + if (merge->outer()->extent()->isOneInt()) { + mapped_ids.push_back({merge->inner(), merge->out()}); + } + } else if (auto split = dynamic_cast(expr)) { + if (split->factor()->isOneInt() && split->startOffset()->isZeroInt() && + split->stopOffset()->isZeroInt()) { + if (split->innerSplit()) { + mapped_ids.push_back({split->in(), split->outer()}); + } else { + mapped_ids.push_back({split->in(), split->inner()}); + } + } + } else if (auto swizzle = dynamic_cast(expr)) { + if (swizzle->swizzleType() == Swizzle2DType::NoSwizzle || + swizzle->swizzleMode() == SwizzleMode::NoSwizzle) { + mapped_ids.push_back({swizzle->inX(), swizzle->outX()}); + mapped_ids.push_back({swizzle->inY(), swizzle->outY()}); + } + } + return mapped_ids; +} + +// TODO: Add explicit id_definitions_ and id_uses_ +void IdGraph::initializeId( + IterDomain* id, + const VectorOfUniqueEntries& definitions, + const VectorOfUniqueEntries& uses) { + auto id_disjoint_set = disjointIdSets().initializeSet(id).first->second; + + ExprGroups def_groups; + for (auto def : definitions) { + auto expr_set = disjointExprSets().initializeSet(def).first->second; + def_groups.pushBack(expr_set); + } + unique_definitions_[id_disjoint_set] = def_groups; + + ExprGroups use_groups; + for (auto use : uses) { + auto expr_set = disjointExprSets().initializeSet(use).first->second; + use_groups.pushBack(expr_set); + } + unique_uses_[id_disjoint_set] = use_groups; +} + +bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { + if (!transformAtributesMatch(first, second)) { + return false; + } + + auto first_ids = ir_utils::filterByType( + forward ? first->inputs() : first->outputs()) + .vector(); + + auto second_ids = ir_utils::filterByType( + forward ? second->inputs() : second->outputs()) + .vector(); + + TORCH_INTERNAL_ASSERT( + first_ids.size() == second_ids.size(), + "Expected number of ", + (forward ? "inputs" : "outputs"), + " to match for\n", + first->toString(), + second->toString()); + + { + std::vector> zipped_ids; + + std::transform( + first_ids.begin(), + first_ids.end(), + second_ids.begin(), + std::back_inserter(zipped_ids), + [](IterDomain* first, IterDomain* second) { + return std::make_pair(first, second); + }); + + if (std::any_of( + zipped_ids.begin(), + zipped_ids.end(), + [&](std::pair id_pair) { + return !disjointIdSets().permissiveAreMapped( + id_pair.first, id_pair.second); + })) { + return false; + } + } + + // Special handling for backprop of merge + if (first->isA() && !forward) { + // Can't back prop through merge without making sure one input actually + // matches. This can be done on a map or extent basis. + auto merge0 = first->as(); + auto merge1 = second->as(); + + auto extent_0o = merge0->outer()->extent(); + auto extent_0i = merge0->inner()->extent(); + auto extent_1o = merge1->outer()->extent(); + auto extent_1i = merge1->inner()->extent(); + + auto extent_0_match = extent_0o->sameAs(extent_1o) || + (extent_0o->isConstInt() && extent_1o->isConstInt() && + extent_0o->evaluateInt() == extent_1o->evaluateInt()) || + disjointIdSets().permissiveAreMapped(merge0->outer(), merge1->outer()); + + auto extent_1_match = extent_0i->sameAs(extent_1i) || + (extent_0i->isConstInt() && extent_1i->isConstInt() && + extent_0i->evaluateInt() == extent_1i->evaluateInt()) || + disjointIdSets().permissiveAreMapped(merge0->inner(), merge1->inner()); + + if (!(extent_0_match || extent_1_match)) { + return false; + } + } + + return true; +} + +ExprGroups IdGraph::uniqueDefinitions(IdGroup group) const { + auto unique_defs_it = unique_definitions_.find(group); + TORCH_INTERNAL_ASSERT( + unique_defs_it != unique_definitions_.end(), + "Definition not found for IdGroup: ", + group->toString()); + return unique_defs_it->second; +} + +ExprGroups IdGraph::uniqueUses(IdGroup group) const { + auto unique_uses_it = unique_uses_.find(group); + TORCH_INTERNAL_ASSERT( + unique_uses_it != unique_definitions_.end(), + "Uses not found for IdGroup: ", + group->toString()); + return unique_uses_it->second; +} + +void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { + if (expr0 == expr1) { + return; + } + + if (disjointExprSets().strictAreMapped(expr0, expr1)) { + return; + } + + // TODO: make these class functions for convenience, there are too many + // asserts in this file. + auto assert_get_expr_group = [&](Expr* expr) { + auto expr_group_pair = disjointExprSet(expr); + TORCH_INTERNAL_ASSERT( + expr_group_pair.second, "Could not find entry for expression: ", expr); + return expr_group_pair.first; + }; + + auto assert_get_id_group = [&](IterDomain* id) { + auto id_group_pair = disjointIdSet(id); + TORCH_INTERNAL_ASSERT( + id_group_pair.second, "Could not find entry for IterDomain: ", id); + return id_group_pair.first; + }; + + ExprGroup expr0_orig_group = assert_get_expr_group(expr0); + ExprGroup expr1_orig_group = assert_get_expr_group(expr1); + + disjointExprSets().mapEntries(expr0, expr1); + + auto expr_new_group = assert_get_expr_group(expr0); + + // Update unique uses of producers + IdGroups producers; + for (auto expr : std::vector{expr0, expr1}) { + for (auto input_id : ir_utils::filterByType(expr->inputs())) { + producers.pushBack(assert_get_id_group(input_id)); + } + } + + for (auto producer_group : producers) { + uniqueUses().at(producer_group).erase(expr0_orig_group); + uniqueUses().at(producer_group).erase(expr1_orig_group); + uniqueUses().at(producer_group).pushBack(expr_new_group); + } + + // Update unique definitinos of consumers + IdGroups consumers; + for (auto expr : std::vector{expr0, expr1}) { + for (auto output_id : ir_utils::filterByType(expr->outputs())) { + consumers.pushBack(assert_get_id_group(output_id)); + } + } + + for (auto consumer_group : consumers) { + uniqueDefinitions().at(consumer_group).erase(expr0_orig_group); + uniqueDefinitions().at(consumer_group).erase(expr1_orig_group); + uniqueDefinitions().at(consumer_group).pushBack(expr_new_group); + } +} + +void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { + if (id0 == id1) { + return; + } + + if (disjointIdSets().strictAreMapped(id0, id1)) { + return; + } + // Definitions and uses are based on the groups of id0 and id1, don't merge + // them into a single group until we grab all definitions and uses for later + // processing. + auto orig_id_group0 = disjointIdSet(id0).first; + auto orig_id_group1 = disjointIdSet(id1).first; + ExprGroups orig_defs0 = uniqueDefinitions(orig_id_group0); + ExprGroups orig_defs1 = uniqueDefinitions(orig_id_group1); + ExprGroups orig_uses0 = uniqueUses(orig_id_group0); + ExprGroups orig_uses1 = uniqueUses(orig_id_group1); + + // Map the iter domains together before we traverse across definitions and + // uses. Traversing definitions and uses could use the new property of id0 and + // id1 being mapped. + disjointIdSets().mapEntries(id0, id1); + auto new_id_group = disjointIdSet(id0).first; + + unique_definitions_.erase(orig_id_group0); + unique_definitions_.erase(orig_id_group1); + unique_uses_.erase(orig_id_group0); + unique_uses_.erase(orig_id_group1); + + unique_definitions_[new_id_group] = orig_defs0.computeUnion(orig_defs1); + unique_uses_[new_id_group] = orig_uses0.computeUnion(orig_uses1); + + // Propagate on uses + if (orig_uses0.size() > 0 || orig_uses1.size() > 0) { + if (orig_uses0.size() > 0 && orig_uses1.size() > 0) { + for (auto use_group_1 : orig_uses1) { + if (orig_uses0.has(use_group_1)) { + continue; + } + + for (auto use_group_0 : orig_uses0) { + auto use0 = use_group_0->front(); + auto use1 = use_group_1->front(); + if (exprsMap(use0, use1, true)) { + mapExprs(use0, use1); + mapThroughExpr(use0, use1, true); + } + } + } + } + } + + // Propagate on definitions + if (orig_defs0.size() > 0 || orig_defs1.size() > 0) { + if (orig_defs0.size() > 0 && orig_defs1.size() > 0) { + for (auto def_group_1 : orig_defs1) { + if (orig_defs0.has(def_group_1)) { + continue; + } + + for (auto def_group_0 : orig_defs0) { + auto def0 = def_group_0->front(); + auto def1 = def_group_1->front(); + if (exprsMap(def0, def1, false)) { + mapExprs(def0, def1); + mapThroughExpr(def0, def1, false); + } + } + } + } + } +} + +bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { + if (first == nullptr || second == nullptr) { + return false; + } + + if (!exprsMap(first, second, forward)) { + return false; + } + + auto first_ids = ir_utils::filterByType( + forward ? first->outputs() : first->inputs()) + .vector(); + auto second_ids = ir_utils::filterByType( + forward ? second->outputs() : second->inputs()) + .vector(); + TORCH_INTERNAL_ASSERT( + first_ids.size() == second_ids.size(), + "This should be unreachable, if transformation expressions match, their number of inputs and outputs should as well.\n However found:\n", + first->toString(), + "\nand\n", + second->toString()); + for (auto out_i : c10::irange(first_ids.size())) { + mapIds(first_ids[out_i], second_ids[out_i]); + } + + return true; +} + +void IterDomainGraphs::assertNoSelfMapping() { + TORCH_INTERNAL_ASSERT( + !hasSelfMapping(), + "Unsupported domain mapping detected in ", + std::get<0>(*self_mapping_info_)->toString(), + ". ", + std::get<3>(*self_mapping_info_), + " domains, ", + std::get<1>(*self_mapping_info_)->toString(), + " and ", + std::get<2>(*self_mapping_info_)->toString(), + ", are mapped with each other."); +} + +void IdGraph::mapThroughLoopSwizzles() { + for (auto use_pairs : unique_uses_) { + auto use_groups = use_pairs.second; + for (auto use_group : use_groups) { + for (auto use : *use_group) { + if (auto swizzle_2d = dynamic_cast(use)) { + // Map each input to its corresponding output on the given + // disjoint set if this is a loop swizzle. Loop swizzles don't impact + // indexing, only iteration order. + if (swizzle_2d->swizzleMode() == SwizzleMode::Loop) { + mapIds(swizzle_2d->inX(), swizzle_2d->outX()); + mapIds(swizzle_2d->inY(), swizzle_2d->outY()); + } + } + } + } + } +} + +IterDomainGraphs::IterDomainGraphs( + const std::vector& exprs, + const std::vector& additional_tvs, + bool allow_self_mapping) { + build(exprs, additional_tvs); + + if (!allow_self_mapping) { + assertNoSelfMapping(); + } +} + +IterDomainGraphs::IterDomainGraphs( + const std::vector& exprs, + bool allow_self_mapping) + : IterDomainGraphs(exprs, {}, allow_self_mapping) {} + +IterDomainGraphs::IterDomainGraphs(Fusion* fusion, bool allow_self_mapping) { + std::vector inputs_and_outputs; + { + auto inp_tvs = ir_utils::filterByType(fusion->inputs()); + inputs_and_outputs.insert( + inputs_and_outputs.begin(), inp_tvs.begin(), inp_tvs.end()); + } + { + auto out_tvs = ir_utils::filterByType(fusion->outputs()); + inputs_and_outputs.insert( + inputs_and_outputs.begin(), out_tvs.begin(), out_tvs.end()); + } + + build(fusion->exprs(), inputs_and_outputs); + + if (!allow_self_mapping) { + assertNoSelfMapping(); + } +} + +const IdGraph& IterDomainGraphs::idGraph(IdMappingMode mode) const { + auto graph_it = id_graphs_.find(mode); + TORCH_INTERNAL_ASSERT(graph_it != id_graphs_.end()); + return graph_it->second; +} + +IdGraph& IterDomainGraphs::idGraph(IdMappingMode mode) { + auto graph_it = id_graphs_.find(mode); + TORCH_INTERNAL_ASSERT(graph_it != id_graphs_.end()); + return graph_it->second; +} + +Expr* IterDomainGraphs::idUse(IterDomain* id) const { + auto use_it = id_uses_.find(id); + if (use_it == id_uses_.end()) { + return nullptr; + } + return use_it->second.front(); +} + +Expr* IterDomainGraphs::idDef(IterDomain* id) const { + auto def_it = id_definitions_.find(id); + if (def_it == id_definitions_.end()) { + return nullptr; + } + return def_it->second.front(); +} + +namespace { + +// Returns the first pair of id's in ids detected to match eachother on the +// permissive map of the ID graph. TODO: what this is really looking for is if +// there's any overlapping between the iter domains in the provided set. +// +// i.e. if we have: +// tv0 = arange(6).view({3, 2}) +// tv1 = tv0[3, 2].t() +// tv2 = tv0[3, 2].view({2, 3}) +// tv3 = tv1 + tv2 +// +// Then we can see this overlap in the tv3 expression as: +// +// tv0 = { {0, 1, 2}, +// {3, 4, 5} } +// +// tv1 = { {0, 3}, +// {1, 4}, +// {2, 5} } +// +// tv2 = { {0, 1}, +// {2, 3}, +// {4, 5} } +// +// The elements in tv1 {3, 1, 4, 2}, map respectively to the elements in tv2 +// {1, 2, 3, 4}. The reason this is so important is it means that generating +// tv3 is no longer a trivially parallelizable problem (if we include the dag +// all the way to tv0). So tv0's axes cannot be inlined across both the tv0 +// and tv1 path. This breaks some assumptions we have today in schedulers that +// will assume tv2 can be trivially inlined/parallelized. Instead we'd need to +// take into consideration the effective communication going on here, so that +// we pull multiple values of tv0 to compute tv3. +c10::optional> detectMappablePair( + const std::vector& ids, + const IterDomainGraphs& id_graph, + IdMappingMode mode) { + for (auto id1 : ids) { + for (auto id2 : ids) { + if (id1 == id2) { + continue; + } + if (id_graph.idGraph(mode).disjointIdSets().permissiveAreMapped( + id1, id2)) { + return std::make_pair(id1, id2); + } + } + } + + return {}; +} + +// It is assumed that for any tensor represented by a list of domains, +// those domains should never be mapped with each other. It may be +// possible to lift this assumption, but it's unclear if it could +// matter in practice. +c10::optional> +findFirstSelfMapping( + const std::vector& all_tvs, + const IterDomainGraphs& id_graph) { + for (auto tv : all_tvs) { + // For each tensor, make sure root, rfactor and leaf domains + // should not include domains that are mapped with another domain + // in the same set of domains. This may be overly conservative, + // and it maybe enough to check the root domains. + + // Root domains + auto self_mappped_root_pair = + detectMappablePair(tv->getRootDomain(), id_graph, IdMappingMode::EXACT); + if (self_mappped_root_pair.has_value()) { + return std::make_tuple( + tv, + self_mappped_root_pair->first, + self_mappped_root_pair->second, + "Root"); + } + + // Rfactor domains + if (tv->hasRFactor()) { + auto self_mappped_rf_pair = detectMappablePair( + tv->getRFactorDomain(), id_graph, IdMappingMode::EXACT); + if (self_mappped_rf_pair.has_value()) { + return std::make_tuple( + tv, + self_mappped_rf_pair->first, + self_mappped_rf_pair->second, + "RFactor"); + } + } + + // Leaf domains + auto self_mappped_leaf_pair = detectMappablePair( + tv->domain()->domain(), id_graph, IdMappingMode::LOOP); + if (self_mappped_leaf_pair.has_value()) { + return std::make_tuple( + tv, + self_mappped_leaf_pair->first, + self_mappped_leaf_pair->second, + "Leaf"); + } + } + return c10::nullopt; +} + +} // namespace + +void IterDomainGraphs::buildIterDomainDefinitionsAndUses( + const std::vector& all_tvs) { + for (auto tv : all_tvs) { + VectorOfUniqueEntries root_domain_ids{ + tv->getRootDomain().begin(), tv->getRootDomain().end()}; + + auto all_ids = ir_utils::allIDsOf(tv); + + // Check is this domain is a consumer of a view-like operation + bool view_like_domain = tv->domain()->hasViewLikeRFactor(); + + for (auto id : all_ids) { + // Check if this id is a view like rfactor id + if (view_like_domain && id->isRFactorProduct()) { + // If the tensor domain is a view like domain, and the iteration + // domain is marked as an rfactor product and is in the rfactor + // domain, it's a view like rfactor iteration domain + const auto& rfactor_domain = tv->domain()->getMaybeRFactorDomain(); + if (std::find(rfactor_domain.begin(), rfactor_domain.end(), id) != + rfactor_domain.end()) { + view_rfactor_ids_.emplace(id); + } + } + + if (id_definitions_.find(id) == id_definitions_.end()) { + id_definitions_[id] = {}; + } + + if (id_uses_.find(id) == id_uses_.end()) { + id_uses_[id] = {}; + } + + auto def = id->definition(); + + if (def == nullptr || root_domain_ids.has(id)) { + continue; + } + + if (id_definitions_.find(id) == id_definitions_.end()) { + id_definitions_[id] = {}; + } + id_definitions_.at(id).pushBack(def); + + auto inp_ids = ir_utils::filterByType(def->inputs()); + for (auto inp_id : inp_ids) { + if (id_uses_.find(inp_id) == id_uses_.end()) { + id_uses_[inp_id] = {}; + } + id_uses_.at(inp_id).pushBack(def); + } + } + } +} + +// TODO: Extend to include other information. +std::string IterDomainGraphs::toString() const { + std::stringstream ss; + ss << "IterDomainGraphs { \n"; + // for (auto set : disjoint_ids_) { + // ss << "Set " << set.first << ": " << std::endl; + // ss << set.second.toString() << std::endl; + // } + ss << " } IterDomainGraphs\n" << std::endl; + return ss.str(); +} + +// Replay Expr but with the inputs provided. +Expr* IterDomainGraphs::addReplayAs( + const std::vector& new_inputs, + Expr* expr) { + // Figure out which graphs are already initialized to make sure we add the new + // expression to them. + std::vector initialized_modes; + for (auto mode : kIdMappingModes) { + auto graph_it = id_graphs_.find(mode); + if (graph_it == id_graphs_.end()) { + continue; + } + + auto& graph = graph_it->second; + if (graph.disjointIdSets().disjointSetMap().empty()) { + continue; + } + + initialized_modes.push_back(mode); + } + + auto orig_inputs = ir_utils::filterByType(expr->inputs()); + std::vector orig_input_ids( + orig_inputs.begin(), orig_inputs.end()); + + { + TORCH_INTERNAL_ASSERT( + new_inputs.size() == orig_input_ids.size(), + "Invalid number of inputs: ", + new_inputs.size(), + " does not match number of iter domain inputs for ", + expr->toString()); + + VectorOfUniqueEntries all_inputs{ + orig_input_ids.begin(), orig_input_ids.end()}; + + all_inputs.pushBack(VectorOfUniqueEntries{ + new_inputs.begin(), new_inputs.end()}); + + for (auto mode : initialized_modes) { + for (auto inp : all_inputs) { + TORCH_INTERNAL_ASSERT( + idGraph(mode).disjointIdSet(inp).second, + "All inputs for replay need to be initialized in all graphs, ", + inp->toString(), + " was not found in mode: ", + mode); + } + } + } + + // Create the new expression with provided inputs + auto replay = ReplayTransform::replayAs(new_inputs, expr); + + for (auto out_id : ir_utils::filterByType(replay->outputs())) { + id_definitions_[out_id] = {replay}; + id_uses_[out_id] = {}; + } + + // Add the expression to the uses of the inputs + for (auto inp_id : ir_utils::filterByType(replay->inputs())) { + id_uses_.at(inp_id).pushBack(replay); + } + + // Initialize output iter domains in the graphs + for (auto mode : initialized_modes) { + idGraph(mode).disjointExprSets().initializeSet(replay); + auto replay_group = idGraph(mode).disjointExprSet(replay).first; + + // Initialize output ids in map + for (auto out_id : ir_utils::filterByType(replay->outputs())) { + idGraph(mode).initializeId(out_id, {replay}, {}); + } + + // Update uses of the inputs in the graphs + for (auto inp_id : ir_utils::filterByType(replay->inputs())) { + auto inp_group = idGraph(mode).disjointIdSet(inp_id).first; + idGraph(mode).uniqueUses().at(inp_group).pushBack(replay_group); + } + + // Propagate through all the uses of the iter domain groups of the inputs + // with the new expression. + auto& graph = idGraph(mode); + // Gather all use expressions from inputs + VectorOfUniqueEntries representative_uses; + for (auto inp : new_inputs) { + auto uses_pair = + graph.iterDomainGroupUses(graph.disjointIdSet(inp).first); + if (uses_pair.second) { + for (auto use_group : uses_pair.first) { + representative_uses.pushBack(use_group->front()); + } + } + } + + for (auto expr : representative_uses) { + if (graph.exprsMap(expr, replay, true)) { + graph.mapExprs(expr, replay); + graph.mapThroughExpr(expr, replay, true); + } + } + } + + return replay; +} + +IdGraph IterDomainGraphs::initializeIdGraph() { + IdGraph id_graph; + + for (auto definition_entry : id_definitions_) { + auto id = definition_entry.first; + auto defs = definition_entry.second; + auto uses_it = id_uses_.find(id); + TORCH_INTERNAL_ASSERT( + uses_it != id_uses_.end(), + "Failed to initialize id: ", + id->toString(), + " as it's missing a definition entry."); + id_graph.initializeId(id, defs, uses_it->second); + } + + return id_graph; +} + +void IterDomainGraphs::buildExactMap(const std::vector& exprs) { + for (auto expr : exprs) { + TensorView* c_tv = ir_utils::getTvOutput(expr); + + auto all_tv_outputs = ir_utils::filterByType(expr->outputs()); + + // Map siblings, as all other tv output domains must match the first tv + // outputs domain. + std::deque other_tv_outputs( + all_tv_outputs.begin(), all_tv_outputs.end()); + other_tv_outputs.pop_front(); + + for (auto other_tv_output : other_tv_outputs) { + // Sibling tv's must be exactly mapped with eachother so simply zip + // their leaf iter domains. + + TORCH_INTERNAL_ASSERT( + other_tv_output->getRootDomain().size() == + c_tv->getRootDomain().size(), + "Multiple outputs with mismatched TV domains is not supported."); + + for (auto domain_i : c10::irange(c_tv->getRootDomain().size())) { + auto c_id = c_tv->getRootDomain()[domain_i]; + auto o_id = other_tv_output->getRootDomain()[domain_i]; + idGraph(IdMappingMode::EXACT).mapIds(o_id, c_id); + } + } + + // Map producer-consumer relationships based on the root domain map + auto tv_inputs = ir_utils::filterByType(expr->inputs()); + for (auto p_tv : tv_inputs) { + // For exact mapings do not map any broadcast dimensions to + // non-broadcast dimensions. Prevent any broadcasted axes being mapped + // to non-broadcasted axes. + auto exact_c2p_root_map = + PairwiseRootDomainMap(p_tv, c_tv, true) + .mapConsumerToProducer(c_tv->domain(), p_tv->domain()); + + for (auto c_id : getSortedKeys(exact_c2p_root_map, Statement::lessThan)) { + auto p_id = exact_c2p_root_map.at(c_id); + idGraph(IdMappingMode::EXACT).mapIds(c_id, p_id); + } + } + + idGraph(IdMappingMode::EXACT).mapThroughLoopSwizzles(); + } +} + +void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { + idGraph(IdMappingMode::PERMISSIVE) = idGraph(IdMappingMode::ALMOSTEXACT); + + for (auto expr : exprs) { + // Multiple outputs are already mapped, we can ignore all but the first + // consumer given they have to be replayed in the same exact way + // Multiple outputs are already mapped, we can ignore all but the first + // consumer given they have to be replayed in the same exact way + TensorView* c_tv = ir_utils::getTvOutput(expr); + + auto tv_inputs = ir_utils::filterByType(expr->inputs()); + + for (auto p_tv : tv_inputs) { + auto p_ids_vec = ir_utils::allIDsOf(p_tv); + auto c_ids_vec = ir_utils::allIDsOf(c_tv); + std::unordered_set p_ids(p_ids_vec.begin(), p_ids_vec.end()); + std::unordered_set c_ids(c_ids_vec.begin(), c_ids_vec.end()); + + ForwardingInfo permissive_forwarding(p_tv, c_tv); + for (auto entry : permissive_forwarding.producer_forwarding_map) { + idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); + } + + // TODO: Should this just get rolled up in the forwarding map now? + for (auto entry : permissive_forwarding.producer_compliment_map) { + for (auto entry_2 : entry.second) { + idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); + } + } + + for (auto entry : permissive_forwarding.consumer_forwarding_map) { + idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); + } + + // TODO: Should this just get rolled up in the forwarding map now? + for (auto entry : permissive_forwarding.consumer_compliment_map) { + for (auto entry_2 : entry.second) { + idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); + } + } + + auto permissive_c2p_root_map = PairwiseRootDomainMap(p_tv, c_tv); + + for (auto entry : permissive_c2p_root_map.mapConsumerToProducer( + c_tv->domain(), p_tv->domain())) { + idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); + } + } + } + idGraph(IdMappingMode::PERMISSIVE).mapThroughLoopSwizzles(); +} + +void IterDomainGraphs::buildAlmostExactMap() { + // Build almost exact map by forwarding through broadcast axes + idGraph(IdMappingMode::ALMOSTEXACT) = idGraph(IdMappingMode::EXACT); + + VectorOfUniqueEntries exprs; + for (auto expr : + idGraph(IdMappingMode::ALMOSTEXACT).disjointExprSets().disjointSets()) { + exprs.pushBack(expr->front()); + } + ExprGroups trivial_expr_groups; + + // Map through trivial expressions + for (auto expr : exprs) { + auto mapped_ids = IdGraph::isTrivialExpr(expr); + for (auto mapped_id_group : mapped_ids) { + for (auto id : mapped_id_group) { + trivial_expr_groups.pushBack( + idGraph(IdMappingMode::ALMOSTEXACT).disjointExprSet(expr).first); + idGraph(IdMappingMode::ALMOSTEXACT).mapIds(mapped_id_group.front(), id); + } + } + } + + // TODO: Clear out expressions that map inputs and outputs to the same group + // from definitions and uses. They shouldn't be important in traversal. + // Similar to what's drafted in buildIndexMap +} + +void IterDomainGraphs::validateAndPropagatePType() const { + for (const auto& loop_disjoint_set : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + ParallelType common_ptype = ParallelType::Serial; + for (auto id : loop_disjoint_set->vector()) { + auto id_ptype = id->getParallelType(); + TORCH_INTERNAL_ASSERT( + id_ptype == common_ptype || id_ptype == ParallelType::Serial || + common_ptype == ParallelType::Serial, + "Issue validating parallel type disjoint ptype is, ", + common_ptype, + " but found in the set the id: ", + id->toString()); + common_ptype = + common_ptype == ParallelType::Serial ? id_ptype : common_ptype; + } + + for (auto id : loop_disjoint_set->vector()) { + id->parallelize(common_ptype); + } + } +} + +void IterDomainGraphs::build( + const std::vector& exprs, + const std::vector& additional_tvs) { + // Initialize the required sets as if a permissive relationship is never + // found, then querying an empty permissive map will fail later. + // Initialize disjoint sets + for (auto mode : kIdMappingModes) { + id_graphs_[mode] = IdGraph(); + } + + std::vector tv_exprs; + + std::copy_if( + exprs.begin(), exprs.end(), std::back_inserter(tv_exprs), [](Expr* expr) { + TORCH_INTERNAL_ASSERT(expr != nullptr); + return ir_utils::isTvOp(expr); + }); + + auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs); + if (additional_tvs.size() > 0) { + std::unordered_set all_added_tvs( + all_tvs.begin(), all_tvs.end()); + for (auto additional_tv : additional_tvs) { + if (all_added_tvs.find(additional_tv) == all_added_tvs.end()) { + all_tvs.push_back(additional_tv); + } + } + } + + if (all_tvs.empty()) { + return; + } + + FusionGuard fg(all_tvs.front()->fusion()); + // Add uses and definitions to all iter domains. + buildIterDomainDefinitionsAndUses(all_tvs); + + // Initialize the maps with all the IterDomains used in the provded + // expressions. + idGraph(IdMappingMode::EXACT) = initializeIdGraph(); + + buildExactMap(tv_exprs); + + buildAlmostExactMap(); + + buildPermissiveMap(tv_exprs); + + // Only build loop map during lowering + if (FusionGuard::getCurFusion()->isA()) { + FusionGuard::getCurFusion()->print(); + // Find loops that need to be promoted because of broadcast resolution, + // figure out what that resolution should look like, compute IDs for it if + // necessary. + buildLoopPromotionMap(tv_exprs); + + TORCH_INTERNAL_ASSERT(false); + + validateAndPropagatePType(); + } + + // Debug, make sure there's no self mapping in TensorView's during lowering + // that would invalidate lowering assumptions. + self_mapping_info_ = findFirstSelfMapping(all_tvs, *this); +} + +namespace { + +// Returns the root producer iteration domains that are resolved by provided +// consumer +std::unordered_map resolvedRootBroadcasts( + TensorView* producer, + TensorView* consumer) { + auto p2c_map = + PairwiseRootDomainMap(producer, consumer) + .mapProducerToConsumer(producer->domain(), consumer->domain()); + + std::unordered_map resolved_bcast_map; + for (const auto& kv : p2c_map) { + auto p_id = kv.first; + // Ignore non-broadcast dims + if (!p_id->isBroadcast()) { + continue; + } + auto c_id = kv.second; + // If the consumer ID is a reduction (i.e., a trivial + // reduction), do not consider it's concretized. + if (c_id->isBroadcast() || c_id->isReduction()) { + continue; + } + resolved_bcast_map[p_id] = c_id; + } + return resolved_bcast_map; +} + +} // namespace + +std::unordered_map IterDomainGraphs:: + buildCoveredAlmostExact() { + // Helper functions. + auto producerIdGroups = [&](IdGroup id_group) { + IdGroups producer_groups; + auto definition_pair_it = idGraph(IdMappingMode::ALMOSTEXACT) + .iterDomainGroupDefinitions(id_group); + if (!definition_pair_it.second) { + return producer_groups; + } + for (auto def_group : definition_pair_it.first) { + auto inp_groups = + idGraph(IdMappingMode::ALMOSTEXACT).inputGroups(def_group); + producer_groups.pushBack(inp_groups); + } + return producer_groups; + }; + + auto consumerIdGroups = [&](IdGroup id_group) { + IdGroups consumer_groups; + auto uses_pair_it = + idGraph(IdMappingMode::ALMOSTEXACT).iterDomainGroupUses(id_group); + if (!uses_pair_it.second) { + return consumer_groups; + } + for (auto use_group : uses_pair_it.first) { + auto out_groups = + idGraph(IdMappingMode::ALMOSTEXACT).outputGroups(use_group); + consumer_groups.pushBack(out_groups); + } + return consumer_groups; + }; + + // Start at terminating inputs of the almost exact graph and almost exact + // entries that are rfactor nodes. Propagate and accumulate these nodes + // through consumers. + // + // The almost exact entries covered by an iteration domain is effectively + // all the iteration domains this domain relies on. Initialize broadcast + // entries to not cover any domains. + std::unordered_map covered_almost_exact_entries; + + // We will traverse over the almost exact set expressions. Save where we + // want to start traversal: + IdGroups to_visit; + // Initialize covered groups + for (auto almost_exact_set : + idGraph(IdMappingMode::ALMOSTEXACT).disjointIdSets().disjointSets()) { + // what broadcast domains cover doesn't matter + if (std::all_of( + almost_exact_set->begin(), + almost_exact_set->end(), + [&](IterDomain* id) { return id->isBroadcast(); })) { + covered_almost_exact_entries[almost_exact_set] = {}; + continue; + } + + // Initialize rfactor domains to cover themselves only + if (std::any_of( + almost_exact_set->begin(), + almost_exact_set->end(), + [&](IterDomain* id) { + return viewRfactorIds().find(id) != viewRfactorIds().end(); + })) { + covered_almost_exact_entries[almost_exact_set] = {almost_exact_set}; + to_visit.pushBack(consumerIdGroups(almost_exact_set)); + continue; + } + + // Initialize any groups that don't have a definition except (potentialy) + // ones that traverse back to this set. + auto def_pair = idGraph(IdMappingMode::ALMOSTEXACT) + .iterDomainGroupDefinitions(almost_exact_set); + if (!def_pair.second) { + covered_almost_exact_entries[almost_exact_set] = {almost_exact_set}; + to_visit.pushBack(consumerIdGroups(almost_exact_set)); + continue; + } + + for (auto def : def_pair.first) { + // If all definitions are self mapping (can happen with + // merging our splitting with a broadcast/ dim of size 1) + // then this group is an input. + auto inp_groups = idGraph(IdMappingMode::ALMOSTEXACT).inputGroups(def); + if (std::find(inp_groups.begin(), inp_groups.end(), almost_exact_set) == + inp_groups.end()) { + goto loop_continue; + } + } + + covered_almost_exact_entries[almost_exact_set] = {almost_exact_set}; + to_visit.pushBack(consumerIdGroups(almost_exact_set)); + + loop_continue:; + } + + // Starting from the initialized inputs propagate forward from those inputs to + // mark what every iter domain in the graph covers. This will be used in later + // analysis. + while (to_visit.size() > 0) { + IdGroups still_to_visit; + bool something_processed = false; + while (to_visit.size() > 0) { + auto currently_visiting = to_visit.popFront(); + if (covered_almost_exact_entries.find(currently_visiting) != + covered_almost_exact_entries.end()) { + continue; + } + auto producer_ids = producerIdGroups(currently_visiting); + producer_ids.erase(currently_visiting); + IdGroups currently_visiting_covered; + for (auto producer_id : producer_ids) { + auto producer_covered_it = + covered_almost_exact_entries.find(producer_id); + if (producer_covered_it == covered_almost_exact_entries.end()) { + still_to_visit.pushBack(currently_visiting); + goto inner_while_continue; + } + for (auto entry : producer_covered_it->second) { + if (currently_visiting_covered.has(entry)) { + continue; + } + } + currently_visiting_covered.pushBack(producer_covered_it->second); + } + covered_almost_exact_entries[currently_visiting] = + currently_visiting_covered; + to_visit.pushBack(consumerIdGroups(currently_visiting)); + something_processed = true; + + inner_while_continue:; + } + TORCH_INTERNAL_ASSERT( + still_to_visit.empty() || something_processed, + "Entered infinite loop."); + std::swap(still_to_visit, to_visit); + } + return covered_almost_exact_entries; +} + +void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { + idGraph(IdMappingMode::LOOP) = initializeIdGraph(); + + std::unordered_map> + p2c_root_broadcast_resolution_map; + + // Track all of the p2c mappings through the fusion within those inlined + // domains. + std::unordered_map> + p2c_ca_permissive_maps; + + VectorOfUniqueEntries ordered_p_ca_ids; + + auto accumulateInMap = + [](std::unordered_map>& + map, + IterDomain* key, + IterDomain* new_value) { + auto entry_it = map.find(key); + if (map.find(key) == map.end()) { + map[key] = {new_value}; + } else { + auto& value = entry_it->second; + value.pushBack(new_value); + } + }; + + auto accumulateInMapVec = + [](std::unordered_map>& + map, + IterDomain* key, + const VectorOfUniqueEntries& new_values) { + auto entry_it = map.find(key); + if (map.find(key) == map.end()) { + map[key] = new_values; + } else { + auto& value = entry_it->second; + value.pushBack(new_values); + } + }; + + for (auto expr : exprs) { + for (auto producer : ir_utils::filterByType(expr->inputs())) { + auto producer_root = producer->getMaybeRFactorDomain(); + auto producer_domain = producer->domain()->domain(); + + // Grab all iteration domains in producer that its compute at iter domains + // depend on. + VectorOfUniqueEntries all_producer_ca_deps; + { + auto ca_dep_vals = DependencyCheck::getAllValsBetween( + {producer_root.begin(), producer_root.end()}, + {producer_domain.begin(), + producer_domain.begin() + producer->getComputeAtPosition()}); + + auto ca_deps_filter = ir_utils::filterByType(ca_dep_vals); + + all_producer_ca_deps.insert( + ca_deps_filter.begin(), ca_deps_filter.end()); + } + + ordered_p_ca_ids.pushBack(all_producer_ca_deps); + + for (auto consumer : + ir_utils::filterByType(expr->outputs())) { + auto resolved_bcast_map = resolvedRootBroadcasts(producer, consumer); + for (auto entry : resolved_bcast_map) { + accumulateInMap( + p2c_root_broadcast_resolution_map, entry.first, entry.second); + for (auto other_exact_bcast : *idGraph(IdMappingMode::EXACT) + .disjointIdSet(entry.first) + .first) { + if (all_producer_ca_deps.has(other_exact_bcast)) { + accumulateInMap( + p2c_root_broadcast_resolution_map, + other_exact_bcast, + entry.second); + } + } + } + + auto p2c_ca_permissive_map = idGraph(IdMappingMode::PERMISSIVE) + .buildMapBetween( + all_producer_ca_deps.vector(), + ir_utils::allIDsOf(consumer)); + + for (auto entry : p2c_ca_permissive_map) { + if (entry.second.size() == 0) { + continue; + } + accumulateInMapVec(p2c_ca_permissive_maps, entry.first, entry.second); + } + } + } + } + + // Make sure this is called in a deterministic order + for (auto p_id : ordered_p_ca_ids) { + auto entry_it = p2c_ca_permissive_maps.find(p_id); + if (entry_it == p2c_ca_permissive_maps.end()) { + continue; + } + auto c_ids = entry_it->second; + for (auto c_id : c_ids) { + idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); + } + } + + // Terminal loop ids are iteration domains in each loop group that: + // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a + // consumer TV's iter domain maps to this domain in a way that that domain + // is also in the same loop group + // 2) Don't have a direct IterDomain consumer within the group + VectorOfUniqueEntries terminal_loop_ids; + + // Case (1) + VectorOfUniqueEntries p2c_ca_terminal_loop_ids; + // Case(2) + VectorOfUniqueEntries id_consumer_terminal_loop_ids; + + for (auto group : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + if (group->size() == 1) { + p2c_ca_terminal_loop_ids.pushBack(group->front()); + id_consumer_terminal_loop_ids.pushBack(group->front()); + } + + // Don't select producer iter domains + for (auto loop_id : *group) { + if (p2c_ca_permissive_maps.find(loop_id) != + p2c_ca_permissive_maps.end()) { + continue; + } + + p2c_ca_terminal_loop_ids.pushBack(loop_id); + + auto uses_it = id_uses_.find(loop_id); + if (uses_it == id_uses_.end()) { + id_consumer_terminal_loop_ids.pushBack(loop_id); + continue; + } + + // If there's an output group that is not in the same group, then it's id + // consumer terminal. Also if there's no output groups it's id consumer + // terminal. + bool all_outs_in_loop_group = uses_it->second.size() == 0 ? false : true; + for (auto use : uses_it->second) { + for (auto out_id : ir_utils::filterByType(use->outputs())) { + auto out_loop_set_pair = + idGraph(IdMappingMode::LOOP).disjointIdSet(out_id); + TORCH_INTERNAL_ASSERT(out_loop_set_pair.second); + if (group != out_loop_set_pair.first) { + all_outs_in_loop_group = false; + } + } + } + + if (!all_outs_in_loop_group) { + id_consumer_terminal_loop_ids.pushBack(loop_id); + } + } + } + + terminal_loop_ids = + p2c_ca_terminal_loop_ids.intersect(id_consumer_terminal_loop_ids); + + std::cout << "Loop graph: " << std::endl; + { + IdGroups groups; + for (auto group : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + groups.pushBack(group); + } + std::cout << debug_print::idGroupsStringShort(groups) << std::endl; + } + + // std::cout << "p2c ca terminal: " << p2c_ca_terminal_loop_ids.toString() + // << std::endl; + // std::cout << "id consumer terminal: " + // << id_consumer_terminal_loop_ids.toString() << std::endl; + // std::cout << "Terminal: " << terminal_loop_ids.toString() << std::endl; + + std::cout << "Almost Exact graph: " << std::endl; + { + IdGroups groups; + for (auto group : + idGraph(IdMappingMode::ALMOSTEXACT).disjointIdSets().disjointSets()) { + groups.pushBack(group); + } + std::cout << debug_print::idGroupsStringShort(groups) << std::endl; + } + + auto intersection_exact_loop_graph = initializeIdGraph(); + + // Make an intersection of the exact and loop map. This will group together + // entries in each loop group that are exact with eachother. This provides a + // better graph to do promotion and replays. + + // It's tempting to use the intersection of the almost exact and loop, but we + // need to model broadcast promotion, and if we have two tensors like: + // + // T1[i0, b1] = T0[i0] + // T2[i0, b1] = T0[i0] + // + // Then resolution of: + // T4 = T1[i0, b1] + T3[i0, i1] + // T6 = T2[i0, b1] + T5[i0, i2] + // + // The almost exact map will map T1's and T2's b1 together, but they're being + // resolved to i1 and i2 respectively. So we want to have separate entries so + // we can have an easy to process promotion map. + // + // Loop is a permissive like map, it could have many entries, use the exact + // map as the one we iterate on to reduce complexity as it hopefully has + // smaller groups and this algorithm scales with the number of groups * + // (number of entries in groups ^ 2) + + for (auto exact_group : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + auto set_size = exact_group->size(); + for (auto id0_i : c10::irange(set_size)) { + auto id0 = exact_group->vector()[id0_i]; + for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { + auto id1 = exact_group->vector()[id1_i]; + // id0 and id1 map in the almost exact map, if they also map in the loop + // graph, then add the mapping to the inersection + if (idGraph(IdMappingMode::LOOP) + .disjointIdSets() + .strictAreMapped(id0, id1)) { + intersection_exact_loop_graph.mapIds(id0, id1); + } + } + } + } + + std::cout << "Intersection exact - loop: " << std::endl; + { + IdGroups groups; + for (auto group : + intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + groups.pushBack(group); + } + std::cout << debug_print::idGroupsStringShort(groups) << std::endl; + } + + // Promotion logic is going to be on the intersection of the exact and loop + // graph. We will generate a map on the entries of this graph so it's + // important to not modify this graph moving forward, as that would invalidate + // the map. + // + // iel stands for Intersection of the Exact and Loop graphs. + std::unordered_map iel_promotion_map; + + // Find terminating inputs to start traversal from in the iel graph. This + // graph is more strict than exact, so we can simply make sure there's no + // definitions in the group, or the group has an rfactor domain. + IdGroups terminating_inputs; + + for (auto iel_group : + intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + auto iel_group_defs = + intersection_exact_loop_graph.uniqueDefinitions(iel_group); + if (iel_group_defs.empty()) { + terminating_inputs.pushBack(iel_group); + continue; + } + + if (std::any_of(iel_group->begin(), iel_group->end(), [&](IterDomain* id) { + return viewRfactorIds().find(id) != viewRfactorIds().end(); + })) { + terminating_inputs.pushBack(iel_group); + } + } + + // This should probably work just on terminating inputs, as we shouldn't be + // able to modify a broadcast domain between root and rfactor which would be + // required to resolve a non input broadcast domain. But for now leaving it as + // traversal on all broadcast groups. + for (auto iel_group : + intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + if (!iel_group->front()->isBroadcast()) { + continue; + } + + // Collect all the exact groups of the resolutions of the broadcast id's + IdGroups resolved_exact_groups; + for (auto bcast_id : *iel_group) { + auto p2c_root_broadcast_resolution_map_it = + p2c_root_broadcast_resolution_map.find(bcast_id); + + if (p2c_root_broadcast_resolution_map_it == + p2c_root_broadcast_resolution_map.end()) { + continue; + } + + resolved_exact_groups.pushBack( + idGraph(IdMappingMode::EXACT) + .toGroups(p2c_root_broadcast_resolution_map_it->second)); + } + + // Collect all the exact groups in the loop set containing this iel_group + auto loop_group_pair = + idGraph(IdMappingMode::LOOP).disjointIdSet(iel_group->front()); + TORCH_INTERNAL_ASSERT(loop_group_pair.second); + auto loop_group = loop_group_pair.first; + auto loop_covered_exact_groups = + idGraph(IdMappingMode::EXACT).toGroups(*loop_group); + + // The intersection of the exact groups that the broadcast domains can be + // broadcasted to, and those that exist within the same loop are is the + // promotion needed for this iel_group. + auto loop_exact_resolved_intersection = + resolved_exact_groups.intersect(loop_covered_exact_groups); + + if (loop_exact_resolved_intersection.empty()) { + // No resolution + continue; + } + + if (loop_exact_resolved_intersection.size() > 1) { + std::stringstream err_msg; + + err_msg + << "Invalid multiple broadcast resolution within shared loops detected, group:\n " + << iel_group->toString() << "\nIs being broadcasted to:"; + + for (auto entry : loop_exact_resolved_intersection) { + err_msg << "\n " << entry->toString(); + } + TORCH_INTERNAL_ASSERT(false, err_msg.str()); + } + + // loop_exact_resolved_intersection.size() == 1 + auto exact_resolution_group = loop_exact_resolved_intersection.front(); + + VectorOfUniqueEntries resolved_ids = + exact_resolution_group->intersect(*loop_group); + auto promoted_iel_groups = + intersection_exact_loop_graph.toGroups(resolved_ids); + + if (promoted_iel_groups.size() == 0) { + continue; + } + + if (promoted_iel_groups.size() > 1) { + std::stringstream err_msg; + + err_msg + << "Invalid multiple broadcast resolution within shared loops detected, group:\n " + << iel_group->toString() << "\nIs being broadcasted to:"; + + for (auto entry : promoted_iel_groups) { + err_msg << "\n " << entry->toString(); + } + TORCH_INTERNAL_ASSERT(false, err_msg.str()); + } + + iel_promotion_map[iel_group] = promoted_iel_groups.front()->front(); + } + + std::cout << "Initial promotion map:" << std::endl; + + for (auto iel_group : + intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + auto entry_it = iel_promotion_map.find(iel_group); + if (entry_it == iel_promotion_map.end()) { + continue; + } + std::cout << entry_it->second->toString() << " <- " + << entry_it->first->toString() << std::endl; + } + + std::cout << "Loop graph: " << std::endl; + { + IdGroups groups; + for (auto group : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + groups.pushBack(group); + } + std::cout << debug_print::idGroupsStringShort(groups) << std::endl; + } + + // Initialize traversal of the iel graph and build promotions + IdGroups visited_ids = terminating_inputs; + + ExprGroups visited_exprs; + ExprGroups to_visit_exprs; + + for (auto terminating_input : terminating_inputs) { + to_visit_exprs.pushBack( + intersection_exact_loop_graph.uniqueUses(terminating_input)); + } + + while (to_visit_exprs.size() > 0) { + // Try to detect when nothing has been processed which would put us in an + // infinite loop + bool something_was_processed = false; + ExprGroups still_to_visit_exprs; + while (to_visit_exprs.size() > 0) { + auto currently_visiting_expr_group = to_visit_exprs.popFront(); + if (visited_exprs.has(currently_visiting_expr_group)) { + // Expr group already processed + continue; + } + + // Make sure all input groups have been processed, otherwise can't process + // this expr group + auto input_groups = intersection_exact_loop_graph.inputGroups( + currently_visiting_expr_group); + + bool all_inputs_processed = true; + for (auto inp : input_groups) { + if (!visited_ids.has(inp)) { + all_inputs_processed = false; + } + } + + if (!all_inputs_processed) { + // Not all input groups were processed, queue this expr up to be + // processed later + still_to_visit_exprs.pushBack(currently_visiting_expr_group); + continue; + } + + // This expr group is ready to be processed, mark it as visited as we are + // actively visiting it + visited_exprs.pushBack(currently_visiting_expr_group); + something_was_processed = true; + + // Mark outputs as visited, as we need to successfully visit this expr + auto out_groups = intersection_exact_loop_graph.outputGroups( + currently_visiting_expr_group); + + visited_ids.pushBack(out_groups); + + // Queue up output uses to be visited + for (auto out_group : out_groups) { + to_visit_exprs.pushBack( + intersection_exact_loop_graph.uniqueUses(out_group).subtract( + visited_exprs)); + } + + // Check if any inputs need promotion indicating this expr group needs to + // be replayed with promoted inputs + std::vector promoted_inputs; + bool an_input_was_promoted = false; + + for (auto inp : input_groups) { + auto inp_promo_it = iel_promotion_map.find(inp); + if (inp_promo_it == iel_promotion_map.end()) { + promoted_inputs.push_back(inp->front()); + } else { + promoted_inputs.push_back(inp_promo_it->second); + an_input_was_promoted = true; + } + } + + if (!an_input_was_promoted) { + // No inputs need promotion so just continue + continue; + } + + for (auto inp : input_groups) { + auto inp_promo_it = iel_promotion_map.find(inp); + if (inp_promo_it == iel_promotion_map.end()) { + std::cout << "IEL inp: " << debug_print::idGroupStringShort(inp) + << std::endl; + } else { + std::cout << "Promoted input: " + << debug_print::idGroupStringShort(inp) << " -> " + << inp_promo_it->second->toString() << std::endl; + } + } + + // TODO: Only replay if necessary? + // Expr* replay; + + // Replay expression with promoted inputs + Expr* replay = + addReplayAs(promoted_inputs, currently_visiting_expr_group->front()); + std::cout << "REPLAY:\n " << currently_visiting_expr_group->front() + << " " << replay->toString() << std::endl; + + // static int debug_count = 0; + // debug_count++; + + // if(debug_count == 10){ + // std::cout << "Loop map: " << std::endl; + // for (auto group : + // idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + // std::cout << debug_print::idGroupStringShort(group) << std::endl; + // } + + // TORCH_INTERNAL_ASSERT(false); + // } + + // Mark outputs as having a promoted iter domain + auto replay_out_ids = + ir_utils::filterByType(replay->outputs()).vector(); + + TORCH_INTERNAL_ASSERT(replay_out_ids.size() == out_groups.size()); + + for (auto i : c10::irange(replay_out_ids.size())) { + iel_promotion_map[out_groups.vector()[i]] = replay_out_ids[i]; + } + } + + std::swap(to_visit_exprs, still_to_visit_exprs); + + // Make sure something was processed in this iteration otherwise throw an + // infinite loop error. + if (!something_was_processed && to_visit_exprs.size() > 0) { + std::stringstream err_msg; + err_msg << "Infinite loop entered, visited ids:" << std::endl; + err_msg << debug_print::idGroupsStringShort(visited_ids) << std::endl; + err_msg << "Exprs visited:" << std::endl; + err_msg << debug_print::exprGroupsStringShort( + intersection_exact_loop_graph, visited_exprs) + << std::endl; + err_msg << "Exprs to visit:" << std::endl; + err_msg << debug_print::exprGroupsStringShort( + intersection_exact_loop_graph, to_visit_exprs) + << std::endl; + TORCH_INTERNAL_ASSERT(false, err_msg.str()); + } + } + + // std::cout << "Filled promotion map:" << std::endl; + // for (auto entry : iel_promotion_map) { + // std::cout << entry.second->toString() << " <- " << entry.first->toString() + // << std::endl; + // } + + // Map from an exact iter domain group, to all the exact iter domain groups it + // covers + std::unordered_map exact_covered_ids; + + for (auto id_group : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + // Initialize inputs + if (idGraph(IdMappingMode::EXACT).uniqueDefinitions(id_group).empty()) { + exact_covered_ids[id_group] = {id_group}; + } + + // Initialize rfactor groups + if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { + return view_rfactor_ids_.find(id) != view_rfactor_ids_.end(); + })) { + exact_covered_ids[id_group] = {id_group}; + } + + // Initialize broadcast groups to empty + if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { + return id->isBroadcast(); + })) { + exact_covered_ids[id_group] = {}; + } + } + + // Traverse expressions in exact map to populate exact_covered_ids entries. + { + ExprGroups all_expr_groups( + idGraph(IdMappingMode::EXACT).disjointExprSets().disjointSets().begin(), + idGraph(IdMappingMode::EXACT).disjointExprSets().disjointSets().end()); + + while (!all_expr_groups.empty()) { + ExprGroups still_to_visit; + + bool something_visited = false; + while (!all_expr_groups.empty()) { + ExprGroup currently_visiting = all_expr_groups.popBack(); + + auto input_groups = + idGraph(IdMappingMode::EXACT).inputGroups(currently_visiting); + + // Make sure expression group is ready to process + bool ready_to_visit = true; + for (auto inp_group : input_groups) { + if (exact_covered_ids.find(inp_group) == exact_covered_ids.end()) { + ready_to_visit = false; + } + } + + // If not ready re-enqueue and continue + if (!ready_to_visit) { + still_to_visit.pushBack(currently_visiting); + continue; + } + + something_visited = true; + // Visit expression + IdGroups covered; + for (auto inp_group : input_groups) { + covered.pushBack(exact_covered_ids.at(inp_group)); + } + + for (auto output_group : + idGraph(IdMappingMode::EXACT).outputGroups(currently_visiting)) { + exact_covered_ids[output_group] = covered; + } + } + + std::swap(still_to_visit, all_expr_groups); + + if (!something_visited) { + std::cout << "Not visited:" << std::endl; + debug_print::exprGroupsStringShort( + idGraph(IdMappingMode::EXACT), all_expr_groups); + } + TORCH_INTERNAL_ASSERT( + something_visited || all_expr_groups.empty(), + "Entered infinite loops, error traversing on exact map."); + } + } + + std::cout << "Covered exact entries:" << std::endl; + for(auto exact_group : idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()){ + auto exact_covered_id_it = exact_covered_ids.find(exact_group); + if(exact_covered_id_it == exact_covered_ids.end()){ + continue; + } + + std::cout << debug_print::idGroupStringShort(exact_group) << " -> " + << debug_print::idGroupsStringShort(exact_covered_id_it->second) << std::endl; + } + + std::unordered_map loop_promotion_map; + + for (auto loop_group : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + if (loop_group->size() == 1) { + loop_promotion_map[loop_group] = loop_group->front(); + continue; + } + + // We need to check the exact groups the terminal id's are in, but for + // promotion we want an iter domain within the loop group. Since exact group + // can traverse loop group boundaires, save a vector of the group and + // the iter domain. + std::vector> exact_promoted_terminal_ids; + for (auto loop_id : *loop_group) { + if (terminal_loop_ids.has(loop_id)) { + auto iel_set_pair = + intersection_exact_loop_graph.disjointIdSet(loop_id); + TORCH_INTERNAL_ASSERT(iel_set_pair.second); + auto iel_group = iel_set_pair.first; + auto iel_promo_it = iel_promotion_map.find(iel_group); + if (iel_promo_it == iel_promotion_map.end()) { + auto promo_id_exact_it = + idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); + TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); + exact_promoted_terminal_ids.push_back( + std::make_pair(promo_id_exact_it.first, loop_id)); + } else { + auto promo_id_exact_it = + idGraph(IdMappingMode::EXACT).disjointIdSet(iel_promo_it->second); + TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); + exact_promoted_terminal_ids.push_back( + std::make_pair(promo_id_exact_it.first, iel_promo_it->second)); + } + } + } + + // All exact groups with iter domains in this loop group + IdGroups exact_groups; + for (auto loop_id : *loop_group) { + auto exact_set_pair = + idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); + TORCH_INTERNAL_ASSERT(exact_set_pair.second); + exact_groups.pushBack(exact_set_pair.first); + } + + // All exact groups covered by all iter domains in this loop group + IdGroups loop_group_covered_ids; + for (auto exact_group : exact_groups) { + auto covered_it = exact_covered_ids.find(exact_group); + TORCH_INTERNAL_ASSERT(covered_it != exact_covered_ids.end()); + loop_group_covered_ids.pushBack(covered_it->second); + } + + IterDomain* loop_promotion_id = nullptr; + + for (auto entry : exact_promoted_terminal_ids) { + auto terminal_id_group = entry.first; + auto terminal_id = entry.second; + auto covered_it = exact_covered_ids.find(terminal_id_group); + TORCH_INTERNAL_ASSERT(covered_it != exact_covered_ids.end()); + if (loop_group_covered_ids.subtract(covered_it->second).size() == 0) { + loop_promotion_id = terminal_id; + } + } + + if (loop_promotion_id == nullptr) { + std::stringstream err_msg; + err_msg << "\nCould not find promotion for loop group:\n "; + err_msg << debug_print::idGroupsStringShort(loop_group_covered_ids); + err_msg + << "\nHowever, none of the iter domains that this group promotes to:\n"; + for (auto entry : exact_promoted_terminal_ids) { + auto terminal_id_group = entry.first; + err_msg << " " << debug_print::idGroupStringShort(terminal_id_group); + err_msg << "\ncover these groups\n"; + } + TORCH_INTERNAL_ASSERT(false, err_msg.str()); + } + + loop_promotion_map[loop_group] = loop_promotion_id; + } + + std::cout << "Loop graph: " << std::endl; + for (auto group : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + std::cout << debug_print::idGroupStringShort(group) << std::endl; + } + + std::cout << "Loop promotion map: " << std::endl; + for (auto group : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + if (loop_promotion_map.find(group) == loop_promotion_map.end()) { + continue; + } + std::cout << debug_print::idGroupStringShort(group) << " -> " + << loop_promotion_map.at(group)->toString() << std::endl; + } + + std::cout << "All exprs in loop map" << std::endl; + + std::cout << "\n\nTraversal test" << std::endl; + + IdGraphStmtSort loop_stmt_sort(idGraph(IdMappingMode::LOOP)); + for (auto loop_expr : loop_stmt_sort.exprs()) { + std::cout << " " + << debug_print::exprGroupStringShort( + idGraph(IdMappingMode::LOOP), loop_expr) + << std::endl; + } + + TORCH_INTERNAL_ASSERT(false); + + std::cout << "IEL Graph PRE: " << std::endl; + { + IdGroups groups; + for (auto group : + intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + groups.pushBack(group); + } + std::cout << debug_print::idGroupsStringShort(groups) << std::endl; + } + + iel_promotion_map.clear(); + + // Reinitialize the IEL graph, entries have been added since it's been built. + intersection_exact_loop_graph = initializeIdGraph(); + for (auto exact_group : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + auto set_size = exact_group->size(); + for (auto id0_i : c10::irange(set_size)) { + auto id0 = exact_group->vector()[id0_i]; + for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { + auto id1 = exact_group->vector()[id1_i]; + // id0 and id1 map in the almost exact map, if they also map in the loop + // graph, then add the mapping to the inersection + if (idGraph(IdMappingMode::LOOP) + .disjointIdSets() + .strictAreMapped(id0, id1)) { + intersection_exact_loop_graph.mapIds(id0, id1); + } + } + } + } + + std::cout << "IEL Graph POST: " << std::endl; + for (auto entry : + intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + std::cout << debug_print::idGroupStringShort(entry) << std::endl; + } + + TORCH_INTERNAL_ASSERT(false); + + for (auto iel_group : + intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + auto loop_group_pair = + idGraph(IdMappingMode::LOOP).disjointIdSet(iel_group->front()); + TORCH_INTERNAL_ASSERT(loop_group_pair.second); + auto loop_group = loop_group_pair.first; + + auto promo_entry_it = loop_promotion_map.find(loop_group); + + if (promo_entry_it == loop_promotion_map.end()) { + continue; + } + + auto promo_id = promo_entry_it->second; + + if (idGraph(IdMappingMode::ALMOSTEXACT) + .disjointIdSets() + .strictAreMapped(promo_id, iel_group->front())) { + continue; + } + + // Only promote terminal consumers in the loop groups, otherwise we could + // re-promote transformations. iel promotion map is going to be used to + // replay transformations depending on promoted inlined iter domains. We + // don't want to replay transformations within loop groups, really just + // across them. + if (id_consumer_terminal_loop_ids.has(iel_group->front())) { + iel_promotion_map[iel_group] = promo_id; + } + } + + std::cout << "IEL promotion map init2:" << std::endl; + for (auto entry : iel_promotion_map) { + std::cout << entry.second->toString() << " <- " << entry.first->toString() + << std::endl; + } + + TORCH_INTERNAL_ASSERT(false); + // Finish the promotion map, so far the iel promotion map is only replayed for + // iter domains within the inlined iter domains. However, branches off of the + // inlined iter domains could require replay. + { + ExprGroups all_expr_groups( + intersection_exact_loop_graph.disjointExprSets().disjointSets().begin(), + intersection_exact_loop_graph.disjointExprSets().disjointSets().end()); + + IdGroups visited; + // Initialize inputs + for (auto id_group : + intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + // Initialize inputs + if (intersection_exact_loop_graph.uniqueDefinitions(id_group).empty()) { + visited.pushBack(id_group); + } + + // Initialize rfactor groups + if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { + return view_rfactor_ids_.find(id) != view_rfactor_ids_.end(); + })) { + visited.pushBack(id_group); + } + + // Initialize broadcast groups to empty + if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { + return id->isBroadcast(); + })) { + visited.pushBack(id_group); + } + } + + while (!all_expr_groups.empty()) { + ExprGroups still_to_visit; + + bool something_visited = false; + while (!all_expr_groups.empty()) { + ExprGroup currently_visiting = all_expr_groups.popBack(); + + auto input_groups = + intersection_exact_loop_graph.inputGroups(currently_visiting); + + // Make sure expression group is ready to process + if (std::any_of( + input_groups.begin(), + input_groups.end(), + [&visited](IdGroup id_group) { + return !visited.has(id_group); + })) { + still_to_visit.pushBack(currently_visiting); + continue; + } + + something_visited = true; + + auto output_groups = + intersection_exact_loop_graph.outputGroups(currently_visiting); + + for (auto out_group : output_groups) { + visited.pushBack(out_group); + } + + // // If all the output groups are resolved by inlined promotion, they + // // shouldn't be replayed here. + // if (std::all_of( + // output_groups.begin(), + // output_groups.end(), + // [&ordered_p_ca_ids](IdGroup out_group) { + // return std::any_of( + // out_group->begin(), + // out_group->end(), + // [&ordered_p_ca_ids](IterDomain* out_group_id) { + // return ordered_p_ca_ids.has(out_group_id); + // }); + // })) { + // continue; + // } + + std::vector promoted_inputs; + + bool input_is_promoted = false; + for (auto inp_group : input_groups) { + auto inp_promo_it = iel_promotion_map.find(inp_group); + if (inp_promo_it == iel_promotion_map.end()) { + promoted_inputs.push_back(inp_group->front()); + } else { + promoted_inputs.push_back(inp_promo_it->second); + input_is_promoted = true; + } + } + + if (!input_is_promoted) { + continue; + } + + if (std::none_of( + output_groups.begin(), + output_groups.end(), + [&iel_promotion_map](IdGroup out_group) { + return iel_promotion_map.find(out_group) == + iel_promotion_map.end(); + })) { + continue; + } + + Expr* replay = + addReplayAs(promoted_inputs, currently_visiting->front()); + + std::cout << "REPLAY2:\n " << currently_visiting->front() << " " + << replay->toString() << std::endl; + + // Mark outputs as having a promoted iter domain + auto replay_out_ids = + ir_utils::filterByType(replay->outputs()).vector(); + + TORCH_INTERNAL_ASSERT(replay_out_ids.size() == output_groups.size()); + + for (auto i : c10::irange(replay_out_ids.size())) { + iel_promotion_map[output_groups.vector()[i]] = replay_out_ids[i]; + } + } + + std::swap(still_to_visit, all_expr_groups); + + TORCH_INTERNAL_ASSERT( + something_visited || all_expr_groups.empty(), + "Entered infinite loops, error traversing on exact map."); + } + } + std::cout << "IEL promotion map:" << std::endl; + for (auto entry : iel_promotion_map) { + std::cout << entry.second->toString() << " <- " << entry.first->toString() + << std::endl; + } + + TORCH_INTERNAL_ASSERT(false); +} + +void IterDomainGraphs::buildIndexMap(const std::vector& all_tvs) { + // Initialize map at loop leaf nodes. This needs to be done just like we + // would in "initializeId" for the exact map. Unlike AlmostExact and + // Permissive, index map is not a superset of exact map. + for (auto loop_group : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + for (auto id : *loop_group) { + auto id_disjoint_set = idGraph(IdMappingMode::INDEX) + .disjointIdSets() + .initializeSet(id) + .first->second; + + auto def_it = id_definitions_.find(id); + if (def_it != id_definitions_.end()) { + auto defs = def_it->second; + ExprGroups expr_groups; + for (auto def : defs) { + auto expr_set = idGraph(IdMappingMode::INDEX) + .disjointExprSets() + .initializeSet(def) + .first->second; + expr_groups.pushBack(expr_set); + } + idGraph(IdMappingMode::INDEX).uniqueDefinitions()[id_disjoint_set] = + expr_groups; + } else { + id_definitions_[id] = {}; + idGraph(IdMappingMode::INDEX).uniqueDefinitions()[id_disjoint_set] = {}; + } + + auto use_it = id_uses_.find(id); + if (use_it != id_uses_.end()) { + auto uses = use_it->second; + ExprGroups expr_groups; + for (auto use : uses) { + auto expr_set = idGraph(IdMappingMode::INDEX) + .disjointExprSets() + .initializeSet(use) + .first->second; + expr_groups.pushBack(expr_set); + } + idGraph(IdMappingMode::INDEX).uniqueUses()[id_disjoint_set] = + expr_groups; + } else { + id_uses_[id] = {}; + idGraph(IdMappingMode::INDEX).uniqueUses()[id_disjoint_set] = {}; + } + } + } + + // Below is the same as building the almost exact map. It just maps through + // trivial expressions and removes their traversal from definition/uses + VectorOfUniqueEntries exprs; + for (auto expr : + idGraph(IdMappingMode::INDEX).disjointExprSets().disjointSets()) { + exprs.pushBack(expr->front()); + } + ExprGroups trivial_expr_groups; + + // Map through trivial expressions + for (auto expr : exprs) { + auto mapped_ids = IdGraph::isTrivialExpr(expr); + for (auto mapped_id_group : mapped_ids) { + for (auto id : mapped_id_group) { + trivial_expr_groups.pushBack( + idGraph(IdMappingMode::INDEX).disjointExprSet(expr).first); + idGraph(IdMappingMode::INDEX).mapIds(mapped_id_group.front(), id); + } + } + } + + // Clear out expressions that map inputs and outputs to the same group from + // definitions and uses. They shouldn't be important in traversal. Iterate + // on a copy as we're updating the map as we traverse. + std::unordered_map defs_copy = + idGraph(IdMappingMode::INDEX).uniqueDefinitions(); + for (auto& id_2_expr_group_map_entry : defs_copy) { + ExprGroups expr_groups_new; + for (auto& expr_group : id_2_expr_group_map_entry.second) { + if (!trivial_expr_groups.has(expr_group)) { + expr_groups_new.pushBack(expr_group); + } + } + + if (expr_groups_new.size() == id_2_expr_group_map_entry.second.size()) { + continue; + } + + idGraph(IdMappingMode::INDEX) + .uniqueDefinitions()[id_2_expr_group_map_entry.first] = expr_groups_new; + } + + std::unordered_map uses_copy = + idGraph(IdMappingMode::INDEX).uniqueUses(); + for (auto& id_2_expr_group_map_entry : uses_copy) { + ExprGroups expr_groups_new; + for (auto expr_group : id_2_expr_group_map_entry.second) { + if (!trivial_expr_groups.has(expr_group)) { + expr_groups_new.pushBack(expr_group); + } + } + + if (expr_groups_new.size() == id_2_expr_group_map_entry.second.size()) { + continue; + } + if (!expr_groups_new.empty()) { + for (auto i : c10::irange(100)) { + if (i > 0) { + expr_groups_new.pushBack(expr_groups_new.front()); + } + } + } + + idGraph(IdMappingMode::INDEX) + .uniqueUses()[id_2_expr_group_map_entry.first] = expr_groups_new; + } + + for (auto loop_group : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + auto loop_promotion_it = loop_promotion_map_.find(loop_group); + } + IdGroups processed; + + for (auto tv : all_tvs) { + if (tv->isFusionInput()) { + continue; + } + for (auto id : tv->domain()->domain()) { + auto loop_group_pair = idGraph(IdMappingMode::LOOP).disjointIdSet(id); + TORCH_INTERNAL_ASSERT( + loop_group_pair.second, + "Loop group not found for leaf id: ", + id->toString()); + auto loop_group = loop_group_pair.first; + if (processed.has(loop_group)) { + continue; + } + processed.pushBack(loop_group); + + auto loop_promotion_it = loop_promotion_map_.find(loop_group); + TORCH_INTERNAL_ASSERT(loop_promotion_it != loop_promotion_map_.end()); + IterDomain* promoted_id = loop_promotion_it->second; + + for (auto loop_group_id : *loop_group) { + if (loop_group_id == promoted_id) { + continue; + } + if (idGraph(IdMappingMode::ALMOSTEXACT) + .disjointIdSets() + .permissiveAreMapped(loop_group_id, promoted_id)) { + idGraph(IdMappingMode::INDEX).mapIds(loop_group_id, promoted_id); + } + } + } + } +} + +} // namespace nvfuser diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h new file mode 100644 index 00000000000..c0d76456875 --- /dev/null +++ b/csrc/id_graphs.h @@ -0,0 +1,465 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace nvfuser { + +using IdGroup = std::shared_ptr>; +using IdGroups = VectorOfUniqueEntries; +using ExprGroup = std::shared_ptr>; +using ExprGroups = VectorOfUniqueEntries; + +class TORCH_CUDA_CU_API IdGraph { + public: + IdGraph() = default; + + IdGraph(const IdGraph& other); + IdGraph(IdGraph&& other) = default; + + IdGraph& operator=(const IdGraph& other); + IdGraph& operator=(IdGraph&& other) = default; + + // Returns the disjoint IterDomain set. + const DisjointSets& disjointIdSets() const; + + DisjointSets& disjointIdSets(); + + // Returns + // { + // (1) The disjoint set of the provided Iter Domain if it exists, + // otherwise a null shared ptr + // (2) If the disjoint set of the provided Iter Domain exists + // } + std::pair disjointIdSet(IterDomain* id) const; + + // Returns the disjoint Expr set. + const DisjointSets& disjointExprSets() const; + + DisjointSets& disjointExprSets(); + + // Same as getDisjointIdSet but for the Expression sets. + std::pair disjointExprSet(Expr* expr) const; + + // Convert unique vector of expressions to unique vector of its groups + ExprGroups toGroups(const VectorOfUniqueEntries& exprs) const; + + // Convert unique vector of IterDomain to unique vector of its groups + IdGroups toGroups(const VectorOfUniqueEntries& ids) const; + + // Return output iter domain groups of provided expr + IdGroups outputGroups(ExprGroup expr) const; + + // Return input iter domain groups of provided expr + IdGroups inputGroups(ExprGroup expr) const; + + // Traverses uses of the IdGroups in 'of' and returns all ExprGroups + // that have a use in their definition of provided of IdGroups. + ExprGroups allUsesOf(const IdGroups& of) const; + + // Traverses definitions of the IdGroups in 'of' and returns all ExprGroups + // used in this history of defining the 'of' IdGroups. + ExprGroups allDefinitionsOf(const IdGroups& of) const; + + // Return sorted expressions to go from the provided IterDomains in from to + // the provided IterDomains in to with provided mode. Minimal expressions to + // get from 'from' to 'to' returned. + ExprGroups getExprsBetween(const IdGroups& from, const IdGroups& to) const; + + // Supports one to many mappings, uses the disjoint sets of the provided mode + // to produce mappings between from and to. If multiple IterDomains in to map + // to a single iter domain in from, the order of the IterDomains in value of + // the map is preserved to be the order provided in to. + std::unordered_map> + buildMapBetween( + const std::vector& from, + const std::vector& to) const; + + // Alias of the above on unique vector entries + std::unordered_map> + buildMapBetween( + const VectorOfUniqueEntries& from, + const VectorOfUniqueEntries& to) const; + + //! Returns + //! (1) The expressions associated with the definitions of the provided + //! IterDomain group in the provided mapping mode (if it exists). + //! (2) If there is a definitions entry of the provided IterDomain group in + //! the provided mapping mode. + //! First entry in the returned pair is a vector of vector of expressions. The + //! inner vector is proven to be equivalent based on the provided mode. The + //! outer vector are expression groups that are not equivalent based on the + //! provided mode, but produce one of the IterDomains within the same disjoint + //! Iter Domain set based on the provided mode. + //! TODO: Change name to start with get + std::pair iterDomainGroupDefinitions( + IdGroup id_group) const; + + //! Same as iterDomainGroupDefinitions but for uses instead of definitions + //! TODO: Change name to start with get + std::pair iterDomainGroupUses(IdGroup id_group) const; + + std::string toString() const; + + // Checks if the expression is a trivial operation where an input is simply an + // output of the transformation. Returns the mapped iter domains if found. + static std::vector> isTrivialExpr(Expr* expr); + + // Initializes entries for the provided IterDomain in the IterDomainGraphs + void initializeId( + IterDomain* id, + const VectorOfUniqueEntries& definitions, + const VectorOfUniqueEntries& uses); + + // Returns if first and second are expressions through which the provided + // id_map have matching inputs (if forward), or outputs (if not forward). + // Returning true means the expressions are "the same", in terms they modify + // matching original extents, by the same amount. + bool exprsMap( + Expr* first, + Expr* second, + bool forward + // , std::vector second_input_or_output_override + ) const; + + // If entry exists in id_definitions for provided group in provided mode, + // returns that entry, otherwise goes through all iter domains in the group + // and accumulates their id_definitions_ entries + ExprGroups uniqueDefinitions(IdGroup group) const; + + // If entry exists in id_uses for provided group in provided mode, + // returns that entry, otherwise goes through all iter domains in the group + // and accumulates their id_uses_ entries + ExprGroups uniqueUses(IdGroup group) const; + + std::unordered_map& uniqueUses() { + return unique_uses_; + } + + std::unordered_map& uniqueDefinitions() { + return unique_definitions_; + } + + // Set id0 and id1 to mapped in disjointIdsSet[mode], attempt to propagate + // new mapping through id0/id1 definitions/uses. + void mapIds(IterDomain* id0, IterDomain* id1); + + // Map expr0 and expr1 with eachother, update unique_definitions_ unique_uses_ + void mapExprs(Expr* expr0, Expr* expr1); + + // Checks if expr's are considered "the same" where sameness inputs and + // outputs in the same position across expressions map with provided + // MappingMode. If the expressions are determined the same then + // if forward + // will map outputs + // else + // will map inputs + // in the provided mode. + // Returns if expressions were mapped through. + bool mapThroughExpr(Expr* first, Expr* second, bool forward); + + // Map through loop swizzles, as input/output IterDomains are exact, only the + // order they're traversed differs. + void mapThroughLoopSwizzles(); + + private: + // Keeps a disjoint set entry for all IterDomain for all mapping mode types. + // + // Using an array here might be nice, but it seems hard to use an enum as an + // array key + // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum + DisjointSets disjoint_ids_; + + // Keeps a disjoint set entry for all Expressions for all mapping mode types. + DisjointSets disjoint_exprs_; + + std::unordered_map unique_definitions_; + + std::unordered_map unique_uses_; + + // If multiple transformations occur IterDomains could have multiple uses, + // however only one should be active in the given Fusion. When we resolve loop + // promotions during lowering, we can generate new iter domains from existing + // ones, so there can be multiple uses generated. Tracks all the active iter + // domain uses. + std::unordered_map> id_uses_; + + // Make sure we don't blindly use definitions as we don't want to grab + // transformations before a tensor view's root domain. + std::unordered_map> id_definitions_; + + // Hold a set of IterDomains that are considered view rfactor ids. This + // identification is particularly important to understand if split operations + // are divisible or not. + // + // TODO: This should just be in IterDomainGraphs, not here. + std::unordered_set view_rfactor_ids_; +}; + +// Iterates through an IterDomain Graph in topological order, calling handle on +// all Id and all Expr groups in a forward topological order. +// +// Warning: Expr groups that have an input and output in the same IdGroup are +// ignored. +// +// Warning: This is not a great iterator if there's a desire to minimize paths +// traveled to simply visit all IdGroups in order. See ExprsBetween to see how +// we might minimize paths. +class TORCH_CUDA_CU_API IdGraphVisitor { + protected: + // If sub_selection is assumed to be a set of iter domains by which form a + // sub-regrion of the IdGraph provided. Only that sub-region will be visited. + IdGraphVisitor( + IdGraph& id_graph, + const VectorOfUniqueEntries sub_selection = {}) + : id_graph_(id_graph), sub_selection_(sub_selection) {} + + virtual void handle(IdGroup id_group) = 0; + virtual void handle(ExprGroup expr_group) = 0; + + void traverse(); + + IdGraph& graph() { + return id_graph_; + }; + + IdGraphVisitor() = delete; + + IdGraphVisitor(const IdGraphVisitor& other) = default; + IdGraphVisitor& operator=(const IdGraphVisitor& other) = default; + + IdGraphVisitor(IdGraphVisitor&& other) = default; + IdGraphVisitor& operator=(IdGraphVisitor&& other) = default; + + virtual ~IdGraphVisitor() = default; + + private: + IdGraph& id_graph_; + const VectorOfUniqueEntries sub_selection_; +}; + +// Statement sorting based on IdGraphVisitor, see warnings to IdGraph Visitor. +class IdGraphStmtSort : public IdGraphVisitor { + public: + IdGraphStmtSort( + IdGraph& id_graph, + const VectorOfUniqueEntries sub_selection = {}) + : IdGraphVisitor(id_graph, sub_selection) { + IdGraphVisitor::traverse(); + } + + ExprGroups exprs() { + return sorted_exprs; + } + + IdGroups ids() { + return sorted_ids; + } + + ~IdGraphStmtSort() override = default; + + protected: + using IdGraphVisitor::handle; + void handle(IdGroup id_group) override { + sorted_ids.pushBack(id_group); + } + + void handle(ExprGroup expr_group) override { + sorted_exprs.pushBack(expr_group); + } + + ExprGroups sorted_exprs; + IdGroups sorted_ids; +}; + +// There's three modes of these iter domain mappings all uniquely important in +// the lowering process. +// +// For EXACT/PERMISSIVE mode consider: +// +// consumer[i0, b1] = producer[i0] +// consumer->merge(0) (consumer will now be [i0 * b1]) +// When producer is replayed as consumer (the direction we use for mapping) +// with BestEffortReplay forward_bcast_mismatch = True the producer to +// consumer map will have both a mapping of consumer(i0) to producer(i0) as +// well as consumer(i0*b1) to producer(i0). This latter mapping is important +// for loop nest mappings as the consumer will generate a loop based on i0*b1 +// and the producer may be computeAt inside this loop nest. However, for +// indexing we do not want these two maps as producer may be indexed as i0*i1 +// depending on the loop nest structure and how it was built. Therefore we +// really need to carry (at least) two sets of maps around for lowering. +// +// LOOP mode is important if we have something like: +// consumer[i0o, threadIdx.x{i0i}] = producer[i0o, threadIdx.y{i0i}](computeAt +// = 1) which can easily happen when using shared memory. We want to make sure +// that the iteration domain used for loop construction (concreteId) has the +// proper parallelization strategy. In parallel mode we do typical iteration +// domain mapping, however we remove from it any iteration domains outside the +// computeAt of producer when mapping. This guarentees we won't map +// IterDomains that could have different parallelization strategies. We also +// propagate the parallel strategy in parallel mode so all mapped IDs that +// must have the same parallel type, do. +// +// IdMappingMode::LOOP +// Only maps leaf axes to left of compute at +// Forward broadcast axes in replay +// IdMappingMode::PERMISSIVE +// Forward broadcast axes in replay +// Map all iteration domains +// Always contain root mappings (otherwise they could have been forwarded in +// broadcast) +// IdMappingMode::EXACT +// Don't map any broadcast axes to non-broadcast axes +// Do not forward through any broadcast IDs +// IdMappingMode::AlmostExact +// Forward through broadcast axes, but not through to a non-broadcast axis +// i.e. id{b1*i0}, id{i0} are mapped +// id{i1*i0}, id{i0} are not mapped (this part is the difference from +// PERMISSIVE) +// Forward through split one axes, i.e. id{ceilDiv(i0, 1)}, id{i0} are mapped +// +class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { + public: + IterDomainGraphs( + const std::vector& exprs, + const std::vector& additional_tvs, + bool allow_self_mapping = false); + + IterDomainGraphs( + const std::vector& exprs, + bool allow_self_mapping = false); + + // Same as the above constructor with fusion->exprs() excpet fusion may have + // some dangling inputs/outputs that are expected to have IterDomain entries + // even though there's no possible connections from them. + IterDomainGraphs(Fusion* fusion, bool allow_self_mapping = false); + + // Returns iter domain graph of provided mode. + const IdGraph& idGraph(IdMappingMode mode) const; + IdGraph& idGraph(IdMappingMode mode); + + // IterDomains from the original fusion are only allowed to be used once in + // the IterDomain graph, id->uses() are not directly used as there's no bounds + // check that would prevent a use from being defined that's not part of the + // actual fusion definition. + // + // Note, any iter domains used during something like loop or concrete id + // resolution could actually have multiple Expr* uses, and uses on disjoint id + // sets should be used, not this. + // + // TODO: Refactor or remove? + Expr* idUse(IterDomain* id) const; + Expr* idDef(IterDomain* id) const; + + // TODO: Seems a bit unfortunate that this isn't IterDomain local information. + const std::unordered_set& viewRfactorIds() const { + return view_rfactor_ids_; + } + + // Returns if a self mapping was detected that would invalidate assumptions of + // the overall lowering system. + // + // TODO: Can we make this more of an alias analysis? + // Ref: https://github.com/csarofeen/pytorch/pull/1954#discussion_r961940498 + bool hasSelfMapping() const { + return self_mapping_info_.has_value(); + } + + // Update the LOOP ID disjoint sets with resolved computeWith + void updateComputeWith(TensorView* compute_with_tv); + + std::string toString() const; + + // Replay Expr but with the inputs provided. IterDomainGraphss will be updated + // for all maps that have entries, adding the output iter domains of the + // replayed expression and adding potential mappings through the expression. + Expr* addReplayAs(const std::vector& new_inputs, Expr* expr); + + protected: + // Sometimes fusion inputs or outputs are disconnected from expressions, in + // those cases we still may want to send in some additional tensor views from + // the Fusion that don't have expressions associated with them. + void build( + const std::vector& exprs, + const std::vector& additional_tvs); + + // ======= START Iteration domain build process in order called ======= + + // Fills id_uses_ and id_definitions_ for all IterDomains active in the + // fusion. + void buildIterDomainDefinitionsAndUses( + const std::vector& all_tvs); + + // Iterates over all IterDomains in id_definitions_ and calls initializeID on + // a new IdGraph and returns it. + IdGraph initializeIdGraph(); + + // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs + // and first output of expr + void buildExactMap(const std::vector& exprs); + + // Fills disjoint_ids_[IdMappingMode::ALMOSTEXACT]. Initialize AlmostExact as + // Exact entries, then map anything that's either merged with a size-1 or + // split by a size-1 dimension. + void buildAlmostExactMap(); + + // Fills disjoint_ids_[IdMappingMode::PERMISSIVE]. Initialize PermissiveMap as + // AlmostExact entries, then map through broadcasts + void buildPermissiveMap(const std::vector& exprs); + + //! Run through disjoint sets in the LOOP map, make sure there's only one + //! non-serial parallel type in each disjoint set, set the parallel type of + //! all IterDomains in the disjoint set to that PType. + void validateAndPropagatePType() const; + + void buildLoopPromotionMap(const std::vector& exprs); + + // Returns the terminal rfactor or input iter domains each group in the almost + // exact map covers (in the almost exact map). This effectively returns all + // the input almost exact iter domain groups for each almost exact iter domain + // group. RFactor axes are considered an "input" as all broadcast dimensions + // have to be resolved by or before the rfactor iter domain. + std::unordered_map buildCoveredAlmostExact(); + + void buildIndexMap(const std::vector& all_tvs); + + // ======= END Iteration domain build process in order called ======= + + // Errors if self mapping occurs + void assertNoSelfMapping(); + + // Keeps a disjoint set entry for all IterDomain for all mapping mode types. + // + // Using an array here might be nice, but it seems hard to use an enum as an + // array key + // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum + std::unordered_map id_graphs_; + + // If multiple transformations occur IterDomains could have multiple uses, + // however only one should be active in the given Fusion. When we resolve loop + // promotions during lowering, we can generate new iter domains from existing + // ones, so there can be multiple uses generated. Tracks all the active iter + // domain uses. + std::unordered_map> id_uses_; + + // Make sure we don't blindly use definitions as we don't want to grab + // transformations before a tensor view's root domain. + std::unordered_map> id_definitions_; + + // Debug information to hold if a self mapping in a TensorView is found. + c10::optional> + self_mapping_info_ = c10::nullopt; + + std::unordered_map loop_promotion_map_; + + std::unordered_set view_rfactor_ids_; +}; + +using DoubleBufferIndices = std::unordered_map; + +} // namespace nvfuser From eee9bf7371f0dedb08ee42cf1f3315020e6ed8de Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 23 Mar 2023 09:48:02 -0400 Subject: [PATCH 003/178] Forward replay now works, need to implement backward index replay. --- csrc/id_graphs.cpp | 986 ++++++++++++++++++++----------------- test/test_gpu_indexing.cpp | 33 ++ 2 files changed, 562 insertions(+), 457 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 67ce0d6de81..b2b31ab5123 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -1902,6 +1902,7 @@ std::unordered_map resolvedRootBroadcasts( if (c_id->isBroadcast() || c_id->isReduction()) { continue; } + resolved_bcast_map[p_id] = c_id; } return resolved_bcast_map; @@ -2103,7 +2104,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { {producer_root.begin(), producer_root.end()}, {producer_domain.begin(), producer_domain.begin() + producer->getComputeAtPosition()}); - auto ca_deps_filter = ir_utils::filterByType(ca_dep_vals); all_producer_ca_deps.insert( @@ -2157,6 +2157,24 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } + std::cout << "p2c_root_broadcast_resolution_map" << std::endl; + for (auto p_id : ordered_p_ca_ids) { + if (p2c_root_broadcast_resolution_map.find(p_id) != + p2c_root_broadcast_resolution_map.end()) { + std::cout << p_id->toString() << " -> " + << p2c_root_broadcast_resolution_map.at(p_id).toString(); + } + } + + std::cout << "p2c_ca_permissive_maps" << std::endl; + for (auto p_id : ordered_p_ca_ids) { + if (p2c_ca_permissive_maps.find(p_id) != p2c_ca_permissive_maps.end()) { + std::cout << p_id->toString() << " -> " + << p2c_ca_permissive_maps.at(p_id).toString() << std::endl; + ; + } + } + // Terminal loop ids are iteration domains in each loop group that: // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a // consumer TV's iter domain maps to this domain in a way that that domain @@ -2215,15 +2233,15 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { terminal_loop_ids = p2c_ca_terminal_loop_ids.intersect(id_consumer_terminal_loop_ids); - std::cout << "Loop graph: " << std::endl; - { - IdGroups groups; - for (auto group : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - groups.pushBack(group); - } - std::cout << debug_print::idGroupsStringShort(groups) << std::endl; - } + // std::cout << "Loop graph: " << std::endl; + // { + // IdGroups groups; + // for (auto group : + // idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + // groups.pushBack(group); + // } + // std::cout << debug_print::idGroupsStringShort(groups) << std::endl; + // } // std::cout << "p2c ca terminal: " << p2c_ca_terminal_loop_ids.toString() // << std::endl; @@ -2231,15 +2249,16 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // << id_consumer_terminal_loop_ids.toString() << std::endl; // std::cout << "Terminal: " << terminal_loop_ids.toString() << std::endl; - std::cout << "Almost Exact graph: " << std::endl; - { - IdGroups groups; - for (auto group : - idGraph(IdMappingMode::ALMOSTEXACT).disjointIdSets().disjointSets()) { - groups.pushBack(group); - } - std::cout << debug_print::idGroupsStringShort(groups) << std::endl; - } + // std::cout << "Almost Exact graph: " << std::endl; + // { + // IdGroups groups; + // for (auto group : + // idGraph(IdMappingMode::ALMOSTEXACT).disjointIdSets().disjointSets()) + // { + // groups.pushBack(group); + // } + // std::cout << debug_print::idGroupsStringShort(groups) << std::endl; + // } auto intersection_exact_loop_graph = initializeIdGraph(); @@ -2284,15 +2303,15 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - std::cout << "Intersection exact - loop: " << std::endl; - { - IdGroups groups; - for (auto group : - intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - groups.pushBack(group); - } - std::cout << debug_print::idGroupsStringShort(groups) << std::endl; - } + // std::cout << "Intersection exact - loop: " << std::endl; + // { + // IdGroups groups; + // for (auto group : + // intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + // groups.pushBack(group); + // } + // std::cout << debug_print::idGroupsStringShort(groups) << std::endl; + // } // Promotion logic is going to be on the intersection of the exact and loop // graph. We will generate a map on the entries of this graph so it's @@ -2302,27 +2321,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // iel stands for Intersection of the Exact and Loop graphs. std::unordered_map iel_promotion_map; - // Find terminating inputs to start traversal from in the iel graph. This - // graph is more strict than exact, so we can simply make sure there's no - // definitions in the group, or the group has an rfactor domain. - IdGroups terminating_inputs; - - for (auto iel_group : - intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - auto iel_group_defs = - intersection_exact_loop_graph.uniqueDefinitions(iel_group); - if (iel_group_defs.empty()) { - terminating_inputs.pushBack(iel_group); - continue; - } - - if (std::any_of(iel_group->begin(), iel_group->end(), [&](IterDomain* id) { - return viewRfactorIds().find(id) != viewRfactorIds().end(); - })) { - terminating_inputs.pushBack(iel_group); - } - } - // This should probably work just on terminating inputs, as we shouldn't be // able to modify a broadcast domain between root and rfactor which would be // required to resolve a non input broadcast domain. But for now leaving it as @@ -2421,167 +2419,120 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { << entry_it->first->toString() << std::endl; } - std::cout << "Loop graph: " << std::endl; - { - IdGroups groups; - for (auto group : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - groups.pushBack(group); - } - std::cout << debug_print::idGroupsStringShort(groups) << std::endl; - } - - // Initialize traversal of the iel graph and build promotions - IdGroups visited_ids = terminating_inputs; - - ExprGroups visited_exprs; - ExprGroups to_visit_exprs; - - for (auto terminating_input : terminating_inputs) { - to_visit_exprs.pushBack( - intersection_exact_loop_graph.uniqueUses(terminating_input)); - } - - while (to_visit_exprs.size() > 0) { - // Try to detect when nothing has been processed which would put us in an - // infinite loop - bool something_was_processed = false; - ExprGroups still_to_visit_exprs; - while (to_visit_exprs.size() > 0) { - auto currently_visiting_expr_group = to_visit_exprs.popFront(); - if (visited_exprs.has(currently_visiting_expr_group)) { - // Expr group already processed - continue; - } - - // Make sure all input groups have been processed, otherwise can't process - // this expr group - auto input_groups = intersection_exact_loop_graph.inputGroups( - currently_visiting_expr_group); - - bool all_inputs_processed = true; - for (auto inp : input_groups) { - if (!visited_ids.has(inp)) { - all_inputs_processed = false; - } - } - - if (!all_inputs_processed) { - // Not all input groups were processed, queue this expr up to be - // processed later - still_to_visit_exprs.pushBack(currently_visiting_expr_group); - continue; - } - - // This expr group is ready to be processed, mark it as visited as we are - // actively visiting it - visited_exprs.pushBack(currently_visiting_expr_group); - something_was_processed = true; + // std::cout << "Loop graph: " << std::endl; + // { + // IdGroups groups; + // for (auto group : + // idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + // groups.pushBack(group); + // } + // std::cout << debug_print::idGroupsStringShort(groups) << std::endl; + // } - // Mark outputs as visited, as we need to successfully visit this expr - auto out_groups = intersection_exact_loop_graph.outputGroups( - currently_visiting_expr_group); + IdGraphStmtSort iel_stmt_sort(intersection_exact_loop_graph); - visited_ids.pushBack(out_groups); + for (auto iel_expr : iel_stmt_sort.exprs()) { + auto input_groups = intersection_exact_loop_graph.inputGroups(iel_expr); + // Check if any inputs need promotion indicating this expr group needs to + // be replayed with promoted inputs + std::vector promoted_inputs; + bool an_input_was_promoted = false; - // Queue up output uses to be visited - for (auto out_group : out_groups) { - to_visit_exprs.pushBack( - intersection_exact_loop_graph.uniqueUses(out_group).subtract( - visited_exprs)); + for (auto inp : input_groups) { + auto inp_promo_it = iel_promotion_map.find(inp); + if (inp_promo_it == iel_promotion_map.end()) { + promoted_inputs.push_back(inp->front()); + } else { + promoted_inputs.push_back(inp_promo_it->second); + an_input_was_promoted = true; } + } - // Check if any inputs need promotion indicating this expr group needs to - // be replayed with promoted inputs - std::vector promoted_inputs; - bool an_input_was_promoted = false; + if (!an_input_was_promoted) { + // No inputs need promotion so just continue + continue; + } - for (auto inp : input_groups) { - auto inp_promo_it = iel_promotion_map.find(inp); - if (inp_promo_it == iel_promotion_map.end()) { - promoted_inputs.push_back(inp->front()); - } else { - promoted_inputs.push_back(inp_promo_it->second); - an_input_was_promoted = true; + for (auto inp : input_groups) { + auto inp_promo_it = iel_promotion_map.find(inp); + if (inp_promo_it == iel_promotion_map.end()) { + std::cout << "IEL inp: " << debug_print::idGroupStringShort(inp) + << std::endl; + } else { + std::cout << "Promoted input: " << debug_print::idGroupStringShort(inp) + << " -> " << inp_promo_it->second->toString() << std::endl; + } + } + + Expr* replay = nullptr; + + // Before replaying, check if there's already an expression like this, if so + // use that for promotion. + ExprGroups promoted_input_uses; + for (auto inp_id : promoted_inputs) { + auto inp_exact_group = + idGraph(IdMappingMode::EXACT).toGroups({inp_id}).front(); + promoted_input_uses.pushBack( + idGraph(IdMappingMode::EXACT).uniqueUses(inp_exact_group)); + } + + for (auto exact_use_group : promoted_input_uses) { + std::cout << "Check use: " << exact_use_group->front()->toString(); + if (transformAtributesMatch( + iel_expr->front(), exact_use_group->front())) { + std::cout << "Attributes match" << std::endl; + auto exact_use_inps = ir_utils::filterByType( + exact_use_group->front()->inputs()) + .vector(); + bool inps_match = true; + for (auto inp_i : c10::irange(exact_use_inps.size())) { + inps_match = inps_match && + idGraph(IdMappingMode::EXACT) + .disjointIdSets() + .strictAreMapped( + exact_use_inps[inp_i], promoted_inputs[inp_i]); + if (!idGraph(IdMappingMode::EXACT) + .disjointIdSets() + .strictAreMapped( + exact_use_inps[inp_i], promoted_inputs[inp_i])) { + std::cout << exact_use_inps[inp_i]->toString() << " doesn't match " + << promoted_inputs[inp_i]->toString() << std::endl; + } } - } - - if (!an_input_was_promoted) { - // No inputs need promotion so just continue - continue; - } - - for (auto inp : input_groups) { - auto inp_promo_it = iel_promotion_map.find(inp); - if (inp_promo_it == iel_promotion_map.end()) { - std::cout << "IEL inp: " << debug_print::idGroupStringShort(inp) - << std::endl; - } else { - std::cout << "Promoted input: " - << debug_print::idGroupStringShort(inp) << " -> " - << inp_promo_it->second->toString() << std::endl; + if (inps_match) { + replay = exact_use_group->front(); + break; } } + } - // TODO: Only replay if necessary? - // Expr* replay; - - // Replay expression with promoted inputs - Expr* replay = - addReplayAs(promoted_inputs, currently_visiting_expr_group->front()); - std::cout << "REPLAY:\n " << currently_visiting_expr_group->front() - << " " << replay->toString() << std::endl; - - // static int debug_count = 0; - // debug_count++; - - // if(debug_count == 10){ - // std::cout << "Loop map: " << std::endl; - // for (auto group : - // idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - // std::cout << debug_print::idGroupStringShort(group) << std::endl; - // } + if (replay == nullptr) { + replay = addReplayAs(promoted_inputs, iel_expr->front()); + std::cout << "REPLAY:\n " << iel_expr->front() << " " + << replay->toString() << std::endl; + } - // TORCH_INTERNAL_ASSERT(false); - // } + auto out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); - // Mark outputs as having a promoted iter domain - auto replay_out_ids = - ir_utils::filterByType(replay->outputs()).vector(); + // Mark outputs as having a promoted iter domain + auto replay_out_ids = + ir_utils::filterByType(replay->outputs()).vector(); - TORCH_INTERNAL_ASSERT(replay_out_ids.size() == out_groups.size()); + TORCH_INTERNAL_ASSERT(replay_out_ids.size() == out_groups.size()); - for (auto i : c10::irange(replay_out_ids.size())) { - iel_promotion_map[out_groups.vector()[i]] = replay_out_ids[i]; - } + for (auto i : c10::irange(replay_out_ids.size())) { + iel_promotion_map[out_groups.vector()[i]] = replay_out_ids[i]; + std::cout << "Mapping: " << out_groups.vector()[i]->toString() << " -> " + << replay_out_ids[i]->toString() << std::endl; } + } - std::swap(to_visit_exprs, still_to_visit_exprs); - - // Make sure something was processed in this iteration otherwise throw an - // infinite loop error. - if (!something_was_processed && to_visit_exprs.size() > 0) { - std::stringstream err_msg; - err_msg << "Infinite loop entered, visited ids:" << std::endl; - err_msg << debug_print::idGroupsStringShort(visited_ids) << std::endl; - err_msg << "Exprs visited:" << std::endl; - err_msg << debug_print::exprGroupsStringShort( - intersection_exact_loop_graph, visited_exprs) - << std::endl; - err_msg << "Exprs to visit:" << std::endl; - err_msg << debug_print::exprGroupsStringShort( - intersection_exact_loop_graph, to_visit_exprs) + std::cout << "Filled promotion map:" << std::endl; + for (auto entry : iel_promotion_map) { + std::cout << entry.second->toString() << " <- " << entry.first->toString() << std::endl; - TORCH_INTERNAL_ASSERT(false, err_msg.str()); - } } - // std::cout << "Filled promotion map:" << std::endl; - // for (auto entry : iel_promotion_map) { - // std::cout << entry.second->toString() << " <- " << entry.first->toString() - // << std::endl; - // } - // Map from an exact iter domain group, to all the exact iter domain groups it // covers std::unordered_map exact_covered_ids; @@ -2608,86 +2559,52 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - // Traverse expressions in exact map to populate exact_covered_ids entries. - { - ExprGroups all_expr_groups( - idGraph(IdMappingMode::EXACT).disjointExprSets().disjointSets().begin(), - idGraph(IdMappingMode::EXACT).disjointExprSets().disjointSets().end()); + IdGraphStmtSort exact_stmt_sort(idGraph(IdMappingMode::EXACT)); - while (!all_expr_groups.empty()) { - ExprGroups still_to_visit; + for (auto exact_expr : exact_stmt_sort.exprs()) { + auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); - bool something_visited = false; - while (!all_expr_groups.empty()) { - ExprGroup currently_visiting = all_expr_groups.popBack(); - - auto input_groups = - idGraph(IdMappingMode::EXACT).inputGroups(currently_visiting); - - // Make sure expression group is ready to process - bool ready_to_visit = true; - for (auto inp_group : input_groups) { - if (exact_covered_ids.find(inp_group) == exact_covered_ids.end()) { - ready_to_visit = false; - } - } - - // If not ready re-enqueue and continue - if (!ready_to_visit) { - still_to_visit.pushBack(currently_visiting); - continue; - } - - something_visited = true; - // Visit expression - IdGroups covered; - for (auto inp_group : input_groups) { - covered.pushBack(exact_covered_ids.at(inp_group)); - } - - for (auto output_group : - idGraph(IdMappingMode::EXACT).outputGroups(currently_visiting)) { - exact_covered_ids[output_group] = covered; - } - } - - std::swap(still_to_visit, all_expr_groups); + IdGroups covered; + for (auto inp_group : input_groups) { + covered.pushBack(exact_covered_ids.at(inp_group)); + } - if (!something_visited) { - std::cout << "Not visited:" << std::endl; - debug_print::exprGroupsStringShort( - idGraph(IdMappingMode::EXACT), all_expr_groups); - } - TORCH_INTERNAL_ASSERT( - something_visited || all_expr_groups.empty(), - "Entered infinite loops, error traversing on exact map."); + for (auto output_group : + idGraph(IdMappingMode::EXACT).outputGroups(exact_expr)) { + exact_covered_ids[output_group] = covered; } } std::cout << "Covered exact entries:" << std::endl; - for(auto exact_group : idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()){ + for (auto exact_group : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { auto exact_covered_id_it = exact_covered_ids.find(exact_group); - if(exact_covered_id_it == exact_covered_ids.end()){ + if (exact_covered_id_it == exact_covered_ids.end()) { continue; } std::cout << debug_print::idGroupStringShort(exact_group) << " -> " - << debug_print::idGroupsStringShort(exact_covered_id_it->second) << std::endl; + << debug_print::idGroupsStringShort(exact_covered_id_it->second) + << std::endl; } - std::unordered_map loop_promotion_map; + // Loop promotion map is to prepare for IterDomain replays. Since these + // replays will modify the loop map, we operate on a copy of the loop map, + // not the original one. - for (auto loop_group : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + auto loop_graph_copy = idGraph(IdMappingMode::LOOP); + std::unordered_map loop_graph_copy_promotion_map; + + for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_group->size() == 1) { - loop_promotion_map[loop_group] = loop_group->front(); + loop_graph_copy_promotion_map[loop_group] = loop_group->front(); continue; } // We need to check the exact groups the terminal id's are in, but for - // promotion we want an iter domain within the loop group. Since exact group - // can traverse loop group boundaires, save a vector of the group and - // the iter domain. + // promotion we want an iter domain within the loop group. Since exact + // group can traverse loop group boundaires, save a vector of the group + // and the iter domain. std::vector> exact_promoted_terminal_ids; for (auto loop_id : *loop_group) { if (terminal_loop_ids.has(loop_id)) { @@ -2755,253 +2672,408 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { TORCH_INTERNAL_ASSERT(false, err_msg.str()); } - loop_promotion_map[loop_group] = loop_promotion_id; - } - - std::cout << "Loop graph: " << std::endl; - for (auto group : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - std::cout << debug_print::idGroupStringShort(group) << std::endl; + loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; } - std::cout << "Loop promotion map: " << std::endl; - for (auto group : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - if (loop_promotion_map.find(group) == loop_promotion_map.end()) { - continue; - } - std::cout << debug_print::idGroupStringShort(group) << " -> " - << loop_promotion_map.at(group)->toString() << std::endl; - } - - std::cout << "All exprs in loop map" << std::endl; + // std::cout << "Loop graph copy: " << std::endl; + // for (auto group : + // loop_graph_copy.disjointIdSets().disjointSets()) { + // std::cout << debug_print::idGroupStringShort(group) << std::endl; + // } - std::cout << "\n\nTraversal test" << std::endl; + // std::cout << "Loop graph copy promotion map: " << std::endl; + // for (auto group : + // loop_graph_copy.disjointIdSets().disjointSets()) { + // if (loop_graph_copy_promotion_map.find(group) == + // loop_graph_copy_promotion_map.end()) { + // continue; + // } + // std::cout << debug_print::idGroupStringShort(group) << " -> " + // << loop_graph_copy_promotion_map.at(group)->toString() << + // std::endl; + // } - IdGraphStmtSort loop_stmt_sort(idGraph(IdMappingMode::LOOP)); - for (auto loop_expr : loop_stmt_sort.exprs()) { - std::cout << " " - << debug_print::exprGroupStringShort( - idGraph(IdMappingMode::LOOP), loop_expr) - << std::endl; - } + // std::cout << "All exprs in loop map" << std::endl; + + // iel_promotion_map.clear(); + + // // Reinitialize the IEL graph, entries have been added since it's been + // built. intersection_exact_loop_graph = initializeIdGraph(); for (auto + // exact_group : + // idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + // auto set_size = exact_group->size(); + // for (auto id0_i : c10::irange(set_size)) { + // auto id0 = exact_group->vector()[id0_i]; + // for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { + // auto id1 = exact_group->vector()[id1_i]; + // // id0 and id1 map in the almost exact map, if they also map in the + // loop + // // graph, then add the mapping to the inersection + // if (idGraph(IdMappingMode::LOOP) + // .disjointIdSets() + // .strictAreMapped(id0, id1)) { + // intersection_exact_loop_graph.mapIds(id0, id1); + // } + // } + // } + // } - TORCH_INTERNAL_ASSERT(false); + // std::cout << "IEL Graph POST: " << std::endl; + // for (auto entry : + // intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + // std::cout << debug_print::idGroupStringShort(entry) << std::endl; + // } - std::cout << "IEL Graph PRE: " << std::endl; - { - IdGroups groups; - for (auto group : - intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - groups.pushBack(group); - } - std::cout << debug_print::idGroupsStringShort(groups) << std::endl; - } + // // Initialize IterDomain promotions based on loop group, onto the + // intersection + // // exact loop graph + // for(auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()){ + // auto promo_it = loop_graph_copy_promotion_map.find(loop_group); + // if ( promo_it == + // loop_graph_copy_promotion_map.end()) { + // continue; + // } + // auto promo_id = promo_it->second; + // auto iel_groups = intersection_exact_loop_graph.toGroups(*loop_group); + // for(auto iel_group : iel_groups){ + // if (!idGraph(IdMappingMode::ALMOSTEXACT) + // .disjointIdSets() + // .strictAreMapped(promo_id, iel_group->front())) { + // iel_promotion_map[iel_group] = promo_id; + // } + // } + // } + // Reset the promotion map for the second pass iel_promotion_map.clear(); - // Reinitialize the IEL graph, entries have been added since it's been built. - intersection_exact_loop_graph = initializeIdGraph(); - for (auto exact_group : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - auto set_size = exact_group->size(); - for (auto id0_i : c10::irange(set_size)) { - auto id0 = exact_group->vector()[id0_i]; - for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { - auto id1 = exact_group->vector()[id1_i]; - // id0 and id1 map in the almost exact map, if they also map in the loop - // graph, then add the mapping to the inersection - if (idGraph(IdMappingMode::LOOP) - .disjointIdSets() - .strictAreMapped(id0, id1)) { - intersection_exact_loop_graph.mapIds(id0, id1); - } - } - } - } - - std::cout << "IEL Graph POST: " << std::endl; - for (auto entry : - intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - std::cout << debug_print::idGroupStringShort(entry) << std::endl; - } + std::cout << "\n\n Forward replay iel graph:" << std::endl; - TORCH_INTERNAL_ASSERT(false); + IdGraphStmtSort iel_stmt_sort2(intersection_exact_loop_graph); + for (auto iel_expr : iel_stmt_sort2.exprs()) { + auto iel_inp_groups = intersection_exact_loop_graph.inputGroups(iel_expr); - for (auto iel_group : - intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - auto loop_group_pair = - idGraph(IdMappingMode::LOOP).disjointIdSet(iel_group->front()); - TORCH_INTERNAL_ASSERT(loop_group_pair.second); - auto loop_group = loop_group_pair.first; + auto iel_out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); - auto promo_entry_it = loop_promotion_map.find(loop_group); + // When replaying the transformations a second time we want to take loop + // promotion into consideration. However, we don't want to blindly apply + // loop promotion to all iter domains within a loop group as it would + // replay the transformations within that loop group on the promoted id of + // that loop group. + // + // Instead only promote an input if the inputs are of a different loop + // group than the outputs. Then we want to promote the inputs to compute + // the output. - if (promo_entry_it == loop_promotion_map.end()) { - continue; + IdGroups inp_loop_groups; + for (auto iel_inp_group : iel_inp_groups) { + inp_loop_groups.pushBack( + loop_graph_copy.toGroups({iel_inp_group->front()}).front()); } - auto promo_id = promo_entry_it->second; - - if (idGraph(IdMappingMode::ALMOSTEXACT) - .disjointIdSets() - .strictAreMapped(promo_id, iel_group->front())) { - continue; + IdGroups out_loop_groups; + for (auto iel_out_group : iel_out_groups) { + out_loop_groups.pushBack( + loop_graph_copy.toGroups({iel_out_group->front()}).front()); } - // Only promote terminal consumers in the loop groups, otherwise we could - // re-promote transformations. iel promotion map is going to be used to - // replay transformations depending on promoted inlined iter domains. We - // don't want to replay transformations within loop groups, really just - // across them. - if (id_consumer_terminal_loop_ids.has(iel_group->front())) { - iel_promotion_map[iel_group] = promo_id; - } - } + bool loop_promote_inputs = + !inp_loop_groups.subtract(out_loop_groups).empty(); - std::cout << "IEL promotion map init2:" << std::endl; - for (auto entry : iel_promotion_map) { - std::cout << entry.second->toString() << " <- " << entry.first->toString() - << std::endl; - } - - TORCH_INTERNAL_ASSERT(false); - // Finish the promotion map, so far the iel promotion map is only replayed for - // iter domains within the inlined iter domains. However, branches off of the - // inlined iter domains could require replay. - { - ExprGroups all_expr_groups( - intersection_exact_loop_graph.disjointExprSets().disjointSets().begin(), - intersection_exact_loop_graph.disjointExprSets().disjointSets().end()); + std::vector promoted_inputs; - IdGroups visited; - // Initialize inputs - for (auto id_group : - intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - // Initialize inputs - if (intersection_exact_loop_graph.uniqueDefinitions(id_group).empty()) { - visited.pushBack(id_group); - } + bool input_is_promoted = false; - // Initialize rfactor groups - if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { - return view_rfactor_ids_.find(id) != view_rfactor_ids_.end(); - })) { - visited.pushBack(id_group); - } - - // Initialize broadcast groups to empty - if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { - return id->isBroadcast(); - })) { - visited.pushBack(id_group); + // Promote inputs for replay + for (auto iel_inp_group : iel_inp_groups) { + // Prefer loop promotion + auto loop_copy_group = + loop_graph_copy.toGroups({iel_inp_group->front()}).front(); + auto inp_loop_promo_it = + loop_graph_copy_promotion_map.find(loop_copy_group); + if (loop_promote_inputs && + inp_loop_promo_it != loop_graph_copy_promotion_map.end()) { + promoted_inputs.push_back(inp_loop_promo_it->second); + input_is_promoted = true; + } else { + auto inp_promo_it = iel_promotion_map.find(iel_inp_group); + if (inp_promo_it == iel_promotion_map.end()) { + promoted_inputs.push_back(iel_inp_group->front()); + } else { + promoted_inputs.push_back(inp_promo_it->second); + input_is_promoted = true; + } } } - while (!all_expr_groups.empty()) { - ExprGroups still_to_visit; - - bool something_visited = false; - while (!all_expr_groups.empty()) { - ExprGroup currently_visiting = all_expr_groups.popBack(); - - auto input_groups = - intersection_exact_loop_graph.inputGroups(currently_visiting); - - // Make sure expression group is ready to process - if (std::any_of( - input_groups.begin(), - input_groups.end(), - [&visited](IdGroup id_group) { - return !visited.has(id_group); - })) { - still_to_visit.pushBack(currently_visiting); - continue; - } - - something_visited = true; - - auto output_groups = - intersection_exact_loop_graph.outputGroups(currently_visiting); - - for (auto out_group : output_groups) { - visited.pushBack(out_group); - } + if (!input_is_promoted) { + continue; + } - // // If all the output groups are resolved by inlined promotion, they - // // shouldn't be replayed here. - // if (std::all_of( - // output_groups.begin(), - // output_groups.end(), - // [&ordered_p_ca_ids](IdGroup out_group) { - // return std::any_of( - // out_group->begin(), - // out_group->end(), - // [&ordered_p_ca_ids](IterDomain* out_group_id) { - // return ordered_p_ca_ids.has(out_group_id); - // }); - // })) { - // continue; - // } - - std::vector promoted_inputs; - - bool input_is_promoted = false; - for (auto inp_group : input_groups) { - auto inp_promo_it = iel_promotion_map.find(inp_group); - if (inp_promo_it == iel_promotion_map.end()) { - promoted_inputs.push_back(inp_group->front()); - } else { - promoted_inputs.push_back(inp_promo_it->second); - input_is_promoted = true; + Expr* replay = nullptr; + + // Before replaying, check if there's already an expression like this, if so + // use that for promotion. + ExprGroups promoted_input_uses; + for (auto inp_id : promoted_inputs) { + auto inp_exact_group = + idGraph(IdMappingMode::EXACT).toGroups({inp_id}).front(); + promoted_input_uses.pushBack( + idGraph(IdMappingMode::EXACT).uniqueUses(inp_exact_group)); + } + + for (auto exact_use_group : promoted_input_uses) { + std::cout << "Check use: " << exact_use_group->front()->toString(); + if (transformAtributesMatch( + iel_expr->front(), exact_use_group->front())) { + std::cout << "Attributes match" << std::endl; + auto exact_use_inps = ir_utils::filterByType( + exact_use_group->front()->inputs()) + .vector(); + bool inps_match = true; + for (auto inp_i : c10::irange(exact_use_inps.size())) { + inps_match = inps_match && + idGraph(IdMappingMode::EXACT) + .disjointIdSets() + .strictAreMapped( + exact_use_inps[inp_i], promoted_inputs[inp_i]); + if (!idGraph(IdMappingMode::EXACT) + .disjointIdSets() + .strictAreMapped( + exact_use_inps[inp_i], promoted_inputs[inp_i])) { + std::cout << exact_use_inps[inp_i]->toString() << " doesn't match " + << promoted_inputs[inp_i]->toString() << std::endl; } } - - if (!input_is_promoted) { - continue; - } - - if (std::none_of( - output_groups.begin(), - output_groups.end(), - [&iel_promotion_map](IdGroup out_group) { - return iel_promotion_map.find(out_group) == - iel_promotion_map.end(); - })) { - continue; + if (inps_match) { + replay = exact_use_group->front(); + break; } + } + } - Expr* replay = - addReplayAs(promoted_inputs, currently_visiting->front()); + if (replay == nullptr) { + replay = addReplayAs(promoted_inputs, iel_expr->front()); + std::cout << "REPLAY2:\n " << iel_expr->front() << " " + << replay->toString() << std::endl; + } - std::cout << "REPLAY2:\n " << currently_visiting->front() << " " - << replay->toString() << std::endl; + auto output_groups = intersection_exact_loop_graph.outputGroups(iel_expr); - // Mark outputs as having a promoted iter domain - auto replay_out_ids = - ir_utils::filterByType(replay->outputs()).vector(); + // Mark outputs as having a promoted iter domain + auto replay_out_ids = + ir_utils::filterByType(replay->outputs()).vector(); - TORCH_INTERNAL_ASSERT(replay_out_ids.size() == output_groups.size()); + TORCH_INTERNAL_ASSERT(replay_out_ids.size() == output_groups.size()); - for (auto i : c10::irange(replay_out_ids.size())) { - iel_promotion_map[output_groups.vector()[i]] = replay_out_ids[i]; - } + for (auto i : c10::irange(replay_out_ids.size())) { + if (!idGraph(IdMappingMode::EXACT) + .disjointIdSets() + .strictAreMapped( + replay_out_ids[i], output_groups.vector()[i]->front())) { + iel_promotion_map[output_groups.vector()[i]] = replay_out_ids[i]; } - - std::swap(still_to_visit, all_expr_groups); - - TORCH_INTERNAL_ASSERT( - something_visited || all_expr_groups.empty(), - "Entered infinite loops, error traversing on exact map."); } + + std::cout << " " + << debug_print::exprGroupStringShort( + intersection_exact_loop_graph, iel_expr) + << std::endl; } - std::cout << "IEL promotion map:" << std::endl; + + std::cout << "Filled promotion map2:" << std::endl; for (auto entry : iel_promotion_map) { std::cout << entry.second->toString() << " <- " << entry.first->toString() << std::endl; } TORCH_INTERNAL_ASSERT(false); + + // for (auto iel_group : + // intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + // auto loop_group_pair = + // idGraph(IdMappingMode::LOOP).disjointIdSet(iel_group->front()); + // TORCH_INTERNAL_ASSERT(loop_group_pair.second); + // auto loop_group = loop_group_pair.first; + + // auto promo_entry_it = loop_promotion_map.find(loop_group); + + // if (promo_entry_it == loop_promotion_map.end()) { + // continue; + // } + + // auto promo_id = promo_entry_it->second; + + // if (idGraph(IdMappingMode::ALMOSTEXACT) + // .disjointIdSets() + // .strictAreMapped(promo_id, iel_group->front())) { + // continue; + // } + + // // Only promote terminal consumers in the loop groups, otherwise we + // could + // // re-promote transformations. iel promotion map is going to be used to + // // replay transformations depending on promoted inlined iter domains. + // We + // // don't want to replay transformations within loop groups, really just + // // across them. + // if (id_consumer_terminal_loop_ids.has(iel_group->front())) { + // iel_promotion_map[iel_group] = promo_id; + // } + // } + + // std::cout << "IEL promotion map init2:" << std::endl; + // for (auto entry : iel_promotion_map) { + // std::cout << entry.second->toString() << " <- " << + // entry.first->toString() + // << std::endl; + // } + + // TORCH_INTERNAL_ASSERT(false); + // // Finish the promotion map, so far the iel promotion map is only + // replayed for + // // iter domains within the inlined iter domains. However, branches off of + // the + // // inlined iter domains could require replay. + // { + // ExprGroups all_expr_groups( + // intersection_exact_loop_graph.disjointExprSets().disjointSets().begin(), + // intersection_exact_loop_graph.disjointExprSets().disjointSets().end()); + + // IdGroups visited; + // // Initialize inputs + // for (auto id_group : + // intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + // // Initialize inputs + // if + // (intersection_exact_loop_graph.uniqueDefinitions(id_group).empty()) { + // visited.pushBack(id_group); + // } + + // // Initialize rfactor groups + // if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* + // id) { + // return view_rfactor_ids_.find(id) != view_rfactor_ids_.end(); + // })) { + // visited.pushBack(id_group); + // } + + // // Initialize broadcast groups to empty + // if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* + // id) { + // return id->isBroadcast(); + // })) { + // visited.pushBack(id_group); + // } + // } + + // while (!all_expr_groups.empty()) { + // ExprGroups still_to_visit; + + // bool something_visited = false; + // while (!all_expr_groups.empty()) { + // ExprGroup currently_visiting = all_expr_groups.popBack(); + + // auto input_groups = + // intersection_exact_loop_graph.inputGroups(currently_visiting); + + // // Make sure expression group is ready to process + // if (std::any_of( + // input_groups.begin(), + // input_groups.end(), + // [&visited](IdGroup id_group) { + // return !visited.has(id_group); + // })) { + // still_to_visit.pushBack(currently_visiting); + // continue; + // } + + // something_visited = true; + + // auto output_groups = + // intersection_exact_loop_graph.outputGroups(currently_visiting); + + // for (auto out_group : output_groups) { + // visited.pushBack(out_group); + // } + + // // // If all the output groups are resolved by inlined promotion, + // they + // // // shouldn't be replayed here. + // // if (std::all_of( + // // output_groups.begin(), + // // output_groups.end(), + // // [&ordered_p_ca_ids](IdGroup out_group) { + // // return std::any_of( + // // out_group->begin(), + // // out_group->end(), + // // [&ordered_p_ca_ids](IterDomain* out_group_id) { + // // return ordered_p_ca_ids.has(out_group_id); + // // }); + // // })) { + // // continue; + // // } + + // std::vector promoted_inputs; + + // bool input_is_promoted = false; + // for (auto inp_group : input_groups) { + // auto inp_promo_it = iel_promotion_map.find(inp_group); + // if (inp_promo_it == iel_promotion_map.end()) { + // promoted_inputs.push_back(inp_group->front()); + // } else { + // promoted_inputs.push_back(inp_promo_it->second); + // input_is_promoted = true; + // } + // } + + // if (!input_is_promoted) { + // continue; + // } + + // if (std::none_of( + // output_groups.begin(), + // output_groups.end(), + // [&iel_promotion_map](IdGroup out_group) { + // return iel_promotion_map.find(out_group) == + // iel_promotion_map.end(); + // })) { + // continue; + // } + + // Expr* replay = + // addReplayAs(promoted_inputs, currently_visiting->front()); + + // std::cout << "REPLAY2:\n " << currently_visiting->front() << " " + // << replay->toString() << std::endl; + + // // Mark outputs as having a promoted iter domain + // auto replay_out_ids = + // ir_utils::filterByType(replay->outputs()).vector(); + + // TORCH_INTERNAL_ASSERT(replay_out_ids.size() == + // output_groups.size()); + + // for (auto i : c10::irange(replay_out_ids.size())) { + // iel_promotion_map[output_groups.vector()[i]] = replay_out_ids[i]; + // } + // } + + // std::swap(still_to_visit, all_expr_groups); + + // TORCH_INTERNAL_ASSERT( + // something_visited || all_expr_groups.empty(), + // "Entered infinite loops, error traversing on exact map."); + // } + // } + // std::cout << "IEL promotion map:" << std::endl; + // for (auto entry : iel_promotion_map) { + // std::cout << entry.second->toString() << " <- " << + // entry.first->toString() + // << std::endl; + // // } + + // // TORCH_INTERNAL_ASSERT(false); } void IterDomainGraphs::buildIndexMap(const std::vector& all_tvs) { diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index b3e7ec5c786..7ad77fcbcf1 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -1057,4 +1057,37 @@ TEST_F(NVFuserTest, FusionMultiPromotion2_CUDA) { ASSERT_ANY_THROW(fusion.printKernel()); } +// TODO: All the above tests are merges followed by splits, we should make some +// more complex examples even though merging then spliting is the most likely +// use case. In multi-gpu it may be the exact opposite where we split out the +// outer most iter domain to the multi-gpu dimension, then schedule. + +TEST_F(NVFuserTest, FusionIndexSplitMerge_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + // [w] + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + // [w, x] + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {false, true}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + tv3->split(0, 3); + tv3->split(2, 4); + tv3->merge(1); + tv3->split(1, 5); + + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + inlineAllAt(tv3, 2, false); + + fusion.printKernel(); +} + + } // namespace nvfuser From f745410dd9d4f1b4f7159b170ac14e4fe38347e8 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 25 Mar 2023 14:06:11 -0400 Subject: [PATCH 004/178] Finish computing loop group promotions. --- csrc/id_graphs.cpp | 370 ++++++++++++++++++++++++--------------------- csrc/id_graphs.h | 3 + 2 files changed, 202 insertions(+), 171 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index b2b31ab5123..8a46e2cd9d1 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -390,11 +390,11 @@ void IdGraphVisitor::traverse() { to_visit_exprs.pushBack(uses_pair.first); } } - } else { still_to_visit_ids.pushBack(current_id_group); } } + std::swap(to_visit_ids, still_to_visit_ids); TORCH_INTERNAL_ASSERT( something_was_processed || @@ -2184,7 +2184,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Case (1) VectorOfUniqueEntries p2c_ca_terminal_loop_ids; - // Case(2) + // Case (2) VectorOfUniqueEntries id_consumer_terminal_loop_ids; for (auto group : @@ -2260,8 +2260,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // std::cout << debug_print::idGroupsStringShort(groups) << std::endl; // } - auto intersection_exact_loop_graph = initializeIdGraph(); - // Make an intersection of the exact and loop map. This will group together // entries in each loop group that are exact with eachother. This provides a // better graph to do promotion and replays. @@ -2285,6 +2283,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // smaller groups and this algorithm scales with the number of groups * // (number of entries in groups ^ 2) + auto intersection_exact_loop_graph = initializeIdGraph(); for (auto exact_group : idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { auto set_size = exact_group->size(); @@ -2661,13 +2660,16 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (loop_promotion_id == nullptr) { std::stringstream err_msg; err_msg << "\nCould not find promotion for loop group:\n "; - err_msg << debug_print::idGroupsStringShort(loop_group_covered_ids); - err_msg - << "\nHowever, none of the iter domains that this group promotes to:\n"; + err_msg << debug_print::idGroupStringShort(loop_group); + err_msg << "\nnone of the terminal iter domains of this group:\n "; for (auto entry : exact_promoted_terminal_ids) { auto terminal_id_group = entry.first; - err_msg << " " << debug_print::idGroupStringShort(terminal_id_group); - err_msg << "\ncover these groups\n"; + err_msg << " " << debug_print::idGroupStringShort(terminal_id_group) + << std::endl; + } + err_msg << "iter domains in this group cover all id groups:\n"; + for (auto covered_group : loop_group_covered_ids) { + err_msg << " " << debug_print::idGroupStringShort(covered_group); } TORCH_INTERNAL_ASSERT(false, err_msg.str()); } @@ -2887,193 +2889,219 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { << std::endl; } - TORCH_INTERNAL_ASSERT(false); + // Need to update the iel_graph again since we've added operations to the + // exact and loop map. + // *************** START: Code copied verbatim from above ******************** + intersection_exact_loop_graph = initializeIdGraph(); + for (auto exact_group : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + auto set_size = exact_group->size(); + for (auto id0_i : c10::irange(set_size)) { + auto id0 = exact_group->vector()[id0_i]; + for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { + auto id1 = exact_group->vector()[id1_i]; + // id0 and id1 map in the almost exact map, if they also map in the loop + // graph, then add the mapping to the inersection + if (idGraph(IdMappingMode::LOOP) + .disjointIdSets() + .strictAreMapped(id0, id1)) { + intersection_exact_loop_graph.mapIds(id0, id1); + } + } + } + } + // *************** STOP: Code copied verbatim from above ******************** - // for (auto iel_group : - // intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - // auto loop_group_pair = - // idGraph(IdMappingMode::LOOP).disjointIdSet(iel_group->front()); - // TORCH_INTERNAL_ASSERT(loop_group_pair.second); - // auto loop_group = loop_group_pair.first; + // *************** START: Code copied verbatim from above ******************** + exact_covered_ids.clear(); - // auto promo_entry_it = loop_promotion_map.find(loop_group); + for (auto id_group : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + // Initialize inputs + if (idGraph(IdMappingMode::EXACT).uniqueDefinitions(id_group).empty()) { + exact_covered_ids[id_group] = {id_group}; + } - // if (promo_entry_it == loop_promotion_map.end()) { - // continue; - // } + // Initialize rfactor groups + if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { + return view_rfactor_ids_.find(id) != view_rfactor_ids_.end(); + })) { + exact_covered_ids[id_group] = {id_group}; + } - // auto promo_id = promo_entry_it->second; + // Initialize broadcast groups to empty + if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { + return id->isBroadcast(); + })) { + exact_covered_ids[id_group] = {}; + } + } - // if (idGraph(IdMappingMode::ALMOSTEXACT) - // .disjointIdSets() - // .strictAreMapped(promo_id, iel_group->front())) { - // continue; - // } + IdGraphStmtSort exact_stmt_sort2(idGraph(IdMappingMode::EXACT)); - // // Only promote terminal consumers in the loop groups, otherwise we - // could - // // re-promote transformations. iel promotion map is going to be used to - // // replay transformations depending on promoted inlined iter domains. - // We - // // don't want to replay transformations within loop groups, really just - // // across them. - // if (id_consumer_terminal_loop_ids.has(iel_group->front())) { - // iel_promotion_map[iel_group] = promo_id; - // } - // } + for (auto exact_expr : exact_stmt_sort2.exprs()) { + auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); - // std::cout << "IEL promotion map init2:" << std::endl; - // for (auto entry : iel_promotion_map) { - // std::cout << entry.second->toString() << " <- " << - // entry.first->toString() - // << std::endl; - // } + IdGroups covered; + for (auto inp_group : input_groups) { + covered.pushBack(exact_covered_ids.at(inp_group)); + } - // TORCH_INTERNAL_ASSERT(false); - // // Finish the promotion map, so far the iel promotion map is only - // replayed for - // // iter domains within the inlined iter domains. However, branches off of - // the - // // inlined iter domains could require replay. - // { - // ExprGroups all_expr_groups( - // intersection_exact_loop_graph.disjointExprSets().disjointSets().begin(), - // intersection_exact_loop_graph.disjointExprSets().disjointSets().end()); + for (auto output_group : + idGraph(IdMappingMode::EXACT).outputGroups(exact_expr)) { + exact_covered_ids[output_group] = covered; + } + } - // IdGroups visited; - // // Initialize inputs - // for (auto id_group : - // intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - // // Initialize inputs - // if - // (intersection_exact_loop_graph.uniqueDefinitions(id_group).empty()) { - // visited.pushBack(id_group); - // } + std::cout << "Covered exact entries:" << std::endl; + for (auto exact_group : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + auto exact_covered_id_it = exact_covered_ids.find(exact_group); + if (exact_covered_id_it == exact_covered_ids.end()) { + continue; + } - // // Initialize rfactor groups - // if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* - // id) { - // return view_rfactor_ids_.find(id) != view_rfactor_ids_.end(); - // })) { - // visited.pushBack(id_group); - // } + std::cout << debug_print::idGroupStringShort(exact_group) << " -> " + << debug_print::idGroupsStringShort(exact_covered_id_it->second) + << std::endl; + } - // // Initialize broadcast groups to empty - // if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* - // id) { - // return id->isBroadcast(); - // })) { - // visited.pushBack(id_group); - // } - // } + // Loop promotion map is to prepare for IterDomain replays. Since these + // replays will modify the loop map, we operate on a copy of the loop map, + // not the original one. - // while (!all_expr_groups.empty()) { - // ExprGroups still_to_visit; - - // bool something_visited = false; - // while (!all_expr_groups.empty()) { - // ExprGroup currently_visiting = all_expr_groups.popBack(); - - // auto input_groups = - // intersection_exact_loop_graph.inputGroups(currently_visiting); - - // // Make sure expression group is ready to process - // if (std::any_of( - // input_groups.begin(), - // input_groups.end(), - // [&visited](IdGroup id_group) { - // return !visited.has(id_group); - // })) { - // still_to_visit.pushBack(currently_visiting); - // continue; - // } + loop_graph_copy = idGraph(IdMappingMode::LOOP); + loop_graph_copy_promotion_map.clear(); - // something_visited = true; + for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { + if (loop_group->size() == 1) { + loop_graph_copy_promotion_map[loop_group] = loop_group->front(); + continue; + } - // auto output_groups = - // intersection_exact_loop_graph.outputGroups(currently_visiting); + // We need to check the exact groups the terminal id's are in, but for + // promotion we want an iter domain within the loop group. Since exact + // group can traverse loop group boundaires, save a vector of the group + // and the iter domain. + std::vector> exact_promoted_terminal_ids; + for (auto loop_id : *loop_group) { + // *************** START DIFF ******************** + // This is different as there's iter domains not based on the original + // producer-consumer relationships, so finding terminal id's can be a bit + // different here. - // for (auto out_group : output_groups) { - // visited.pushBack(out_group); - // } + // If there's an entry in the p2c_ca_permissive map, this loop_id is not a + // promotion candidate. + if (p2c_ca_permissive_maps.find(loop_id) != + p2c_ca_permissive_maps.end()) { + continue; + } - // // // If all the output groups are resolved by inlined promotion, - // they - // // // shouldn't be replayed here. - // // if (std::all_of( - // // output_groups.begin(), - // // output_groups.end(), - // // [&ordered_p_ca_ids](IdGroup out_group) { - // // return std::any_of( - // // out_group->begin(), - // // out_group->end(), - // // [&ordered_p_ca_ids](IterDomain* out_group_id) { - // // return ordered_p_ca_ids.has(out_group_id); - // // }); - // // })) { - // // continue; - // // } - - // std::vector promoted_inputs; - - // bool input_is_promoted = false; - // for (auto inp_group : input_groups) { - // auto inp_promo_it = iel_promotion_map.find(inp_group); - // if (inp_promo_it == iel_promotion_map.end()) { - // promoted_inputs.push_back(inp_group->front()); - // } else { - // promoted_inputs.push_back(inp_promo_it->second); - // input_is_promoted = true; - // } - // } + // Grab all the output groups of uses in the iel graph. + TORCH_INTERNAL_ASSERT( + intersection_exact_loop_graph.disjointIdSet(loop_id).second); + auto iel_group = + intersection_exact_loop_graph.disjointIdSet(loop_id).first; + auto iel_uses = intersection_exact_loop_graph.uniqueUses(iel_group); - // if (!input_is_promoted) { - // continue; - // } + IdGroups iel_output_groups; + for (auto iel_use : iel_uses) { + iel_output_groups.pushBack( + intersection_exact_loop_graph.outputGroups(iel_use)); + } - // if (std::none_of( - // output_groups.begin(), - // output_groups.end(), - // [&iel_promotion_map](IdGroup out_group) { - // return iel_promotion_map.find(out_group) == - // iel_promotion_map.end(); - // })) { - // continue; - // } + // Convert the iel output groups into loop groups + IdGroups loop_output_groups; + for (auto iel_group : iel_output_groups) { + TORCH_INTERNAL_ASSERT( + intersection_exact_loop_graph.disjointIdSet(iel_group->front()) + .second); + loop_output_groups.pushBack( + intersection_exact_loop_graph.disjointIdSet(iel_group->front()) + .first); + } - // Expr* replay = - // addReplayAs(promoted_inputs, currently_visiting->front()); + // If all outputs of the uses of this id in the iel graph are within the + // same loop group, then it's not a promotion candidate. + if (loop_output_groups.size() == 1 && + loop_output_groups.front() == loop_group) { + continue; + } - // std::cout << "REPLAY2:\n " << currently_visiting->front() << " " - // << replay->toString() << std::endl; + // This id is a promotion candidate + auto promo_id_exact_it = + idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); + TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); + exact_promoted_terminal_ids.push_back( + std::make_pair(promo_id_exact_it.first, loop_id)); + } + // *************** STOP DIFF ******************** - // // Mark outputs as having a promoted iter domain - // auto replay_out_ids = - // ir_utils::filterByType(replay->outputs()).vector(); + // All exact groups with iter domains in this loop group + IdGroups exact_groups; + for (auto loop_id : *loop_group) { + auto exact_set_pair = + idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); + TORCH_INTERNAL_ASSERT(exact_set_pair.second); + exact_groups.pushBack(exact_set_pair.first); + } - // TORCH_INTERNAL_ASSERT(replay_out_ids.size() == - // output_groups.size()); + // All exact groups covered by all iter domains in this loop group + IdGroups loop_group_covered_ids; + for (auto exact_group : exact_groups) { + auto covered_it = exact_covered_ids.find(exact_group); + TORCH_INTERNAL_ASSERT(covered_it != exact_covered_ids.end()); + loop_group_covered_ids.pushBack(covered_it->second); + } - // for (auto i : c10::irange(replay_out_ids.size())) { - // iel_promotion_map[output_groups.vector()[i]] = replay_out_ids[i]; - // } - // } + IterDomain* loop_promotion_id = nullptr; - // std::swap(still_to_visit, all_expr_groups); + for (auto entry : exact_promoted_terminal_ids) { + auto terminal_id_group = entry.first; + auto terminal_id = entry.second; + auto covered_it = exact_covered_ids.find(terminal_id_group); + TORCH_INTERNAL_ASSERT(covered_it != exact_covered_ids.end()); + if (loop_group_covered_ids.subtract(covered_it->second).size() == 0) { + loop_promotion_id = terminal_id; + } + } - // TORCH_INTERNAL_ASSERT( - // something_visited || all_expr_groups.empty(), - // "Entered infinite loops, error traversing on exact map."); - // } - // } - // std::cout << "IEL promotion map:" << std::endl; - // for (auto entry : iel_promotion_map) { - // std::cout << entry.second->toString() << " <- " << - // entry.first->toString() - // << std::endl; - // // } - - // // TORCH_INTERNAL_ASSERT(false); + if (loop_promotion_id == nullptr) { + std::stringstream err_msg; + err_msg << "\nCould not find promotion for loop group:\n "; + err_msg << debug_print::idGroupStringShort(loop_group); + err_msg << "\nnone of the terminal iter domains of this group:\n "; + for (auto entry : exact_promoted_terminal_ids) { + auto terminal_id_group = entry.first; + err_msg << " " << debug_print::idGroupStringShort(terminal_id_group) + << std::endl; + } + err_msg << "iter domains in this group cover all id groups:\n"; + for (auto covered_group : loop_group_covered_ids) { + err_msg << " " << debug_print::idGroupStringShort(covered_group); + } + TORCH_INTERNAL_ASSERT(false, err_msg.str()); + } + + loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; + } + + // *************** STOP: Code copied verbatim from above ******************** + + std::cout << "Loop graph copy promotion map: " << std::endl; + for (auto group : loop_graph_copy.disjointIdSets().disjointSets()) { + if (loop_graph_copy_promotion_map.find(group) == + loop_graph_copy_promotion_map.end()) { + continue; + } + std::cout << debug_print::idGroupStringShort(group) << " -> " + << loop_graph_copy_promotion_map.at(group)->toString() + << std::endl; + } + + auto index_graph = initializeIdGraph(); + + TORCH_INTERNAL_ASSERT(false); } void IterDomainGraphs::buildIndexMap(const std::vector& all_tvs) { diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index c0d76456875..c5d465ec2af 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -277,6 +277,8 @@ class IdGraphStmtSort : public IdGraphVisitor { IdGroups sorted_ids; }; +// TODO: Comment is stale, update. +// // There's three modes of these iter domain mappings all uniquely important in // the lowering process. // @@ -380,6 +382,7 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { // replayed expression and adding potential mappings through the expression. Expr* addReplayAs(const std::vector& new_inputs, Expr* expr); + // TODO: Should this not be private? protected: // Sometimes fusion inputs or outputs are disconnected from expressions, in // those cases we still may want to send in some additional tensor views from From 6edbba38efa8cc7f9e9b3d716fafb9a243a4612d Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 27 Mar 2023 12:30:40 -0400 Subject: [PATCH 005/178] Prepare for backward graph replays. --- csrc/id_graphs.cpp | 197 +++++++++++++++++++++++++++++++++++++++- csrc/id_graphs.h | 24 +++++ csrc/transform_iter.cpp | 85 ++++++++++++++++- csrc/transform_iter.h | 40 +++++++- 4 files changed, 337 insertions(+), 9 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 8a46e2cd9d1..db6188d0e81 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -47,7 +47,7 @@ std::string idGroupStringShort(const IdGroup& id_group) { return ss.str(); } -std::string idGroupsStringShortInline(const IdGroups& id_groups) { +std::string idGroupsStringShortInline(const std::vector& id_groups) { // Track position in id_groups and its min iter domain name in the set std::vector> group_name_info; @@ -76,14 +76,18 @@ std::string idGroupsStringShortInline(const IdGroups& id_groups) { ss << ", "; } auto pos = group_name_info[i].second; - ss << idGroupStringShort(id_groups.vector()[pos]); + ss << idGroupStringShort(id_groups[pos]); } ss << "}"; return ss.str(); } -std::string idGroupsStringShort(const IdGroups& id_groups) { +std::string idGroupsStringShortInline(const IdGroups& id_groups) { + return idGroupsStringShortInline(id_groups.vector()); +} + +std::string idGroupsStringShort(const std::vector& id_groups) { std::stringstream ss; // Track position in id_groups and its min iter domain name in the set @@ -108,13 +112,17 @@ std::string idGroupsStringShort(const IdGroups& id_groups) { for (auto i : c10::irange(group_name_info.size())) { auto pos = group_name_info[i].second; - ss << " " << idGroupStringShort(id_groups.vector()[pos]) << "\n"; + ss << " " << idGroupStringShort(id_groups[pos]) << "\n"; } ss << "}"; return ss.str(); } +std::string idGroupsStringShort(const IdGroups& id_groups) { + return idGroupsStringShort(id_groups.vector()); +} + std::string exprGroupStringShort(ExprGroup expr_group) { std::vector names; for (auto expr : *expr_group) { @@ -1644,6 +1652,143 @@ Expr* IterDomainGraphs::addReplayAs( return replay; } +// Generate a new expr with the IterDomain outputs provided and IterDomain +// inputs that exactly match expr->inputs +Expr* IterDomainGraphs::addReplayAsBackward( + const std::vector& new_outputs, + Expr* expr) { + // Figure out which graphs are already initialized to make sure we add the new + // expression to them. + std::vector initialized_modes; + for (auto mode : kIdMappingModes) { + auto graph_it = id_graphs_.find(mode); + if (graph_it == id_graphs_.end()) { + continue; + } + + auto& graph = graph_it->second; + if (graph.disjointIdSets().disjointSetMap().empty()) { + continue; + } + + initialized_modes.push_back(mode); + } + + auto orig_outputs = ir_utils::filterByType(expr->outputs()); + std::vector orig_output_ids( + orig_outputs.begin(), orig_outputs.end()); + + { + TORCH_INTERNAL_ASSERT( + new_outputs.size() == orig_output_ids.size(), + "Invalid number of outputs: ", + new_outputs.size(), + " does not match number of iter domain outputs for ", + expr->toString()); + + VectorOfUniqueEntries all_outputs{ + orig_output_ids.begin(), orig_output_ids.end()}; + + all_outputs.pushBack(VectorOfUniqueEntries{ + new_outputs.begin(), new_outputs.end()}); + + for (auto mode : initialized_modes) { + for (auto inp : all_outputs) { + TORCH_INTERNAL_ASSERT( + idGraph(mode).disjointIdSet(inp).second, + "All outputs for replay need to be initialized in all graphs, ", + inp->toString(), + " was not found in mode: ", + mode); + } + } + } + + // Create the new expression with provided outputs + auto replay = BackwardTransformCloner::clone(new_outputs, expr); + + for (auto out_id : ir_utils::filterByType(replay->outputs())) { + id_definitions_[out_id].pushBack(replay); + } + + // Add the expression to the uses of the inputs + for (auto inp_id : ir_utils::filterByType(replay->inputs())) { + id_definitions_[inp_id] = {}; + id_uses_[inp_id] = {replay}; + } + + // Initialize output iter domains in the graphs + for (auto mode : initialized_modes) { + idGraph(mode).disjointExprSets().initializeSet(replay); + auto replay_group = idGraph(mode).disjointExprSet(replay).first; + + // Initialize input ids in map + for (auto inp_id : ir_utils::filterByType(replay->inputs())) { + idGraph(mode).initializeId(inp_id, {}, {replay}); + } + + // Update definitions in the graph of the outputs + for (auto out_id : ir_utils::filterByType(replay->outputs())) { + auto out_group = idGraph(mode).disjointIdSet(out_id).first; + idGraph(mode).uniqueDefinitions().at(out_group).pushBack(replay_group); + } + + // Propagate through all the defintions of the iter domain groups of the + // outputs with the new expression. + auto& graph = idGraph(mode); + // Gather all use expressions from inputs + VectorOfUniqueEntries representative_defs; + for (auto out : new_outputs) { + auto defs_pair = + graph.iterDomainGroupDefinitions(graph.disjointIdSet(out).first); + if (defs_pair.second) { + for (auto def_group : defs_pair.first) { + representative_defs.pushBack(def_group->front()); + } + } + } + + for (auto expr : representative_defs) { + if (graph.exprsMap(expr, replay, false)) { + graph.mapExprs(expr, replay); + graph.mapThroughExpr(expr, replay, false); + } + } + } + + return replay; +} + +// Clone provided iter domain and return the new copy. Map that copy in relevant +// maps. +IterDomain* IterDomainGraphs::cloneIterDomain(IterDomain* id) { + // Figure out which graphs are already initialized to make sure we add the new + // expression to them. + std::vector initialized_modes; + for (auto mode : kIdMappingModes) { + auto graph_it = id_graphs_.find(mode); + if (graph_it == id_graphs_.end()) { + continue; + } + + auto& graph = graph_it->second; + if (graph.disjointIdSets().disjointSetMap().empty()) { + continue; + } + + initialized_modes.push_back(mode); + } + + auto id_copy = id->cloneWithoutRFactor(); + + for (auto mode : initialized_modes) { + idGraph(mode).initializeId(id_copy, {}, {}); + idGraph(mode).mapIds(id, id_copy); + } + + return id_copy; +} + IdGraph IterDomainGraphs::initializeIdGraph() { IdGraph id_graph; @@ -3099,6 +3244,50 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { << std::endl; } + // Mark all iter domains that should share a loop nest, ignoring promotion for + // now + auto original_loop_graph = initializeIdGraph(); + + for (auto expr : exprs) { + std::vector producer_leaves; + VectorOfUniqueEntries all_p_ids; + for (auto producer : ir_utils::filterByType(expr->inputs())) { + all_p_ids.insert( + producer->domain()->domain().begin(), + producer->domain()->domain().begin() + + producer->getComputeAtPosition()); + producer_leaves.insert( + producer_leaves.end(), + producer->domain()->domain().begin(), + producer->domain()->domain().begin() + + producer->getComputeAtPosition()); + } + + std::vector consumer_leaves; + for (auto consumer : ir_utils::filterByType(expr->outputs())) { + consumer_leaves.insert( + consumer_leaves.end(), + consumer->domain()->domain().begin(), + consumer->domain()->domain().begin() + + consumer->getMaxProducerPosition()); + } + + auto p2c_loop_map = idGraph(IdMappingMode::LOOP) + .buildMapBetween(producer_leaves, consumer_leaves); + // Make sure we call mapIds deterministically + for (auto p_id : all_p_ids) { + auto p2c_loop_map_it = p2c_loop_map.find(p_id); + if (p2c_loop_map_it == p2c_loop_map.end()) { + continue; + } + auto c_ids = p2c_loop_map_it->second; + + for (auto c_id : c_ids) { + original_loop_graph.mapIds(p_id, c_id); + } + } + } + auto index_graph = initializeIdGraph(); TORCH_INTERNAL_ASSERT(false); diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index c5d465ec2af..e46cefd7a6b 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -382,6 +382,30 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { // replayed expression and adding potential mappings through the expression. Expr* addReplayAs(const std::vector& new_inputs, Expr* expr); + // Similar to addReplayAs, but in the reverse direction. Also addReplayAs can + // generate output ids by using the IterDomain::transform functions. For + // backwards because of merge the input iter domains of the transform are just + // cloned with IterDomain::cloneWithoutRFactor, and the transform Expr is + // generated with IrBuilder copying over all the attributes. + Expr* addReplayAsBackward( + const std::vector& new_outputs, + Expr* expr); + + // Make a new expr matching that provided but using the outputs provided. + // IterDomainGraphss will be updated for all maps that have entries. Adding + // the input iter domains of the replayed expression and adding potential + // mappings through the expressions. Input domains will match exactly in all + // properties as those in expr. This is unlike addReplayAs which will produce + // new outputs using transformations directly. + Expr* addBackwardsReplayAs( + const std::vector& new_outputs, + Expr* expr); + + // Make an exact copy of provided IterDomain (without rfactor set), and map + // the copy to the original in all registered IdGraphs. IterDomain copy will + // not have any registered uses or definitions. + IterDomain* cloneIterDomain(IterDomain* id); + // TODO: Should this not be private? protected: // Sometimes fusion inputs or outputs are disconnected from expressions, in diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 3b03ae31895..5d0246aee58 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -7,6 +7,8 @@ // clang-format on #include +#include + #include #include @@ -68,10 +70,17 @@ void ReplayTransform::handle(const Swizzle2D* swizzle_2d) { .first->definition(); } - void ReplayTransform::handle(const Resize* resize) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); + TORCH_INTERNAL_ASSERT( + input_ids_.size() == 1, + "Expected one input to match resize: ", + resize->toString()); + replayed_expr_ = + IterDomain::resize( + input_ids_[0], resize->leftExpand(), resize->rightExpand()) + ->definition(); } + // Transform dispatch void ReplayTransformations::handle(Expr* e) { auto is_supported_expr = e->isOneOf(); @@ -80,6 +89,78 @@ void ReplayTransformations::handle(Expr* e) { IterVisitor::handle(e); } +Expr* BackwardTransformCloner::clone( + const std::vector& ordered_outputs, + const Expr* expression_to_match) { + BackwardTransformCloner replay(ordered_outputs, expression_to_match); + return replay.new_expr_; +} + +BackwardTransformCloner::BackwardTransformCloner( + const std::vector& ordered_outputs, + const Expr* expression_to_match) + : output_ids_(ordered_outputs) { + OptOutConstDispatch::handle(expression_to_match); +} + +// We're going to replay this split operation on the corresponding ID +void BackwardTransformCloner::handle(const Split* split) { + TORCH_INTERNAL_ASSERT( + output_ids_.size() == 2, + "Expected two outputs to match split: ", + split->toString()); + + new_expr_ = IrBuilder::create( + output_ids_[0], + output_ids_[1], + split->in()->cloneWithoutRFactor(), + split->factor(), + split->innerSplit(), + split->startOffset(), + split->stopOffset()); +} + +// We're going to replay this merge operation on the corresponding IDs +void BackwardTransformCloner::handle(const Merge* merge) { + TORCH_INTERNAL_ASSERT( + output_ids_.size() == 1, + "Expected one output to match merge: ", + merge->toString()); + + new_expr_ = IrBuilder::create( + output_ids_[0], + merge->outer()->cloneWithoutRFactor(), + merge->inner()->cloneWithoutRFactor()); +} + +// We're going to replay this swizzle operation on the corresponding IDs +// if replaying swizzle is enabled. +void BackwardTransformCloner::handle(const Swizzle2D* swizzle_2d) { + TORCH_INTERNAL_ASSERT( + output_ids_.size() == 2, + "Expected two outputs to match swizzle: ", + swizzle_2d->toString()); + new_expr_ = IrBuilder::create( + output_ids_[0], + output_ids_[1], + swizzle_2d->inX()->cloneWithoutRFactor(), + swizzle_2d->inY()->cloneWithoutRFactor(), + swizzle_2d->swizzleType(), + swizzle_2d->swizzleMode()); +} + +void BackwardTransformCloner::handle(const Resize* resize) { + TORCH_INTERNAL_ASSERT( + output_ids_.size() == 1, + "Expected one output to match resize: ", + resize->toString()); + new_expr_ = IrBuilder::create( + output_ids_[0], + resize->in()->cloneWithoutRFactor(), + resize->leftExpand(), + resize->rightExpand()); +} + // We're going to replay this split operation on the corresponding ID void ReplayTransformations::handle(Split* s) { // Grab our input to the split node diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index ade428f9542..0df6d7bd9a5 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -32,7 +32,7 @@ struct id_int_lt { } // namespace -class ReplayTransform : OptOutConstDispatch { +class ReplayTransform : OptInConstDispatch { public: // Replays expression_to_match with the provided ordered_inputs. Inputs should // be ordered as they would be used in provided expression. Returns new @@ -48,7 +48,7 @@ class ReplayTransform : OptOutConstDispatch { const std::vector& ordered_inputs, const Expr* expression_to_match); - using OptOutConstDispatch::handle; + using OptInConstDispatch::handle; // We're going to replay this split operation on the corresponding ID void handle(const Split* split) override; @@ -60,7 +60,6 @@ class ReplayTransform : OptOutConstDispatch { // if replaying swizzle is enabled. void handle(const Swizzle2D* swizzle_2d) override; - // We're going to replay this resize operation on the corresponding IDs // if replaying resize is enabled. void handle(const Resize* resize) override; @@ -69,6 +68,41 @@ class ReplayTransform : OptOutConstDispatch { const std::vector& input_ids_; }; +class BackwardTransformCloner : OptInConstDispatch { + public: + // Generates a copy of expression_to_match with provided output + // IterDomains, cloning the inputs in expression_to_match. + static Expr* clone( + const std::vector& ordered_outputs, + const Expr* expression_to_match); + + private: + BackwardTransformCloner() = delete; + + BackwardTransformCloner( + const std::vector& ordered_outputs, + const Expr* expression_to_match); + + using OptInConstDispatch::handle; + + // We're going to replay this split operation on the corresponding ID + void handle(const Split* split) override; + + // We're going to replay this merge operation on the corresponding IDs + void handle(const Merge* merge) override; + + // We're going to replay this swizzle operation on the corresponding IDs + // if replaying swizzle is enabled. + void handle(const Swizzle2D* swizzle_2d) override; + + // We're going to replay this resize operation on the corresponding IDs + // if replaying resize is enabled. + void handle(const Resize* resize) override; + + Expr* new_expr_ = nullptr; + const std::vector& output_ids_; +}; + // Uses the history of _target_domain, and replays that history using the // provided map. // From 1a2b261b4b5465ccdc0fd5e740ba02ac277031af Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 5 Apr 2023 09:51:26 -0400 Subject: [PATCH 006/178] Build promoted tensor domains. --- csrc/id_graphs.cpp | 246 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 231 insertions(+), 15 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index db6188d0e81..41eed00e7b0 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -1879,6 +1879,7 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { } // TODO: Should this just get rolled up in the forwarding map now? + // TODO: Why should IDs be mapped to their compliments? Is this right? for (auto entry : permissive_forwarding.producer_compliment_map) { for (auto entry_2 : entry.second) { idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); @@ -1890,6 +1891,7 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { } // TODO: Should this just get rolled up in the forwarding map now? + // TODO: Why should IDs be mapped to their compliments? Is this right? for (auto entry : permissive_forwarding.consumer_compliment_map) { for (auto entry_2 : entry.second) { idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); @@ -1999,11 +2001,13 @@ void IterDomainGraphs::build( // expressions. idGraph(IdMappingMode::EXACT) = initializeIdGraph(); + std::cout << "buildExactMap" << std::endl; buildExactMap(tv_exprs); - + std::cout << "buildAlmostExactMap" << std::endl; buildAlmostExactMap(); - + std::cout << "buildPermissiveMap" << std::endl; buildPermissiveMap(tv_exprs); + std::cout << "built non lowering graphs" << std::endl; // Only build loop map during lowering if (FusionGuard::getCurFusion()->isA()) { @@ -3243,37 +3247,84 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { << loop_graph_copy_promotion_map.at(group)->toString() << std::endl; } + // Indexing traversal must start at leaf nodes of TensorViews as that's where + // the loop indices are defined. For indexing we need to propagate leaves to + // root domains. We want the indexing graph easy to traverse. Easy to traverse + // means that we start at terminating outputs of this graph and propagate to + // terminating inputs. We shouldn't have to worry about which paths each time + // we traverse the index graph as we may do it many times. + + // The IEL Map cannot be traversed for indexing, because the loop map is + // really only used to model broadcast promotion. We could have multiple paths + // from leaf nodes to an intermediate IEL entry. Meaning: + + // T0 root[i0, i1] T0 leaf domain [i0*i1//32, 4, 8] + // T1 root[i0, i1] T0 leaf domain [i0*i1//32, 8, 4] + + // Even though T0 and T1 are inlined on the outer most dimension, indexing + // into their roots is different. Yet, their roots would be in the same IEL + // entries. + + // The index graph should provide a direct model of what indices are reused, + // i.e. if two ID's in the IndexMap map to eachother, they should use the same + // index math. Therefore, roughly what we need to do is: + + // - Figure out which leaves share exact indexing and map them together: + // (1) Promoted producer-consumer leaf nodes are almost exact. + // (2) Producer-consumer leaf nodes are inlined with eachother, and they're + // almost exact. + + // - Start at the promoted leaf nodes of each tensor view + + // - If those promoted leaf nodes are *ALMOST EXACT* mapped from + // producer-consumer they can be mapped in the index map + + // - Traversing backward from each tensor view's leaf nodes, we directly reach + // the root nodes of that tensor view - // Mark all iter domains that should share a loop nest, ignoring promotion for - // now - auto original_loop_graph = initializeIdGraph(); + // - During the backward traversal, for an expression, if the output iter + // domains are mapped in the index map, their inputs should be mapped as well. + // So as we build the index map, we could also be accumulating mapped iter + // domains. + + // Mark all iter domains that share a loop nest and are almost exact mapped. + // Ignores promotion. + auto index_graph = initializeIdGraph(); for (auto expr : exprs) { - std::vector producer_leaves; + // Iter domains in producer that are inlined with consumer iter domains + std::vector producer_inlined_leaves; + + // Copy of all the producer id's for determinism VectorOfUniqueEntries all_p_ids; for (auto producer : ir_utils::filterByType(expr->inputs())) { all_p_ids.insert( producer->domain()->domain().begin(), producer->domain()->domain().begin() + producer->getComputeAtPosition()); - producer_leaves.insert( - producer_leaves.end(), + producer_inlined_leaves.insert( + producer_inlined_leaves.end(), producer->domain()->domain().begin(), producer->domain()->domain().begin() + producer->getComputeAtPosition()); } - std::vector consumer_leaves; + // Grab potentially inlined iter domains in consumers + std::vector consumer_inlined_leaves; for (auto consumer : ir_utils::filterByType(expr->outputs())) { - consumer_leaves.insert( - consumer_leaves.end(), + consumer_inlined_leaves.insert( + consumer_inlined_leaves.end(), consumer->domain()->domain().begin(), consumer->domain()->domain().begin() + consumer->getMaxProducerPosition()); } - auto p2c_loop_map = idGraph(IdMappingMode::LOOP) - .buildMapBetween(producer_leaves, consumer_leaves); + // Almost exact map from producer inlined iter domains to all the consumer + // domains they could be inlined into. Build an almost exact map between + // those. + auto p2c_loop_map = + idGraph(IdMappingMode::ALMOSTEXACT) + .buildMapBetween(producer_inlined_leaves, consumer_inlined_leaves); // Make sure we call mapIds deterministically for (auto p_id : all_p_ids) { auto p2c_loop_map_it = p2c_loop_map.find(p_id); @@ -3283,12 +3334,177 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { auto c_ids = p2c_loop_map_it->second; for (auto c_id : c_ids) { - original_loop_graph.mapIds(p_id, c_id); + index_graph.mapIds(p_id, c_id); } } } - auto index_graph = initializeIdGraph(); + // Doing the same as above on promoted iter domains is a bit tricky, because + // there's a promoted IterDomian per IEL group, we need a promoted IterDomain + // per index group. So let's figure out which leaf domains share a promoted + // iter domain, so we don't have to build a promoted iter domain for every + // leaf, then try to rejoin them. + + // TODO: I think we need to validate that for each tensor view leaf domains, + // no two leaves within a tensor domain map to another leaf in the same tensor + // domain in the IEL graph. + + // Which non-promoted iter domains, share their promoted iterdomains + DisjointSets shared_promoted_id; + + for (auto expr : exprs) { + std::unordered_map> + promo_id_to_producer_ids; + std::unordered_map> + promo_id_to_consumer_ids; + + // Copy of all promo ids for determinism + VectorOfUniqueEntries all_promo_ids; + + for (auto producer : ir_utils::filterByType(expr->inputs())) { + for (auto p_id : producer->domain()->domain()) { + // Initialize all entries + shared_promoted_id.initializeSet(p_id); + + auto loop_copy_p_group_pair = loop_graph_copy.disjointIdSet(p_id); + TORCH_INTERNAL_ASSERT(loop_copy_p_group_pair.second); + auto loop_copy_p_group = loop_copy_p_group_pair.first; + + auto promo_id_it = + loop_graph_copy_promotion_map.find(loop_copy_p_group); + TORCH_INTERNAL_ASSERT( + promo_id_it != loop_graph_copy_promotion_map.end()); + + promo_id_to_producer_ids[promo_id_it->second].pushBack(p_id); + all_promo_ids.pushBack(promo_id_it->second); + } + } + + for (auto consumer : ir_utils::filterByType(expr->outputs())) { + for (auto c_id : consumer->domain()->domain()) { + // Initialize all entries + shared_promoted_id.initializeSet(c_id); + + auto loop_copy_c_group_pair = loop_graph_copy.disjointIdSet(c_id); + TORCH_INTERNAL_ASSERT(loop_copy_c_group_pair.second); + auto loop_copy_c_group = loop_copy_c_group_pair.first; + + auto promo_id_it = + loop_graph_copy_promotion_map.find(loop_copy_c_group); + TORCH_INTERNAL_ASSERT( + promo_id_it != loop_graph_copy_promotion_map.end()); + + promo_id_to_consumer_ids[promo_id_it->second].pushBack(c_id); + all_promo_ids.pushBack(promo_id_it->second); + } + } + + for (auto promo_id : all_promo_ids) { + auto p_ids_it = promo_id_to_producer_ids.find(promo_id); + if (p_ids_it == promo_id_to_producer_ids.end()) { + continue; + } + auto p_ids = p_ids_it->second; + + auto c_ids_it = promo_id_to_consumer_ids.find(promo_id); + if (c_ids_it == promo_id_to_consumer_ids.end()) { + continue; + } + auto c_ids = c_ids_it->second; + + if (c_ids.size() && p_ids.size()) { + for (auto p_id : p_ids) { + shared_promoted_id.mapEntries(p_ids.front(), p_id); + } + for (auto c_id : c_ids) { + shared_promoted_id.mapEntries(p_ids.front(), c_id); + } + } + } + } + std::cout << "Leaf iter domains that share a promoted iter domain." + << std::endl; + for (auto disjoint_set : shared_promoted_id.disjointSets()) { + std::cout << disjoint_set->toString() << std::endl; + } + + // Map from leaf iter domains to their potentially promoted iter domain used + // for indexing. + std::unordered_map leaf_promotion_map; + + // If a promoted iter domain was generated by replays, it won't be connected + // in the index graph. We can reuse these iter domains directly instead of + // having to make a clone of them. However, we can only use them once for a + // group. + VectorOfUniqueEntries used_promo_ids; + + for (auto id_group : shared_promoted_id.disjointSets()) { + auto first_id = id_group->front(); + auto loop_copy_group_pair = loop_graph_copy.disjointIdSet(first_id); + TORCH_INTERNAL_ASSERT(loop_copy_group_pair.second); + auto loop_copy_group = loop_copy_group_pair.first; + + auto promo_id_it = loop_graph_copy_promotion_map.find(loop_copy_group); + TORCH_INTERNAL_ASSERT(promo_id_it != loop_graph_copy_promotion_map.end()); + + IterDomain* promo_id = promo_id_it->second; + + // Promoted id is already part of the group, just use that. + if (std::find(id_group->begin(), id_group->end(), promo_id) != + id_group->end()) { + for (auto id : *id_group) { + leaf_promotion_map[id] = promo_id; + } + continue; + } + + // Promo id generated from running replay, we can use it for one of the + // index groups. + if (!shared_promoted_id.mappingExists(promo_id) && + !used_promo_ids.has(promo_id)) { + used_promo_ids.pushBack(promo_id); + for (auto id : *id_group) { + leaf_promotion_map[id] = promo_id; + } + continue; + } + + // Need to take a copy of the promo_id as it's already dedicated to an index + // group. + promo_id = cloneIterDomain(promo_id); + for (auto id : *id_group) { + leaf_promotion_map[id] = promo_id; + } + } + + std::cout << "Iter domain group to their promoted iter domain." << std::endl; + for (auto id_group : shared_promoted_id.disjointSets()) { + std::cout << id_group->toString() << "\n -> " + << leaf_promotion_map.at(id_group->front()) << std::endl; + } + + // Could pass this into the function, but just using this for now. + auto all_tvs = ir_utils::allTvsOfExprs(exprs); + + auto promoted_domain = [&](TensorDomain* td) { + std::vector promoted_leaves; + for (auto id : td->domain()) { + auto promo_it = leaf_promotion_map.find(id); + TORCH_INTERNAL_ASSERT(promo_it != leaf_promotion_map.end()); + promoted_leaves.push_back(promo_it->second); + } + return promoted_leaves; + }; + + std::cout << "Promoted tensor view domains:" << std::endl; + // Need to replay all of the indexing expressions to make sure roots are + // connected to domains. + for (auto tv : all_tvs) { + // replay from root to promoted leaves. + std::cout << "TV" << tv->name() << " " << promoted_domain(tv->domain()) + << "\n <- " + << "TV" << tv->name() << tv->domain()->toString() << std::endl; + } TORCH_INTERNAL_ASSERT(false); } From 8b3fe64ffeb0c91f85d4685fa06f4ae9db60fff1 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 5 Apr 2023 15:40:10 -0400 Subject: [PATCH 007/178] First attempt at replaying the index operations. --- csrc/id_graphs.cpp | 173 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 163 insertions(+), 10 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 41eed00e7b0..0f4f61db67b 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -830,11 +830,11 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) for (auto inp : terminating_inputs) { auto use_it = uses_path.find(inp); - TORCH_INTERNAL_ASSERT( - use_it != uses_path.end(), - "Invalid calculation of exprs between, no use found of a provided terminating input: ", - inp->toString(), - " expressions cannot be computed."); + if (use_it == uses_path.end()) { + // This can happen for a trivial traversal where inputs and outputs are + // exactly the same. + continue; + } auto uses = use_it->second; for (auto use : uses) { to_visit.pushBack(use); @@ -1781,6 +1781,9 @@ IterDomain* IterDomainGraphs::cloneIterDomain(IterDomain* id) { auto id_copy = id->cloneWithoutRFactor(); + id_uses_[id_copy] = {}; + id_definitions_[id_copy] = {}; + for (auto mode : initialized_modes) { idGraph(mode).initializeId(id_copy, {}, {}); idGraph(mode).mapIds(id, id_copy); @@ -3486,7 +3489,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Could pass this into the function, but just using this for now. auto all_tvs = ir_utils::allTvsOfExprs(exprs); - auto promoted_domain = [&](TensorDomain* td) { + auto get_promoted_domain = [&](TensorDomain* td) { std::vector promoted_leaves; for (auto id : td->domain()) { auto promo_it = leaf_promotion_map.find(id); @@ -3496,14 +3499,164 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { return promoted_leaves; }; + idGraph(IdMappingMode::INDEX) = initializeIdGraph(); + + // Track every expression required for indexing + VectorOfUniqueEntries all_index_exprs; + // Track every iter domain required for indexing + VectorOfUniqueEntries all_index_ids; + std::cout << "Promoted tensor view domains:" << std::endl; - // Need to replay all of the indexing expressions to make sure roots are - // connected to domains. + // Need to "replay" all of the indexing expressions to make sure roots are + // connected to the promoted leaves, in a way we can index directly on the + // index graph. + + auto& ae_graph = idGraph(IdMappingMode::ALMOSTEXACT); for (auto tv : all_tvs) { + auto promoted_domain = get_promoted_domain(tv->domain()); // replay from root to promoted leaves. - std::cout << "TV" << tv->name() << " " << promoted_domain(tv->domain()) - << "\n <- " + std::cout << "\n\nTV" << tv->name() << " " << promoted_domain << "\n <- " << "TV" << tv->name() << tv->domain()->toString() << std::endl; + + VectorOfUniqueEntries root_ids{ + tv->getRootDomain().begin(), tv->getRootDomain().end()}; + + auto ae_root_groups = ae_graph.toGroups(root_ids); + + std::unordered_map ae_group_2_id; + + for (auto root_i : c10::irange(ae_root_groups.size())) { + ae_group_2_id[ae_root_groups.vector()[root_i]] = + root_ids.vector()[root_i]; + } + + VectorOfUniqueEntries leaf_ids{ + promoted_domain.begin(), promoted_domain.end()}; + + auto ae_leaf_groups = ae_graph.toGroups(leaf_ids); + + // Get indexing transformations + auto indexing_transforms = + ae_graph.getExprsBetween(ae_root_groups, ae_leaf_groups); + + // Replay indexing transformations on the root_ids + for (ExprGroup ae_expr : indexing_transforms) { + std::cout << "Almost exact expr: " << ae_expr->front()->toString(); + + // Replay mostly copied for a third time. + auto input_groups = ae_graph.inputGroups(ae_expr); + + // Inputs "promoted" with the ae_group_2_id map. + // + // if there isn't an entry in ae_group_2_id, then we have a resolved + // merged in broadcast, we need to clone that input. Would be nice to see + // if the dangling input has already been added already through another + // indexing path that this overlaps with, however having an additional Id + // and expression per case doesn't seem too bad right now. + std::vector promoted_inputs; + bool an_input_was_promoted = false; + + for (auto inp_group : input_groups) { + auto inp_promo_it = ae_group_2_id.find(inp_group); + if (inp_promo_it == ae_group_2_id.end()) { + // Clone dangling input, this is unique for index graph compared to + // the other replays. + promoted_inputs.push_back(cloneIterDomain(inp_group->front())); + } else { + promoted_inputs.push_back(inp_promo_it->second); + an_input_was_promoted = true; + } + } + + if (!an_input_was_promoted) { + // No inputs need promotion so just continue + continue; + } + + // Debug print the promotion + for (auto inp_i : c10::irange(input_groups.size())) { + auto inp_group = input_groups.vector()[inp_i]; + auto inp_promo_it = ae_group_2_id.find(inp_group); + if (inp_promo_it == ae_group_2_id.end()) { + std::cout << "Cloned input: " << promoted_inputs[inp_i] << std::endl; + } else { + std::cout << "\"Promoted\" input: " << promoted_inputs[inp_i] + << std::endl; + } + } + + Expr* replay = nullptr; + + // Before replaying, check if there's already an expression like this, if + // so use that for promotion. + ExprGroups promoted_input_uses; + for (auto inp_id : promoted_inputs) { + auto index_group = + idGraph(IdMappingMode::INDEX).toGroups({inp_id}).front(); + promoted_input_uses.pushBack( + idGraph(IdMappingMode::INDEX).uniqueUses(index_group)); + } + + for (auto index_use_group : promoted_input_uses) { + std::cout << "Check use: " << index_use_group->front()->toString(); + if (transformAtributesMatch( + ae_expr->front(), index_use_group->front())) { + std::cout << " Attributes match" << std::endl; + auto index_use_inputs = ir_utils::filterByType( + index_use_group->front()->inputs()) + .vector(); + bool inps_match = true; + for (auto inp_i : c10::irange(index_use_inputs.size())) { + inps_match = inps_match && + idGraph(IdMappingMode::INDEX) + .disjointIdSets() + .strictAreMapped( + index_use_inputs[inp_i], promoted_inputs[inp_i]); + if (!idGraph(IdMappingMode::INDEX) + .disjointIdSets() + .strictAreMapped( + index_use_inputs[inp_i], promoted_inputs[inp_i])) { + std::cout << " " << index_use_inputs[inp_i]->toString() + << " doesn't match " + << promoted_inputs[inp_i]->toString() << std::endl; + } + } + if (inps_match) { + std::cout << " Inputs match" << std::endl; + replay = index_use_group->front(); + break; + } else { + std::cout << " Inputs don't match" << std::endl; + } + } else { + std::cout << " Attributes don't match" << std::endl; + } + } + + if (replay == nullptr) { + std::cout << "Replay: " << ae_expr->front(); + std::cout << "With promoted inputs: " << promoted_inputs << std::endl; + replay = addReplayAs(promoted_inputs, ae_expr->front()); + std::cout << "REPLAY3 :\n " << replay->toString() << std::endl; + } + + auto out_groups = + idGraph(IdMappingMode::ALMOSTEXACT).outputGroups(ae_expr); + + // Mark outputs as having a promoted iter domain + auto replay_out_ids = + ir_utils::filterByType(replay->outputs()).vector(); + + TORCH_INTERNAL_ASSERT(replay_out_ids.size() == out_groups.size()); + + for (auto i : c10::irange(replay_out_ids.size())) { + ae_group_2_id[out_groups.vector()[i]] = replay_out_ids[i]; + std::cout << "Mapping: " << out_groups.vector()[i]->toString() << " -> " + << replay_out_ids[i]->toString() << std::endl; + } + + std::cout << std::endl; + } } TORCH_INTERNAL_ASSERT(false); From 4d5c60459e524d34df0a9561a85074da54cfae5d Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 5 Apr 2023 16:06:39 -0400 Subject: [PATCH 008/178] Print indexing expressoins and iter domains. --- csrc/id_graphs.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 0f4f61db67b..b83ddd92605 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -3640,6 +3640,16 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::cout << "REPLAY3 :\n " << replay->toString() << std::endl; } + all_index_exprs.pushBack(replay); + + { + auto in_ids = ir_utils::filterByType(replay->inputs()); + all_index_ids.insert(in_ids.begin(), in_ids.end()); + + auto out_ids = ir_utils::filterByType(replay->outputs()); + all_index_ids.insert(out_ids.begin(), out_ids.end()); + } + auto out_groups = idGraph(IdMappingMode::ALMOSTEXACT).outputGroups(ae_expr); @@ -3659,6 +3669,14 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } + std::cout << "Indexing expressions: " << std::endl; + for (auto expr : all_index_exprs) { + std::cout << expr->toString(); + } + + std::cout << "All indexing iter domains: " << all_index_ids.toString() + << std::endl; + TORCH_INTERNAL_ASSERT(false); } From 0e47e5ca268f6d094ebce7161ca982effa5ed812 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 6 Apr 2023 10:37:44 -0400 Subject: [PATCH 009/178] Clean up some printing. --- csrc/id_graphs.cpp | 284 +++++---------------------------------------- 1 file changed, 31 insertions(+), 253 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index b83ddd92605..be625d8c98d 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -2309,24 +2309,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - std::cout << "p2c_root_broadcast_resolution_map" << std::endl; - for (auto p_id : ordered_p_ca_ids) { - if (p2c_root_broadcast_resolution_map.find(p_id) != - p2c_root_broadcast_resolution_map.end()) { - std::cout << p_id->toString() << " -> " - << p2c_root_broadcast_resolution_map.at(p_id).toString(); - } - } - - std::cout << "p2c_ca_permissive_maps" << std::endl; - for (auto p_id : ordered_p_ca_ids) { - if (p2c_ca_permissive_maps.find(p_id) != p2c_ca_permissive_maps.end()) { - std::cout << p_id->toString() << " -> " - << p2c_ca_permissive_maps.at(p_id).toString() << std::endl; - ; - } - } - // Terminal loop ids are iteration domains in each loop group that: // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a // consumer TV's iter domain maps to this domain in a way that that domain @@ -2385,33 +2367,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { terminal_loop_ids = p2c_ca_terminal_loop_ids.intersect(id_consumer_terminal_loop_ids); - // std::cout << "Loop graph: " << std::endl; - // { - // IdGroups groups; - // for (auto group : - // idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - // groups.pushBack(group); - // } - // std::cout << debug_print::idGroupsStringShort(groups) << std::endl; - // } - - // std::cout << "p2c ca terminal: " << p2c_ca_terminal_loop_ids.toString() - // << std::endl; - // std::cout << "id consumer terminal: " - // << id_consumer_terminal_loop_ids.toString() << std::endl; - // std::cout << "Terminal: " << terminal_loop_ids.toString() << std::endl; - - // std::cout << "Almost Exact graph: " << std::endl; - // { - // IdGroups groups; - // for (auto group : - // idGraph(IdMappingMode::ALMOSTEXACT).disjointIdSets().disjointSets()) - // { - // groups.pushBack(group); - // } - // std::cout << debug_print::idGroupsStringShort(groups) << std::endl; - // } - // Make an intersection of the exact and loop map. This will group together // entries in each loop group that are exact with eachother. This provides a // better graph to do promotion and replays. @@ -2454,16 +2409,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - // std::cout << "Intersection exact - loop: " << std::endl; - // { - // IdGroups groups; - // for (auto group : - // intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - // groups.pushBack(group); - // } - // std::cout << debug_print::idGroupsStringShort(groups) << std::endl; - // } - // Promotion logic is going to be on the intersection of the exact and loop // graph. We will generate a map on the entries of this graph so it's // important to not modify this graph moving forward, as that would invalidate @@ -2566,22 +2511,13 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (entry_it == iel_promotion_map.end()) { continue; } - std::cout << entry_it->second->toString() << " <- " + std::cout << " " << entry_it->second->toString() << " <- " << entry_it->first->toString() << std::endl; } - // std::cout << "Loop graph: " << std::endl; - // { - // IdGroups groups; - // for (auto group : - // idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - // groups.pushBack(group); - // } - // std::cout << debug_print::idGroupsStringShort(groups) << std::endl; - // } - IdGraphStmtSort iel_stmt_sort(intersection_exact_loop_graph); + std::cout<<"Initial promotion replay:"<& exprs) { continue; } - for (auto inp : input_groups) { - auto inp_promo_it = iel_promotion_map.find(inp); - if (inp_promo_it == iel_promotion_map.end()) { - std::cout << "IEL inp: " << debug_print::idGroupStringShort(inp) - << std::endl; - } else { - std::cout << "Promoted input: " << debug_print::idGroupStringShort(inp) - << " -> " << inp_promo_it->second->toString() << std::endl; - } - } - Expr* replay = nullptr; // Before replaying, check if there's already an expression like this, if so @@ -2628,10 +2553,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } for (auto exact_use_group : promoted_input_uses) { - std::cout << "Check use: " << exact_use_group->front()->toString(); if (transformAtributesMatch( iel_expr->front(), exact_use_group->front())) { - std::cout << "Attributes match" << std::endl; auto exact_use_inps = ir_utils::filterByType( exact_use_group->front()->inputs()) .vector(); @@ -2642,13 +2565,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { .disjointIdSets() .strictAreMapped( exact_use_inps[inp_i], promoted_inputs[inp_i]); - if (!idGraph(IdMappingMode::EXACT) - .disjointIdSets() - .strictAreMapped( - exact_use_inps[inp_i], promoted_inputs[inp_i])) { - std::cout << exact_use_inps[inp_i]->toString() << " doesn't match " - << promoted_inputs[inp_i]->toString() << std::endl; - } } if (inps_match) { replay = exact_use_group->front(); @@ -2659,8 +2575,10 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (replay == nullptr) { replay = addReplayAs(promoted_inputs, iel_expr->front()); - std::cout << "REPLAY:\n " << iel_expr->front() << " " - << replay->toString() << std::endl; + std::cout << " ***REPLAY***:\n " << iel_expr->front() << " As:" + << replay->toString(); + } else { + std::cout << " Matched replay found: " << replay->toString(); } auto out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); @@ -2673,17 +2591,9 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { for (auto i : c10::irange(replay_out_ids.size())) { iel_promotion_map[out_groups.vector()[i]] = replay_out_ids[i]; - std::cout << "Mapping: " << out_groups.vector()[i]->toString() << " -> " - << replay_out_ids[i]->toString() << std::endl; } } - std::cout << "Filled promotion map:" << std::endl; - for (auto entry : iel_promotion_map) { - std::cout << entry.second->toString() << " <- " << entry.first->toString() - << std::endl; - } - // Map from an exact iter domain group, to all the exact iter domain groups it // covers std::unordered_map exact_covered_ids; @@ -2726,19 +2636,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - std::cout << "Covered exact entries:" << std::endl; - for (auto exact_group : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - auto exact_covered_id_it = exact_covered_ids.find(exact_group); - if (exact_covered_id_it == exact_covered_ids.end()) { - continue; - } - - std::cout << debug_print::idGroupStringShort(exact_group) << " -> " - << debug_print::idGroupsStringShort(exact_covered_id_it->second) - << std::endl; - } - // Loop promotion map is to prepare for IterDomain replays. Since these // replays will modify the loop map, we operate on a copy of the loop map, // not the original one. @@ -2811,7 +2708,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (loop_promotion_id == nullptr) { std::stringstream err_msg; - err_msg << "\nCould not find promotion for loop group:\n "; + err_msg + << "\n ERROR Loop promotion map build. Could not find promotion for loop group:\n "; err_msg << debug_print::idGroupStringShort(loop_group); err_msg << "\nnone of the terminal iter domains of this group:\n "; for (auto entry : exact_promoted_terminal_ids) { @@ -2829,79 +2727,10 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; } - // std::cout << "Loop graph copy: " << std::endl; - // for (auto group : - // loop_graph_copy.disjointIdSets().disjointSets()) { - // std::cout << debug_print::idGroupStringShort(group) << std::endl; - // } - - // std::cout << "Loop graph copy promotion map: " << std::endl; - // for (auto group : - // loop_graph_copy.disjointIdSets().disjointSets()) { - // if (loop_graph_copy_promotion_map.find(group) == - // loop_graph_copy_promotion_map.end()) { - // continue; - // } - // std::cout << debug_print::idGroupStringShort(group) << " -> " - // << loop_graph_copy_promotion_map.at(group)->toString() << - // std::endl; - // } - - // std::cout << "All exprs in loop map" << std::endl; - - // iel_promotion_map.clear(); - - // // Reinitialize the IEL graph, entries have been added since it's been - // built. intersection_exact_loop_graph = initializeIdGraph(); for (auto - // exact_group : - // idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - // auto set_size = exact_group->size(); - // for (auto id0_i : c10::irange(set_size)) { - // auto id0 = exact_group->vector()[id0_i]; - // for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { - // auto id1 = exact_group->vector()[id1_i]; - // // id0 and id1 map in the almost exact map, if they also map in the - // loop - // // graph, then add the mapping to the inersection - // if (idGraph(IdMappingMode::LOOP) - // .disjointIdSets() - // .strictAreMapped(id0, id1)) { - // intersection_exact_loop_graph.mapIds(id0, id1); - // } - // } - // } - // } - - // std::cout << "IEL Graph POST: " << std::endl; - // for (auto entry : - // intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - // std::cout << debug_print::idGroupStringShort(entry) << std::endl; - // } - - // // Initialize IterDomain promotions based on loop group, onto the - // intersection - // // exact loop graph - // for(auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()){ - // auto promo_it = loop_graph_copy_promotion_map.find(loop_group); - // if ( promo_it == - // loop_graph_copy_promotion_map.end()) { - // continue; - // } - // auto promo_id = promo_it->second; - // auto iel_groups = intersection_exact_loop_graph.toGroups(*loop_group); - // for(auto iel_group : iel_groups){ - // if (!idGraph(IdMappingMode::ALMOSTEXACT) - // .disjointIdSets() - // .strictAreMapped(promo_id, iel_group->front())) { - // iel_promotion_map[iel_group] = promo_id; - // } - // } - // } - // Reset the promotion map for the second pass iel_promotion_map.clear(); - std::cout << "\n\n Forward replay iel graph:" << std::endl; + std::cout << "\n\nForward replay iel graph:" << std::endl; IdGraphStmtSort iel_stmt_sort2(intersection_exact_loop_graph); for (auto iel_expr : iel_stmt_sort2.exprs()) { @@ -2936,7 +2765,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::vector promoted_inputs; - bool input_is_promoted = false; + bool an_input_was_promoted = false; // Promote inputs for replay for (auto iel_inp_group : iel_inp_groups) { @@ -2948,19 +2777,19 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (loop_promote_inputs && inp_loop_promo_it != loop_graph_copy_promotion_map.end()) { promoted_inputs.push_back(inp_loop_promo_it->second); - input_is_promoted = true; + an_input_was_promoted = true; } else { auto inp_promo_it = iel_promotion_map.find(iel_inp_group); if (inp_promo_it == iel_promotion_map.end()) { promoted_inputs.push_back(iel_inp_group->front()); } else { promoted_inputs.push_back(inp_promo_it->second); - input_is_promoted = true; + an_input_was_promoted = true; } } } - if (!input_is_promoted) { + if (!an_input_was_promoted) { continue; } @@ -2977,10 +2806,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } for (auto exact_use_group : promoted_input_uses) { - std::cout << "Check use: " << exact_use_group->front()->toString(); if (transformAtributesMatch( iel_expr->front(), exact_use_group->front())) { - std::cout << "Attributes match" << std::endl; auto exact_use_inps = ir_utils::filterByType( exact_use_group->front()->inputs()) .vector(); @@ -3008,8 +2835,10 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (replay == nullptr) { replay = addReplayAs(promoted_inputs, iel_expr->front()); - std::cout << "REPLAY2:\n " << iel_expr->front() << " " - << replay->toString() << std::endl; + std::cout << " ***REPLAY2***:\n " << iel_expr->front() << " As:" + << replay->toString(); + } else { + std::cout << " Matched replay found: " << replay->toString(); } auto output_groups = intersection_exact_loop_graph.outputGroups(iel_expr); @@ -3028,17 +2857,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { iel_promotion_map[output_groups.vector()[i]] = replay_out_ids[i]; } } - - std::cout << " " - << debug_print::exprGroupStringShort( - intersection_exact_loop_graph, iel_expr) - << std::endl; - } - - std::cout << "Filled promotion map2:" << std::endl; - for (auto entry : iel_promotion_map) { - std::cout << entry.second->toString() << " <- " << entry.first->toString() - << std::endl; } // Need to update the iel_graph again since we've added operations to the @@ -3105,19 +2923,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - std::cout << "Covered exact entries:" << std::endl; - for (auto exact_group : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - auto exact_covered_id_it = exact_covered_ids.find(exact_group); - if (exact_covered_id_it == exact_covered_ids.end()) { - continue; - } - - std::cout << debug_print::idGroupStringShort(exact_group) << " -> " - << debug_print::idGroupsStringShort(exact_covered_id_it->second) - << std::endl; - } - // Loop promotion map is to prepare for IterDomain replays. Since these // replays will modify the loop map, we operate on a copy of the loop map, // not the original one. @@ -3125,6 +2930,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { loop_graph_copy = idGraph(IdMappingMode::LOOP); loop_graph_copy_promotion_map.clear(); + std::cout << "Find promoted ids within loop groups." << std::endl; + for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_group->size() == 1) { loop_graph_copy_promotion_map[loop_group] = loop_group->front(); @@ -3506,7 +3313,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Track every iter domain required for indexing VectorOfUniqueEntries all_index_ids; - std::cout << "Promoted tensor view domains:" << std::endl; + std::cout << "Building promoted tensor view domains:" << std::endl; // Need to "replay" all of the indexing expressions to make sure roots are // connected to the promoted leaves, in a way we can index directly on the // index graph. @@ -3515,8 +3322,9 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { for (auto tv : all_tvs) { auto promoted_domain = get_promoted_domain(tv->domain()); // replay from root to promoted leaves. - std::cout << "\n\nTV" << tv->name() << " " << promoted_domain << "\n <- " - << "TV" << tv->name() << tv->domain()->toString() << std::endl; + std::cout << "\n\n Processing: TV" << tv->name() + << "\n Promoted: " << promoted_domain + << "\n Original: " << tv->domain()->toString() << std::endl; VectorOfUniqueEntries root_ids{ tv->getRootDomain().begin(), tv->getRootDomain().end()}; @@ -3539,10 +3347,9 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { auto indexing_transforms = ae_graph.getExprsBetween(ae_root_groups, ae_leaf_groups); + std::cout<<" Replaying path to domain:"<front()->toString(); - // Replay mostly copied for a third time. auto input_groups = ae_graph.inputGroups(ae_expr); @@ -3573,18 +3380,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { continue; } - // Debug print the promotion - for (auto inp_i : c10::irange(input_groups.size())) { - auto inp_group = input_groups.vector()[inp_i]; - auto inp_promo_it = ae_group_2_id.find(inp_group); - if (inp_promo_it == ae_group_2_id.end()) { - std::cout << "Cloned input: " << promoted_inputs[inp_i] << std::endl; - } else { - std::cout << "\"Promoted\" input: " << promoted_inputs[inp_i] - << std::endl; - } - } - Expr* replay = nullptr; // Before replaying, check if there's already an expression like this, if @@ -3598,10 +3393,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } for (auto index_use_group : promoted_input_uses) { - std::cout << "Check use: " << index_use_group->front()->toString(); if (transformAtributesMatch( ae_expr->front(), index_use_group->front())) { - std::cout << " Attributes match" << std::endl; auto index_use_inputs = ir_utils::filterByType( index_use_group->front()->inputs()) .vector(); @@ -3612,32 +3405,21 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { .disjointIdSets() .strictAreMapped( index_use_inputs[inp_i], promoted_inputs[inp_i]); - if (!idGraph(IdMappingMode::INDEX) - .disjointIdSets() - .strictAreMapped( - index_use_inputs[inp_i], promoted_inputs[inp_i])) { - std::cout << " " << index_use_inputs[inp_i]->toString() - << " doesn't match " - << promoted_inputs[inp_i]->toString() << std::endl; - } } if (inps_match) { - std::cout << " Inputs match" << std::endl; replay = index_use_group->front(); break; - } else { - std::cout << " Inputs don't match" << std::endl; } - } else { - std::cout << " Attributes don't match" << std::endl; } } if (replay == nullptr) { - std::cout << "Replay: " << ae_expr->front(); - std::cout << "With promoted inputs: " << promoted_inputs << std::endl; + std::cout << " Replay: " << ae_expr->front(); + std::cout << " With promoted inputs: " << promoted_inputs + << std::endl; replay = addReplayAs(promoted_inputs, ae_expr->front()); - std::cout << "REPLAY3 :\n " << replay->toString() << std::endl; + std::cout << " ***REPLAY3***:\n " << ae_expr->front() + << " As:" << replay->toString(); } all_index_exprs.pushBack(replay); @@ -3661,20 +3443,16 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { for (auto i : c10::irange(replay_out_ids.size())) { ae_group_2_id[out_groups.vector()[i]] = replay_out_ids[i]; - std::cout << "Mapping: " << out_groups.vector()[i]->toString() << " -> " - << replay_out_ids[i]->toString() << std::endl; } - - std::cout << std::endl; } } - std::cout << "Indexing expressions: " << std::endl; + std::cout << "All indexing expressions that need to be processed: " << std::endl; for (auto expr : all_index_exprs) { std::cout << expr->toString(); } - std::cout << "All indexing iter domains: " << all_index_ids.toString() + std::cout << "All iter domains that would be indexed: " << all_index_ids.toString() << std::endl; TORCH_INTERNAL_ASSERT(false); From 91fb637a76d0e92498d5854c794663cb1309d90c Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 6 Apr 2023 19:37:07 -0400 Subject: [PATCH 010/178] Fix for iel promotion replay. --- csrc/id_graphs.cpp | 70 ++++++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index be625d8c98d..e74f024c1c9 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -2452,8 +2452,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { idGraph(IdMappingMode::EXACT).toGroups(*loop_group); // The intersection of the exact groups that the broadcast domains can be - // broadcasted to, and those that exist within the same loop are is the - // promotion needed for this iel_group. + // broadcasted to, and those that exist within the same loop groop are is + // the promotion needed for this iel_group. auto loop_exact_resolved_intersection = resolved_exact_groups.intersect(loop_covered_exact_groups); @@ -2475,7 +2475,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { TORCH_INTERNAL_ASSERT(false, err_msg.str()); } - // loop_exact_resolved_intersection.size() == 1 + // loop_exact_resolved_intersection.size() must be 1 at this point auto exact_resolution_group = loop_exact_resolved_intersection.front(); VectorOfUniqueEntries resolved_ids = @@ -2517,9 +2517,10 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { IdGraphStmtSort iel_stmt_sort(intersection_exact_loop_graph); - std::cout<<"Initial promotion replay:"< promoted_inputs; @@ -2542,32 +2543,32 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { Expr* replay = nullptr; - // Before replaying, check if there's already an expression like this, if so - // use that for promotion. - ExprGroups promoted_input_uses; - for (auto inp_id : promoted_inputs) { - auto inp_exact_group = - idGraph(IdMappingMode::EXACT).toGroups({inp_id}).front(); - promoted_input_uses.pushBack( - idGraph(IdMappingMode::EXACT).uniqueUses(inp_exact_group)); - } + auto promoted_input_groups = intersection_exact_loop_graph.toGroups( + VectorOfUniqueEntries{ + promoted_inputs.begin(), promoted_inputs.end()}); - for (auto exact_use_group : promoted_input_uses) { - if (transformAtributesMatch( - iel_expr->front(), exact_use_group->front())) { - auto exact_use_inps = ir_utils::filterByType( - exact_use_group->front()->inputs()) - .vector(); + // Before replaying, check if there's already an expression like this, if so + // use that for promotion. We would need the iel entries for non-promoted + // inputs to match exactly to reuse the expression. + ExprGroups non_promoted_input_uses; + for (auto iel_group : promoted_input_groups.intersect(input_groups)) { + non_promoted_input_uses.pushBack( + intersection_exact_loop_graph.uniqueUses(iel_group)); + } + + for (auto iel_use_group : non_promoted_input_uses) { + if (transformAtributesMatch(iel_expr->front(), iel_use_group->front())) { + auto use_inps = + ir_utils::filterByType(iel_use_group->front()->inputs()) + .vector(); bool inps_match = true; - for (auto inp_i : c10::irange(exact_use_inps.size())) { + for (auto inp_i : c10::irange(use_inps.size())) { inps_match = inps_match && - idGraph(IdMappingMode::EXACT) - .disjointIdSets() - .strictAreMapped( - exact_use_inps[inp_i], promoted_inputs[inp_i]); + intersection_exact_loop_graph.disjointIdSets().strictAreMapped( + use_inps[inp_i], promoted_inputs[inp_i]); } if (inps_match) { - replay = exact_use_group->front(); + replay = iel_use_group->front(); break; } } @@ -2575,8 +2576,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (replay == nullptr) { replay = addReplayAs(promoted_inputs, iel_expr->front()); - std::cout << " ***REPLAY***:\n " << iel_expr->front() << " As:" - << replay->toString(); + std::cout << " ***REPLAY***:\n " << iel_expr->front() + << " As:" << replay->toString(); } else { std::cout << " Matched replay found: " << replay->toString(); } @@ -2835,8 +2836,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (replay == nullptr) { replay = addReplayAs(promoted_inputs, iel_expr->front()); - std::cout << " ***REPLAY2***:\n " << iel_expr->front() << " As:" - << replay->toString(); + std::cout << " ***REPLAY2***:\n " << iel_expr->front() + << " As:" << replay->toString(); } else { std::cout << " Matched replay found: " << replay->toString(); } @@ -3347,7 +3348,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { auto indexing_transforms = ae_graph.getExprsBetween(ae_root_groups, ae_leaf_groups); - std::cout<<" Replaying path to domain:"<& exprs) { } } - std::cout << "All indexing expressions that need to be processed: " << std::endl; + std::cout << "All indexing expressions that need to be processed: " + << std::endl; for (auto expr : all_index_exprs) { std::cout << expr->toString(); } - std::cout << "All iter domains that would be indexed: " << all_index_ids.toString() - << std::endl; + std::cout << "All iter domains that would be indexed: " + << all_index_ids.toString() << std::endl; TORCH_INTERNAL_ASSERT(false); } From 1175da2ebea650e378d0bc1ae9ee6e0ea7df4586 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 8 Apr 2023 14:51:07 -0400 Subject: [PATCH 011/178] Fix indexing expression generation. Some minor cleanup and refactoring. --- csrc/disjoint_set.h | 20 +++ csrc/id_graphs.cpp | 329 ++++++++++++++++---------------------------- csrc/id_graphs.h | 33 +++-- 3 files changed, 151 insertions(+), 231 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 67bc08c7b56..bd82bc73a2e 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -398,6 +398,26 @@ class DisjointSets { return disjoint_set_maps_.find(entry) != disjoint_set_maps_.end(); } + // Erases element if it exists in the disjoint set, returns if element found. + bool erase(T entry) { + auto entry_it = disjoint_set_maps_.find(entry); + if (entry_it == disjoint_set_maps_.end()) { + return false; + } + + auto set = entry_it->second; + if (set->size() == 1 && set->front() == entry) { + disjoint_set_maps_.erase(entry); + disjoint_sets_.erase( + std::find(disjoint_sets_.begin(), disjoint_sets_.end(), set)); + } else { + disjoint_set_maps_.erase(entry); + set->erase(entry); + } + + return true; + } + // Returns a deterministic list of all entries that have been added to any // disjoint set. // diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index e74f024c1c9..1d08fc9ca30 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -13,7 +13,7 @@ namespace nvfuser { -namespace debug_print { +namespace debug_string { // A few compressed printing utilities to show critical uniqueness information. // i.e. being able to tell slight differences between groups we're working with. @@ -170,15 +170,10 @@ std::string exprGroupsStringShort( std::stringstream ss; ss << /* ptrStringShort(&expr_groups) <<*/ "(exprs) {"; - bool first = true; for (auto i : c10::irange(group_name_info.size())) { - if (first) { - first = false; - } else { - ss << ", "; - } auto pos = group_name_info[i].second; - ss << exprGroupStringShort(id_graph, expr_groups.vector()[pos]); + ss << " " << exprGroupStringShort(id_graph, expr_groups.vector()[pos]) + << "\n"; } ss << "}"; @@ -217,7 +212,7 @@ std::string usesToString(const IdGraph& id_graph) { return ss.str(); } -} // namespace debug_print +} // namespace debug_string namespace { @@ -414,8 +409,6 @@ void IdGraphVisitor::traverse() { IdGraph::IdGraph(const IdGraph& other) { disjoint_ids_ = other.disjoint_ids_; disjoint_exprs_ = other.disjoint_exprs_; - id_uses_ = other.id_uses_; - id_definitions_ = other.id_definitions_; view_rfactor_ids_ = other.view_rfactor_ids_; for (auto orig_unique_def_pair : other.unique_definitions_) { @@ -460,8 +453,6 @@ IdGraph& IdGraph::operator=(const IdGraph& other) { disjoint_exprs_.clear(); unique_definitions_.clear(); unique_uses_.clear(); - id_uses_.clear(); - id_definitions_.clear(); view_rfactor_ids_.clear(); IdGraph copy(other); std::swap(*this, copy); @@ -631,13 +622,13 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) all_id_groups.pushBack(inp_groups); - if (inp_groups.empty()) { + if (!inp_groups.empty()) { not_outputs.pushBack(inp_groups); } all_id_groups.pushBack(out_groups); - if (out_groups.empty()) { + if (!out_groups.empty()) { not_inputs.pushBack(out_groups); } } @@ -1109,7 +1100,7 @@ ExprGroups IdGraph::uniqueDefinitions(IdGroup group) const { ExprGroups IdGraph::uniqueUses(IdGroup group) const { auto unique_uses_it = unique_uses_.find(group); TORCH_INTERNAL_ASSERT( - unique_uses_it != unique_definitions_.end(), + unique_uses_it != unique_uses_.end(), "Uses not found for IdGroup: ", group->toString()); return unique_uses_it->second; @@ -1291,6 +1282,53 @@ void IterDomainGraphs::assertNoSelfMapping() { ", are mapped with each other."); } +void IdGraph::mapThroughTrivialExprs() { + // Grab all expressions + std::vector exprs; + + for (auto expr_group : disjointExprSets().disjointSets()) { + for (auto expr : *expr_group) { + exprs.push_back(expr); + } + } + + for (auto expr : exprs) { + // If not trivial continue + auto mapped_ids = IdGraph::isTrivialExpr(expr); + if (mapped_ids.empty()) { + continue; + } + + // Map through trivial expressions + for (auto mapped_id_group : mapped_ids) { + for (auto id : mapped_id_group) { + mapIds(mapped_id_group.front(), id); + } + } + } +} + +void IdGraph::removeTrivialExprs() { + ExprGroups trivial_expr_groups; + for (auto expr_group : disjointExprSets().disjointSets()) { + auto inp_groups = inputGroups(expr_group); + auto out_groups = outputGroups(expr_group); + if (inp_groups.intersect(out_groups).size()) { + trivial_expr_groups.pushBack(expr_group); + } + } + + // Clear out expressions that map inputs and outputs to the same group + // from definitions and uses. They shouldn't be important in traversal, and + // will break the terminal input/terminal output logic of traversal. Similar + // to what's drafted in buildIndexMap + for (auto trivial_expr_group : trivial_expr_groups) { + // Complexity of erase not good as both disjoint set and vector of unique + // entries require a vector find to erase an entry. + eraseExprGroup(trivial_expr_group); + } +} + void IdGraph::mapThroughLoopSwizzles() { for (auto use_pairs : unique_uses_) { auto use_groups = use_pairs.second; @@ -1310,6 +1348,30 @@ void IdGraph::mapThroughLoopSwizzles() { } } +// Complexity here is not great. We might want a better complexity version when +// erasing multiple expr_groups. +void IdGraph::eraseExprGroup(ExprGroup expr_group) { + // Erase entries that exist in unique_definitions_ and unique_uses_ + for (auto id_group : disjointIdSets().disjointSets()) { + // Make sure the entries exists + TORCH_INTERNAL_ASSERT( + unique_definitions_.find(id_group) != unique_definitions_.end(), + "Broken definitions, couldn't find entry for id group, ", + debug_string::idGroupStringShort(id_group)); + TORCH_INTERNAL_ASSERT( + unique_uses_.find(id_group) != unique_uses_.end(), + "Broken uses, couldn't find entry for id group, ", + debug_string::idGroupStringShort(id_group)); + + unique_definitions_[id_group].erase(expr_group); + unique_uses_[id_group].erase(expr_group); + } + + for (auto expr : *expr_group) { + disjoint_exprs_.erase(expr); + } +} + IterDomainGraphs::IterDomainGraphs( const std::vector& exprs, const std::vector& additional_tvs, @@ -1915,29 +1977,8 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { void IterDomainGraphs::buildAlmostExactMap() { // Build almost exact map by forwarding through broadcast axes idGraph(IdMappingMode::ALMOSTEXACT) = idGraph(IdMappingMode::EXACT); - - VectorOfUniqueEntries exprs; - for (auto expr : - idGraph(IdMappingMode::ALMOSTEXACT).disjointExprSets().disjointSets()) { - exprs.pushBack(expr->front()); - } - ExprGroups trivial_expr_groups; - - // Map through trivial expressions - for (auto expr : exprs) { - auto mapped_ids = IdGraph::isTrivialExpr(expr); - for (auto mapped_id_group : mapped_ids) { - for (auto id : mapped_id_group) { - trivial_expr_groups.pushBack( - idGraph(IdMappingMode::ALMOSTEXACT).disjointExprSet(expr).first); - idGraph(IdMappingMode::ALMOSTEXACT).mapIds(mapped_id_group.front(), id); - } - } - } - - // TODO: Clear out expressions that map inputs and outputs to the same group - // from definitions and uses. They shouldn't be important in traversal. - // Similar to what's drafted in buildIndexMap + idGraph(IdMappingMode::ALMOSTEXACT).mapThroughTrivialExprs(); + idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); } void IterDomainGraphs::validateAndPropagatePType() const { @@ -2550,6 +2591,12 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Before replaying, check if there's already an expression like this, if so // use that for promotion. We would need the iel entries for non-promoted // inputs to match exactly to reuse the expression. + // + // Unfortunately this doesn't actually seem to save any replays because + // we're not adding the replayed expression to the iel graph since we're + // traversing the iel graph. + // + // TODO: Can we reduce the number of new expressions generated here? ExprGroups non_promoted_input_uses; for (auto iel_group : promoted_input_groups.intersect(input_groups)) { non_promoted_input_uses.pushBack( @@ -2711,16 +2758,16 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::stringstream err_msg; err_msg << "\n ERROR Loop promotion map build. Could not find promotion for loop group:\n "; - err_msg << debug_print::idGroupStringShort(loop_group); + err_msg << debug_string::idGroupStringShort(loop_group); err_msg << "\nnone of the terminal iter domains of this group:\n "; for (auto entry : exact_promoted_terminal_ids) { auto terminal_id_group = entry.first; - err_msg << " " << debug_print::idGroupStringShort(terminal_id_group) + err_msg << " " << debug_string::idGroupStringShort(terminal_id_group) << std::endl; } err_msg << "iter domains in this group cover all id groups:\n"; for (auto covered_group : loop_group_covered_ids) { - err_msg << " " << debug_print::idGroupStringShort(covered_group); + err_msg << " " << debug_string::idGroupStringShort(covered_group); } TORCH_INTERNAL_ASSERT(false, err_msg.str()); } @@ -2819,13 +2866,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { .disjointIdSets() .strictAreMapped( exact_use_inps[inp_i], promoted_inputs[inp_i]); - if (!idGraph(IdMappingMode::EXACT) - .disjointIdSets() - .strictAreMapped( - exact_use_inps[inp_i], promoted_inputs[inp_i])) { - std::cout << exact_use_inps[inp_i]->toString() << " doesn't match " - << promoted_inputs[inp_i]->toString() << std::endl; - } } if (inps_match) { replay = exact_use_group->front(); @@ -3029,16 +3069,16 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (loop_promotion_id == nullptr) { std::stringstream err_msg; err_msg << "\nCould not find promotion for loop group:\n "; - err_msg << debug_print::idGroupStringShort(loop_group); + err_msg << debug_string::idGroupStringShort(loop_group); err_msg << "\nnone of the terminal iter domains of this group:\n "; for (auto entry : exact_promoted_terminal_ids) { auto terminal_id_group = entry.first; - err_msg << " " << debug_print::idGroupStringShort(terminal_id_group) + err_msg << " " << debug_string::idGroupStringShort(terminal_id_group) << std::endl; } err_msg << "iter domains in this group cover all id groups:\n"; for (auto covered_group : loop_group_covered_ids) { - err_msg << " " << debug_print::idGroupStringShort(covered_group); + err_msg << " " << debug_string::idGroupStringShort(covered_group); } TORCH_INTERNAL_ASSERT(false, err_msg.str()); } @@ -3054,7 +3094,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { loop_graph_copy_promotion_map.end()) { continue; } - std::cout << debug_print::idGroupStringShort(group) << " -> " + std::cout << debug_string::idGroupStringShort(group) << " -> " << loop_graph_copy_promotion_map.at(group)->toString() << std::endl; } @@ -3314,18 +3354,29 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Track every iter domain required for indexing VectorOfUniqueEntries all_index_ids; + // The almost exact map could have new trivial expression groups from the + // replays, which are expressions that have an input mapped to an output of + // that expression. getExprsBetween protects against these, but they can also + // just be removed. + idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); + std::cout << "Building promoted tensor view domains:" << std::endl; // Need to "replay" all of the indexing expressions to make sure roots are // connected to the promoted leaves, in a way we can index directly on the // index graph. - + // + // Since we're performing replays we need to copy the graph we're iterating + // on. auto& ae_graph = idGraph(IdMappingMode::ALMOSTEXACT); + for (auto tv : all_tvs) { auto promoted_domain = get_promoted_domain(tv->domain()); // replay from root to promoted leaves. - std::cout << "\n\n Processing: TV" << tv->name() - << "\n Promoted: " << promoted_domain - << "\n Original: " << tv->domain()->toString() << std::endl; + std::cout << "\n\n Processing: TV" << tv->name() << "\n Root: TV" + << tv->getRootDomain() << "\n Promoted: " + << promoted_domain + // << "\n Original: " << tv->domain()->toString() + << std::endl; VectorOfUniqueEntries root_ids{ tv->getRootDomain().begin(), tv->getRootDomain().end()}; @@ -3454,167 +3505,17 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::cout << expr->toString(); } + std::cout << "All indexing expressions (on the index graph): " << std::endl; + auto index_expr_groups = + idGraph(IdMappingMode::INDEX).toGroups(all_index_exprs); + std::cout << debug_string::exprGroupsStringShort( + idGraph(IdMappingMode::INDEX), index_expr_groups) + << std::endl; + std::cout << "All iter domains that would be indexed: " << all_index_ids.toString() << std::endl; TORCH_INTERNAL_ASSERT(false); } -void IterDomainGraphs::buildIndexMap(const std::vector& all_tvs) { - // Initialize map at loop leaf nodes. This needs to be done just like we - // would in "initializeId" for the exact map. Unlike AlmostExact and - // Permissive, index map is not a superset of exact map. - for (auto loop_group : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - for (auto id : *loop_group) { - auto id_disjoint_set = idGraph(IdMappingMode::INDEX) - .disjointIdSets() - .initializeSet(id) - .first->second; - - auto def_it = id_definitions_.find(id); - if (def_it != id_definitions_.end()) { - auto defs = def_it->second; - ExprGroups expr_groups; - for (auto def : defs) { - auto expr_set = idGraph(IdMappingMode::INDEX) - .disjointExprSets() - .initializeSet(def) - .first->second; - expr_groups.pushBack(expr_set); - } - idGraph(IdMappingMode::INDEX).uniqueDefinitions()[id_disjoint_set] = - expr_groups; - } else { - id_definitions_[id] = {}; - idGraph(IdMappingMode::INDEX).uniqueDefinitions()[id_disjoint_set] = {}; - } - - auto use_it = id_uses_.find(id); - if (use_it != id_uses_.end()) { - auto uses = use_it->second; - ExprGroups expr_groups; - for (auto use : uses) { - auto expr_set = idGraph(IdMappingMode::INDEX) - .disjointExprSets() - .initializeSet(use) - .first->second; - expr_groups.pushBack(expr_set); - } - idGraph(IdMappingMode::INDEX).uniqueUses()[id_disjoint_set] = - expr_groups; - } else { - id_uses_[id] = {}; - idGraph(IdMappingMode::INDEX).uniqueUses()[id_disjoint_set] = {}; - } - } - } - - // Below is the same as building the almost exact map. It just maps through - // trivial expressions and removes their traversal from definition/uses - VectorOfUniqueEntries exprs; - for (auto expr : - idGraph(IdMappingMode::INDEX).disjointExprSets().disjointSets()) { - exprs.pushBack(expr->front()); - } - ExprGroups trivial_expr_groups; - - // Map through trivial expressions - for (auto expr : exprs) { - auto mapped_ids = IdGraph::isTrivialExpr(expr); - for (auto mapped_id_group : mapped_ids) { - for (auto id : mapped_id_group) { - trivial_expr_groups.pushBack( - idGraph(IdMappingMode::INDEX).disjointExprSet(expr).first); - idGraph(IdMappingMode::INDEX).mapIds(mapped_id_group.front(), id); - } - } - } - - // Clear out expressions that map inputs and outputs to the same group from - // definitions and uses. They shouldn't be important in traversal. Iterate - // on a copy as we're updating the map as we traverse. - std::unordered_map defs_copy = - idGraph(IdMappingMode::INDEX).uniqueDefinitions(); - for (auto& id_2_expr_group_map_entry : defs_copy) { - ExprGroups expr_groups_new; - for (auto& expr_group : id_2_expr_group_map_entry.second) { - if (!trivial_expr_groups.has(expr_group)) { - expr_groups_new.pushBack(expr_group); - } - } - - if (expr_groups_new.size() == id_2_expr_group_map_entry.second.size()) { - continue; - } - - idGraph(IdMappingMode::INDEX) - .uniqueDefinitions()[id_2_expr_group_map_entry.first] = expr_groups_new; - } - - std::unordered_map uses_copy = - idGraph(IdMappingMode::INDEX).uniqueUses(); - for (auto& id_2_expr_group_map_entry : uses_copy) { - ExprGroups expr_groups_new; - for (auto expr_group : id_2_expr_group_map_entry.second) { - if (!trivial_expr_groups.has(expr_group)) { - expr_groups_new.pushBack(expr_group); - } - } - - if (expr_groups_new.size() == id_2_expr_group_map_entry.second.size()) { - continue; - } - if (!expr_groups_new.empty()) { - for (auto i : c10::irange(100)) { - if (i > 0) { - expr_groups_new.pushBack(expr_groups_new.front()); - } - } - } - - idGraph(IdMappingMode::INDEX) - .uniqueUses()[id_2_expr_group_map_entry.first] = expr_groups_new; - } - - for (auto loop_group : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - auto loop_promotion_it = loop_promotion_map_.find(loop_group); - } - IdGroups processed; - - for (auto tv : all_tvs) { - if (tv->isFusionInput()) { - continue; - } - for (auto id : tv->domain()->domain()) { - auto loop_group_pair = idGraph(IdMappingMode::LOOP).disjointIdSet(id); - TORCH_INTERNAL_ASSERT( - loop_group_pair.second, - "Loop group not found for leaf id: ", - id->toString()); - auto loop_group = loop_group_pair.first; - if (processed.has(loop_group)) { - continue; - } - processed.pushBack(loop_group); - - auto loop_promotion_it = loop_promotion_map_.find(loop_group); - TORCH_INTERNAL_ASSERT(loop_promotion_it != loop_promotion_map_.end()); - IterDomain* promoted_id = loop_promotion_it->second; - - for (auto loop_group_id : *loop_group) { - if (loop_group_id == promoted_id) { - continue; - } - if (idGraph(IdMappingMode::ALMOSTEXACT) - .disjointIdSets() - .permissiveAreMapped(loop_group_id, promoted_id)) { - idGraph(IdMappingMode::INDEX).mapIds(loop_group_id, promoted_id); - } - } - } - } -} - } // namespace nvfuser diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index e46cefd7a6b..098cd1b5941 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -127,14 +127,12 @@ class TORCH_CUDA_CU_API IdGraph { // , std::vector second_input_or_output_override ) const; - // If entry exists in id_definitions for provided group in provided mode, - // returns that entry, otherwise goes through all iter domains in the group - // and accumulates their id_definitions_ entries + // Returns entry in unique_definitions_ for provided group in provided mode, + // otherwise errors if no entry is found. ExprGroups uniqueDefinitions(IdGroup group) const; - // If entry exists in id_uses for provided group in provided mode, - // returns that entry, otherwise goes through all iter domains in the group - // and accumulates their id_uses_ entries + // Returns entry in unique_uses_ for provided group in provided mode, + // otherwise errors if no entry is found. ExprGroups uniqueUses(IdGroup group) const; std::unordered_map& uniqueUses() { @@ -167,7 +165,19 @@ class TORCH_CUDA_CU_API IdGraph { // order they're traversed differs. void mapThroughLoopSwizzles(); + // Maps iter domain pairs returned by calling that return mappings from + // IdGraph::isTrivialExpr on every expression in the graph. + void mapThroughTrivialExprs(); + + // Removes expressions from unique_definitions_ and unique_uses_ that return + // mappings from IdGraph::isTrivialExpr + void removeTrivialExprs(); + private: + // Removes the provided expression group from unique_definitions_ and + // unique_uses_ breaking traversal through them. + void eraseExprGroup(ExprGroup expr_group); + // Keeps a disjoint set entry for all IterDomain for all mapping mode types. // // Using an array here might be nice, but it seems hard to use an enum as an @@ -182,17 +192,6 @@ class TORCH_CUDA_CU_API IdGraph { std::unordered_map unique_uses_; - // If multiple transformations occur IterDomains could have multiple uses, - // however only one should be active in the given Fusion. When we resolve loop - // promotions during lowering, we can generate new iter domains from existing - // ones, so there can be multiple uses generated. Tracks all the active iter - // domain uses. - std::unordered_map> id_uses_; - - // Make sure we don't blindly use definitions as we don't want to grab - // transformations before a tensor view's root domain. - std::unordered_map> id_definitions_; - // Hold a set of IterDomains that are considered view rfactor ids. This // identification is particularly important to understand if split operations // are divisible or not. From 7891741f3256f40f37157d155c9aeb973707bf13 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 13 Apr 2023 09:49:11 -0400 Subject: [PATCH 012/178] First working index graph example. --- csrc/disjoint_set.h | 3 + csrc/id_graphs.cpp | 457 ++++++++++++++++++++++++++++------------ csrc/id_graphs.h | 26 ++- csrc/transform_iter.cpp | 145 +++++++++---- csrc/transform_iter.h | 25 ++- 5 files changed, 463 insertions(+), 193 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index bd82bc73a2e..d376646db85 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -456,6 +456,9 @@ class DisjointSets { disjoint_set_maps_; // Keep a list of disjoint_sets that's deterministic to iterate over + // + // TODO: Should this just be a + // VectorOfUniqueEntries>> disjoint_sets_; }; diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 1d08fc9ca30..b11658f97b6 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -123,6 +123,13 @@ std::string idGroupsStringShort(const IdGroups& id_groups) { return idGroupsStringShort(id_groups.vector()); } +std::string idGroups(const IdGraph& id_graph) { + IdGroups id_groups( + id_graph.disjointIdSets().disjointSets().begin(), + id_graph.disjointIdSets().disjointSets().end()); + return idGroupsStringShort(id_groups); +} + std::string exprGroupStringShort(ExprGroup expr_group) { std::vector names; for (auto expr : *expr_group) { @@ -180,6 +187,13 @@ std::string exprGroupsStringShort( return ss.str(); } +std::string exprGroups(const IdGraph& id_graph) { + ExprGroups expr_groups( + id_graph.disjointExprSets().disjointSets().begin(), + id_graph.disjointExprSets().disjointSets().end()); + return exprGroupsStringShort(id_graph, expr_groups); +} + std::string definitionsToString(const IdGraph& id_graph) { std::stringstream ss; ExprGroups defs; @@ -636,6 +650,12 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) terminating_outputs = all_id_groups.subtract(not_outputs); } + std::cout << "Term inp: " + << debug_string::idGroupsStringShort(terminating_inputs) + << std::endl; + std::cout << "Term out: " + << debug_string::idGroupsStringShort(terminating_outputs) + << std::endl; // Track all expressions to get from outputs to this IterDomain. We // traverse backwards as that's the direction of indexing expressions. An // index is assigned to each leaf of a domain and as we traverse backwards @@ -1663,13 +1683,14 @@ Expr* IterDomainGraphs::addReplayAs( auto replay = ReplayTransform::replayAs(new_inputs, expr); for (auto out_id : ir_utils::filterByType(replay->outputs())) { - id_definitions_[out_id] = {replay}; - id_uses_[out_id] = {}; + id_definitions_[out_id].pushBack(replay); + id_uses_[out_id]; } // Add the expression to the uses of the inputs for (auto inp_id : ir_utils::filterByType(replay->inputs())) { - id_uses_.at(inp_id).pushBack(replay); + id_definitions_[inp_id]; + id_uses_[inp_id].pushBack(replay); } // Initialize output iter domains in the graphs @@ -1703,10 +1724,10 @@ Expr* IterDomainGraphs::addReplayAs( } } - for (auto expr : representative_uses) { - if (graph.exprsMap(expr, replay, true)) { - graph.mapExprs(expr, replay); - graph.mapThroughExpr(expr, replay, true); + for (auto rep_use : representative_uses) { + if (graph.exprsMap(rep_use, replay, true)) { + graph.mapExprs(rep_use, replay); + graph.mapThroughExpr(rep_use, replay, true); } } } @@ -1716,9 +1737,10 @@ Expr* IterDomainGraphs::addReplayAs( // Generate a new expr with the IterDomain outputs provided and IterDomain // inputs that exactly match expr->inputs -Expr* IterDomainGraphs::addReplayAsBackward( - const std::vector& new_outputs, - Expr* expr) { + +Expr* IterDomainGraphs::addExprWithReplacement( + const std::unordered_map& old_2_new_ids, + Expr* old_expr) { // Figure out which graphs are already initialized to make sure we add the new // expression to them. std::vector initialized_modes; @@ -1736,47 +1758,75 @@ Expr* IterDomainGraphs::addReplayAsBackward( initialized_modes.push_back(mode); } - auto orig_outputs = ir_utils::filterByType(expr->outputs()); - std::vector orig_output_ids( - orig_outputs.begin(), orig_outputs.end()); + // We will fill this map for every IterDomain in input and output. + std::unordered_map replacement_map = old_2_new_ids; - { + // Validate replacement map. Make sure the keys are an input or output + for (auto replacement_entry : replacement_map) { TORCH_INTERNAL_ASSERT( - new_outputs.size() == orig_output_ids.size(), - "Invalid number of outputs: ", - new_outputs.size(), - " does not match number of iter domain outputs for ", - expr->toString()); + std::find( + old_expr->inputs().begin(), + old_expr->inputs().end(), + replacement_entry.first) != old_expr->inputs().end() || + std::find( + old_expr->outputs().begin(), + old_expr->outputs().end(), + replacement_entry.first) != old_expr->outputs().end(), + "Wanted to replace ", + replacement_entry.first->toString(), + " however the is not an input or output of:\n", + old_expr->toString()); + } + + // If all inputs and or all output were replaced + bool all_inps_replaced = true; + bool all_outs_replaced = true; + { + for (auto inp_id : ir_utils::filterByType(old_expr->inputs())) { + if (replacement_map.find(inp_id) == replacement_map.end()) { + all_inps_replaced = false; + replacement_map[inp_id] = inp_id->cloneWithoutRFactor(); + } + } - VectorOfUniqueEntries all_outputs{ - orig_output_ids.begin(), orig_output_ids.end()}; + for (auto out_id : + ir_utils::filterByType(old_expr->outputs())) { + if (replacement_map.find(out_id) == replacement_map.end()) { + all_outs_replaced = false; + replacement_map[out_id] = out_id->cloneWithoutRFactor(); + } + } - all_outputs.pushBack(VectorOfUniqueEntries{ - new_outputs.begin(), new_outputs.end()}); + TORCH_INTERNAL_ASSERT( + (all_inps_replaced || all_outs_replaced), + "Either all the inputs or all the outputs need to be replaced when using this function."); for (auto mode : initialized_modes) { - for (auto inp : all_outputs) { + for (auto inp_or_out_id : all_inps_replaced + ? ir_utils::filterByType(old_expr->inputs()) + : ir_utils::filterByType(old_expr->outputs())) { TORCH_INTERNAL_ASSERT( - idGraph(mode).disjointIdSet(inp).second, - "All outputs for replay need to be initialized in all graphs, ", - inp->toString(), - " was not found in mode: ", + idGraph(mode).disjointIdSet(inp_or_out_id).second, + "Expected ", + inp_or_out_id->toString(), + " to be initialized in graph mode: ", mode); } } } // Create the new expression with provided outputs - auto replay = BackwardTransformCloner::clone(new_outputs, expr); + auto replay = ReplacementTransformCloner::clone(replacement_map, old_expr); for (auto out_id : ir_utils::filterByType(replay->outputs())) { id_definitions_[out_id].pushBack(replay); + id_uses_[out_id]; } // Add the expression to the uses of the inputs for (auto inp_id : ir_utils::filterByType(replay->inputs())) { - id_definitions_[inp_id] = {}; - id_uses_[inp_id] = {replay}; + id_definitions_[inp_id]; + id_uses_[inp_id].pushBack(replay); } // Initialize output iter domains in the graphs @@ -1784,40 +1834,81 @@ Expr* IterDomainGraphs::addReplayAsBackward( idGraph(mode).disjointExprSets().initializeSet(replay); auto replay_group = idGraph(mode).disjointExprSet(replay).first; - // Initialize input ids in map for (auto inp_id : ir_utils::filterByType(replay->inputs())) { - idGraph(mode).initializeId(inp_id, {}, {replay}); + if (!idGraph(mode).disjointIdSets().mappingExists(inp_id)) { + // inp_id is not initialized in the map, initialize it + idGraph(mode).initializeId(inp_id, {}, {replay}); + } else { + // inp_id is already initialized add the replay as a unique use of its + // group. + auto inp_group = idGraph(mode).disjointIdSet(inp_id).first; + idGraph(mode).uniqueUses()[inp_group].pushBack(replay_group); + } } // Update definitions in the graph of the outputs for (auto out_id : ir_utils::filterByType(replay->outputs())) { - auto out_group = idGraph(mode).disjointIdSet(out_id).first; - idGraph(mode).uniqueDefinitions().at(out_group).pushBack(replay_group); + if (!idGraph(mode).disjointIdSets().mappingExists(out_id)) { + // out_id is not initialized in the map, initialize it + idGraph(mode).initializeId(out_id, {replay}, {}); + } else { + // out_id is already initialized, add the replay as a unique definition + // of its group + auto out_group = idGraph(mode).disjointIdSet(out_id).first; + idGraph(mode).uniqueDefinitions().at(out_group).pushBack(replay_group); + } } - // Propagate through all the defintions of the iter domain groups of the - // outputs with the new expression. + // We expect that inputs or outputs were replaced by iter domains that + // already exist in the graphs. If the inputs were replaced we want to + // replay forward through the newly added expression. If the outputs were + // replaced we want to replay backwards (towards inputs) instead. auto& graph = idGraph(mode); // Gather all use expressions from inputs - VectorOfUniqueEntries representative_defs; - for (auto out : new_outputs) { - auto defs_pair = - graph.iterDomainGroupDefinitions(graph.disjointIdSet(out).first); - if (defs_pair.second) { - for (auto def_group : defs_pair.first) { - representative_defs.pushBack(def_group->front()); + + if (all_inps_replaced) { + VectorOfUniqueEntries representative_uses; + for (auto in : ir_utils::filterByType(replay->inputs())) { + auto uses_pair = + graph.iterDomainGroupUses(graph.disjointIdSet(in).first); + if (uses_pair.second) { + for (auto def_group : uses_pair.first) { + representative_uses.pushBack(def_group->front()); + } } } - } - for (auto expr : representative_defs) { - if (graph.exprsMap(expr, replay, false)) { - graph.mapExprs(expr, replay); - graph.mapThroughExpr(expr, replay, false); + representative_uses.erase(replay); + + for (auto rep_use : representative_uses) { + if (graph.exprsMap(rep_use, replay, true)) { + graph.mapExprs(rep_use, replay); + graph.mapThroughExpr(rep_use, replay, true); + } + } + + if (all_outs_replaced) { + VectorOfUniqueEntries representative_defs; + for (auto out : ir_utils::filterByType(replay->outputs())) { + auto defs_pair = + graph.iterDomainGroupDefinitions(graph.disjointIdSet(out).first); + if (defs_pair.second) { + for (auto def_group : defs_pair.first) { + representative_defs.pushBack(def_group->front()); + } + } + } + representative_defs.erase(replay); + + for (auto rep_def : representative_defs) { + if (graph.exprsMap(rep_def, replay, false)) { + graph.mapExprs(rep_def, replay); + graph.mapThroughExpr(rep_def, replay, false); + } + } } } } - return replay; } @@ -3337,6 +3428,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Could pass this into the function, but just using this for now. auto all_tvs = ir_utils::allTvsOfExprs(exprs); + // TODO: This needs to be available as a member function auto get_promoted_domain = [&](TensorDomain* td) { std::vector promoted_leaves; for (auto id : td->domain()) { @@ -3367,115 +3459,206 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // // Since we're performing replays we need to copy the graph we're iterating // on. - auto& ae_graph = idGraph(IdMappingMode::ALMOSTEXACT); + auto ae_graph = idGraph(IdMappingMode::ALMOSTEXACT); for (auto tv : all_tvs) { + // We don't have to process inputs at this point as they're already + // allocated on a global + if (tv->isFusionInput()) { + continue; + } + auto promoted_domain = get_promoted_domain(tv->domain()); // replay from root to promoted leaves. std::cout << "\n\n Processing: TV" << tv->name() << "\n Root: TV" - << tv->getRootDomain() << "\n Promoted: " - << promoted_domain - // << "\n Original: " << tv->domain()->toString() + << tv->getRootDomain() << "\n Promoted: " << promoted_domain << std::endl; - VectorOfUniqueEntries root_ids{ - tv->getRootDomain().begin(), tv->getRootDomain().end()}; + // The promoted leaf iter domains are where indexing starts. We're going to + // start at those expressions and replay transformations for this tensor + // view working back to root domains. We want to intercept the history of + // the transformations local to the tensor view where possible. + // + // So effectively what we have to do is map the ae graph to the history of + // the tensor view as well as the promoted iter domains. We start traversal + // at the promoted iter domains and will intercept the tensor view history + // as possible. + // + // We must be able to interecept the provided tensor view at the rfactor and + // root domains, otherwise we wouldn't be able to allocate or index into the + // buffer at tensor view (rfactor domain) or it's producer (root domain). - auto ae_root_groups = ae_graph.toGroups(root_ids); + // Grab all the domains and convert them to their ae groups. + auto all_ids_v = ir_utils::allIDsOf(tv); + auto all_ids = + VectorOfUniqueEntries(all_ids_v.begin(), all_ids_v.end()); + // Add the promoted domain ids + for (auto promoted_id : promoted_domain) { + all_ids.pushBack(promoted_id); + } + + // Create a map from the ae group to the iter domain as when we replay we'll + // replace the ae iter domain in the replay with the id in this map. std::unordered_map ae_group_2_id; - for (auto root_i : c10::irange(ae_root_groups.size())) { - ae_group_2_id[ae_root_groups.vector()[root_i]] = - root_ids.vector()[root_i]; + for (auto tv_id : all_ids) { + // Use emplace here as it multiple tv_ids could map to the same ae_group. + // Emplace will simply grab the first one that appears. + ae_group_2_id.emplace( + std::make_pair(ae_graph.toGroups({tv_id}).front(), tv_id)); } - VectorOfUniqueEntries leaf_ids{ - promoted_domain.begin(), promoted_domain.end()}; + auto ae_leaf_groups = ae_graph.toGroups(VectorOfUniqueEntries{ + promoted_domain.begin(), promoted_domain.end()}); + + // Don't support multiple leaf domains promoted to the same ae graph at this + // point. + TORCH_INTERNAL_ASSERT( + ae_leaf_groups.size() == promoted_domain.size(), + "Multiple leaf domains that map almost exactly is not supported at this point."); - auto ae_leaf_groups = ae_graph.toGroups(leaf_ids); + auto ae_root_groups = ae_graph.toGroups(VectorOfUniqueEntries{ + tv->getRootDomain().begin(), tv->getRootDomain().end()}); - // Get indexing transformations - auto indexing_transforms = - ae_graph.getExprsBetween(ae_root_groups, ae_leaf_groups); + // Make a copy of the expressions so we can reverse them + auto reverse_indexing_transforms = + ae_graph.getExprsBetween(ae_root_groups, ae_leaf_groups).vector(); - std::cout << " Replaying path to domain:" << std::endl; - // Replay indexing transformations on the root_ids - for (ExprGroup ae_expr : indexing_transforms) { - // Replay mostly copied for a third time. - auto input_groups = ae_graph.inputGroups(ae_expr); + std::reverse( + reverse_indexing_transforms.begin(), reverse_indexing_transforms.end()); - // Inputs "promoted" with the ae_group_2_id map. + // Replay indexing transformations start on leaf nodes propagating back to + // the root domain + for (ExprGroup ae_expr : reverse_indexing_transforms) { + // Outputs must be promoted with the ae_group_2_id map. Inputs may be + // promoted when we intercept the history of the TV with the replay. // // if there isn't an entry in ae_group_2_id, then we have a resolved - // merged in broadcast, we need to clone that input. Would be nice to see - // if the dangling input has already been added already through another - // indexing path that this overlaps with, however having an additional Id - // and expression per case doesn't seem too bad right now. - std::vector promoted_inputs; - bool an_input_was_promoted = false; - - for (auto inp_group : input_groups) { - auto inp_promo_it = ae_group_2_id.find(inp_group); - if (inp_promo_it == ae_group_2_id.end()) { - // Clone dangling input, this is unique for index graph compared to - // the other replays. - promoted_inputs.push_back(cloneIterDomain(inp_group->front())); - } else { - promoted_inputs.push_back(inp_promo_it->second); - an_input_was_promoted = true; - } - } + // merged in broadcast, and that resolved iter domain will need to be + // cloned. Would be nice to see if the dangling input has already been + // added already through another indexing path that this overlaps with, + // however having an additional ID and expression per case doesn't seem + // too bad right now. - if (!an_input_was_promoted) { - // No inputs need promotion so just continue - continue; + auto ae_output_groups = ae_graph.outputGroups(ae_expr); + + std::vector promoted_outputs; + for (auto out_group : ae_output_groups) { + auto out_promo_it = ae_group_2_id.find(out_group); + TORCH_INTERNAL_ASSERT( + out_promo_it != ae_group_2_id.end(), + "Expected promoted iter domain for: ", + debug_string::idGroupStringShort(out_group)); + promoted_outputs.push_back(out_promo_it->second); } Expr* replay = nullptr; // Before replaying, check if there's already an expression like this, if // so use that for promotion. - ExprGroups promoted_input_uses; - for (auto inp_id : promoted_inputs) { + ExprGroups promoted_output_defs; + for (auto out_id : promoted_outputs) { auto index_group = - idGraph(IdMappingMode::INDEX).toGroups({inp_id}).front(); - promoted_input_uses.pushBack( - idGraph(IdMappingMode::INDEX).uniqueUses(index_group)); - } - - for (auto index_use_group : promoted_input_uses) { - if (transformAtributesMatch( - ae_expr->front(), index_use_group->front())) { - auto index_use_inputs = ir_utils::filterByType( - index_use_group->front()->inputs()) - .vector(); - bool inps_match = true; - for (auto inp_i : c10::irange(index_use_inputs.size())) { - inps_match = inps_match && - idGraph(IdMappingMode::INDEX) - .disjointIdSets() - .strictAreMapped( - index_use_inputs[inp_i], promoted_inputs[inp_i]); - } - if (inps_match) { - replay = index_use_group->front(); - break; + idGraph(IdMappingMode::INDEX).toGroups({out_id}).front(); + promoted_output_defs.pushBack( + idGraph(IdMappingMode::INDEX).uniqueDefinitions(index_group)); + } + + for (auto index_def_group : promoted_output_defs) { + // This enforces that inputs and outputs are all almost exact mapping + if (!idGraph(IdMappingMode::ALMOSTEXACT) + .disjointExprSets() + .strictAreMapped(index_def_group->front(), ae_expr->front())) { + continue; + } + + // Check that the outputs we need on the replay match in the index map + // with this expression. + auto index_def_outputs = ir_utils::filterByType( + index_def_group->front()->outputs()) + .vector(); + bool outs_match = true; + for (auto inp_i : c10::irange(index_def_outputs.size())) { + outs_match = outs_match && + idGraph(IdMappingMode::INDEX) + .disjointIdSets() + .strictAreMapped( + index_def_outputs[inp_i], promoted_outputs[inp_i]); + } + + if (!outs_match) { + continue; + } + + // Outputs all match in the index map, but need to make sure the inputs + // do as well. + auto index_def_inputs = ir_utils::filterByType( + index_def_group->front()->inputs()) + .vector(); + + bool inps_match = true; + for (auto inp_id : index_def_inputs) { + IterDomain* promoted_inp = nullptr; + auto ae_inp_group = ae_graph.toGroups({inp_id}).front(); + auto promoted_inp_it = ae_group_2_id.find(ae_inp_group); + if (promoted_inp_it == ae_group_2_id.end()) { + // This input is already almost exact mapped, and we don't need this + // input to map exactly in the index map. + continue; + } else { + promoted_inp = promoted_inp_it->second; } + + inps_match = inps_match && + idGraph(IdMappingMode::INDEX) + .disjointIdSets() + .strictAreMapped(inp_id, promoted_inp); + } + + if (!inps_match) { + continue; } + + replay = index_def_group->front(); + break; } if (replay == nullptr) { - std::cout << " Replay: " << ae_expr->front(); - std::cout << " With promoted inputs: " << promoted_inputs + std::vector ae_inps_outs = + ir_utils::filterByType(ae_expr->front()->inputs()) + .vector(); + auto outs = + ir_utils::filterByType(ae_expr->front()->outputs()); + ae_inps_outs.insert(ae_inps_outs.end(), outs.begin(), outs.end()); + + std::unordered_map replacement_map; + for (auto id : ae_inps_outs) { + auto ae_group = ae_graph.toGroups({id}).front(); + auto promoted_it = ae_group_2_id.find(ae_group); + if (promoted_it == ae_group_2_id.end()) { + replacement_map[id] = id->cloneWithoutRFactor(); + } else { + replacement_map[id] = promoted_it->second; + } + } + + // std::cout << " Replay: " << ae_expr->front(); + // std::cout << " With promoted inputs: " << promoted_inputs + // << std::endl; + replay = addExprWithReplacement(replacement_map, ae_expr->front()); + // std::cout << " ***REPLAY3***:\n " << ae_expr->front() + // << " As:" << replay->toString(); + std::cout << " ***REPLAY3***:\n " + << " " << replay->toString(); + std::cout << debug_string::idGroups(idGraph(IdMappingMode::INDEX)) << std::endl; - replay = addReplayAs(promoted_inputs, ae_expr->front()); - std::cout << " ***REPLAY3***:\n " << ae_expr->front() - << " As:" << replay->toString(); + } else { + std::cout << " ***MATCHED3***:\n " + << " " << replay->toString(); } all_index_exprs.pushBack(replay); - { auto in_ids = ir_utils::filterByType(replay->inputs()); all_index_ids.insert(in_ids.begin(), in_ids.end()); @@ -3484,17 +3667,17 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { all_index_ids.insert(out_ids.begin(), out_ids.end()); } - auto out_groups = - idGraph(IdMappingMode::ALMOSTEXACT).outputGroups(ae_expr); - - // Mark outputs as having a promoted iter domain - auto replay_out_ids = - ir_utils::filterByType(replay->outputs()).vector(); + std::vector ae_inps = + ir_utils::filterByType(ae_expr->front()->inputs()) + .vector(); + std::vector replay_inps = + ir_utils::filterByType(replay->inputs()).vector(); + TORCH_INTERNAL_ASSERT(ae_inps.size() == replay_inps.size()); - TORCH_INTERNAL_ASSERT(replay_out_ids.size() == out_groups.size()); - - for (auto i : c10::irange(replay_out_ids.size())) { - ae_group_2_id[out_groups.vector()[i]] = replay_out_ids[i]; + for (auto inp_i : c10::irange(ae_inps.size())) { + auto ae_group = ae_graph.toGroups({ae_inps[inp_i]}).front(); + // Only replace if entry does not exist. + ae_group_2_id.emplace(std::make_pair(ae_group, replay_inps[inp_i])); } } } @@ -3512,8 +3695,12 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { idGraph(IdMappingMode::INDEX), index_expr_groups) << std::endl; - std::cout << "All iter domains that would be indexed: " - << all_index_ids.toString() << std::endl; + std::cout << "All iter domains (on the index graph): " << std::endl; + auto index_id_groups = idGraph(IdMappingMode::INDEX).toGroups(all_index_ids); + std::cout << debug_string::idGroupsStringShort(index_id_groups) << std::endl; + + // std::cout << "All iter domains that would be indexed: " + // << all_index_ids.toString() << std::endl; TORCH_INTERNAL_ASSERT(false); } diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index 098cd1b5941..48a3241e726 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -46,6 +46,12 @@ class TORCH_CUDA_CU_API IdGraph { // Same as getDisjointIdSet but for the Expression sets. std::pair disjointExprSet(Expr* expr) const; + // TODO: Audit usage of toGroups: + // Being used when only a single expr or id, break that into a separate + // function. + // There may be an assumption that the size of incoming vector is same as + // output, that is not the case. + // Convert unique vector of expressions to unique vector of its groups ExprGroups toGroups(const VectorOfUniqueEntries& exprs) const; @@ -381,14 +387,18 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { // replayed expression and adding potential mappings through the expression. Expr* addReplayAs(const std::vector& new_inputs, Expr* expr); - // Similar to addReplayAs, but in the reverse direction. Also addReplayAs can - // generate output ids by using the IterDomain::transform functions. For - // backwards because of merge the input iter domains of the transform are just - // cloned with IterDomain::cloneWithoutRFactor, and the transform Expr is - // generated with IrBuilder copying over all the attributes. - Expr* addReplayAsBackward( - const std::vector& new_outputs, - Expr* expr); + // Similar to addReplayAs, but clones the expr exactly instead of replaying it + // forward. It's up to the calling code to make sure the replacements are + // valid for the provided expr. It's generally recommended that the + // IterDomains exactly match those in the expr. + // + // "forward" dictates the same argument for mapThroughExpr. If forward the + // function will apply mapThroughExpr forward if inputs map in each + // initialized map. Else does the same but backwards through the expression + // from outputs. + Expr* addExprWithReplacement( + const std::unordered_map& old_2_new_ids, + Expr* old_expr); // Make a new expr matching that provided but using the outputs provided. // IterDomainGraphss will be updated for all maps that have entries. Adding diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 14524ca98e7..43095c8fe50 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -89,31 +89,52 @@ void ReplayTransformations::handle(Expr* e) { IterVisitor::handle(e); } -Expr* BackwardTransformCloner::clone( - const std::vector& ordered_outputs, +Expr* ReplacementTransformCloner::clone( + const std::unordered_map& + provided_expr_val_2_replacement_val, const Expr* expression_to_match) { - BackwardTransformCloner replay(ordered_outputs, expression_to_match); + ReplacementTransformCloner replay( + provided_expr_val_2_replacement_val, expression_to_match); return replay.new_expr_; } -BackwardTransformCloner::BackwardTransformCloner( - const std::vector& ordered_outputs, +ReplacementTransformCloner::ReplacementTransformCloner( + const std::unordered_map& + provided_expr_val_2_replacement_val, const Expr* expression_to_match) - : output_ids_(ordered_outputs) { + : provided_expr_val_2_replacement_val_( + provided_expr_val_2_replacement_val) { OptOutConstDispatch::handle(expression_to_match); } // We're going to replay this split operation on the corresponding ID -void BackwardTransformCloner::handle(const Split* split) { - TORCH_INTERNAL_ASSERT( - output_ids_.size() == 2, - "Expected two outputs to match split: ", - split->toString()); +void ReplacementTransformCloner::handle(const Split* split) { + // Replace or clone + auto split_in = split->in(); + split_in = provided_expr_val_2_replacement_val_.find(split_in) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(split_in) + : split_in->cloneWithoutRFactor(); + + auto split_outer = split->outer(); + split_outer = provided_expr_val_2_replacement_val_.find(split_outer) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(split_outer) + : split_outer->cloneWithoutRFactor(); + + auto split_inner = split->inner(); + split_inner = provided_expr_val_2_replacement_val_.find(split_inner) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(split_inner) + : split_inner->cloneWithoutRFactor(); + + // TODO: Should we check inner/outer matches the factor if + // innerSplit()/!innerSplit()? new_expr_ = IrBuilder::create( - output_ids_[0], - output_ids_[1], - split->in()->cloneWithoutRFactor(), + split_outer, + split_inner, + split_in, split->factor(), split->innerSplit(), split->startOffset(), @@ -121,44 +142,84 @@ void BackwardTransformCloner::handle(const Split* split) { } // We're going to replay this merge operation on the corresponding IDs -void BackwardTransformCloner::handle(const Merge* merge) { - TORCH_INTERNAL_ASSERT( - output_ids_.size() == 1, - "Expected one output to match merge: ", - merge->toString()); - - new_expr_ = IrBuilder::create( - output_ids_[0], - merge->outer()->cloneWithoutRFactor(), - merge->inner()->cloneWithoutRFactor()); +void ReplacementTransformCloner::handle(const Merge* merge) { + // Replace or clone + auto merge_outer = merge->outer(); + merge_outer = provided_expr_val_2_replacement_val_.find(merge_outer) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(merge_outer) + : merge_outer->cloneWithoutRFactor(); + + auto merge_inner = merge->inner(); + merge_inner = provided_expr_val_2_replacement_val_.find(merge_inner) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(merge_inner) + : merge_inner->cloneWithoutRFactor(); + + auto merge_out = merge->out(); + merge_out = provided_expr_val_2_replacement_val_.find(merge_out) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(merge_out) + : merge_out->cloneWithoutRFactor(); + + new_expr_ = IrBuilder::create(merge_out, merge_outer, merge_inner); } // We're going to replay this swizzle operation on the corresponding IDs // if replaying swizzle is enabled. -void BackwardTransformCloner::handle(const Swizzle2D* swizzle_2d) { - TORCH_INTERNAL_ASSERT( - output_ids_.size() == 2, - "Expected two outputs to match swizzle: ", - swizzle_2d->toString()); +void ReplacementTransformCloner::handle(const Swizzle2D* swizzle_2d) { + // Replace or clone + auto swizzle_inx = swizzle_2d->inX(); + swizzle_inx = provided_expr_val_2_replacement_val_.find(swizzle_inx) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(swizzle_inx) + : swizzle_inx->cloneWithoutRFactor(); + + // Replace or clone + auto swizzle_iny = swizzle_2d->inY(); + swizzle_iny = provided_expr_val_2_replacement_val_.find(swizzle_iny) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(swizzle_iny) + : swizzle_iny->cloneWithoutRFactor(); + + // Replace or clone + auto swizzle_outx = swizzle_2d->outX(); + swizzle_outx = provided_expr_val_2_replacement_val_.find(swizzle_outx) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(swizzle_outx) + : swizzle_outx->cloneWithoutRFactor(); + + // Replace or clone + auto swizzle_outy = swizzle_2d->outY(); + swizzle_outy = provided_expr_val_2_replacement_val_.find(swizzle_outy) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(swizzle_outy) + : swizzle_outy->cloneWithoutRFactor(); + new_expr_ = IrBuilder::create( - output_ids_[0], - output_ids_[1], - swizzle_2d->inX()->cloneWithoutRFactor(), - swizzle_2d->inY()->cloneWithoutRFactor(), + swizzle_outx, + swizzle_outy, + swizzle_inx, + swizzle_iny, swizzle_2d->swizzleType(), swizzle_2d->swizzleMode()); } -void BackwardTransformCloner::handle(const Resize* resize) { - TORCH_INTERNAL_ASSERT( - output_ids_.size() == 1, - "Expected one output to match resize: ", - resize->toString()); +void ReplacementTransformCloner::handle(const Resize* resize) { + auto resize_in = resize->in(); + resize_in = provided_expr_val_2_replacement_val_.find(resize_in) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(resize_in) + : resize_in->cloneWithoutRFactor(); + + auto resize_out = resize->out(); + resize_out = provided_expr_val_2_replacement_val_.find(resize_out) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(resize_out) + : resize_out->cloneWithoutRFactor(); + new_expr_ = IrBuilder::create( - output_ids_[0], - resize->in()->cloneWithoutRFactor(), - resize->leftExpand(), - resize->rightExpand()); + resize_out, resize_in, resize->leftExpand(), resize->rightExpand()); } // We're going to replay this split operation on the corresponding ID diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index 0df6d7bd9a5..8c3721511c2 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -68,19 +68,27 @@ class ReplayTransform : OptInConstDispatch { const std::vector& input_ids_; }; -class BackwardTransformCloner : OptInConstDispatch { +class ReplacementTransformCloner : OptInConstDispatch { public: - // Generates a copy of expression_to_match with provided output - // IterDomains, cloning the inputs in expression_to_match. + // Generates a copy of expression_to_match with inputs and/or outputs replaced + // by entries provided in the map. Inputs and outputs are expected to be + // "clones". Not literally, but it's up to the envoking code to make the + // input/output replacements are safe to use in the cloned expression. No + // validation is done on provided inputs/outputs. + // + // In other words a split i0{I0}->i1{I0//2}, i2{2} with a map: + // i2{2} -> i3{48} wouldn't throw an error, but would not bevalid. static Expr* clone( - const std::vector& ordered_outputs, + const std::unordered_map& + provided_expr_val_2_replacement_val, const Expr* expression_to_match); private: - BackwardTransformCloner() = delete; + ReplacementTransformCloner() = delete; - BackwardTransformCloner( - const std::vector& ordered_outputs, + ReplacementTransformCloner( + const std::unordered_map& + expr_to_match_2_replacement, const Expr* expression_to_match); using OptInConstDispatch::handle; @@ -100,7 +108,8 @@ class BackwardTransformCloner : OptInConstDispatch { void handle(const Resize* resize) override; Expr* new_expr_ = nullptr; - const std::vector& output_ids_; + const std::unordered_map& + provided_expr_val_2_replacement_val_; }; // Uses the history of _target_domain, and replays that history using the From ee4e3114d4efedffd54b5269f4c8b6e43bf89913 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 13 Apr 2023 17:26:16 -0400 Subject: [PATCH 013/178] Fix building permissive graph. --- csrc/id_graphs.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index b11658f97b6..bcda351793e 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -2069,7 +2069,6 @@ void IterDomainGraphs::buildAlmostExactMap() { // Build almost exact map by forwarding through broadcast axes idGraph(IdMappingMode::ALMOSTEXACT) = idGraph(IdMappingMode::EXACT); idGraph(IdMappingMode::ALMOSTEXACT).mapThroughTrivialExprs(); - idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); } void IterDomainGraphs::validateAndPropagatePType() const { @@ -2142,6 +2141,11 @@ void IterDomainGraphs::build( buildAlmostExactMap(); std::cout << "buildPermissiveMap" << std::endl; buildPermissiveMap(tv_exprs); + // Permissive graph needs the trivial exprs from the almost exact graph to + // build correctly. Once built though we can remove the trivial expressions + // from the almost exact graph. + idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); + std::cout << "built non lowering graphs" << std::endl; // Only build loop map during lowering From 4249e94e8178c8f14d1ee413ff11bb3d8ce51b51 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Fri, 14 Apr 2023 15:05:18 -0400 Subject: [PATCH 014/178] Small fix, clean up printing. --- csrc/id_graphs.cpp | 72 +++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 40 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index bcda351793e..204bf78a8e9 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -650,12 +650,6 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) terminating_outputs = all_id_groups.subtract(not_outputs); } - std::cout << "Term inp: " - << debug_string::idGroupsStringShort(terminating_inputs) - << std::endl; - std::cout << "Term out: " - << debug_string::idGroupsStringShort(terminating_outputs) - << std::endl; // Track all expressions to get from outputs to this IterDomain. We // traverse backwards as that's the direction of indexing expressions. An // index is assigned to each leaf of a domain and as we traverse backwards @@ -2135,22 +2129,18 @@ void IterDomainGraphs::build( // expressions. idGraph(IdMappingMode::EXACT) = initializeIdGraph(); - std::cout << "buildExactMap" << std::endl; buildExactMap(tv_exprs); - std::cout << "buildAlmostExactMap" << std::endl; buildAlmostExactMap(); - std::cout << "buildPermissiveMap" << std::endl; buildPermissiveMap(tv_exprs); + // Permissive graph needs the trivial exprs from the almost exact graph to // build correctly. Once built though we can remove the trivial expressions // from the almost exact graph. idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); - std::cout << "built non lowering graphs" << std::endl; - // Only build loop map during lowering if (FusionGuard::getCurFusion()->isA()) { - FusionGuard::getCurFusion()->print(); + FusionGuard::getCurFusion()->print(std::cout, true); // Find loops that need to be promoted because of broadcast resolution, // figure out what that resolution should look like, compute IDs for it if // necessary. @@ -2720,8 +2710,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { replay = addReplayAs(promoted_inputs, iel_expr->front()); std::cout << " ***REPLAY***:\n " << iel_expr->front() << " As:" << replay->toString(); - } else { - std::cout << " Matched replay found: " << replay->toString(); } auto out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); @@ -2870,10 +2858,17 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; } + std::cout << "Loop promotion:" << std::endl; + for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { + std::cout << debug_string::idGroupStringShort(loop_group) << " -> " + << loop_graph_copy_promotion_map[loop_group]->toString() + << std::endl; + } + // Reset the promotion map for the second pass iel_promotion_map.clear(); - std::cout << "\n\nForward replay iel graph:" << std::endl; + std::cout << "\n\nSecond replay:" << std::endl; IdGraphStmtSort iel_stmt_sort2(intersection_exact_loop_graph); for (auto iel_expr : iel_stmt_sort2.exprs()) { @@ -2973,8 +2968,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { replay = addReplayAs(promoted_inputs, iel_expr->front()); std::cout << " ***REPLAY2***:\n " << iel_expr->front() << " As:" << replay->toString(); - } else { - std::cout << " Matched replay found: " << replay->toString(); } auto output_groups = intersection_exact_loop_graph.outputGroups(iel_expr); @@ -3183,7 +3176,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // *************** STOP: Code copied verbatim from above ******************** - std::cout << "Loop graph copy promotion map: " << std::endl; + std::cout << "Promotion map from concrete id pass: " << std::endl; for (auto group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_graph_copy_promotion_map.find(group) == loop_graph_copy_promotion_map.end()) { @@ -3193,6 +3186,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { << loop_graph_copy_promotion_map.at(group)->toString() << std::endl; } + // Indexing traversal must start at leaf nodes of TensorViews as that's where // the loop indices are defined. For indexing we need to propagate leaves to // root domains. We want the indexing graph easy to traverse. Easy to traverse @@ -3368,6 +3362,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } } + std::cout << "Leaf iter domains that share a promoted iter domain." << std::endl; for (auto disjoint_set : shared_promoted_id.disjointSets()) { @@ -3456,6 +3451,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // just be removed. idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); + std::cout << "\n\nThird and final replay" << std::endl; std::cout << "Building promoted tensor view domains:" << std::endl; // Need to "replay" all of the indexing expressions to make sure roots are // connected to the promoted leaves, in a way we can index directly on the @@ -3475,8 +3471,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { auto promoted_domain = get_promoted_domain(tv->domain()); // replay from root to promoted leaves. std::cout << "\n\n Processing: TV" << tv->name() << "\n Root: TV" - << tv->getRootDomain() << "\n Promoted: " << promoted_domain - << std::endl; + << tv->getRootDomain() + << "\n Domain promoted to: " << promoted_domain << std::endl; // The promoted leaf iter domains are where indexing starts. We're going to // start at those expressions and replay transformations for this tensor @@ -3604,14 +3600,20 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { bool inps_match = true; for (auto inp_id : index_def_inputs) { IterDomain* promoted_inp = nullptr; - auto ae_inp_group = ae_graph.toGroups({inp_id}).front(); - auto promoted_inp_it = ae_group_2_id.find(ae_inp_group); - if (promoted_inp_it == ae_group_2_id.end()) { + std::cout << inp_id->toString() << std::endl; + auto ae_group_pair = ae_graph.disjointIdSet(inp_id); + + std::cout << inp_id->toString() << std::endl; + if (ae_group_pair.second && + ae_group_2_id.find(ae_group_pair.first) != ae_group_2_id.end()) { + promoted_inp = ae_group_2_id.at(ae_group_pair.first); + } else { + // TODO: Should this be here or should we continue below. Check + // Indexing20 test. + // This input is already almost exact mapped, and we don't need this // input to map exactly in the index map. continue; - } else { - promoted_inp = promoted_inp_it->second; } inps_match = inps_match && @@ -3647,19 +3649,9 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - // std::cout << " Replay: " << ae_expr->front(); - // std::cout << " With promoted inputs: " << promoted_inputs - // << std::endl; replay = addExprWithReplacement(replacement_map, ae_expr->front()); - // std::cout << " ***REPLAY3***:\n " << ae_expr->front() - // << " As:" << replay->toString(); std::cout << " ***REPLAY3***:\n " << " " << replay->toString(); - std::cout << debug_string::idGroups(idGraph(IdMappingMode::INDEX)) - << std::endl; - } else { - std::cout << " ***MATCHED3***:\n " - << " " << replay->toString(); } all_index_exprs.pushBack(replay); @@ -3686,11 +3678,11 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - std::cout << "All indexing expressions that need to be processed: " - << std::endl; - for (auto expr : all_index_exprs) { - std::cout << expr->toString(); - } + // std::cout << "All indexing expressions that need to be processed: " + // << std::endl; + // for (auto expr : all_index_exprs) { + // std::cout << expr->toString(); + // } std::cout << "All indexing expressions (on the index graph): " << std::endl; auto index_expr_groups = From 853cd445562af1a9cfc07e9fc43162efc50be09a Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 20 Apr 2023 09:46:55 -0400 Subject: [PATCH 015/178] Rework some interfaces around IdGraph, add option to not propagate through expressions for IdGraph, don't use it yet. --- csrc/disjoint_set.h | 9 ++ csrc/id_graphs.cpp | 239 ++++++++++++++++++++++++------------- csrc/id_graphs.h | 52 ++++++-- test/test_gpu_indexing.cpp | 13 +- 4 files changed, 210 insertions(+), 103 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index d376646db85..331b069ed3b 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -57,6 +57,15 @@ class VectorOfUniqueEntries { template VectorOfUniqueEntries(InputIt first, InputIt last) { + pushBack(first, last); + } + + template + VectorOfUniqueEntries(const Container& container) + : VectorOfUniqueEntries(container.begin(), container.end()) {} + + template + void pushBack(InputIt first, InputIt last) { while (first != last) { pushBack(*first++); } diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 204bf78a8e9..3b40a34199c 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -146,8 +146,8 @@ std::string exprGroupStringShort( const IdGraph& id_graph, ExprGroup expr_group) { std::stringstream ss; - auto inputs = id_graph.inputGroups(expr_group); - auto outputs = id_graph.outputGroups(expr_group); + auto inputs = IdGroups(id_graph.inputGroups(expr_group)); + auto outputs = IdGroups(id_graph.outputGroups(expr_group)); ss << idGroupsStringShortInline(inputs) << " -" << exprGroupStringShort(expr_group) << "-> " << idGroupsStringShortInline(outputs); @@ -295,8 +295,8 @@ void IdGraphVisitor::traverse() { if (all_exprs.has(def)) { continue; } - auto inp_groups = graph().inputGroups(def); - auto out_groups = graph().outputGroups(def); + auto inp_groups = IdGroups(graph().inputGroups(def)); + auto out_groups = IdGroups(graph().outputGroups(def)); if (inp_groups.subtract(all_ids).empty() && out_groups.subtract(all_ids).empty()) { all_exprs.pushBack(def); @@ -314,8 +314,8 @@ void IdGraphVisitor::traverse() { IdGroups not_inputs; IdGroups not_outputs; for (auto expr_group : all_exprs) { - auto inp_groups = graph().inputGroups(expr_group); - auto out_groups = graph().outputGroups(expr_group); + auto inp_groups = IdGroups(graph().inputGroups(expr_group)); + auto out_groups = IdGroups(graph().outputGroups(expr_group)); if (inp_groups.intersect(out_groups).size() > 0) { // Expression is just a loop to its current group, ignore @@ -505,13 +505,29 @@ std::pair IdGraph::disjointExprSet(Expr* expr) const { return std::make_pair(disjoint_set_it->second, true); } +ExprGroup IdGraph::toGroup(Expr* expr) const { + auto disjoint_set_pair = disjointExprSet(expr); + TORCH_INTERNAL_ASSERT( + disjoint_set_pair.second, + "\nExpr group could not be found in graph associated with: ", + expr->toString()); + return disjoint_set_pair.first; +} + +IdGroup IdGraph::toGroup(IterDomain* id) const { + auto disjoint_set_pair = disjointIdSet(id); + TORCH_INTERNAL_ASSERT( + disjoint_set_pair.second, + "\nId group could not be found in graph associated with: ", + id->toString(), + "\n"); + return disjoint_set_pair.first; +} + ExprGroups IdGraph::toGroups(const VectorOfUniqueEntries& exprs) const { ExprGroups expr_groups; for (auto expr : exprs) { - auto disjoint_set_pair = disjointExprSet(expr); - if (disjoint_set_pair.second) { - expr_groups.pushBack(disjoint_set_pair.first); - } + expr_groups.pushBack(toGroup(expr)); } return expr_groups; } @@ -520,30 +536,55 @@ IdGroups IdGraph::toGroups( const VectorOfUniqueEntries& ids) const { IdGroups id_groups; for (auto id : ids) { - auto disjoint_set_pair = disjointIdSet(id); - if (disjoint_set_pair.second) { - id_groups.pushBack(disjoint_set_pair.first); - } + id_groups.pushBack(toGroup(id)); } return id_groups; } -IdGroups IdGraph::outputGroups(ExprGroup expr) const { - VectorOfUniqueEntries id_outputs; +std::vector IdGraph::outputGroups(ExprGroup expr) const { + std::vector output_groups; for (auto id_output : ir_utils::filterByType(expr->front()->outputs())) { - id_outputs.pushBack(id_output); + output_groups.push_back(toGroup(id_output)); } - return toGroups(id_outputs); + return output_groups; } -IdGroups IdGraph::inputGroups(ExprGroup expr) const { - VectorOfUniqueEntries id_inputs; +std::vector IdGraph::inputGroups(ExprGroup expr) const { + std::vector input_groups; for (auto id_input : ir_utils::filterByType(expr->front()->inputs())) { - id_inputs.pushBack(id_input); + input_groups.push_back(toGroup(id_input)); + } + return input_groups; +} + +bool IdGraph::groupsMatch( + std::vector id_groups0, + std::vector id_groups1) const { + if (id_groups0.size() != id_groups1.size()) { + return false; } - return toGroups(id_inputs); + for (auto id_g_i : c10::irange(id_groups0.size())) { + if (id_groups0[id_g_i] != id_groups1[id_g_i]) { + return false; + } + } + return true; +} + +bool IdGraph::groupsMatch( + std::vector expr_groups0, + std::vector expr_groups1) const { + if (expr_groups0.size() != expr_groups1.size()) { + return false; + } + for (auto id_g_i : c10::irange(expr_groups0.size())) { + if (expr_groups0[id_g_i] != expr_groups1[id_g_i]) { + return false; + } + } + return true; } ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { @@ -629,7 +670,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) for (auto expr_group : all_exprs) { auto inp_groups = inputGroups(expr_group); auto out_groups = outputGroups(expr_group); - if (inp_groups.intersect(out_groups).size() > 0) { + if (IdGroups(inp_groups).intersect(IdGroups(out_groups)).size() > 0) { // Expression is just a loop to its current group, ignore continue; } @@ -1129,34 +1170,18 @@ void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { return; } - // TODO: make these class functions for convenience, there are too many - // asserts in this file. - auto assert_get_expr_group = [&](Expr* expr) { - auto expr_group_pair = disjointExprSet(expr); - TORCH_INTERNAL_ASSERT( - expr_group_pair.second, "Could not find entry for expression: ", expr); - return expr_group_pair.first; - }; - - auto assert_get_id_group = [&](IterDomain* id) { - auto id_group_pair = disjointIdSet(id); - TORCH_INTERNAL_ASSERT( - id_group_pair.second, "Could not find entry for IterDomain: ", id); - return id_group_pair.first; - }; - - ExprGroup expr0_orig_group = assert_get_expr_group(expr0); - ExprGroup expr1_orig_group = assert_get_expr_group(expr1); + ExprGroup expr0_orig_group = toGroup(expr0); + ExprGroup expr1_orig_group = toGroup(expr1); disjointExprSets().mapEntries(expr0, expr1); - auto expr_new_group = assert_get_expr_group(expr0); + auto expr_new_group = toGroup(expr0); // Update unique uses of producers IdGroups producers; for (auto expr : std::vector{expr0, expr1}) { for (auto input_id : ir_utils::filterByType(expr->inputs())) { - producers.pushBack(assert_get_id_group(input_id)); + producers.pushBack(toGroup(input_id)); } } @@ -1170,7 +1195,7 @@ void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { IdGroups consumers; for (auto expr : std::vector{expr0, expr1}) { for (auto output_id : ir_utils::filterByType(expr->outputs())) { - consumers.pushBack(assert_get_id_group(output_id)); + consumers.pushBack(toGroup(output_id)); } } @@ -1225,8 +1250,17 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { auto use0 = use_group_0->front(); auto use1 = use_group_1->front(); if (exprsMap(use0, use1, true)) { - mapExprs(use0, use1); - mapThroughExpr(use0, use1, true); + if (propagate_exprs_) { + mapExprs(use0, use1); + mapThroughExpr(use0, use1, true); + } else if ((groupsMatch( + inputGroups(toGroup(use0)), + inputGroups(toGroup(use1))) && + groupsMatch( + outputGroups(toGroup(use0)), + outputGroups(toGroup(use1))))) { + mapExprs(use0, use1); + } } } } @@ -1245,8 +1279,17 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { auto def0 = def_group_0->front(); auto def1 = def_group_1->front(); if (exprsMap(def0, def1, false)) { - mapExprs(def0, def1); - mapThroughExpr(def0, def1, false); + if (propagate_exprs_) { + mapExprs(def0, def1); + mapThroughExpr(def0, def1, false); + } else if ((groupsMatch( + inputGroups(toGroup(def0)), + inputGroups(toGroup(def1))) && + groupsMatch( + outputGroups(toGroup(def0)), + outputGroups(toGroup(def1))))) { + mapExprs(def0, def1); + } } } } @@ -1327,7 +1370,7 @@ void IdGraph::removeTrivialExprs() { for (auto expr_group : disjointExprSets().disjointSets()) { auto inp_groups = inputGroups(expr_group); auto out_groups = outputGroups(expr_group); - if (inp_groups.intersect(out_groups).size()) { + if (IdGroups(inp_groups).intersect(IdGroups(out_groups)).size()) { trivial_expr_groups.pushBack(expr_group); } } @@ -2025,6 +2068,9 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { ForwardingInfo permissive_forwarding(p_tv, c_tv); for (auto entry : permissive_forwarding.producer_forwarding_map) { + std::cout << "Permissive producer forwarding: " + << entry.first->toString() << " -> " + << entry.second->toString() << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); } @@ -2032,11 +2078,17 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { // TODO: Why should IDs be mapped to their compliments? Is this right? for (auto entry : permissive_forwarding.producer_compliment_map) { for (auto entry_2 : entry.second) { + std::cout << "Permissive producer compliment: " + << entry.first->toString() << " -> " << entry_2->toString() + << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); } } for (auto entry : permissive_forwarding.consumer_forwarding_map) { + std::cout << "Permissive consumer forwarding: " + << entry.first->toString() << " -> " + << entry.second->toString() << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); } @@ -2044,6 +2096,9 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { // TODO: Why should IDs be mapped to their compliments? Is this right? for (auto entry : permissive_forwarding.consumer_compliment_map) { for (auto entry_2 : entry.second) { + std::cout << "Permissive consumer compliment: " + << entry.first->toString() << " -> " << entry_2->toString() + << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); } } @@ -2201,7 +2256,7 @@ std::unordered_map IterDomainGraphs:: for (auto def_group : definition_pair_it.first) { auto inp_groups = idGraph(IdMappingMode::ALMOSTEXACT).inputGroups(def_group); - producer_groups.pushBack(inp_groups); + producer_groups.pushBack(inp_groups.begin(), inp_groups.end()); } return producer_groups; }; @@ -2330,6 +2385,7 @@ std::unordered_map IterDomainGraphs:: void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { idGraph(IdMappingMode::LOOP) = initializeIdGraph(); + // idGraph(IdMappingMode::LOOP).disableExprPropagation(); std::unordered_map> p2c_root_broadcast_resolution_map; @@ -2431,10 +2487,16 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } auto c_ids = entry_it->second; for (auto c_id : c_ids) { + std::cout << "Map: " << p_id->toString() << " <-> " << c_id->toString() + << std::endl; idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); } } + std::cout << "Loop groups: " + << debug_string::idGroups(idGraph(IdMappingMode::LOOP)) + << std::endl; + // Terminal loop ids are iteration domains in each loop group that: // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a // consumer TV's iter domain maps to this domain in a way that that domain @@ -2645,13 +2707,14 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::cout << "Initial promotion replay:" << std::endl; for (auto iel_expr : iel_stmt_sort.exprs()) { - IdGroups input_groups = intersection_exact_loop_graph.inputGroups(iel_expr); - + std::cout << "a" << std::endl; + auto input_groups = intersection_exact_loop_graph.inputGroups(iel_expr); + std::cout << "b" << std::endl; // Check if any inputs need promotion indicating this expr group needs to // be replayed with promoted inputs std::vector promoted_inputs; bool an_input_was_promoted = false; - + std::cout << "c" << std::endl; for (auto inp : input_groups) { auto inp_promo_it = iel_promotion_map.find(inp); if (inp_promo_it == iel_promotion_map.end()) { @@ -2661,18 +2724,24 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { an_input_was_promoted = true; } } - + std::cout << "d" << std::endl; if (!an_input_was_promoted) { // No inputs need promotion so just continue continue; } - + std::cout << "e" << std::endl; Expr* replay = nullptr; - auto promoted_input_groups = intersection_exact_loop_graph.toGroups( - VectorOfUniqueEntries{ - promoted_inputs.begin(), promoted_inputs.end()}); + IdGroups promoted_input_groups; + for (auto inp_id : promoted_inputs) { + auto inp_disjoint_set_pair = + intersection_exact_loop_graph.disjointIdSet(inp_id); + if (inp_disjoint_set_pair.second) { + promoted_input_groups.pushBack(inp_disjoint_set_pair.first); + } + } + std::cout << "f" << std::endl; // Before replaying, check if there's already an expression like this, if so // use that for promotion. We would need the iel entries for non-promoted // inputs to match exactly to reuse the expression. @@ -2687,7 +2756,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { non_promoted_input_uses.pushBack( intersection_exact_loop_graph.uniqueUses(iel_group)); } - + std::cout << "g" << std::endl; for (auto iel_use_group : non_promoted_input_uses) { if (transformAtributesMatch(iel_expr->front(), iel_use_group->front())) { auto use_inps = @@ -2721,10 +2790,10 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { TORCH_INTERNAL_ASSERT(replay_out_ids.size() == out_groups.size()); for (auto i : c10::irange(replay_out_ids.size())) { - iel_promotion_map[out_groups.vector()[i]] = replay_out_ids[i]; + iel_promotion_map[out_groups[i]] = replay_out_ids[i]; } } - + std::cout << "post" << std::endl; // Map from an exact iter domain group, to all the exact iter domain groups it // covers std::unordered_map exact_covered_ids; @@ -2754,7 +2823,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { IdGraphStmtSort exact_stmt_sort(idGraph(IdMappingMode::EXACT)); for (auto exact_expr : exact_stmt_sort.exprs()) { - auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); + auto input_groups = + idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); IdGroups covered; for (auto inp_group : input_groups) { @@ -2872,9 +2942,11 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { IdGraphStmtSort iel_stmt_sort2(intersection_exact_loop_graph); for (auto iel_expr : iel_stmt_sort2.exprs()) { - auto iel_inp_groups = intersection_exact_loop_graph.inputGroups(iel_expr); + auto iel_inp_groups = + intersection_exact_loop_graph.inputGroups(iel_expr); - auto iel_out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); + auto iel_out_groups = + intersection_exact_loop_graph.outputGroups(iel_expr); // When replaying the transformations a second time we want to take loop // promotion into consideration. However, we don't want to blindly apply @@ -2888,14 +2960,12 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { IdGroups inp_loop_groups; for (auto iel_inp_group : iel_inp_groups) { - inp_loop_groups.pushBack( - loop_graph_copy.toGroups({iel_inp_group->front()}).front()); + inp_loop_groups.pushBack(loop_graph_copy.toGroup(iel_inp_group->front())); } IdGroups out_loop_groups; for (auto iel_out_group : iel_out_groups) { - out_loop_groups.pushBack( - loop_graph_copy.toGroups({iel_out_group->front()}).front()); + out_loop_groups.pushBack(loop_graph_copy.toGroup(iel_out_group->front())); } bool loop_promote_inputs = @@ -2908,8 +2978,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Promote inputs for replay for (auto iel_inp_group : iel_inp_groups) { // Prefer loop promotion - auto loop_copy_group = - loop_graph_copy.toGroups({iel_inp_group->front()}).front(); + auto loop_copy_group = loop_graph_copy.toGroup(iel_inp_group->front()); auto inp_loop_promo_it = loop_graph_copy_promotion_map.find(loop_copy_group); if (loop_promote_inputs && @@ -2937,8 +3006,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // use that for promotion. ExprGroups promoted_input_uses; for (auto inp_id : promoted_inputs) { - auto inp_exact_group = - idGraph(IdMappingMode::EXACT).toGroups({inp_id}).front(); + auto inp_exact_group = idGraph(IdMappingMode::EXACT).toGroup(inp_id); promoted_input_uses.pushBack( idGraph(IdMappingMode::EXACT).uniqueUses(inp_exact_group)); } @@ -2970,7 +3038,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { << " As:" << replay->toString(); } - auto output_groups = intersection_exact_loop_graph.outputGroups(iel_expr); + auto output_groups = + intersection_exact_loop_graph.outputGroups(iel_expr); // Mark outputs as having a promoted iter domain auto replay_out_ids = @@ -2981,9 +3050,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { for (auto i : c10::irange(replay_out_ids.size())) { if (!idGraph(IdMappingMode::EXACT) .disjointIdSets() - .strictAreMapped( - replay_out_ids[i], output_groups.vector()[i]->front())) { - iel_promotion_map[output_groups.vector()[i]] = replay_out_ids[i]; + .strictAreMapped(replay_out_ids[i], output_groups[i]->front())) { + iel_promotion_map[output_groups[i]] = replay_out_ids[i]; } } } @@ -3039,7 +3107,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { IdGraphStmtSort exact_stmt_sort2(idGraph(IdMappingMode::EXACT)); for (auto exact_expr : exact_stmt_sort2.exprs()) { - auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); + auto input_groups = + idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); IdGroups covered; for (auto inp_group : input_groups) { @@ -3505,8 +3574,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { for (auto tv_id : all_ids) { // Use emplace here as it multiple tv_ids could map to the same ae_group. // Emplace will simply grab the first one that appears. - ae_group_2_id.emplace( - std::make_pair(ae_graph.toGroups({tv_id}).front(), tv_id)); + ae_group_2_id.emplace(std::make_pair(ae_graph.toGroup(tv_id), tv_id)); } auto ae_leaf_groups = ae_graph.toGroups(VectorOfUniqueEntries{ @@ -3559,8 +3627,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // so use that for promotion. ExprGroups promoted_output_defs; for (auto out_id : promoted_outputs) { - auto index_group = - idGraph(IdMappingMode::INDEX).toGroups({out_id}).front(); + auto index_group = idGraph(IdMappingMode::INDEX).toGroup(out_id); promoted_output_defs.pushBack( idGraph(IdMappingMode::INDEX).uniqueDefinitions(index_group)); } @@ -3600,10 +3667,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { bool inps_match = true; for (auto inp_id : index_def_inputs) { IterDomain* promoted_inp = nullptr; - std::cout << inp_id->toString() << std::endl; auto ae_group_pair = ae_graph.disjointIdSet(inp_id); - - std::cout << inp_id->toString() << std::endl; if (ae_group_pair.second && ae_group_2_id.find(ae_group_pair.first) != ae_group_2_id.end()) { promoted_inp = ae_group_2_id.at(ae_group_pair.first); @@ -3640,7 +3704,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::unordered_map replacement_map; for (auto id : ae_inps_outs) { - auto ae_group = ae_graph.toGroups({id}).front(); + auto ae_group = ae_graph.toGroup(id); auto promoted_it = ae_group_2_id.find(ae_group); if (promoted_it == ae_group_2_id.end()) { replacement_map[id] = id->cloneWithoutRFactor(); @@ -3652,6 +3716,9 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { replay = addExprWithReplacement(replacement_map, ae_expr->front()); std::cout << " ***REPLAY3***:\n " << " " << replay->toString(); + } else { + std::cout << " ***MATCH3***:\n " + << " " << replay->toString(); } all_index_exprs.pushBack(replay); @@ -3671,7 +3738,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { TORCH_INTERNAL_ASSERT(ae_inps.size() == replay_inps.size()); for (auto inp_i : c10::irange(ae_inps.size())) { - auto ae_group = ae_graph.toGroups({ae_inps[inp_i]}).front(); + auto ae_group = ae_graph.toGroup(ae_inps[inp_i]); // Only replace if entry does not exist. ae_group_2_id.emplace(std::make_pair(ae_group, replay_inps[inp_i])); } diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index 48a3241e726..cd906b2d140 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -46,11 +46,11 @@ class TORCH_CUDA_CU_API IdGraph { // Same as getDisjointIdSet but for the Expression sets. std::pair disjointExprSet(Expr* expr) const; - // TODO: Audit usage of toGroups: - // Being used when only a single expr or id, break that into a separate - // function. - // There may be an assumption that the size of incoming vector is same as - // output, that is not the case. + // Convert expr to its exprGroup, assert that it exists. + ExprGroup toGroup(Expr* expr) const; + + // Convert iter domain to its IdGroup, assert that it exists. + IdGroup toGroup(IterDomain* id) const; // Convert unique vector of expressions to unique vector of its groups ExprGroups toGroups(const VectorOfUniqueEntries& exprs) const; @@ -58,11 +58,21 @@ class TORCH_CUDA_CU_API IdGraph { // Convert unique vector of IterDomain to unique vector of its groups IdGroups toGroups(const VectorOfUniqueEntries& ids) const; - // Return output iter domain groups of provided expr - IdGroups outputGroups(ExprGroup expr) const; + // Return output/input iter domain groups of provided expr + std::vector outputGroups(ExprGroup expr) const; + std::vector inputGroups(ExprGroup expr) const; + + // Returns if for each group in id_groups0 is the same as all groups in + // id_groups1. Requires size and order to be exact. + bool groupsMatch( + std::vector id_groups0, + std::vector id_groups1) const; - // Return input iter domain groups of provided expr - IdGroups inputGroups(ExprGroup expr) const; + // Returns if for each group in expr_groups0 is the same as all groups in + // expr_groups1. Requires size and order to be exact. + bool groupsMatch( + std::vector expr_groups0, + std::vector expr_groups1) const; // Traverses uses of the IdGroups in 'of' and returns all ExprGroups // that have a use in their definition of provided of IdGroups. @@ -179,11 +189,31 @@ class TORCH_CUDA_CU_API IdGraph { // mappings from IdGraph::isTrivialExpr void removeTrivialExprs(); + // See comment on propagate_expr_ member bool for description + void enableExprPropagation() { + propagate_exprs_ = true; + } + // See comment on propagate_expr_ member bool for description + void disableExprPropagation() { + propagate_exprs_ = false; + } + private: // Removes the provided expression group from unique_definitions_ and // unique_uses_ breaking traversal through them. void eraseExprGroup(ExprGroup expr_group); + // If propagate_exprs_ = false, then mapThroughExpr will not be called as a + // consequence of calling mapIds. As well as mapThroughExpr will not be called + // (again) as a result of calling mapThroughExpr. + // + // Note: For the second sentence of above... mapThroughExpr can call mapIds + // which could in return call mapThoughExpr again, but propagate_exprs_ as + // mentioned above prevents that from happening. + // + // TODO: Should propagate_exprs_ be a const member? + bool propagate_exprs_ = true; + // Keeps a disjoint set entry for all IterDomain for all mapping mode types. // // Using an array here might be nice, but it seems hard to use an enum as an @@ -236,10 +266,10 @@ class TORCH_CUDA_CU_API IdGraphVisitor { IdGraphVisitor() = delete; IdGraphVisitor(const IdGraphVisitor& other) = default; - IdGraphVisitor& operator=(const IdGraphVisitor& other) = default; + IdGraphVisitor& operator=(const IdGraphVisitor& other) = delete; IdGraphVisitor(IdGraphVisitor&& other) = default; - IdGraphVisitor& operator=(IdGraphVisitor&& other) = default; + IdGraphVisitor& operator=(IdGraphVisitor&& other) = delete; virtual ~IdGraphVisitor() = default; diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index eca33e95df0..f4b77f2fa7d 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -909,9 +909,9 @@ TEST_F(NVFuserTest, FusionIndexing20_CUDA) { // [3, 5, 7] fusion.addOutput(tv7); - tv4->merge(0)->split(0, 3, false); + tv4->merge(0)->split(0, 2, false); // [3, 5] - // [3, 3*5/3] + // [3, 3*5//2] TransformPropagatorWithCheck propagator(tv4); MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); @@ -922,10 +922,11 @@ TEST_F(NVFuserTest, FusionIndexing20_CUDA) { tv2->inlineAt(1); tv4->inlineAt(1); - tv5->merge(1)->split(1, 5, false); - // [3, 3*5/3, 7] - tv7->merge(1)->split(1, 5, false); - // [3, 5, (3*5/3)*7/5] + // [2, 3*5//2] + tv5->merge(1)->split(1, 4, false); + // [2, 4, (3*5//2)*1//4] + tv7->merge(1)->split(1, 4, false); + // [2, 4, (3*5//2)*7//4] tv5->inlineAt(2); fusion.printKernel(); From 8d571a6dd2d19909f521921d65d49790d1bf6af9 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 20 Apr 2023 09:55:01 -0400 Subject: [PATCH 016/178] Cleanup debug print. --- csrc/id_graphs.cpp | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 3b40a34199c..44b10a46e6d 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -2707,14 +2707,13 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::cout << "Initial promotion replay:" << std::endl; for (auto iel_expr : iel_stmt_sort.exprs()) { - std::cout << "a" << std::endl; auto input_groups = intersection_exact_loop_graph.inputGroups(iel_expr); - std::cout << "b" << std::endl; + // Check if any inputs need promotion indicating this expr group needs to // be replayed with promoted inputs std::vector promoted_inputs; bool an_input_was_promoted = false; - std::cout << "c" << std::endl; + for (auto inp : input_groups) { auto inp_promo_it = iel_promotion_map.find(inp); if (inp_promo_it == iel_promotion_map.end()) { @@ -2724,12 +2723,12 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { an_input_was_promoted = true; } } - std::cout << "d" << std::endl; + if (!an_input_was_promoted) { // No inputs need promotion so just continue continue; } - std::cout << "e" << std::endl; + Expr* replay = nullptr; IdGroups promoted_input_groups; @@ -2741,7 +2740,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - std::cout << "f" << std::endl; // Before replaying, check if there's already an expression like this, if so // use that for promotion. We would need the iel entries for non-promoted // inputs to match exactly to reuse the expression. @@ -2756,7 +2754,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { non_promoted_input_uses.pushBack( intersection_exact_loop_graph.uniqueUses(iel_group)); } - std::cout << "g" << std::endl; + for (auto iel_use_group : non_promoted_input_uses) { if (transformAtributesMatch(iel_expr->front(), iel_use_group->front())) { auto use_inps = @@ -2793,7 +2791,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { iel_promotion_map[out_groups[i]] = replay_out_ids[i]; } } - std::cout << "post" << std::endl; // Map from an exact iter domain group, to all the exact iter domain groups it // covers std::unordered_map exact_covered_ids; @@ -2823,8 +2820,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { IdGraphStmtSort exact_stmt_sort(idGraph(IdMappingMode::EXACT)); for (auto exact_expr : exact_stmt_sort.exprs()) { - auto input_groups = - idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); + auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); IdGroups covered; for (auto inp_group : input_groups) { @@ -2942,11 +2938,9 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { IdGraphStmtSort iel_stmt_sort2(intersection_exact_loop_graph); for (auto iel_expr : iel_stmt_sort2.exprs()) { - auto iel_inp_groups = - intersection_exact_loop_graph.inputGroups(iel_expr); + auto iel_inp_groups = intersection_exact_loop_graph.inputGroups(iel_expr); - auto iel_out_groups = - intersection_exact_loop_graph.outputGroups(iel_expr); + auto iel_out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); // When replaying the transformations a second time we want to take loop // promotion into consideration. However, we don't want to blindly apply @@ -3038,8 +3032,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { << " As:" << replay->toString(); } - auto output_groups = - intersection_exact_loop_graph.outputGroups(iel_expr); + auto output_groups = intersection_exact_loop_graph.outputGroups(iel_expr); // Mark outputs as having a promoted iter domain auto replay_out_ids = @@ -3107,8 +3100,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { IdGraphStmtSort exact_stmt_sort2(idGraph(IdMappingMode::EXACT)); for (auto exact_expr : exact_stmt_sort2.exprs()) { - auto input_groups = - idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); + auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); IdGroups covered; for (auto inp_group : input_groups) { From 887945967a6fe71dddce8343d758b5519c2b0530 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 20 Apr 2023 10:35:07 -0400 Subject: [PATCH 017/178] Improve accuracy of loop grouping. --- csrc/id_graphs.cpp | 122 +++++++++++++++++++++++++++------------------ 1 file changed, 74 insertions(+), 48 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 44b10a46e6d..968a51dc9be 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -2385,7 +2385,9 @@ std::unordered_map IterDomainGraphs:: void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { idGraph(IdMappingMode::LOOP) = initializeIdGraph(); - // idGraph(IdMappingMode::LOOP).disableExprPropagation(); + // See Indexing20 example for why we shouldn't propagate when generating loop + // groups + idGraph(IdMappingMode::LOOP).disableExprPropagation(); std::unordered_map> p2c_root_broadcast_resolution_map; @@ -2395,36 +2397,14 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::unordered_map> p2c_ca_permissive_maps; - VectorOfUniqueEntries ordered_p_ca_ids; + // Tracks all p2c mappings in permissive maps even those not inlined between + // producer and consumer + std::unordered_map> + p2c_permissive_maps; - auto accumulateInMap = - [](std::unordered_map>& - map, - IterDomain* key, - IterDomain* new_value) { - auto entry_it = map.find(key); - if (map.find(key) == map.end()) { - map[key] = {new_value}; - } else { - auto& value = entry_it->second; - value.pushBack(new_value); - } - }; - - auto accumulateInMapVec = - [](std::unordered_map>& - map, - IterDomain* key, - const VectorOfUniqueEntries& new_values) { - auto entry_it = map.find(key); - if (map.find(key) == map.end()) { - map[key] = new_values; - } else { - auto& value = entry_it->second; - value.pushBack(new_values); - } - }; + VectorOfUniqueEntries ordered_p_ca_ids; + // Grab inlining relationships for (auto expr : exprs) { for (auto producer : ir_utils::filterByType(expr->inputs())) { auto producer_root = producer->getMaybeRFactorDomain(); @@ -2450,46 +2430,92 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { ir_utils::filterByType(expr->outputs())) { auto resolved_bcast_map = resolvedRootBroadcasts(producer, consumer); for (auto entry : resolved_bcast_map) { - accumulateInMap( - p2c_root_broadcast_resolution_map, entry.first, entry.second); + p2c_root_broadcast_resolution_map[entry.first].pushBack(entry.second); for (auto other_exact_bcast : *idGraph(IdMappingMode::EXACT) .disjointIdSet(entry.first) .first) { if (all_producer_ca_deps.has(other_exact_bcast)) { - accumulateInMap( - p2c_root_broadcast_resolution_map, - other_exact_bcast, + p2c_root_broadcast_resolution_map[other_exact_bcast].pushBack( entry.second); } } } - auto p2c_ca_permissive_map = idGraph(IdMappingMode::PERMISSIVE) - .buildMapBetween( - all_producer_ca_deps.vector(), - ir_utils::allIDsOf(consumer)); + auto all_consumer_ids = ir_utils::allIDsOf(consumer); + auto all_producer_ids = ir_utils::allIDsOf(producer); + + auto p2c_permissive_map = + idGraph(IdMappingMode::PERMISSIVE) + .buildMapBetween(all_producer_ids, all_consumer_ids); - for (auto entry : p2c_ca_permissive_map) { + for (auto entry : p2c_permissive_map) { if (entry.second.size() == 0) { continue; } - accumulateInMapVec(p2c_ca_permissive_maps, entry.first, entry.second); + if (all_producer_ca_deps.has(entry.first)) { + p2c_ca_permissive_maps[entry.first].pushBack(entry.second); + } + p2c_permissive_maps[entry.first].pushBack(entry.second); + } + + for (auto entry : p2c_permissive_map) { + if (entry.second.size() == 0) { + continue; + } + p2c_permissive_maps[entry.first].pushBack(entry.second); } } } } - // Make sure this is called in a deterministic order + // Make sure this is called in a deterministic order. Build all inlined + // relationships in loop graph. for (auto p_id : ordered_p_ca_ids) { auto entry_it = p2c_ca_permissive_maps.find(p_id); - if (entry_it == p2c_ca_permissive_maps.end()) { - continue; + if (entry_it != p2c_ca_permissive_maps.end()) { + auto c_ids = entry_it->second; + for (auto c_id : c_ids) { + std::cout << "Map: " << p_id->toString() << " <-> " << c_id->toString() + << std::endl; + idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); + } } - auto c_ids = entry_it->second; - for (auto c_id : c_ids) { - std::cout << "Map: " << p_id->toString() << " <-> " << c_id->toString() - << std::endl; - idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); + } + + // Opportunistically add loop relationships where they don't interfere with + // the loop groups. + for (auto p_id : ordered_p_ca_ids) { + auto entry_it = p2c_permissive_maps.find(p_id); + if (entry_it != p2c_permissive_maps.end()) { + auto c_ids = entry_it->second; + for (auto c_id : c_ids) { + if (idGraph(IdMappingMode::LOOP) + .disjointIdSets() + .permissiveAreMapped(p_id, c_id)) { + // Already mapped + continue; + } + // Grab all iter domains already in the loop groups for both iter + // domains. + auto loop_groups = + idGraph(IdMappingMode::LOOP) + .toGroups(VectorOfUniqueEntries{p_id, c_id}); + VectorOfUniqueEntries all_ids_in_groups; + for (auto loop_group : loop_groups) { + all_ids_in_groups.pushBack(*loop_group); + } + + // Grab the almost exact map of all iter domains in those loop groups + auto ae_groups = + idGraph(IdMappingMode::ALMOSTEXACT).toGroups(all_ids_in_groups); + // If there's no broadcast promotion within the loop group then all the + // iter domains will be almost exact mapped with eachother. + if (ae_groups.size() == 1) { + idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); + std::cout << "Map2: " << p_id->toString() << " <-> " + << c_id->toString() << std::endl; + } + } } } From a3a86fdaf7146ddaafb0a65d7f74b56d0c0eef38 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 25 Apr 2023 23:33:13 -0400 Subject: [PATCH 018/178] Four major tests working, minimal index map. --- csrc/id_graphs.cpp | 854 ++++++++++++++++++++++++++------------------- csrc/id_graphs.h | 18 +- 2 files changed, 501 insertions(+), 371 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 968a51dc9be..f69f3182dc8 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -1161,6 +1161,21 @@ ExprGroups IdGraph::uniqueUses(IdGroup group) const { return unique_uses_it->second; } +void IdGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { + if (exprsMap(expr0, expr1, forward)) { + if (propagate_exprs_) { + mapExprs(expr0, expr1); + mapThroughExpr(expr0, expr1, forward); + } else if ((groupsMatch( + inputGroups(toGroup(expr0)), inputGroups(toGroup(expr1))) && + groupsMatch( + outputGroups(toGroup(expr0)), + outputGroups(toGroup(expr1))))) { + mapExprs(expr0, expr1); + } + } +} + void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { if (expr0 == expr1) { return; @@ -1249,19 +1264,7 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { for (auto use_group_0 : orig_uses0) { auto use0 = use_group_0->front(); auto use1 = use_group_1->front(); - if (exprsMap(use0, use1, true)) { - if (propagate_exprs_) { - mapExprs(use0, use1); - mapThroughExpr(use0, use1, true); - } else if ((groupsMatch( - inputGroups(toGroup(use0)), - inputGroups(toGroup(use1))) && - groupsMatch( - outputGroups(toGroup(use0)), - outputGroups(toGroup(use1))))) { - mapExprs(use0, use1); - } - } + maybeMapThroughExprs(use0, use1, true); } } } @@ -1278,19 +1281,7 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { for (auto def_group_0 : orig_defs0) { auto def0 = def_group_0->front(); auto def1 = def_group_1->front(); - if (exprsMap(def0, def1, false)) { - if (propagate_exprs_) { - mapExprs(def0, def1); - mapThroughExpr(def0, def1, false); - } else if ((groupsMatch( - inputGroups(toGroup(def0)), - inputGroups(toGroup(def1))) && - groupsMatch( - outputGroups(toGroup(def0)), - outputGroups(toGroup(def1))))) { - mapExprs(def0, def1); - } - } + maybeMapThroughExprs(def0, def1, false); } } } @@ -1306,6 +1297,10 @@ bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return false; } + TORCH_INTERNAL_ASSERT( + propagate_exprs_, + "Asked to propagate expression mappings on a graph that has propagate_exprs_ disabled."); + auto first_ids = ir_utils::filterByType( forward ? first->outputs() : first->inputs()) .vector(); @@ -1762,10 +1757,7 @@ Expr* IterDomainGraphs::addReplayAs( } for (auto rep_use : representative_uses) { - if (graph.exprsMap(rep_use, replay, true)) { - graph.mapExprs(rep_use, replay); - graph.mapThroughExpr(rep_use, replay, true); - } + graph.maybeMapThroughExprs(rep_use, replay, true); } } @@ -1866,33 +1858,36 @@ Expr* IterDomainGraphs::addExprWithReplacement( id_uses_[inp_id].pushBack(replay); } + // TODO: Update comments // Initialize output iter domains in the graphs for (auto mode : initialized_modes) { - idGraph(mode).disjointExprSets().initializeSet(replay); - auto replay_group = idGraph(mode).disjointExprSet(replay).first; + auto& graph = idGraph(mode); + + graph.disjointExprSets().initializeSet(replay); + auto replay_group = graph.disjointExprSet(replay).first; for (auto inp_id : ir_utils::filterByType(replay->inputs())) { - if (!idGraph(mode).disjointIdSets().mappingExists(inp_id)) { + if (!graph.disjointIdSets().mappingExists(inp_id)) { // inp_id is not initialized in the map, initialize it - idGraph(mode).initializeId(inp_id, {}, {replay}); + graph.initializeId(inp_id, {}, {replay}); } else { // inp_id is already initialized add the replay as a unique use of its // group. - auto inp_group = idGraph(mode).disjointIdSet(inp_id).first; - idGraph(mode).uniqueUses()[inp_group].pushBack(replay_group); + auto inp_group = graph.disjointIdSet(inp_id).first; + graph.uniqueUses()[inp_group].pushBack(replay_group); } } // Update definitions in the graph of the outputs for (auto out_id : ir_utils::filterByType(replay->outputs())) { - if (!idGraph(mode).disjointIdSets().mappingExists(out_id)) { + if (!graph.disjointIdSets().mappingExists(out_id)) { // out_id is not initialized in the map, initialize it - idGraph(mode).initializeId(out_id, {replay}, {}); + graph.initializeId(out_id, {replay}, {}); } else { // out_id is already initialized, add the replay as a unique definition // of its group - auto out_group = idGraph(mode).disjointIdSet(out_id).first; - idGraph(mode).uniqueDefinitions().at(out_group).pushBack(replay_group); + auto out_group = graph.disjointIdSet(out_id).first; + graph.uniqueDefinitions()[out_group].pushBack(replay_group); } } @@ -1900,51 +1895,41 @@ Expr* IterDomainGraphs::addExprWithReplacement( // already exist in the graphs. If the inputs were replaced we want to // replay forward through the newly added expression. If the outputs were // replaced we want to replay backwards (towards inputs) instead. - auto& graph = idGraph(mode); - // Gather all use expressions from inputs - if (all_inps_replaced) { - VectorOfUniqueEntries representative_uses; - for (auto in : ir_utils::filterByType(replay->inputs())) { - auto uses_pair = - graph.iterDomainGroupUses(graph.disjointIdSet(in).first); - if (uses_pair.second) { - for (auto def_group : uses_pair.first) { - representative_uses.pushBack(def_group->front()); + VectorOfUniqueEntries representative_uses; + for (auto in : ir_utils::filterByType(replay->inputs())) { + auto uses_pair = graph.iterDomainGroupUses(graph.disjointIdSet(in).first); + if (uses_pair.second) { + for (auto use_group : uses_pair.first) { + if (use_group == replay_group) { + continue; } + representative_uses.pushBack(use_group->front()); } } + } - representative_uses.erase(replay); - - for (auto rep_use : representative_uses) { - if (graph.exprsMap(rep_use, replay, true)) { - graph.mapExprs(rep_use, replay); - graph.mapThroughExpr(rep_use, replay, true); - } - } - - if (all_outs_replaced) { - VectorOfUniqueEntries representative_defs; - for (auto out : ir_utils::filterByType(replay->outputs())) { - auto defs_pair = - graph.iterDomainGroupDefinitions(graph.disjointIdSet(out).first); - if (defs_pair.second) { - for (auto def_group : defs_pair.first) { - representative_defs.pushBack(def_group->front()); - } - } - } - representative_defs.erase(replay); + for (auto rep_use : representative_uses) { + graph.maybeMapThroughExprs(rep_use, replay, true); + } - for (auto rep_def : representative_defs) { - if (graph.exprsMap(rep_def, replay, false)) { - graph.mapExprs(rep_def, replay); - graph.mapThroughExpr(rep_def, replay, false); + VectorOfUniqueEntries representative_defs; + for (auto out : ir_utils::filterByType(replay->outputs())) { + auto defs_pair = + graph.iterDomainGroupDefinitions(graph.disjointIdSet(out).first); + if (defs_pair.second) { + for (auto def_group : defs_pair.first) { + if (def_group == replay_group) { + continue; } + representative_defs.pushBack(def_group->front()); } } } + + for (auto rep_def : representative_defs) { + graph.maybeMapThroughExprs(rep_def, replay, false); + } } return replay; } @@ -2397,13 +2382,17 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::unordered_map> p2c_ca_permissive_maps; + // All producer ids in a deterministic order + VectorOfUniqueEntries ordered_p_ca_ids; + + // All ids in a deterministic order + VectorOfUniqueEntries ordered_c_ids; + // Tracks all p2c mappings in permissive maps even those not inlined between // producer and consumer std::unordered_map> p2c_permissive_maps; - VectorOfUniqueEntries ordered_p_ca_ids; - // Grab inlining relationships for (auto expr : exprs) { for (auto producer : ir_utils::filterByType(expr->inputs())) { @@ -2423,6 +2412,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { all_producer_ca_deps.insert( ca_deps_filter.begin(), ca_deps_filter.end()); } + std::cout << "Producer: " << producer->toString() << "\n " + << all_producer_ca_deps.toString() << std::endl; ordered_p_ca_ids.pushBack(all_producer_ca_deps); @@ -2441,8 +2432,9 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - auto all_consumer_ids = ir_utils::allIDsOf(consumer); auto all_producer_ids = ir_utils::allIDsOf(producer); + auto all_consumer_ids = ir_utils::allIDsOf(consumer); + ordered_c_ids.pushBack(all_consumer_ids); auto p2c_permissive_map = idGraph(IdMappingMode::PERMISSIVE) @@ -2475,50 +2467,11 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (entry_it != p2c_ca_permissive_maps.end()) { auto c_ids = entry_it->second; for (auto c_id : c_ids) { - std::cout << "Map: " << p_id->toString() << " <-> " << c_id->toString() - << std::endl; idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); } } } - // Opportunistically add loop relationships where they don't interfere with - // the loop groups. - for (auto p_id : ordered_p_ca_ids) { - auto entry_it = p2c_permissive_maps.find(p_id); - if (entry_it != p2c_permissive_maps.end()) { - auto c_ids = entry_it->second; - for (auto c_id : c_ids) { - if (idGraph(IdMappingMode::LOOP) - .disjointIdSets() - .permissiveAreMapped(p_id, c_id)) { - // Already mapped - continue; - } - // Grab all iter domains already in the loop groups for both iter - // domains. - auto loop_groups = - idGraph(IdMappingMode::LOOP) - .toGroups(VectorOfUniqueEntries{p_id, c_id}); - VectorOfUniqueEntries all_ids_in_groups; - for (auto loop_group : loop_groups) { - all_ids_in_groups.pushBack(*loop_group); - } - - // Grab the almost exact map of all iter domains in those loop groups - auto ae_groups = - idGraph(IdMappingMode::ALMOSTEXACT).toGroups(all_ids_in_groups); - // If there's no broadcast promotion within the loop group then all the - // iter domains will be almost exact mapped with eachother. - if (ae_groups.size() == 1) { - idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); - std::cout << "Map2: " << p_id->toString() << " <-> " - << c_id->toString() << std::endl; - } - } - } - } - std::cout << "Loop groups: " << debug_string::idGroups(idGraph(IdMappingMode::LOOP)) << std::endl; @@ -2605,6 +2558,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // (number of entries in groups ^ 2) auto intersection_exact_loop_graph = initializeIdGraph(); + intersection_exact_loop_graph.disableExprPropagation(); for (auto exact_group : idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { auto set_size = exact_group->size(); @@ -2799,6 +2753,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } + bool replayed = replay == nullptr; if (replay == nullptr) { replay = addReplayAs(promoted_inputs, iel_expr->front()); std::cout << " ***REPLAY***:\n " << iel_expr->front() @@ -2810,13 +2765,104 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Mark outputs as having a promoted iter domain auto replay_out_ids = ir_utils::filterByType(replay->outputs()).vector(); + auto ref_out_ids = + ir_utils::filterByType(iel_expr->front()->outputs()) + .vector(); TORCH_INTERNAL_ASSERT(replay_out_ids.size() == out_groups.size()); for (auto i : c10::irange(replay_out_ids.size())) { iel_promotion_map[out_groups[i]] = replay_out_ids[i]; + // Explicitly map loop map since expr propagation doesn't happen + if (replayed) { + idGraph(IdMappingMode::LOOP).mapIds(replay_out_ids[i], ref_out_ids[i]); + } + } + } + + // Opportunistically add non-inlined loop relationships where they don't + // interfere with the loop groups. This should be on all p_ids that are not + // p_ca_ids. + for (auto p_id : ordered_c_ids.subtract(ordered_p_ca_ids)) { + auto entry_it = p2c_permissive_maps.find(p_id); + if (entry_it == p2c_permissive_maps.end()) { + continue; + } + auto c_ids = entry_it->second; + for (auto c_id : c_ids) { + if (idGraph(IdMappingMode::LOOP) + .disjointIdSets() + .permissiveAreMapped(p_id, c_id)) { + // Already mapped + continue; + } + // Grab all iter domains already in the loop groups for both iter + // domains. + auto loop_groups = + idGraph(IdMappingMode::LOOP) + .toGroups(VectorOfUniqueEntries{p_id, c_id}); + VectorOfUniqueEntries all_ids_in_groups; + for (auto loop_group : loop_groups) { + all_ids_in_groups.pushBack(*loop_group); + } + + // Ignore new loop mappings from replays, we can still opportunistically + // merge leaves if they already have a promoted id from replay associated + // with them. + all_ids_in_groups = all_ids_in_groups.intersect(ordered_c_ids); + + // Grab the almost exact map of all iter domains in those loop groups + auto ae_groups = + idGraph(IdMappingMode::ALMOSTEXACT).toGroups(all_ids_in_groups); + // If there's no broadcast promotion within the loop group then all the + // iter domains will be almost exact mapped with eachother. + if (ae_groups.size() == 1) { + idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); + std::cout << "Map2: " << p_id->toString() << " <-> " << c_id->toString() + << std::endl; + } + } + } + + // Need to update the iel_graph again since we've added operations to the + // exact and loop map. + intersection_exact_loop_graph = initializeIdGraph(); + intersection_exact_loop_graph.disableExprPropagation(); + for (auto exact_group : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + auto set_size = exact_group->size(); + for (auto id0_i : c10::irange(set_size)) { + auto id0 = exact_group->vector()[id0_i]; + for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { + auto id1 = exact_group->vector()[id1_i]; + // id0 and id1 map in the almost exact map, if they also map in the loop + // graph, then add the mapping to the inersection + if (idGraph(IdMappingMode::LOOP) + .disjointIdSets() + .strictAreMapped(id0, id1)) { + intersection_exact_loop_graph.mapIds(id0, id1); + } + } } } + + std::cout << "New loop groups:" << std::endl; + std::cout << debug_string::idGroups(idGraph(IdMappingMode::LOOP)) + << std::endl; + + { + // Update iel_promotion_map since we changed the loop map the IdGroup key is + // invalid + std::unordered_map old_iel_promotion_map; + std::swap(iel_promotion_map, old_iel_promotion_map); + for (auto entry : old_iel_promotion_map) { + auto old_iel_group = entry.first; + auto id = entry.second; + iel_promotion_map[intersection_exact_loop_graph.toGroup( + old_iel_group->front())] = id; + } + } + // Map from an exact iter domain group, to all the exact iter domain groups it // covers std::unordered_map exact_covered_ids; @@ -2842,23 +2888,23 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { exact_covered_ids[id_group] = {}; } } + { + IdGraphStmtSort exact_stmt_sort(idGraph(IdMappingMode::EXACT)); - IdGraphStmtSort exact_stmt_sort(idGraph(IdMappingMode::EXACT)); - - for (auto exact_expr : exact_stmt_sort.exprs()) { - auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); + for (auto exact_expr : exact_stmt_sort.exprs()) { + auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); - IdGroups covered; - for (auto inp_group : input_groups) { - covered.pushBack(exact_covered_ids.at(inp_group)); - } + IdGroups covered; + for (auto inp_group : input_groups) { + covered.pushBack(exact_covered_ids.at(inp_group)); + } - for (auto output_group : - idGraph(IdMappingMode::EXACT).outputGroups(exact_expr)) { - exact_covered_ids[output_group] = covered; + for (auto output_group : + idGraph(IdMappingMode::EXACT).outputGroups(exact_expr)) { + exact_covered_ids[output_group] = covered; + } } } - // Loop promotion map is to prepare for IterDomain replays. Since these // replays will modify the loop map, we operate on a copy of the loop map, // not the original one. @@ -2878,25 +2924,26 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // and the iter domain. std::vector> exact_promoted_terminal_ids; for (auto loop_id : *loop_group) { - if (terminal_loop_ids.has(loop_id)) { - auto iel_set_pair = - intersection_exact_loop_graph.disjointIdSet(loop_id); - TORCH_INTERNAL_ASSERT(iel_set_pair.second); - auto iel_group = iel_set_pair.first; - auto iel_promo_it = iel_promotion_map.find(iel_group); - if (iel_promo_it == iel_promotion_map.end()) { - auto promo_id_exact_it = - idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); - TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); - exact_promoted_terminal_ids.push_back( - std::make_pair(promo_id_exact_it.first, loop_id)); - } else { - auto promo_id_exact_it = - idGraph(IdMappingMode::EXACT).disjointIdSet(iel_promo_it->second); - TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); - exact_promoted_terminal_ids.push_back( - std::make_pair(promo_id_exact_it.first, iel_promo_it->second)); - } + if (!terminal_loop_ids.has(loop_id)) { + continue; + } + + auto iel_set_pair = intersection_exact_loop_graph.disjointIdSet(loop_id); + TORCH_INTERNAL_ASSERT(iel_set_pair.second); + auto iel_group = iel_set_pair.first; + auto iel_promo_it = iel_promotion_map.find(iel_group); + if (iel_promo_it == iel_promotion_map.end()) { + auto promo_id_exact_it = + idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); + TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); + exact_promoted_terminal_ids.push_back( + std::make_pair(promo_id_exact_it.first, loop_id)); + } else { + auto promo_id_exact_it = + idGraph(IdMappingMode::EXACT).disjointIdSet(iel_promo_it->second); + TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); + exact_promoted_terminal_ids.push_back( + std::make_pair(promo_id_exact_it.first, iel_promo_it->second)); } } @@ -2926,6 +2973,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { TORCH_INTERNAL_ASSERT(covered_it != exact_covered_ids.end()); if (loop_group_covered_ids.subtract(covered_it->second).size() == 0) { loop_promotion_id = terminal_id; + break; } } @@ -2937,24 +2985,30 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { err_msg << "\nnone of the terminal iter domains of this group:\n "; for (auto entry : exact_promoted_terminal_ids) { auto terminal_id_group = entry.first; + auto covered_id_groups = exact_covered_ids.at(terminal_id_group); err_msg << " " << debug_string::idGroupStringShort(terminal_id_group) + << " -(covers)-> " + << debug_string::idGroupsStringShortInline(covered_id_groups) << std::endl; } err_msg << "iter domains in this group cover all id groups:\n"; for (auto covered_group : loop_group_covered_ids) { err_msg << " " << debug_string::idGroupStringShort(covered_group); } - TORCH_INTERNAL_ASSERT(false, err_msg.str()); + // TORCH_INTERNAL_ASSERT(false, err_msg.str()); + } else { + loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; } - - loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; } std::cout << "Loop promotion:" << std::endl; for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { - std::cout << debug_string::idGroupStringShort(loop_group) << " -> " - << loop_graph_copy_promotion_map[loop_group]->toString() - << std::endl; + if (loop_graph_copy_promotion_map.find(loop_group) != + loop_graph_copy_promotion_map.end()) { + std::cout << debug_string::idGroupStringShort(loop_group) << " -> " + << loop_graph_copy_promotion_map[loop_group]->toString() + << std::endl; + } } // Reset the promotion map for the second pass @@ -3052,10 +3106,14 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } + bool replayed = replay == nullptr; if (replay == nullptr) { replay = addReplayAs(promoted_inputs, iel_expr->front()); std::cout << " ***REPLAY2***:\n " << iel_expr->front() << " As:" << replay->toString(); + } else { + std::cout << " ***MATCH2***:\n " << iel_expr->front() + << " As:" << replay->toString(); } auto output_groups = intersection_exact_loop_graph.outputGroups(iel_expr); @@ -3063,6 +3121,9 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Mark outputs as having a promoted iter domain auto replay_out_ids = ir_utils::filterByType(replay->outputs()).vector(); + auto ref_out_ids = + ir_utils::filterByType(iel_expr->front()->outputs()) + .vector(); TORCH_INTERNAL_ASSERT(replay_out_ids.size() == output_groups.size()); @@ -3071,173 +3132,195 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { .disjointIdSets() .strictAreMapped(replay_out_ids[i], output_groups[i]->front())) { iel_promotion_map[output_groups[i]] = replay_out_ids[i]; + // Explicitly map loop map since expr propagation doesn't happen on the + // loop map and the replayed outputs are brand new so we can map them + // without joining disjoint loop groups (other than the new loop groups + // the outputs of the replay are in) + if (replayed) { + idGraph(IdMappingMode::LOOP) + .mapIds(replay_out_ids[i], ref_out_ids[i]); + } } } } - // Need to update the iel_graph again since we've added operations to the - // exact and loop map. - // *************** START: Code copied verbatim from above ******************** - intersection_exact_loop_graph = initializeIdGraph(); - for (auto exact_group : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - auto set_size = exact_group->size(); - for (auto id0_i : c10::irange(set_size)) { - auto id0 = exact_group->vector()[id0_i]; - for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { - auto id1 = exact_group->vector()[id1_i]; - // id0 and id1 map in the almost exact map, if they also map in the loop - // graph, then add the mapping to the inersection - if (idGraph(IdMappingMode::LOOP) - .disjointIdSets() - .strictAreMapped(id0, id1)) { - intersection_exact_loop_graph.mapIds(id0, id1); - } - } + std::cout << "Promotion map from second replay: " << std::endl; + for (auto group : + intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + if (iel_promotion_map.find(group) == iel_promotion_map.end()) { + continue; } + std::cout << debug_string::idGroupStringShort(group) << " -> " + << iel_promotion_map.at(group)->toString() << std::endl; } - // *************** STOP: Code copied verbatim from above ******************** - - // *************** START: Code copied verbatim from above ******************** - exact_covered_ids.clear(); - for (auto id_group : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - // Initialize inputs - if (idGraph(IdMappingMode::EXACT).uniqueDefinitions(id_group).empty()) { - exact_covered_ids[id_group] = {id_group}; + // Need to perform some updates after replay + { + intersection_exact_loop_graph = initializeIdGraph(); + intersection_exact_loop_graph.disableExprPropagation(); + for (auto exact_group : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + auto set_size = exact_group->size(); + for (auto id0_i : c10::irange(set_size)) { + auto id0 = exact_group->vector()[id0_i]; + for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { + auto id1 = exact_group->vector()[id1_i]; + // id0 and id1 map in the almost exact map, if they also map in the + // loop graph, then add the mapping to the inersection + if (idGraph(IdMappingMode::LOOP) + .disjointIdSets() + .strictAreMapped(id0, id1)) { + intersection_exact_loop_graph.mapIds(id0, id1); + } + } + } } - // Initialize rfactor groups - if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { - return view_rfactor_ids_.find(id) != view_rfactor_ids_.end(); - })) { - exact_covered_ids[id_group] = {id_group}; + // Update iel_promotion_map since we changed the loop map the IdGroup key is + // invalid + std::unordered_map old_iel_promotion_map; + std::swap(iel_promotion_map, old_iel_promotion_map); + for (auto entry : old_iel_promotion_map) { + auto old_iel_group = entry.first; + auto id = entry.second; + iel_promotion_map[intersection_exact_loop_graph.toGroup( + old_iel_group->front())] = id; } - // Initialize broadcast groups to empty - if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { - return id->isBroadcast(); - })) { - exact_covered_ids[id_group] = {}; - } - } + exact_covered_ids.clear(); - IdGraphStmtSort exact_stmt_sort2(idGraph(IdMappingMode::EXACT)); + for (auto id_group : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + // Initialize inputs + if (idGraph(IdMappingMode::EXACT).uniqueDefinitions(id_group).empty()) { + exact_covered_ids[id_group] = {id_group}; + } - for (auto exact_expr : exact_stmt_sort2.exprs()) { - auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); + // Initialize rfactor groups + if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { + return view_rfactor_ids_.find(id) != view_rfactor_ids_.end(); + })) { + exact_covered_ids[id_group] = {id_group}; + } - IdGroups covered; - for (auto inp_group : input_groups) { - covered.pushBack(exact_covered_ids.at(inp_group)); + // Initialize broadcast groups to empty + if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { + return id->isBroadcast(); + })) { + exact_covered_ids[id_group] = {}; + } } - for (auto output_group : - idGraph(IdMappingMode::EXACT).outputGroups(exact_expr)) { - exact_covered_ids[output_group] = covered; + IdGraphStmtSort exact_stmt_sort(idGraph(IdMappingMode::EXACT)); + + for (auto exact_expr : exact_stmt_sort.exprs()) { + auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); + + IdGroups covered; + for (auto inp_group : input_groups) { + covered.pushBack(exact_covered_ids.at(inp_group)); + } + + for (auto output_group : + idGraph(IdMappingMode::EXACT).outputGroups(exact_expr)) { + exact_covered_ids[output_group] = covered; + } } - } - // Loop promotion map is to prepare for IterDomain replays. Since these - // replays will modify the loop map, we operate on a copy of the loop map, - // not the original one. + // Loop promotion map is to prepare for IterDomain replays. Since these + // replays will modify the loop map, we operate on a copy of the loop map, + // not the original one. - loop_graph_copy = idGraph(IdMappingMode::LOOP); - loop_graph_copy_promotion_map.clear(); + loop_graph_copy = idGraph(IdMappingMode::LOOP); + loop_graph_copy_promotion_map.clear(); + } + + // Returns a new promoted domain if one is found in the iel_promotion_map, + // otherwise returns original id. + auto get_promoted_id = [&](IterDomain* id) { + auto iel_group = intersection_exact_loop_graph.toGroup(id); + auto iel_promotion_map_it = iel_promotion_map.find(iel_group); + if (iel_promotion_map_it != iel_promotion_map.end()) { + return iel_promotion_map_it->second; + } + return id; + }; - std::cout << "Find promoted ids within loop groups." << std::endl; + // Returns the entry in exact_covered_ids associated with provided IterDomain + auto get_covered_exact_groups = [&](IterDomain* id) { + auto exact_group = idGraph(IdMappingMode::EXACT).toGroup(id); + auto covered_it = exact_covered_ids.find(exact_group); + TORCH_INTERNAL_ASSERT( + covered_it != exact_covered_ids.end(), + "Missing map entry in analysis for: ", + debug_string::idGroupStringShort(exact_group)); + return covered_it->second; + }; + std::cout << "Find promoted ids from loop group or promoted iter domains." + << std::endl; for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_group->size() == 1) { - loop_graph_copy_promotion_map[loop_group] = loop_group->front(); + auto promoted_id = get_promoted_id(loop_group->front()); + + TORCH_INTERNAL_ASSERT( + get_covered_exact_groups(loop_group->front()) + .subtract(get_covered_exact_groups(promoted_id)) + .size() == 0, + "Promotion failed, promoted id: ", + promoted_id->toString(), + " doesn't cover the right domains for ", + loop_group->front()->toString()); + loop_graph_copy_promotion_map[loop_group] = promoted_id; continue; } - // We need to check the exact groups the terminal id's are in, but for - // promotion we want an iter domain within the loop group. Since exact - // group can traverse loop group boundaires, save a vector of the group - // and the iter domain. - std::vector> exact_promoted_terminal_ids; - for (auto loop_id : *loop_group) { - // *************** START DIFF ******************** - // This is different as there's iter domains not based on the original - // producer-consumer relationships, so finding terminal id's can be a bit - // different here. + // If promotion entry exists for any terminal id the promoted id will be + // stored here. + std::vector promoted_terminal_ids; - // If there's an entry in the p2c_ca_permissive map, this loop_id is not a - // promotion candidate. - if (p2c_ca_permissive_maps.find(loop_id) != - p2c_ca_permissive_maps.end()) { - continue; - } - - // Grab all the output groups of uses in the iel graph. - TORCH_INTERNAL_ASSERT( - intersection_exact_loop_graph.disjointIdSet(loop_id).second); - auto iel_group = - intersection_exact_loop_graph.disjointIdSet(loop_id).first; - auto iel_uses = intersection_exact_loop_graph.uniqueUses(iel_group); - - IdGroups iel_output_groups; - for (auto iel_use : iel_uses) { - iel_output_groups.pushBack( - intersection_exact_loop_graph.outputGroups(iel_use)); - } + // If a promotion entry doesn't exist for a terminal id, put it here. + std::vector terminal_ids; - // Convert the iel output groups into loop groups - IdGroups loop_output_groups; - for (auto iel_group : iel_output_groups) { - TORCH_INTERNAL_ASSERT( - intersection_exact_loop_graph.disjointIdSet(iel_group->front()) - .second); - loop_output_groups.pushBack( - intersection_exact_loop_graph.disjointIdSet(iel_group->front()) - .first); - } + IdGroups all_covered_exact_groups; - // If all outputs of the uses of this id in the iel graph are within the - // same loop group, then it's not a promotion candidate. - if (loop_output_groups.size() == 1 && - loop_output_groups.front() == loop_group) { + for (auto loop_id : *loop_group) { + if (!terminal_loop_ids.has(loop_id)) { continue; } - // This id is a promotion candidate - auto promo_id_exact_it = - idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); - TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); - exact_promoted_terminal_ids.push_back( - std::make_pair(promo_id_exact_it.first, loop_id)); - } - // *************** STOP DIFF ******************** + all_covered_exact_groups.pushBack(get_covered_exact_groups(loop_id)); - // All exact groups with iter domains in this loop group - IdGroups exact_groups; - for (auto loop_id : *loop_group) { - auto exact_set_pair = - idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); - TORCH_INTERNAL_ASSERT(exact_set_pair.second); - exact_groups.pushBack(exact_set_pair.first); + auto promoted_id = get_promoted_id(loop_id); + if (promoted_id == loop_id) { + terminal_ids.push_back(loop_id); + } else { + promoted_terminal_ids.push_back(promoted_id); + } } - // All exact groups covered by all iter domains in this loop group - IdGroups loop_group_covered_ids; - for (auto exact_group : exact_groups) { - auto covered_it = exact_covered_ids.find(exact_group); - TORCH_INTERNAL_ASSERT(covered_it != exact_covered_ids.end()); - loop_group_covered_ids.pushBack(covered_it->second); - } + auto candidate_ids = + promoted_terminal_ids.empty() ? terminal_ids : promoted_terminal_ids; IterDomain* loop_promotion_id = nullptr; - for (auto entry : exact_promoted_terminal_ids) { - auto terminal_id_group = entry.first; - auto terminal_id = entry.second; - auto covered_it = exact_covered_ids.find(terminal_id_group); - TORCH_INTERNAL_ASSERT(covered_it != exact_covered_ids.end()); - if (loop_group_covered_ids.subtract(covered_it->second).size() == 0) { - loop_promotion_id = terminal_id; + for (auto candidate_id : candidate_ids) { + if (all_covered_exact_groups + .subtract(get_covered_exact_groups(candidate_id)) + .empty()) { + loop_promotion_id = candidate_id; + } + } + + // Try any replayed IDs if we're still mising the promoted id. + if (loop_promotion_id == nullptr) { + candidate_ids = loop_group->subtract(ordered_c_ids).vector(); + for (auto candidate_id : candidate_ids) { + if (all_covered_exact_groups + .subtract(get_covered_exact_groups(candidate_id)) + .empty()) { + loop_promotion_id = candidate_id; + } } } @@ -3245,24 +3328,19 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::stringstream err_msg; err_msg << "\nCould not find promotion for loop group:\n "; err_msg << debug_string::idGroupStringShort(loop_group); - err_msg << "\nnone of the terminal iter domains of this group:\n "; - for (auto entry : exact_promoted_terminal_ids) { - auto terminal_id_group = entry.first; - err_msg << " " << debug_string::idGroupStringShort(terminal_id_group) - << std::endl; - } - err_msg << "iter domains in this group cover all id groups:\n"; - for (auto covered_group : loop_group_covered_ids) { - err_msg << " " << debug_string::idGroupStringShort(covered_group); - } + err_msg << "\nnone of the candidate iter domains of this group:\n "; + err_msg << " " + << VectorOfUniqueEntries(candidate_ids).toString(); + err_msg << "\n cover all id groups that the loop group covers:\n"; + err_msg << " " + << debug_string::idGroupsStringShort(all_covered_exact_groups) + << std::endl; TORCH_INTERNAL_ASSERT(false, err_msg.str()); } loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; } - // *************** STOP: Code copied verbatim from above ******************** - std::cout << "Promotion map from concrete id pass: " << std::endl; for (auto group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_graph_copy_promotion_map.find(group) == @@ -3450,6 +3528,63 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } + auto get_representative_promoted_id = [&](IterDomain* id) { + auto loop_copy_group_pair = loop_graph_copy.disjointIdSet(id); + TORCH_INTERNAL_ASSERT(loop_copy_group_pair.second); + auto loop_copy_group = loop_copy_group_pair.first; + + auto promo_id_it = loop_graph_copy_promotion_map.find(loop_copy_group); + TORCH_INTERNAL_ASSERT(promo_id_it != loop_graph_copy_promotion_map.end()); + + return promo_id_it->second; + }; + + std::cout << "Opportunistic joining of shared promos:" << std::endl; + // Opportunistically collapse indexing of non-inlined leaf domains + for (auto expr : exprs) { + for (auto producer : ir_utils::filterByType(expr->inputs())) { + std::cout << " Producer: " << producer->toString() << std::endl; + auto producer_root = producer->getMaybeRFactorDomain(); + + auto non_inline_producer_domain = producer->domain()->domain(); + non_inline_producer_domain.erase( + non_inline_producer_domain.begin(), + non_inline_producer_domain.begin() + + producer->getComputeAtPosition()); + + for (auto consumer : + ir_utils::filterByType(expr->outputs())) { + std::cout << " Consumer: " << consumer->toString() << std::endl; + auto consumer_domain = consumer->domain()->domain(); + + auto p2c_permissive_map = + idGraph(IdMappingMode::PERMISSIVE) + .buildMapBetween(non_inline_producer_domain, consumer_domain); + + for (auto p_id : non_inline_producer_domain) { + auto p2c_it = p2c_permissive_map.find(p_id); + if (p2c_it == p2c_permissive_map.end() || p2c_it->second.empty()) { + continue; + } + + auto rep_p_id = get_representative_promoted_id(p_id); + auto c_id = p2c_it->second.front(); + auto rep_c_id = get_representative_promoted_id(c_id); + + std::cout << " " << p_id->toString() << " -> " + << rep_p_id->toString() << " :: " << c_id->toString() + << " -> " << rep_c_id->toString() << std::endl; + if (idGraph(IdMappingMode::ALMOSTEXACT) + .disjointIdSets() + .strictAreMapped(rep_p_id, rep_c_id)) { + std::cout << " Mapped" << std::endl; + shared_promoted_id.mapEntries(p_id, c_id); + } + } + } + } + } + std::cout << "Leaf iter domains that share a promoted iter domain." << std::endl; for (auto disjoint_set : shared_promoted_id.disjointSets()) { @@ -3467,15 +3602,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { VectorOfUniqueEntries used_promo_ids; for (auto id_group : shared_promoted_id.disjointSets()) { - auto first_id = id_group->front(); - auto loop_copy_group_pair = loop_graph_copy.disjointIdSet(first_id); - TORCH_INTERNAL_ASSERT(loop_copy_group_pair.second); - auto loop_copy_group = loop_copy_group_pair.first; - - auto promo_id_it = loop_graph_copy_promotion_map.find(loop_copy_group); - TORCH_INTERNAL_ASSERT(promo_id_it != loop_graph_copy_promotion_map.end()); - - IterDomain* promo_id = promo_id_it->second; + IterDomain* promo_id = get_representative_promoted_id(id_group->front()); // Promoted id is already part of the group, just use that. if (std::find(id_group->begin(), id_group->end(), promo_id) != @@ -3488,8 +3615,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Promo id generated from running replay, we can use it for one of the // index groups. - if (!shared_promoted_id.mappingExists(promo_id) && - !used_promo_ids.has(promo_id)) { + if (!ordered_c_ids.has(promo_id) && !used_promo_ids.has(promo_id)) { used_promo_ids.pushBack(promo_id); for (auto id : *id_group) { leaf_promotion_map[id] = promo_id; @@ -3505,15 +3631,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - std::cout << "Iter domain group to their promoted iter domain." << std::endl; - for (auto id_group : shared_promoted_id.disjointSets()) { - std::cout << id_group->toString() << "\n -> " - << leaf_promotion_map.at(id_group->front()) << std::endl; - } - - // Could pass this into the function, but just using this for now. - auto all_tvs = ir_utils::allTvsOfExprs(exprs); - // TODO: This needs to be available as a member function auto get_promoted_domain = [&](TensorDomain* td) { std::vector promoted_leaves; @@ -3525,7 +3642,18 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { return promoted_leaves; }; + std::cout << "Iter domain group to their promoted iter domain." << std::endl; + for (auto id_group : shared_promoted_id.disjointSets()) { + std::cout << id_group->toString() << "\n -> " + << leaf_promotion_map.at(id_group->front()) << std::endl; + } + + // Could pass this into the function, but just using this for now. + auto all_tvs = ir_utils::allTvsOfExprs(exprs); + idGraph(IdMappingMode::INDEX) = initializeIdGraph(); + idGraph(IdMappingMode::INDEX).mapThroughTrivialExprs(); + idGraph(IdMappingMode::INDEX).removeTrivialExprs(); // Track every expression required for indexing VectorOfUniqueEntries all_index_exprs; @@ -3580,11 +3708,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { auto all_ids = VectorOfUniqueEntries(all_ids_v.begin(), all_ids_v.end()); - // Add the promoted domain ids - for (auto promoted_id : promoted_domain) { - all_ids.pushBack(promoted_id); - } - // Create a map from the ae group to the iter domain as when we replay we'll // replace the ae iter domain in the replay with the id in this map. std::unordered_map ae_group_2_id; @@ -3592,7 +3715,13 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { for (auto tv_id : all_ids) { // Use emplace here as it multiple tv_ids could map to the same ae_group. // Emplace will simply grab the first one that appears. - ae_group_2_id.emplace(std::make_pair(ae_graph.toGroup(tv_id), tv_id)); + ae_group_2_id[ae_graph.toGroup(tv_id)] = tv_id; + } + + // Add the promoted domain ids + for (auto promoted_id : promoted_domain) { + all_ids.pushBack(promoted_id); + ae_group_2_id[ae_graph.toGroup(promoted_id)] = promoted_id; } auto ae_leaf_groups = ae_graph.toGroups(VectorOfUniqueEntries{ @@ -3616,7 +3745,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Replay indexing transformations start on leaf nodes propagating back to // the root domain - for (ExprGroup ae_expr : reverse_indexing_transforms) { + for (ExprGroup ae_expr_group : reverse_indexing_transforms) { // Outputs must be promoted with the ae_group_2_id map. Inputs may be // promoted when we intercept the history of the TV with the replay. // @@ -3627,16 +3756,16 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // however having an additional ID and expression per case doesn't seem // too bad right now. - auto ae_output_groups = ae_graph.outputGroups(ae_expr); + auto ae_output_groups = ae_graph.outputGroups(ae_expr_group); std::vector promoted_outputs; for (auto out_group : ae_output_groups) { auto out_promo_it = ae_group_2_id.find(out_group); - TORCH_INTERNAL_ASSERT( - out_promo_it != ae_group_2_id.end(), - "Expected promoted iter domain for: ", - debug_string::idGroupStringShort(out_group)); - promoted_outputs.push_back(out_promo_it->second); + if (out_promo_it == ae_group_2_id.end()) { + promoted_outputs.push_back(out_group->front()); + } else { + promoted_outputs.push_back(out_promo_it->second); + } } Expr* replay = nullptr; @@ -3651,10 +3780,11 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } for (auto index_def_group : promoted_output_defs) { - // This enforces that inputs and outputs are all almost exact mapping + // This enforces that inputs and outputs are all almost exact mapped if (!idGraph(IdMappingMode::ALMOSTEXACT) .disjointExprSets() - .strictAreMapped(index_def_group->front(), ae_expr->front())) { + .strictAreMapped( + index_def_group->front(), ae_expr_group->front())) { continue; } @@ -3663,61 +3793,48 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { auto index_def_outputs = ir_utils::filterByType( index_def_group->front()->outputs()) .vector(); + bool outs_match = true; - for (auto inp_i : c10::irange(index_def_outputs.size())) { + for (auto out_i : c10::irange(index_def_outputs.size())) { outs_match = outs_match && idGraph(IdMappingMode::INDEX) .disjointIdSets() .strictAreMapped( - index_def_outputs[inp_i], promoted_outputs[inp_i]); + index_def_outputs[out_i], promoted_outputs[out_i]); } if (!outs_match) { continue; } - // Outputs all match in the index map, but need to make sure the inputs - // do as well. - auto index_def_inputs = ir_utils::filterByType( - index_def_group->front()->inputs()) - .vector(); + replay = index_def_group->front(); - bool inps_match = true; - for (auto inp_id : index_def_inputs) { - IterDomain* promoted_inp = nullptr; - auto ae_group_pair = ae_graph.disjointIdSet(inp_id); - if (ae_group_pair.second && - ae_group_2_id.find(ae_group_pair.first) != ae_group_2_id.end()) { - promoted_inp = ae_group_2_id.at(ae_group_pair.first); - } else { - // TODO: Should this be here or should we continue below. Check - // Indexing20 test. + std::vector ae_inps = + ir_utils::filterByType(ae_expr_group->front()->inputs()) + .vector(); + + auto replay_inputs = + ir_utils::filterByType(replay->inputs()).vector(); - // This input is already almost exact mapped, and we don't need this - // input to map exactly in the index map. + for (auto inp_i : c10::irange(replay_inputs.size())) { + auto ae_group_pair = ae_graph.disjointIdSet(ae_inps[inp_i]); + if (!(ae_group_pair.second && + ae_group_2_id.find(ae_group_pair.first) != + ae_group_2_id.end())) { continue; } - - inps_match = inps_match && - idGraph(IdMappingMode::INDEX) - .disjointIdSets() - .strictAreMapped(inp_id, promoted_inp); - } - - if (!inps_match) { - continue; + idGraph(IdMappingMode::INDEX) + .mapIds( + replay_inputs[inp_i], ae_group_2_id.at(ae_group_pair.first)); } - - replay = index_def_group->front(); - break; } if (replay == nullptr) { std::vector ae_inps_outs = - ir_utils::filterByType(ae_expr->front()->inputs()) + ir_utils::filterByType(ae_expr_group->front()->inputs()) .vector(); - auto outs = - ir_utils::filterByType(ae_expr->front()->outputs()); + auto outs = ir_utils::filterByType( + ae_expr_group->front()->outputs()); ae_inps_outs.insert(ae_inps_outs.end(), outs.begin(), outs.end()); std::unordered_map replacement_map; @@ -3731,9 +3848,12 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - replay = addExprWithReplacement(replacement_map, ae_expr->front()); + replay = + addExprWithReplacement(replacement_map, ae_expr_group->front()); std::cout << " ***REPLAY3***:\n " - << " " << replay->toString(); + << ae_expr_group->front()->toString() + << " As:" << replay->toString(); + } else { std::cout << " ***MATCH3***:\n " << " " << replay->toString(); @@ -3749,7 +3869,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } std::vector ae_inps = - ir_utils::filterByType(ae_expr->front()->inputs()) + ir_utils::filterByType(ae_expr_group->front()->inputs()) .vector(); std::vector replay_inps = ir_utils::filterByType(replay->inputs()).vector(); diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index cd906b2d140..2963164b165 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -163,7 +163,13 @@ class TORCH_CUDA_CU_API IdGraph { // new mapping through id0/id1 definitions/uses. void mapIds(IterDomain* id0, IterDomain* id1); + // Checks if expr0 and expr1 should map together, maps them together, and if + // expression propagation is on, propagates mapping through them. This should + // be the only call in IdGraph to mapThroughExpr + void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); + // Map expr0 and expr1 with eachother, update unique_definitions_ unique_uses_ + // TODO: Make this variant hidden? void mapExprs(Expr* expr0, Expr* expr1); // Checks if expr's are considered "the same" where sameness inputs and @@ -175,6 +181,8 @@ class TORCH_CUDA_CU_API IdGraph { // will map inputs // in the provided mode. // Returns if expressions were mapped through. + // + // TODO: Make this private bool mapThroughExpr(Expr* first, Expr* second, bool forward); // Map through loop swizzles, as input/output IterDomains are exact, only the @@ -190,10 +198,12 @@ class TORCH_CUDA_CU_API IdGraph { void removeTrivialExprs(); // See comment on propagate_expr_ member bool for description - void enableExprPropagation() { - propagate_exprs_ = true; - } - // See comment on propagate_expr_ member bool for description + // Once disabled this can't be reenabled on a graph. If it's reenabled it's + // hard to predict how mappings will propagate, which will be triggered on the + // next mapping. To support changing this flag, we should likely run through + // all expressions currently registered and propagate through all of them on + // switch. Then once enabled it couldn't be redisabled because we don't record + // the history of mapId calls. void disableExprPropagation() { propagate_exprs_ = false; } From 6eaeb061f6fc2b57376966064aeec77819fb37ee Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 26 Apr 2023 14:03:19 -0400 Subject: [PATCH 019/178] Cleanup print utilities. --- csrc/id_graphs.cpp | 297 ++++++++++++++++++++++++++++++--------------- csrc/id_graphs.h | 69 +++++++++++ 2 files changed, 267 insertions(+), 99 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index f69f3182dc8..0280ddc12d6 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -13,22 +13,39 @@ namespace nvfuser { -namespace debug_string { -// A few compressed printing utilities to show critical uniqueness information. -// i.e. being able to tell slight differences between groups we're working with. +// Printing utilities to show critical uniqueness information. i.e. being able +// to tell slight differences between groups we're working with. +namespace debug { +namespace { // Sometimes it can be helpful to directly check the pointer addresses of the // groups. As one group might look exactly like another group but are in // different disjoint sets. Leaving commented out by default. +template +std::string toString(const T* ptr, bool enable) { + if (!enable) { + return ""; + } + std::stringstream ss; + ss << ptr; + return "[0x." + ss.str().substr(9) + "]"; +} -// template -// std::string ptrStringShort(const T* ptr) { -// std::stringstream ss; -// ss << ptr; -// return "0x." + ss.str().substr(9); -// } +std::string indent(int size = 0) { + std::stringstream ss; + for (auto i : c10::irange(size)) { + // Unused variable error + if (i >= 0) { + ss << " "; + } + } + return ss.str(); +} +} // namespace -std::string idsStringShort(const VectorOfUniqueEntries& id_group) { +std::string toString( + const std::vector& id_group, + int indent_size) { std::vector names; for (auto id : id_group) { names.push_back(id->name()); @@ -36,18 +53,24 @@ std::string idsStringShort(const VectorOfUniqueEntries& id_group) { std::sort(names.begin(), names.end()); std::stringstream ss; - ss << "{" << names << "}"; + ss << indent(indent_size) << "{" << names << "}"; return ss.str(); } -std::string idGroupStringShort(const IdGroup& id_group) { +std::string toString(const IdGroup& id_group, int indent_size, bool with_ptr) { std::stringstream ss; - ss << /* ptrStringShort(id_group.get()) << */ "(idg)" - << idsStringShort(*id_group); + ss << indent(indent_size) << "idg" << (with_ptr ? "(" : "") + << toString(id_group.get(), with_ptr) << (with_ptr ? ")" : "") + << toString(id_group->vector()); return ss.str(); } -std::string idGroupsStringShortInline(const std::vector& id_groups) { +std::string toString( + const std::vector& id_groups, + int indent_size, + bool with_ptr) { + std::stringstream ss; + // Track position in id_groups and its min iter domain name in the set std::vector> group_name_info; @@ -63,31 +86,24 @@ std::string idGroupsStringShortInline(const std::vector& id_groups) { group_name_info.push_back(std::make_pair(min_id_name, pos++)); } + ss << indent(indent_size) << "(idgs){\n"; + // Sort based on minimum id in the group std::sort(group_name_info.begin(), group_name_info.end()); - std::stringstream ss; - ss << /* ptrStringShort(&id_groups) <<*/ "(idgs){"; - bool first = true; for (auto i : c10::irange(group_name_info.size())) { - if (first) { - first = false; - } else { - ss << ", "; - } auto pos = group_name_info[i].second; - ss << idGroupStringShort(id_groups[pos]); + ss << toString(id_groups[pos], indent_size + 1, with_ptr) << "\n"; } ss << "}"; return ss.str(); } -std::string idGroupsStringShortInline(const IdGroups& id_groups) { - return idGroupsStringShortInline(id_groups.vector()); -} - -std::string idGroupsStringShort(const std::vector& id_groups) { +std::string toString( + const IdGroups& id_groups, + int indent_size, + bool with_ptr) { std::stringstream ss; // Track position in id_groups and its min iter domain name in the set @@ -105,58 +121,86 @@ std::string idGroupsStringShort(const std::vector& id_groups) { group_name_info.push_back(std::make_pair(min_id_name, pos++)); } - ss << /* ptrStringShort(&id_groups) <<*/ "(idgs){\n"; + ss << indent(indent_size) << "(idgs){\n"; // Sort based on minimum id in the group std::sort(group_name_info.begin(), group_name_info.end()); for (auto i : c10::irange(group_name_info.size())) { auto pos = group_name_info[i].second; - ss << " " << idGroupStringShort(id_groups[pos]) << "\n"; + ss << toString(id_groups.vector()[pos], indent_size + 1, with_ptr) << "\n"; } ss << "}"; return ss.str(); } -std::string idGroupsStringShort(const IdGroups& id_groups) { - return idGroupsStringShort(id_groups.vector()); -} +std::string toInlineString(const std::vector& id_groups) { + // Track position in id_groups and its min iter domain name in the set + std::vector> group_name_info; -std::string idGroups(const IdGraph& id_graph) { - IdGroups id_groups( - id_graph.disjointIdSets().disjointSets().begin(), - id_graph.disjointIdSets().disjointSets().end()); - return idGroupsStringShort(id_groups); + unsigned int pos = 0; + + for (auto id_group : id_groups) { + unsigned int min_id_name = std::numeric_limits::max(); + for (auto id : *id_group) { + if (id->name() < min_id_name) { + min_id_name = id->name(); + } + } + group_name_info.push_back(std::make_pair(min_id_name, pos++)); + } + + // Sort based on minimum id in the group + std::sort(group_name_info.begin(), group_name_info.end()); + + std::stringstream ss; + + ss << "(idgs){"; + bool first = true; + for (auto i : c10::irange(group_name_info.size())) { + if (first) { + first = false; + } else { + ss << ", "; + } + auto pos = group_name_info[i].second; + ss << toString(id_groups[pos]); + } + + return ss.str(); } -std::string exprGroupStringShort(ExprGroup expr_group) { +std::string toString(const std::vector& expr_group, int indent_size) { std::vector names; - for (auto expr : *expr_group) { + for (auto expr : expr_group) { names.push_back(expr->name()); } std::sort(names.begin(), names.end()); std::stringstream ss; - ss << /* ptrStringShort(&expr_group) <<*/ "(exprg){" << names << "}"; + ss << indent(indent_size) << "{" << names << "}"; return ss.str(); } -std::string exprGroupStringShort( - const IdGraph& id_graph, - ExprGroup expr_group) { +std::string toString( + const ExprGroup& expr_group, + int indent_size, + bool with_ptr) { std::stringstream ss; - auto inputs = IdGroups(id_graph.inputGroups(expr_group)); - auto outputs = IdGroups(id_graph.outputGroups(expr_group)); - ss << idGroupsStringShortInline(inputs) << " -" - << exprGroupStringShort(expr_group) << "-> " - << idGroupsStringShortInline(outputs); + ss << indent(indent_size) << "exprg" << (with_ptr ? "(" : "") + << toString(expr_group.get(), with_ptr) << (with_ptr ? ")" : "") + << toString(expr_group->vector()); return ss.str(); } -std::string exprGroupsStringShort( +std::string toString( const IdGraph& id_graph, - ExprGroups expr_groups) { + const std::vector& expr_groups, + int indent_size, + bool with_ptr) { + std::stringstream ss; + // Track position in expr_groups and its min iter domain name in the set std::vector> group_name_info; @@ -172,30 +216,93 @@ std::string exprGroupsStringShort( group_name_info.push_back(std::make_pair(min_expr_name, pos++)); } + ss << indent(indent_size) << "(exprgs){\n"; + // Sort based on minimum id in the group std::sort(group_name_info.begin(), group_name_info.end()); + for (auto i : c10::irange(group_name_info.size())) { + auto pos = group_name_info[i].second; + auto expr_group = expr_groups[pos]; + + auto inputs = IdGroups(id_graph.inputGroups(expr_group)); + auto outputs = IdGroups(id_graph.outputGroups(expr_group)); + + ss << indent(indent_size + 1) << toInlineString(inputs.vector()) << " --" + << toString(expr_group, 0, with_ptr) << "--> " + << toInlineString(outputs.vector()) << "\n"; + } + + ss << indent(indent_size) << "}"; + return ss.str(); +} + +std::string toString( + const IdGraph& id_graph, + const ExprGroups& expr_groups, + int indent_size, + bool with_ptr) { std::stringstream ss; - ss << /* ptrStringShort(&expr_groups) <<*/ "(exprs) {"; + + // Track position in expr_groups and its min iter domain name in the set + std::vector> group_name_info; + + unsigned int pos = 0; + + for (auto expr_group : expr_groups) { + unsigned int min_id_name = std::numeric_limits::max(); + for (auto id : *expr_group) { + if (id->name() < min_id_name) { + min_id_name = id->name(); + } + } + group_name_info.push_back(std::make_pair(min_id_name, pos++)); + } + + ss << indent(indent_size) << "(exprgs){\n"; + + // Sort based on minimum id in the group + std::sort(group_name_info.begin(), group_name_info.end()); + for (auto i : c10::irange(group_name_info.size())) { auto pos = group_name_info[i].second; - ss << " " << exprGroupStringShort(id_graph, expr_groups.vector()[pos]) - << "\n"; + auto expr_group = expr_groups.vector()[pos]; + + auto inputs = IdGroups(id_graph.inputGroups(expr_group)); + auto outputs = IdGroups(id_graph.outputGroups(expr_group)); + + ss << indent(indent_size + 1) << toInlineString(inputs.vector()) << " --" + << toString(expr_group, 0, with_ptr) << "--> " + << toInlineString(outputs.vector()) << "\n"; } - ss << "}"; + ss << indent(indent_size) << "}"; return ss.str(); } -std::string exprGroups(const IdGraph& id_graph) { +std::string idGroupsString( + const IdGraph& id_graph, + int indent_size, + bool with_ptr) { + IdGroups id_groups( + id_graph.disjointIdSets().disjointSets().begin(), + id_graph.disjointIdSets().disjointSets().end()); + return toString(id_groups, indent_size, with_ptr); +} +std::string exprGroupsString( + const IdGraph& id_graph, + int indent_size, + bool with_ptr) { ExprGroups expr_groups( id_graph.disjointExprSets().disjointSets().begin(), id_graph.disjointExprSets().disjointSets().end()); - return exprGroupsStringShort(id_graph, expr_groups); + return toString(id_graph, expr_groups, indent_size, with_ptr); } -std::string definitionsToString(const IdGraph& id_graph) { - std::stringstream ss; +std::string definitionsString( + const IdGraph& id_graph, + int indent_size, + bool with_ptr) { ExprGroups defs; for (auto id_group : id_graph.disjointIdSets().disjointSets()) { auto definition_pair = id_graph.iterDomainGroupDefinitions(id_group); @@ -205,28 +312,26 @@ std::string definitionsToString(const IdGraph& id_graph) { } } } - for (auto expr : defs) { - ss << exprGroupStringShort(id_graph, expr) << std::endl; - } - return ss.str(); + return toString(id_graph, defs, indent_size, with_ptr); } -std::string usesToString(const IdGraph& id_graph) { - std::stringstream ss; - +std::string usesString( + const IdGraph& id_graph, + int indent_size, + bool with_ptr) { + ExprGroups uses; for (auto id_group : id_graph.disjointIdSets().disjointSets()) { - auto uses_pair = id_graph.iterDomainGroupUses(id_group); - ss << idGroupStringShort(id_group) << std::endl; - if (uses_pair.second) { - for (auto expr_group : uses_pair.first) { - ss << " " << exprGroupStringShort(id_graph, expr_group) << std::endl; + auto definition_pair = id_graph.iterDomainGroupUses(id_group); + if (definition_pair.second) { + for (auto expr_group : definition_pair.first) { + uses.pushBack(expr_group); } } } - return ss.str(); + return toString(id_graph, uses, indent_size, with_ptr); } -} // namespace debug_string +} // namespace debug namespace { @@ -1010,11 +1115,12 @@ std::pair IdGraph::iterDomainGroupUses( return std::make_pair(uses_it->second, true); } -// TODO: Improve and extend to include other information. std::string IdGraph::toString() const { std::stringstream ss; ss << "IdGraph { \n"; - ss << "Disjoint Id Set " << disjoint_ids_.toString() << std::endl; + ss << "Disjoint Ids:\n" + << debug::idGroupsString(*this, 1) << "\n\nDisjoint Expression groups:\n" + << debug::exprGroupsString(*this, 1) << std::endl; ss << " } IdGraph\n" << std::endl; return ss.str(); } @@ -1047,7 +1153,6 @@ std::vector> IdGraph::isTrivialExpr(Expr* expr) { return mapped_ids; } -// TODO: Add explicit id_definitions_ and id_uses_ void IdGraph::initializeId( IterDomain* id, const VectorOfUniqueEntries& definitions, @@ -1409,11 +1514,11 @@ void IdGraph::eraseExprGroup(ExprGroup expr_group) { TORCH_INTERNAL_ASSERT( unique_definitions_.find(id_group) != unique_definitions_.end(), "Broken definitions, couldn't find entry for id group, ", - debug_string::idGroupStringShort(id_group)); + debug::toString(id_group, 0, true)); TORCH_INTERNAL_ASSERT( unique_uses_.find(id_group) != unique_uses_.end(), "Broken uses, couldn't find entry for id group, ", - debug_string::idGroupStringShort(id_group)); + debug::toString(id_group, 0, true)); unique_definitions_[id_group].erase(expr_group); unique_uses_[id_group].erase(expr_group); @@ -2473,8 +2578,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } std::cout << "Loop groups: " - << debug_string::idGroups(idGraph(IdMappingMode::LOOP)) - << std::endl; + << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) << std::endl; // Terminal loop ids are iteration domains in each loop group that: // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a @@ -2847,8 +2951,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } std::cout << "New loop groups:" << std::endl; - std::cout << debug_string::idGroups(idGraph(IdMappingMode::LOOP)) - << std::endl; + std::cout << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) << std::endl; { // Update iel_promotion_map since we changed the loop map the IdGroup key is @@ -2981,19 +3084,18 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::stringstream err_msg; err_msg << "\n ERROR Loop promotion map build. Could not find promotion for loop group:\n "; - err_msg << debug_string::idGroupStringShort(loop_group); + err_msg << debug::toString(loop_group, 0, true); err_msg << "\nnone of the terminal iter domains of this group:\n "; for (auto entry : exact_promoted_terminal_ids) { auto terminal_id_group = entry.first; auto covered_id_groups = exact_covered_ids.at(terminal_id_group); - err_msg << " " << debug_string::idGroupStringShort(terminal_id_group) - << " -(covers)-> " - << debug_string::idGroupsStringShortInline(covered_id_groups) + err_msg << " " << debug::toString(terminal_id_group, 0, true) + << " -(covers)-> " << debug::toString(covered_id_groups) << std::endl; } err_msg << "iter domains in this group cover all id groups:\n"; for (auto covered_group : loop_group_covered_ids) { - err_msg << " " << debug_string::idGroupStringShort(covered_group); + err_msg << " " << debug::toString(covered_group, 0, true); } // TORCH_INTERNAL_ASSERT(false, err_msg.str()); } else { @@ -3005,7 +3107,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_graph_copy_promotion_map.find(loop_group) != loop_graph_copy_promotion_map.end()) { - std::cout << debug_string::idGroupStringShort(loop_group) << " -> " + std::cout << debug::toString(loop_group, 0, true) << " -> " << loop_graph_copy_promotion_map[loop_group]->toString() << std::endl; } @@ -3150,7 +3252,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (iel_promotion_map.find(group) == iel_promotion_map.end()) { continue; } - std::cout << debug_string::idGroupStringShort(group) << " -> " + std::cout << debug::toString(group, 0, true) << " -> " << iel_promotion_map.at(group)->toString() << std::endl; } @@ -3253,7 +3355,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { TORCH_INTERNAL_ASSERT( covered_it != exact_covered_ids.end(), "Missing map entry in analysis for: ", - debug_string::idGroupStringShort(exact_group)); + debug::toString(exact_group, 0, true)); return covered_it->second; }; @@ -3327,14 +3429,12 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { if (loop_promotion_id == nullptr) { std::stringstream err_msg; err_msg << "\nCould not find promotion for loop group:\n "; - err_msg << debug_string::idGroupStringShort(loop_group); + err_msg << debug::toString(loop_group, 0, true); err_msg << "\nnone of the candidate iter domains of this group:\n "; err_msg << " " << VectorOfUniqueEntries(candidate_ids).toString(); err_msg << "\n cover all id groups that the loop group covers:\n"; - err_msg << " " - << debug_string::idGroupsStringShort(all_covered_exact_groups) - << std::endl; + err_msg << " " << debug::toString(all_covered_exact_groups) << std::endl; TORCH_INTERNAL_ASSERT(false, err_msg.str()); } @@ -3347,7 +3447,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { loop_graph_copy_promotion_map.end()) { continue; } - std::cout << debug_string::idGroupStringShort(group) << " -> " + std::cout << debug::toString(group, 0, true) << " -> " << loop_graph_copy_promotion_map.at(group)->toString() << std::endl; } @@ -3892,13 +3992,12 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::cout << "All indexing expressions (on the index graph): " << std::endl; auto index_expr_groups = idGraph(IdMappingMode::INDEX).toGroups(all_index_exprs); - std::cout << debug_string::exprGroupsStringShort( - idGraph(IdMappingMode::INDEX), index_expr_groups) + std::cout << debug::toString(idGraph(IdMappingMode::INDEX), index_expr_groups) << std::endl; std::cout << "All iter domains (on the index graph): " << std::endl; auto index_id_groups = idGraph(IdMappingMode::INDEX).toGroups(all_index_ids); - std::cout << debug_string::idGroupsStringShort(index_id_groups) << std::endl; + std::cout << debug::toString(index_id_groups) << std::endl; // std::cout << "All iter domains that would be indexed: " // << all_index_ids.toString() << std::endl; diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index 2963164b165..627faedd175 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -246,6 +246,75 @@ class TORCH_CUDA_CU_API IdGraph { std::unordered_set view_rfactor_ids_; }; +// Debuging print functions +namespace debug { +std::string toString( + const std::vector& id_group, + int indent_size = 0); +std::string toString( + const IdGroup& id_group, + int indent_size = 0, + bool with_ptr = false); + +std::string toString( + const std::vector& id_groups, + int indent_size = 0, + bool with_ptr = false); + +std::string toString( + const IdGroups& id_groups, + int indent_size = 0, + bool with_ptr = false); + +std::string toInlineString(const std::vector& id_groups); +std::string toInlineString(const IdGroups& id_groups); + +std::string toString(const std::vector& expr_group, int indent_size = 0); +std::string toString( + const ExprGroup& expr_group, + int indent_size = 0, + bool with_ptr = false); + +std::string toString( + const IdGraph& id_graph, + const std::vector& expr_group, + int indent_size = 0, + bool with_ptr = false); +std::string toString( + const IdGraph& id_graph, + const ExprGroup& expr_groups, + int indent_size = 0, + bool with_ptr = false); + +std::string toString( + const IdGraph& id_graph, + const std::vector& expr_groups, + int indent_size = 0, + bool with_ptr = false); +std::string toString( + const IdGraph& id_graph, + const ExprGroups& expr_groups, + int indent_size = 0, + bool with_ptr = false); + +std::string idGroupsString( + const IdGraph& id_graph, + int indent_size = 0, + bool with_ptr = false); +std::string exprGroupsString( + const IdGraph& id_graph, + int indent_size = 0, + bool with_ptr = false); +std::string definitionsString( + const IdGraph& id_graph, + int indent_size = 0, + bool with_ptr = false); +std::string usesString( + const IdGraph& id_graph, + int indent_size = 0, + bool with_ptr = false); +} // namespace debug + // Iterates through an IterDomain Graph in topological order, calling handle on // all Id and all Expr groups in a forward topological order. // From 32cbc5bdf1fc6b1db49864fd03283fb14b95cbaa Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 26 Apr 2023 14:23:27 -0400 Subject: [PATCH 020/178] Cleanup. --- csrc/id_graphs.cpp | 258 ++++++++++++--------------------------------- 1 file changed, 67 insertions(+), 191 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 0280ddc12d6..d0730c29a64 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -1753,14 +1753,36 @@ void IterDomainGraphs::buildIterDomainDefinitionsAndUses( } } -// TODO: Extend to include other information. std::string IterDomainGraphs::toString() const { + // Figure out which graphs are already initialized to make sure we add the new + // expression to them. + std::vector initialized_modes; + for (auto mode : kIdMappingModes) { + auto graph_it = id_graphs_.find(mode); + if (graph_it == id_graphs_.end()) { + continue; + } + + auto& graph = graph_it->second; + if (graph.disjointIdSets().disjointSetMap().empty()) { + continue; + } + + initialized_modes.push_back(mode); + } + std::stringstream ss; ss << "IterDomainGraphs { \n"; - // for (auto set : disjoint_ids_) { - // ss << "Set " << set.first << ": " << std::endl; - // ss << set.second.toString() << std::endl; - // } + for (auto mode : initialized_modes) { + std::stringstream ss; + ss << " IdGraph " << mode << "{ \n"; + ss << " Disjoint Ids:\n" + << debug::idGroupsString(idGraph(mode), 2) + << "\n Disjoint Expression groups:\n" + << debug::exprGroupsString(idGraph(mode), 2) << std::endl; + ss << " } IdGraph\n" << std::endl; + return ss.str(); + } ss << " } IterDomainGraphs\n" << std::endl; return ss.str(); } @@ -1869,9 +1891,8 @@ Expr* IterDomainGraphs::addReplayAs( return replay; } -// Generate a new expr with the IterDomain outputs provided and IterDomain -// inputs that exactly match expr->inputs - +// Generate a new expr with the IterDomain inputs/outputs replaced based on map. +// Replaced inputs/outputs should almost exact match with provided expr. Expr* IterDomainGraphs::addExprWithReplacement( const std::unordered_map& old_2_new_ids, Expr* old_expr) { @@ -1952,38 +1973,38 @@ Expr* IterDomainGraphs::addExprWithReplacement( // Create the new expression with provided outputs auto replay = ReplacementTransformCloner::clone(replacement_map, old_expr); + // Add new output iter domains to id_definitions_/id_uses_ of IdGraphs for (auto out_id : ir_utils::filterByType(replay->outputs())) { id_definitions_[out_id].pushBack(replay); id_uses_[out_id]; } - // Add the expression to the uses of the inputs + // Add new input iter domains to id_definitions_/id_uses_ of IdGraphs for (auto inp_id : ir_utils::filterByType(replay->inputs())) { id_definitions_[inp_id]; id_uses_[inp_id].pushBack(replay); } - // TODO: Update comments - // Initialize output iter domains in the graphs + // Update all the initialized graph mappings for (auto mode : initialized_modes) { auto& graph = idGraph(mode); graph.disjointExprSets().initializeSet(replay); auto replay_group = graph.disjointExprSet(replay).first; + // Initialize any non-existant input ids, update existing ones for (auto inp_id : ir_utils::filterByType(replay->inputs())) { if (!graph.disjointIdSets().mappingExists(inp_id)) { // inp_id is not initialized in the map, initialize it graph.initializeId(inp_id, {}, {replay}); } else { - // inp_id is already initialized add the replay as a unique use of its - // group. + // Update unique uses of existing input ids auto inp_group = graph.disjointIdSet(inp_id).first; graph.uniqueUses()[inp_group].pushBack(replay_group); } } - // Update definitions in the graph of the outputs + // Initialize any non-existant output ids, update existing ones for (auto out_id : ir_utils::filterByType(replay->outputs())) { if (!graph.disjointIdSets().mappingExists(out_id)) { // out_id is not initialized in the map, initialize it @@ -1996,11 +2017,11 @@ Expr* IterDomainGraphs::addExprWithReplacement( } } - // We expect that inputs or outputs were replaced by iter domains that - // already exist in the graphs. If the inputs were replaced we want to - // replay forward through the newly added expression. If the outputs were - // replaced we want to replay backwards (towards inputs) instead. + // If the inputs were replaced we want to map through forward the newly + // added expression. If the outputs were replaced we want to map through + // backwards the newly added expression. + // Forward VectorOfUniqueEntries representative_uses; for (auto in : ir_utils::filterByType(replay->inputs())) { auto uses_pair = graph.iterDomainGroupUses(graph.disjointIdSet(in).first); @@ -2018,6 +2039,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( graph.maybeMapThroughExprs(rep_use, replay, true); } + // Backwards VectorOfUniqueEntries representative_defs; for (auto out : ir_utils::filterByType(replay->outputs())) { auto defs_pair = @@ -2211,26 +2233,32 @@ void IterDomainGraphs::buildAlmostExactMap() { } void IterDomainGraphs::validateAndPropagatePType() const { - for (const auto& loop_disjoint_set : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - ParallelType common_ptype = ParallelType::Serial; - for (auto id : loop_disjoint_set->vector()) { - auto id_ptype = id->getParallelType(); - TORCH_INTERNAL_ASSERT( - id_ptype == common_ptype || id_ptype == ParallelType::Serial || - common_ptype == ParallelType::Serial, - "Issue validating parallel type disjoint ptype is, ", - common_ptype, - " but found in the set the id: ", - id->toString()); - common_ptype = - common_ptype == ParallelType::Serial ? id_ptype : common_ptype; - } - - for (auto id : loop_disjoint_set->vector()) { - id->parallelize(common_ptype); - } - } + // TODO: This needs to be done when the loop map is correctly defined to do + // this. The loop map gets built, but then later pulls in iter domains that + // are not inlined. Parallel propagate should be done on inlined iter domains, + // and the loop map shouldn't group together indices of iter domains that are + // not inlined and parallelized differently. + // + // for (const auto& loop_disjoint_set : + // idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + // ParallelType common_ptype = ParallelType::Serial; + // for (auto id : loop_disjoint_set->vector()) { + // auto id_ptype = id->getParallelType(); + // TORCH_INTERNAL_ASSERT( + // id_ptype == common_ptype || id_ptype == ParallelType::Serial || + // common_ptype == ParallelType::Serial, + // "Issue validating parallel type disjoint ptype is, ", + // common_ptype, + // " but found in the set the id: ", + // id->toString()); + // common_ptype = + // common_ptype == ParallelType::Serial ? id_ptype : common_ptype; + // } + + // for (auto id : loop_disjoint_set->vector()) { + // id->parallelize(common_ptype); + // } + // } } void IterDomainGraphs::build( @@ -2291,8 +2319,7 @@ void IterDomainGraphs::build( // necessary. buildLoopPromotionMap(tv_exprs); - TORCH_INTERNAL_ASSERT(false); - + // Doesn't do anything right now validateAndPropagatePType(); } @@ -2333,146 +2360,6 @@ std::unordered_map resolvedRootBroadcasts( } // namespace -std::unordered_map IterDomainGraphs:: - buildCoveredAlmostExact() { - // Helper functions. - auto producerIdGroups = [&](IdGroup id_group) { - IdGroups producer_groups; - auto definition_pair_it = idGraph(IdMappingMode::ALMOSTEXACT) - .iterDomainGroupDefinitions(id_group); - if (!definition_pair_it.second) { - return producer_groups; - } - for (auto def_group : definition_pair_it.first) { - auto inp_groups = - idGraph(IdMappingMode::ALMOSTEXACT).inputGroups(def_group); - producer_groups.pushBack(inp_groups.begin(), inp_groups.end()); - } - return producer_groups; - }; - - auto consumerIdGroups = [&](IdGroup id_group) { - IdGroups consumer_groups; - auto uses_pair_it = - idGraph(IdMappingMode::ALMOSTEXACT).iterDomainGroupUses(id_group); - if (!uses_pair_it.second) { - return consumer_groups; - } - for (auto use_group : uses_pair_it.first) { - auto out_groups = - idGraph(IdMappingMode::ALMOSTEXACT).outputGroups(use_group); - consumer_groups.pushBack(out_groups); - } - return consumer_groups; - }; - - // Start at terminating inputs of the almost exact graph and almost exact - // entries that are rfactor nodes. Propagate and accumulate these nodes - // through consumers. - // - // The almost exact entries covered by an iteration domain is effectively - // all the iteration domains this domain relies on. Initialize broadcast - // entries to not cover any domains. - std::unordered_map covered_almost_exact_entries; - - // We will traverse over the almost exact set expressions. Save where we - // want to start traversal: - IdGroups to_visit; - // Initialize covered groups - for (auto almost_exact_set : - idGraph(IdMappingMode::ALMOSTEXACT).disjointIdSets().disjointSets()) { - // what broadcast domains cover doesn't matter - if (std::all_of( - almost_exact_set->begin(), - almost_exact_set->end(), - [&](IterDomain* id) { return id->isBroadcast(); })) { - covered_almost_exact_entries[almost_exact_set] = {}; - continue; - } - - // Initialize rfactor domains to cover themselves only - if (std::any_of( - almost_exact_set->begin(), - almost_exact_set->end(), - [&](IterDomain* id) { - return viewRfactorIds().find(id) != viewRfactorIds().end(); - })) { - covered_almost_exact_entries[almost_exact_set] = {almost_exact_set}; - to_visit.pushBack(consumerIdGroups(almost_exact_set)); - continue; - } - - // Initialize any groups that don't have a definition except (potentialy) - // ones that traverse back to this set. - auto def_pair = idGraph(IdMappingMode::ALMOSTEXACT) - .iterDomainGroupDefinitions(almost_exact_set); - if (!def_pair.second) { - covered_almost_exact_entries[almost_exact_set] = {almost_exact_set}; - to_visit.pushBack(consumerIdGroups(almost_exact_set)); - continue; - } - - for (auto def : def_pair.first) { - // If all definitions are self mapping (can happen with - // merging our splitting with a broadcast/ dim of size 1) - // then this group is an input. - auto inp_groups = idGraph(IdMappingMode::ALMOSTEXACT).inputGroups(def); - if (std::find(inp_groups.begin(), inp_groups.end(), almost_exact_set) == - inp_groups.end()) { - goto loop_continue; - } - } - - covered_almost_exact_entries[almost_exact_set] = {almost_exact_set}; - to_visit.pushBack(consumerIdGroups(almost_exact_set)); - - loop_continue:; - } - - // Starting from the initialized inputs propagate forward from those inputs to - // mark what every iter domain in the graph covers. This will be used in later - // analysis. - while (to_visit.size() > 0) { - IdGroups still_to_visit; - bool something_processed = false; - while (to_visit.size() > 0) { - auto currently_visiting = to_visit.popFront(); - if (covered_almost_exact_entries.find(currently_visiting) != - covered_almost_exact_entries.end()) { - continue; - } - auto producer_ids = producerIdGroups(currently_visiting); - producer_ids.erase(currently_visiting); - IdGroups currently_visiting_covered; - for (auto producer_id : producer_ids) { - auto producer_covered_it = - covered_almost_exact_entries.find(producer_id); - if (producer_covered_it == covered_almost_exact_entries.end()) { - still_to_visit.pushBack(currently_visiting); - goto inner_while_continue; - } - for (auto entry : producer_covered_it->second) { - if (currently_visiting_covered.has(entry)) { - continue; - } - } - currently_visiting_covered.pushBack(producer_covered_it->second); - } - covered_almost_exact_entries[currently_visiting] = - currently_visiting_covered; - to_visit.pushBack(consumerIdGroups(currently_visiting)); - something_processed = true; - - inner_while_continue:; - } - TORCH_INTERNAL_ASSERT( - still_to_visit.empty() || something_processed, - "Entered infinite loop."); - std::swap(still_to_visit, to_visit); - } - return covered_almost_exact_entries; -} - void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { idGraph(IdMappingMode::LOOP) = initializeIdGraph(); // See Indexing20 example for why we shouldn't propagate when generating loop @@ -3983,12 +3870,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - // std::cout << "All indexing expressions that need to be processed: " - // << std::endl; - // for (auto expr : all_index_exprs) { - // std::cout << expr->toString(); - // } - std::cout << "All indexing expressions (on the index graph): " << std::endl; auto index_expr_groups = idGraph(IdMappingMode::INDEX).toGroups(all_index_exprs); @@ -3998,11 +3879,6 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::cout << "All iter domains (on the index graph): " << std::endl; auto index_id_groups = idGraph(IdMappingMode::INDEX).toGroups(all_index_ids); std::cout << debug::toString(index_id_groups) << std::endl; - - // std::cout << "All iter domains that would be indexed: " - // << all_index_ids.toString() << std::endl; - - TORCH_INTERNAL_ASSERT(false); } } // namespace nvfuser From db3ba36f27b747fa609e35c162f4cf0e7ce70d8b Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 30 Apr 2023 08:22:50 -0400 Subject: [PATCH 021/178] Refactoring mono-function, in progress, functional. --- csrc/id_graphs.cpp | 624 +++++++++++++++++++++++++-------------------- csrc/id_graphs.h | 67 ++++- 2 files changed, 413 insertions(+), 278 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index d0730c29a64..a16598504dd 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -2232,103 +2232,73 @@ void IterDomainGraphs::buildAlmostExactMap() { idGraph(IdMappingMode::ALMOSTEXACT).mapThroughTrivialExprs(); } -void IterDomainGraphs::validateAndPropagatePType() const { - // TODO: This needs to be done when the loop map is correctly defined to do - // this. The loop map gets built, but then later pulls in iter domains that - // are not inlined. Parallel propagate should be done on inlined iter domains, - // and the loop map shouldn't group together indices of iter domains that are - // not inlined and parallelized differently. - // - // for (const auto& loop_disjoint_set : - // idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - // ParallelType common_ptype = ParallelType::Serial; - // for (auto id : loop_disjoint_set->vector()) { - // auto id_ptype = id->getParallelType(); - // TORCH_INTERNAL_ASSERT( - // id_ptype == common_ptype || id_ptype == ParallelType::Serial || - // common_ptype == ParallelType::Serial, - // "Issue validating parallel type disjoint ptype is, ", - // common_ptype, - // " but found in the set the id: ", - // id->toString()); - // common_ptype = - // common_ptype == ParallelType::Serial ? id_ptype : common_ptype; - // } - - // for (auto id : loop_disjoint_set->vector()) { - // id->parallelize(common_ptype); - // } - // } -} - -void IterDomainGraphs::build( - const std::vector& exprs, - const std::vector& additional_tvs) { - // Initialize the required sets as if a permissive relationship is never - // found, then querying an empty permissive map will fail later. - // Initialize disjoint sets - for (auto mode : kIdMappingModes) { - id_graphs_[mode] = IdGraph(); +void IterDomainGraphs::validatePTypes( + const std::vector& all_tvs) const { + VectorOfUniqueEntries leaf_ids; + for (auto tv : all_tvs) { + leaf_ids.pushBack(tv->domain()->domain()); } - std::vector tv_exprs; - - std::copy_if( - exprs.begin(), exprs.end(), std::back_inserter(tv_exprs), [](Expr* expr) { - TORCH_INTERNAL_ASSERT(expr != nullptr); - return ir_utils::isTvOp(expr); - }); + for (const auto& disjoint_set : + idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + for (auto id : disjoint_set->vector()) { + auto id_ptype = id->getParallelType(); - auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs); - if (additional_tvs.size() > 0) { - std::unordered_set all_added_tvs( - all_tvs.begin(), all_tvs.end()); - for (auto additional_tv : additional_tvs) { - if (all_added_tvs.find(additional_tv) == all_added_tvs.end()) { - all_tvs.push_back(additional_tv); - } + TORCH_INTERNAL_ASSERT( + leaf_ids.has(id) || id_ptype == ParallelType::Serial, + "Invalid parallelization of non leaf iter domain: ", + id->toString()); } } +} - if (all_tvs.empty()) { - return; - } +void IterDomainGraphs::propagateLoopPTypes() const { + for (const auto& loop_disjoint_set : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + ParallelType common_ptype = ParallelType::Serial; + for (auto id : loop_disjoint_set->vector()) { + auto id_ptype = id->getParallelType(); - FusionGuard fg(all_tvs.front()->fusion()); - // Add uses and definitions to all iter domains. - buildIterDomainDefinitionsAndUses(all_tvs); + TORCH_INTERNAL_ASSERT( + id_ptype == common_ptype || id_ptype == ParallelType::Serial || + common_ptype == ParallelType::Serial, + "Issue validating parallel type disjoint ptype is, ", + common_ptype, + " but found in the set the id: ", + id->toString()); - // Initialize the maps with all the IterDomains used in the provded - // expressions. - idGraph(IdMappingMode::EXACT) = initializeIdGraph(); + common_ptype = + common_ptype == ParallelType::Serial ? id_ptype : common_ptype; + } - buildExactMap(tv_exprs); - buildAlmostExactMap(); - buildPermissiveMap(tv_exprs); + for (auto id : loop_disjoint_set->vector()) { + id->parallelize(common_ptype); + } + } +} - // Permissive graph needs the trivial exprs from the almost exact graph to - // build correctly. Once built though we can remove the trivial expressions - // from the almost exact graph. - idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); +namespace { +struct StatefulLoweringInfo { + // Tracks all p2c mappings in permissive maps even those not inlined between + // producer and consumer + std::unordered_map> + p2c_permissive_maps; - // Only build loop map during lowering - if (FusionGuard::getCurFusion()->isA()) { - FusionGuard::getCurFusion()->print(std::cout, true); - // Find loops that need to be promoted because of broadcast resolution, - // figure out what that resolution should look like, compute IDs for it if - // necessary. - buildLoopPromotionMap(tv_exprs); + // All consumer ids in a deterministic order (ignores fusion->inputs()) + VectorOfUniqueEntries ordered_c_ids; - // Doesn't do anything right now - validateAndPropagatePType(); - } + // p2c mappings through the fusion within (including dependencies of) inlined + // leaf domains. + std::unordered_map> + p2c_ca_permissive_maps; - // Debug, make sure there's no self mapping in TensorView's during lowering - // that would invalidate lowering assumptions. - self_mapping_info_ = findFirstSelfMapping(all_tvs, *this); -} + // All producer ids within (including dependencies of) inlined leaf domains, + // used for deterministic order + VectorOfUniqueEntries ordered_p_ca_ids; -namespace { + std::unordered_map> + p2c_root_broadcast_resolution_map; +}; // Returns the root producer iteration domains that are resolved by provided // consumer @@ -2358,33 +2328,11 @@ std::unordered_map resolvedRootBroadcasts( return resolved_bcast_map; } -} // namespace - -void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { - idGraph(IdMappingMode::LOOP) = initializeIdGraph(); - // See Indexing20 example for why we shouldn't propagate when generating loop - // groups - idGraph(IdMappingMode::LOOP).disableExprPropagation(); - - std::unordered_map> - p2c_root_broadcast_resolution_map; - - // Track all of the p2c mappings through the fusion within those inlined - // domains. - std::unordered_map> - p2c_ca_permissive_maps; - - // All producer ids in a deterministic order - VectorOfUniqueEntries ordered_p_ca_ids; - - // All ids in a deterministic order - VectorOfUniqueEntries ordered_c_ids; - - // Tracks all p2c mappings in permissive maps even those not inlined between - // producer and consumer - std::unordered_map> - p2c_permissive_maps; - +StatefulLoweringInfo buildInfo( + const std::vector& exprs, + const IdGraph& exact_graph, + const IdGraph& permissive_graph) { + StatefulLoweringInfo info; // Grab inlining relationships for (auto expr : exprs) { for (auto producer : ir_utils::filterByType(expr->inputs())) { @@ -2407,97 +2355,163 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::cout << "Producer: " << producer->toString() << "\n " << all_producer_ca_deps.toString() << std::endl; - ordered_p_ca_ids.pushBack(all_producer_ca_deps); + info.ordered_p_ca_ids.pushBack(all_producer_ca_deps); for (auto consumer : ir_utils::filterByType(expr->outputs())) { auto resolved_bcast_map = resolvedRootBroadcasts(producer, consumer); for (auto entry : resolved_bcast_map) { - p2c_root_broadcast_resolution_map[entry.first].pushBack(entry.second); - for (auto other_exact_bcast : *idGraph(IdMappingMode::EXACT) - .disjointIdSet(entry.first) - .first) { + info.p2c_root_broadcast_resolution_map[entry.first].pushBack( + entry.second); + for (auto other_exact_bcast : + *exact_graph.disjointIdSet(entry.first).first) { if (all_producer_ca_deps.has(other_exact_bcast)) { - p2c_root_broadcast_resolution_map[other_exact_bcast].pushBack( - entry.second); + info.p2c_root_broadcast_resolution_map[other_exact_bcast] + .pushBack(entry.second); } } } auto all_producer_ids = ir_utils::allIDsOf(producer); auto all_consumer_ids = ir_utils::allIDsOf(consumer); - ordered_c_ids.pushBack(all_consumer_ids); + info.ordered_c_ids.pushBack(all_consumer_ids); - auto p2c_permissive_map = - idGraph(IdMappingMode::PERMISSIVE) - .buildMapBetween(all_producer_ids, all_consumer_ids); + auto p2c_permissive_map = permissive_graph.buildMapBetween( + all_producer_ids, all_consumer_ids); for (auto entry : p2c_permissive_map) { if (entry.second.size() == 0) { continue; } if (all_producer_ca_deps.has(entry.first)) { - p2c_ca_permissive_maps[entry.first].pushBack(entry.second); + info.p2c_ca_permissive_maps[entry.first].pushBack(entry.second); } - p2c_permissive_maps[entry.first].pushBack(entry.second); + info.p2c_permissive_maps[entry.first].pushBack(entry.second); } for (auto entry : p2c_permissive_map) { if (entry.second.size() == 0) { continue; } - p2c_permissive_maps[entry.first].pushBack(entry.second); + info.p2c_permissive_maps[entry.first].pushBack(entry.second); } } } } + return info; +} - // Make sure this is called in a deterministic order. Build all inlined - // relationships in loop graph. - for (auto p_id : ordered_p_ca_ids) { - auto entry_it = p2c_ca_permissive_maps.find(p_id); - if (entry_it != p2c_ca_permissive_maps.end()) { - auto c_ids = entry_it->second; - for (auto c_id : c_ids) { - idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); +} // namespace + +void IterDomainGraphs::build( + const std::vector& exprs, + const std::vector& additional_tvs) { + // Initialize the required sets as if a permissive relationship is never + // found, then querying an empty permissive map will fail later. + // Initialize disjoint sets + for (auto mode : kIdMappingModes) { + id_graphs_[mode] = IdGraph(); + } + + std::vector tv_exprs; + + std::copy_if( + exprs.begin(), exprs.end(), std::back_inserter(tv_exprs), [](Expr* expr) { + TORCH_INTERNAL_ASSERT(expr != nullptr); + return ir_utils::isTvOp(expr); + }); + + auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs); + if (additional_tvs.size() > 0) { + std::unordered_set all_added_tvs( + all_tvs.begin(), all_tvs.end()); + for (auto additional_tv : additional_tvs) { + if (all_added_tvs.find(additional_tv) == all_added_tvs.end()) { + all_tvs.push_back(additional_tv); } } } - std::cout << "Loop groups: " - << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) << std::endl; + if (all_tvs.empty()) { + return; + } - // Terminal loop ids are iteration domains in each loop group that: - // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a - // consumer TV's iter domain maps to this domain in a way that that domain - // is also in the same loop group - // 2) Don't have a direct IterDomain consumer within the group - VectorOfUniqueEntries terminal_loop_ids; + FusionGuard fg(all_tvs.front()->fusion()); + // Add uses and definitions to all iter domains. + buildIterDomainDefinitionsAndUses(all_tvs); - // Case (1) - VectorOfUniqueEntries p2c_ca_terminal_loop_ids; - // Case (2) - VectorOfUniqueEntries id_consumer_terminal_loop_ids; + // Initialize the maps with all the IterDomains used in the provded + // expressions. + idGraph(IdMappingMode::EXACT) = initializeIdGraph(); + + buildExactMap(tv_exprs); + buildAlmostExactMap(); + buildPermissiveMap(tv_exprs); + // Permissive graph needs the trivial exprs from the almost exact graph to + // build correctly. Once built though we can remove the trivial expressions + // from the almost exact graph. + idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); + + // Only build loop map during lowering + if (FusionGuard::getCurFusion()->isA()) { + validatePTypes(all_tvs); + + FusionGuard::getCurFusion()->print(std::cout, true); + + StatefulLoweringInfo info = buildInfo( + tv_exprs, + idGraph(IdMappingMode::EXACT), + idGraph(IdMappingMode::PERMISSIVE)); + + initializeLoopMap(info); + std::cout << "Loop groups: " + << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) + << std::endl; + + std::cout << "Promoted groups: " + << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) + << std::endl; + + // Initial propagation of parallel types for inlined iter domains. Each time + // new expressions are replayed this needs to be run. The disjoint sets in + // the loop graph can only be joined after this point. + propagateLoopPTypes(); + + auto iel_promotion_map = buildInlinePromotions(info); + propagateLoopPTypes(); + + // Find loops that need to be promoted because of broadcast resolution, + // figure out what that resolution should look like, compute IDs for it if + // necessary. + buildLoopPromotionMap(tv_exprs, info, iel_promotion_map); + propagateLoopPTypes(); + } + + // Debug, make sure there's no self mapping in TensorView's during lowering + // that would invalidate lowering assumptions. + self_mapping_info_ = findFirstSelfMapping(all_tvs, *this); +} + +VectorOfUniqueEntries IterDomainGraphs::computeTerminalLoopIds( + const StatefulLoweringInfo info) { + VectorOfUniqueEntries terminal_loop_ids; for (auto group : idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { if (group->size() == 1) { - p2c_ca_terminal_loop_ids.pushBack(group->front()); - id_consumer_terminal_loop_ids.pushBack(group->front()); + terminal_loop_ids.pushBack(group->front()); } // Don't select producer iter domains for (auto loop_id : *group) { - if (p2c_ca_permissive_maps.find(loop_id) != - p2c_ca_permissive_maps.end()) { + if (info.p2c_ca_permissive_maps.find(loop_id) != + info.p2c_ca_permissive_maps.end()) { continue; } - p2c_ca_terminal_loop_ids.pushBack(loop_id); - auto uses_it = id_uses_.find(loop_id); if (uses_it == id_uses_.end()) { - id_consumer_terminal_loop_ids.pushBack(loop_id); + terminal_loop_ids.pushBack(loop_id); continue; } @@ -2517,14 +2531,59 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } if (!all_outs_in_loop_group) { - id_consumer_terminal_loop_ids.pushBack(loop_id); + terminal_loop_ids.pushBack(loop_id); + } + } + } + return terminal_loop_ids; +} + +IdGraph IterDomainGraphs::buildIntersection( + const IdGraph& graph0, + const IdGraph& graph1, + bool propagate_exprs) { + auto intersection = initializeIdGraph(); + if (!propagate_exprs) { + intersection.disableExprPropagation(); + } + for (auto exact_group : graph0.disjointIdSets().disjointSets()) { + auto set_size = exact_group->size(); + for (auto id0_i : c10::irange(set_size)) { + auto id0 = exact_group->vector()[id0_i]; + for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { + auto id1 = exact_group->vector()[id1_i]; + // id0 and id1 map in the almost exact map, if they also map in the loop + // graph, then add the mapping to the inersection + if (graph1.disjointIdSets().strictAreMapped(id0, id1)) { + intersection.mapIds(id0, id1); + } } } } + return intersection; +} - terminal_loop_ids = - p2c_ca_terminal_loop_ids.intersect(id_consumer_terminal_loop_ids); +void IterDomainGraphs::initializeLoopMap(StatefulLoweringInfo& info) { + idGraph(IdMappingMode::LOOP) = initializeIdGraph(); + // See Indexing20 example for why we shouldn't propagate when generating loop + // groups + idGraph(IdMappingMode::LOOP).disableExprPropagation(); + // Make sure this is called in a deterministic order. Build all inlined + // relationships in loop graph. + for (auto p_id : info.ordered_p_ca_ids) { + auto entry_it = info.p2c_ca_permissive_maps.find(p_id); + if (entry_it != info.p2c_ca_permissive_maps.end()) { + auto c_ids = entry_it->second; + for (auto c_id : c_ids) { + idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); + } + } + } +} + +std::unordered_map IterDomainGraphs:: + buildInlinePromotions(StatefulLoweringInfo& info) { // Make an intersection of the exact and loop map. This will group together // entries in each loop group that are exact with eachother. This provides a // better graph to do promotion and replays. @@ -2548,25 +2607,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // smaller groups and this algorithm scales with the number of groups * // (number of entries in groups ^ 2) - auto intersection_exact_loop_graph = initializeIdGraph(); - intersection_exact_loop_graph.disableExprPropagation(); - for (auto exact_group : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - auto set_size = exact_group->size(); - for (auto id0_i : c10::irange(set_size)) { - auto id0 = exact_group->vector()[id0_i]; - for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { - auto id1 = exact_group->vector()[id1_i]; - // id0 and id1 map in the almost exact map, if they also map in the loop - // graph, then add the mapping to the inersection - if (idGraph(IdMappingMode::LOOP) - .disjointIdSets() - .strictAreMapped(id0, id1)) { - intersection_exact_loop_graph.mapIds(id0, id1); - } - } - } - } + auto intersection_exact_loop_graph = buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); // Promotion logic is going to be on the intersection of the exact and loop // graph. We will generate a map on the entries of this graph so it's @@ -2590,10 +2632,10 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { IdGroups resolved_exact_groups; for (auto bcast_id : *iel_group) { auto p2c_root_broadcast_resolution_map_it = - p2c_root_broadcast_resolution_map.find(bcast_id); + info.p2c_root_broadcast_resolution_map.find(bcast_id); if (p2c_root_broadcast_resolution_map_it == - p2c_root_broadcast_resolution_map.end()) { + info.p2c_root_broadcast_resolution_map.end()) { continue; } @@ -2770,13 +2812,92 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } } + return iel_promotion_map; +} + +namespace { + +std::unordered_map updateMap( + const std::unordered_map stale_map, + IdGraph& new_graph) { + std::unordered_map new_map; + for (auto stale_entry : stale_map) { + auto stale_id_group = stale_entry.first; + auto new_groups = new_graph.toGroups(*stale_id_group); + TORCH_INTERNAL_ASSERT( + new_groups.size() == 1, + "\nUpdate map assumes that new graph is equivalent to old graph plus extra mappings.\n", + "i.e. all mappings in new_graph should exist in the graph stale_map was produced on.\n", + "old:", + debug::toString(stale_id_group), + "new: ", + debug::toString(new_groups)); + new_map[new_groups.front()] = stale_entry.second; + } + return new_map; +} + +// Returns for each IdGroup in provided IdGraph what the input IdGroups are +// traversing on definitions. Ignoring broadcast IdGroups and resetting inputs +// at RFactor IdGroups. +std::unordered_map computeCoveredGroups( + const IdGraph& graph, + std::unordered_set view_rfactor_ids) { + // Map from an exact iter domain group, to all the exact iter domain groups it + // covers + std::unordered_map covered_ids; + + for (auto id_group : graph.disjointIdSets().disjointSets()) { + // Initialize inputs + if (graph.uniqueDefinitions(id_group).empty()) { + covered_ids[id_group] = {id_group}; + } + + // Initialize rfactor groups + if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { + return view_rfactor_ids.find(id) != view_rfactor_ids.end(); + })) { + covered_ids[id_group] = {id_group}; + } + + // Initialize broadcast groups to empty + if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { + return id->isBroadcast(); + })) { + covered_ids[id_group] = {}; + } + } + + IdGraphStmtSort exact_stmt_sort(graph); + + for (auto exact_expr : exact_stmt_sort.exprs()) { + auto input_groups = graph.inputGroups(exact_expr); + + IdGroups covered; + for (auto inp_group : input_groups) { + covered.pushBack(covered_ids.at(inp_group)); + } + + for (auto output_group : graph.outputGroups(exact_expr)) { + covered_ids[output_group] = covered; + } + } + + return covered_ids; +} +}; // namespace +std::unordered_map IterDomainGraphs:: + buildLoopPromotionMap( + const std::vector& exprs, + StatefulLoweringInfo& info, + std::unordered_map stale_promotion_map) { // Opportunistically add non-inlined loop relationships where they don't // interfere with the loop groups. This should be on all p_ids that are not // p_ca_ids. - for (auto p_id : ordered_c_ids.subtract(ordered_p_ca_ids)) { - auto entry_it = p2c_permissive_maps.find(p_id); - if (entry_it == p2c_permissive_maps.end()) { + for (auto p_id : info.ordered_c_ids.subtract(info.ordered_p_ca_ids)) { + auto entry_it = info.p2c_permissive_maps.find(p_id); + if (entry_it == info.p2c_permissive_maps.end()) { continue; } auto c_ids = entry_it->second; @@ -2787,12 +2908,27 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Already mapped continue; } + // Grab all iter domains already in the loop groups for both iter // domains. auto loop_groups = idGraph(IdMappingMode::LOOP) .toGroups(VectorOfUniqueEntries{p_id, c_id}); + VectorOfUniqueEntries all_ids_in_groups; + + ParallelType common_ptype = + loop_groups.front()->front()->getParallelType(); + if (std::any_of( + loop_groups.begin() + 1, + loop_groups.end(), + [common_ptype](IdGroup id_group) { + return id_group->front()->getParallelType() != common_ptype; + })) { + // Parallel types don't match, cannot merge non-inlined loop groups. + continue; + } + for (auto loop_group : loop_groups) { all_ids_in_groups.pushBack(*loop_group); } @@ -2800,104 +2936,40 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Ignore new loop mappings from replays, we can still opportunistically // merge leaves if they already have a promoted id from replay associated // with them. - all_ids_in_groups = all_ids_in_groups.intersect(ordered_c_ids); + all_ids_in_groups = all_ids_in_groups.intersect(info.ordered_c_ids); // Grab the almost exact map of all iter domains in those loop groups auto ae_groups = idGraph(IdMappingMode::ALMOSTEXACT).toGroups(all_ids_in_groups); + // If there's no broadcast promotion within the loop group then all the // iter domains will be almost exact mapped with eachother. if (ae_groups.size() == 1) { idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); - std::cout << "Map2: " << p_id->toString() << " <-> " << c_id->toString() - << std::endl; } } } - // Need to update the iel_graph again since we've added operations to the - // exact and loop map. - intersection_exact_loop_graph = initializeIdGraph(); - intersection_exact_loop_graph.disableExprPropagation(); - for (auto exact_group : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - auto set_size = exact_group->size(); - for (auto id0_i : c10::irange(set_size)) { - auto id0 = exact_group->vector()[id0_i]; - for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { - auto id1 = exact_group->vector()[id1_i]; - // id0 and id1 map in the almost exact map, if they also map in the loop - // graph, then add the mapping to the inersection - if (idGraph(IdMappingMode::LOOP) - .disjointIdSets() - .strictAreMapped(id0, id1)) { - intersection_exact_loop_graph.mapIds(id0, id1); - } - } - } - } + // Need to use the intersection of exact and loop map again. + auto intersection_exact_loop_graph = buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); - std::cout << "New loop groups:" << std::endl; - std::cout << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) << std::endl; - - { - // Update iel_promotion_map since we changed the loop map the IdGroup key is - // invalid - std::unordered_map old_iel_promotion_map; - std::swap(iel_promotion_map, old_iel_promotion_map); - for (auto entry : old_iel_promotion_map) { - auto old_iel_group = entry.first; - auto id = entry.second; - iel_promotion_map[intersection_exact_loop_graph.toGroup( - old_iel_group->front())] = id; - } - } + // Update the promotion map + auto iel_promotion_map = + updateMap(stale_promotion_map, intersection_exact_loop_graph); // Map from an exact iter domain group, to all the exact iter domain groups it // covers - std::unordered_map exact_covered_ids; - - for (auto id_group : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - // Initialize inputs - if (idGraph(IdMappingMode::EXACT).uniqueDefinitions(id_group).empty()) { - exact_covered_ids[id_group] = {id_group}; - } + std::unordered_map exact_covered_ids = + computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); - // Initialize rfactor groups - if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { - return view_rfactor_ids_.find(id) != view_rfactor_ids_.end(); - })) { - exact_covered_ids[id_group] = {id_group}; - } + // Grab terminal iter domain in the loop groups. + VectorOfUniqueEntries terminal_loop_ids = + computeTerminalLoopIds(info); - // Initialize broadcast groups to empty - if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { - return id->isBroadcast(); - })) { - exact_covered_ids[id_group] = {}; - } - } - { - IdGraphStmtSort exact_stmt_sort(idGraph(IdMappingMode::EXACT)); - - for (auto exact_expr : exact_stmt_sort.exprs()) { - auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); - - IdGroups covered; - for (auto inp_group : input_groups) { - covered.pushBack(exact_covered_ids.at(inp_group)); - } - - for (auto output_group : - idGraph(IdMappingMode::EXACT).outputGroups(exact_expr)) { - exact_covered_ids[output_group] = covered; - } - } - } - // Loop promotion map is to prepare for IterDomain replays. Since these - // replays will modify the loop map, we operate on a copy of the loop map, - // not the original one. + // Loop promotion map is to prepare for IterDomain replays to resolve + // non-inlined loop groups. Since these replays will modify the loop map, we + // operate on a copy of the loop map, not the original one. auto loop_graph_copy = idGraph(IdMappingMode::LOOP); std::unordered_map loop_graph_copy_promotion_map; @@ -2908,27 +2980,31 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { continue; } - // We need to check the exact groups the terminal id's are in, but for - // promotion we want an iter domain within the loop group. Since exact - // group can traverse loop group boundaires, save a vector of the group - // and the iter domain. + // Grab all the (potentially promoted) terminal iter domains in this group. + // Save the exact group and the iter domain in this vector. std::vector> exact_promoted_terminal_ids; for (auto loop_id : *loop_group) { + // If not a terminal id in the group skip if (!terminal_loop_ids.has(loop_id)) { continue; } + // Grab the iel entry auto iel_set_pair = intersection_exact_loop_graph.disjointIdSet(loop_id); TORCH_INTERNAL_ASSERT(iel_set_pair.second); auto iel_group = iel_set_pair.first; + auto iel_promo_it = iel_promotion_map.find(iel_group); if (iel_promo_it == iel_promotion_map.end()) { + // If this terminal ID has a promotion, grab the promoted ID. auto promo_id_exact_it = idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); exact_promoted_terminal_ids.push_back( std::make_pair(promo_id_exact_it.first, loop_id)); } else { + // If this terminal ID doesn't have a promotion associated with it, save + // the terminal ID. auto promo_id_exact_it = idGraph(IdMappingMode::EXACT).disjointIdSet(iel_promo_it->second); TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); @@ -2937,14 +3013,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } - // All exact groups with iter domains in this loop group - IdGroups exact_groups; - for (auto loop_id : *loop_group) { - auto exact_set_pair = - idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); - TORCH_INTERNAL_ASSERT(exact_set_pair.second); - exact_groups.pushBack(exact_set_pair.first); - } + // Collect all the exact groups of the iter domains in the loop group + IdGroups exact_groups = idGraph(IdMappingMode::EXACT).toGroups(*loop_group); // All exact groups covered by all iter domains in this loop group IdGroups loop_group_covered_ids; @@ -3133,6 +3203,14 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { } } + // ==================================================================================== + // ==================================================================================== + // ==================================================================================== + // ==================================================================================== + // ==================================================================================== + // ==================================================================================== + // ==================================================================================== + std::cout << "Promotion map from second replay: " << std::endl; for (auto group : intersection_exact_loop_graph.disjointIdSets().disjointSets()) { @@ -3303,7 +3381,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Try any replayed IDs if we're still mising the promoted id. if (loop_promotion_id == nullptr) { - candidate_ids = loop_group->subtract(ordered_c_ids).vector(); + candidate_ids = loop_group->subtract(info.ordered_c_ids).vector(); for (auto candidate_id : candidate_ids) { if (all_covered_exact_groups .subtract(get_covered_exact_groups(candidate_id)) @@ -3602,7 +3680,7 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { // Promo id generated from running replay, we can use it for one of the // index groups. - if (!ordered_c_ids.has(promo_id) && !used_promo_ids.has(promo_id)) { + if (!info.ordered_c_ids.has(promo_id) && !used_promo_ids.has(promo_id)) { used_promo_ids.pushBack(promo_id); for (auto id : *id_group) { leaf_promotion_map[id] = promo_id; @@ -3879,6 +3957,8 @@ void IterDomainGraphs::buildLoopPromotionMap(const std::vector& exprs) { std::cout << "All iter domains (on the index graph): " << std::endl; auto index_id_groups = idGraph(IdMappingMode::INDEX).toGroups(all_index_ids); std::cout << debug::toString(index_id_groups) << std::endl; + + return {}; } } // namespace nvfuser diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index 627faedd175..2067179c3d5 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -36,6 +36,8 @@ class TORCH_CUDA_CU_API IdGraph { // otherwise a null shared ptr // (2) If the disjoint set of the provided Iter Domain exists // } + // + // TODO: Audit usage std::pair disjointIdSet(IterDomain* id) const; // Returns the disjoint Expr set. @@ -44,6 +46,8 @@ class TORCH_CUDA_CU_API IdGraph { DisjointSets& disjointExprSets(); // Same as getDisjointIdSet but for the Expression sets. + // + // TODO: Audit usage std::pair disjointExprSet(Expr* expr) const; // Convert expr to its exprGroup, assert that it exists. @@ -329,7 +333,7 @@ class TORCH_CUDA_CU_API IdGraphVisitor { // If sub_selection is assumed to be a set of iter domains by which form a // sub-regrion of the IdGraph provided. Only that sub-region will be visited. IdGraphVisitor( - IdGraph& id_graph, + const IdGraph& id_graph, const VectorOfUniqueEntries sub_selection = {}) : id_graph_(id_graph), sub_selection_(sub_selection) {} @@ -338,7 +342,7 @@ class TORCH_CUDA_CU_API IdGraphVisitor { void traverse(); - IdGraph& graph() { + const IdGraph& graph() { return id_graph_; }; @@ -353,7 +357,7 @@ class TORCH_CUDA_CU_API IdGraphVisitor { virtual ~IdGraphVisitor() = default; private: - IdGraph& id_graph_; + const IdGraph& id_graph_; const VectorOfUniqueEntries sub_selection_; }; @@ -361,7 +365,7 @@ class TORCH_CUDA_CU_API IdGraphVisitor { class IdGraphStmtSort : public IdGraphVisitor { public: IdGraphStmtSort( - IdGraph& id_graph, + const IdGraph& id_graph, const VectorOfUniqueEntries sub_selection = {}) : IdGraphVisitor(id_graph, sub_selection) { IdGraphVisitor::traverse(); @@ -391,6 +395,12 @@ class IdGraphStmtSort : public IdGraphVisitor { IdGroups sorted_ids; }; +namespace { +// Convenience to store some intermediate data across a few lowering build +// passes. +struct StatefulLoweringInfo; +} // namespace + // TODO: Comment is stale, update. // // There's three modes of these iter domain mappings all uniquely important in @@ -557,12 +567,57 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { // AlmostExact entries, then map through broadcasts void buildPermissiveMap(const std::vector& exprs); + // Make sure only leaf nodes of tensor views are parallelized + void validatePTypes(const std::vector& all_tvs) const; + //! Run through disjoint sets in the LOOP map, make sure there's only one //! non-serial parallel type in each disjoint set, set the parallel type of //! all IterDomains in the disjoint set to that PType. - void validateAndPropagatePType() const; + void propagateLoopPTypes() const; + + // !! START Helper functions to build loop promotion and index map!! + + // Terminal loop ids are iteration domains in each loop group that: + // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a + // consumer TV's iter domain maps to this domain in a way that that domain + // is also in the same loop group + // 2) Don't have a direct IterDomain consumer within the group + VectorOfUniqueEntries computeTerminalLoopIds( + const StatefulLoweringInfo info); + + // Returns an IdGraph with all Id's mapped that are mapped both in graph0 and + // graph1. + IdGraph buildIntersection( + const IdGraph& graph0, + const IdGraph& graph1, + bool propagate_exprs = true); + + // !! END Helper functions to build loop promotion and index map!! + + // Start loop map by grouping inlined iter domains + void initializeLoopMap(StatefulLoweringInfo& info); + + // Returns map of IdGroups in the loop map to a representative IterDomain that + // contains all resolved transformations that the terminal IterDomains should + // be promoted to. The returned promotions are valid only for inlined iter + // domains. + std::unordered_map buildInlinePromotions( + StatefulLoweringInfo& info); + + // Returns a similar thing to buildInlinePromotions but also includes iter + // domains that are not inlined. + std::unordered_map buildLoopPromotionMap( + const std::vector& exprs, + StatefulLoweringInfo& info, + std::unordered_map stale_promotion_map); - void buildLoopPromotionMap(const std::vector& exprs); + // Builds idGraph(IdMappingMode::INDEX) and returns the iter domain promotion + // map to go from leaf domains of each (consumer only?) tensor to their + // corresponding leaf domain in the index graph. + std::unordered_map buildIndexGraph( + const std::vector& exprs, + StatefulLoweringInfo& info, + std::unordered_map stale_promotion_map); // Returns the terminal rfactor or input iter domains each group in the almost // exact map covers (in the almost exact map). This effectively returns all From adbad3dbd63dd0153c90ea7ce3f4fd589adcc87b Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Fri, 5 May 2023 21:43:19 -0400 Subject: [PATCH 022/178] WIP Index Graph. --- csrc/id_graphs.cpp | 578 ++++++++++++++++++++++++++------------------- csrc/id_graphs.h | 16 +- 2 files changed, 336 insertions(+), 258 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index a16598504dd..09e784471d9 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -664,34 +664,6 @@ std::vector IdGraph::inputGroups(ExprGroup expr) const { return input_groups; } -bool IdGraph::groupsMatch( - std::vector id_groups0, - std::vector id_groups1) const { - if (id_groups0.size() != id_groups1.size()) { - return false; - } - for (auto id_g_i : c10::irange(id_groups0.size())) { - if (id_groups0[id_g_i] != id_groups1[id_g_i]) { - return false; - } - } - return true; -} - -bool IdGraph::groupsMatch( - std::vector expr_groups0, - std::vector expr_groups1) const { - if (expr_groups0.size() != expr_groups1.size()) { - return false; - } - for (auto id_g_i : c10::irange(expr_groups0.size())) { - if (expr_groups0[id_g_i] != expr_groups1[id_g_i]) { - return false; - } - } - return true; -} - ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { ExprGroups to_visit; for (auto of_id_group : of) { @@ -1271,11 +1243,9 @@ void IdGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { if (propagate_exprs_) { mapExprs(expr0, expr1); mapThroughExpr(expr0, expr1, forward); - } else if ((groupsMatch( - inputGroups(toGroup(expr0)), inputGroups(toGroup(expr1))) && - groupsMatch( - outputGroups(toGroup(expr0)), - outputGroups(toGroup(expr1))))) { + } else if ( + inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)) && + outputGroups(toGroup(expr0)) == outputGroups(toGroup(expr1))) { mapExprs(expr0, expr1); } } @@ -1478,7 +1448,7 @@ void IdGraph::removeTrivialExprs() { // Clear out expressions that map inputs and outputs to the same group // from definitions and uses. They shouldn't be important in traversal, and // will break the terminal input/terminal output logic of traversal. Similar - // to what's drafted in buildIndexMap + // to what's drafted in buildIndexGraph for (auto trivial_expr_group : trivial_expr_groups) { // Complexity of erase not good as both disjoint set and vector of unique // entries require a vector find to erase an entry. @@ -2484,7 +2454,18 @@ void IterDomainGraphs::build( // Find loops that need to be promoted because of broadcast resolution, // figure out what that resolution should look like, compute IDs for it if // necessary. - buildLoopPromotionMap(tv_exprs, info, iel_promotion_map); + iel_promotion_map = + buildLoopPromotionMap(tv_exprs, info, iel_promotion_map); + // Loop map potentialy changed changed, as we could have replayed + // expressions. Re-propagate parallel types. + propagateLoopPTypes(); + + // Find loops that need to be promoted because of broadcast resolution, + // figure out what that resolution should look like, compute IDs for it if + // necessary. + auto leaf_id_promo_map = + buildIndexGraph(tv_exprs, all_tvs, info, iel_promotion_map); + // Make sure we update ptypes onto the index leaf iter domains propagateLoopPTypes(); } @@ -2950,7 +2931,8 @@ std::unordered_map IterDomainGraphs:: } } - // Need to use the intersection of exact and loop map again. + // Need to use the intersection of exact and loop map again, it needs to be + // recomputed. auto intersection_exact_loop_graph = buildIntersection( idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); @@ -2959,7 +2941,7 @@ std::unordered_map IterDomainGraphs:: updateMap(stale_promotion_map, intersection_exact_loop_graph); // Map from an exact iter domain group, to all the exact iter domain groups it - // covers + // covers; needs to be recomputed. std::unordered_map exact_covered_ids = computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); @@ -2968,12 +2950,20 @@ std::unordered_map IterDomainGraphs:: computeTerminalLoopIds(info); // Loop promotion map is to prepare for IterDomain replays to resolve - // non-inlined loop groups. Since these replays will modify the loop map, we - // operate on a copy of the loop map, not the original one. - + // non-inlined loop groups. Since these replays will modify the loop map as + // we're iterating over the loop map, operate on a copy of the loop map, not + // the original one. auto loop_graph_copy = idGraph(IdMappingMode::LOOP); + + // Build a map from loop iter domain group to a promoted iter domain (doesn't + // have to be in the loop group) that covers all the exact groups + // representative of the resolved transformations within the loop group. Only + // the inlined loop groups will be covered here. std::unordered_map loop_graph_copy_promotion_map; + // TODO: I'm uncertain if we can simply use the iel_promotion_map. Once this + // system is in use we should test not recomputing the "concrete ids". + for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_group->size() == 1) { loop_graph_copy_promotion_map[loop_group] = loop_group->front(); @@ -3013,7 +3003,7 @@ std::unordered_map IterDomainGraphs:: } } - // Collect all the exact groups of the iter domains in the loop group + // All the exact groups of the iter domains in the loop group IdGroups exact_groups = idGraph(IdMappingMode::EXACT).toGroups(*loop_group); // All exact groups covered by all iter domains in this loop group @@ -3026,6 +3016,9 @@ std::unordered_map IterDomainGraphs:: IterDomain* loop_promotion_id = nullptr; + // Check if any of the candidate Iter Domains we collected cover all the + // exact groups of loop_group_covered_ids. If so, that's the correct + // promoted iter domain of this group. for (auto entry : exact_promoted_terminal_ids) { auto terminal_id_group = entry.first; auto terminal_id = entry.second; @@ -3060,7 +3053,7 @@ std::unordered_map IterDomainGraphs:: } } - std::cout << "Loop promotion:" << std::endl; + std::cout << "Loop promotion before second replay:" << std::endl; for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_graph_copy_promotion_map.find(loop_group) != loop_graph_copy_promotion_map.end()) { @@ -3070,26 +3063,41 @@ std::unordered_map IterDomainGraphs:: } } - // Reset the promotion map for the second pass + // Reset the promotion map for the second pass. + // TODO: Unclear if we could simply update the iel_promotion_map from + // buildInlinePromotions, instead of manually building it. iel_promotion_map.clear(); - std::cout << "\n\nSecond replay:" << std::endl; + // Need to run a replay for the loop groups that are dependent on inlined loop + // groups, but themselves are not inlined loop groups. - IdGraphStmtSort iel_stmt_sort2(intersection_exact_loop_graph); - for (auto iel_expr : iel_stmt_sort2.exprs()) { + for (auto iel_expr : IdGraphStmtSort(intersection_exact_loop_graph).exprs()) { auto iel_inp_groups = intersection_exact_loop_graph.inputGroups(iel_expr); auto iel_out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); - // When replaying the transformations a second time we want to take loop - // promotion into consideration. However, we don't want to blindly apply - // loop promotion to all iter domains within a loop group as it would - // replay the transformations within that loop group on the promoted id of - // that loop group. + // When replaying the transformations we can't blindly apply loop promotion + // to all iter domains within a loop group as it would replay the + // transformations within that loop group on the promoted id of that loop + // group. + // + // i.e. if we have the inlined domains from: + // T2[i0*i1] pa(1) = T0[i0*b1]ca(1) + T1[i0*i1]ca(1) + // The inlined loop group would be: + // + // i0, i1, b1, i0*i1, b0*i1 + // Then if we replayed the iel transformations they would be: + // merge(i0, i1) + // merge(i0, b1) // - // Instead only promote an input if the inputs are of a different loop - // group than the outputs. Then we want to promote the inputs to compute - // the output. + // So if we replayed them with loop promotion, then i0, i1, b1 would be + // promoted to i0*i1, and the merges would be replayed. + // + // Therefore only promote i0*b1 to i0*i1, or i0*i1 to i0*i1 (i.e. don't + // promote an input to any transformation within the loop group). + // + // So if we have an iel_expr make sure it's inputs and outputs are not in + // the same loop group. IdGroups inp_loop_groups; for (auto iel_inp_group : iel_inp_groups) { @@ -3101,6 +3109,7 @@ std::unordered_map IterDomainGraphs:: out_loop_groups.pushBack(loop_graph_copy.toGroup(iel_out_group->front())); } + // The inputs should be promoted based on the loop promotion map. bool loop_promote_inputs = !inp_loop_groups.subtract(out_loop_groups).empty(); @@ -3110,7 +3119,12 @@ std::unordered_map IterDomainGraphs:: // Promote inputs for replay for (auto iel_inp_group : iel_inp_groups) { - // Prefer loop promotion + // Promote loops based on the loop promotion map. If the loop promotion + // map should be used and has an entry we should use that promotion. This + // happen when an iel expression is across a loop group boundary. + // Signifying and capturing instances when we traverse across an inlined + // loop group to a non-inlined loop group boundary (think of the iel graph + // projected onto the loop graph). auto loop_copy_group = loop_graph_copy.toGroup(iel_inp_group->front()); auto inp_loop_promo_it = loop_graph_copy_promotion_map.find(loop_copy_group); @@ -3119,6 +3133,10 @@ std::unordered_map IterDomainGraphs:: promoted_inputs.push_back(inp_loop_promo_it->second); an_input_was_promoted = true; } else { + // We still could require an input promotion. We could be traversing + // across non-inlined groups. Meaning we have inputs that were promoted + // in an inlined loop group traversing through the non-inlined portions + // of the iel graph. auto inp_promo_it = iel_promotion_map.find(iel_inp_group); if (inp_promo_it == iel_promotion_map.end()) { promoted_inputs.push_back(iel_inp_group->front()); @@ -3136,33 +3154,39 @@ std::unordered_map IterDomainGraphs:: Expr* replay = nullptr; // Before replaying, check if there's already an expression like this, if so - // use that for promotion. + // use that for promotion. We're still only looking for representative iter + // domains, so if there's already an expression that would produce something + // representative (matching in the exact graph) of what the new inputs would + // generate, just promote to that expressions outputs, don't bother + // generating a new one. + // + // Check all uses of the exact map the inputs are in, and look for one that + // would match. Grab all uses of the promoted inputs' groups in the exact + // map. + std::vector promoted_input_groups; + ExprGroups promoted_input_uses; for (auto inp_id : promoted_inputs) { auto inp_exact_group = idGraph(IdMappingMode::EXACT).toGroup(inp_id); + promoted_input_groups.push_back(inp_exact_group); promoted_input_uses.pushBack( idGraph(IdMappingMode::EXACT).uniqueUses(inp_exact_group)); } + // Check every use to see if it matches for (auto exact_use_group : promoted_input_uses) { - if (transformAtributesMatch( + // Check if all the attributes (including type) of the transform match + if (!transformAtributesMatch( iel_expr->front(), exact_use_group->front())) { - auto exact_use_inps = ir_utils::filterByType( - exact_use_group->front()->inputs()) - .vector(); - bool inps_match = true; - for (auto inp_i : c10::irange(exact_use_inps.size())) { - inps_match = inps_match && - idGraph(IdMappingMode::EXACT) - .disjointIdSets() - .strictAreMapped( - exact_use_inps[inp_i], promoted_inputs[inp_i]); - } - if (inps_match) { - replay = exact_use_group->front(); - break; - } + continue; + } + // Check if inputs all match + if (promoted_input_groups != + idGraph(IdMappingMode::EXACT).inputGroups(exact_use_group)) { + continue; } + replay = exact_use_group->front(); + break; } bool replayed = replay == nullptr; @@ -3177,7 +3201,7 @@ std::unordered_map IterDomainGraphs:: auto output_groups = intersection_exact_loop_graph.outputGroups(iel_expr); - // Mark outputs as having a promoted iter domain + // Match or replay, mark promotion for output groups. auto replay_out_ids = ir_utils::filterByType(replay->outputs()).vector(); auto ref_out_ids = @@ -3190,12 +3214,16 @@ std::unordered_map IterDomainGraphs:: if (!idGraph(IdMappingMode::EXACT) .disjointIdSets() .strictAreMapped(replay_out_ids[i], output_groups[i]->front())) { + // Promote if necessary, if the output is already in the same exact map + // it doesn't need a promotion. iel_promotion_map[output_groups[i]] = replay_out_ids[i]; // Explicitly map loop map since expr propagation doesn't happen on the // loop map and the replayed outputs are brand new so we can map them // without joining disjoint loop groups (other than the new loop groups // the outputs of the replay are in) if (replayed) { + // If we built new iter domains because we generated a new expression, + // link the outputs in the loop graph. idGraph(IdMappingMode::LOOP) .mapIds(replay_out_ids[i], ref_out_ids[i]); } @@ -3203,14 +3231,6 @@ std::unordered_map IterDomainGraphs:: } } - // ==================================================================================== - // ==================================================================================== - // ==================================================================================== - // ==================================================================================== - // ==================================================================================== - // ==================================================================================== - // ==================================================================================== - std::cout << "Promotion map from second replay: " << std::endl; for (auto group : intersection_exact_loop_graph.disjointIdSets().disjointSets()) { @@ -3221,90 +3241,48 @@ std::unordered_map IterDomainGraphs:: << iel_promotion_map.at(group)->toString() << std::endl; } - // Need to perform some updates after replay - { - intersection_exact_loop_graph = initializeIdGraph(); - intersection_exact_loop_graph.disableExprPropagation(); - for (auto exact_group : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - auto set_size = exact_group->size(); - for (auto id0_i : c10::irange(set_size)) { - auto id0 = exact_group->vector()[id0_i]; - for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { - auto id1 = exact_group->vector()[id1_i]; - // id0 and id1 map in the almost exact map, if they also map in the - // loop graph, then add the mapping to the inersection - if (idGraph(IdMappingMode::LOOP) - .disjointIdSets() - .strictAreMapped(id0, id1)) { - intersection_exact_loop_graph.mapIds(id0, id1); - } - } - } - } - - // Update iel_promotion_map since we changed the loop map the IdGroup key is - // invalid - std::unordered_map old_iel_promotion_map; - std::swap(iel_promotion_map, old_iel_promotion_map); - for (auto entry : old_iel_promotion_map) { - auto old_iel_group = entry.first; - auto id = entry.second; - iel_promotion_map[intersection_exact_loop_graph.toGroup( - old_iel_group->front())] = id; - } - - exact_covered_ids.clear(); - - for (auto id_group : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - // Initialize inputs - if (idGraph(IdMappingMode::EXACT).uniqueDefinitions(id_group).empty()) { - exact_covered_ids[id_group] = {id_group}; - } - - // Initialize rfactor groups - if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { - return view_rfactor_ids_.find(id) != view_rfactor_ids_.end(); - })) { - exact_covered_ids[id_group] = {id_group}; - } - - // Initialize broadcast groups to empty - if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { - return id->isBroadcast(); - })) { - exact_covered_ids[id_group] = {}; - } - } + return iel_promotion_map; +} - IdGraphStmtSort exact_stmt_sort(idGraph(IdMappingMode::EXACT)); +std::unordered_map IterDomainGraphs::buildIndexGraph( + const std::vector& exprs, + const std::vector& all_tvs, + StatefulLoweringInfo& info, + std::unordered_map stale_promotion_map) { + // Update the iel graph + auto intersection_exact_loop_graph = buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); - for (auto exact_expr : exact_stmt_sort.exprs()) { - auto input_groups = idGraph(IdMappingMode::EXACT).inputGroups(exact_expr); + // Update the promotion map + auto iel_promotion_map = + updateMap(stale_promotion_map, intersection_exact_loop_graph); - IdGroups covered; - for (auto inp_group : input_groups) { - covered.pushBack(exact_covered_ids.at(inp_group)); - } + auto exact_covered_ids = + computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); - for (auto output_group : - idGraph(IdMappingMode::EXACT).outputGroups(exact_expr)) { - exact_covered_ids[output_group] = covered; - } - } + // Grab terminal iter domain in the loop groups. + VectorOfUniqueEntries terminal_loop_ids = + computeTerminalLoopIds(info); - // Loop promotion map is to prepare for IterDomain replays. Since these - // replays will modify the loop map, we operate on a copy of the loop map, - // not the original one. + // Loop promotion map is to prepare for IterDomain replays. Since these + // replays will modify the loop map, we operate on a copy of the loop map, + // not the original one. + // Loop promotion map is to prepare for IterDomain replays to resolve + // non-inlined loop groups. Since these replays will modify the loop map as + // we're iterating over the loop map, operate on a copy of the loop map, not + // the original one. + auto loop_graph_copy = idGraph(IdMappingMode::LOOP); - loop_graph_copy = idGraph(IdMappingMode::LOOP); - loop_graph_copy_promotion_map.clear(); - } + // Build a map from loop iter domain group to a promoted iter domain (doesn't + // have to be in the loop group) that covers all the exact groups + // representative of the resolved transformations within the loop group. Only + // the inlined loop groups will be covered here. + std::unordered_map loop_graph_copy_promotion_map; // Returns a new promoted domain if one is found in the iel_promotion_map, // otherwise returns original id. - auto get_promoted_id = [&](IterDomain* id) { + auto get_promoted_id = [&intersection_exact_loop_graph, + &iel_promotion_map](IterDomain* id) { auto iel_group = intersection_exact_loop_graph.toGroup(id); auto iel_promotion_map_it = iel_promotion_map.find(iel_group); if (iel_promotion_map_it != iel_promotion_map.end()) { @@ -3313,7 +3291,8 @@ std::unordered_map IterDomainGraphs:: return id; }; - // Returns the entry in exact_covered_ids associated with provided IterDomain + // Returns the entry in exact_covered_ids associated with provided IterDomain. + // Basically calling .at but with a better error. auto get_covered_exact_groups = [&](IterDomain* id) { auto exact_group = idGraph(IdMappingMode::EXACT).toGroup(id); auto covered_it = exact_covered_ids.find(exact_group); @@ -3324,6 +3303,13 @@ std::unordered_map IterDomainGraphs:: return covered_it->second; }; + // Now we need to find the right promoted ID for every loop group, making + // sure the promoted ID covers every ID of the IDs in the loop group. + // This ID could be a terminal ID in the group. A promoted ID of the terminal + // IDs, or an ID that was replayed previously and now part of the loop group. + // + // The correct/final promoted ID of the loop group must exist at this point. + // It just might not be within the loop group we're looking at. std::cout << "Find promoted ids from loop group or promoted iter domains." << std::endl; for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { @@ -3349,8 +3335,10 @@ std::unordered_map IterDomainGraphs:: // If a promotion entry doesn't exist for a terminal id, put it here. std::vector terminal_ids; + // All exact groups that the terminal loop id's cover. IdGroups all_covered_exact_groups; + // Populate all three structures above. for (auto loop_id : *loop_group) { if (!terminal_loop_ids.has(loop_id)) { continue; @@ -3366,20 +3354,24 @@ std::unordered_map IterDomainGraphs:: } } + // If promoted id's exist, those are the candidates to have the right + // transformations for indexing. Otherwise, use the terminal _ids. auto candidate_ids = promoted_terminal_ids.empty() ? terminal_ids : promoted_terminal_ids; + // Find the loop promotion id from the candidates. IterDomain* loop_promotion_id = nullptr; - for (auto candidate_id : candidate_ids) { if (all_covered_exact_groups .subtract(get_covered_exact_groups(candidate_id)) .empty()) { loop_promotion_id = candidate_id; + break; } } - // Try any replayed IDs if we're still mising the promoted id. + // If we're still missing the loop_promotion_id, check all replayed IDs in + // the loop group. if (loop_promotion_id == nullptr) { candidate_ids = loop_group->subtract(info.ordered_c_ids).vector(); for (auto candidate_id : candidate_ids) { @@ -3406,7 +3398,7 @@ std::unordered_map IterDomainGraphs:: loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; } - std::cout << "Promotion map from concrete id pass: " << std::endl; + std::cout << "Promotion map to build the Index Graph: " << std::endl; for (auto group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_graph_copy_promotion_map.find(group) == loop_graph_copy_promotion_map.end()) { @@ -3440,9 +3432,10 @@ std::unordered_map IterDomainGraphs:: // index math. Therefore, roughly what we need to do is: // - Figure out which leaves share exact indexing and map them together: - // (1) Promoted producer-consumer leaf nodes are almost exact. - // (2) Producer-consumer leaf nodes are inlined with eachother, and they're - // almost exact. + // (1) Producer-consumer leaf nodes are inlined with eachother (map to the + // same promoted id) + // (2) Promoted producer-consumer leaf nodes are almost exact, have the same + // parallel type, but are not inlined. // - Start at the promoted leaf nodes of each tensor view @@ -3459,55 +3452,6 @@ std::unordered_map IterDomainGraphs:: // Mark all iter domains that share a loop nest and are almost exact mapped. // Ignores promotion. - auto index_graph = initializeIdGraph(); - - for (auto expr : exprs) { - // Iter domains in producer that are inlined with consumer iter domains - std::vector producer_inlined_leaves; - - // Copy of all the producer id's for determinism - VectorOfUniqueEntries all_p_ids; - for (auto producer : ir_utils::filterByType(expr->inputs())) { - all_p_ids.insert( - producer->domain()->domain().begin(), - producer->domain()->domain().begin() + - producer->getComputeAtPosition()); - producer_inlined_leaves.insert( - producer_inlined_leaves.end(), - producer->domain()->domain().begin(), - producer->domain()->domain().begin() + - producer->getComputeAtPosition()); - } - - // Grab potentially inlined iter domains in consumers - std::vector consumer_inlined_leaves; - for (auto consumer : ir_utils::filterByType(expr->outputs())) { - consumer_inlined_leaves.insert( - consumer_inlined_leaves.end(), - consumer->domain()->domain().begin(), - consumer->domain()->domain().begin() + - consumer->getMaxProducerPosition()); - } - - // Almost exact map from producer inlined iter domains to all the consumer - // domains they could be inlined into. Build an almost exact map between - // those. - auto p2c_loop_map = - idGraph(IdMappingMode::ALMOSTEXACT) - .buildMapBetween(producer_inlined_leaves, consumer_inlined_leaves); - // Make sure we call mapIds deterministically - for (auto p_id : all_p_ids) { - auto p2c_loop_map_it = p2c_loop_map.find(p_id); - if (p2c_loop_map_it == p2c_loop_map.end()) { - continue; - } - auto c_ids = p2c_loop_map_it->second; - - for (auto c_id : c_ids) { - index_graph.mapIds(p_id, c_id); - } - } - } // Doing the same as above on promoted iter domains is a bit tricky, because // there's a promoted IterDomian per IEL group, we need a promoted IterDomain @@ -3517,7 +3461,8 @@ std::unordered_map IterDomainGraphs:: // TODO: I think we need to validate that for each tensor view leaf domains, // no two leaves within a tensor domain map to another leaf in the same tensor - // domain in the IEL graph. + // domain in the IEL graph. Not sure how this could occur, but I suspect it + // could. // Which non-promoted iter domains, share their promoted iterdomains DisjointSets shared_promoted_id; @@ -3594,18 +3539,15 @@ std::unordered_map IterDomainGraphs:: } auto get_representative_promoted_id = [&](IterDomain* id) { - auto loop_copy_group_pair = loop_graph_copy.disjointIdSet(id); - TORCH_INTERNAL_ASSERT(loop_copy_group_pair.second); - auto loop_copy_group = loop_copy_group_pair.first; - - auto promo_id_it = loop_graph_copy_promotion_map.find(loop_copy_group); + auto promo_id_it = + loop_graph_copy_promotion_map.find(loop_graph_copy.toGroup(id)); TORCH_INTERNAL_ASSERT(promo_id_it != loop_graph_copy_promotion_map.end()); - return promo_id_it->second; }; std::cout << "Opportunistic joining of shared promos:" << std::endl; - // Opportunistically collapse indexing of non-inlined leaf domains + // Opportunistically collapse indexing of non-inlined leaf domains if their + // promoted ids are almost exact mapped and have the same parallel type. for (auto expr : exprs) { for (auto producer : ir_utils::filterByType(expr->inputs())) { std::cout << " Producer: " << producer->toString() << std::endl; @@ -3639,12 +3581,16 @@ std::unordered_map IterDomainGraphs:: std::cout << " " << p_id->toString() << " -> " << rep_p_id->toString() << " :: " << c_id->toString() << " -> " << rep_c_id->toString() << std::endl; - if (idGraph(IdMappingMode::ALMOSTEXACT) - .disjointIdSets() - .strictAreMapped(rep_p_id, rep_c_id)) { - std::cout << " Mapped" << std::endl; - shared_promoted_id.mapEntries(p_id, c_id); + if (!idGraph(IdMappingMode::ALMOSTEXACT) + .disjointIdSets() + .strictAreMapped(rep_p_id, rep_c_id)) { + continue; + } + if (rep_p_id->getParallelType() != rep_c_id->getParallelType()) { + continue; } + std::cout << " Mapped" << std::endl; + shared_promoted_id.mapEntries(p_id, c_id); } } } @@ -3713,24 +3659,11 @@ std::unordered_map IterDomainGraphs:: << leaf_promotion_map.at(id_group->front()) << std::endl; } - // Could pass this into the function, but just using this for now. - auto all_tvs = ir_utils::allTvsOfExprs(exprs); - - idGraph(IdMappingMode::INDEX) = initializeIdGraph(); - idGraph(IdMappingMode::INDEX).mapThroughTrivialExprs(); - idGraph(IdMappingMode::INDEX).removeTrivialExprs(); - // Track every expression required for indexing VectorOfUniqueEntries all_index_exprs; // Track every iter domain required for indexing VectorOfUniqueEntries all_index_ids; - // The almost exact map could have new trivial expression groups from the - // replays, which are expressions that have an input mapped to an output of - // that expression. getExprsBetween protects against these, but they can also - // just be removed. - idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); - std::cout << "\n\nThird and final replay" << std::endl; std::cout << "Building promoted tensor view domains:" << std::endl; // Need to "replay" all of the indexing expressions to make sure roots are @@ -3741,6 +3674,64 @@ std::unordered_map IterDomainGraphs:: // on. auto ae_graph = idGraph(IdMappingMode::ALMOSTEXACT); + // Because of how replays work in buildInlinePromotions and + // buildLoopPromotionMap, we could have multiple uses and definitions of the + // the same iter domain. + // + // However, for the index graph we want to go back to every iter domain having + // at most one use and definition. + // + // We also want to use expressions that exist if we can. + // + // If there's multiple paths on the index graph then we would generate + // conflicting indicies (unless somehow the expressions all end up collapsing + // by being mapped later). Enforce one defintion and use per iter domain. + std::unordered_map id_to_index_use; + std::unordered_map id_to_index_def; + + // Initialize index graph using the history of each tensorview. These + // expressions are not guaranteed to be used, but if it is used, this will + // prefer those used in a tv's history. + // + // This prevents conflicts later where we try to reuse an expression and take + // an expression in another tensor view's history. + for (auto tv : all_tvs) { + auto transforms = StmtSort::getExprsBetween( + FusionGuard::getCurFusion(), + {tv->getRootDomain().begin(), tv->getRootDomain().end()}, + {tv->domain()->domain().begin(), tv->domain()->domain().end()}); + for (auto transform : transforms) { + for (auto inp : ir_utils::filterByType(transform->inputs())) { + id_to_index_use[inp] = transform; + } + for (auto out : + ir_utils::filterByType(transform->outputs())) { + id_to_index_def[out] = transform; + } + } + } + + // Manually initialize the index graph + for (auto id_group : + idGraph(IdMappingMode::ALMOSTEXACT).disjointIdSets().disjointSets()) { + for (auto id : *id_group) { + VectorOfUniqueEntries defs; + if (id_to_index_def.find(id) != id_to_index_def.end()) { + defs.pushBack(id_to_index_def.at(id)); + } + + VectorOfUniqueEntries uses; + if (id_to_index_use.find(id) != id_to_index_use.end()) { + uses.pushBack(id_to_index_use.at(id)); + } + + idGraph(IdMappingMode::INDEX).initializeId(id, defs, uses); + } + } + + idGraph(IdMappingMode::INDEX).mapThroughTrivialExprs(); + idGraph(IdMappingMode::INDEX).removeTrivialExprs(); + for (auto tv : all_tvs) { // We don't have to process inputs at this point as they're already // allocated on a global @@ -3780,7 +3771,7 @@ std::unordered_map IterDomainGraphs:: for (auto tv_id : all_ids) { // Use emplace here as it multiple tv_ids could map to the same ae_group. // Emplace will simply grab the first one that appears. - ae_group_2_id[ae_graph.toGroup(tv_id)] = tv_id; + ae_group_2_id.emplace(std::make_pair(ae_graph.toGroup(tv_id), tv_id)); } // Add the promoted domain ids @@ -3835,6 +3826,16 @@ std::unordered_map IterDomainGraphs:: Expr* replay = nullptr; + // Check if we already have this expression covered in the index graph. If + // so, don't add another expr, just add mappings for the iter domains + // necessary. + + // If there isn't already an index expression covering this, check the + // almost exact map if there's any expression not already in the index + // graph that we can use, and add in the index graph. + + // Else generate a new index expression from scratch. + // Before replaying, check if there's already an expression like this, if // so use that for promotion. ExprGroups promoted_output_defs; @@ -3872,7 +3873,87 @@ std::unordered_map IterDomainGraphs:: continue; } - replay = index_def_group->front(); + // Look for an expression in the group we can reuse. + // + // See comment on definition of id_to_index_use + for (auto maybe_match : *index_def_group) { + VectorOfUniqueEntries input_uses; + for (auto inp : + ir_utils::filterByType(maybe_match->inputs())) { + auto use_it = id_to_index_use.find(inp); + if (use_it == id_to_index_use.end()) { + continue; + } + input_uses.pushBack(use_it->second); + } + + // If there's already a use, make sure it's this use. + if (input_uses.subtract({maybe_match}).size() > 0) { + continue; + } + + VectorOfUniqueEntries output_defs; + for (auto out : + ir_utils::filterByType(maybe_match->outputs())) { + auto def_it = id_to_index_def.find(out); + if (def_it == id_to_index_def.end()) { + continue; + } + output_defs.pushBack(def_it->second); + } + + // If there's already a def, make sure it's this def. + if (output_defs.subtract({maybe_match}).size() > 0) { + continue; + } + + std::vector ae_inps = + ir_utils::filterByType( + ae_expr_group->front()->inputs()) + .vector(); + + auto maybe_match_inputs = + ir_utils::filterByType(maybe_match->inputs()) + .vector(); + + // If there are promoted inputs, we need them to match exactly, + // otherwise we can't reuse this expression. So although replay is not + // nullptr, we may set it back and keep looking. + bool promo_inps_match = true; + for (auto inp_i : c10::irange(maybe_match_inputs.size())) { + auto ae_group_pair = ae_graph.disjointIdSet(ae_inps[inp_i]); + if (ae_group_pair.second && + ae_group_2_id.find(ae_group_pair.first) != + ae_group_2_id.end()) { + auto promo_inp = ae_group_2_id.at(ae_group_pair.first); + if (promo_inp != maybe_match_inputs[inp_i]) { + promo_inps_match = false; + } + } + } + + if (!promo_inps_match) { + continue; + } + + replay = maybe_match; + + for (auto inp : + ir_utils::filterByType(replay->inputs())) { + id_to_index_use[inp] = replay; + } + + for (auto out : + ir_utils::filterByType(replay->outputs())) { + id_to_index_def[out] = replay; + } + break; + } + + // No expression we could use found, keep trying. + if (replay == nullptr) { + continue; + } std::vector ae_inps = ir_utils::filterByType(ae_expr_group->front()->inputs()) @@ -3894,6 +3975,7 @@ std::unordered_map IterDomainGraphs:: } } + // No existing expression could be reused. if (replay == nullptr) { std::vector ae_inps_outs = ir_utils::filterByType(ae_expr_group->front()->inputs()) @@ -3951,12 +4033,18 @@ std::unordered_map IterDomainGraphs:: std::cout << "All indexing expressions (on the index graph): " << std::endl; auto index_expr_groups = idGraph(IdMappingMode::INDEX).toGroups(all_index_exprs); - std::cout << debug::toString(idGraph(IdMappingMode::INDEX), index_expr_groups) - << std::endl; - std::cout << "All iter domains (on the index graph): " << std::endl; - auto index_id_groups = idGraph(IdMappingMode::INDEX).toGroups(all_index_ids); - std::cout << debug::toString(index_id_groups) << std::endl; + ExprGroups extraneous_expr_groups = + ExprGroups( + idGraph(IdMappingMode::INDEX).disjointExprSets().disjointSets()) + .subtract(index_expr_groups); + for (auto group : extraneous_expr_groups) { + idGraph(IdMappingMode::INDEX).eraseExprGroup(group); + } + + std::cout << "All index graph exprs: " << std::endl; + std::cout << debug::exprGroupsString(idGraph(IdMappingMode::INDEX)) + << std::endl; return {}; } diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index 2067179c3d5..f20db85ab73 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -66,18 +66,6 @@ class TORCH_CUDA_CU_API IdGraph { std::vector outputGroups(ExprGroup expr) const; std::vector inputGroups(ExprGroup expr) const; - // Returns if for each group in id_groups0 is the same as all groups in - // id_groups1. Requires size and order to be exact. - bool groupsMatch( - std::vector id_groups0, - std::vector id_groups1) const; - - // Returns if for each group in expr_groups0 is the same as all groups in - // expr_groups1. Requires size and order to be exact. - bool groupsMatch( - std::vector expr_groups0, - std::vector expr_groups1) const; - // Traverses uses of the IdGroups in 'of' and returns all ExprGroups // that have a use in their definition of provided of IdGroups. ExprGroups allUsesOf(const IdGroups& of) const; @@ -212,11 +200,11 @@ class TORCH_CUDA_CU_API IdGraph { propagate_exprs_ = false; } - private: // Removes the provided expression group from unique_definitions_ and // unique_uses_ breaking traversal through them. void eraseExprGroup(ExprGroup expr_group); + private: // If propagate_exprs_ = false, then mapThroughExpr will not be called as a // consequence of calling mapIds. As well as mapThroughExpr will not be called // (again) as a result of calling mapThroughExpr. @@ -616,6 +604,7 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { // corresponding leaf domain in the index graph. std::unordered_map buildIndexGraph( const std::vector& exprs, + const std::vector& all_tvs, StatefulLoweringInfo& info, std::unordered_map stale_promotion_map); @@ -626,6 +615,7 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { // have to be resolved by or before the rfactor iter domain. std::unordered_map buildCoveredAlmostExact(); + // TODO: Remove void buildIndexMap(const std::vector& all_tvs); // ======= END Iteration domain build process in order called ======= From 6f682d73a057deefe6be9b5de2724afd14808420 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Fri, 5 May 2023 22:07:58 -0400 Subject: [PATCH 023/178] Merge conflicts. --- csrc/id_graphs.cpp | 24 ++++++++++++------------ csrc/transform_iter.cpp | 36 ++++++++++++------------------------ 2 files changed, 24 insertions(+), 36 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 09e784471d9..132a419d763 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -1654,8 +1654,8 @@ findFirstSelfMapping( } // Leaf domains - auto self_mappped_leaf_pair = detectMappablePair( - tv->domain()->domain(), id_graph, IdMappingMode::LOOP); + auto self_mappped_leaf_pair = + detectMappablePair(tv->domain()->leaf(), id_graph, IdMappingMode::LOOP); if (self_mappped_leaf_pair.has_value()) { return std::make_tuple( tv, @@ -1686,7 +1686,7 @@ void IterDomainGraphs::buildIterDomainDefinitionsAndUses( // If the tensor domain is a view like domain, and the iteration // domain is marked as an rfactor product and is in the rfactor // domain, it's a view like rfactor iteration domain - const auto& rfactor_domain = tv->domain()->getMaybeRFactorDomain(); + const auto& rfactor_domain = tv->domain()->maybeRFactor(); if (std::find(rfactor_domain.begin(), rfactor_domain.end(), id) != rfactor_domain.end()) { view_rfactor_ids_.emplace(id); @@ -2117,7 +2117,7 @@ void IterDomainGraphs::buildExactMap(const std::vector& exprs) { // non-broadcast dimensions. Prevent any broadcasted axes being mapped // to non-broadcasted axes. auto exact_c2p_root_map = - PairwiseRootDomainMap(p_tv, c_tv, true) + PairwiseRootDomainMap(p_tv, c_tv) .mapConsumerToProducer(c_tv->domain(), p_tv->domain()); for (auto c_id : getSortedKeys(exact_c2p_root_map, Statement::lessThan)) { @@ -2206,7 +2206,7 @@ void IterDomainGraphs::validatePTypes( const std::vector& all_tvs) const { VectorOfUniqueEntries leaf_ids; for (auto tv : all_tvs) { - leaf_ids.pushBack(tv->domain()->domain()); + leaf_ids.pushBack(tv->domain()->leaf()); } for (const auto& disjoint_set : @@ -2307,7 +2307,7 @@ StatefulLoweringInfo buildInfo( for (auto expr : exprs) { for (auto producer : ir_utils::filterByType(expr->inputs())) { auto producer_root = producer->getMaybeRFactorDomain(); - auto producer_domain = producer->domain()->domain(); + auto producer_domain = producer->domain()->leaf(); // Grab all iteration domains in producer that its compute at iter domains // depend on. @@ -3477,7 +3477,7 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( VectorOfUniqueEntries all_promo_ids; for (auto producer : ir_utils::filterByType(expr->inputs())) { - for (auto p_id : producer->domain()->domain()) { + for (auto p_id : producer->domain()->leaf()) { // Initialize all entries shared_promoted_id.initializeSet(p_id); @@ -3496,7 +3496,7 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( } for (auto consumer : ir_utils::filterByType(expr->outputs())) { - for (auto c_id : consumer->domain()->domain()) { + for (auto c_id : consumer->domain()->leaf()) { // Initialize all entries shared_promoted_id.initializeSet(c_id); @@ -3553,7 +3553,7 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( std::cout << " Producer: " << producer->toString() << std::endl; auto producer_root = producer->getMaybeRFactorDomain(); - auto non_inline_producer_domain = producer->domain()->domain(); + auto non_inline_producer_domain = producer->domain()->leaf(); non_inline_producer_domain.erase( non_inline_producer_domain.begin(), non_inline_producer_domain.begin() + @@ -3562,7 +3562,7 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( for (auto consumer : ir_utils::filterByType(expr->outputs())) { std::cout << " Consumer: " << consumer->toString() << std::endl; - auto consumer_domain = consumer->domain()->domain(); + auto consumer_domain = consumer->domain()->leaf(); auto p2c_permissive_map = idGraph(IdMappingMode::PERMISSIVE) @@ -3645,7 +3645,7 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( // TODO: This needs to be available as a member function auto get_promoted_domain = [&](TensorDomain* td) { std::vector promoted_leaves; - for (auto id : td->domain()) { + for (auto id : td->leaf()) { auto promo_it = leaf_promotion_map.find(id); TORCH_INTERNAL_ASSERT(promo_it != leaf_promotion_map.end()); promoted_leaves.push_back(promo_it->second); @@ -3699,7 +3699,7 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( auto transforms = StmtSort::getExprsBetween( FusionGuard::getCurFusion(), {tv->getRootDomain().begin(), tv->getRootDomain().end()}, - {tv->domain()->domain().begin(), tv->domain()->domain().end()}); + {tv->domain()->leaf().begin(), tv->domain()->leaf().end()}); for (auto transform : transforms) { for (auto inp : ir_utils::filterByType(transform->inputs())) { id_to_index_use[inp] = transform; diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 13911c759d6..ca41c9606d8 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -924,8 +924,8 @@ ForwardingInfo::ForwardingInfo( std::vector active_tv_history = StmtSort::getExprs( FusionGuard::getCurFusion(), std::vector( - active_tv->domain()->domain().begin(), - active_tv->domain()->domain().end())); + active_tv->domain()->leaf().begin(), + active_tv->domain()->leaf().end())); auto isInForwardIdSet = [&forwarded_ids](IterDomain* input_id) { return forwarded_ids.count(input_id) > 0; @@ -953,34 +953,21 @@ ForwardingInfo::ForwardingInfo( // For the sake of BestEffortReplay we can forward the input mapping // to both the active and inactive tensor to the output of the // expression - std::vector forwarded_ids; + std::vector forwarded_ids_vec; std::vector compliment_ids; - // We have root axes in active_tv that don't exist in the inactive tensor, - // now forward those to include all id's in active_tv comprised of only axes - // not in the inactive tensor. - std::vector active_tv_history = StmtSort::getExprs( - FusionGuard::getCurFusion(), - std::vector( - active_tv->domain()->leaf().begin(), - active_tv->domain()->leaf().end())); - - auto isIdOnlyInActiveTv = [&forwarded_ids](IterDomain* input_id) { - return forwarded_ids.count(input_id) > 0; - }; - - for (auto expr : active_tv_history) { - auto input_ids = ir_utils::filterByType(expr->inputs()); - // If expr inputs are all in forwarded_ids, then so are all outputs - if (std::all_of(input_ids.begin(), input_ids.end(), isIdOnlyInActiveTv)) { - for (auto output_ids : - ir_utils::filterByType(expr->outputs())) { - forwarded_ids.emplace(output_ids); + for (auto input_id : input_ids) { + if (!isInForwardIdSet(input_id)) { + forwarded_ids_vec.emplace_back(input_id); + active_forwarding_map->emplace( + std::make_pair(input_id, merge_expr->out())); + } else { + compliment_ids.push_back(input_id); } } // Set up compliment map - for (auto forwarded_id : forwarded_ids) { + for (auto forwarded_id : forwarded_ids_vec) { active_compliment_map->emplace( std::make_pair(forwarded_id, compliment_ids)); } @@ -1008,6 +995,7 @@ IterDomain* getSwizzleFinalOutput( // This means id is a leaf that doesn't // have any consumers. Stop iteration in this case. if (expr_it == id2expr.end()) { + is_swizzle_input = false; break; } From 05882766bd921d878c7c39fb79b9e55c0438bd4a Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 7 May 2023 15:24:43 -0400 Subject: [PATCH 024/178] Swizzle fix. --- csrc/id_graphs.cpp | 234 ++++++++++++++++++++++++--------------------- csrc/id_graphs.h | 5 + 2 files changed, 131 insertions(+), 108 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 132a419d763..04c20259b7c 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -457,7 +457,7 @@ void IdGraphVisitor::traverse() { return std::all_of( unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { return expr_group->empty() || visited_exprs.has(expr_group) || - IdGraph::isTrivialExpr(expr_group->front()).size(); + graph().isTrivialExprGroup(expr_group); }); }; @@ -1395,18 +1395,21 @@ bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return true; } +// TODO: Actually assert if self mapping found. Self mapping test is not correct +// yet. void IterDomainGraphs::assertNoSelfMapping() { - TORCH_INTERNAL_ASSERT( - !hasSelfMapping(), - "Unsupported domain mapping detected in ", - std::get<0>(*self_mapping_info_)->toString(), - ". ", - std::get<3>(*self_mapping_info_), - " domains, ", - std::get<1>(*self_mapping_info_)->toString(), - " and ", - std::get<2>(*self_mapping_info_)->toString(), - ", are mapped with each other."); + if (hasSelfMapping()) { + TORCH_WARN( + "IdGraphs thinks there's a self mapping in the problem. It's probably IdGraphs problem, not yours... ", + std::get<0>(*self_mapping_info_)->toString(), + ". ", + std::get<3>(*self_mapping_info_), + " domains, ", + std::get<1>(*self_mapping_info_)->toString(), + " and ", + std::get<2>(*self_mapping_info_)->toString(), + ", are mapped with each other."); + } } void IdGraph::mapThroughTrivialExprs() { @@ -1437,10 +1440,9 @@ void IdGraph::mapThroughTrivialExprs() { void IdGraph::removeTrivialExprs() { ExprGroups trivial_expr_groups; + // This seems like it shouls just be a copy if. for (auto expr_group : disjointExprSets().disjointSets()) { - auto inp_groups = inputGroups(expr_group); - auto out_groups = outputGroups(expr_group); - if (IdGroups(inp_groups).intersect(IdGroups(out_groups)).size()) { + if (isTrivialExprGroup(expr_group)) { trivial_expr_groups.pushBack(expr_group); } } @@ -1457,20 +1459,21 @@ void IdGraph::removeTrivialExprs() { } void IdGraph::mapThroughLoopSwizzles() { - for (auto use_pairs : unique_uses_) { - auto use_groups = use_pairs.second; - for (auto use_group : use_groups) { - for (auto use : *use_group) { - if (auto swizzle_2d = dynamic_cast(use)) { - // Map each input to its corresponding output on the given - // disjoint set if this is a loop swizzle. Loop swizzles don't impact - // indexing, only iteration order. - if (swizzle_2d->swizzleMode() == SwizzleMode::Loop) { - mapIds(swizzle_2d->inX(), swizzle_2d->outX()); - mapIds(swizzle_2d->inY(), swizzle_2d->outY()); - } - } - } + std::vector all_swizzles; + + for (auto expr_set : disjointExprSets().disjointSets()) { + auto swizzles_in_expr_set = ir_utils::filterByType( + expr_set->vector().begin(), expr_set->vector().end()); + all_swizzles.insert( + all_swizzles.end(), + swizzles_in_expr_set.begin(), + swizzles_in_expr_set.end()); + } + + for (auto swizzle : all_swizzles) { + if (swizzle->swizzleMode() == SwizzleMode::Loop) { + mapIds(swizzle->inX(), swizzle->outX()); + mapIds(swizzle->inY(), swizzle->outY()); } } } @@ -1499,6 +1502,12 @@ void IdGraph::eraseExprGroup(ExprGroup expr_group) { } } +bool IdGraph::isTrivialExprGroup(ExprGroup expr_group) const { + return !IdGroups(inputGroups(expr_group)) + .intersect(IdGroups(outputGroups(expr_group))) + .empty(); +} + IterDomainGraphs::IterDomainGraphs( const std::vector& exprs, const std::vector& additional_tvs, @@ -2150,9 +2159,9 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { ForwardingInfo permissive_forwarding(p_tv, c_tv); for (auto entry : permissive_forwarding.producer_forwarding_map) { - std::cout << "Permissive producer forwarding: " - << entry.first->toString() << " -> " - << entry.second->toString() << std::endl; + // std::cout << "Permissive producer forwarding: " + // << entry.first->toString() << " -> " + // << entry.second->toString() << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); } @@ -2160,17 +2169,18 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { // TODO: Why should IDs be mapped to their compliments? Is this right? for (auto entry : permissive_forwarding.producer_compliment_map) { for (auto entry_2 : entry.second) { - std::cout << "Permissive producer compliment: " - << entry.first->toString() << " -> " << entry_2->toString() - << std::endl; + // std::cout << "Permissive producer compliment: " + // << entry.first->toString() << " -> " << + // entry_2->toString() + // << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); } } for (auto entry : permissive_forwarding.consumer_forwarding_map) { - std::cout << "Permissive consumer forwarding: " - << entry.first->toString() << " -> " - << entry.second->toString() << std::endl; + // std::cout << "Permissive consumer forwarding: " + // << entry.first->toString() << " -> " + // << entry.second->toString() << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); } @@ -2178,9 +2188,10 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { // TODO: Why should IDs be mapped to their compliments? Is this right? for (auto entry : permissive_forwarding.consumer_compliment_map) { for (auto entry_2 : entry.second) { - std::cout << "Permissive consumer compliment: " - << entry.first->toString() << " -> " << entry_2->toString() - << std::endl; + // std::cout << "Permissive consumer compliment: " + // << entry.first->toString() << " -> " << + // entry_2->toString() + // << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); } } @@ -2322,8 +2333,8 @@ StatefulLoweringInfo buildInfo( all_producer_ca_deps.insert( ca_deps_filter.begin(), ca_deps_filter.end()); } - std::cout << "Producer: " << producer->toString() << "\n " - << all_producer_ca_deps.toString() << std::endl; + // std::cout << "Producer: " << producer->toString() << "\n " + // << all_producer_ca_deps.toString() << std::endl; info.ordered_p_ca_ids.pushBack(all_producer_ca_deps); @@ -2427,7 +2438,7 @@ void IterDomainGraphs::build( if (FusionGuard::getCurFusion()->isA()) { validatePTypes(all_tvs); - FusionGuard::getCurFusion()->print(std::cout, true); + // FusionGuard::getCurFusion()->print(std::cout, true); StatefulLoweringInfo info = buildInfo( tv_exprs, @@ -2435,21 +2446,21 @@ void IterDomainGraphs::build( idGraph(IdMappingMode::PERMISSIVE)); initializeLoopMap(info); - std::cout << "Loop groups: " - << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) - << std::endl; + // std::cout << "Loop groups: " + // << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) + // << std::endl; - std::cout << "Promoted groups: " - << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) - << std::endl; + // std::cout << "Promoted groups: " + // << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) + // << std::endl; // Initial propagation of parallel types for inlined iter domains. Each time // new expressions are replayed this needs to be run. The disjoint sets in // the loop graph can only be joined after this point. - propagateLoopPTypes(); + // propagateLoopPTypes(); auto iel_promotion_map = buildInlinePromotions(info); - propagateLoopPTypes(); + // propagateLoopPTypes(); // Find loops that need to be promoted because of broadcast resolution, // figure out what that resolution should look like, compute IDs for it if @@ -2458,8 +2469,11 @@ void IterDomainGraphs::build( buildLoopPromotionMap(tv_exprs, info, iel_promotion_map); // Loop map potentialy changed changed, as we could have replayed // expressions. Re-propagate parallel types. - propagateLoopPTypes(); + // propagateLoopPTypes(); + // This pass still doesn't work, disable for now in case it's disruptive to + // tests. + /* // Find loops that need to be promoted because of broadcast resolution, // figure out what that resolution should look like, compute IDs for it if // necessary. @@ -2467,6 +2481,7 @@ void IterDomainGraphs::build( buildIndexGraph(tv_exprs, all_tvs, info, iel_promotion_map); // Make sure we update ptypes onto the index leaf iter domains propagateLoopPTypes(); + */ } // Debug, make sure there's no self mapping in TensorView's during lowering @@ -2685,7 +2700,7 @@ std::unordered_map IterDomainGraphs:: iel_promotion_map[iel_group] = promoted_iel_groups.front()->front(); } - std::cout << "Initial promotion map:" << std::endl; + // std::cout << "Initial promotion map:" << std::endl; for (auto iel_group : intersection_exact_loop_graph.disjointIdSets().disjointSets()) { @@ -2693,13 +2708,13 @@ std::unordered_map IterDomainGraphs:: if (entry_it == iel_promotion_map.end()) { continue; } - std::cout << " " << entry_it->second->toString() << " <- " - << entry_it->first->toString() << std::endl; + // std::cout << " " << entry_it->second->toString() << " <- " + // << entry_it->first->toString() << std::endl; } IdGraphStmtSort iel_stmt_sort(intersection_exact_loop_graph); - std::cout << "Initial promotion replay:" << std::endl; + // std::cout << "Initial promotion replay:" << std::endl; for (auto iel_expr : iel_stmt_sort.exprs()) { auto input_groups = intersection_exact_loop_graph.inputGroups(iel_expr); @@ -2770,8 +2785,8 @@ std::unordered_map IterDomainGraphs:: bool replayed = replay == nullptr; if (replay == nullptr) { replay = addReplayAs(promoted_inputs, iel_expr->front()); - std::cout << " ***REPLAY***:\n " << iel_expr->front() - << " As:" << replay->toString(); + // std::cout << " ***REPLAY***:\n " << iel_expr->front() + // << " As:" << replay->toString(); } auto out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); @@ -3053,13 +3068,13 @@ std::unordered_map IterDomainGraphs:: } } - std::cout << "Loop promotion before second replay:" << std::endl; + // std::cout << "Loop promotion before second replay:" << std::endl; for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_graph_copy_promotion_map.find(loop_group) != loop_graph_copy_promotion_map.end()) { - std::cout << debug::toString(loop_group, 0, true) << " -> " - << loop_graph_copy_promotion_map[loop_group]->toString() - << std::endl; + // std::cout << debug::toString(loop_group, 0, true) << " -> " + // << loop_graph_copy_promotion_map[loop_group]->toString() + // << std::endl; } } @@ -3192,12 +3207,13 @@ std::unordered_map IterDomainGraphs:: bool replayed = replay == nullptr; if (replay == nullptr) { replay = addReplayAs(promoted_inputs, iel_expr->front()); - std::cout << " ***REPLAY2***:\n " << iel_expr->front() - << " As:" << replay->toString(); - } else { - std::cout << " ***MATCH2***:\n " << iel_expr->front() - << " As:" << replay->toString(); } + // std::cout << " ***REPLAY2***:\n " << iel_expr->front() + // << " As:" << replay->toString(); + // } else { + // std::cout << " ***MATCH2***:\n " << iel_expr->front() + // << " As:" << replay->toString(); + // } auto output_groups = intersection_exact_loop_graph.outputGroups(iel_expr); @@ -3231,14 +3247,14 @@ std::unordered_map IterDomainGraphs:: } } - std::cout << "Promotion map from second replay: " << std::endl; + // std::cout << "Promotion map from second replay: " << std::endl; for (auto group : intersection_exact_loop_graph.disjointIdSets().disjointSets()) { if (iel_promotion_map.find(group) == iel_promotion_map.end()) { continue; } - std::cout << debug::toString(group, 0, true) << " -> " - << iel_promotion_map.at(group)->toString() << std::endl; + // std::cout << debug::toString(group, 0, true) << " -> " + // << iel_promotion_map.at(group)->toString() << std::endl; } return iel_promotion_map; @@ -3310,8 +3326,8 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( // // The correct/final promoted ID of the loop group must exist at this point. // It just might not be within the loop group we're looking at. - std::cout << "Find promoted ids from loop group or promoted iter domains." - << std::endl; + // std::cout << "Find promoted ids from loop group or promoted iter domains." + // << std::endl; for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_group->size() == 1) { auto promoted_id = get_promoted_id(loop_group->front()); @@ -3398,15 +3414,15 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; } - std::cout << "Promotion map to build the Index Graph: " << std::endl; + // std::cout << "Promotion map to build the Index Graph: " << std::endl; for (auto group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_graph_copy_promotion_map.find(group) == loop_graph_copy_promotion_map.end()) { continue; } - std::cout << debug::toString(group, 0, true) << " -> " - << loop_graph_copy_promotion_map.at(group)->toString() - << std::endl; + // std::cout << debug::toString(group, 0, true) << " -> " + // << loop_graph_copy_promotion_map.at(group)->toString() + // << std::endl; } // Indexing traversal must start at leaf nodes of TensorViews as that's where @@ -3545,12 +3561,12 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( return promo_id_it->second; }; - std::cout << "Opportunistic joining of shared promos:" << std::endl; + // std::cout << "Opportunistic joining of shared promos:" << std::endl; // Opportunistically collapse indexing of non-inlined leaf domains if their // promoted ids are almost exact mapped and have the same parallel type. for (auto expr : exprs) { for (auto producer : ir_utils::filterByType(expr->inputs())) { - std::cout << " Producer: " << producer->toString() << std::endl; + // std::cout << " Producer: " << producer->toString() << std::endl; auto producer_root = producer->getMaybeRFactorDomain(); auto non_inline_producer_domain = producer->domain()->leaf(); @@ -3561,7 +3577,7 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( for (auto consumer : ir_utils::filterByType(expr->outputs())) { - std::cout << " Consumer: " << consumer->toString() << std::endl; + // std::cout << " Consumer: " << consumer->toString() << std::endl; auto consumer_domain = consumer->domain()->leaf(); auto p2c_permissive_map = @@ -3578,9 +3594,9 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( auto c_id = p2c_it->second.front(); auto rep_c_id = get_representative_promoted_id(c_id); - std::cout << " " << p_id->toString() << " -> " - << rep_p_id->toString() << " :: " << c_id->toString() - << " -> " << rep_c_id->toString() << std::endl; + // std::cout << " " << p_id->toString() << " -> " + // << rep_p_id->toString() << " :: " << c_id->toString() + // << " -> " << rep_c_id->toString() << std::endl; if (!idGraph(IdMappingMode::ALMOSTEXACT) .disjointIdSets() .strictAreMapped(rep_p_id, rep_c_id)) { @@ -3589,18 +3605,18 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( if (rep_p_id->getParallelType() != rep_c_id->getParallelType()) { continue; } - std::cout << " Mapped" << std::endl; + // std::cout << " Mapped" << std::endl; shared_promoted_id.mapEntries(p_id, c_id); } } } } - std::cout << "Leaf iter domains that share a promoted iter domain." - << std::endl; - for (auto disjoint_set : shared_promoted_id.disjointSets()) { - std::cout << disjoint_set->toString() << std::endl; - } + // std::cout << "Leaf iter domains that share a promoted iter domain." + // << std::endl; + // for (auto disjoint_set : shared_promoted_id.disjointSets()) { + // std::cout << disjoint_set->toString() << std::endl; + // } // Map from leaf iter domains to their potentially promoted iter domain used // for indexing. @@ -3653,19 +3669,19 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( return promoted_leaves; }; - std::cout << "Iter domain group to their promoted iter domain." << std::endl; - for (auto id_group : shared_promoted_id.disjointSets()) { - std::cout << id_group->toString() << "\n -> " - << leaf_promotion_map.at(id_group->front()) << std::endl; - } + // std::cout << "Iter domain group to their promoted iter domain." << + // std::endl; for (auto id_group : shared_promoted_id.disjointSets()) { + // std::cout << id_group->toString() << "\n -> " + // << leaf_promotion_map.at(id_group->front()) << std::endl; + // } // Track every expression required for indexing VectorOfUniqueEntries all_index_exprs; // Track every iter domain required for indexing VectorOfUniqueEntries all_index_ids; - std::cout << "\n\nThird and final replay" << std::endl; - std::cout << "Building promoted tensor view domains:" << std::endl; + // std::cout << "\n\nThird and final replay" << std::endl; + // std::cout << "Building promoted tensor view domains:" << std::endl; // Need to "replay" all of the indexing expressions to make sure roots are // connected to the promoted leaves, in a way we can index directly on the // index graph. @@ -3741,9 +3757,10 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( auto promoted_domain = get_promoted_domain(tv->domain()); // replay from root to promoted leaves. - std::cout << "\n\n Processing: TV" << tv->name() << "\n Root: TV" - << tv->getRootDomain() - << "\n Domain promoted to: " << promoted_domain << std::endl; + // std::cout << "\n\n Processing: TV" << tv->name() << "\n Root: TV" + // << tv->getRootDomain() + // << "\n Domain promoted to: " << promoted_domain << + // std::endl; // The promoted leaf iter domains are where indexing starts. We're going to // start at those expressions and replay transformations for this tensor @@ -3997,13 +4014,13 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( replay = addExprWithReplacement(replacement_map, ae_expr_group->front()); - std::cout << " ***REPLAY3***:\n " - << ae_expr_group->front()->toString() - << " As:" << replay->toString(); + // std::cout << " ***REPLAY3***:\n " + // << ae_expr_group->front()->toString() + // << " As:" << replay->toString(); } else { - std::cout << " ***MATCH3***:\n " - << " " << replay->toString(); + // std::cout << " ***MATCH3***:\n " + // << " " << replay->toString(); } all_index_exprs.pushBack(replay); @@ -4030,7 +4047,8 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( } } - std::cout << "All indexing expressions (on the index graph): " << std::endl; + // std::cout << "All indexing expressions (on the index graph): " << + // std::endl; auto index_expr_groups = idGraph(IdMappingMode::INDEX).toGroups(all_index_exprs); @@ -4042,9 +4060,9 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( idGraph(IdMappingMode::INDEX).eraseExprGroup(group); } - std::cout << "All index graph exprs: " << std::endl; - std::cout << debug::exprGroupsString(idGraph(IdMappingMode::INDEX)) - << std::endl; + // std::cout << "All index graph exprs: " << std::endl; + // std::cout << debug::exprGroupsString(idGraph(IdMappingMode::INDEX)) + // << std::endl; return {}; } diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index f20db85ab73..46fbb31bf8d 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -204,6 +204,11 @@ class TORCH_CUDA_CU_API IdGraph { // unique_uses_ breaking traversal through them. void eraseExprGroup(ExprGroup expr_group); + // Returns if the expression group has an input id group that matches an + // output id group. This means traversing on this expression doesn't actually + // do anything. + bool isTrivialExprGroup(ExprGroup expr_group) const; + private: // If propagate_exprs_ = false, then mapThroughExpr will not be called as a // consequence of calling mapIds. As well as mapThroughExpr will not be called From 81ba29983b7db3557c171786ce4401a13915b712 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 8 May 2023 11:25:30 -0400 Subject: [PATCH 025/178] Most tests passing, only those looking for exact code match fail, or new tests added. --- csrc/disjoint_set.h | 9 +++--- csrc/id_graphs.cpp | 72 +++++++++++++++++++++++++++++++++++---------- csrc/id_graphs.h | 2 +- 3 files changed, 62 insertions(+), 21 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 03994acbd75..2ca43be19e1 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -334,10 +334,6 @@ class DisjointSets { // belonging to entry0, maps all entries of disjoint set belonging to entry1 // to entry0, removes original disjoint set belonging to entry1. void mapEntries(T entry0, T entry1) { - if (entry0 == entry1) { - return; - } - auto set_it_0 = disjoint_set_maps_.find(entry0); auto set_it_1 = disjoint_set_maps_.find(entry1); @@ -371,6 +367,11 @@ class DisjointSets { disjoint_set_maps_[entry0] = new_set; } + // This should be after we enter a new set in case it doesn't exist. + if (entry0 == entry1) { + return; + } + if (set_1_found) { auto set_1 = set_it_1->second; for (auto set_1_entry : *set_1) { diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 04c20259b7c..3f150d3cc60 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -341,7 +341,8 @@ bool transformAtributesMatch(Expr* first, Expr* second) { } TORCH_INTERNAL_ASSERT( - first->isA() || first->isA() || first->isA(), + first->isA() || first->isA() || first->isA() || + first->isA(), "Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n", first->toString()); @@ -1217,6 +1218,21 @@ bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { } } + // TODO: For now we're using same as, however we could know what val's are + // exactly the same given the exact map. We might want to pipe that + // information through to here. + if (first->isA()) { + if (!first->as()->leftExpand()->sameAs( + second->as()->leftExpand())) { + return false; + } + + if (!first->as()->rightExpand()->sameAs( + second->as()->rightExpand())) { + return false; + } + } + return true; } @@ -1768,7 +1784,7 @@ std::string IterDomainGraphs::toString() const { // Replay Expr but with the inputs provided. Expr* IterDomainGraphs::addReplayAs( - const std::vector& new_inputs, + std::vector new_inputs, Expr* expr) { // Figure out which graphs are already initialized to make sure we add the new // expression to them. @@ -1791,6 +1807,28 @@ Expr* IterDomainGraphs::addReplayAs( std::vector orig_input_ids( orig_inputs.begin(), orig_inputs.end()); + if (std::any_of( + new_inputs.begin(), + new_inputs.end(), + [](IterDomain* id) { return id->isReduction(); }) && + std::any_of(new_inputs.begin(), new_inputs.end(), [](IterDomain* id) { + return !id->isReduction(); + })) { + // Inputs have mismatched type, replace new_inputs + decltype(new_inputs) tmp_inputs; + std::swap(tmp_inputs, new_inputs); + for (auto tmp_input : tmp_inputs) { + new_inputs.push_back( + IterDomainBuilder(tmp_input).iter_type(IterType::Iteration).build()); + id_definitions_[new_inputs.back()]; + id_uses_[new_inputs.back()]; + for (auto mode : initialized_modes) { + idGraph(mode).initializeId(new_inputs.back(), {}, {}); + idGraph(mode).mapIds(new_inputs.back(), tmp_input); + } + } + } + { TORCH_INTERNAL_ASSERT( new_inputs.size() == orig_input_ids.size(), @@ -2213,24 +2251,26 @@ void IterDomainGraphs::buildAlmostExactMap() { idGraph(IdMappingMode::ALMOSTEXACT).mapThroughTrivialExprs(); } +// TODO: Reenable after reenabling parallel propagation. +// propagateLoopPTypes void IterDomainGraphs::validatePTypes( const std::vector& all_tvs) const { - VectorOfUniqueEntries leaf_ids; - for (auto tv : all_tvs) { - leaf_ids.pushBack(tv->domain()->leaf()); - } + // VectorOfUniqueEntries leaf_ids; + // for (auto tv : all_tvs) { + // leaf_ids.pushBack(tv->domain()->leaf()); + // } - for (const auto& disjoint_set : - idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { - for (auto id : disjoint_set->vector()) { - auto id_ptype = id->getParallelType(); + // for (const auto& disjoint_set : + // idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + // for (auto id : disjoint_set->vector()) { + // auto id_ptype = id->getParallelType(); - TORCH_INTERNAL_ASSERT( - leaf_ids.has(id) || id_ptype == ParallelType::Serial, - "Invalid parallelization of non leaf iter domain: ", - id->toString()); - } - } + // TORCH_INTERNAL_ASSERT( + // leaf_ids.has(id) || id_ptype == ParallelType::Serial, + // "Invalid parallelization of non leaf iter domain: ", + // id->toString()); + // } + // } } void IterDomainGraphs::propagateLoopPTypes() const { diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h index 46fbb31bf8d..37661c2e019 100644 --- a/csrc/id_graphs.h +++ b/csrc/id_graphs.h @@ -497,7 +497,7 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { // Replay Expr but with the inputs provided. IterDomainGraphss will be updated // for all maps that have entries, adding the output iter domains of the // replayed expression and adding potential mappings through the expression. - Expr* addReplayAs(const std::vector& new_inputs, Expr* expr); + Expr* addReplayAs(std::vector new_inputs, Expr* expr); // Similar to addReplayAs, but clones the expr exactly instead of replaying it // forward. It's up to the calling code to make sure the replacements are From 4db481b64b36c6f64a6ffdd6b5e517dd0d957c76 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 9 May 2023 09:12:06 -0400 Subject: [PATCH 026/178] Project CI Green --- test/test_gpu_indexing.cpp | 33 +++++++++++--- test/test_loop_rotation.cpp | 86 ++++++++++++++++++------------------- 2 files changed, 71 insertions(+), 48 deletions(-) diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index c6673f8d06f..d51af7d1cc2 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -832,6 +832,7 @@ TEST_F(NVFuserTest, FusionIndexing18_CUDA) { } // TODO: Finish and enable test +#if 0 // // Create a case where we're missing a valid concrete id so the compute at map // processing will fail. We need to be able to create the concrete ID not just @@ -880,7 +881,9 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { fusion.print(); fusion.printKernel(); } +#endif +#if 0 // TODO: Finish and enable test // // Progressive loop promotion. producer gets promoted in consumer, consumer is @@ -931,6 +934,7 @@ TEST_F(NVFuserTest, FusionIndexing20_CUDA) { fusion.printKernel(); } +#endif // Repro for issue #1873 TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) { @@ -972,9 +976,6 @@ TEST_F(NVFuserTest, FusionMultiPromotion_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - int w = 3, x = 4, y = 7, z = 8; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - // [y] auto tv0 = makeSymbolicTensor(1); // [w, x, y, z] @@ -1000,6 +1001,9 @@ TEST_F(NVFuserTest, FusionMultiPromotion_CUDA) { FusionExecutor fe; + int w = 3, x = 4, y = 7, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({y}, options); at::Tensor t1 = at::randn({w, x, y, z}, options); @@ -1015,6 +1019,8 @@ TEST_F(NVFuserTest, FusionMultiPromotion_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } +#if 0 +// TODO: Finish and enable test. // Broadcast and concretize same domain in two different ways and try to merge // their loops remains unsupported. TEST_F(NVFuserTest, FusionMultiPromotion2_CUDA) { @@ -1056,6 +1062,7 @@ TEST_F(NVFuserTest, FusionMultiPromotion2_CUDA) { ASSERT_ANY_THROW(fusion.printKernel()); } +#endif // TODO: All the above tests are merges followed by splits, we should make some // more complex examples even though merging then spliting is the most likely @@ -1086,8 +1093,24 @@ TEST_F(NVFuserTest, FusionIndexSplitMerge_CUDA) { MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); inlineAllAt(tv3, 2, false); - fusion.printKernel(); -} + FusionExecutor fe; + + int x = 4, y = 7; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({x}, options); + at::Tensor t1 = at::randn({x, y}, options); + + auto t2 = t0.unsqueeze(-1); + auto aten_output = t1.add(t2); + + std::vector aten_inputs = {t0, t1}; + + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} } // namespace nvfuser diff --git a/test/test_loop_rotation.cpp b/test/test_loop_rotation.cpp index b4163426fc1..266d4e7dabb 100644 --- a/test/test_loop_rotation.cpp +++ b/test/test_loop_rotation.cpp @@ -36,11 +36,11 @@ TEST_F(LoopRotationTest, RotateInner_CUDA) { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO #pragma unroll 1 - for(nvfuser_index_t i21 = 0; i21 < T0.size[0]; ++i21) { + for(nvfuser_index_t i22 = 0; i22 < T0.size[0]; ++i22) { int64_t i52; - i52 = T0.stride[0] * i21; + i52 = T0.stride[0] * i22; int64_t i84; - i84 = 3 * i21; + i84 = 3 * i22; float T1[1]; float T2[1]; T1[0] = 0; @@ -50,13 +50,13 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { = T1[0]; NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll - for(nvfuser_index_t i22 = 0; i22 < 3; ++i22) { + for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) { int64_t i111; - i111 = (1 + i22) + nvfuser_zero; + i111 = (1 + i21) + nvfuser_zero; float T3[1]; T3[0] = T2[0]; - T4[(i84 + (i22 + nvfuser_zero))] + T4[(i84 + (i21 + nvfuser_zero))] = T3[0]; T1[0] = 0; if ((i111 < 3)) { @@ -125,13 +125,13 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll 1 - for(nvfuser_index_t i24 = 0; i24 < T0.size[0]; ++i24) { + for(nvfuser_index_t i25 = 0; i25 < T0.size[0]; ++i25) { int64_t i90; - i90 = 3 * i24; + i90 = 3 * i25; int64_t i123; - i123 = T0.stride[0] + (T0.stride[0] * i24); + i123 = T0.stride[0] + (T0.stride[0] * i25); bool b225; - b225 = (1 + i24) < T0.size[0]; + b225 = (1 + i25) < T0.size[0]; // Alias Allocation - register auto& T3 = T1; #pragma unroll @@ -141,9 +141,9 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll - for(nvfuser_index_t i25 = 0; i25 < 3; ++i25) { - T4[(i90 + (i25 + nvfuser_zero))] - = T3[i25]; + for(nvfuser_index_t i24 = 0; i24 < 3; ++i24) { + T4[(i90 + (i24 + nvfuser_zero))] + = T3[i24]; } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll @@ -228,9 +228,9 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll 1 - for(nvfuser_index_t i39 = 0; i39 < (ceilDiv((T0.size[0] * T0.size[1]), 5)); ++i39) { + for(nvfuser_index_t i40 = 0; i40 < (ceilDiv((T0.size[0] * T0.size[1]), 5)); ++i40) { int64_t i238; - i238 = 5 * i39; + i238 = 5 * i40; int64_t i936; i936 = 5 + i238; // Alias Allocation - register @@ -242,12 +242,12 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll - for(nvfuser_index_t i40 = 0; i40 < 5; ++i40) { + for(nvfuser_index_t i39 = 0; i39 < 5; ++i39) { int64_t i239; - i239 = i238 + (i40 + nvfuser_zero); + i239 = i238 + (i39 + nvfuser_zero); if ((i239 < i1117)) { T4[i239] - = T3[i40]; + = T3[i39]; } } NVFUSER_UPDATE_MAGIC_ZERO @@ -310,13 +310,13 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { i199 = T0.stride[0] * 4; float T1[15]; #pragma unroll - for(nvfuser_index_t i24 = 0; i24 < 4; ++i24) { + for(nvfuser_index_t i25 = 0; i25 < 4; ++i25) { int64_t i62; - i62 = 3 * i24; + i62 = 3 * i25; int64_t i79; - i79 = T0.stride[0] * i24; + i79 = T0.stride[0] * i25; bool b377; - b377 = (i24 + nvfuser_zero) < T0.size[0]; + b377 = (i25 + nvfuser_zero) < T0.size[0]; #pragma unroll for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) { T1[(i62 + i21)] = 0; @@ -338,17 +338,17 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll 1 - for(nvfuser_index_t i25 = 0; i25 < T0.size[0]; ++i25) { + for(nvfuser_index_t i26 = 0; i26 < T0.size[0]; ++i26) { int64_t i169; - i169 = 4 + i25; + i169 = 4 + i26; int64_t i171; i171 = 3 * (i169 % 5); int64_t i201; - i201 = i199 + (T0.stride[0] * i25); + i201 = i199 + (T0.stride[0] * i26); int64_t i298; - i298 = 3 * i25; + i298 = 3 * i26; int64_t i365; - i365 = 3 * ((1 + i25) % 5); + i365 = 3 * ((1 + i26) % 5); bool b440; b440 = i169 < T0.size[0]; #pragma unroll @@ -372,9 +372,9 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll - for(nvfuser_index_t i27 = 0; i27 < 3; ++i27) { - T4[(i298 + (i27 + nvfuser_zero))] - = T3[i27]; + for(nvfuser_index_t i24 = 0; i24 < 3; ++i24) { + T4[(i298 + (i24 + nvfuser_zero))] + = T3[i24]; } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll @@ -440,13 +440,13 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll - for(nvfuser_index_t i24 = 0; i24 < 4; ++i24) { + for(nvfuser_index_t i25 = 0; i25 < 4; ++i25) { int64_t i87; - i87 = 3 + (3 * i24); + i87 = 3 + (3 * i25); int64_t i116; - i116 = T0.stride[0] + (T0.stride[0] * i24); + i116 = T0.stride[0] + (T0.stride[0] * i25); bool b575; - b575 = ((1 + i24) + nvfuser_zero) < T0.size[0]; + b575 = ((1 + i25) + nvfuser_zero) < T0.size[0]; #pragma unroll for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) { T1[(i87 + i21)] = 0; @@ -481,17 +481,17 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll 1 - for(nvfuser_index_t i25 = 0; i25 < T0.size[0]; ++i25) { + for(nvfuser_index_t i26 = 0; i26 < T0.size[0]; ++i26) { int64_t i217; - i217 = 3 * i25; + i217 = 3 * i26; int64_t i304; - i304 = 3 * (i25 % 5); + i304 = 3 * (i26 % 5); int64_t i342; - i342 = i340 + (T0.stride[0] * i25); + i342 = i340 + (T0.stride[0] * i26); int64_t i498; - i498 = 3 * ((1 + i25) % 5); + i498 = 3 * ((1 + i26) % 5); bool b680; - b680 = (5 + i25) < T0.size[0]; + b680 = (5 + i26) < T0.size[0]; float T3[3]; #pragma unroll for(nvfuser_index_t i23 = 0; i23 < 3; ++i23) { @@ -500,9 +500,9 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll - for(nvfuser_index_t i27 = 0; i27 < 3; ++i27) { - T4[(i217 + (i27 + nvfuser_zero))] - = T3[i27]; + for(nvfuser_index_t i24 = 0; i24 < 3; ++i24) { + T4[(i217 + (i24 + nvfuser_zero))] + = T3[i24]; } NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll From ad7012b486a44ff4646e70d7bfa52346383d8fb3 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 9 May 2023 18:29:27 -0400 Subject: [PATCH 027/178] Reenable self mapping. --- csrc/id_graphs.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index 3f150d3cc60..aff8749d115 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -1411,12 +1411,11 @@ bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return true; } -// TODO: Actually assert if self mapping found. Self mapping test is not correct -// yet. void IterDomainGraphs::assertNoSelfMapping() { if (hasSelfMapping()) { - TORCH_WARN( - "IdGraphs thinks there's a self mapping in the problem. It's probably IdGraphs problem, not yours... ", + TORCH_INTERNAL_ASSERT( + !hasSelfMapping(), + "Unsupported domain mapping detected in ", std::get<0>(*self_mapping_info_)->toString(), ". ", std::get<3>(*self_mapping_info_), @@ -1679,8 +1678,11 @@ findFirstSelfMapping( } // Leaf domains + // TODO: Exact map isn't quite right here, it should be based on the index + // map. However, it should also be impossible for index map to generate a + // case like this. auto self_mappped_leaf_pair = - detectMappablePair(tv->domain()->leaf(), id_graph, IdMappingMode::LOOP); + detectMappablePair(tv->domain()->leaf(), id_graph, IdMappingMode::EXACT); if (self_mappped_leaf_pair.has_value()) { return std::make_tuple( tv, From a9d192eff1432a03daeda567d78c6d4f4672e2da Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 9 May 2023 18:32:05 -0400 Subject: [PATCH 028/178] Cleanup. --- csrc/id_graphs.cpp | 850 +-------------------------------------------- 1 file changed, 3 insertions(+), 847 deletions(-) diff --git a/csrc/id_graphs.cpp b/csrc/id_graphs.cpp index aff8749d115..bc315b219ec 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_graphs.cpp @@ -1681,8 +1681,8 @@ findFirstSelfMapping( // TODO: Exact map isn't quite right here, it should be based on the index // map. However, it should also be impossible for index map to generate a // case like this. - auto self_mappped_leaf_pair = - detectMappablePair(tv->domain()->leaf(), id_graph, IdMappingMode::EXACT); + auto self_mappped_leaf_pair = detectMappablePair( + tv->domain()->leaf(), id_graph, IdMappingMode::EXACT); if (self_mappped_leaf_pair.has_value()) { return std::make_tuple( tv, @@ -2199,9 +2199,6 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { ForwardingInfo permissive_forwarding(p_tv, c_tv); for (auto entry : permissive_forwarding.producer_forwarding_map) { - // std::cout << "Permissive producer forwarding: " - // << entry.first->toString() << " -> " - // << entry.second->toString() << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); } @@ -2209,18 +2206,11 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { // TODO: Why should IDs be mapped to their compliments? Is this right? for (auto entry : permissive_forwarding.producer_compliment_map) { for (auto entry_2 : entry.second) { - // std::cout << "Permissive producer compliment: " - // << entry.first->toString() << " -> " << - // entry_2->toString() - // << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); } } for (auto entry : permissive_forwarding.consumer_forwarding_map) { - // std::cout << "Permissive consumer forwarding: " - // << entry.first->toString() << " -> " - // << entry.second->toString() << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); } @@ -2228,10 +2218,6 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { // TODO: Why should IDs be mapped to their compliments? Is this right? for (auto entry : permissive_forwarding.consumer_compliment_map) { for (auto entry_2 : entry.second) { - // std::cout << "Permissive consumer compliment: " - // << entry.first->toString() << " -> " << - // entry_2->toString() - // << std::endl; idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); } } @@ -2375,8 +2361,6 @@ StatefulLoweringInfo buildInfo( all_producer_ca_deps.insert( ca_deps_filter.begin(), ca_deps_filter.end()); } - // std::cout << "Producer: " << producer->toString() << "\n " - // << all_producer_ca_deps.toString() << std::endl; info.ordered_p_ca_ids.pushBack(all_producer_ca_deps); @@ -2480,21 +2464,12 @@ void IterDomainGraphs::build( if (FusionGuard::getCurFusion()->isA()) { validatePTypes(all_tvs); - // FusionGuard::getCurFusion()->print(std::cout, true); - StatefulLoweringInfo info = buildInfo( tv_exprs, idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::PERMISSIVE)); initializeLoopMap(info); - // std::cout << "Loop groups: " - // << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) - // << std::endl; - - // std::cout << "Promoted groups: " - // << debug::idGroupsString(idGraph(IdMappingMode::LOOP)) - // << std::endl; // Initial propagation of parallel types for inlined iter domains. Each time // new expressions are replayed this needs to be run. The disjoint sets in @@ -2742,21 +2717,16 @@ std::unordered_map IterDomainGraphs:: iel_promotion_map[iel_group] = promoted_iel_groups.front()->front(); } - // std::cout << "Initial promotion map:" << std::endl; - for (auto iel_group : intersection_exact_loop_graph.disjointIdSets().disjointSets()) { auto entry_it = iel_promotion_map.find(iel_group); if (entry_it == iel_promotion_map.end()) { continue; } - // std::cout << " " << entry_it->second->toString() << " <- " - // << entry_it->first->toString() << std::endl; } IdGraphStmtSort iel_stmt_sort(intersection_exact_loop_graph); - // std::cout << "Initial promotion replay:" << std::endl; for (auto iel_expr : iel_stmt_sort.exprs()) { auto input_groups = intersection_exact_loop_graph.inputGroups(iel_expr); @@ -2827,8 +2797,6 @@ std::unordered_map IterDomainGraphs:: bool replayed = replay == nullptr; if (replay == nullptr) { replay = addReplayAs(promoted_inputs, iel_expr->front()); - // std::cout << " ***REPLAY***:\n " << iel_expr->front() - // << " As:" << replay->toString(); } auto out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); @@ -3110,13 +3078,9 @@ std::unordered_map IterDomainGraphs:: } } - // std::cout << "Loop promotion before second replay:" << std::endl; for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_graph_copy_promotion_map.find(loop_group) != loop_graph_copy_promotion_map.end()) { - // std::cout << debug::toString(loop_group, 0, true) << " -> " - // << loop_graph_copy_promotion_map[loop_group]->toString() - // << std::endl; } } @@ -3250,12 +3214,6 @@ std::unordered_map IterDomainGraphs:: if (replay == nullptr) { replay = addReplayAs(promoted_inputs, iel_expr->front()); } - // std::cout << " ***REPLAY2***:\n " << iel_expr->front() - // << " As:" << replay->toString(); - // } else { - // std::cout << " ***MATCH2***:\n " << iel_expr->front() - // << " As:" << replay->toString(); - // } auto output_groups = intersection_exact_loop_graph.outputGroups(iel_expr); @@ -3289,14 +3247,11 @@ std::unordered_map IterDomainGraphs:: } } - // std::cout << "Promotion map from second replay: " << std::endl; for (auto group : intersection_exact_loop_graph.disjointIdSets().disjointSets()) { if (iel_promotion_map.find(group) == iel_promotion_map.end()) { continue; } - // std::cout << debug::toString(group, 0, true) << " -> " - // << iel_promotion_map.at(group)->toString() << std::endl; } return iel_promotion_map; @@ -3307,806 +3262,7 @@ std::unordered_map IterDomainGraphs::buildIndexGraph( const std::vector& all_tvs, StatefulLoweringInfo& info, std::unordered_map stale_promotion_map) { - // Update the iel graph - auto intersection_exact_loop_graph = buildIntersection( - idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); - - // Update the promotion map - auto iel_promotion_map = - updateMap(stale_promotion_map, intersection_exact_loop_graph); - - auto exact_covered_ids = - computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); - - // Grab terminal iter domain in the loop groups. - VectorOfUniqueEntries terminal_loop_ids = - computeTerminalLoopIds(info); - - // Loop promotion map is to prepare for IterDomain replays. Since these - // replays will modify the loop map, we operate on a copy of the loop map, - // not the original one. - // Loop promotion map is to prepare for IterDomain replays to resolve - // non-inlined loop groups. Since these replays will modify the loop map as - // we're iterating over the loop map, operate on a copy of the loop map, not - // the original one. - auto loop_graph_copy = idGraph(IdMappingMode::LOOP); - - // Build a map from loop iter domain group to a promoted iter domain (doesn't - // have to be in the loop group) that covers all the exact groups - // representative of the resolved transformations within the loop group. Only - // the inlined loop groups will be covered here. - std::unordered_map loop_graph_copy_promotion_map; - - // Returns a new promoted domain if one is found in the iel_promotion_map, - // otherwise returns original id. - auto get_promoted_id = [&intersection_exact_loop_graph, - &iel_promotion_map](IterDomain* id) { - auto iel_group = intersection_exact_loop_graph.toGroup(id); - auto iel_promotion_map_it = iel_promotion_map.find(iel_group); - if (iel_promotion_map_it != iel_promotion_map.end()) { - return iel_promotion_map_it->second; - } - return id; - }; - - // Returns the entry in exact_covered_ids associated with provided IterDomain. - // Basically calling .at but with a better error. - auto get_covered_exact_groups = [&](IterDomain* id) { - auto exact_group = idGraph(IdMappingMode::EXACT).toGroup(id); - auto covered_it = exact_covered_ids.find(exact_group); - TORCH_INTERNAL_ASSERT( - covered_it != exact_covered_ids.end(), - "Missing map entry in analysis for: ", - debug::toString(exact_group, 0, true)); - return covered_it->second; - }; - - // Now we need to find the right promoted ID for every loop group, making - // sure the promoted ID covers every ID of the IDs in the loop group. - // This ID could be a terminal ID in the group. A promoted ID of the terminal - // IDs, or an ID that was replayed previously and now part of the loop group. - // - // The correct/final promoted ID of the loop group must exist at this point. - // It just might not be within the loop group we're looking at. - // std::cout << "Find promoted ids from loop group or promoted iter domains." - // << std::endl; - for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { - if (loop_group->size() == 1) { - auto promoted_id = get_promoted_id(loop_group->front()); - - TORCH_INTERNAL_ASSERT( - get_covered_exact_groups(loop_group->front()) - .subtract(get_covered_exact_groups(promoted_id)) - .size() == 0, - "Promotion failed, promoted id: ", - promoted_id->toString(), - " doesn't cover the right domains for ", - loop_group->front()->toString()); - loop_graph_copy_promotion_map[loop_group] = promoted_id; - continue; - } - - // If promotion entry exists for any terminal id the promoted id will be - // stored here. - std::vector promoted_terminal_ids; - - // If a promotion entry doesn't exist for a terminal id, put it here. - std::vector terminal_ids; - - // All exact groups that the terminal loop id's cover. - IdGroups all_covered_exact_groups; - - // Populate all three structures above. - for (auto loop_id : *loop_group) { - if (!terminal_loop_ids.has(loop_id)) { - continue; - } - - all_covered_exact_groups.pushBack(get_covered_exact_groups(loop_id)); - - auto promoted_id = get_promoted_id(loop_id); - if (promoted_id == loop_id) { - terminal_ids.push_back(loop_id); - } else { - promoted_terminal_ids.push_back(promoted_id); - } - } - - // If promoted id's exist, those are the candidates to have the right - // transformations for indexing. Otherwise, use the terminal _ids. - auto candidate_ids = - promoted_terminal_ids.empty() ? terminal_ids : promoted_terminal_ids; - - // Find the loop promotion id from the candidates. - IterDomain* loop_promotion_id = nullptr; - for (auto candidate_id : candidate_ids) { - if (all_covered_exact_groups - .subtract(get_covered_exact_groups(candidate_id)) - .empty()) { - loop_promotion_id = candidate_id; - break; - } - } - - // If we're still missing the loop_promotion_id, check all replayed IDs in - // the loop group. - if (loop_promotion_id == nullptr) { - candidate_ids = loop_group->subtract(info.ordered_c_ids).vector(); - for (auto candidate_id : candidate_ids) { - if (all_covered_exact_groups - .subtract(get_covered_exact_groups(candidate_id)) - .empty()) { - loop_promotion_id = candidate_id; - } - } - } - - if (loop_promotion_id == nullptr) { - std::stringstream err_msg; - err_msg << "\nCould not find promotion for loop group:\n "; - err_msg << debug::toString(loop_group, 0, true); - err_msg << "\nnone of the candidate iter domains of this group:\n "; - err_msg << " " - << VectorOfUniqueEntries(candidate_ids).toString(); - err_msg << "\n cover all id groups that the loop group covers:\n"; - err_msg << " " << debug::toString(all_covered_exact_groups) << std::endl; - TORCH_INTERNAL_ASSERT(false, err_msg.str()); - } - - loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; - } - - // std::cout << "Promotion map to build the Index Graph: " << std::endl; - for (auto group : loop_graph_copy.disjointIdSets().disjointSets()) { - if (loop_graph_copy_promotion_map.find(group) == - loop_graph_copy_promotion_map.end()) { - continue; - } - // std::cout << debug::toString(group, 0, true) << " -> " - // << loop_graph_copy_promotion_map.at(group)->toString() - // << std::endl; - } - - // Indexing traversal must start at leaf nodes of TensorViews as that's where - // the loop indices are defined. For indexing we need to propagate leaves to - // root domains. We want the indexing graph easy to traverse. Easy to traverse - // means that we start at terminating outputs of this graph and propagate to - // terminating inputs. We shouldn't have to worry about which paths each time - // we traverse the index graph as we may do it many times. - - // The IEL Map cannot be traversed for indexing, because the loop map is - // really only used to model broadcast promotion. We could have multiple paths - // from leaf nodes to an intermediate IEL entry. Meaning: - - // T0 root[i0, i1] T0 leaf domain [i0*i1//32, 4, 8] - // T1 root[i0, i1] T0 leaf domain [i0*i1//32, 8, 4] - - // Even though T0 and T1 are inlined on the outer most dimension, indexing - // into their roots is different. Yet, their roots would be in the same IEL - // entries. - - // The index graph should provide a direct model of what indices are reused, - // i.e. if two ID's in the IndexMap map to eachother, they should use the same - // index math. Therefore, roughly what we need to do is: - - // - Figure out which leaves share exact indexing and map them together: - // (1) Producer-consumer leaf nodes are inlined with eachother (map to the - // same promoted id) - // (2) Promoted producer-consumer leaf nodes are almost exact, have the same - // parallel type, but are not inlined. - - // - Start at the promoted leaf nodes of each tensor view - - // - If those promoted leaf nodes are *ALMOST EXACT* mapped from - // producer-consumer they can be mapped in the index map - - // - Traversing backward from each tensor view's leaf nodes, we directly reach - // the root nodes of that tensor view - - // - During the backward traversal, for an expression, if the output iter - // domains are mapped in the index map, their inputs should be mapped as well. - // So as we build the index map, we could also be accumulating mapped iter - // domains. - - // Mark all iter domains that share a loop nest and are almost exact mapped. - // Ignores promotion. - - // Doing the same as above on promoted iter domains is a bit tricky, because - // there's a promoted IterDomian per IEL group, we need a promoted IterDomain - // per index group. So let's figure out which leaf domains share a promoted - // iter domain, so we don't have to build a promoted iter domain for every - // leaf, then try to rejoin them. - - // TODO: I think we need to validate that for each tensor view leaf domains, - // no two leaves within a tensor domain map to another leaf in the same tensor - // domain in the IEL graph. Not sure how this could occur, but I suspect it - // could. - - // Which non-promoted iter domains, share their promoted iterdomains - DisjointSets shared_promoted_id; - - for (auto expr : exprs) { - std::unordered_map> - promo_id_to_producer_ids; - std::unordered_map> - promo_id_to_consumer_ids; - - // Copy of all promo ids for determinism - VectorOfUniqueEntries all_promo_ids; - - for (auto producer : ir_utils::filterByType(expr->inputs())) { - for (auto p_id : producer->domain()->leaf()) { - // Initialize all entries - shared_promoted_id.initializeSet(p_id); - - auto loop_copy_p_group_pair = loop_graph_copy.disjointIdSet(p_id); - TORCH_INTERNAL_ASSERT(loop_copy_p_group_pair.second); - auto loop_copy_p_group = loop_copy_p_group_pair.first; - - auto promo_id_it = - loop_graph_copy_promotion_map.find(loop_copy_p_group); - TORCH_INTERNAL_ASSERT( - promo_id_it != loop_graph_copy_promotion_map.end()); - - promo_id_to_producer_ids[promo_id_it->second].pushBack(p_id); - all_promo_ids.pushBack(promo_id_it->second); - } - } - - for (auto consumer : ir_utils::filterByType(expr->outputs())) { - for (auto c_id : consumer->domain()->leaf()) { - // Initialize all entries - shared_promoted_id.initializeSet(c_id); - - auto loop_copy_c_group_pair = loop_graph_copy.disjointIdSet(c_id); - TORCH_INTERNAL_ASSERT(loop_copy_c_group_pair.second); - auto loop_copy_c_group = loop_copy_c_group_pair.first; - - auto promo_id_it = - loop_graph_copy_promotion_map.find(loop_copy_c_group); - TORCH_INTERNAL_ASSERT( - promo_id_it != loop_graph_copy_promotion_map.end()); - - promo_id_to_consumer_ids[promo_id_it->second].pushBack(c_id); - all_promo_ids.pushBack(promo_id_it->second); - } - } - - for (auto promo_id : all_promo_ids) { - auto p_ids_it = promo_id_to_producer_ids.find(promo_id); - if (p_ids_it == promo_id_to_producer_ids.end()) { - continue; - } - auto p_ids = p_ids_it->second; - - auto c_ids_it = promo_id_to_consumer_ids.find(promo_id); - if (c_ids_it == promo_id_to_consumer_ids.end()) { - continue; - } - auto c_ids = c_ids_it->second; - - if (c_ids.size() && p_ids.size()) { - for (auto p_id : p_ids) { - shared_promoted_id.mapEntries(p_ids.front(), p_id); - } - for (auto c_id : c_ids) { - shared_promoted_id.mapEntries(p_ids.front(), c_id); - } - } - } - } - - auto get_representative_promoted_id = [&](IterDomain* id) { - auto promo_id_it = - loop_graph_copy_promotion_map.find(loop_graph_copy.toGroup(id)); - TORCH_INTERNAL_ASSERT(promo_id_it != loop_graph_copy_promotion_map.end()); - return promo_id_it->second; - }; - - // std::cout << "Opportunistic joining of shared promos:" << std::endl; - // Opportunistically collapse indexing of non-inlined leaf domains if their - // promoted ids are almost exact mapped and have the same parallel type. - for (auto expr : exprs) { - for (auto producer : ir_utils::filterByType(expr->inputs())) { - // std::cout << " Producer: " << producer->toString() << std::endl; - auto producer_root = producer->getMaybeRFactorDomain(); - - auto non_inline_producer_domain = producer->domain()->leaf(); - non_inline_producer_domain.erase( - non_inline_producer_domain.begin(), - non_inline_producer_domain.begin() + - producer->getComputeAtPosition()); - - for (auto consumer : - ir_utils::filterByType(expr->outputs())) { - // std::cout << " Consumer: " << consumer->toString() << std::endl; - auto consumer_domain = consumer->domain()->leaf(); - - auto p2c_permissive_map = - idGraph(IdMappingMode::PERMISSIVE) - .buildMapBetween(non_inline_producer_domain, consumer_domain); - - for (auto p_id : non_inline_producer_domain) { - auto p2c_it = p2c_permissive_map.find(p_id); - if (p2c_it == p2c_permissive_map.end() || p2c_it->second.empty()) { - continue; - } - - auto rep_p_id = get_representative_promoted_id(p_id); - auto c_id = p2c_it->second.front(); - auto rep_c_id = get_representative_promoted_id(c_id); - - // std::cout << " " << p_id->toString() << " -> " - // << rep_p_id->toString() << " :: " << c_id->toString() - // << " -> " << rep_c_id->toString() << std::endl; - if (!idGraph(IdMappingMode::ALMOSTEXACT) - .disjointIdSets() - .strictAreMapped(rep_p_id, rep_c_id)) { - continue; - } - if (rep_p_id->getParallelType() != rep_c_id->getParallelType()) { - continue; - } - // std::cout << " Mapped" << std::endl; - shared_promoted_id.mapEntries(p_id, c_id); - } - } - } - } - - // std::cout << "Leaf iter domains that share a promoted iter domain." - // << std::endl; - // for (auto disjoint_set : shared_promoted_id.disjointSets()) { - // std::cout << disjoint_set->toString() << std::endl; - // } - - // Map from leaf iter domains to their potentially promoted iter domain used - // for indexing. - std::unordered_map leaf_promotion_map; - - // If a promoted iter domain was generated by replays, it won't be connected - // in the index graph. We can reuse these iter domains directly instead of - // having to make a clone of them. However, we can only use them once for a - // group. - VectorOfUniqueEntries used_promo_ids; - - for (auto id_group : shared_promoted_id.disjointSets()) { - IterDomain* promo_id = get_representative_promoted_id(id_group->front()); - - // Promoted id is already part of the group, just use that. - if (std::find(id_group->begin(), id_group->end(), promo_id) != - id_group->end()) { - for (auto id : *id_group) { - leaf_promotion_map[id] = promo_id; - } - continue; - } - - // Promo id generated from running replay, we can use it for one of the - // index groups. - if (!info.ordered_c_ids.has(promo_id) && !used_promo_ids.has(promo_id)) { - used_promo_ids.pushBack(promo_id); - for (auto id : *id_group) { - leaf_promotion_map[id] = promo_id; - } - continue; - } - - // Need to take a copy of the promo_id as it's already dedicated to an index - // group. - promo_id = cloneIterDomain(promo_id); - for (auto id : *id_group) { - leaf_promotion_map[id] = promo_id; - } - } - - // TODO: This needs to be available as a member function - auto get_promoted_domain = [&](TensorDomain* td) { - std::vector promoted_leaves; - for (auto id : td->leaf()) { - auto promo_it = leaf_promotion_map.find(id); - TORCH_INTERNAL_ASSERT(promo_it != leaf_promotion_map.end()); - promoted_leaves.push_back(promo_it->second); - } - return promoted_leaves; - }; - - // std::cout << "Iter domain group to their promoted iter domain." << - // std::endl; for (auto id_group : shared_promoted_id.disjointSets()) { - // std::cout << id_group->toString() << "\n -> " - // << leaf_promotion_map.at(id_group->front()) << std::endl; - // } - - // Track every expression required for indexing - VectorOfUniqueEntries all_index_exprs; - // Track every iter domain required for indexing - VectorOfUniqueEntries all_index_ids; - - // std::cout << "\n\nThird and final replay" << std::endl; - // std::cout << "Building promoted tensor view domains:" << std::endl; - // Need to "replay" all of the indexing expressions to make sure roots are - // connected to the promoted leaves, in a way we can index directly on the - // index graph. - // - // Since we're performing replays we need to copy the graph we're iterating - // on. - auto ae_graph = idGraph(IdMappingMode::ALMOSTEXACT); - - // Because of how replays work in buildInlinePromotions and - // buildLoopPromotionMap, we could have multiple uses and definitions of the - // the same iter domain. - // - // However, for the index graph we want to go back to every iter domain having - // at most one use and definition. - // - // We also want to use expressions that exist if we can. - // - // If there's multiple paths on the index graph then we would generate - // conflicting indicies (unless somehow the expressions all end up collapsing - // by being mapped later). Enforce one defintion and use per iter domain. - std::unordered_map id_to_index_use; - std::unordered_map id_to_index_def; - - // Initialize index graph using the history of each tensorview. These - // expressions are not guaranteed to be used, but if it is used, this will - // prefer those used in a tv's history. - // - // This prevents conflicts later where we try to reuse an expression and take - // an expression in another tensor view's history. - for (auto tv : all_tvs) { - auto transforms = StmtSort::getExprsBetween( - FusionGuard::getCurFusion(), - {tv->getRootDomain().begin(), tv->getRootDomain().end()}, - {tv->domain()->leaf().begin(), tv->domain()->leaf().end()}); - for (auto transform : transforms) { - for (auto inp : ir_utils::filterByType(transform->inputs())) { - id_to_index_use[inp] = transform; - } - for (auto out : - ir_utils::filterByType(transform->outputs())) { - id_to_index_def[out] = transform; - } - } - } - - // Manually initialize the index graph - for (auto id_group : - idGraph(IdMappingMode::ALMOSTEXACT).disjointIdSets().disjointSets()) { - for (auto id : *id_group) { - VectorOfUniqueEntries defs; - if (id_to_index_def.find(id) != id_to_index_def.end()) { - defs.pushBack(id_to_index_def.at(id)); - } - - VectorOfUniqueEntries uses; - if (id_to_index_use.find(id) != id_to_index_use.end()) { - uses.pushBack(id_to_index_use.at(id)); - } - - idGraph(IdMappingMode::INDEX).initializeId(id, defs, uses); - } - } - - idGraph(IdMappingMode::INDEX).mapThroughTrivialExprs(); - idGraph(IdMappingMode::INDEX).removeTrivialExprs(); - - for (auto tv : all_tvs) { - // We don't have to process inputs at this point as they're already - // allocated on a global - if (tv->isFusionInput()) { - continue; - } - - auto promoted_domain = get_promoted_domain(tv->domain()); - // replay from root to promoted leaves. - // std::cout << "\n\n Processing: TV" << tv->name() << "\n Root: TV" - // << tv->getRootDomain() - // << "\n Domain promoted to: " << promoted_domain << - // std::endl; - - // The promoted leaf iter domains are where indexing starts. We're going to - // start at those expressions and replay transformations for this tensor - // view working back to root domains. We want to intercept the history of - // the transformations local to the tensor view where possible. - // - // So effectively what we have to do is map the ae graph to the history of - // the tensor view as well as the promoted iter domains. We start traversal - // at the promoted iter domains and will intercept the tensor view history - // as possible. - // - // We must be able to interecept the provided tensor view at the rfactor and - // root domains, otherwise we wouldn't be able to allocate or index into the - // buffer at tensor view (rfactor domain) or it's producer (root domain). - - // Grab all the domains and convert them to their ae groups. - auto all_ids_v = ir_utils::allIDsOf(tv); - auto all_ids = - VectorOfUniqueEntries(all_ids_v.begin(), all_ids_v.end()); - - // Create a map from the ae group to the iter domain as when we replay we'll - // replace the ae iter domain in the replay with the id in this map. - std::unordered_map ae_group_2_id; - - for (auto tv_id : all_ids) { - // Use emplace here as it multiple tv_ids could map to the same ae_group. - // Emplace will simply grab the first one that appears. - ae_group_2_id.emplace(std::make_pair(ae_graph.toGroup(tv_id), tv_id)); - } - - // Add the promoted domain ids - for (auto promoted_id : promoted_domain) { - all_ids.pushBack(promoted_id); - ae_group_2_id[ae_graph.toGroup(promoted_id)] = promoted_id; - } - - auto ae_leaf_groups = ae_graph.toGroups(VectorOfUniqueEntries{ - promoted_domain.begin(), promoted_domain.end()}); - - // Don't support multiple leaf domains promoted to the same ae graph at this - // point. - TORCH_INTERNAL_ASSERT( - ae_leaf_groups.size() == promoted_domain.size(), - "Multiple leaf domains that map almost exactly is not supported at this point."); - - auto ae_root_groups = ae_graph.toGroups(VectorOfUniqueEntries{ - tv->getRootDomain().begin(), tv->getRootDomain().end()}); - - // Make a copy of the expressions so we can reverse them - auto reverse_indexing_transforms = - ae_graph.getExprsBetween(ae_root_groups, ae_leaf_groups).vector(); - - std::reverse( - reverse_indexing_transforms.begin(), reverse_indexing_transforms.end()); - - // Replay indexing transformations start on leaf nodes propagating back to - // the root domain - for (ExprGroup ae_expr_group : reverse_indexing_transforms) { - // Outputs must be promoted with the ae_group_2_id map. Inputs may be - // promoted when we intercept the history of the TV with the replay. - // - // if there isn't an entry in ae_group_2_id, then we have a resolved - // merged in broadcast, and that resolved iter domain will need to be - // cloned. Would be nice to see if the dangling input has already been - // added already through another indexing path that this overlaps with, - // however having an additional ID and expression per case doesn't seem - // too bad right now. - - auto ae_output_groups = ae_graph.outputGroups(ae_expr_group); - - std::vector promoted_outputs; - for (auto out_group : ae_output_groups) { - auto out_promo_it = ae_group_2_id.find(out_group); - if (out_promo_it == ae_group_2_id.end()) { - promoted_outputs.push_back(out_group->front()); - } else { - promoted_outputs.push_back(out_promo_it->second); - } - } - - Expr* replay = nullptr; - - // Check if we already have this expression covered in the index graph. If - // so, don't add another expr, just add mappings for the iter domains - // necessary. - - // If there isn't already an index expression covering this, check the - // almost exact map if there's any expression not already in the index - // graph that we can use, and add in the index graph. - - // Else generate a new index expression from scratch. - - // Before replaying, check if there's already an expression like this, if - // so use that for promotion. - ExprGroups promoted_output_defs; - for (auto out_id : promoted_outputs) { - auto index_group = idGraph(IdMappingMode::INDEX).toGroup(out_id); - promoted_output_defs.pushBack( - idGraph(IdMappingMode::INDEX).uniqueDefinitions(index_group)); - } - - for (auto index_def_group : promoted_output_defs) { - // This enforces that inputs and outputs are all almost exact mapped - if (!idGraph(IdMappingMode::ALMOSTEXACT) - .disjointExprSets() - .strictAreMapped( - index_def_group->front(), ae_expr_group->front())) { - continue; - } - - // Check that the outputs we need on the replay match in the index map - // with this expression. - auto index_def_outputs = ir_utils::filterByType( - index_def_group->front()->outputs()) - .vector(); - - bool outs_match = true; - for (auto out_i : c10::irange(index_def_outputs.size())) { - outs_match = outs_match && - idGraph(IdMappingMode::INDEX) - .disjointIdSets() - .strictAreMapped( - index_def_outputs[out_i], promoted_outputs[out_i]); - } - - if (!outs_match) { - continue; - } - - // Look for an expression in the group we can reuse. - // - // See comment on definition of id_to_index_use - for (auto maybe_match : *index_def_group) { - VectorOfUniqueEntries input_uses; - for (auto inp : - ir_utils::filterByType(maybe_match->inputs())) { - auto use_it = id_to_index_use.find(inp); - if (use_it == id_to_index_use.end()) { - continue; - } - input_uses.pushBack(use_it->second); - } - - // If there's already a use, make sure it's this use. - if (input_uses.subtract({maybe_match}).size() > 0) { - continue; - } - - VectorOfUniqueEntries output_defs; - for (auto out : - ir_utils::filterByType(maybe_match->outputs())) { - auto def_it = id_to_index_def.find(out); - if (def_it == id_to_index_def.end()) { - continue; - } - output_defs.pushBack(def_it->second); - } - - // If there's already a def, make sure it's this def. - if (output_defs.subtract({maybe_match}).size() > 0) { - continue; - } - - std::vector ae_inps = - ir_utils::filterByType( - ae_expr_group->front()->inputs()) - .vector(); - - auto maybe_match_inputs = - ir_utils::filterByType(maybe_match->inputs()) - .vector(); - - // If there are promoted inputs, we need them to match exactly, - // otherwise we can't reuse this expression. So although replay is not - // nullptr, we may set it back and keep looking. - bool promo_inps_match = true; - for (auto inp_i : c10::irange(maybe_match_inputs.size())) { - auto ae_group_pair = ae_graph.disjointIdSet(ae_inps[inp_i]); - if (ae_group_pair.second && - ae_group_2_id.find(ae_group_pair.first) != - ae_group_2_id.end()) { - auto promo_inp = ae_group_2_id.at(ae_group_pair.first); - if (promo_inp != maybe_match_inputs[inp_i]) { - promo_inps_match = false; - } - } - } - - if (!promo_inps_match) { - continue; - } - - replay = maybe_match; - - for (auto inp : - ir_utils::filterByType(replay->inputs())) { - id_to_index_use[inp] = replay; - } - - for (auto out : - ir_utils::filterByType(replay->outputs())) { - id_to_index_def[out] = replay; - } - break; - } - - // No expression we could use found, keep trying. - if (replay == nullptr) { - continue; - } - - std::vector ae_inps = - ir_utils::filterByType(ae_expr_group->front()->inputs()) - .vector(); - - auto replay_inputs = - ir_utils::filterByType(replay->inputs()).vector(); - - for (auto inp_i : c10::irange(replay_inputs.size())) { - auto ae_group_pair = ae_graph.disjointIdSet(ae_inps[inp_i]); - if (!(ae_group_pair.second && - ae_group_2_id.find(ae_group_pair.first) != - ae_group_2_id.end())) { - continue; - } - idGraph(IdMappingMode::INDEX) - .mapIds( - replay_inputs[inp_i], ae_group_2_id.at(ae_group_pair.first)); - } - } - - // No existing expression could be reused. - if (replay == nullptr) { - std::vector ae_inps_outs = - ir_utils::filterByType(ae_expr_group->front()->inputs()) - .vector(); - auto outs = ir_utils::filterByType( - ae_expr_group->front()->outputs()); - ae_inps_outs.insert(ae_inps_outs.end(), outs.begin(), outs.end()); - - std::unordered_map replacement_map; - for (auto id : ae_inps_outs) { - auto ae_group = ae_graph.toGroup(id); - auto promoted_it = ae_group_2_id.find(ae_group); - if (promoted_it == ae_group_2_id.end()) { - replacement_map[id] = id->cloneWithoutRFactor(); - } else { - replacement_map[id] = promoted_it->second; - } - } - - replay = - addExprWithReplacement(replacement_map, ae_expr_group->front()); - // std::cout << " ***REPLAY3***:\n " - // << ae_expr_group->front()->toString() - // << " As:" << replay->toString(); - - } else { - // std::cout << " ***MATCH3***:\n " - // << " " << replay->toString(); - } - - all_index_exprs.pushBack(replay); - { - auto in_ids = ir_utils::filterByType(replay->inputs()); - all_index_ids.insert(in_ids.begin(), in_ids.end()); - - auto out_ids = ir_utils::filterByType(replay->outputs()); - all_index_ids.insert(out_ids.begin(), out_ids.end()); - } - - std::vector ae_inps = - ir_utils::filterByType(ae_expr_group->front()->inputs()) - .vector(); - std::vector replay_inps = - ir_utils::filterByType(replay->inputs()).vector(); - TORCH_INTERNAL_ASSERT(ae_inps.size() == replay_inps.size()); - - for (auto inp_i : c10::irange(ae_inps.size())) { - auto ae_group = ae_graph.toGroup(ae_inps[inp_i]); - // Only replace if entry does not exist. - ae_group_2_id.emplace(std::make_pair(ae_group, replay_inps[inp_i])); - } - } - } - - // std::cout << "All indexing expressions (on the index graph): " << - // std::endl; - auto index_expr_groups = - idGraph(IdMappingMode::INDEX).toGroups(all_index_exprs); - - ExprGroups extraneous_expr_groups = - ExprGroups( - idGraph(IdMappingMode::INDEX).disjointExprSets().disjointSets()) - .subtract(index_expr_groups); - for (auto group : extraneous_expr_groups) { - idGraph(IdMappingMode::INDEX).eraseExprGroup(group); - } - - // std::cout << "All index graph exprs: " << std::endl; - // std::cout << debug::exprGroupsString(idGraph(IdMappingMode::INDEX)) - // << std::endl; - - return {}; + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } } // namespace nvfuser From e923a0a388802b4c7fd91ccb2c4d61c7a6eac907 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 10 May 2023 07:57:19 -0400 Subject: [PATCH 029/178] Cleanup, move id modelling into its own directory. --- CMakeLists.txt | 5 +- csrc/id_graphs.h | 660 ------------- csrc/id_model/id_graph.cpp | 1026 +++++++++++++++++++ csrc/id_model/id_graph.h | 248 +++++ csrc/{ => id_model}/id_graphs.cpp | 1518 +---------------------------- csrc/id_model/id_graphs.h | 290 ++++++ csrc/id_model/to_string.cpp | 322 ++++++ csrc/id_model/to_string.h | 77 ++ csrc/id_model/visitor.cpp | 156 +++ csrc/id_model/visitor.h | 85 ++ csrc/lower2device.cpp | 2 +- 11 files changed, 2222 insertions(+), 2167 deletions(-) delete mode 100644 csrc/id_graphs.h create mode 100644 csrc/id_model/id_graph.cpp create mode 100644 csrc/id_model/id_graph.h rename csrc/{ => id_model}/id_graphs.cpp (57%) create mode 100644 csrc/id_model/id_graphs.h create mode 100644 csrc/id_model/to_string.cpp create mode 100644 csrc/id_model/to_string.h create mode 100644 csrc/id_model/visitor.cpp create mode 100644 csrc/id_model/visitor.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 62117920b72..cfab2a69352 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,7 +87,10 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/grouped_reduction.cpp ${NVFUSER_SRCS_DIR}/index_compute.cpp ${NVFUSER_SRCS_DIR}/lower_index_compute.cpp - ${NVFUSER_SRCS_DIR}/id_graphs.cpp + ${NVFUSER_SRCS_DIR}/id_model/id_graph.cpp + ${NVFUSER_SRCS_DIR}/id_model/id_graphs.cpp + ${NVFUSER_SRCS_DIR}/id_model/to_string.cpp + ${NVFUSER_SRCS_DIR}/id_model/visitor.cpp ${NVFUSER_SRCS_DIR}/instrumentation.cpp ${NVFUSER_SRCS_DIR}/ir_base_nodes.cpp ${NVFUSER_SRCS_DIR}/ir_builder.cpp diff --git a/csrc/id_graphs.h b/csrc/id_graphs.h deleted file mode 100644 index 37661c2e019..00000000000 --- a/csrc/id_graphs.h +++ /dev/null @@ -1,660 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include -#include - -namespace nvfuser { - -using IdGroup = std::shared_ptr>; -using IdGroups = VectorOfUniqueEntries; -using ExprGroup = std::shared_ptr>; -using ExprGroups = VectorOfUniqueEntries; - -class TORCH_CUDA_CU_API IdGraph { - public: - IdGraph() = default; - - IdGraph(const IdGraph& other); - IdGraph(IdGraph&& other) = default; - - IdGraph& operator=(const IdGraph& other); - IdGraph& operator=(IdGraph&& other) = default; - - // Returns the disjoint IterDomain set. - const DisjointSets& disjointIdSets() const; - - DisjointSets& disjointIdSets(); - - // Returns - // { - // (1) The disjoint set of the provided Iter Domain if it exists, - // otherwise a null shared ptr - // (2) If the disjoint set of the provided Iter Domain exists - // } - // - // TODO: Audit usage - std::pair disjointIdSet(IterDomain* id) const; - - // Returns the disjoint Expr set. - const DisjointSets& disjointExprSets() const; - - DisjointSets& disjointExprSets(); - - // Same as getDisjointIdSet but for the Expression sets. - // - // TODO: Audit usage - std::pair disjointExprSet(Expr* expr) const; - - // Convert expr to its exprGroup, assert that it exists. - ExprGroup toGroup(Expr* expr) const; - - // Convert iter domain to its IdGroup, assert that it exists. - IdGroup toGroup(IterDomain* id) const; - - // Convert unique vector of expressions to unique vector of its groups - ExprGroups toGroups(const VectorOfUniqueEntries& exprs) const; - - // Convert unique vector of IterDomain to unique vector of its groups - IdGroups toGroups(const VectorOfUniqueEntries& ids) const; - - // Return output/input iter domain groups of provided expr - std::vector outputGroups(ExprGroup expr) const; - std::vector inputGroups(ExprGroup expr) const; - - // Traverses uses of the IdGroups in 'of' and returns all ExprGroups - // that have a use in their definition of provided of IdGroups. - ExprGroups allUsesOf(const IdGroups& of) const; - - // Traverses definitions of the IdGroups in 'of' and returns all ExprGroups - // used in this history of defining the 'of' IdGroups. - ExprGroups allDefinitionsOf(const IdGroups& of) const; - - // Return sorted expressions to go from the provided IterDomains in from to - // the provided IterDomains in to with provided mode. Minimal expressions to - // get from 'from' to 'to' returned. - ExprGroups getExprsBetween(const IdGroups& from, const IdGroups& to) const; - - // Supports one to many mappings, uses the disjoint sets of the provided mode - // to produce mappings between from and to. If multiple IterDomains in to map - // to a single iter domain in from, the order of the IterDomains in value of - // the map is preserved to be the order provided in to. - std::unordered_map> - buildMapBetween( - const std::vector& from, - const std::vector& to) const; - - // Alias of the above on unique vector entries - std::unordered_map> - buildMapBetween( - const VectorOfUniqueEntries& from, - const VectorOfUniqueEntries& to) const; - - //! Returns - //! (1) The expressions associated with the definitions of the provided - //! IterDomain group in the provided mapping mode (if it exists). - //! (2) If there is a definitions entry of the provided IterDomain group in - //! the provided mapping mode. - //! First entry in the returned pair is a vector of vector of expressions. The - //! inner vector is proven to be equivalent based on the provided mode. The - //! outer vector are expression groups that are not equivalent based on the - //! provided mode, but produce one of the IterDomains within the same disjoint - //! Iter Domain set based on the provided mode. - //! TODO: Change name to start with get - std::pair iterDomainGroupDefinitions( - IdGroup id_group) const; - - //! Same as iterDomainGroupDefinitions but for uses instead of definitions - //! TODO: Change name to start with get - std::pair iterDomainGroupUses(IdGroup id_group) const; - - std::string toString() const; - - // Checks if the expression is a trivial operation where an input is simply an - // output of the transformation. Returns the mapped iter domains if found. - static std::vector> isTrivialExpr(Expr* expr); - - // Initializes entries for the provided IterDomain in the IterDomainGraphs - void initializeId( - IterDomain* id, - const VectorOfUniqueEntries& definitions, - const VectorOfUniqueEntries& uses); - - // Returns if first and second are expressions through which the provided - // id_map have matching inputs (if forward), or outputs (if not forward). - // Returning true means the expressions are "the same", in terms they modify - // matching original extents, by the same amount. - bool exprsMap( - Expr* first, - Expr* second, - bool forward - // , std::vector second_input_or_output_override - ) const; - - // Returns entry in unique_definitions_ for provided group in provided mode, - // otherwise errors if no entry is found. - ExprGroups uniqueDefinitions(IdGroup group) const; - - // Returns entry in unique_uses_ for provided group in provided mode, - // otherwise errors if no entry is found. - ExprGroups uniqueUses(IdGroup group) const; - - std::unordered_map& uniqueUses() { - return unique_uses_; - } - - std::unordered_map& uniqueDefinitions() { - return unique_definitions_; - } - - // Set id0 and id1 to mapped in disjointIdsSet[mode], attempt to propagate - // new mapping through id0/id1 definitions/uses. - void mapIds(IterDomain* id0, IterDomain* id1); - - // Checks if expr0 and expr1 should map together, maps them together, and if - // expression propagation is on, propagates mapping through them. This should - // be the only call in IdGraph to mapThroughExpr - void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); - - // Map expr0 and expr1 with eachother, update unique_definitions_ unique_uses_ - // TODO: Make this variant hidden? - void mapExprs(Expr* expr0, Expr* expr1); - - // Checks if expr's are considered "the same" where sameness inputs and - // outputs in the same position across expressions map with provided - // MappingMode. If the expressions are determined the same then - // if forward - // will map outputs - // else - // will map inputs - // in the provided mode. - // Returns if expressions were mapped through. - // - // TODO: Make this private - bool mapThroughExpr(Expr* first, Expr* second, bool forward); - - // Map through loop swizzles, as input/output IterDomains are exact, only the - // order they're traversed differs. - void mapThroughLoopSwizzles(); - - // Maps iter domain pairs returned by calling that return mappings from - // IdGraph::isTrivialExpr on every expression in the graph. - void mapThroughTrivialExprs(); - - // Removes expressions from unique_definitions_ and unique_uses_ that return - // mappings from IdGraph::isTrivialExpr - void removeTrivialExprs(); - - // See comment on propagate_expr_ member bool for description - // Once disabled this can't be reenabled on a graph. If it's reenabled it's - // hard to predict how mappings will propagate, which will be triggered on the - // next mapping. To support changing this flag, we should likely run through - // all expressions currently registered and propagate through all of them on - // switch. Then once enabled it couldn't be redisabled because we don't record - // the history of mapId calls. - void disableExprPropagation() { - propagate_exprs_ = false; - } - - // Removes the provided expression group from unique_definitions_ and - // unique_uses_ breaking traversal through them. - void eraseExprGroup(ExprGroup expr_group); - - // Returns if the expression group has an input id group that matches an - // output id group. This means traversing on this expression doesn't actually - // do anything. - bool isTrivialExprGroup(ExprGroup expr_group) const; - - private: - // If propagate_exprs_ = false, then mapThroughExpr will not be called as a - // consequence of calling mapIds. As well as mapThroughExpr will not be called - // (again) as a result of calling mapThroughExpr. - // - // Note: For the second sentence of above... mapThroughExpr can call mapIds - // which could in return call mapThoughExpr again, but propagate_exprs_ as - // mentioned above prevents that from happening. - // - // TODO: Should propagate_exprs_ be a const member? - bool propagate_exprs_ = true; - - // Keeps a disjoint set entry for all IterDomain for all mapping mode types. - // - // Using an array here might be nice, but it seems hard to use an enum as an - // array key - // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum - DisjointSets disjoint_ids_; - - // Keeps a disjoint set entry for all Expressions for all mapping mode types. - DisjointSets disjoint_exprs_; - - std::unordered_map unique_definitions_; - - std::unordered_map unique_uses_; - - // Hold a set of IterDomains that are considered view rfactor ids. This - // identification is particularly important to understand if split operations - // are divisible or not. - // - // TODO: This should just be in IterDomainGraphs, not here. - std::unordered_set view_rfactor_ids_; -}; - -// Debuging print functions -namespace debug { -std::string toString( - const std::vector& id_group, - int indent_size = 0); -std::string toString( - const IdGroup& id_group, - int indent_size = 0, - bool with_ptr = false); - -std::string toString( - const std::vector& id_groups, - int indent_size = 0, - bool with_ptr = false); - -std::string toString( - const IdGroups& id_groups, - int indent_size = 0, - bool with_ptr = false); - -std::string toInlineString(const std::vector& id_groups); -std::string toInlineString(const IdGroups& id_groups); - -std::string toString(const std::vector& expr_group, int indent_size = 0); -std::string toString( - const ExprGroup& expr_group, - int indent_size = 0, - bool with_ptr = false); - -std::string toString( - const IdGraph& id_graph, - const std::vector& expr_group, - int indent_size = 0, - bool with_ptr = false); -std::string toString( - const IdGraph& id_graph, - const ExprGroup& expr_groups, - int indent_size = 0, - bool with_ptr = false); - -std::string toString( - const IdGraph& id_graph, - const std::vector& expr_groups, - int indent_size = 0, - bool with_ptr = false); -std::string toString( - const IdGraph& id_graph, - const ExprGroups& expr_groups, - int indent_size = 0, - bool with_ptr = false); - -std::string idGroupsString( - const IdGraph& id_graph, - int indent_size = 0, - bool with_ptr = false); -std::string exprGroupsString( - const IdGraph& id_graph, - int indent_size = 0, - bool with_ptr = false); -std::string definitionsString( - const IdGraph& id_graph, - int indent_size = 0, - bool with_ptr = false); -std::string usesString( - const IdGraph& id_graph, - int indent_size = 0, - bool with_ptr = false); -} // namespace debug - -// Iterates through an IterDomain Graph in topological order, calling handle on -// all Id and all Expr groups in a forward topological order. -// -// Warning: Expr groups that have an input and output in the same IdGroup are -// ignored. -// -// Warning: This is not a great iterator if there's a desire to minimize paths -// traveled to simply visit all IdGroups in order. See ExprsBetween to see how -// we might minimize paths. -class TORCH_CUDA_CU_API IdGraphVisitor { - protected: - // If sub_selection is assumed to be a set of iter domains by which form a - // sub-regrion of the IdGraph provided. Only that sub-region will be visited. - IdGraphVisitor( - const IdGraph& id_graph, - const VectorOfUniqueEntries sub_selection = {}) - : id_graph_(id_graph), sub_selection_(sub_selection) {} - - virtual void handle(IdGroup id_group) = 0; - virtual void handle(ExprGroup expr_group) = 0; - - void traverse(); - - const IdGraph& graph() { - return id_graph_; - }; - - IdGraphVisitor() = delete; - - IdGraphVisitor(const IdGraphVisitor& other) = default; - IdGraphVisitor& operator=(const IdGraphVisitor& other) = delete; - - IdGraphVisitor(IdGraphVisitor&& other) = default; - IdGraphVisitor& operator=(IdGraphVisitor&& other) = delete; - - virtual ~IdGraphVisitor() = default; - - private: - const IdGraph& id_graph_; - const VectorOfUniqueEntries sub_selection_; -}; - -// Statement sorting based on IdGraphVisitor, see warnings to IdGraph Visitor. -class IdGraphStmtSort : public IdGraphVisitor { - public: - IdGraphStmtSort( - const IdGraph& id_graph, - const VectorOfUniqueEntries sub_selection = {}) - : IdGraphVisitor(id_graph, sub_selection) { - IdGraphVisitor::traverse(); - } - - ExprGroups exprs() { - return sorted_exprs; - } - - IdGroups ids() { - return sorted_ids; - } - - ~IdGraphStmtSort() override = default; - - protected: - using IdGraphVisitor::handle; - void handle(IdGroup id_group) override { - sorted_ids.pushBack(id_group); - } - - void handle(ExprGroup expr_group) override { - sorted_exprs.pushBack(expr_group); - } - - ExprGroups sorted_exprs; - IdGroups sorted_ids; -}; - -namespace { -// Convenience to store some intermediate data across a few lowering build -// passes. -struct StatefulLoweringInfo; -} // namespace - -// TODO: Comment is stale, update. -// -// There's three modes of these iter domain mappings all uniquely important in -// the lowering process. -// -// For EXACT/PERMISSIVE mode consider: -// -// consumer[i0, b1] = producer[i0] -// consumer->merge(0) (consumer will now be [i0 * b1]) -// When producer is replayed as consumer (the direction we use for mapping) -// with BestEffortReplay forward_bcast_mismatch = True the producer to -// consumer map will have both a mapping of consumer(i0) to producer(i0) as -// well as consumer(i0*b1) to producer(i0). This latter mapping is important -// for loop nest mappings as the consumer will generate a loop based on i0*b1 -// and the producer may be computeAt inside this loop nest. However, for -// indexing we do not want these two maps as producer may be indexed as i0*i1 -// depending on the loop nest structure and how it was built. Therefore we -// really need to carry (at least) two sets of maps around for lowering. -// -// LOOP mode is important if we have something like: -// consumer[i0o, threadIdx.x{i0i}] = producer[i0o, threadIdx.y{i0i}](computeAt -// = 1) which can easily happen when using shared memory. We want to make sure -// that the iteration domain used for loop construction (concreteId) has the -// proper parallelization strategy. In parallel mode we do typical iteration -// domain mapping, however we remove from it any iteration domains outside the -// computeAt of producer when mapping. This guarentees we won't map -// IterDomains that could have different parallelization strategies. We also -// propagate the parallel strategy in parallel mode so all mapped IDs that -// must have the same parallel type, do. -// -// IdMappingMode::LOOP -// Only maps leaf axes to left of compute at -// Forward broadcast axes in replay -// IdMappingMode::PERMISSIVE -// Forward broadcast axes in replay -// Map all iteration domains -// Always contain root mappings (otherwise they could have been forwarded in -// broadcast) -// IdMappingMode::EXACT -// Don't map any broadcast axes to non-broadcast axes -// Do not forward through any broadcast IDs -// IdMappingMode::AlmostExact -// Forward through broadcast axes, but not through to a non-broadcast axis -// i.e. id{b1*i0}, id{i0} are mapped -// id{i1*i0}, id{i0} are not mapped (this part is the difference from -// PERMISSIVE) -// Forward through split one axes, i.e. id{ceilDiv(i0, 1)}, id{i0} are mapped -// -class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { - public: - IterDomainGraphs( - const std::vector& exprs, - const std::vector& additional_tvs, - bool allow_self_mapping = false); - - IterDomainGraphs( - const std::vector& exprs, - bool allow_self_mapping = false); - - // Same as the above constructor with fusion->exprs() excpet fusion may have - // some dangling inputs/outputs that are expected to have IterDomain entries - // even though there's no possible connections from them. - IterDomainGraphs(Fusion* fusion, bool allow_self_mapping = false); - - // Returns iter domain graph of provided mode. - const IdGraph& idGraph(IdMappingMode mode) const; - IdGraph& idGraph(IdMappingMode mode); - - // IterDomains from the original fusion are only allowed to be used once in - // the IterDomain graph, id->uses() are not directly used as there's no bounds - // check that would prevent a use from being defined that's not part of the - // actual fusion definition. - // - // Note, any iter domains used during something like loop or concrete id - // resolution could actually have multiple Expr* uses, and uses on disjoint id - // sets should be used, not this. - // - // TODO: Refactor or remove? - Expr* idUse(IterDomain* id) const; - Expr* idDef(IterDomain* id) const; - - // TODO: Seems a bit unfortunate that this isn't IterDomain local information. - const std::unordered_set& viewRfactorIds() const { - return view_rfactor_ids_; - } - - // Returns if a self mapping was detected that would invalidate assumptions of - // the overall lowering system. - // - // TODO: Can we make this more of an alias analysis? - // Ref: https://github.com/csarofeen/pytorch/pull/1954#discussion_r961940498 - bool hasSelfMapping() const { - return self_mapping_info_.has_value(); - } - - // Update the LOOP ID disjoint sets with resolved computeWith - void updateComputeWith(TensorView* compute_with_tv); - - std::string toString() const; - - // Replay Expr but with the inputs provided. IterDomainGraphss will be updated - // for all maps that have entries, adding the output iter domains of the - // replayed expression and adding potential mappings through the expression. - Expr* addReplayAs(std::vector new_inputs, Expr* expr); - - // Similar to addReplayAs, but clones the expr exactly instead of replaying it - // forward. It's up to the calling code to make sure the replacements are - // valid for the provided expr. It's generally recommended that the - // IterDomains exactly match those in the expr. - // - // "forward" dictates the same argument for mapThroughExpr. If forward the - // function will apply mapThroughExpr forward if inputs map in each - // initialized map. Else does the same but backwards through the expression - // from outputs. - Expr* addExprWithReplacement( - const std::unordered_map& old_2_new_ids, - Expr* old_expr); - - // Make a new expr matching that provided but using the outputs provided. - // IterDomainGraphss will be updated for all maps that have entries. Adding - // the input iter domains of the replayed expression and adding potential - // mappings through the expressions. Input domains will match exactly in all - // properties as those in expr. This is unlike addReplayAs which will produce - // new outputs using transformations directly. - Expr* addBackwardsReplayAs( - const std::vector& new_outputs, - Expr* expr); - - // Make an exact copy of provided IterDomain (without rfactor set), and map - // the copy to the original in all registered IdGraphs. IterDomain copy will - // not have any registered uses or definitions. - IterDomain* cloneIterDomain(IterDomain* id); - - // TODO: Should this not be private? - protected: - // Sometimes fusion inputs or outputs are disconnected from expressions, in - // those cases we still may want to send in some additional tensor views from - // the Fusion that don't have expressions associated with them. - void build( - const std::vector& exprs, - const std::vector& additional_tvs); - - // ======= START Iteration domain build process in order called ======= - - // Fills id_uses_ and id_definitions_ for all IterDomains active in the - // fusion. - void buildIterDomainDefinitionsAndUses( - const std::vector& all_tvs); - - // Iterates over all IterDomains in id_definitions_ and calls initializeID on - // a new IdGraph and returns it. - IdGraph initializeIdGraph(); - - // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs - // and first output of expr - void buildExactMap(const std::vector& exprs); - - // Fills disjoint_ids_[IdMappingMode::ALMOSTEXACT]. Initialize AlmostExact as - // Exact entries, then map anything that's either merged with a size-1 or - // split by a size-1 dimension. - void buildAlmostExactMap(); - - // Fills disjoint_ids_[IdMappingMode::PERMISSIVE]. Initialize PermissiveMap as - // AlmostExact entries, then map through broadcasts - void buildPermissiveMap(const std::vector& exprs); - - // Make sure only leaf nodes of tensor views are parallelized - void validatePTypes(const std::vector& all_tvs) const; - - //! Run through disjoint sets in the LOOP map, make sure there's only one - //! non-serial parallel type in each disjoint set, set the parallel type of - //! all IterDomains in the disjoint set to that PType. - void propagateLoopPTypes() const; - - // !! START Helper functions to build loop promotion and index map!! - - // Terminal loop ids are iteration domains in each loop group that: - // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a - // consumer TV's iter domain maps to this domain in a way that that domain - // is also in the same loop group - // 2) Don't have a direct IterDomain consumer within the group - VectorOfUniqueEntries computeTerminalLoopIds( - const StatefulLoweringInfo info); - - // Returns an IdGraph with all Id's mapped that are mapped both in graph0 and - // graph1. - IdGraph buildIntersection( - const IdGraph& graph0, - const IdGraph& graph1, - bool propagate_exprs = true); - - // !! END Helper functions to build loop promotion and index map!! - - // Start loop map by grouping inlined iter domains - void initializeLoopMap(StatefulLoweringInfo& info); - - // Returns map of IdGroups in the loop map to a representative IterDomain that - // contains all resolved transformations that the terminal IterDomains should - // be promoted to. The returned promotions are valid only for inlined iter - // domains. - std::unordered_map buildInlinePromotions( - StatefulLoweringInfo& info); - - // Returns a similar thing to buildInlinePromotions but also includes iter - // domains that are not inlined. - std::unordered_map buildLoopPromotionMap( - const std::vector& exprs, - StatefulLoweringInfo& info, - std::unordered_map stale_promotion_map); - - // Builds idGraph(IdMappingMode::INDEX) and returns the iter domain promotion - // map to go from leaf domains of each (consumer only?) tensor to their - // corresponding leaf domain in the index graph. - std::unordered_map buildIndexGraph( - const std::vector& exprs, - const std::vector& all_tvs, - StatefulLoweringInfo& info, - std::unordered_map stale_promotion_map); - - // Returns the terminal rfactor or input iter domains each group in the almost - // exact map covers (in the almost exact map). This effectively returns all - // the input almost exact iter domain groups for each almost exact iter domain - // group. RFactor axes are considered an "input" as all broadcast dimensions - // have to be resolved by or before the rfactor iter domain. - std::unordered_map buildCoveredAlmostExact(); - - // TODO: Remove - void buildIndexMap(const std::vector& all_tvs); - - // ======= END Iteration domain build process in order called ======= - - // Errors if self mapping occurs - void assertNoSelfMapping(); - - // Keeps a disjoint set entry for all IterDomain for all mapping mode types. - // - // Using an array here might be nice, but it seems hard to use an enum as an - // array key - // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum - std::unordered_map id_graphs_; - - // If multiple transformations occur IterDomains could have multiple uses, - // however only one should be active in the given Fusion. When we resolve loop - // promotions during lowering, we can generate new iter domains from existing - // ones, so there can be multiple uses generated. Tracks all the active iter - // domain uses. - std::unordered_map> id_uses_; - - // Make sure we don't blindly use definitions as we don't want to grab - // transformations before a tensor view's root domain. - std::unordered_map> id_definitions_; - - // Debug information to hold if a self mapping in a TensorView is found. - c10::optional> - self_mapping_info_ = c10::nullopt; - - std::unordered_map loop_promotion_map_; - - std::unordered_set view_rfactor_ids_; -}; - -using DoubleBufferIndices = std::unordered_map; - -} // namespace nvfuser diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp new file mode 100644 index 00000000000..f2c772435d6 --- /dev/null +++ b/csrc/id_model/id_graph.cpp @@ -0,0 +1,1026 @@ +#include +#include +#include + +namespace nvfuser { + +IdGraph::IdGraph(const IdGraph& other) { + disjoint_ids_ = other.disjoint_ids_; + disjoint_exprs_ = other.disjoint_exprs_; + view_rfactor_ids_ = other.view_rfactor_ids_; + + for (auto orig_unique_def_pair : other.unique_definitions_) { + auto orig_id_group = orig_unique_def_pair.first; + auto orig_expr_groups = orig_unique_def_pair.second; + + auto new_id_group_pair = disjointIdSet(orig_id_group->front()); + TORCH_INTERNAL_ASSERT(new_id_group_pair.second); + auto new_id_group = new_id_group_pair.first; + + ExprGroups new_expr_groups; + for (auto orig_expr_group : orig_expr_groups) { + auto new_expr_group_pair = disjointExprSet(orig_expr_group->front()); + TORCH_INTERNAL_ASSERT(new_expr_group_pair.second); + new_expr_groups.pushBack(new_expr_group_pair.first); + } + + unique_definitions_[new_id_group] = new_expr_groups; + } + + for (auto orig_unique_use_pair : other.unique_uses_) { + auto orig_id_group = orig_unique_use_pair.first; + auto orig_expr_groups = orig_unique_use_pair.second; + + auto new_id_group_pair = disjointIdSet(orig_id_group->front()); + TORCH_INTERNAL_ASSERT(new_id_group_pair.second); + auto new_id_group = new_id_group_pair.first; + + ExprGroups new_expr_groups; + for (auto orig_expr_group : orig_expr_groups) { + auto new_expr_group_pair = disjointExprSet(orig_expr_group->front()); + TORCH_INTERNAL_ASSERT(new_expr_group_pair.second); + new_expr_groups.pushBack(new_expr_group_pair.first); + } + + unique_uses_[new_id_group] = new_expr_groups; + } +} + +IdGraph& IdGraph::operator=(const IdGraph& other) { + disjoint_ids_.clear(); + disjoint_exprs_.clear(); + unique_definitions_.clear(); + unique_uses_.clear(); + view_rfactor_ids_.clear(); + IdGraph copy(other); + std::swap(*this, copy); + return *this; +} + +const DisjointSets& IdGraph::disjointIdSets() const { + return disjoint_ids_; +} + +DisjointSets& IdGraph::disjointIdSets() { + return disjoint_ids_; +} + +std::pair IdGraph::disjointIdSet(IterDomain* id) const { + auto disjoint_set_it = disjoint_ids_.disjointSetMap().find(id); + if (disjoint_set_it == disjoint_ids_.disjointSetMap().end()) { + return std::make_pair(IdGroup(nullptr), false); + } + return std::make_pair(disjoint_set_it->second, true); +} + +const DisjointSets& IdGraph::disjointExprSets() const { + return disjoint_exprs_; +} + +DisjointSets& IdGraph::disjointExprSets() { + return disjoint_exprs_; +} + +std::pair IdGraph::disjointExprSet(Expr* expr) const { + auto disjoint_set_it = disjoint_exprs_.disjointSetMap().find(expr); + if (disjoint_set_it == disjoint_exprs_.disjointSetMap().end()) { + return std::make_pair(ExprGroup(nullptr), false); + } + return std::make_pair(disjoint_set_it->second, true); +} + +ExprGroup IdGraph::toGroup(Expr* expr) const { + auto disjoint_set_pair = disjointExprSet(expr); + TORCH_INTERNAL_ASSERT( + disjoint_set_pair.second, + "\nExpr group could not be found in graph associated with: ", + expr->toString()); + return disjoint_set_pair.first; +} + +IdGroup IdGraph::toGroup(IterDomain* id) const { + auto disjoint_set_pair = disjointIdSet(id); + TORCH_INTERNAL_ASSERT( + disjoint_set_pair.second, + "\nId group could not be found in graph associated with: ", + id->toString(), + "\n"); + return disjoint_set_pair.first; +} + +ExprGroups IdGraph::toGroups(const VectorOfUniqueEntries& exprs) const { + ExprGroups expr_groups; + for (auto expr : exprs) { + expr_groups.pushBack(toGroup(expr)); + } + return expr_groups; +} + +IdGroups IdGraph::toGroups( + const VectorOfUniqueEntries& ids) const { + IdGroups id_groups; + for (auto id : ids) { + id_groups.pushBack(toGroup(id)); + } + return id_groups; +} + +std::vector IdGraph::outputGroups(ExprGroup expr) const { + std::vector output_groups; + for (auto id_output : + ir_utils::filterByType(expr->front()->outputs())) { + output_groups.push_back(toGroup(id_output)); + } + return output_groups; +} + +std::vector IdGraph::inputGroups(ExprGroup expr) const { + std::vector input_groups; + for (auto id_input : + ir_utils::filterByType(expr->front()->inputs())) { + input_groups.push_back(toGroup(id_input)); + } + return input_groups; +} + +ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { + ExprGroups to_visit; + for (auto of_id_group : of) { + auto group_uses_pair = iterDomainGroupUses(of_id_group); + if (group_uses_pair.second) { + to_visit.pushBack(group_uses_pair.first); + } + } + + ExprGroups visited; + while (to_visit.size() > 0) { + auto current_expr = to_visit.popFront(); + visited.pushBack(current_expr); + auto output_ids = outputGroups(current_expr); + for (auto output_id : output_ids) { + auto group_uses_pair = iterDomainGroupUses(output_id); + if (!group_uses_pair.second) { + continue; + } + for (auto group_use : group_uses_pair.first) { + if (visited.has(group_use)) { + continue; + } + to_visit.pushBack(group_use); + } + } + } + + return visited; +} + +ExprGroups IdGraph::allDefinitionsOf(const IdGroups& of) const { + ExprGroups to_visit; + for (auto of_id_group : of) { + auto group_defs_pair = iterDomainGroupDefinitions(of_id_group); + if (group_defs_pair.second) { + to_visit.pushBack(group_defs_pair.first); + } + } + + ExprGroups visited; + while (to_visit.size() > 0) { + auto current_expr = to_visit.popFront(); + visited.pushBack(current_expr); + auto input_ids = inputGroups(current_expr); + for (auto input_id : input_ids) { + auto group_defs_pair = iterDomainGroupDefinitions(input_id); + if (!group_defs_pair.second) { + continue; + } + for (auto group_def : group_defs_pair.first) { + if (visited.has(group_def)) { + continue; + } + to_visit.pushBack(group_def); + } + } + } + + return visited; +} + +ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) + const { + auto all_uses_of_from = allUsesOf(from); + auto all_definitions_of_to = allDefinitionsOf(to); + + // All of the expressions between from and to. Not all will be used as we + // just want to define each iter domain group once. + auto all_exprs = all_uses_of_from.intersect(all_definitions_of_to); + + // There could be IterDomains in from or to that are between other from and + // to nodes. Make sure to clear those out. + IdGroups terminating_inputs; + IdGroups terminating_outputs; + { + IdGroups not_inputs; + IdGroups not_outputs; + IdGroups all_id_groups; + + for (auto expr_group : all_exprs) { + auto inp_groups = inputGroups(expr_group); + auto out_groups = outputGroups(expr_group); + if (IdGroups(inp_groups).intersect(IdGroups(out_groups)).size() > 0) { + // Expression is just a loop to its current group, ignore + continue; + } + + all_id_groups.pushBack(inp_groups); + + if (!inp_groups.empty()) { + not_outputs.pushBack(inp_groups); + } + + all_id_groups.pushBack(out_groups); + + if (!out_groups.empty()) { + not_inputs.pushBack(out_groups); + } + } + terminating_inputs = all_id_groups.subtract(not_inputs); + terminating_outputs = all_id_groups.subtract(not_outputs); + } + + // Track all expressions to get from outputs to this IterDomain. We + // traverse backwards as that's the direction of indexing expressions. An + // index is assigned to each leaf of a domain and as we traverse backwards + // we're effectively accumulating indexing math. We'll only keep the fewest + // expression lists to get to the iter domain. + std::unordered_map required_ind_exprs_ids; + std::unordered_map required_ind_exprs_exprs; + + // Return if all output IterDomain groups of an expression group have + // already been visited + auto outputsVisited = [&](ExprGroup expr) { + for (auto id_group : outputGroups(expr)) { + if (required_ind_exprs_ids.find(id_group) == + required_ind_exprs_ids.end()) { + return false; + } + } + return true; + }; + + auto allIdUsesVisisted = [&](IdGroup id) { + auto uses_pair = iterDomainGroupUses(id); + if (!uses_pair.second) { + return true; + } + for (auto use_group : uses_pair.first) { + if (all_exprs.has(use_group)) { + if (required_ind_exprs_exprs.find(use_group) == + required_ind_exprs_exprs.end()) { + return false; + } + } + } + return true; + }; + + // Returns all expression groups in required_ind_exprs_ids of outputs + auto requiredExprsOutputs = [&](ExprGroup expr) { + ExprGroups all_output_required_exprs; + for (auto id_group : outputGroups(expr)) { + auto id_group_exprs_it = required_ind_exprs_ids.find(id_group); + TORCH_INTERNAL_ASSERT( + id_group_exprs_it != required_ind_exprs_ids.end(), + "Failure in Iter Domain Graph index resolution, count expected for group: ", + id_group->toString()); + all_output_required_exprs.pushBack(id_group_exprs_it->second); + } + return all_output_required_exprs; + }; + + auto processExpr = [&](ExprGroup expr) { + if (!outputsVisited(expr)) { + return false; + } + // Accumulate expressions from all outputs add this expression and set it + // as current expressions required indexing expressions. + required_ind_exprs_exprs[expr] = requiredExprsOutputs(expr); + return true; + }; + + auto processId = [&](IdGroup id) { + // Track if we've grabed any of the uses required indexing expressions. + bool initialized = false; + // Expression group of all indexing expressions required for this iter + // domain coming back from any of its uses. + ExprGroups min_groups; + + auto uses_pair = iterDomainGroupUses(id); + if (!uses_pair.second) { + // No expressions required for this iter domain, it must be a + // terminating output. + required_ind_exprs_ids[id] = min_groups; + return true; + } + + // Only worry about expressions between inputs and outputs we're + // looking at. + for (auto use_group : uses_pair.first.intersect(all_exprs)) { + auto use_required_ind_exprs_it = required_ind_exprs_exprs.find(use_group); + if (use_required_ind_exprs_it == required_ind_exprs_exprs.end()) { + // If there isn't an entry for the use expression it wasn't + // processed, so don't try to process this iter domain yet. + return false; + } + if (!initialized) { + // If first use found initialize the minimum expression group + min_groups = + use_required_ind_exprs_it->second.computeUnion({use_group}); + initialized = true; + } else if ( + use_required_ind_exprs_it->second.size() + 1 < min_groups.size()) { + // If current use has fewer expressions use that, make sure to add the + // use expression. + min_groups = + use_required_ind_exprs_it->second.computeUnion({use_group}); + } + } + required_ind_exprs_ids[id] = min_groups; + return true; + }; + + IdGroups to_visit_ids = terminating_outputs; + ExprGroups to_visit_exprs; + + while (to_visit_ids.size() > 0 || to_visit_exprs.size() > 0) { + // Process expressions first as all uses of iter domains have to be + // processed before we can process that iter domain. + + // Try to detect when nothing has been processed which would put us in an + // infinite loop + bool something_was_processed = false; + ExprGroups still_to_visit_exprs; + while (to_visit_exprs.size() > 0) { + auto currently_visiting = to_visit_exprs.popFront(); + if (required_ind_exprs_exprs.find(currently_visiting) != + required_ind_exprs_exprs.end()) { + continue; + } + if (processExpr(currently_visiting)) { + something_was_processed = true; + auto inp_groups = inputGroups(currently_visiting); + for (auto inp_group : inp_groups) { + to_visit_ids.pushBack(inp_group); + } + } else { + still_to_visit_exprs.pushBack(currently_visiting); + } + } + + std::swap(to_visit_exprs, still_to_visit_exprs); + + IdGroups still_to_visit_ids; + while (to_visit_ids.size() > 0) { + auto currently_visiting = to_visit_ids.popFront(); + if (required_ind_exprs_ids.find(currently_visiting) != + required_ind_exprs_ids.end()) { + continue; + } + + if (processId(currently_visiting)) { + something_was_processed = true; + auto definitions_pair = iterDomainGroupDefinitions(currently_visiting); + if (definitions_pair.second) { + for (auto def : definitions_pair.first) { + if (!all_exprs.has(def)) { + continue; + } + if (required_ind_exprs_exprs.find(def) == + required_ind_exprs_exprs.end()) { + to_visit_exprs.pushBack(def); + } + } + } + } else { + still_to_visit_ids.pushBack(currently_visiting); + } + } + + TORCH_INTERNAL_ASSERT( + something_was_processed || + (to_visit_ids.size() == 0 && to_visit_exprs.size() == 0), + "Infinite loop entered."); + } + + // We want to traverse the expressions registered in required_ind_exprs_ids, + // let's create a strict "uses path" + std::unordered_map uses_path; + for (auto entry : required_ind_exprs_ids) { + auto id = entry.first; + auto traverse_exprs = entry.second; + auto all_uses = iterDomainGroupUses(id); + if (all_uses.second) { + uses_path[id] = traverse_exprs.intersect(all_uses.first); + } else { + uses_path[id] = {}; + continue; + } + } + + // Topologically sort the uses_path. + ExprGroups sorted_exprs; + ExprGroups to_visit; + + for (auto inp : terminating_inputs) { + auto use_it = uses_path.find(inp); + if (use_it == uses_path.end()) { + // This can happen for a trivial traversal where inputs and outputs are + // exactly the same. + continue; + } + auto uses = use_it->second; + for (auto use : uses) { + to_visit.pushBack(use); + } + } + + IdGroups visited = terminating_inputs; + + while (to_visit.size() > 0) { + bool something_processed = false; + ExprGroups still_to_visit; + while (to_visit.size() > 0) { + auto currently_visiting = to_visit.popFront(); + auto inputs = inputGroups(currently_visiting); + if (std::all_of(inputs.begin(), inputs.end(), [&](IdGroup inp_id) { + return visited.has(inp_id); + })) { + something_processed = true; + sorted_exprs.pushBack(currently_visiting); + auto outputs = outputGroups(currently_visiting); + for (auto out_id : outputs) { + visited.pushBack(out_id); + auto use_pair = iterDomainGroupUses(out_id); + if (!use_pair.second) { + continue; + } + still_to_visit.pushBack(use_pair.first.intersect(all_exprs)); + } + } else { + still_to_visit.pushBack(currently_visiting); + } + } + std::swap(to_visit, still_to_visit); + TORCH_INTERNAL_ASSERT(something_processed, "Infinite loop entered."); + } + + return sorted_exprs; +} + +std::unordered_map> IdGraph:: + buildMapBetween( + const std::vector& from, + const std::vector& to) const { + std::unordered_map from_ids2set; + + for (auto from_id : from) { + auto from_disjoint_set_pair = disjointIdSet(from_id); + if (!from_disjoint_set_pair.second) { + continue; + } + from_ids2set[from_id] = from_disjoint_set_pair.first; + } + + // Map from the sets associated with the IterDomains in to, to those iter + // domains + std::unordered_map> set2to_ids; + + for (auto to_id : to) { + auto to_disjoint_set_pair = disjointIdSet(to_id); + if (!to_disjoint_set_pair.second) { + continue; + } + auto to_set = to_disjoint_set_pair.first; + auto set2to_ids_it = set2to_ids.find(to_set); + + if (set2to_ids_it == set2to_ids.end()) { + set2to_ids[to_set] = {to_id}; + } else { + set2to_ids[to_set].pushBack(to_id); + } + } + + std::unordered_map> + from_ids2to_ids; + for (auto from_id : from) { + from_ids2to_ids[from_id] = VectorOfUniqueEntries(); + + auto from_it = from_ids2set.find(from_id); + TORCH_INTERNAL_ASSERT(from_it != from_ids2set.end()); + + auto from_set = from_it->second; + auto to_entry_it = set2to_ids.find(from_set); + if (to_entry_it == set2to_ids.end()) { + continue; + } + from_ids2to_ids[from_id] = to_entry_it->second; + } + return from_ids2to_ids; +} + +std::unordered_map> IdGraph:: + buildMapBetween( + const VectorOfUniqueEntries& from, + const VectorOfUniqueEntries& to) const { + return buildMapBetween(from.vector(), to.vector()); +} + +std::pair IdGraph::iterDomainGroupDefinitions( + IdGroup id_group) const { + auto null_return = std::make_pair(ExprGroups(), false); + + if (id_group == nullptr) { + return null_return; + } + + auto definitions_it = unique_definitions_.find(id_group); + if (definitions_it == unique_definitions_.end()) { + return null_return; + } + + return std::make_pair(definitions_it->second, true); +} + +std::pair IdGraph::iterDomainGroupUses( + IdGroup id_group) const { + auto null_return = std::make_pair(ExprGroups(), false); + + if (id_group == nullptr) { + return null_return; + } + + auto uses_it = unique_uses_.find(id_group); + if (uses_it == unique_uses_.end()) { + return null_return; + } + + return std::make_pair(uses_it->second, true); +} + +std::string IdGraph::toString() const { + std::stringstream ss; + ss << "IdGraph { \n"; + ss << "Disjoint Ids:\n" + << idGroupsString(*this, 1) << "\n\nDisjoint Expression groups:\n" + << exprGroupsString(*this, 1) << std::endl; + ss << " } IdGraph\n" << std::endl; + return ss.str(); +} + +std::vector> IdGraph::isTrivialExpr(Expr* expr) { + std::vector> mapped_ids; + if (auto merge = dynamic_cast(expr)) { + if (merge->inner()->extent()->isOneInt()) { + mapped_ids.push_back({merge->outer(), merge->out()}); + } + if (merge->outer()->extent()->isOneInt()) { + mapped_ids.push_back({merge->inner(), merge->out()}); + } + } else if (auto split = dynamic_cast(expr)) { + if (split->factor()->isOneInt() && split->startOffset()->isZeroInt() && + split->stopOffset()->isZeroInt()) { + if (split->innerSplit()) { + mapped_ids.push_back({split->in(), split->outer()}); + } else { + mapped_ids.push_back({split->in(), split->inner()}); + } + } + } else if (auto swizzle = dynamic_cast(expr)) { + if (swizzle->swizzleType() == Swizzle2DType::NoSwizzle || + swizzle->swizzleMode() == SwizzleMode::NoSwizzle) { + mapped_ids.push_back({swizzle->inX(), swizzle->outX()}); + mapped_ids.push_back({swizzle->inY(), swizzle->outY()}); + } + } + return mapped_ids; +} + +bool IdGraph::transformAtributesMatch(Expr* first, Expr* second) { + if (first == nullptr || second == nullptr) { + return false; + } + + TORCH_INTERNAL_ASSERT( + first->isA() || first->isA() || first->isA() || + first->isA(), + "Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n", + first->toString()); + + if (typeid(*first) != typeid(*second)) { + return false; + } + + if (first->isA()) { + auto first_split = first->as(); + auto second_split = second->as(); + if (!first_split->factor()->sameAs(second_split->factor()) || + first_split->innerSplit() != second_split->innerSplit() || + !first_split->startOffset()->sameAs(second_split->startOffset()) || + !first_split->stopOffset()->sameAs(second_split->stopOffset())) { + return false; + } + } + + if (first->isA()) { + auto first_swizzle = first->as(); + auto second_swizzle = second->as(); + if (first_swizzle->swizzleMode() != second_swizzle->swizzleMode() || + first_swizzle->swizzleType() != second_swizzle->swizzleType()) { + return false; + } + } + + return true; +} + +void IdGraph::initializeId( + IterDomain* id, + const VectorOfUniqueEntries& definitions, + const VectorOfUniqueEntries& uses) { + auto id_disjoint_set = disjointIdSets().initializeSet(id).first->second; + + ExprGroups def_groups; + for (auto def : definitions) { + auto expr_set = disjointExprSets().initializeSet(def).first->second; + def_groups.pushBack(expr_set); + } + unique_definitions_[id_disjoint_set] = def_groups; + + ExprGroups use_groups; + for (auto use : uses) { + auto expr_set = disjointExprSets().initializeSet(use).first->second; + use_groups.pushBack(expr_set); + } + unique_uses_[id_disjoint_set] = use_groups; +} + +bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { + if (!transformAtributesMatch(first, second)) { + return false; + } + + auto first_ids = ir_utils::filterByType( + forward ? first->inputs() : first->outputs()) + .vector(); + + auto second_ids = ir_utils::filterByType( + forward ? second->inputs() : second->outputs()) + .vector(); + + TORCH_INTERNAL_ASSERT( + first_ids.size() == second_ids.size(), + "Expected number of ", + (forward ? "inputs" : "outputs"), + " to match for\n", + first->toString(), + second->toString()); + + { + std::vector> zipped_ids; + + std::transform( + first_ids.begin(), + first_ids.end(), + second_ids.begin(), + std::back_inserter(zipped_ids), + [](IterDomain* first, IterDomain* second) { + return std::make_pair(first, second); + }); + + if (std::any_of( + zipped_ids.begin(), + zipped_ids.end(), + [&](std::pair id_pair) { + return !disjointIdSets().permissiveAreMapped( + id_pair.first, id_pair.second); + })) { + return false; + } + } + + // Special handling for backprop of merge + if (first->isA() && !forward) { + // Can't back prop through merge without making sure one input actually + // matches. This can be done on a map or extent basis. + auto merge0 = first->as(); + auto merge1 = second->as(); + + auto extent_0o = merge0->outer()->extent(); + auto extent_0i = merge0->inner()->extent(); + auto extent_1o = merge1->outer()->extent(); + auto extent_1i = merge1->inner()->extent(); + + auto extent_0_match = extent_0o->sameAs(extent_1o) || + (extent_0o->isConstInt() && extent_1o->isConstInt() && + extent_0o->evaluateInt() == extent_1o->evaluateInt()) || + disjointIdSets().permissiveAreMapped(merge0->outer(), merge1->outer()); + + auto extent_1_match = extent_0i->sameAs(extent_1i) || + (extent_0i->isConstInt() && extent_1i->isConstInt() && + extent_0i->evaluateInt() == extent_1i->evaluateInt()) || + disjointIdSets().permissiveAreMapped(merge0->inner(), merge1->inner()); + + if (!(extent_0_match || extent_1_match)) { + return false; + } + } + + // TODO: For now we're using same as, however we could know what val's are + // exactly the same given the exact map. We might want to pipe that + // information through to here. + if (first->isA()) { + if (!first->as()->leftExpand()->sameAs( + second->as()->leftExpand())) { + return false; + } + + if (!first->as()->rightExpand()->sameAs( + second->as()->rightExpand())) { + return false; + } + } + + return true; +} + +ExprGroups IdGraph::uniqueDefinitions(IdGroup group) const { + auto unique_defs_it = unique_definitions_.find(group); + TORCH_INTERNAL_ASSERT( + unique_defs_it != unique_definitions_.end(), + "Definition not found for IdGroup: ", + group->toString()); + return unique_defs_it->second; +} + +ExprGroups IdGraph::uniqueUses(IdGroup group) const { + auto unique_uses_it = unique_uses_.find(group); + TORCH_INTERNAL_ASSERT( + unique_uses_it != unique_uses_.end(), + "Uses not found for IdGroup: ", + group->toString()); + return unique_uses_it->second; +} + +void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { + if (id0 == id1) { + return; + } + + if (disjointIdSets().strictAreMapped(id0, id1)) { + return; + } + // Definitions and uses are based on the groups of id0 and id1, don't merge + // them into a single group until we grab all definitions and uses for later + // processing. + auto orig_id_group0 = disjointIdSet(id0).first; + auto orig_id_group1 = disjointIdSet(id1).first; + ExprGroups orig_defs0 = uniqueDefinitions(orig_id_group0); + ExprGroups orig_defs1 = uniqueDefinitions(orig_id_group1); + ExprGroups orig_uses0 = uniqueUses(orig_id_group0); + ExprGroups orig_uses1 = uniqueUses(orig_id_group1); + + // Map the iter domains together before we traverse across definitions and + // uses. Traversing definitions and uses could use the new property of id0 and + // id1 being mapped. + disjointIdSets().mapEntries(id0, id1); + auto new_id_group = disjointIdSet(id0).first; + + unique_definitions_.erase(orig_id_group0); + unique_definitions_.erase(orig_id_group1); + unique_uses_.erase(orig_id_group0); + unique_uses_.erase(orig_id_group1); + + unique_definitions_[new_id_group] = orig_defs0.computeUnion(orig_defs1); + unique_uses_[new_id_group] = orig_uses0.computeUnion(orig_uses1); + + // Propagate on uses + if (orig_uses0.size() > 0 || orig_uses1.size() > 0) { + if (orig_uses0.size() > 0 && orig_uses1.size() > 0) { + for (auto use_group_1 : orig_uses1) { + if (orig_uses0.has(use_group_1)) { + continue; + } + + for (auto use_group_0 : orig_uses0) { + auto use0 = use_group_0->front(); + auto use1 = use_group_1->front(); + maybeMapThroughExprs(use0, use1, true); + } + } + } + } + + // Propagate on definitions + if (orig_defs0.size() > 0 || orig_defs1.size() > 0) { + if (orig_defs0.size() > 0 && orig_defs1.size() > 0) { + for (auto def_group_1 : orig_defs1) { + if (orig_defs0.has(def_group_1)) { + continue; + } + + for (auto def_group_0 : orig_defs0) { + auto def0 = def_group_0->front(); + auto def1 = def_group_1->front(); + maybeMapThroughExprs(def0, def1, false); + } + } + } + } +} + +void IdGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { + if (exprsMap(expr0, expr1, forward)) { + if (propagate_exprs_) { + mapExprs(expr0, expr1); + mapThroughExpr(expr0, expr1, forward); + } else if ( + inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)) && + outputGroups(toGroup(expr0)) == outputGroups(toGroup(expr1))) { + mapExprs(expr0, expr1); + } + } +} + +void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { + if (expr0 == expr1) { + return; + } + + if (disjointExprSets().strictAreMapped(expr0, expr1)) { + return; + } + + ExprGroup expr0_orig_group = toGroup(expr0); + ExprGroup expr1_orig_group = toGroup(expr1); + + disjointExprSets().mapEntries(expr0, expr1); + + auto expr_new_group = toGroup(expr0); + + // Update unique uses of producers + IdGroups producers; + for (auto expr : std::vector{expr0, expr1}) { + for (auto input_id : ir_utils::filterByType(expr->inputs())) { + producers.pushBack(toGroup(input_id)); + } + } + + for (auto producer_group : producers) { + uniqueUses().at(producer_group).erase(expr0_orig_group); + uniqueUses().at(producer_group).erase(expr1_orig_group); + uniqueUses().at(producer_group).pushBack(expr_new_group); + } + + // Update unique definitinos of consumers + IdGroups consumers; + for (auto expr : std::vector{expr0, expr1}) { + for (auto output_id : ir_utils::filterByType(expr->outputs())) { + consumers.pushBack(toGroup(output_id)); + } + } + + for (auto consumer_group : consumers) { + uniqueDefinitions().at(consumer_group).erase(expr0_orig_group); + uniqueDefinitions().at(consumer_group).erase(expr1_orig_group); + uniqueDefinitions().at(consumer_group).pushBack(expr_new_group); + } +} + +bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { + if (first == nullptr || second == nullptr) { + return false; + } + + if (!exprsMap(first, second, forward)) { + return false; + } + + TORCH_INTERNAL_ASSERT( + propagate_exprs_, + "Asked to propagate expression mappings on a graph that has propagate_exprs_ disabled."); + + auto first_ids = ir_utils::filterByType( + forward ? first->outputs() : first->inputs()) + .vector(); + auto second_ids = ir_utils::filterByType( + forward ? second->outputs() : second->inputs()) + .vector(); + TORCH_INTERNAL_ASSERT( + first_ids.size() == second_ids.size(), + "This should be unreachable, if transformation expressions match, their number of inputs and outputs should as well.\n However found:\n", + first->toString(), + "\nand\n", + second->toString()); + for (auto out_i : c10::irange(first_ids.size())) { + mapIds(first_ids[out_i], second_ids[out_i]); + } + + return true; +} + +void IdGraph::mapThroughLoopSwizzles() { + std::vector all_swizzles; + + for (auto expr_set : disjointExprSets().disjointSets()) { + auto swizzles_in_expr_set = ir_utils::filterByType( + expr_set->vector().begin(), expr_set->vector().end()); + all_swizzles.insert( + all_swizzles.end(), + swizzles_in_expr_set.begin(), + swizzles_in_expr_set.end()); + } + + for (auto swizzle : all_swizzles) { + if (swizzle->swizzleMode() == SwizzleMode::Loop) { + mapIds(swizzle->inX(), swizzle->outX()); + mapIds(swizzle->inY(), swizzle->outY()); + } + } +} + +void IdGraph::mapThroughTrivialExprs() { + // Grab all expressions + std::vector exprs; + + for (auto expr_group : disjointExprSets().disjointSets()) { + for (auto expr : *expr_group) { + exprs.push_back(expr); + } + } + + for (auto expr : exprs) { + // If not trivial continue + auto mapped_ids = IdGraph::isTrivialExpr(expr); + if (mapped_ids.empty()) { + continue; + } + + // Map through trivial expressions + for (auto mapped_id_group : mapped_ids) { + for (auto id : mapped_id_group) { + mapIds(mapped_id_group.front(), id); + } + } + } +} + +void IdGraph::removeTrivialExprs() { + ExprGroups trivial_expr_groups; + // This seems like it shouls just be a copy if. + for (auto expr_group : disjointExprSets().disjointSets()) { + if (isTrivialExprGroup(expr_group)) { + trivial_expr_groups.pushBack(expr_group); + } + } + + // Clear out expressions that map inputs and outputs to the same group + // from definitions and uses. They shouldn't be important in traversal, and + // will break the terminal input/terminal output logic of traversal. Similar + // to what's drafted in buildIndexGraph + for (auto trivial_expr_group : trivial_expr_groups) { + // Complexity of erase not good as both disjoint set and vector of unique + // entries require a vector find to erase an entry. + eraseExprGroup(trivial_expr_group); + } +} + +// Complexity here is not great. We might want a better complexity version when +// erasing multiple expr_groups. +void IdGraph::eraseExprGroup(ExprGroup expr_group) { + // Erase entries that exist in unique_definitions_ and unique_uses_ + for (auto id_group : disjointIdSets().disjointSets()) { + // Make sure the entries exists + TORCH_INTERNAL_ASSERT( + unique_definitions_.find(id_group) != unique_definitions_.end(), + "Broken definitions, couldn't find entry for id group, ", + nvfuser::toString(id_group, 0, true)); + TORCH_INTERNAL_ASSERT( + unique_uses_.find(id_group) != unique_uses_.end(), + "Broken uses, couldn't find entry for id group, ", + nvfuser::toString(id_group, 0, true)); + + unique_definitions_[id_group].erase(expr_group); + unique_uses_[id_group].erase(expr_group); + } + + for (auto expr : *expr_group) { + disjoint_exprs_.erase(expr); + } +} + +bool IdGraph::isTrivialExprGroup(ExprGroup expr_group) const { + return !IdGroups(inputGroups(expr_group)) + .intersect(IdGroups(outputGroups(expr_group))) + .empty(); +} + +} // namespace nvfuser \ No newline at end of file diff --git a/csrc/id_model/id_graph.h b/csrc/id_model/id_graph.h new file mode 100644 index 00000000000..b8d7a4e3686 --- /dev/null +++ b/csrc/id_model/id_graph.h @@ -0,0 +1,248 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace nvfuser { + +using IdGroup = std::shared_ptr>; +using IdGroups = VectorOfUniqueEntries; +using ExprGroup = std::shared_ptr>; +using ExprGroups = VectorOfUniqueEntries; + +class TORCH_CUDA_CU_API IdGraph { + public: + IdGraph() = default; + + IdGraph(const IdGraph& other); + IdGraph(IdGraph&& other) = default; + + IdGraph& operator=(const IdGraph& other); + IdGraph& operator=(IdGraph&& other) = default; + + // Returns the disjoint IterDomain set. + const DisjointSets& disjointIdSets() const; + + DisjointSets& disjointIdSets(); + + // Returns + // { + // (1) The disjoint set of the provided Iter Domain if it exists, + // otherwise a null shared ptr + // (2) If the disjoint set of the provided Iter Domain exists + // } + // + // TODO: Audit usage + std::pair disjointIdSet(IterDomain* id) const; + + // Returns the disjoint Expr set. + const DisjointSets& disjointExprSets() const; + + DisjointSets& disjointExprSets(); + + // Same as getDisjointIdSet but for the Expression sets. + // + // TODO: Audit usage + std::pair disjointExprSet(Expr* expr) const; + + // Convert expr to its exprGroup, assert that it exists. + ExprGroup toGroup(Expr* expr) const; + + // Convert iter domain to its IdGroup, assert that it exists. + IdGroup toGroup(IterDomain* id) const; + + // Convert unique vector of expressions to unique vector of its groups + ExprGroups toGroups(const VectorOfUniqueEntries& exprs) const; + + // Convert unique vector of IterDomain to unique vector of its groups + IdGroups toGroups(const VectorOfUniqueEntries& ids) const; + + // Return output/input iter domain groups of provided expr + std::vector outputGroups(ExprGroup expr) const; + std::vector inputGroups(ExprGroup expr) const; + + // Traverses uses of the IdGroups in 'of' and returns all ExprGroups + // that have a use in their definition of provided of IdGroups. + ExprGroups allUsesOf(const IdGroups& of) const; + + // Traverses definitions of the IdGroups in 'of' and returns all ExprGroups + // used in this history of defining the 'of' IdGroups. + ExprGroups allDefinitionsOf(const IdGroups& of) const; + + // Return sorted expressions to go from the provided IterDomains in from to + // the provided IterDomains in to with provided mode. Minimal expressions to + // get from 'from' to 'to' returned. + ExprGroups getExprsBetween(const IdGroups& from, const IdGroups& to) const; + + // Supports one to many mappings, uses the disjoint sets of the provided mode + // to produce mappings between from and to. If multiple IterDomains in to map + // to a single iter domain in from, the order of the IterDomains in value of + // the map is preserved to be the order provided in to. + std::unordered_map> + buildMapBetween( + const std::vector& from, + const std::vector& to) const; + + // Alias of the above on unique vector entries + std::unordered_map> + buildMapBetween( + const VectorOfUniqueEntries& from, + const VectorOfUniqueEntries& to) const; + + //! Returns + //! (1) The expressions associated with the definitions of the provided + //! IterDomain group in the provided mapping mode (if it exists). + //! (2) If there is a definitions entry of the provided IterDomain group in + //! the provided mapping mode. + //! First entry in the returned pair is a vector of vector of expressions. The + //! inner vector is proven to be equivalent based on the provided mode. The + //! outer vector are expression groups that are not equivalent based on the + //! provided mode, but produce one of the IterDomains within the same disjoint + //! Iter Domain set based on the provided mode. + //! TODO: Change name to start with get + std::pair iterDomainGroupDefinitions( + IdGroup id_group) const; + + //! Same as iterDomainGroupDefinitions but for uses instead of definitions + //! TODO: Change name to start with get + std::pair iterDomainGroupUses(IdGroup id_group) const; + + std::string toString() const; + + // Checks if the expression is a trivial operation where an input is simply an + // output of the transformation. Returns the mapped iter domains if found. + static std::vector> isTrivialExpr(Expr* expr); + + // Returns if all atributes of the ID transforms first and second are the same + static bool transformAtributesMatch(Expr* first, Expr* second); + + // Initializes entries for the provided IterDomain in the IterDomainGraphs + void initializeId( + IterDomain* id, + const VectorOfUniqueEntries& definitions, + const VectorOfUniqueEntries& uses); + + // Returns if first and second are expressions through which the provided + // id_map have matching inputs (if forward), or outputs (if not forward). + // Returning true means the expressions are "the same", in terms they modify + // matching original extents, by the same amount. + bool exprsMap( + Expr* first, + Expr* second, + bool forward + // , std::vector second_input_or_output_override + ) const; + + // Returns entry in unique_definitions_ for provided group in provided mode, + // otherwise errors if no entry is found. + ExprGroups uniqueDefinitions(IdGroup group) const; + + // Returns entry in unique_uses_ for provided group in provided mode, + // otherwise errors if no entry is found. + ExprGroups uniqueUses(IdGroup group) const; + + std::unordered_map& uniqueUses() { + return unique_uses_; + } + + std::unordered_map& uniqueDefinitions() { + return unique_definitions_; + } + + // Set id0 and id1 to mapped in disjointIdsSet[mode], attempt to propagate + // new mapping through id0/id1 definitions/uses. + void mapIds(IterDomain* id0, IterDomain* id1); + + // Checks if expr0 and expr1 should map together, maps them together, and if + // expression propagation is on, propagates mapping through them. This should + // be the only call in IdGraph to mapThroughExpr + void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); + + // Map expr0 and expr1 with eachother, update unique_definitions_ unique_uses_ + // TODO: Make this variant hidden? + void mapExprs(Expr* expr0, Expr* expr1); + + // Checks if expr's are considered "the same" where sameness inputs and + // outputs in the same position across expressions map with provided + // MappingMode. If the expressions are determined the same then + // if forward + // will map outputs + // else + // will map inputs + // in the provided mode. + // Returns if expressions were mapped through. + // + // TODO: Make this private + bool mapThroughExpr(Expr* first, Expr* second, bool forward); + + // Map through loop swizzles, as input/output IterDomains are exact, only the + // order they're traversed differs. + void mapThroughLoopSwizzles(); + + // Maps iter domain pairs returned by calling that return mappings from + // IdGraph::isTrivialExpr on every expression in the graph. + void mapThroughTrivialExprs(); + + // Removes expressions from unique_definitions_ and unique_uses_ that return + // mappings from IdGraph::isTrivialExpr + void removeTrivialExprs(); + + // See comment on propagate_expr_ member bool for description + // Once disabled this can't be reenabled on a graph. If it's reenabled it's + // hard to predict how mappings will propagate, which will be triggered on the + // next mapping. To support changing this flag, we should likely run through + // all expressions currently registered and propagate through all of them on + // switch. Then once enabled it couldn't be redisabled because we don't record + // the history of mapId calls. + void disableExprPropagation() { + propagate_exprs_ = false; + } + + // Removes the provided expression group from unique_definitions_ and + // unique_uses_ breaking traversal through them. + void eraseExprGroup(ExprGroup expr_group); + + // Returns if the expression group has an input id group that matches an + // output id group. This means traversing on this expression doesn't actually + // do anything. + bool isTrivialExprGroup(ExprGroup expr_group) const; + + private: + // If propagate_exprs_ = false, then mapThroughExpr will not be called as a + // consequence of calling mapIds. As well as mapThroughExpr will not be called + // (again) as a result of calling mapThroughExpr. + // + // Note: For the second sentence of above... mapThroughExpr can call mapIds + // which could in return call mapThoughExpr again, but propagate_exprs_ as + // mentioned above prevents that from happening. + // + // TODO: Should propagate_exprs_ be a const member? + bool propagate_exprs_ = true; + + // Keeps a disjoint set entry for all IterDomain for all mapping mode types. + // + // Using an array here might be nice, but it seems hard to use an enum as an + // array key + // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum + DisjointSets disjoint_ids_; + + // Keeps a disjoint set entry for all Expressions for all mapping mode types. + DisjointSets disjoint_exprs_; + + std::unordered_map unique_definitions_; + + std::unordered_map unique_uses_; + + // Hold a set of IterDomains that are considered view rfactor ids. This + // identification is particularly important to understand if split operations + // are divisible or not. + // + // TODO: This should just be in IterDomainGraphs, not here. + std::unordered_set view_rfactor_ids_; +}; + +} \ No newline at end of file diff --git a/csrc/id_graphs.cpp b/csrc/id_model/id_graphs.cpp similarity index 57% rename from csrc/id_graphs.cpp rename to csrc/id_model/id_graphs.cpp index bc315b219ec..f6938680883 100644 --- a/csrc/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1,4 +1,6 @@ -#include +#include +#include +#include #include #include @@ -13,1404 +15,6 @@ namespace nvfuser { -// Printing utilities to show critical uniqueness information. i.e. being able -// to tell slight differences between groups we're working with. -namespace debug { - -namespace { -// Sometimes it can be helpful to directly check the pointer addresses of the -// groups. As one group might look exactly like another group but are in -// different disjoint sets. Leaving commented out by default. -template -std::string toString(const T* ptr, bool enable) { - if (!enable) { - return ""; - } - std::stringstream ss; - ss << ptr; - return "[0x." + ss.str().substr(9) + "]"; -} - -std::string indent(int size = 0) { - std::stringstream ss; - for (auto i : c10::irange(size)) { - // Unused variable error - if (i >= 0) { - ss << " "; - } - } - return ss.str(); -} -} // namespace - -std::string toString( - const std::vector& id_group, - int indent_size) { - std::vector names; - for (auto id : id_group) { - names.push_back(id->name()); - } - std::sort(names.begin(), names.end()); - - std::stringstream ss; - ss << indent(indent_size) << "{" << names << "}"; - return ss.str(); -} - -std::string toString(const IdGroup& id_group, int indent_size, bool with_ptr) { - std::stringstream ss; - ss << indent(indent_size) << "idg" << (with_ptr ? "(" : "") - << toString(id_group.get(), with_ptr) << (with_ptr ? ")" : "") - << toString(id_group->vector()); - return ss.str(); -} - -std::string toString( - const std::vector& id_groups, - int indent_size, - bool with_ptr) { - std::stringstream ss; - - // Track position in id_groups and its min iter domain name in the set - std::vector> group_name_info; - - unsigned int pos = 0; - - for (auto id_group : id_groups) { - unsigned int min_id_name = std::numeric_limits::max(); - for (auto id : *id_group) { - if (id->name() < min_id_name) { - min_id_name = id->name(); - } - } - group_name_info.push_back(std::make_pair(min_id_name, pos++)); - } - - ss << indent(indent_size) << "(idgs){\n"; - - // Sort based on minimum id in the group - std::sort(group_name_info.begin(), group_name_info.end()); - - for (auto i : c10::irange(group_name_info.size())) { - auto pos = group_name_info[i].second; - ss << toString(id_groups[pos], indent_size + 1, with_ptr) << "\n"; - } - - ss << "}"; - return ss.str(); -} - -std::string toString( - const IdGroups& id_groups, - int indent_size, - bool with_ptr) { - std::stringstream ss; - - // Track position in id_groups and its min iter domain name in the set - std::vector> group_name_info; - - unsigned int pos = 0; - - for (auto id_group : id_groups) { - unsigned int min_id_name = std::numeric_limits::max(); - for (auto id : *id_group) { - if (id->name() < min_id_name) { - min_id_name = id->name(); - } - } - group_name_info.push_back(std::make_pair(min_id_name, pos++)); - } - - ss << indent(indent_size) << "(idgs){\n"; - - // Sort based on minimum id in the group - std::sort(group_name_info.begin(), group_name_info.end()); - - for (auto i : c10::irange(group_name_info.size())) { - auto pos = group_name_info[i].second; - ss << toString(id_groups.vector()[pos], indent_size + 1, with_ptr) << "\n"; - } - - ss << "}"; - return ss.str(); -} - -std::string toInlineString(const std::vector& id_groups) { - // Track position in id_groups and its min iter domain name in the set - std::vector> group_name_info; - - unsigned int pos = 0; - - for (auto id_group : id_groups) { - unsigned int min_id_name = std::numeric_limits::max(); - for (auto id : *id_group) { - if (id->name() < min_id_name) { - min_id_name = id->name(); - } - } - group_name_info.push_back(std::make_pair(min_id_name, pos++)); - } - - // Sort based on minimum id in the group - std::sort(group_name_info.begin(), group_name_info.end()); - - std::stringstream ss; - - ss << "(idgs){"; - bool first = true; - for (auto i : c10::irange(group_name_info.size())) { - if (first) { - first = false; - } else { - ss << ", "; - } - auto pos = group_name_info[i].second; - ss << toString(id_groups[pos]); - } - - return ss.str(); -} - -std::string toString(const std::vector& expr_group, int indent_size) { - std::vector names; - for (auto expr : expr_group) { - names.push_back(expr->name()); - } - std::sort(names.begin(), names.end()); - - std::stringstream ss; - ss << indent(indent_size) << "{" << names << "}"; - return ss.str(); -} - -std::string toString( - const ExprGroup& expr_group, - int indent_size, - bool with_ptr) { - std::stringstream ss; - ss << indent(indent_size) << "exprg" << (with_ptr ? "(" : "") - << toString(expr_group.get(), with_ptr) << (with_ptr ? ")" : "") - << toString(expr_group->vector()); - return ss.str(); -} - -std::string toString( - const IdGraph& id_graph, - const std::vector& expr_groups, - int indent_size, - bool with_ptr) { - std::stringstream ss; - - // Track position in expr_groups and its min iter domain name in the set - std::vector> group_name_info; - - unsigned int pos = 0; - - for (auto expr_group : expr_groups) { - unsigned int min_expr_name = std::numeric_limits::max(); - for (auto expr : *expr_group) { - if (expr->name() < min_expr_name) { - min_expr_name = expr->name(); - } - } - group_name_info.push_back(std::make_pair(min_expr_name, pos++)); - } - - ss << indent(indent_size) << "(exprgs){\n"; - - // Sort based on minimum id in the group - std::sort(group_name_info.begin(), group_name_info.end()); - - for (auto i : c10::irange(group_name_info.size())) { - auto pos = group_name_info[i].second; - auto expr_group = expr_groups[pos]; - - auto inputs = IdGroups(id_graph.inputGroups(expr_group)); - auto outputs = IdGroups(id_graph.outputGroups(expr_group)); - - ss << indent(indent_size + 1) << toInlineString(inputs.vector()) << " --" - << toString(expr_group, 0, with_ptr) << "--> " - << toInlineString(outputs.vector()) << "\n"; - } - - ss << indent(indent_size) << "}"; - return ss.str(); -} - -std::string toString( - const IdGraph& id_graph, - const ExprGroups& expr_groups, - int indent_size, - bool with_ptr) { - std::stringstream ss; - - // Track position in expr_groups and its min iter domain name in the set - std::vector> group_name_info; - - unsigned int pos = 0; - - for (auto expr_group : expr_groups) { - unsigned int min_id_name = std::numeric_limits::max(); - for (auto id : *expr_group) { - if (id->name() < min_id_name) { - min_id_name = id->name(); - } - } - group_name_info.push_back(std::make_pair(min_id_name, pos++)); - } - - ss << indent(indent_size) << "(exprgs){\n"; - - // Sort based on minimum id in the group - std::sort(group_name_info.begin(), group_name_info.end()); - - for (auto i : c10::irange(group_name_info.size())) { - auto pos = group_name_info[i].second; - auto expr_group = expr_groups.vector()[pos]; - - auto inputs = IdGroups(id_graph.inputGroups(expr_group)); - auto outputs = IdGroups(id_graph.outputGroups(expr_group)); - - ss << indent(indent_size + 1) << toInlineString(inputs.vector()) << " --" - << toString(expr_group, 0, with_ptr) << "--> " - << toInlineString(outputs.vector()) << "\n"; - } - - ss << indent(indent_size) << "}"; - return ss.str(); -} - -std::string idGroupsString( - const IdGraph& id_graph, - int indent_size, - bool with_ptr) { - IdGroups id_groups( - id_graph.disjointIdSets().disjointSets().begin(), - id_graph.disjointIdSets().disjointSets().end()); - return toString(id_groups, indent_size, with_ptr); -} -std::string exprGroupsString( - const IdGraph& id_graph, - int indent_size, - bool with_ptr) { - ExprGroups expr_groups( - id_graph.disjointExprSets().disjointSets().begin(), - id_graph.disjointExprSets().disjointSets().end()); - return toString(id_graph, expr_groups, indent_size, with_ptr); -} - -std::string definitionsString( - const IdGraph& id_graph, - int indent_size, - bool with_ptr) { - ExprGroups defs; - for (auto id_group : id_graph.disjointIdSets().disjointSets()) { - auto definition_pair = id_graph.iterDomainGroupDefinitions(id_group); - if (definition_pair.second) { - for (auto expr_group : definition_pair.first) { - defs.pushBack(expr_group); - } - } - } - return toString(id_graph, defs, indent_size, with_ptr); -} - -std::string usesString( - const IdGraph& id_graph, - int indent_size, - bool with_ptr) { - ExprGroups uses; - for (auto id_group : id_graph.disjointIdSets().disjointSets()) { - auto definition_pair = id_graph.iterDomainGroupUses(id_group); - if (definition_pair.second) { - for (auto expr_group : definition_pair.first) { - uses.pushBack(expr_group); - } - } - } - return toString(id_graph, uses, indent_size, with_ptr); -} - -} // namespace debug - -namespace { - -bool transformAtributesMatch(Expr* first, Expr* second) { - if (first == nullptr || second == nullptr) { - return false; - } - - TORCH_INTERNAL_ASSERT( - first->isA() || first->isA() || first->isA() || - first->isA(), - "Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n", - first->toString()); - - if (typeid(*first) != typeid(*second)) { - return false; - } - - if (first->isA()) { - auto first_split = first->as(); - auto second_split = second->as(); - if (!first_split->factor()->sameAs(second_split->factor()) || - first_split->innerSplit() != second_split->innerSplit() || - !first_split->startOffset()->sameAs(second_split->startOffset()) || - !first_split->stopOffset()->sameAs(second_split->stopOffset())) { - return false; - } - } - - if (first->isA()) { - auto first_swizzle = first->as(); - auto second_swizzle = second->as(); - if (first_swizzle->swizzleMode() != second_swizzle->swizzleMode() || - first_swizzle->swizzleType() != second_swizzle->swizzleType()) { - return false; - } - } - - return true; -} -} // namespace - -void IdGraphVisitor::traverse() { - IdGroups all_ids; - ExprGroups all_exprs; - { - if (sub_selection_.empty()) { - all_ids = IdGroups( - graph().disjointIdSets().disjointSets().begin(), - graph().disjointIdSets().disjointSets().end()); - } else { - for (auto id : sub_selection_) { - auto disjoint_pair = graph().disjointIdSet(id); - if (disjoint_pair.second) { - all_ids.pushBack(disjoint_pair.first); - } - } - } - - if (sub_selection_.empty()) { - all_exprs = ExprGroups( - graph().disjointExprSets().disjointSets().begin(), - graph().disjointExprSets().disjointSets().end()); - } else { - for (auto id_group : all_ids) { - for (auto def : graph().uniqueDefinitions(id_group)) { - if (all_exprs.has(def)) { - continue; - } - auto inp_groups = IdGroups(graph().inputGroups(def)); - auto out_groups = IdGroups(graph().outputGroups(def)); - if (inp_groups.subtract(all_ids).empty() && - out_groups.subtract(all_ids).empty()) { - all_exprs.pushBack(def); - } - } - } - } - } - // There could be IterDomains in from or to that are between other from and - // to nodes. Make sure to clear those out. - IdGroups terminating_inputs; - IdGroups terminating_outputs; - - { - IdGroups not_inputs; - IdGroups not_outputs; - for (auto expr_group : all_exprs) { - auto inp_groups = IdGroups(graph().inputGroups(expr_group)); - auto out_groups = IdGroups(graph().outputGroups(expr_group)); - - if (inp_groups.intersect(out_groups).size() > 0) { - // Expression is just a loop to its current group, ignore - continue; - } - - not_inputs.pushBack(out_groups); - not_outputs.pushBack(inp_groups); - } - - terminating_inputs = - IdGroups(all_ids.begin(), all_ids.end()).subtract(not_inputs); - - terminating_outputs = - IdGroups(all_ids.begin(), all_ids.end()).subtract(not_outputs); - } - - IdGroups to_visit_ids = terminating_inputs; - IdGroups visited_ids; - - ExprGroups to_visit_exprs; - ExprGroups visited_exprs; - - auto is_expr_ready = [&](ExprGroup expr_group) { - auto inp_groups = graph().inputGroups(expr_group); - return std::all_of( - inp_groups.begin(), inp_groups.end(), [&](IdGroup id_group) { - return visited_ids.has(id_group) || id_group->empty(); - }); - }; - - auto is_id_ready = [&](IdGroup id_group) { - auto unique_defs = graph().uniqueDefinitions(id_group); - return std::all_of( - unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { - return expr_group->empty() || visited_exprs.has(expr_group) || - graph().isTrivialExprGroup(expr_group); - }); - }; - - while (to_visit_ids.size() > 0 || to_visit_exprs.size() > 0) { - // Process expressions first as all definitions of iter domains have to be - // processed before we can process that iter domain. - - // Detect if nothing has been processed which would put us in an infinite - // loop - bool something_was_processed = false; - ExprGroups still_to_visit_exprs; - - while (to_visit_exprs.size() > 0) { - auto current_expr_group = to_visit_exprs.popFront(); - if (visited_exprs.has(current_expr_group)) { - continue; - } - - if (is_expr_ready(current_expr_group)) { - handle(current_expr_group); - - something_was_processed = true; - visited_exprs.pushBack(current_expr_group); - - auto out_groups = graph().outputGroups(current_expr_group); - for (auto out_group : out_groups) { - to_visit_ids.pushBack(out_group); - } - } else { - still_to_visit_exprs.pushBack(current_expr_group); - } - } - - std::swap(to_visit_exprs, still_to_visit_exprs); - - IdGroups still_to_visit_ids; - while (to_visit_ids.size() > 0) { - auto current_id_group = to_visit_ids.popFront(); - if (visited_ids.has(current_id_group)) { - continue; - } - - if (is_id_ready(current_id_group)) { - handle(current_id_group); - - something_was_processed = true; - visited_ids.pushBack(current_id_group); - - if (!terminating_outputs.has(current_id_group)) { - auto uses_pair = graph().iterDomainGroupUses(current_id_group); - if (uses_pair.second) { - to_visit_exprs.pushBack(uses_pair.first); - } - } - } else { - still_to_visit_ids.pushBack(current_id_group); - } - } - std::swap(to_visit_ids, still_to_visit_ids); - - TORCH_INTERNAL_ASSERT( - something_was_processed || - (to_visit_ids.size() == 0 && to_visit_exprs.size() == 0), - "Infinite loop entered."); - } -} - -IdGraph::IdGraph(const IdGraph& other) { - disjoint_ids_ = other.disjoint_ids_; - disjoint_exprs_ = other.disjoint_exprs_; - view_rfactor_ids_ = other.view_rfactor_ids_; - - for (auto orig_unique_def_pair : other.unique_definitions_) { - auto orig_id_group = orig_unique_def_pair.first; - auto orig_expr_groups = orig_unique_def_pair.second; - - auto new_id_group_pair = disjointIdSet(orig_id_group->front()); - TORCH_INTERNAL_ASSERT(new_id_group_pair.second); - auto new_id_group = new_id_group_pair.first; - - ExprGroups new_expr_groups; - for (auto orig_expr_group : orig_expr_groups) { - auto new_expr_group_pair = disjointExprSet(orig_expr_group->front()); - TORCH_INTERNAL_ASSERT(new_expr_group_pair.second); - new_expr_groups.pushBack(new_expr_group_pair.first); - } - - unique_definitions_[new_id_group] = new_expr_groups; - } - - for (auto orig_unique_use_pair : other.unique_uses_) { - auto orig_id_group = orig_unique_use_pair.first; - auto orig_expr_groups = orig_unique_use_pair.second; - - auto new_id_group_pair = disjointIdSet(orig_id_group->front()); - TORCH_INTERNAL_ASSERT(new_id_group_pair.second); - auto new_id_group = new_id_group_pair.first; - - ExprGroups new_expr_groups; - for (auto orig_expr_group : orig_expr_groups) { - auto new_expr_group_pair = disjointExprSet(orig_expr_group->front()); - TORCH_INTERNAL_ASSERT(new_expr_group_pair.second); - new_expr_groups.pushBack(new_expr_group_pair.first); - } - - unique_uses_[new_id_group] = new_expr_groups; - } -} - -IdGraph& IdGraph::operator=(const IdGraph& other) { - disjoint_ids_.clear(); - disjoint_exprs_.clear(); - unique_definitions_.clear(); - unique_uses_.clear(); - view_rfactor_ids_.clear(); - IdGraph copy(other); - std::swap(*this, copy); - return *this; -} - -const DisjointSets& IdGraph::disjointIdSets() const { - return disjoint_ids_; -} - -DisjointSets& IdGraph::disjointIdSets() { - return disjoint_ids_; -} - -std::pair IdGraph::disjointIdSet(IterDomain* id) const { - auto disjoint_set_it = disjoint_ids_.disjointSetMap().find(id); - if (disjoint_set_it == disjoint_ids_.disjointSetMap().end()) { - return std::make_pair(IdGroup(nullptr), false); - } - return std::make_pair(disjoint_set_it->second, true); -} - -const DisjointSets& IdGraph::disjointExprSets() const { - return disjoint_exprs_; -} - -DisjointSets& IdGraph::disjointExprSets() { - return disjoint_exprs_; -} - -std::pair IdGraph::disjointExprSet(Expr* expr) const { - auto disjoint_set_it = disjoint_exprs_.disjointSetMap().find(expr); - if (disjoint_set_it == disjoint_exprs_.disjointSetMap().end()) { - return std::make_pair(ExprGroup(nullptr), false); - } - return std::make_pair(disjoint_set_it->second, true); -} - -ExprGroup IdGraph::toGroup(Expr* expr) const { - auto disjoint_set_pair = disjointExprSet(expr); - TORCH_INTERNAL_ASSERT( - disjoint_set_pair.second, - "\nExpr group could not be found in graph associated with: ", - expr->toString()); - return disjoint_set_pair.first; -} - -IdGroup IdGraph::toGroup(IterDomain* id) const { - auto disjoint_set_pair = disjointIdSet(id); - TORCH_INTERNAL_ASSERT( - disjoint_set_pair.second, - "\nId group could not be found in graph associated with: ", - id->toString(), - "\n"); - return disjoint_set_pair.first; -} - -ExprGroups IdGraph::toGroups(const VectorOfUniqueEntries& exprs) const { - ExprGroups expr_groups; - for (auto expr : exprs) { - expr_groups.pushBack(toGroup(expr)); - } - return expr_groups; -} - -IdGroups IdGraph::toGroups( - const VectorOfUniqueEntries& ids) const { - IdGroups id_groups; - for (auto id : ids) { - id_groups.pushBack(toGroup(id)); - } - return id_groups; -} - -std::vector IdGraph::outputGroups(ExprGroup expr) const { - std::vector output_groups; - for (auto id_output : - ir_utils::filterByType(expr->front()->outputs())) { - output_groups.push_back(toGroup(id_output)); - } - return output_groups; -} - -std::vector IdGraph::inputGroups(ExprGroup expr) const { - std::vector input_groups; - for (auto id_input : - ir_utils::filterByType(expr->front()->inputs())) { - input_groups.push_back(toGroup(id_input)); - } - return input_groups; -} - -ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { - ExprGroups to_visit; - for (auto of_id_group : of) { - auto group_uses_pair = iterDomainGroupUses(of_id_group); - if (group_uses_pair.second) { - to_visit.pushBack(group_uses_pair.first); - } - } - - ExprGroups visited; - while (to_visit.size() > 0) { - auto current_expr = to_visit.popFront(); - visited.pushBack(current_expr); - auto output_ids = outputGroups(current_expr); - for (auto output_id : output_ids) { - auto group_uses_pair = iterDomainGroupUses(output_id); - if (!group_uses_pair.second) { - continue; - } - for (auto group_use : group_uses_pair.first) { - if (visited.has(group_use)) { - continue; - } - to_visit.pushBack(group_use); - } - } - } - - return visited; -} - -ExprGroups IdGraph::allDefinitionsOf(const IdGroups& of) const { - ExprGroups to_visit; - for (auto of_id_group : of) { - auto group_defs_pair = iterDomainGroupDefinitions(of_id_group); - if (group_defs_pair.second) { - to_visit.pushBack(group_defs_pair.first); - } - } - - ExprGroups visited; - while (to_visit.size() > 0) { - auto current_expr = to_visit.popFront(); - visited.pushBack(current_expr); - auto input_ids = inputGroups(current_expr); - for (auto input_id : input_ids) { - auto group_defs_pair = iterDomainGroupDefinitions(input_id); - if (!group_defs_pair.second) { - continue; - } - for (auto group_def : group_defs_pair.first) { - if (visited.has(group_def)) { - continue; - } - to_visit.pushBack(group_def); - } - } - } - - return visited; -} - -ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) - const { - auto all_uses_of_from = allUsesOf(from); - auto all_definitions_of_to = allDefinitionsOf(to); - - // All of the expressions between from and to. Not all will be used as we - // just want to define each iter domain group once. - auto all_exprs = all_uses_of_from.intersect(all_definitions_of_to); - - // There could be IterDomains in from or to that are between other from and - // to nodes. Make sure to clear those out. - IdGroups terminating_inputs; - IdGroups terminating_outputs; - { - IdGroups not_inputs; - IdGroups not_outputs; - IdGroups all_id_groups; - - for (auto expr_group : all_exprs) { - auto inp_groups = inputGroups(expr_group); - auto out_groups = outputGroups(expr_group); - if (IdGroups(inp_groups).intersect(IdGroups(out_groups)).size() > 0) { - // Expression is just a loop to its current group, ignore - continue; - } - - all_id_groups.pushBack(inp_groups); - - if (!inp_groups.empty()) { - not_outputs.pushBack(inp_groups); - } - - all_id_groups.pushBack(out_groups); - - if (!out_groups.empty()) { - not_inputs.pushBack(out_groups); - } - } - terminating_inputs = all_id_groups.subtract(not_inputs); - terminating_outputs = all_id_groups.subtract(not_outputs); - } - - // Track all expressions to get from outputs to this IterDomain. We - // traverse backwards as that's the direction of indexing expressions. An - // index is assigned to each leaf of a domain and as we traverse backwards - // we're effectively accumulating indexing math. We'll only keep the fewest - // expression lists to get to the iter domain. - std::unordered_map required_ind_exprs_ids; - std::unordered_map required_ind_exprs_exprs; - - // Return if all output IterDomain groups of an expression group have - // already been visited - auto outputsVisited = [&](ExprGroup expr) { - for (auto id_group : outputGroups(expr)) { - if (required_ind_exprs_ids.find(id_group) == - required_ind_exprs_ids.end()) { - return false; - } - } - return true; - }; - - auto allIdUsesVisisted = [&](IdGroup id) { - auto uses_pair = iterDomainGroupUses(id); - if (!uses_pair.second) { - return true; - } - for (auto use_group : uses_pair.first) { - if (all_exprs.has(use_group)) { - if (required_ind_exprs_exprs.find(use_group) == - required_ind_exprs_exprs.end()) { - return false; - } - } - } - return true; - }; - - // Returns all expression groups in required_ind_exprs_ids of outputs - auto requiredExprsOutputs = [&](ExprGroup expr) { - ExprGroups all_output_required_exprs; - for (auto id_group : outputGroups(expr)) { - auto id_group_exprs_it = required_ind_exprs_ids.find(id_group); - TORCH_INTERNAL_ASSERT( - id_group_exprs_it != required_ind_exprs_ids.end(), - "Failure in Iter Domain Graph index resolution, count expected for group: ", - id_group->toString()); - all_output_required_exprs.pushBack(id_group_exprs_it->second); - } - return all_output_required_exprs; - }; - - auto processExpr = [&](ExprGroup expr) { - if (!outputsVisited(expr)) { - return false; - } - // Accumulate expressions from all outputs add this expression and set it - // as current expressions required indexing expressions. - required_ind_exprs_exprs[expr] = requiredExprsOutputs(expr); - return true; - }; - - auto processId = [&](IdGroup id) { - // Track if we've grabed any of the uses required indexing expressions. - bool initialized = false; - // Expression group of all indexing expressions required for this iter - // domain coming back from any of its uses. - ExprGroups min_groups; - - auto uses_pair = iterDomainGroupUses(id); - if (!uses_pair.second) { - // No expressions required for this iter domain, it must be a - // terminating output. - required_ind_exprs_ids[id] = min_groups; - return true; - } - - // Only worry about expressions between inputs and outputs we're - // looking at. - for (auto use_group : uses_pair.first.intersect(all_exprs)) { - auto use_required_ind_exprs_it = required_ind_exprs_exprs.find(use_group); - if (use_required_ind_exprs_it == required_ind_exprs_exprs.end()) { - // If there isn't an entry for the use expression it wasn't - // processed, so don't try to process this iter domain yet. - return false; - } - if (!initialized) { - // If first use found initialize the minimum expression group - min_groups = - use_required_ind_exprs_it->second.computeUnion({use_group}); - initialized = true; - } else if ( - use_required_ind_exprs_it->second.size() + 1 < min_groups.size()) { - // If current use has fewer expressions use that, make sure to add the - // use expression. - min_groups = - use_required_ind_exprs_it->second.computeUnion({use_group}); - } - } - required_ind_exprs_ids[id] = min_groups; - return true; - }; - - IdGroups to_visit_ids = terminating_outputs; - ExprGroups to_visit_exprs; - - while (to_visit_ids.size() > 0 || to_visit_exprs.size() > 0) { - // Process expressions first as all uses of iter domains have to be - // processed before we can process that iter domain. - - // Try to detect when nothing has been processed which would put us in an - // infinite loop - bool something_was_processed = false; - ExprGroups still_to_visit_exprs; - while (to_visit_exprs.size() > 0) { - auto currently_visiting = to_visit_exprs.popFront(); - if (required_ind_exprs_exprs.find(currently_visiting) != - required_ind_exprs_exprs.end()) { - continue; - } - if (processExpr(currently_visiting)) { - something_was_processed = true; - auto inp_groups = inputGroups(currently_visiting); - for (auto inp_group : inp_groups) { - to_visit_ids.pushBack(inp_group); - } - } else { - still_to_visit_exprs.pushBack(currently_visiting); - } - } - - std::swap(to_visit_exprs, still_to_visit_exprs); - - IdGroups still_to_visit_ids; - while (to_visit_ids.size() > 0) { - auto currently_visiting = to_visit_ids.popFront(); - if (required_ind_exprs_ids.find(currently_visiting) != - required_ind_exprs_ids.end()) { - continue; - } - - if (processId(currently_visiting)) { - something_was_processed = true; - auto definitions_pair = iterDomainGroupDefinitions(currently_visiting); - if (definitions_pair.second) { - for (auto def : definitions_pair.first) { - if (!all_exprs.has(def)) { - continue; - } - if (required_ind_exprs_exprs.find(def) == - required_ind_exprs_exprs.end()) { - to_visit_exprs.pushBack(def); - } - } - } - } else { - still_to_visit_ids.pushBack(currently_visiting); - } - } - - TORCH_INTERNAL_ASSERT( - something_was_processed || - (to_visit_ids.size() == 0 && to_visit_exprs.size() == 0), - "Infinite loop entered."); - } - - // We want to traverse the expressions registered in required_ind_exprs_ids, - // let's create a strict "uses path" - std::unordered_map uses_path; - for (auto entry : required_ind_exprs_ids) { - auto id = entry.first; - auto traverse_exprs = entry.second; - auto all_uses = iterDomainGroupUses(id); - if (all_uses.second) { - uses_path[id] = traverse_exprs.intersect(all_uses.first); - } else { - uses_path[id] = {}; - continue; - } - } - - // Topologically sort the uses_path. - ExprGroups sorted_exprs; - ExprGroups to_visit; - - for (auto inp : terminating_inputs) { - auto use_it = uses_path.find(inp); - if (use_it == uses_path.end()) { - // This can happen for a trivial traversal where inputs and outputs are - // exactly the same. - continue; - } - auto uses = use_it->second; - for (auto use : uses) { - to_visit.pushBack(use); - } - } - - IdGroups visited = terminating_inputs; - - while (to_visit.size() > 0) { - bool something_processed = false; - ExprGroups still_to_visit; - while (to_visit.size() > 0) { - auto currently_visiting = to_visit.popFront(); - auto inputs = inputGroups(currently_visiting); - if (std::all_of(inputs.begin(), inputs.end(), [&](IdGroup inp_id) { - return visited.has(inp_id); - })) { - something_processed = true; - sorted_exprs.pushBack(currently_visiting); - auto outputs = outputGroups(currently_visiting); - for (auto out_id : outputs) { - visited.pushBack(out_id); - auto use_pair = iterDomainGroupUses(out_id); - if (!use_pair.second) { - continue; - } - still_to_visit.pushBack(use_pair.first.intersect(all_exprs)); - } - } else { - still_to_visit.pushBack(currently_visiting); - } - } - std::swap(to_visit, still_to_visit); - TORCH_INTERNAL_ASSERT(something_processed, "Infinite loop entered."); - } - - return sorted_exprs; -} - -std::unordered_map> IdGraph:: - buildMapBetween( - const std::vector& from, - const std::vector& to) const { - std::unordered_map from_ids2set; - - for (auto from_id : from) { - auto from_disjoint_set_pair = disjointIdSet(from_id); - if (!from_disjoint_set_pair.second) { - continue; - } - from_ids2set[from_id] = from_disjoint_set_pair.first; - } - - // Map from the sets associated with the IterDomains in to, to those iter - // domains - std::unordered_map> set2to_ids; - - for (auto to_id : to) { - auto to_disjoint_set_pair = disjointIdSet(to_id); - if (!to_disjoint_set_pair.second) { - continue; - } - auto to_set = to_disjoint_set_pair.first; - auto set2to_ids_it = set2to_ids.find(to_set); - - if (set2to_ids_it == set2to_ids.end()) { - set2to_ids[to_set] = {to_id}; - } else { - set2to_ids[to_set].pushBack(to_id); - } - } - - std::unordered_map> - from_ids2to_ids; - for (auto from_id : from) { - from_ids2to_ids[from_id] = VectorOfUniqueEntries(); - - auto from_it = from_ids2set.find(from_id); - TORCH_INTERNAL_ASSERT(from_it != from_ids2set.end()); - - auto from_set = from_it->second; - auto to_entry_it = set2to_ids.find(from_set); - if (to_entry_it == set2to_ids.end()) { - continue; - } - from_ids2to_ids[from_id] = to_entry_it->second; - } - return from_ids2to_ids; -} - -std::unordered_map> IdGraph:: - buildMapBetween( - const VectorOfUniqueEntries& from, - const VectorOfUniqueEntries& to) const { - return buildMapBetween(from.vector(), to.vector()); -} - -std::pair IdGraph::iterDomainGroupDefinitions( - IdGroup id_group) const { - auto null_return = std::make_pair(ExprGroups(), false); - - if (id_group == nullptr) { - return null_return; - } - - auto definitions_it = unique_definitions_.find(id_group); - if (definitions_it == unique_definitions_.end()) { - return null_return; - } - - return std::make_pair(definitions_it->second, true); -} - -std::pair IdGraph::iterDomainGroupUses( - IdGroup id_group) const { - auto null_return = std::make_pair(ExprGroups(), false); - - if (id_group == nullptr) { - return null_return; - } - - auto uses_it = unique_uses_.find(id_group); - if (uses_it == unique_uses_.end()) { - return null_return; - } - - return std::make_pair(uses_it->second, true); -} - -std::string IdGraph::toString() const { - std::stringstream ss; - ss << "IdGraph { \n"; - ss << "Disjoint Ids:\n" - << debug::idGroupsString(*this, 1) << "\n\nDisjoint Expression groups:\n" - << debug::exprGroupsString(*this, 1) << std::endl; - ss << " } IdGraph\n" << std::endl; - return ss.str(); -} - -std::vector> IdGraph::isTrivialExpr(Expr* expr) { - std::vector> mapped_ids; - if (auto merge = dynamic_cast(expr)) { - if (merge->inner()->extent()->isOneInt()) { - mapped_ids.push_back({merge->outer(), merge->out()}); - } - if (merge->outer()->extent()->isOneInt()) { - mapped_ids.push_back({merge->inner(), merge->out()}); - } - } else if (auto split = dynamic_cast(expr)) { - if (split->factor()->isOneInt() && split->startOffset()->isZeroInt() && - split->stopOffset()->isZeroInt()) { - if (split->innerSplit()) { - mapped_ids.push_back({split->in(), split->outer()}); - } else { - mapped_ids.push_back({split->in(), split->inner()}); - } - } - } else if (auto swizzle = dynamic_cast(expr)) { - if (swizzle->swizzleType() == Swizzle2DType::NoSwizzle || - swizzle->swizzleMode() == SwizzleMode::NoSwizzle) { - mapped_ids.push_back({swizzle->inX(), swizzle->outX()}); - mapped_ids.push_back({swizzle->inY(), swizzle->outY()}); - } - } - return mapped_ids; -} - -void IdGraph::initializeId( - IterDomain* id, - const VectorOfUniqueEntries& definitions, - const VectorOfUniqueEntries& uses) { - auto id_disjoint_set = disjointIdSets().initializeSet(id).first->second; - - ExprGroups def_groups; - for (auto def : definitions) { - auto expr_set = disjointExprSets().initializeSet(def).first->second; - def_groups.pushBack(expr_set); - } - unique_definitions_[id_disjoint_set] = def_groups; - - ExprGroups use_groups; - for (auto use : uses) { - auto expr_set = disjointExprSets().initializeSet(use).first->second; - use_groups.pushBack(expr_set); - } - unique_uses_[id_disjoint_set] = use_groups; -} - -bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { - if (!transformAtributesMatch(first, second)) { - return false; - } - - auto first_ids = ir_utils::filterByType( - forward ? first->inputs() : first->outputs()) - .vector(); - - auto second_ids = ir_utils::filterByType( - forward ? second->inputs() : second->outputs()) - .vector(); - - TORCH_INTERNAL_ASSERT( - first_ids.size() == second_ids.size(), - "Expected number of ", - (forward ? "inputs" : "outputs"), - " to match for\n", - first->toString(), - second->toString()); - - { - std::vector> zipped_ids; - - std::transform( - first_ids.begin(), - first_ids.end(), - second_ids.begin(), - std::back_inserter(zipped_ids), - [](IterDomain* first, IterDomain* second) { - return std::make_pair(first, second); - }); - - if (std::any_of( - zipped_ids.begin(), - zipped_ids.end(), - [&](std::pair id_pair) { - return !disjointIdSets().permissiveAreMapped( - id_pair.first, id_pair.second); - })) { - return false; - } - } - - // Special handling for backprop of merge - if (first->isA() && !forward) { - // Can't back prop through merge without making sure one input actually - // matches. This can be done on a map or extent basis. - auto merge0 = first->as(); - auto merge1 = second->as(); - - auto extent_0o = merge0->outer()->extent(); - auto extent_0i = merge0->inner()->extent(); - auto extent_1o = merge1->outer()->extent(); - auto extent_1i = merge1->inner()->extent(); - - auto extent_0_match = extent_0o->sameAs(extent_1o) || - (extent_0o->isConstInt() && extent_1o->isConstInt() && - extent_0o->evaluateInt() == extent_1o->evaluateInt()) || - disjointIdSets().permissiveAreMapped(merge0->outer(), merge1->outer()); - - auto extent_1_match = extent_0i->sameAs(extent_1i) || - (extent_0i->isConstInt() && extent_1i->isConstInt() && - extent_0i->evaluateInt() == extent_1i->evaluateInt()) || - disjointIdSets().permissiveAreMapped(merge0->inner(), merge1->inner()); - - if (!(extent_0_match || extent_1_match)) { - return false; - } - } - - // TODO: For now we're using same as, however we could know what val's are - // exactly the same given the exact map. We might want to pipe that - // information through to here. - if (first->isA()) { - if (!first->as()->leftExpand()->sameAs( - second->as()->leftExpand())) { - return false; - } - - if (!first->as()->rightExpand()->sameAs( - second->as()->rightExpand())) { - return false; - } - } - - return true; -} - -ExprGroups IdGraph::uniqueDefinitions(IdGroup group) const { - auto unique_defs_it = unique_definitions_.find(group); - TORCH_INTERNAL_ASSERT( - unique_defs_it != unique_definitions_.end(), - "Definition not found for IdGroup: ", - group->toString()); - return unique_defs_it->second; -} - -ExprGroups IdGraph::uniqueUses(IdGroup group) const { - auto unique_uses_it = unique_uses_.find(group); - TORCH_INTERNAL_ASSERT( - unique_uses_it != unique_uses_.end(), - "Uses not found for IdGroup: ", - group->toString()); - return unique_uses_it->second; -} - -void IdGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { - if (exprsMap(expr0, expr1, forward)) { - if (propagate_exprs_) { - mapExprs(expr0, expr1); - mapThroughExpr(expr0, expr1, forward); - } else if ( - inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)) && - outputGroups(toGroup(expr0)) == outputGroups(toGroup(expr1))) { - mapExprs(expr0, expr1); - } - } -} - -void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { - if (expr0 == expr1) { - return; - } - - if (disjointExprSets().strictAreMapped(expr0, expr1)) { - return; - } - - ExprGroup expr0_orig_group = toGroup(expr0); - ExprGroup expr1_orig_group = toGroup(expr1); - - disjointExprSets().mapEntries(expr0, expr1); - - auto expr_new_group = toGroup(expr0); - - // Update unique uses of producers - IdGroups producers; - for (auto expr : std::vector{expr0, expr1}) { - for (auto input_id : ir_utils::filterByType(expr->inputs())) { - producers.pushBack(toGroup(input_id)); - } - } - - for (auto producer_group : producers) { - uniqueUses().at(producer_group).erase(expr0_orig_group); - uniqueUses().at(producer_group).erase(expr1_orig_group); - uniqueUses().at(producer_group).pushBack(expr_new_group); - } - - // Update unique definitinos of consumers - IdGroups consumers; - for (auto expr : std::vector{expr0, expr1}) { - for (auto output_id : ir_utils::filterByType(expr->outputs())) { - consumers.pushBack(toGroup(output_id)); - } - } - - for (auto consumer_group : consumers) { - uniqueDefinitions().at(consumer_group).erase(expr0_orig_group); - uniqueDefinitions().at(consumer_group).erase(expr1_orig_group); - uniqueDefinitions().at(consumer_group).pushBack(expr_new_group); - } -} - -void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { - if (id0 == id1) { - return; - } - - if (disjointIdSets().strictAreMapped(id0, id1)) { - return; - } - // Definitions and uses are based on the groups of id0 and id1, don't merge - // them into a single group until we grab all definitions and uses for later - // processing. - auto orig_id_group0 = disjointIdSet(id0).first; - auto orig_id_group1 = disjointIdSet(id1).first; - ExprGroups orig_defs0 = uniqueDefinitions(orig_id_group0); - ExprGroups orig_defs1 = uniqueDefinitions(orig_id_group1); - ExprGroups orig_uses0 = uniqueUses(orig_id_group0); - ExprGroups orig_uses1 = uniqueUses(orig_id_group1); - - // Map the iter domains together before we traverse across definitions and - // uses. Traversing definitions and uses could use the new property of id0 and - // id1 being mapped. - disjointIdSets().mapEntries(id0, id1); - auto new_id_group = disjointIdSet(id0).first; - - unique_definitions_.erase(orig_id_group0); - unique_definitions_.erase(orig_id_group1); - unique_uses_.erase(orig_id_group0); - unique_uses_.erase(orig_id_group1); - - unique_definitions_[new_id_group] = orig_defs0.computeUnion(orig_defs1); - unique_uses_[new_id_group] = orig_uses0.computeUnion(orig_uses1); - - // Propagate on uses - if (orig_uses0.size() > 0 || orig_uses1.size() > 0) { - if (orig_uses0.size() > 0 && orig_uses1.size() > 0) { - for (auto use_group_1 : orig_uses1) { - if (orig_uses0.has(use_group_1)) { - continue; - } - - for (auto use_group_0 : orig_uses0) { - auto use0 = use_group_0->front(); - auto use1 = use_group_1->front(); - maybeMapThroughExprs(use0, use1, true); - } - } - } - } - - // Propagate on definitions - if (orig_defs0.size() > 0 || orig_defs1.size() > 0) { - if (orig_defs0.size() > 0 && orig_defs1.size() > 0) { - for (auto def_group_1 : orig_defs1) { - if (orig_defs0.has(def_group_1)) { - continue; - } - - for (auto def_group_0 : orig_defs0) { - auto def0 = def_group_0->front(); - auto def1 = def_group_1->front(); - maybeMapThroughExprs(def0, def1, false); - } - } - } - } -} - -bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { - if (first == nullptr || second == nullptr) { - return false; - } - - if (!exprsMap(first, second, forward)) { - return false; - } - - TORCH_INTERNAL_ASSERT( - propagate_exprs_, - "Asked to propagate expression mappings on a graph that has propagate_exprs_ disabled."); - - auto first_ids = ir_utils::filterByType( - forward ? first->outputs() : first->inputs()) - .vector(); - auto second_ids = ir_utils::filterByType( - forward ? second->outputs() : second->inputs()) - .vector(); - TORCH_INTERNAL_ASSERT( - first_ids.size() == second_ids.size(), - "This should be unreachable, if transformation expressions match, their number of inputs and outputs should as well.\n However found:\n", - first->toString(), - "\nand\n", - second->toString()); - for (auto out_i : c10::irange(first_ids.size())) { - mapIds(first_ids[out_i], second_ids[out_i]); - } - - return true; -} - void IterDomainGraphs::assertNoSelfMapping() { if (hasSelfMapping()) { TORCH_INTERNAL_ASSERT( @@ -1427,102 +31,6 @@ void IterDomainGraphs::assertNoSelfMapping() { } } -void IdGraph::mapThroughTrivialExprs() { - // Grab all expressions - std::vector exprs; - - for (auto expr_group : disjointExprSets().disjointSets()) { - for (auto expr : *expr_group) { - exprs.push_back(expr); - } - } - - for (auto expr : exprs) { - // If not trivial continue - auto mapped_ids = IdGraph::isTrivialExpr(expr); - if (mapped_ids.empty()) { - continue; - } - - // Map through trivial expressions - for (auto mapped_id_group : mapped_ids) { - for (auto id : mapped_id_group) { - mapIds(mapped_id_group.front(), id); - } - } - } -} - -void IdGraph::removeTrivialExprs() { - ExprGroups trivial_expr_groups; - // This seems like it shouls just be a copy if. - for (auto expr_group : disjointExprSets().disjointSets()) { - if (isTrivialExprGroup(expr_group)) { - trivial_expr_groups.pushBack(expr_group); - } - } - - // Clear out expressions that map inputs and outputs to the same group - // from definitions and uses. They shouldn't be important in traversal, and - // will break the terminal input/terminal output logic of traversal. Similar - // to what's drafted in buildIndexGraph - for (auto trivial_expr_group : trivial_expr_groups) { - // Complexity of erase not good as both disjoint set and vector of unique - // entries require a vector find to erase an entry. - eraseExprGroup(trivial_expr_group); - } -} - -void IdGraph::mapThroughLoopSwizzles() { - std::vector all_swizzles; - - for (auto expr_set : disjointExprSets().disjointSets()) { - auto swizzles_in_expr_set = ir_utils::filterByType( - expr_set->vector().begin(), expr_set->vector().end()); - all_swizzles.insert( - all_swizzles.end(), - swizzles_in_expr_set.begin(), - swizzles_in_expr_set.end()); - } - - for (auto swizzle : all_swizzles) { - if (swizzle->swizzleMode() == SwizzleMode::Loop) { - mapIds(swizzle->inX(), swizzle->outX()); - mapIds(swizzle->inY(), swizzle->outY()); - } - } -} - -// Complexity here is not great. We might want a better complexity version when -// erasing multiple expr_groups. -void IdGraph::eraseExprGroup(ExprGroup expr_group) { - // Erase entries that exist in unique_definitions_ and unique_uses_ - for (auto id_group : disjointIdSets().disjointSets()) { - // Make sure the entries exists - TORCH_INTERNAL_ASSERT( - unique_definitions_.find(id_group) != unique_definitions_.end(), - "Broken definitions, couldn't find entry for id group, ", - debug::toString(id_group, 0, true)); - TORCH_INTERNAL_ASSERT( - unique_uses_.find(id_group) != unique_uses_.end(), - "Broken uses, couldn't find entry for id group, ", - debug::toString(id_group, 0, true)); - - unique_definitions_[id_group].erase(expr_group); - unique_uses_[id_group].erase(expr_group); - } - - for (auto expr : *expr_group) { - disjoint_exprs_.erase(expr); - } -} - -bool IdGraph::isTrivialExprGroup(ExprGroup expr_group) const { - return !IdGroups(inputGroups(expr_group)) - .intersect(IdGroups(outputGroups(expr_group))) - .empty(); -} - IterDomainGraphs::IterDomainGraphs( const std::vector& exprs, const std::vector& additional_tvs, @@ -1774,9 +282,9 @@ std::string IterDomainGraphs::toString() const { std::stringstream ss; ss << " IdGraph " << mode << "{ \n"; ss << " Disjoint Ids:\n" - << debug::idGroupsString(idGraph(mode), 2) + << idGroupsString(idGraph(mode), 2) << "\n Disjoint Expression groups:\n" - << debug::exprGroupsString(idGraph(mode), 2) << std::endl; + << exprGroupsString(idGraph(mode), 2) << std::endl; ss << " } IdGraph\n" << std::endl; return ss.str(); } @@ -2777,7 +1285,7 @@ std::unordered_map IterDomainGraphs:: } for (auto iel_use_group : non_promoted_input_uses) { - if (transformAtributesMatch(iel_expr->front(), iel_use_group->front())) { + if (IdGraph::transformAtributesMatch(iel_expr->front(), iel_use_group->front())) { auto use_inps = ir_utils::filterByType(iel_use_group->front()->inputs()) .vector(); @@ -2835,9 +1343,9 @@ std::unordered_map updateMap( "\nUpdate map assumes that new graph is equivalent to old graph plus extra mappings.\n", "i.e. all mappings in new_graph should exist in the graph stale_map was produced on.\n", "old:", - debug::toString(stale_id_group), + nvfuser::toString(stale_id_group), "new: ", - debug::toString(new_groups)); + nvfuser::toString(new_groups)); new_map[new_groups.front()] = stale_entry.second; } return new_map; @@ -3059,18 +1567,18 @@ std::unordered_map IterDomainGraphs:: std::stringstream err_msg; err_msg << "\n ERROR Loop promotion map build. Could not find promotion for loop group:\n "; - err_msg << debug::toString(loop_group, 0, true); + err_msg << nvfuser::toString(loop_group, 0, true); err_msg << "\nnone of the terminal iter domains of this group:\n "; for (auto entry : exact_promoted_terminal_ids) { auto terminal_id_group = entry.first; auto covered_id_groups = exact_covered_ids.at(terminal_id_group); - err_msg << " " << debug::toString(terminal_id_group, 0, true) - << " -(covers)-> " << debug::toString(covered_id_groups) + err_msg << " " << nvfuser::toString(terminal_id_group, 0, true) + << " -(covers)-> " << nvfuser::toString(covered_id_groups) << std::endl; } err_msg << "iter domains in this group cover all id groups:\n"; for (auto covered_group : loop_group_covered_ids) { - err_msg << " " << debug::toString(covered_group, 0, true); + err_msg << " " << nvfuser::toString(covered_group, 0, true); } // TORCH_INTERNAL_ASSERT(false, err_msg.str()); } else { @@ -3197,7 +1705,7 @@ std::unordered_map IterDomainGraphs:: // Check every use to see if it matches for (auto exact_use_group : promoted_input_uses) { // Check if all the attributes (including type) of the transform match - if (!transformAtributesMatch( + if (!IdGraph::transformAtributesMatch( iel_expr->front(), exact_use_group->front())) { continue; } diff --git a/csrc/id_model/id_graphs.h b/csrc/id_model/id_graphs.h new file mode 100644 index 00000000000..e8b9a1e5f36 --- /dev/null +++ b/csrc/id_model/id_graphs.h @@ -0,0 +1,290 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace nvfuser { + +class IdGraph; + +namespace { +// Convenience to store some intermediate data across a few lowering build +// passes. +struct StatefulLoweringInfo; +} // namespace + +// A collection of IterDomainGraphs that are built from a fusion or series of +// expressions. These graphs are related, but have some distinct features based +// on the IdMappingMode. +// +// EXACT/PERMISSIVE mode: +// +// consumer[i0, b1] = producer[i0] +// consumer->merge(0) (consumer will now be [i0 * b1]) +// +// When producer is replayed as consumer (the direction we use for mapping) +// with forwarding from ForwardingInfo the producer to consumer map will have +// both a mapping of consumer(i0) to producer(i0) as well as consumer(i0*b1) to +// producer(i0). This latter mapping is important for loop nest mappings as the +// consumer will generate a loop based on i0*b1 and the producer may be +// computeAt inside this loop nest. However, for indexing we do not want these +// two iter domains mapped as producer may be indexed as i0*i1 depending on the +// loop nest structure and how it was built. +// +// Exact mode is if the iter domain relationships from producer to consumer are +// considered the exact same size operating on matching dimensions from the root +// domain mapping. +// +// LOOP mode is important to resolve inlined broadcassts. If we have something +// like: consumer[i0o, threadIdx.x{i0i}] = producer[i0o, +// threadIdx.y{i0i}](computeAt = 1) which can easily happen when using shared +// memory. Loop is actually defined for all iteration domains, and resembles +// groups of iter domains that are effectively inlined with eachother. Therefore +// iter domain's that are a common dependency of inlined leaf domains may be +// loop mapped together. This map is developed in lowering from +// bulidInlinePromotions and buildLoopPromotionMap. +// +// Loop promotion is a mechanism by which to capture inlined resolved +// broadcasts. If a consumer resolves a broadcast of a producer, and the +// producer's broadcast is inlined (in total or partially). Then the producer's +// iter domain will be "promoted" to the size of the consumers iter domain. +// +// IdMappingMode::LOOP +// Forward broadcast axes in replay +// Denotes groups of IterDomains that are considered promoted to a common iter +// domain size +// IdMappingMode::PERMISSIVE +// Forward broadcast axes in replay +// Map all iteration domains +// Always contain root mappings (otherwise they could have been forwarded in +// broadcast) +// IdMappingMode::EXACT +// Don't map any broadcast axes to non-broadcast axes +// Do not forward through any broadcast IDs +// IdMappingMode::AlmostExact +// Forward through broadcast axes, but not through to a non-broadcast axis +// i.e. id{b1*i0}, id{i0} are mapped +// id{i1*i0}, id{i0} are not mapped (this part is the difference from +// PERMISSIVE) +// Forward through split one axes, i.e. id{ceilDiv(i0, 1)}, id{i0} are mapped +// +class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { + public: + IterDomainGraphs( + const std::vector& exprs, + const std::vector& additional_tvs, + bool allow_self_mapping = false); + + IterDomainGraphs( + const std::vector& exprs, + bool allow_self_mapping = false); + + // Same as the above constructor with fusion->exprs() excpet fusion may have + // some dangling inputs/outputs that are expected to have IterDomain entries + // even though there's no possible connections from them. + IterDomainGraphs(Fusion* fusion, bool allow_self_mapping = false); + + // Returns iter domain graph of provided mode. + const IdGraph& idGraph(IdMappingMode mode) const; + IdGraph& idGraph(IdMappingMode mode); + + // IterDomains from the original fusion are only allowed to be used once in + // the IterDomain graph, id->uses() are not directly used as there's no bounds + // check that would prevent a use from being defined that's not part of the + // actual fusion definition. + // + // Note, any iter domains used during something like loop or concrete id + // resolution could actually have multiple Expr* uses, and uses on disjoint id + // sets should be used, not this. + // + // TODO: Refactor or remove? + Expr* idUse(IterDomain* id) const; + Expr* idDef(IterDomain* id) const; + + // TODO: Seems a bit unfortunate that this isn't IterDomain local information. + const std::unordered_set& viewRfactorIds() const { + return view_rfactor_ids_; + } + + // Returns if a self mapping was detected that would invalidate assumptions of + // the overall lowering system. + // + // TODO: Can we make this more of an alias analysis? + // Ref: https://github.com/csarofeen/pytorch/pull/1954#discussion_r961940498 + bool hasSelfMapping() const { + return self_mapping_info_.has_value(); + } + + // Update the LOOP ID disjoint sets with resolved computeWith + void updateComputeWith(TensorView* compute_with_tv); + + std::string toString() const; + + // Replay Expr but with the inputs provided. IterDomainGraphss will be updated + // for all maps that have entries, adding the output iter domains of the + // replayed expression and adding potential mappings through the expression. + Expr* addReplayAs(std::vector new_inputs, Expr* expr); + + // Similar to addReplayAs, but clones the expr exactly instead of replaying it + // forward. It's up to the calling code to make sure the replacements are + // valid for the provided expr. It's generally recommended that the + // IterDomains exactly match those in the expr. + // + // "forward" dictates the same argument for mapThroughExpr. If forward the + // function will apply mapThroughExpr forward if inputs map in each + // initialized map. Else does the same but backwards through the expression + // from outputs. + Expr* addExprWithReplacement( + const std::unordered_map& old_2_new_ids, + Expr* old_expr); + + // Make a new expr matching that provided but using the outputs provided. + // IterDomainGraphss will be updated for all maps that have entries. Adding + // the input iter domains of the replayed expression and adding potential + // mappings through the expressions. Input domains will match exactly in all + // properties as those in expr. This is unlike addReplayAs which will produce + // new outputs using transformations directly. + Expr* addBackwardsReplayAs( + const std::vector& new_outputs, + Expr* expr); + + // Make an exact copy of provided IterDomain (without rfactor set), and map + // the copy to the original in all registered IdGraphs. IterDomain copy will + // not have any registered uses or definitions. + IterDomain* cloneIterDomain(IterDomain* id); + + // TODO: Should this not be private? + protected: + // Sometimes fusion inputs or outputs are disconnected from expressions, in + // those cases we still may want to send in some additional tensor views from + // the Fusion that don't have expressions associated with them. + void build( + const std::vector& exprs, + const std::vector& additional_tvs); + + // ======= START Iteration domain build process in order called ======= + + // Fills id_uses_ and id_definitions_ for all IterDomains active in the + // fusion. + void buildIterDomainDefinitionsAndUses( + const std::vector& all_tvs); + + // Iterates over all IterDomains in id_definitions_ and calls initializeID on + // a new IdGraph and returns it. + IdGraph initializeIdGraph(); + + // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs + // and first output of expr + void buildExactMap(const std::vector& exprs); + + // Fills disjoint_ids_[IdMappingMode::ALMOSTEXACT]. Initialize AlmostExact as + // Exact entries, then map anything that's either merged with a size-1 or + // split by a size-1 dimension. + void buildAlmostExactMap(); + + // Fills disjoint_ids_[IdMappingMode::PERMISSIVE]. Initialize PermissiveMap as + // AlmostExact entries, then map through broadcasts + void buildPermissiveMap(const std::vector& exprs); + + // Make sure only leaf nodes of tensor views are parallelized + void validatePTypes(const std::vector& all_tvs) const; + + //! Run through disjoint sets in the LOOP map, make sure there's only one + //! non-serial parallel type in each disjoint set, set the parallel type of + //! all IterDomains in the disjoint set to that PType. + void propagateLoopPTypes() const; + + // !! START Helper functions to build loop promotion and index map!! + + // Terminal loop ids are iteration domains in each loop group that: + // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a + // consumer TV's iter domain maps to this domain in a way that that domain + // is also in the same loop group + // 2) Don't have a direct IterDomain consumer within the group + VectorOfUniqueEntries computeTerminalLoopIds( + const StatefulLoweringInfo info); + + // Returns an IdGraph with all Id's mapped that are mapped both in graph0 and + // graph1. + IdGraph buildIntersection( + const IdGraph& graph0, + const IdGraph& graph1, + bool propagate_exprs = true); + + // !! END Helper functions to build loop promotion and index map!! + + // Start loop map by grouping inlined iter domains + void initializeLoopMap(StatefulLoweringInfo& info); + + // Returns map of IdGroups in the loop map to a representative IterDomain that + // contains all resolved transformations that the terminal IterDomains should + // be promoted to. The returned promotions are valid only for inlined iter + // domains. + std::unordered_map buildInlinePromotions( + StatefulLoweringInfo& info); + + // Returns a similar thing to buildInlinePromotions but also includes iter + // domains that are not inlined. + std::unordered_map buildLoopPromotionMap( + const std::vector& exprs, + StatefulLoweringInfo& info, + std::unordered_map stale_promotion_map); + + // Builds idGraph(IdMappingMode::INDEX) and returns the iter domain promotion + // map to go from leaf domains of each (consumer only?) tensor to their + // corresponding leaf domain in the index graph. + std::unordered_map buildIndexGraph( + const std::vector& exprs, + const std::vector& all_tvs, + StatefulLoweringInfo& info, + std::unordered_map stale_promotion_map); + + // Returns the terminal rfactor or input iter domains each group in the almost + // exact map covers (in the almost exact map). This effectively returns all + // the input almost exact iter domain groups for each almost exact iter domain + // group. RFactor axes are considered an "input" as all broadcast dimensions + // have to be resolved by or before the rfactor iter domain. + std::unordered_map buildCoveredAlmostExact(); + + // ======= END Iteration domain build process in order called ======= + + // Errors if self mapping occurs + void assertNoSelfMapping(); + + // Keeps a disjoint set entry for all IterDomain for all mapping mode types. + // + // Using an array here might be nice, but it seems hard to use an enum as an + // array key + // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum + std::unordered_map id_graphs_; + + // If multiple transformations occur IterDomains could have multiple uses, + // however only one should be active in the given Fusion. When we resolve loop + // promotions during lowering, we can generate new iter domains from existing + // ones, so there can be multiple uses generated. Tracks all the active iter + // domain uses. + std::unordered_map> id_uses_; + + // Make sure we don't blindly use definitions as we don't want to grab + // transformations before a tensor view's root domain. + std::unordered_map> id_definitions_; + + // Debug information to hold if a self mapping in a TensorView is found. + c10::optional> + self_mapping_info_ = c10::nullopt; + + std::unordered_map loop_promotion_map_; + + std::unordered_set view_rfactor_ids_; +}; + +using DoubleBufferIndices = std::unordered_map; + +} // namespace nvfuser diff --git a/csrc/id_model/to_string.cpp b/csrc/id_model/to_string.cpp new file mode 100644 index 00000000000..e5991b4ade5 --- /dev/null +++ b/csrc/id_model/to_string.cpp @@ -0,0 +1,322 @@ + +#include + +namespace nvfuser { + +// Printing utilities to show critical uniqueness information. i.e. being able +// to tell slight differences between groups we're working with. +namespace { +// Sometimes it can be helpful to directly check the pointer addresses of the +// groups. As one group might look exactly like another group but are in +// different disjoint sets. Leaving commented out by default. +template +std::string toString(const T* ptr, bool enable) { + if (!enable) { + return ""; + } + std::stringstream ss; + ss << ptr; + return "[0x." + ss.str().substr(9) + "]"; +} + +std::string indent(int size = 0) { + std::stringstream ss; + for (auto i : c10::irange(size)) { + // Unused variable error + if (i >= 0) { + ss << " "; + } + } + return ss.str(); +} +} // namespace + +std::string toString( + const std::vector& id_group, + int indent_size) { + std::vector names; + for (auto id : id_group) { + names.push_back(id->name()); + } + std::sort(names.begin(), names.end()); + + std::stringstream ss; + ss << indent(indent_size) << "{" << names << "}"; + return ss.str(); +} + +std::string toString(const IdGroup& id_group, int indent_size, bool with_ptr) { + std::stringstream ss; + ss << indent(indent_size) << "idg" << (with_ptr ? "(" : "") + << toString(id_group.get(), with_ptr) << (with_ptr ? ")" : "") + << toString(id_group->vector()); + return ss.str(); +} + +std::string toString( + const std::vector& id_groups, + int indent_size, + bool with_ptr) { + std::stringstream ss; + + // Track position in id_groups and its min iter domain name in the set + std::vector> group_name_info; + + unsigned int pos = 0; + + for (auto id_group : id_groups) { + unsigned int min_id_name = std::numeric_limits::max(); + for (auto id : *id_group) { + if (id->name() < min_id_name) { + min_id_name = id->name(); + } + } + group_name_info.push_back(std::make_pair(min_id_name, pos++)); + } + + ss << indent(indent_size) << "(idgs){\n"; + + // Sort based on minimum id in the group + std::sort(group_name_info.begin(), group_name_info.end()); + + for (auto i : c10::irange(group_name_info.size())) { + auto pos = group_name_info[i].second; + ss << toString(id_groups[pos], indent_size + 1, with_ptr) << "\n"; + } + + ss << "}"; + return ss.str(); +} + +std::string toString( + const IdGroups& id_groups, + int indent_size, + bool with_ptr) { + std::stringstream ss; + + // Track position in id_groups and its min iter domain name in the set + std::vector> group_name_info; + + unsigned int pos = 0; + + for (auto id_group : id_groups) { + unsigned int min_id_name = std::numeric_limits::max(); + for (auto id : *id_group) { + if (id->name() < min_id_name) { + min_id_name = id->name(); + } + } + group_name_info.push_back(std::make_pair(min_id_name, pos++)); + } + + ss << indent(indent_size) << "(idgs){\n"; + + // Sort based on minimum id in the group + std::sort(group_name_info.begin(), group_name_info.end()); + + for (auto i : c10::irange(group_name_info.size())) { + auto pos = group_name_info[i].second; + ss << toString(id_groups.vector()[pos], indent_size + 1, with_ptr) << "\n"; + } + + ss << "}"; + return ss.str(); +} + +std::string toInlineString(const std::vector& id_groups) { + // Track position in id_groups and its min iter domain name in the set + std::vector> group_name_info; + + unsigned int pos = 0; + + for (auto id_group : id_groups) { + unsigned int min_id_name = std::numeric_limits::max(); + for (auto id : *id_group) { + if (id->name() < min_id_name) { + min_id_name = id->name(); + } + } + group_name_info.push_back(std::make_pair(min_id_name, pos++)); + } + + // Sort based on minimum id in the group + std::sort(group_name_info.begin(), group_name_info.end()); + + std::stringstream ss; + + ss << "(idgs){"; + bool first = true; + for (auto i : c10::irange(group_name_info.size())) { + if (first) { + first = false; + } else { + ss << ", "; + } + auto pos = group_name_info[i].second; + ss << toString(id_groups[pos]); + } + + return ss.str(); +} + +std::string toString(const std::vector& expr_group, int indent_size) { + std::vector names; + for (auto expr : expr_group) { + names.push_back(expr->name()); + } + std::sort(names.begin(), names.end()); + + std::stringstream ss; + ss << indent(indent_size) << "{" << names << "}"; + return ss.str(); +} + +std::string toString( + const ExprGroup& expr_group, + int indent_size, + bool with_ptr) { + std::stringstream ss; + ss << indent(indent_size) << "exprg" << (with_ptr ? "(" : "") + << toString(expr_group.get(), with_ptr) << (with_ptr ? ")" : "") + << toString(expr_group->vector()); + return ss.str(); +} + +std::string toString( + const IdGraph& id_graph, + const std::vector& expr_groups, + int indent_size, + bool with_ptr) { + std::stringstream ss; + + // Track position in expr_groups and its min iter domain name in the set + std::vector> group_name_info; + + unsigned int pos = 0; + + for (auto expr_group : expr_groups) { + unsigned int min_expr_name = std::numeric_limits::max(); + for (auto expr : *expr_group) { + if (expr->name() < min_expr_name) { + min_expr_name = expr->name(); + } + } + group_name_info.push_back(std::make_pair(min_expr_name, pos++)); + } + + ss << indent(indent_size) << "(exprgs){\n"; + + // Sort based on minimum id in the group + std::sort(group_name_info.begin(), group_name_info.end()); + + for (auto i : c10::irange(group_name_info.size())) { + auto pos = group_name_info[i].second; + auto expr_group = expr_groups[pos]; + + auto inputs = IdGroups(id_graph.inputGroups(expr_group)); + auto outputs = IdGroups(id_graph.outputGroups(expr_group)); + + ss << indent(indent_size + 1) << toInlineString(inputs.vector()) << " --" + << toString(expr_group, 0, with_ptr) << "--> " + << toInlineString(outputs.vector()) << "\n"; + } + + ss << indent(indent_size) << "}"; + return ss.str(); +} + +std::string toString( + const IdGraph& id_graph, + const ExprGroups& expr_groups, + int indent_size, + bool with_ptr) { + std::stringstream ss; + + // Track position in expr_groups and its min iter domain name in the set + std::vector> group_name_info; + + unsigned int pos = 0; + + for (auto expr_group : expr_groups) { + unsigned int min_id_name = std::numeric_limits::max(); + for (auto id : *expr_group) { + if (id->name() < min_id_name) { + min_id_name = id->name(); + } + } + group_name_info.push_back(std::make_pair(min_id_name, pos++)); + } + + ss << indent(indent_size) << "(exprgs){\n"; + + // Sort based on minimum id in the group + std::sort(group_name_info.begin(), group_name_info.end()); + + for (auto i : c10::irange(group_name_info.size())) { + auto pos = group_name_info[i].second; + auto expr_group = expr_groups.vector()[pos]; + + auto inputs = IdGroups(id_graph.inputGroups(expr_group)); + auto outputs = IdGroups(id_graph.outputGroups(expr_group)); + + ss << indent(indent_size + 1) << toInlineString(inputs.vector()) << " --" + << toString(expr_group, 0, with_ptr) << "--> " + << toInlineString(outputs.vector()) << "\n"; + } + + ss << indent(indent_size) << "}"; + return ss.str(); +} + +std::string idGroupsString( + const IdGraph& id_graph, + int indent_size, + bool with_ptr) { + IdGroups id_groups( + id_graph.disjointIdSets().disjointSets().begin(), + id_graph.disjointIdSets().disjointSets().end()); + return toString(id_groups, indent_size, with_ptr); +} +std::string exprGroupsString( + const IdGraph& id_graph, + int indent_size, + bool with_ptr) { + ExprGroups expr_groups( + id_graph.disjointExprSets().disjointSets().begin(), + id_graph.disjointExprSets().disjointSets().end()); + return toString(id_graph, expr_groups, indent_size, with_ptr); +} + +std::string definitionsString( + const IdGraph& id_graph, + int indent_size, + bool with_ptr) { + ExprGroups defs; + for (auto id_group : id_graph.disjointIdSets().disjointSets()) { + auto definition_pair = id_graph.iterDomainGroupDefinitions(id_group); + if (definition_pair.second) { + for (auto expr_group : definition_pair.first) { + defs.pushBack(expr_group); + } + } + } + return toString(id_graph, defs, indent_size, with_ptr); +} + +std::string usesString( + const IdGraph& id_graph, + int indent_size, + bool with_ptr) { + ExprGroups uses; + for (auto id_group : id_graph.disjointIdSets().disjointSets()) { + auto definition_pair = id_graph.iterDomainGroupUses(id_group); + if (definition_pair.second) { + for (auto expr_group : definition_pair.first) { + uses.pushBack(expr_group); + } + } + } + return toString(id_graph, uses, indent_size, with_ptr); +} + +} \ No newline at end of file diff --git a/csrc/id_model/to_string.h b/csrc/id_model/to_string.h new file mode 100644 index 00000000000..eafe6ffcbd5 --- /dev/null +++ b/csrc/id_model/to_string.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include + +#include +#include + +namespace nvfuser { + +std::string toString( + const std::vector& id_group, + int indent_size = 0); +std::string toString( + const IdGroup& id_group, + int indent_size = 0, + bool with_ptr = false); + +std::string toString( + const std::vector& id_groups, + int indent_size = 0, + bool with_ptr = false); + +std::string toString( + const IdGroups& id_groups, + int indent_size = 0, + bool with_ptr = false); + +std::string toInlineString(const std::vector& id_groups); +std::string toInlineString(const IdGroups& id_groups); + +std::string toString(const std::vector& expr_group, int indent_size = 0); +std::string toString( + const ExprGroup& expr_group, + int indent_size = 0, + bool with_ptr = false); + +std::string toString( + const IdGraph& id_graph, + const std::vector& expr_group, + int indent_size = 0, + bool with_ptr = false); +std::string toString( + const IdGraph& id_graph, + const ExprGroup& expr_groups, + int indent_size = 0, + bool with_ptr = false); + +std::string toString( + const IdGraph& id_graph, + const std::vector& expr_groups, + int indent_size = 0, + bool with_ptr = false); +std::string toString( + const IdGraph& id_graph, + const ExprGroups& expr_groups, + int indent_size = 0, + bool with_ptr = false); + +std::string idGroupsString( + const IdGraph& id_graph, + int indent_size = 0, + bool with_ptr = false); +std::string exprGroupsString( + const IdGraph& id_graph, + int indent_size = 0, + bool with_ptr = false); +std::string definitionsString( + const IdGraph& id_graph, + int indent_size = 0, + bool with_ptr = false); +std::string usesString( + const IdGraph& id_graph, + int indent_size = 0, + bool with_ptr = false); + +} // namespace nvfuser \ No newline at end of file diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp new file mode 100644 index 00000000000..91c7aeab4ab --- /dev/null +++ b/csrc/id_model/visitor.cpp @@ -0,0 +1,156 @@ +#include + +namespace nvfuser{ + +void IdGraphVisitor::traverse() { + IdGroups all_ids; + ExprGroups all_exprs; + { + if (sub_selection_.empty()) { + all_ids = IdGroups( + graph().disjointIdSets().disjointSets().begin(), + graph().disjointIdSets().disjointSets().end()); + } else { + for (auto id : sub_selection_) { + auto disjoint_pair = graph().disjointIdSet(id); + if (disjoint_pair.second) { + all_ids.pushBack(disjoint_pair.first); + } + } + } + + if (sub_selection_.empty()) { + all_exprs = ExprGroups( + graph().disjointExprSets().disjointSets().begin(), + graph().disjointExprSets().disjointSets().end()); + } else { + for (auto id_group : all_ids) { + for (auto def : graph().uniqueDefinitions(id_group)) { + if (all_exprs.has(def)) { + continue; + } + auto inp_groups = IdGroups(graph().inputGroups(def)); + auto out_groups = IdGroups(graph().outputGroups(def)); + if (inp_groups.subtract(all_ids).empty() && + out_groups.subtract(all_ids).empty()) { + all_exprs.pushBack(def); + } + } + } + } + } + // There could be IterDomains in from or to that are between other from and + // to nodes. Make sure to clear those out. + IdGroups terminating_inputs; + IdGroups terminating_outputs; + + { + IdGroups not_inputs; + IdGroups not_outputs; + for (auto expr_group : all_exprs) { + auto inp_groups = IdGroups(graph().inputGroups(expr_group)); + auto out_groups = IdGroups(graph().outputGroups(expr_group)); + + if (inp_groups.intersect(out_groups).size() > 0) { + // Expression is just a loop to its current group, ignore + continue; + } + + not_inputs.pushBack(out_groups); + not_outputs.pushBack(inp_groups); + } + + terminating_inputs = + IdGroups(all_ids.begin(), all_ids.end()).subtract(not_inputs); + + terminating_outputs = + IdGroups(all_ids.begin(), all_ids.end()).subtract(not_outputs); + } + + IdGroups to_visit_ids = terminating_inputs; + IdGroups visited_ids; + + ExprGroups to_visit_exprs; + ExprGroups visited_exprs; + + auto is_expr_ready = [&](ExprGroup expr_group) { + auto inp_groups = graph().inputGroups(expr_group); + return std::all_of( + inp_groups.begin(), inp_groups.end(), [&](IdGroup id_group) { + return visited_ids.has(id_group) || id_group->empty(); + }); + }; + + auto is_id_ready = [&](IdGroup id_group) { + auto unique_defs = graph().uniqueDefinitions(id_group); + return std::all_of( + unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { + return expr_group->empty() || visited_exprs.has(expr_group) || + graph().isTrivialExprGroup(expr_group); + }); + }; + + while (to_visit_ids.size() > 0 || to_visit_exprs.size() > 0) { + // Process expressions first as all definitions of iter domains have to be + // processed before we can process that iter domain. + + // Detect if nothing has been processed which would put us in an infinite + // loop + bool something_was_processed = false; + ExprGroups still_to_visit_exprs; + + while (to_visit_exprs.size() > 0) { + auto current_expr_group = to_visit_exprs.popFront(); + if (visited_exprs.has(current_expr_group)) { + continue; + } + + if (is_expr_ready(current_expr_group)) { + handle(current_expr_group); + + something_was_processed = true; + visited_exprs.pushBack(current_expr_group); + + auto out_groups = graph().outputGroups(current_expr_group); + for (auto out_group : out_groups) { + to_visit_ids.pushBack(out_group); + } + } else { + still_to_visit_exprs.pushBack(current_expr_group); + } + } + + std::swap(to_visit_exprs, still_to_visit_exprs); + + IdGroups still_to_visit_ids; + while (to_visit_ids.size() > 0) { + auto current_id_group = to_visit_ids.popFront(); + if (visited_ids.has(current_id_group)) { + continue; + } + + if (is_id_ready(current_id_group)) { + handle(current_id_group); + + something_was_processed = true; + visited_ids.pushBack(current_id_group); + + if (!terminating_outputs.has(current_id_group)) { + auto uses_pair = graph().iterDomainGroupUses(current_id_group); + if (uses_pair.second) { + to_visit_exprs.pushBack(uses_pair.first); + } + } + } else { + still_to_visit_ids.pushBack(current_id_group); + } + } + std::swap(to_visit_ids, still_to_visit_ids); + + TORCH_INTERNAL_ASSERT( + something_was_processed || + (to_visit_ids.size() == 0 && to_visit_exprs.size() == 0), + "Infinite loop entered."); + } +} +} \ No newline at end of file diff --git a/csrc/id_model/visitor.h b/csrc/id_model/visitor.h new file mode 100644 index 00000000000..34cc61704be --- /dev/null +++ b/csrc/id_model/visitor.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include + +namespace nvfuser { + +// Iterates through an IterDomain Graph in topological order, calling handle on +// all Id and all Expr groups in a forward topological order. +// +// Warning: Expr groups that have an input and output in the same IdGroup are +// ignored. +// +// Warning: This is not a great iterator if there's a desire to minimize paths +// traveled to simply visit all IdGroups in order. See ExprsBetween to see how +// we might minimize paths. +class TORCH_CUDA_CU_API IdGraphVisitor { + protected: + // If sub_selection is assumed to be a set of iter domains by which form a + // sub-regrion of the IdGraph provided. Only that sub-region will be visited. + IdGraphVisitor( + const IdGraph& id_graph, + const VectorOfUniqueEntries sub_selection = {}) + : id_graph_(id_graph), sub_selection_(sub_selection) {} + + virtual void handle(IdGroup id_group) = 0; + virtual void handle(ExprGroup expr_group) = 0; + + void traverse(); + + const IdGraph& graph() { + return id_graph_; + }; + + IdGraphVisitor() = delete; + + IdGraphVisitor(const IdGraphVisitor& other) = default; + IdGraphVisitor& operator=(const IdGraphVisitor& other) = delete; + + IdGraphVisitor(IdGraphVisitor&& other) = default; + IdGraphVisitor& operator=(IdGraphVisitor&& other) = delete; + + virtual ~IdGraphVisitor() = default; + + private: + const IdGraph& id_graph_; + const VectorOfUniqueEntries sub_selection_; +}; + +// Statement sorting based on IdGraphVisitor, see warnings to IdGraph Visitor. +class IdGraphStmtSort : public IdGraphVisitor { + public: + IdGraphStmtSort( + const IdGraph& id_graph, + const VectorOfUniqueEntries sub_selection = {}) + : IdGraphVisitor(id_graph, sub_selection) { + IdGraphVisitor::traverse(); + } + + ExprGroups exprs() { + return sorted_exprs; + } + + IdGroups ids() { + return sorted_ids; + } + + ~IdGraphStmtSort() override = default; + + protected: + using IdGraphVisitor::handle; + void handle(IdGroup id_group) override { + sorted_ids.pushBack(id_group); + } + + void handle(ExprGroup expr_group) override { + sorted_exprs.pushBack(expr_group); + } + + ExprGroups sorted_exprs; + IdGroups sorted_ids; +}; + +} \ No newline at end of file diff --git a/csrc/lower2device.cpp b/csrc/lower2device.cpp index a01c8147d97..eb85adbd596 100644 --- a/csrc/lower2device.cpp +++ b/csrc/lower2device.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include From 2896a284044685df0fa26a39decd3f01ff82f1d4 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 10 May 2023 09:18:15 -0400 Subject: [PATCH 030/178] License reference in files. --- csrc/disjoint_set.h | 7 +++++++ csrc/id_model/id_graph.cpp | 20 ++++++++++++++------ csrc/id_model/id_graph.h | 9 ++++++++- csrc/id_model/id_graphs.cpp | 12 ++++++++++-- csrc/id_model/id_graphs.h | 7 +++++++ csrc/id_model/to_string.cpp | 10 ++++++++-- csrc/id_model/to_string.h | 9 ++++++++- csrc/id_model/visitor.cpp | 13 ++++++++++--- csrc/id_model/visitor.h | 9 ++++++++- 9 files changed, 80 insertions(+), 16 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 2ca43be19e1..b0693392f4b 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -1,3 +1,10 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on #pragma once #include diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index f2c772435d6..dfebccde8a6 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -1,14 +1,22 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on #include #include #include namespace nvfuser { -IdGraph::IdGraph(const IdGraph& other) { - disjoint_ids_ = other.disjoint_ids_; - disjoint_exprs_ = other.disjoint_exprs_; - view_rfactor_ids_ = other.view_rfactor_ids_; - +IdGraph::IdGraph(const IdGraph& other) + : disjoint_ids_(other.disjoint_ids_), + disjoint_exprs_(other.disjoint_exprs_), + view_rfactor_ids_(other.view_rfactor_ids_), + unique_definitions_(), + unique_uses_() { for (auto orig_unique_def_pair : other.unique_definitions_) { auto orig_id_group = orig_unique_def_pair.first; auto orig_expr_groups = orig_unique_def_pair.second; @@ -1023,4 +1031,4 @@ bool IdGraph::isTrivialExprGroup(ExprGroup expr_group) const { .empty(); } -} // namespace nvfuser \ No newline at end of file +} // namespace nvfuser diff --git a/csrc/id_model/id_graph.h b/csrc/id_model/id_graph.h index b8d7a4e3686..ffd18d14758 100644 --- a/csrc/id_model/id_graph.h +++ b/csrc/id_model/id_graph.h @@ -1,3 +1,10 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on #pragma once #include @@ -245,4 +252,4 @@ class TORCH_CUDA_CU_API IdGraph { std::unordered_set view_rfactor_ids_; }; -} \ No newline at end of file +} // namespace nvfuser diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index f6938680883..9ea1cc5423d 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1,6 +1,13 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on #include -#include #include +#include #include #include @@ -1285,7 +1292,8 @@ std::unordered_map IterDomainGraphs:: } for (auto iel_use_group : non_promoted_input_uses) { - if (IdGraph::transformAtributesMatch(iel_expr->front(), iel_use_group->front())) { + if (IdGraph::transformAtributesMatch( + iel_expr->front(), iel_use_group->front())) { auto use_inps = ir_utils::filterByType(iel_use_group->front()->inputs()) .vector(); diff --git a/csrc/id_model/id_graphs.h b/csrc/id_model/id_graphs.h index e8b9a1e5f36..d196a8b6fa3 100644 --- a/csrc/id_model/id_graphs.h +++ b/csrc/id_model/id_graphs.h @@ -1,3 +1,10 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on #pragma once #include diff --git a/csrc/id_model/to_string.cpp b/csrc/id_model/to_string.cpp index e5991b4ade5..0c828f814a5 100644 --- a/csrc/id_model/to_string.cpp +++ b/csrc/id_model/to_string.cpp @@ -1,4 +1,10 @@ - +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on #include namespace nvfuser { @@ -319,4 +325,4 @@ std::string usesString( return toString(id_graph, uses, indent_size, with_ptr); } -} \ No newline at end of file +} // namespace nvfuser diff --git a/csrc/id_model/to_string.h b/csrc/id_model/to_string.h index eafe6ffcbd5..97ad537375a 100644 --- a/csrc/id_model/to_string.h +++ b/csrc/id_model/to_string.h @@ -1,3 +1,10 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on #pragma once #include @@ -74,4 +81,4 @@ std::string usesString( int indent_size = 0, bool with_ptr = false); -} // namespace nvfuser \ No newline at end of file +} // namespace nvfuser diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp index 91c7aeab4ab..07ff76ee1ec 100644 --- a/csrc/id_model/visitor.cpp +++ b/csrc/id_model/visitor.cpp @@ -1,7 +1,14 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on #include -namespace nvfuser{ - +namespace nvfuser { + void IdGraphVisitor::traverse() { IdGroups all_ids; ExprGroups all_exprs; @@ -153,4 +160,4 @@ void IdGraphVisitor::traverse() { "Infinite loop entered."); } } -} \ No newline at end of file +} // namespace nvfuser diff --git a/csrc/id_model/visitor.h b/csrc/id_model/visitor.h index 34cc61704be..3bd84abb8bd 100644 --- a/csrc/id_model/visitor.h +++ b/csrc/id_model/visitor.h @@ -1,3 +1,10 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on #pragma once #include @@ -82,4 +89,4 @@ class IdGraphStmtSort : public IdGraphVisitor { IdGroups sorted_ids; }; -} \ No newline at end of file +} // namespace nvfuser From ddc858ede8eb6146f0e3339ba9f4b7aa02179582 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 20 May 2023 15:49:30 -0400 Subject: [PATCH 031/178] Merge Conflict Fix. --- CMakeLists.txt | 1 - csrc/disjoint_set.h | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b7b8a9ea147..6e9b7e74522 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,7 +86,6 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/graph_fuser.cpp ${NVFUSER_SRCS_DIR}/grouped_reduction.cpp ${NVFUSER_SRCS_DIR}/index_compute.cpp - ${NVFUSER_SRCS_DIR}/lower_index_compute.cpp ${NVFUSER_SRCS_DIR}/id_model/id_graph.cpp ${NVFUSER_SRCS_DIR}/id_model/id_graphs.cpp ${NVFUSER_SRCS_DIR}/id_model/to_string.cpp diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 019fb1e209e..6f1305eb864 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -428,7 +428,10 @@ class DisjointSets { } auto set = entry_it->second; - if (set->size() == 1 && set->front() == entry) { + if (set->size() == 1) { + TORCH_INTERNAL_ASSERT( + set->front() == entry, + "Disjoint set container found to be in inconsistent state."); disjoint_set_maps_.erase(entry); disjoint_sets_.erase( std::find(disjoint_sets_.begin(), disjoint_sets_.end(), set)); From 8fd5bceb09f0e5aafae5ccfda8b15bb11e30b445 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 20 May 2023 16:34:59 -0400 Subject: [PATCH 032/178] Cleanup. --- csrc/id_model/id_graph.cpp | 70 ++++++++++++++----------------------- csrc/id_model/id_graph.h | 29 ++++++--------- csrc/id_model/id_graphs.cpp | 69 ++++++++++++++---------------------- csrc/id_model/id_graphs.h | 2 +- csrc/id_model/to_string.h | 2 +- csrc/id_model/visitor.cpp | 5 ++- csrc/id_model/visitor.h | 2 +- csrc/transform_iter.cpp | 2 +- 8 files changed, 68 insertions(+), 113 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index dfebccde8a6..91699cf068c 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -7,7 +7,7 @@ // clang-format on #include #include -#include +#include namespace nvfuser { @@ -20,16 +20,11 @@ IdGraph::IdGraph(const IdGraph& other) for (auto orig_unique_def_pair : other.unique_definitions_) { auto orig_id_group = orig_unique_def_pair.first; auto orig_expr_groups = orig_unique_def_pair.second; - - auto new_id_group_pair = disjointIdSet(orig_id_group->front()); - TORCH_INTERNAL_ASSERT(new_id_group_pair.second); - auto new_id_group = new_id_group_pair.first; + auto new_id_group = toGroup(orig_id_group->front()); ExprGroups new_expr_groups; for (auto orig_expr_group : orig_expr_groups) { - auto new_expr_group_pair = disjointExprSet(orig_expr_group->front()); - TORCH_INTERNAL_ASSERT(new_expr_group_pair.second); - new_expr_groups.pushBack(new_expr_group_pair.first); + new_expr_groups.pushBack(toGroup(orig_expr_group->front())); } unique_definitions_[new_id_group] = new_expr_groups; @@ -38,16 +33,11 @@ IdGraph::IdGraph(const IdGraph& other) for (auto orig_unique_use_pair : other.unique_uses_) { auto orig_id_group = orig_unique_use_pair.first; auto orig_expr_groups = orig_unique_use_pair.second; - - auto new_id_group_pair = disjointIdSet(orig_id_group->front()); - TORCH_INTERNAL_ASSERT(new_id_group_pair.second); - auto new_id_group = new_id_group_pair.first; + auto new_id_group = toGroup(orig_id_group->front()); ExprGroups new_expr_groups; for (auto orig_expr_group : orig_expr_groups) { - auto new_expr_group_pair = disjointExprSet(orig_expr_group->front()); - TORCH_INTERNAL_ASSERT(new_expr_group_pair.second); - new_expr_groups.pushBack(new_expr_group_pair.first); + new_expr_groups.pushBack(toGroup(orig_expr_group->front())); } unique_uses_[new_id_group] = new_expr_groups; @@ -73,14 +63,6 @@ DisjointSets& IdGraph::disjointIdSets() { return disjoint_ids_; } -std::pair IdGraph::disjointIdSet(IterDomain* id) const { - auto disjoint_set_it = disjoint_ids_.disjointSetMap().find(id); - if (disjoint_set_it == disjoint_ids_.disjointSetMap().end()) { - return std::make_pair(IdGroup(nullptr), false); - } - return std::make_pair(disjoint_set_it->second, true); -} - const DisjointSets& IdGraph::disjointExprSets() const { return disjoint_exprs_; } @@ -89,31 +71,33 @@ DisjointSets& IdGraph::disjointExprSets() { return disjoint_exprs_; } -std::pair IdGraph::disjointExprSet(Expr* expr) const { - auto disjoint_set_it = disjoint_exprs_.disjointSetMap().find(expr); - if (disjoint_set_it == disjoint_exprs_.disjointSetMap().end()) { - return std::make_pair(ExprGroup(nullptr), false); - } - return std::make_pair(disjoint_set_it->second, true); +// Return if there's a group entry in the graph for this expr +bool IdGraph::hasGroup(Expr* expr) const { + return disjoint_exprs_.mappingExists(expr); +} + +// Return if there's a group entry in the graph for this id +bool IdGraph::hasGroup(IterDomain* id) const { + return disjoint_ids_.mappingExists(id); } ExprGroup IdGraph::toGroup(Expr* expr) const { - auto disjoint_set_pair = disjointExprSet(expr); + auto disjoint_set_it = disjoint_exprs_.disjointSetMap().find(expr); TORCH_INTERNAL_ASSERT( - disjoint_set_pair.second, + disjoint_set_it != disjoint_exprs_.disjointSetMap().end(), "\nExpr group could not be found in graph associated with: ", expr->toString()); - return disjoint_set_pair.first; + return disjoint_set_it->second; } IdGroup IdGraph::toGroup(IterDomain* id) const { - auto disjoint_set_pair = disjointIdSet(id); + auto disjoint_set_it = disjoint_ids_.disjointSetMap().find(id); TORCH_INTERNAL_ASSERT( - disjoint_set_pair.second, + disjoint_set_it != disjoint_ids_.disjointSetMap().end(), "\nId group could not be found in graph associated with: ", id->toString(), "\n"); - return disjoint_set_pair.first; + return disjoint_set_it->second; } ExprGroups IdGraph::toGroups(const VectorOfUniqueEntries& exprs) const { @@ -491,11 +475,10 @@ std::unordered_map> IdGraph:: std::unordered_map from_ids2set; for (auto from_id : from) { - auto from_disjoint_set_pair = disjointIdSet(from_id); - if (!from_disjoint_set_pair.second) { + if (!hasGroup(from_id)) { continue; } - from_ids2set[from_id] = from_disjoint_set_pair.first; + from_ids2set[from_id] = toGroup(from_id); } // Map from the sets associated with the IterDomains in to, to those iter @@ -503,11 +486,10 @@ std::unordered_map> IdGraph:: std::unordered_map> set2to_ids; for (auto to_id : to) { - auto to_disjoint_set_pair = disjointIdSet(to_id); - if (!to_disjoint_set_pair.second) { + if (!hasGroup(to_id)) { continue; } - auto to_set = to_disjoint_set_pair.first; + auto to_set = toGroup(to_id); auto set2to_ids_it = set2to_ids.find(to_set); if (set2to_ids_it == set2to_ids.end()) { @@ -789,8 +771,8 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { // Definitions and uses are based on the groups of id0 and id1, don't merge // them into a single group until we grab all definitions and uses for later // processing. - auto orig_id_group0 = disjointIdSet(id0).first; - auto orig_id_group1 = disjointIdSet(id1).first; + auto orig_id_group0 = toGroup(id0); + auto orig_id_group1 = toGroup(id1); ExprGroups orig_defs0 = uniqueDefinitions(orig_id_group0); ExprGroups orig_defs1 = uniqueDefinitions(orig_id_group1); ExprGroups orig_uses0 = uniqueUses(orig_id_group0); @@ -800,7 +782,7 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { // uses. Traversing definitions and uses could use the new property of id0 and // id1 being mapped. disjointIdSets().mapEntries(id0, id1); - auto new_id_group = disjointIdSet(id0).first; + auto new_id_group = toGroup(id0); unique_definitions_.erase(orig_id_group0); unique_definitions_.erase(orig_id_group1); diff --git a/csrc/id_model/id_graph.h b/csrc/id_model/id_graph.h index ffd18d14758..4581350f25a 100644 --- a/csrc/id_model/id_graph.h +++ b/csrc/id_model/id_graph.h @@ -8,7 +8,7 @@ #pragma once #include -#include +#include #include #include @@ -36,25 +36,16 @@ class TORCH_CUDA_CU_API IdGraph { DisjointSets& disjointIdSets(); - // Returns - // { - // (1) The disjoint set of the provided Iter Domain if it exists, - // otherwise a null shared ptr - // (2) If the disjoint set of the provided Iter Domain exists - // } - // - // TODO: Audit usage - std::pair disjointIdSet(IterDomain* id) const; - // Returns the disjoint Expr set. const DisjointSets& disjointExprSets() const; DisjointSets& disjointExprSets(); - // Same as getDisjointIdSet but for the Expression sets. - // - // TODO: Audit usage - std::pair disjointExprSet(Expr* expr) const; + // Return if there's a group entry in the graph for this expr + bool hasGroup(Expr* expr) const; + + // Return if there's a group entry in the graph for this id + bool hasGroup(IterDomain* id) const; // Convert expr to its exprGroup, assert that it exists. ExprGroup toGroup(Expr* expr) const; @@ -240,16 +231,16 @@ class TORCH_CUDA_CU_API IdGraph { // Keeps a disjoint set entry for all Expressions for all mapping mode types. DisjointSets disjoint_exprs_; - std::unordered_map unique_definitions_; - - std::unordered_map unique_uses_; - // Hold a set of IterDomains that are considered view rfactor ids. This // identification is particularly important to understand if split operations // are divisible or not. // // TODO: This should just be in IterDomainGraphs, not here. std::unordered_set view_rfactor_ids_; + + std::unordered_map unique_definitions_; + + std::unordered_map unique_uses_; }; } // namespace nvfuser diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 9ea1cc5423d..3d5cce6353f 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -9,11 +9,11 @@ #include #include +#include +#include +#include #include -#include -#include -#include -#include +#include #include #include @@ -363,7 +363,7 @@ Expr* IterDomainGraphs::addReplayAs( for (auto mode : initialized_modes) { for (auto inp : all_inputs) { TORCH_INTERNAL_ASSERT( - idGraph(mode).disjointIdSet(inp).second, + idGraph(mode).hasGroup(inp), "All inputs for replay need to be initialized in all graphs, ", inp->toString(), " was not found in mode: ", @@ -389,7 +389,7 @@ Expr* IterDomainGraphs::addReplayAs( // Initialize output iter domains in the graphs for (auto mode : initialized_modes) { idGraph(mode).disjointExprSets().initializeSet(replay); - auto replay_group = idGraph(mode).disjointExprSet(replay).first; + auto replay_group = idGraph(mode).toGroup(replay); // Initialize output ids in map for (auto out_id : ir_utils::filterByType(replay->outputs())) { @@ -398,7 +398,7 @@ Expr* IterDomainGraphs::addReplayAs( // Update uses of the inputs in the graphs for (auto inp_id : ir_utils::filterByType(replay->inputs())) { - auto inp_group = idGraph(mode).disjointIdSet(inp_id).first; + auto inp_group = idGraph(mode).toGroup(inp_id); idGraph(mode).uniqueUses().at(inp_group).pushBack(replay_group); } @@ -408,8 +408,7 @@ Expr* IterDomainGraphs::addReplayAs( // Gather all use expressions from inputs VectorOfUniqueEntries representative_uses; for (auto inp : new_inputs) { - auto uses_pair = - graph.iterDomainGroupUses(graph.disjointIdSet(inp).first); + auto uses_pair = graph.iterDomainGroupUses(graph.toGroup(inp)); if (uses_pair.second) { for (auto use_group : uses_pair.first) { representative_uses.pushBack(use_group->front()); @@ -495,7 +494,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( ? ir_utils::filterByType(old_expr->inputs()) : ir_utils::filterByType(old_expr->outputs())) { TORCH_INTERNAL_ASSERT( - idGraph(mode).disjointIdSet(inp_or_out_id).second, + idGraph(mode).hasGroup(inp_or_out_id), "Expected ", inp_or_out_id->toString(), " to be initialized in graph mode: ", @@ -524,7 +523,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( auto& graph = idGraph(mode); graph.disjointExprSets().initializeSet(replay); - auto replay_group = graph.disjointExprSet(replay).first; + auto replay_group = graph.toGroup(replay); // Initialize any non-existant input ids, update existing ones for (auto inp_id : ir_utils::filterByType(replay->inputs())) { @@ -533,7 +532,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( graph.initializeId(inp_id, {}, {replay}); } else { // Update unique uses of existing input ids - auto inp_group = graph.disjointIdSet(inp_id).first; + auto inp_group = graph.toGroup(inp_id); graph.uniqueUses()[inp_group].pushBack(replay_group); } } @@ -546,7 +545,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( } else { // out_id is already initialized, add the replay as a unique definition // of its group - auto out_group = graph.disjointIdSet(out_id).first; + auto out_group = graph.toGroup(out_id); graph.uniqueDefinitions()[out_group].pushBack(replay_group); } } @@ -558,7 +557,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( // Forward VectorOfUniqueEntries representative_uses; for (auto in : ir_utils::filterByType(replay->inputs())) { - auto uses_pair = graph.iterDomainGroupUses(graph.disjointIdSet(in).first); + auto uses_pair = graph.iterDomainGroupUses(graph.toGroup(in)); if (uses_pair.second) { for (auto use_group : uses_pair.first) { if (use_group == replay_group) { @@ -576,8 +575,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( // Backwards VectorOfUniqueEntries representative_defs; for (auto out : ir_utils::filterByType(replay->outputs())) { - auto defs_pair = - graph.iterDomainGroupDefinitions(graph.disjointIdSet(out).first); + auto defs_pair = graph.iterDomainGroupDefinitions(graph.toGroup(out)); if (defs_pair.second) { for (auto def_group : defs_pair.first) { if (def_group == replay_group) { @@ -885,8 +883,7 @@ StatefulLoweringInfo buildInfo( for (auto entry : resolved_bcast_map) { info.p2c_root_broadcast_resolution_map[entry.first].pushBack( entry.second); - for (auto other_exact_bcast : - *exact_graph.disjointIdSet(entry.first).first) { + for (auto other_exact_bcast : *exact_graph.toGroup(entry.first)) { if (all_producer_ca_deps.has(other_exact_bcast)) { info.p2c_root_broadcast_resolution_map[other_exact_bcast] .pushBack(entry.second); @@ -1049,10 +1046,7 @@ VectorOfUniqueEntries IterDomainGraphs::computeTerminalLoopIds( bool all_outs_in_loop_group = uses_it->second.size() == 0 ? false : true; for (auto use : uses_it->second) { for (auto out_id : ir_utils::filterByType(use->outputs())) { - auto out_loop_set_pair = - idGraph(IdMappingMode::LOOP).disjointIdSet(out_id); - TORCH_INTERNAL_ASSERT(out_loop_set_pair.second); - if (group != out_loop_set_pair.first) { + if (group != idGraph(IdMappingMode::LOOP).toGroup(out_id)) { all_outs_in_loop_group = false; } } @@ -1173,10 +1167,7 @@ std::unordered_map IterDomainGraphs:: } // Collect all the exact groups in the loop set containing this iel_group - auto loop_group_pair = - idGraph(IdMappingMode::LOOP).disjointIdSet(iel_group->front()); - TORCH_INTERNAL_ASSERT(loop_group_pair.second); - auto loop_group = loop_group_pair.first; + auto loop_group = idGraph(IdMappingMode::LOOP).toGroup(iel_group->front()); auto loop_covered_exact_groups = idGraph(IdMappingMode::EXACT).toGroups(*loop_group); @@ -1269,10 +1260,9 @@ std::unordered_map IterDomainGraphs:: IdGroups promoted_input_groups; for (auto inp_id : promoted_inputs) { - auto inp_disjoint_set_pair = - intersection_exact_loop_graph.disjointIdSet(inp_id); - if (inp_disjoint_set_pair.second) { - promoted_input_groups.pushBack(inp_disjoint_set_pair.first); + if (intersection_exact_loop_graph.hasGroup(inp_id)) { + promoted_input_groups.pushBack( + intersection_exact_loop_graph.toGroup(inp_id)); } } @@ -1521,26 +1511,19 @@ std::unordered_map IterDomainGraphs:: } // Grab the iel entry - auto iel_set_pair = intersection_exact_loop_graph.disjointIdSet(loop_id); - TORCH_INTERNAL_ASSERT(iel_set_pair.second); - auto iel_group = iel_set_pair.first; + auto iel_group = intersection_exact_loop_graph.toGroup(loop_id); auto iel_promo_it = iel_promotion_map.find(iel_group); if (iel_promo_it == iel_promotion_map.end()) { // If this terminal ID has a promotion, grab the promoted ID. - auto promo_id_exact_it = - idGraph(IdMappingMode::EXACT).disjointIdSet(loop_id); - TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); - exact_promoted_terminal_ids.push_back( - std::make_pair(promo_id_exact_it.first, loop_id)); + exact_promoted_terminal_ids.push_back(std::make_pair( + idGraph(IdMappingMode::EXACT).toGroup(loop_id), loop_id)); } else { // If this terminal ID doesn't have a promotion associated with it, save // the terminal ID. - auto promo_id_exact_it = - idGraph(IdMappingMode::EXACT).disjointIdSet(iel_promo_it->second); - TORCH_INTERNAL_ASSERT(promo_id_exact_it.second); - exact_promoted_terminal_ids.push_back( - std::make_pair(promo_id_exact_it.first, iel_promo_it->second)); + exact_promoted_terminal_ids.push_back(std::make_pair( + idGraph(IdMappingMode::EXACT).toGroup(iel_promo_it->second), + iel_promo_it->second)); } } diff --git a/csrc/id_model/id_graphs.h b/csrc/id_model/id_graphs.h index d196a8b6fa3..b0177e466e1 100644 --- a/csrc/id_model/id_graphs.h +++ b/csrc/id_model/id_graphs.h @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include diff --git a/csrc/id_model/to_string.h b/csrc/id_model/to_string.h index 97ad537375a..f58cf00a2a2 100644 --- a/csrc/id_model/to_string.h +++ b/csrc/id_model/to_string.h @@ -8,7 +8,7 @@ #pragma once #include -#include +#include #include #include diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp index 07ff76ee1ec..910ba31de67 100644 --- a/csrc/id_model/visitor.cpp +++ b/csrc/id_model/visitor.cpp @@ -19,9 +19,8 @@ void IdGraphVisitor::traverse() { graph().disjointIdSets().disjointSets().end()); } else { for (auto id : sub_selection_) { - auto disjoint_pair = graph().disjointIdSet(id); - if (disjoint_pair.second) { - all_ids.pushBack(disjoint_pair.first); + if (graph().hasGroup(id)) { + all_ids.pushBack(graph().toGroup(id)); } } } diff --git a/csrc/id_model/visitor.h b/csrc/id_model/visitor.h index 3bd84abb8bd..4029fdced0e 100644 --- a/csrc/id_model/visitor.h +++ b/csrc/id_model/visitor.h @@ -9,7 +9,7 @@ #include #include -#include +#include namespace nvfuser { diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 8b3d4b53ad7..a942b68515e 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -7,7 +7,7 @@ // clang-format on #include -#include +#include #include #include From 84ed670e9be25b49dec99f961bb4e14eeb81d373 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 20 May 2023 16:44:26 -0400 Subject: [PATCH 033/178] Comment out WIP test. --- test/test_gpu_indexing.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index 2f6d5bf493f..7efa05ad319 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -794,6 +794,7 @@ TEST_F(NVFuserTest, FusionIndexing17_CUDA) { &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } +#if 0 // TODO: Finish and enable test TEST_F(NVFuserTest, FusionIndexing18_CUDA) { Fusion fusion; @@ -830,6 +831,7 @@ TEST_F(NVFuserTest, FusionIndexing18_CUDA) { // ComputeAtMap ca_map(&fusion); // std::cout << ca_map.idGraph().loopNodes().toString() << std::endl; } +#endif // TODO: Finish and enable test #if 0 From a466c5c66ffb948ec8e64516b62fcb816c9498a0 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 20 May 2023 16:44:30 -0400 Subject: [PATCH 034/178] Minor cleanup. --- csrc/transform_iter.cpp | 78 ++++++++++------------------------------- csrc/transform_iter.h | 10 ++++-- 2 files changed, 26 insertions(+), 62 deletions(-) diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index a942b68515e..1815b4aee6a 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -107,26 +107,21 @@ ReplacementTransformCloner::ReplacementTransformCloner( OptOutConstDispatch::handle(expression_to_match); } +IterDomain* ReplacementTransformCloner::replaceOrClone(IterDomain* id) { + if (provided_expr_val_2_replacement_val_.find(id) != + provided_expr_val_2_replacement_val_.end()) { + return provided_expr_val_2_replacement_val_.at(id); + } + return id->cloneWithoutRFactor(); +} + // We're going to replay this split operation on the corresponding ID void ReplacementTransformCloner::handle(const Split* split) { // Replace or clone - auto split_in = split->in(); - split_in = provided_expr_val_2_replacement_val_.find(split_in) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(split_in) - : split_in->cloneWithoutRFactor(); - - auto split_outer = split->outer(); - split_outer = provided_expr_val_2_replacement_val_.find(split_outer) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(split_outer) - : split_outer->cloneWithoutRFactor(); - auto split_inner = split->inner(); - split_inner = provided_expr_val_2_replacement_val_.find(split_inner) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(split_inner) - : split_inner->cloneWithoutRFactor(); + auto split_in = replaceOrClone(split->in()); + auto split_outer = replaceOrClone(split->outer()); + auto split_inner = replaceOrClone(split->inner()); // TODO: Should we check inner/outer matches the factor if // innerSplit()/!innerSplit()? @@ -144,24 +139,9 @@ void ReplacementTransformCloner::handle(const Split* split) { // We're going to replay this merge operation on the corresponding IDs void ReplacementTransformCloner::handle(const Merge* merge) { // Replace or clone - auto merge_outer = merge->outer(); - merge_outer = provided_expr_val_2_replacement_val_.find(merge_outer) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(merge_outer) - : merge_outer->cloneWithoutRFactor(); - - auto merge_inner = merge->inner(); - merge_inner = provided_expr_val_2_replacement_val_.find(merge_inner) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(merge_inner) - : merge_inner->cloneWithoutRFactor(); - - auto merge_out = merge->out(); - merge_out = provided_expr_val_2_replacement_val_.find(merge_out) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(merge_out) - : merge_out->cloneWithoutRFactor(); - + auto merge_outer = replaceOrClone(merge->outer()); + auto merge_inner = replaceOrClone(merge->inner()); + auto merge_out = replaceOrClone(merge->out()); new_expr_ = IrBuilder::create(merge_out, merge_outer, merge_inner); } @@ -169,32 +149,10 @@ void ReplacementTransformCloner::handle(const Merge* merge) { // if replaying swizzle is enabled. void ReplacementTransformCloner::handle(const Swizzle2D* swizzle_2d) { // Replace or clone - auto swizzle_inx = swizzle_2d->inX(); - swizzle_inx = provided_expr_val_2_replacement_val_.find(swizzle_inx) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(swizzle_inx) - : swizzle_inx->cloneWithoutRFactor(); - - // Replace or clone - auto swizzle_iny = swizzle_2d->inY(); - swizzle_iny = provided_expr_val_2_replacement_val_.find(swizzle_iny) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(swizzle_iny) - : swizzle_iny->cloneWithoutRFactor(); - - // Replace or clone - auto swizzle_outx = swizzle_2d->outX(); - swizzle_outx = provided_expr_val_2_replacement_val_.find(swizzle_outx) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(swizzle_outx) - : swizzle_outx->cloneWithoutRFactor(); - - // Replace or clone - auto swizzle_outy = swizzle_2d->outY(); - swizzle_outy = provided_expr_val_2_replacement_val_.find(swizzle_outy) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(swizzle_outy) - : swizzle_outy->cloneWithoutRFactor(); + auto swizzle_inx = replaceOrClone(swizzle_2d->inX()); + auto swizzle_iny = replaceOrClone(swizzle_2d->inY()); + auto swizzle_outx = replaceOrClone(swizzle_2d->outX()); + auto swizzle_outy = replaceOrClone(swizzle_2d->outY()); new_expr_ = IrBuilder::create( swizzle_outx, diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index 37d1cf458ec..55ab90bc042 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -93,6 +93,10 @@ class ReplacementTransformCloner : OptInConstDispatch { using OptInConstDispatch::handle; + // Returns entry in provided_expr_val_2_replacement_val_ if exists otherwise + // returns a clone of the provided iter domain. + IterDomain* replaceOrClone(IterDomain* id); + // We're going to replay this split operation on the corresponding ID void handle(const Split* split) override; @@ -276,10 +280,12 @@ class ForwardingInfo { * * Given an Expr in target_domain, check if its inputs are in replay_map. If so, * check if the mapped domain in replay_map are recorded to be transformed by an - * "equivelent" operation in replay_domain's history. If so, "forward" the + * "equivelent" operation in replay_domain's history. If so, forward the * operation and update replay_map to map the outputs of the expressions across * target_domain and reference_domain. * + * Long Description: + * * replay_map maps root IDs in the history of target_domain to root IDs in the * history replay_domain. PasC and CasP is just a convenient mechanism to have * BestEffortReplay make this base root mapping. @@ -322,7 +328,7 @@ class ForwardingInfo { * we want to make sure those transformations are consistent with T4 (between * T4's root and rfactor domain). Best Effort Replay does not actually add any * transformations to the tensors provided. However, it will provide information - * to determine producers's transformations are consistent consumers + * to determine producers's transformations are consistent with consumers * transformations (or the other way around). Best Effort Replay will return * discovered mappings between tensors that it detects to be matching based on * provided initial information (or just through p2c/c2p root domain mappings). From 3934fe69ecb93acabf97c197809614d9442170c9 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 20 May 2023 16:51:27 -0400 Subject: [PATCH 035/178] Cleanup. --- CMakeLists.txt | 1 + csrc/id_model/id_graphs.cpp | 1 + csrc/id_model/replacement_transform.cpp | 96 ++++++++++++++++++++++++ csrc/id_model/replacement_transform.h | 67 +++++++++++++++++ csrc/transform_iter.cpp | 97 +------------------------ csrc/transform_iter.h | 52 +------------ 6 files changed, 170 insertions(+), 144 deletions(-) create mode 100644 csrc/id_model/replacement_transform.cpp create mode 100644 csrc/id_model/replacement_transform.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e9b7e74522..719db3ceb5e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,6 +88,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/index_compute.cpp ${NVFUSER_SRCS_DIR}/id_model/id_graph.cpp ${NVFUSER_SRCS_DIR}/id_model/id_graphs.cpp + ${NVFUSER_SRCS_DIR}/id_model/replacement_transform.cpp ${NVFUSER_SRCS_DIR}/id_model/to_string.cpp ${NVFUSER_SRCS_DIR}/id_model/visitor.cpp ${NVFUSER_SRCS_DIR}/instrumentation.cpp diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 3d5cce6353f..85d8d79d596 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include #include #include diff --git a/csrc/id_model/replacement_transform.cpp b/csrc/id_model/replacement_transform.cpp new file mode 100644 index 00000000000..ffbcd07ec94 --- /dev/null +++ b/csrc/id_model/replacement_transform.cpp @@ -0,0 +1,96 @@ +#include + +#include + +namespace nvfuser { +Expr* ReplacementTransformCloner::clone( + const std::unordered_map& + provided_expr_val_2_replacement_val, + const Expr* expression_to_match) { + ReplacementTransformCloner replay( + provided_expr_val_2_replacement_val, expression_to_match); + return replay.new_expr_; +} + +ReplacementTransformCloner::ReplacementTransformCloner( + const std::unordered_map& + provided_expr_val_2_replacement_val, + const Expr* expression_to_match) + : provided_expr_val_2_replacement_val_( + provided_expr_val_2_replacement_val) { + OptOutConstDispatch::handle(expression_to_match); +} + +IterDomain* ReplacementTransformCloner::replaceOrClone(IterDomain* id) { + if (provided_expr_val_2_replacement_val_.find(id) != + provided_expr_val_2_replacement_val_.end()) { + return provided_expr_val_2_replacement_val_.at(id); + } + return id->cloneWithoutRFactor(); +} + +// We're going to replay this split operation on the corresponding ID +void ReplacementTransformCloner::handle(const Split* split) { + // Replace or clone + + auto split_in = replaceOrClone(split->in()); + auto split_outer = replaceOrClone(split->outer()); + auto split_inner = replaceOrClone(split->inner()); + + // TODO: Should we check inner/outer matches the factor if + // innerSplit()/!innerSplit()? + + new_expr_ = IrBuilder::create( + split_outer, + split_inner, + split_in, + split->factor(), + split->innerSplit(), + split->startOffset(), + split->stopOffset()); +} + +// We're going to replay this merge operation on the corresponding IDs +void ReplacementTransformCloner::handle(const Merge* merge) { + // Replace or clone + auto merge_outer = replaceOrClone(merge->outer()); + auto merge_inner = replaceOrClone(merge->inner()); + auto merge_out = replaceOrClone(merge->out()); + new_expr_ = IrBuilder::create(merge_out, merge_outer, merge_inner); +} + +// We're going to replay this swizzle operation on the corresponding IDs +// if replaying swizzle is enabled. +void ReplacementTransformCloner::handle(const Swizzle2D* swizzle_2d) { + // Replace or clone + auto swizzle_inx = replaceOrClone(swizzle_2d->inX()); + auto swizzle_iny = replaceOrClone(swizzle_2d->inY()); + auto swizzle_outx = replaceOrClone(swizzle_2d->outX()); + auto swizzle_outy = replaceOrClone(swizzle_2d->outY()); + + new_expr_ = IrBuilder::create( + swizzle_outx, + swizzle_outy, + swizzle_inx, + swizzle_iny, + swizzle_2d->swizzleType(), + swizzle_2d->swizzleMode()); +} + +void ReplacementTransformCloner::handle(const Resize* resize) { + auto resize_in = resize->in(); + resize_in = provided_expr_val_2_replacement_val_.find(resize_in) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(resize_in) + : resize_in->cloneWithoutRFactor(); + + auto resize_out = resize->out(); + resize_out = provided_expr_val_2_replacement_val_.find(resize_out) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(resize_out) + : resize_out->cloneWithoutRFactor(); + + new_expr_ = IrBuilder::create( + resize_out, resize_in, resize->leftExpand(), resize->rightExpand()); +} +} // namespace nvfuser \ No newline at end of file diff --git a/csrc/id_model/replacement_transform.h b/csrc/id_model/replacement_transform.h new file mode 100644 index 00000000000..c40b70874e7 --- /dev/null +++ b/csrc/id_model/replacement_transform.h @@ -0,0 +1,67 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include + +#include +#include + +namespace nvfuser { + +class ReplacementTransformCloner : OptInConstDispatch { + public: + // Generates a copy of expression_to_match with inputs and/or outputs replaced + // by entries provided in the map. Inputs and outputs are expected to be + // "clones". Not literally, but it's up to the envoking code to make the + // input/output replacements are safe to use in the cloned expression. No + // validation is done on provided inputs/outputs. + // + // In other words a split i0{I0}->i1{I0//2}, i2{2} with a map: + // i2{2} -> i3{48} wouldn't throw an error, but would not bevalid. + static Expr* clone( + const std::unordered_map& + provided_expr_val_2_replacement_val, + const Expr* expression_to_match); + + private: + ReplacementTransformCloner() = delete; + + ReplacementTransformCloner( + const std::unordered_map& + expr_to_match_2_replacement, + const Expr* expression_to_match); + + using OptInConstDispatch::handle; + + // Returns entry in provided_expr_val_2_replacement_val_ if exists otherwise + // returns a clone of the provided iter domain. + IterDomain* replaceOrClone(IterDomain* id); + + // We're going to replay this split operation on the corresponding ID + void handle(const Split* split) override; + + // We're going to replay this merge operation on the corresponding IDs + void handle(const Merge* merge) override; + + // We're going to replay this swizzle operation on the corresponding IDs + // if replaying swizzle is enabled. + void handle(const Swizzle2D* swizzle_2d) override; + + // We're going to replay this resize operation on the corresponding IDs + // if replaying resize is enabled. + void handle(const Resize* resize) override; + + Expr* new_expr_ = nullptr; + const std::unordered_map& + provided_expr_val_2_replacement_val_; +}; + +} // namespace nvfuser diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 1815b4aee6a..67859419b67 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -5,12 +5,12 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include - #include +#include +#include +#include #include -#include #include @@ -89,97 +89,6 @@ void ReplayTransformations::handle(Expr* e) { IterVisitor::handle(e); } -Expr* ReplacementTransformCloner::clone( - const std::unordered_map& - provided_expr_val_2_replacement_val, - const Expr* expression_to_match) { - ReplacementTransformCloner replay( - provided_expr_val_2_replacement_val, expression_to_match); - return replay.new_expr_; -} - -ReplacementTransformCloner::ReplacementTransformCloner( - const std::unordered_map& - provided_expr_val_2_replacement_val, - const Expr* expression_to_match) - : provided_expr_val_2_replacement_val_( - provided_expr_val_2_replacement_val) { - OptOutConstDispatch::handle(expression_to_match); -} - -IterDomain* ReplacementTransformCloner::replaceOrClone(IterDomain* id) { - if (provided_expr_val_2_replacement_val_.find(id) != - provided_expr_val_2_replacement_val_.end()) { - return provided_expr_val_2_replacement_val_.at(id); - } - return id->cloneWithoutRFactor(); -} - -// We're going to replay this split operation on the corresponding ID -void ReplacementTransformCloner::handle(const Split* split) { - // Replace or clone - - auto split_in = replaceOrClone(split->in()); - auto split_outer = replaceOrClone(split->outer()); - auto split_inner = replaceOrClone(split->inner()); - - // TODO: Should we check inner/outer matches the factor if - // innerSplit()/!innerSplit()? - - new_expr_ = IrBuilder::create( - split_outer, - split_inner, - split_in, - split->factor(), - split->innerSplit(), - split->startOffset(), - split->stopOffset()); -} - -// We're going to replay this merge operation on the corresponding IDs -void ReplacementTransformCloner::handle(const Merge* merge) { - // Replace or clone - auto merge_outer = replaceOrClone(merge->outer()); - auto merge_inner = replaceOrClone(merge->inner()); - auto merge_out = replaceOrClone(merge->out()); - new_expr_ = IrBuilder::create(merge_out, merge_outer, merge_inner); -} - -// We're going to replay this swizzle operation on the corresponding IDs -// if replaying swizzle is enabled. -void ReplacementTransformCloner::handle(const Swizzle2D* swizzle_2d) { - // Replace or clone - auto swizzle_inx = replaceOrClone(swizzle_2d->inX()); - auto swizzle_iny = replaceOrClone(swizzle_2d->inY()); - auto swizzle_outx = replaceOrClone(swizzle_2d->outX()); - auto swizzle_outy = replaceOrClone(swizzle_2d->outY()); - - new_expr_ = IrBuilder::create( - swizzle_outx, - swizzle_outy, - swizzle_inx, - swizzle_iny, - swizzle_2d->swizzleType(), - swizzle_2d->swizzleMode()); -} - -void ReplacementTransformCloner::handle(const Resize* resize) { - auto resize_in = resize->in(); - resize_in = provided_expr_val_2_replacement_val_.find(resize_in) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(resize_in) - : resize_in->cloneWithoutRFactor(); - - auto resize_out = resize->out(); - resize_out = provided_expr_val_2_replacement_val_.find(resize_out) != - provided_expr_val_2_replacement_val_.end() - ? provided_expr_val_2_replacement_val_.at(resize_out) - : resize_out->cloneWithoutRFactor(); - - new_expr_ = IrBuilder::create( - resize_out, resize_in, resize->leftExpand(), resize->rightExpand()); -} - // We're going to replay this split operation on the corresponding ID void ReplayTransformations::handle(Split* s) { // Grab our input to the split node diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index 55ab90bc042..7bce042fbae 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -11,14 +11,14 @@ #include #include -#include #include -#include #include #include namespace nvfuser { +class RootDomainMap; + namespace { // Enable pair in a set, size_t must be unique in set @@ -68,54 +68,6 @@ class ReplayTransform : OptInConstDispatch { const std::vector& input_ids_; }; -class ReplacementTransformCloner : OptInConstDispatch { - public: - // Generates a copy of expression_to_match with inputs and/or outputs replaced - // by entries provided in the map. Inputs and outputs are expected to be - // "clones". Not literally, but it's up to the envoking code to make the - // input/output replacements are safe to use in the cloned expression. No - // validation is done on provided inputs/outputs. - // - // In other words a split i0{I0}->i1{I0//2}, i2{2} with a map: - // i2{2} -> i3{48} wouldn't throw an error, but would not bevalid. - static Expr* clone( - const std::unordered_map& - provided_expr_val_2_replacement_val, - const Expr* expression_to_match); - - private: - ReplacementTransformCloner() = delete; - - ReplacementTransformCloner( - const std::unordered_map& - expr_to_match_2_replacement, - const Expr* expression_to_match); - - using OptInConstDispatch::handle; - - // Returns entry in provided_expr_val_2_replacement_val_ if exists otherwise - // returns a clone of the provided iter domain. - IterDomain* replaceOrClone(IterDomain* id); - - // We're going to replay this split operation on the corresponding ID - void handle(const Split* split) override; - - // We're going to replay this merge operation on the corresponding IDs - void handle(const Merge* merge) override; - - // We're going to replay this swizzle operation on the corresponding IDs - // if replaying swizzle is enabled. - void handle(const Swizzle2D* swizzle_2d) override; - - // We're going to replay this resize operation on the corresponding IDs - // if replaying resize is enabled. - void handle(const Resize* resize) override; - - Expr* new_expr_ = nullptr; - const std::unordered_map& - provided_expr_val_2_replacement_val_; -}; - // Uses the history of _target_domain, and replays that history using the // provided map. // From 0729deebef9ab832bd8636fcdc2cf7fa17819f16 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 20 May 2023 16:57:47 -0400 Subject: [PATCH 036/178] Shuffling code. --- CMakeLists.txt | 2 +- csrc/id_model/id_graphs.cpp | 2 +- ...ent_transform.cpp => transform_replay.cpp} | 75 ++++++++++++++++++- ...acement_transform.h => transform_replay.h} | 36 +++++++++ csrc/transform_iter.cpp | 65 ---------------- csrc/transform_iter.h | 36 --------- 6 files changed, 112 insertions(+), 104 deletions(-) rename csrc/id_model/{replacement_transform.cpp => transform_replay.cpp} (57%) rename csrc/id_model/{replacement_transform.h => transform_replay.h} (65%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 719db3ceb5e..3c03ceef45a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,8 +88,8 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/index_compute.cpp ${NVFUSER_SRCS_DIR}/id_model/id_graph.cpp ${NVFUSER_SRCS_DIR}/id_model/id_graphs.cpp - ${NVFUSER_SRCS_DIR}/id_model/replacement_transform.cpp ${NVFUSER_SRCS_DIR}/id_model/to_string.cpp + ${NVFUSER_SRCS_DIR}/id_model/transform_replay.cpp ${NVFUSER_SRCS_DIR}/id_model/visitor.cpp ${NVFUSER_SRCS_DIR}/instrumentation.cpp ${NVFUSER_SRCS_DIR}/ir/base_nodes.cpp diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 85d8d79d596..0335f731ecf 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -6,8 +6,8 @@ */ // clang-format on #include -#include #include +#include #include #include diff --git a/csrc/id_model/replacement_transform.cpp b/csrc/id_model/transform_replay.cpp similarity index 57% rename from csrc/id_model/replacement_transform.cpp rename to csrc/id_model/transform_replay.cpp index ffbcd07ec94..dbf376949b4 100644 --- a/csrc/id_model/replacement_transform.cpp +++ b/csrc/id_model/transform_replay.cpp @@ -1,8 +1,81 @@ -#include +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include #include namespace nvfuser { + +Expr* ReplayTransform::replayAs( + const std::vector& ordered_inputs, + const Expr* expression_to_match) { + ReplayTransform replay(ordered_inputs, expression_to_match); + return replay.replayed_expr_; +} + +ReplayTransform::ReplayTransform( + const std::vector& ordered_inputs, + const Expr* expression_to_match) + : input_ids_(ordered_inputs) { + OptOutConstDispatch::handle(expression_to_match); +} + +// We're going to replay this split operation on the corresponding ID +void ReplayTransform::handle(const Split* split) { + TORCH_INTERNAL_ASSERT( + input_ids_.size() == 1, + "Expected one input to match split: ", + split->toString()); + replayed_expr_ = IterDomain::split( + input_ids_[0], + split->factor(), + split->innerSplit(), + split->startOffset(), + split->stopOffset()) + .first->definition(); +} + +// We're going to replay this merge operation on the corresponding IDs +void ReplayTransform::handle(const Merge* merge) { + TORCH_INTERNAL_ASSERT( + input_ids_.size() == 2, + "Expected two inputs to match merge: ", + merge->toString()); + replayed_expr_ = + IterDomain::merge(input_ids_[0], input_ids_[1])->definition(); +} + +// We're going to replay this swizzle operation on the corresponding IDs +// if replaying swizzle is enabled. +void ReplayTransform::handle(const Swizzle2D* swizzle_2d) { + TORCH_INTERNAL_ASSERT( + input_ids_.size() == 2, + "Expected two inputs to match swizzle: ", + swizzle_2d->toString()); + replayed_expr_ = IterDomain::swizzle( + swizzle_2d->swizzleType(), + input_ids_[0], + input_ids_[1], + swizzle_2d->swizzleMode()) + .first->definition(); +} + +void ReplayTransform::handle(const Resize* resize) { + TORCH_INTERNAL_ASSERT( + input_ids_.size() == 1, + "Expected one input to match resize: ", + resize->toString()); + replayed_expr_ = + IterDomain::resize( + input_ids_[0], resize->leftExpand(), resize->rightExpand()) + ->definition(); +} + Expr* ReplacementTransformCloner::clone( const std::unordered_map& provided_expr_val_2_replacement_val, diff --git a/csrc/id_model/replacement_transform.h b/csrc/id_model/transform_replay.h similarity index 65% rename from csrc/id_model/replacement_transform.h rename to csrc/id_model/transform_replay.h index c40b70874e7..f1b0bcc0afe 100644 --- a/csrc/id_model/replacement_transform.h +++ b/csrc/id_model/transform_replay.h @@ -16,6 +16,42 @@ namespace nvfuser { +class ReplayTransform : OptInConstDispatch { + public: + // Replays expression_to_match with the provided ordered_inputs. Inputs should + // be ordered as they would be used in provided expression. Returns new + // replayed expression. + static Expr* replayAs( + const std::vector& ordered_inputs, + const Expr* expression_to_match); + + private: + ReplayTransform() = delete; + + ReplayTransform( + const std::vector& ordered_inputs, + const Expr* expression_to_match); + + using OptInConstDispatch::handle; + + // We're going to replay this split operation on the corresponding ID + void handle(const Split* split) override; + + // We're going to replay this merge operation on the corresponding IDs + void handle(const Merge* merge) override; + + // We're going to replay this swizzle operation on the corresponding IDs + // if replaying swizzle is enabled. + void handle(const Swizzle2D* swizzle_2d) override; + + // We're going to replay this resize operation on the corresponding IDs + // if replaying resize is enabled. + void handle(const Resize* resize) override; + + Expr* replayed_expr_ = nullptr; + const std::vector& input_ids_; +}; + class ReplacementTransformCloner : OptInConstDispatch { public: // Generates a copy of expression_to_match with inputs and/or outputs replaced diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 67859419b67..47dd26f00c9 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -16,71 +16,6 @@ namespace nvfuser { -Expr* ReplayTransform::replayAs( - const std::vector& ordered_inputs, - const Expr* expression_to_match) { - ReplayTransform replay(ordered_inputs, expression_to_match); - return replay.replayed_expr_; -} - -ReplayTransform::ReplayTransform( - const std::vector& ordered_inputs, - const Expr* expression_to_match) - : input_ids_(ordered_inputs) { - OptOutConstDispatch::handle(expression_to_match); -} - -// We're going to replay this split operation on the corresponding ID -void ReplayTransform::handle(const Split* split) { - TORCH_INTERNAL_ASSERT( - input_ids_.size() == 1, - "Expected one input to match split: ", - split->toString()); - replayed_expr_ = IterDomain::split( - input_ids_[0], - split->factor(), - split->innerSplit(), - split->startOffset(), - split->stopOffset()) - .first->definition(); -} - -// We're going to replay this merge operation on the corresponding IDs -void ReplayTransform::handle(const Merge* merge) { - TORCH_INTERNAL_ASSERT( - input_ids_.size() == 2, - "Expected two inputs to match merge: ", - merge->toString()); - replayed_expr_ = - IterDomain::merge(input_ids_[0], input_ids_[1])->definition(); -} - -// We're going to replay this swizzle operation on the corresponding IDs -// if replaying swizzle is enabled. -void ReplayTransform::handle(const Swizzle2D* swizzle_2d) { - TORCH_INTERNAL_ASSERT( - input_ids_.size() == 2, - "Expected two inputs to match swizzle: ", - swizzle_2d->toString()); - replayed_expr_ = IterDomain::swizzle( - swizzle_2d->swizzleType(), - input_ids_[0], - input_ids_[1], - swizzle_2d->swizzleMode()) - .first->definition(); -} - -void ReplayTransform::handle(const Resize* resize) { - TORCH_INTERNAL_ASSERT( - input_ids_.size() == 1, - "Expected one input to match resize: ", - resize->toString()); - replayed_expr_ = - IterDomain::resize( - input_ids_[0], resize->leftExpand(), resize->rightExpand()) - ->definition(); -} - // Transform dispatch void ReplayTransformations::handle(Expr* e) { auto is_supported_expr = e->isOneOf(); diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index 7bce042fbae..f3ae3b63c9d 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -32,42 +32,6 @@ struct id_int_lt { } // namespace -class ReplayTransform : OptInConstDispatch { - public: - // Replays expression_to_match with the provided ordered_inputs. Inputs should - // be ordered as they would be used in provided expression. Returns new - // replayed expression. - static Expr* replayAs( - const std::vector& ordered_inputs, - const Expr* expression_to_match); - - private: - ReplayTransform() = delete; - - ReplayTransform( - const std::vector& ordered_inputs, - const Expr* expression_to_match); - - using OptInConstDispatch::handle; - - // We're going to replay this split operation on the corresponding ID - void handle(const Split* split) override; - - // We're going to replay this merge operation on the corresponding IDs - void handle(const Merge* merge) override; - - // We're going to replay this swizzle operation on the corresponding IDs - // if replaying swizzle is enabled. - void handle(const Swizzle2D* swizzle_2d) override; - - // We're going to replay this resize operation on the corresponding IDs - // if replaying resize is enabled. - void handle(const Resize* resize) override; - - Expr* replayed_expr_ = nullptr; - const std::vector& input_ids_; -}; - // Uses the history of _target_domain, and replays that history using the // provided map. // From 1f10bd3977741762b806f9a2d3789319229b4933 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 20 May 2023 17:00:28 -0400 Subject: [PATCH 037/178] Test fix. --- test/test_gpu_indexing.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index 7efa05ad319..079ee1b62c3 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -851,6 +851,8 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { auto tv2 = broadcast(tv1, {false, true}); auto tv3 = makeConcreteTensor({7, 11}); + fusion.addInput(tv3); + auto tv4 = add(tv3, tv2); auto tv5 = broadcast(tv4, {false, false, true}); // tv4[7, 11, 1] @@ -1044,7 +1046,7 @@ TEST_F(NVFuserTest, FusionMultiPromotion2_CUDA) { // [w] auto tv4 = broadcast(tv3, {false, true}); // [w, 1] - auto tv5 = add(tv4, tv2); + auto tv5 = add(tv4, tv1); // [w, x] fusion.addOutput(tv5); @@ -1053,6 +1055,7 @@ TEST_F(NVFuserTest, FusionMultiPromotion2_CUDA) { // [w, 1] auto tv7 = add(tv6, tv2); // [y] + fusion.addOutput(tv7); for (auto tv : std::vector{tv4, tv5, tv6, tv7}) { tv->merge(0); From 2a96ccb2ab5737200cdf0f00e92ad1d835473a44 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 20 May 2023 17:17:25 -0400 Subject: [PATCH 038/178] Comments. --- csrc/id_model/id_graphs.cpp | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 0335f731ecf..1376a8b8612 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1069,14 +1069,14 @@ IdGraph IterDomainGraphs::buildIntersection( if (!propagate_exprs) { intersection.disableExprPropagation(); } - for (auto exact_group : graph0.disjointIdSets().disjointSets()) { - auto set_size = exact_group->size(); + for (auto group0 : graph0.disjointIdSets().disjointSets()) { + auto set_size = group0->size(); for (auto id0_i : c10::irange(set_size)) { - auto id0 = exact_group->vector()[id0_i]; + auto id0 = group0->vector()[id0_i]; for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { - auto id1 = exact_group->vector()[id1_i]; - // id0 and id1 map in the almost exact map, if they also map in the loop - // graph, then add the mapping to the inersection + auto id1 = group0->vector()[id1_i]; + // id0 and id1 map in group0. If they also map in the group1, + // add the mapping to the inersection. if (graph1.disjointIdSets().strictAreMapped(id0, id1)) { intersection.mapIds(id0, id1); } @@ -1115,15 +1115,16 @@ std::unordered_map IterDomainGraphs:: // need to model broadcast promotion, and if we have two tensors like: // // T1[i0, b1] = T0[i0] - // T2[i0, b1] = T0[i0] - // + // T2[i0, b2] = T0[i0] // Then resolution of: // T4 = T1[i0, b1] + T3[i0, i1] - // T6 = T2[i0, b1] + T5[i0, i2] + // T6 = T2[i0, b2] + T5[i0, i2] + // + // Then merge(0, 1) with all tensors except for T0 // - // The almost exact map will map T1's and T2's b1 together, but they're being - // resolved to i1 and i2 respectively. So we want to have separate entries so - // we can have an easy to process promotion map. + // The almost exact map will map i0, i0*b1, and i0*b2 together, but b1 and b2 + // are being resolved to i1 and i2 respectively. So we want to have separate + // entries so we can have an easy to process promotion map. // // Loop is a permissive like map, it could have many entries, use the exact // map as the one we iterate on to reduce complexity as it hopefully has @@ -1392,7 +1393,10 @@ std::unordered_map computeCoveredGroups( } for (auto output_group : graph.outputGroups(exact_expr)) { - covered_ids[output_group] = covered; + // Don't overwrite initialized cases due to rfactor markings. + if(covered_ids.find(output_group) == covered_ids.end()){ + covered_ids[output_group] = covered; + } } } From cccb0792074544ece07a9d15b0a3fbbd7c98710e Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 20 May 2023 17:19:09 -0400 Subject: [PATCH 039/178] clang format. --- csrc/id_model/id_graphs.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 1376a8b8612..7d2a352aa00 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1394,7 +1394,7 @@ std::unordered_map computeCoveredGroups( for (auto output_group : graph.outputGroups(exact_expr)) { // Don't overwrite initialized cases due to rfactor markings. - if(covered_ids.find(output_group) == covered_ids.end()){ + if (covered_ids.find(output_group) == covered_ids.end()) { covered_ids[output_group] = covered; } } From 23b2e789d0b4d9d991b4eb5976f088242e074364 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 30 Jul 2023 09:57:35 -0400 Subject: [PATCH 040/178] merge upstream --- csrc/id_model/id_graphs.h | 2 - csrc/id_model/transform_replay.cpp | 4 +- csrc/id_model/transform_replay.h | 8 +- csrc/options.cpp | 1 + csrc/options.h | 1 + csrc/python_frontend/fusion_record.h | 31 +- csrc/python_frontend/python_bindings.cpp | 185 +- .../test/test_nvfuser_fusion_record.cpp | 21 +- csrc/scheduler/mma_utils.cpp | 8 +- csrc/serde/fusion_cache_generated.h | 2934 +++++++++-------- csrc/transform_iter.cpp | 2 +- csrc/type_traits.h | 7 +- csrc/utils.h | 5 +- test/test_gpu_indexing.cpp | 12 +- 14 files changed, 1748 insertions(+), 1473 deletions(-) diff --git a/csrc/id_model/id_graphs.h b/csrc/id_model/id_graphs.h index b0177e466e1..0d144c3ee52 100644 --- a/csrc/id_model/id_graphs.h +++ b/csrc/id_model/id_graphs.h @@ -292,6 +292,4 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { std::unordered_set view_rfactor_ids_; }; -using DoubleBufferIndices = std::unordered_map; - } // namespace nvfuser diff --git a/csrc/id_model/transform_replay.cpp b/csrc/id_model/transform_replay.cpp index dbf376949b4..562b4687987 100644 --- a/csrc/id_model/transform_replay.cpp +++ b/csrc/id_model/transform_replay.cpp @@ -22,7 +22,7 @@ ReplayTransform::ReplayTransform( const std::vector& ordered_inputs, const Expr* expression_to_match) : input_ids_(ordered_inputs) { - OptOutConstDispatch::handle(expression_to_match); + OptOutConstDispatch::dispatch(expression_to_match); } // We're going to replay this split operation on the corresponding ID @@ -91,7 +91,7 @@ ReplacementTransformCloner::ReplacementTransformCloner( const Expr* expression_to_match) : provided_expr_val_2_replacement_val_( provided_expr_val_2_replacement_val) { - OptOutConstDispatch::handle(expression_to_match); + OptOutConstDispatch::dispatch(expression_to_match); } IterDomain* ReplacementTransformCloner::replaceOrClone(IterDomain* id) { diff --git a/csrc/id_model/transform_replay.h b/csrc/id_model/transform_replay.h index f1b0bcc0afe..290f82f33f5 100644 --- a/csrc/id_model/transform_replay.h +++ b/csrc/id_model/transform_replay.h @@ -35,18 +35,18 @@ class ReplayTransform : OptInConstDispatch { using OptInConstDispatch::handle; // We're going to replay this split operation on the corresponding ID - void handle(const Split* split) override; + void handle(const Split* split) final; // We're going to replay this merge operation on the corresponding IDs - void handle(const Merge* merge) override; + void handle(const Merge* merge) final; // We're going to replay this swizzle operation on the corresponding IDs // if replaying swizzle is enabled. - void handle(const Swizzle2D* swizzle_2d) override; + void handle(const Swizzle2D* swizzle_2d) final; // We're going to replay this resize operation on the corresponding IDs // if replaying resize is enabled. - void handle(const Resize* resize) override; + void handle(const Resize* resize) final; Expr* replayed_expr_ = nullptr; const std::vector& input_ids_; diff --git a/csrc/options.cpp b/csrc/options.cpp index fb8f09dd428..77ba1ad5eee 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -127,6 +127,7 @@ std::unordered_map> Options< {"kernel_ir", DebugDumpOption::KernelIr}, {"launch_param", DebugDumpOption::LaunchParam}, {"loop_rotation", DebugDumpOption::LoopRotation}, + {"lower_name_only", DebugDumpOption::LowerNameOnly}, {"lower_verbose", DebugDumpOption::LowerVerbose}, {"matmul_checks", DebugDumpOption::MatmulChecks}, {"occupancy", DebugDumpOption::Occupancy}, diff --git a/csrc/options.h b/csrc/options.h index 931ce4d7f47..00181d71d8a 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -60,6 +60,7 @@ enum class DebugDumpOption { Ptx, //! Dump compiled PTX BankConflictInfo, //! Dump bank confliction info SyncMap, //! RAW dependency info + LowerNameOnly, //! Print all passes' names as they're run in GpuLower::lower LowerVerbose, //! Print all passes' transform in GpuLower::lower ExprSimplification, //! Print all passes' transform in simplifyExpr ExprSort, //! Print merging decisions on expression sorting diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 1fb80b0bc43..1e5b69f0d9f 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -1674,32 +1674,21 @@ struct ReductionOpRecord : RecordFunctor { result = result && (*fusion_op_.template target< - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>() == + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() == *child_ptr->fusion_op_.template target< - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>()); + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>()); if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << " Target Ptr [self: 0x" << std::hex - << (size_t)*fusion_op_.template target< + debug() + << " Target Ptr [self: 0x" << std::hex + << (size_t)*fusion_op_.template target< - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>() - << "] [other: 0x" << std::hex - << (size_t)*child_ptr->fusion_op_.template target< + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() + << "] [other: 0x" << std::hex + << (size_t)*child_ptr->fusion_op_.template target< - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>() - << "]\n"; + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() + << "]\n"; } result = result && (keep_dim_ == child_ptr->keep_dim_); result = result && (dtype_ == child_ptr->dtype_); diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index a598469303d..453092b4e11 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1576,100 +1576,97 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP("addcmul", addcmul) #undef NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP -#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name, record_type) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - TORCH_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = 0; \ - std::vector axes(arg.dims); \ - std::iota(axes.begin(), axes.end(), 0); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast&, \ - bool, \ - DataType)>(op_name), \ - axes, \ - false, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - int axis, \ - bool keepdim, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - TORCH_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - 1); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast&, \ - bool, \ - DataType)>(op_name), \ - {axis}, \ - keepdim, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("axis"), \ - py::arg("keepdim") = false, \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - const std::vector& axes, \ - bool keepdim, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - TORCH_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast&, \ - bool, \ - DataType)>(op_name), \ - axes, \ - keepdim, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("axes"), \ - py::arg("keepdim") = false, \ - py::arg("dtype") = DataType::Null, \ +#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name, record_type) \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + TORCH_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = 0; \ + std::vector axes(arg.dims); \ + std::iota(axes.begin(), axes.end(), 0); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast< \ + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ + op_name), \ + axes, \ + false, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("dtype") = DataType::Null, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + int axis, \ + bool keepdim, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + TORCH_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = keepdim ? arg.dims : (arg.dims - 1); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast< \ + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ + op_name), \ + {axis}, \ + keepdim, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("axis"), \ + py::arg("keepdim") = false, \ + py::arg("dtype") = DataType::Null, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + const std::vector& axes, \ + bool keepdim, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + TORCH_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast< \ + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ + op_name), \ + axes, \ + keepdim, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("axes"), \ + py::arg("keepdim") = false, \ + py::arg("dtype") = DataType::Null, \ py::return_value_policy::reference); NVFUSER_PYTHON_BINDING_REDUCTION_OP( diff --git a/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp b/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp index e0eabf5122d..170531cfaad 100644 --- a/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp +++ b/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp @@ -98,10 +98,9 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast&, - bool, - DataType)>(sum), + static_cast< + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( + sum), {0}, false, DataType::Float)); @@ -110,10 +109,9 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast&, - bool, - DataType)>(sum), + static_cast< + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( + sum), {0}, false, DataType::Float)); @@ -122,10 +120,9 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast&, - bool, - DataType)>(sum), + static_cast< + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( + sum), {0}, false, DataType::Float)); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 2e745e665b7..bac6f5f6754 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -40,11 +40,11 @@ bool generateSharedMemoryEpilogueHeuristics( properties->warpSize * vector_word; const int mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k; const int nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k; - const size_t smem_a = (size_t)(ceilDiv(mk, round_to_factor) * - round_to_factor * smem_double_buffer_stage) * + const size_t smem_a = + (size_t)(ceilDiv(mk, round_to_factor) * round_to_factor * smem_double_buffer_stage) * dataTypeSize(data_types[0]); - const size_t smem_b = (size_t)(ceilDiv(nk, round_to_factor) * - round_to_factor * smem_double_buffer_stage) * + const size_t smem_b = + (size_t)(ceilDiv(nk, round_to_factor) * round_to_factor * smem_double_buffer_stage) * dataTypeSize(data_types[1]); const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * dataTypeSize(data_types[2]); diff --git a/csrc/serde/fusion_cache_generated.h b/csrc/serde/fusion_cache_generated.h index f8a02893b50..060de4420cb 100644 --- a/csrc/serde/fusion_cache_generated.h +++ b/csrc/serde/fusion_cache_generated.h @@ -1,6 +1,5 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_FUSIONCACHE_NVFUSER_SERDE_H_ #define FLATBUFFERS_GENERATED_FUSIONCACHE_NVFUSER_SERDE_H_ @@ -8,10 +7,10 @@ // Ensure the included flatbuffers.h is the same version as when this file was // generated, otherwise it may not be compatible. -static_assert(FLATBUFFERS_VERSION_MAJOR == 23 && - FLATBUFFERS_VERSION_MINOR == 3 && - FLATBUFFERS_VERSION_REVISION == 3, - "Non-compatible flatbuffers version included"); +static_assert( + FLATBUFFERS_VERSION_MAJOR == 23 && FLATBUFFERS_VERSION_MINOR == 3 && + FLATBUFFERS_VERSION_REVISION == 3, + "Non-compatible flatbuffers version included"); namespace nvfuser { namespace serde { @@ -148,39 +147,38 @@ enum DataType : int32_t { inline const DataType (&EnumValuesDataType())[10] { static const DataType values[] = { - DataType_Double, - DataType_Float, - DataType_Half, - DataType_Int, - DataType_Int32, - DataType_Bool, - DataType_BFloat16, - DataType_ComplexFloat, - DataType_ComplexDouble, - DataType_None - }; + DataType_Double, + DataType_Float, + DataType_Half, + DataType_Int, + DataType_Int32, + DataType_Bool, + DataType_BFloat16, + DataType_ComplexFloat, + DataType_ComplexDouble, + DataType_None}; return values; } -inline const char * const *EnumNamesDataType() { - static const char * const names[11] = { - "Double", - "Float", - "Half", - "Int", - "Int32", - "Bool", - "BFloat16", - "ComplexFloat", - "ComplexDouble", - "None", - nullptr - }; +inline const char* const* EnumNamesDataType() { + static const char* const names[11] = { + "Double", + "Float", + "Half", + "Int", + "Int32", + "Bool", + "BFloat16", + "ComplexFloat", + "ComplexDouble", + "None", + nullptr}; return names; } -inline const char *EnumNameDataType(DataType e) { - if (::flatbuffers::IsOutRange(e, DataType_Double, DataType_None)) return ""; +inline const char* EnumNameDataType(DataType e) { + if (::flatbuffers::IsOutRange(e, DataType_Double, DataType_None)) + return ""; const size_t index = static_cast(e); return EnumNamesDataType()[index]; } @@ -196,27 +194,19 @@ enum StateType : int32_t { inline const StateType (&EnumValuesStateType())[4] { static const StateType values[] = { - StateType_Tensor, - StateType_Scalar, - StateType_Vector, - StateType_None - }; + StateType_Tensor, StateType_Scalar, StateType_Vector, StateType_None}; return values; } -inline const char * const *EnumNamesStateType() { - static const char * const names[5] = { - "Tensor", - "Scalar", - "Vector", - "None", - nullptr - }; +inline const char* const* EnumNamesStateType() { + static const char* const names[5] = { + "Tensor", "Scalar", "Vector", "None", nullptr}; return names; } -inline const char *EnumNameStateType(StateType e) { - if (::flatbuffers::IsOutRange(e, StateType_Tensor, StateType_None)) return ""; +inline const char* EnumNameStateType(StateType e) { + if (::flatbuffers::IsOutRange(e, StateType_Tensor, StateType_None)) + return ""; const size_t index = static_cast(e); return EnumNamesStateType()[index]; } @@ -231,25 +221,19 @@ enum Contiguity : int32_t { inline const Contiguity (&EnumValuesContiguity())[3] { static const Contiguity values[] = { - Contiguity_Strided, - Contiguity_Contiguous, - Contiguity_None - }; + Contiguity_Strided, Contiguity_Contiguous, Contiguity_None}; return values; } -inline const char * const *EnumNamesContiguity() { - static const char * const names[4] = { - "Strided", - "Contiguous", - "None", - nullptr - }; +inline const char* const* EnumNamesContiguity() { + static const char* const names[4] = { + "Strided", "Contiguous", "None", nullptr}; return names; } -inline const char *EnumNameContiguity(Contiguity e) { - if (::flatbuffers::IsOutRange(e, Contiguity_Strided, Contiguity_None)) return ""; +inline const char* EnumNameContiguity(Contiguity e) { + if (::flatbuffers::IsOutRange(e, Contiguity_Strided, Contiguity_None)) + return ""; const size_t index = static_cast(e); return EnumNamesContiguity()[index]; } @@ -316,129 +300,128 @@ enum RecordType : int32_t { inline const RecordType (&EnumValuesRecordType())[55] { static const RecordType values[] = { - RecordType_Base, - RecordType_BatchNormOp, - RecordType_BroadcastOp, - RecordType_BroadcastInDim, - RecordType_BroadcastInDimSymbolic, - RecordType_CastTv, - RecordType_CastVal, - RecordType_CatOp, - RecordType_End, - RecordType_FullOp, - RecordType_IotaOp, - RecordType_IndexSelectOp, - RecordType_TorchGatherOp, - RecordType_TakeAlongAxisOp, - RecordType_Unary_TV, - RecordType_Unary_VAL, - RecordType_Binary_TV, - RecordType_Binary_VAL, - RecordType_Binary_TV_VAL, - RecordType_Binary_VAL_TV, - RecordType_Ternary_TV, - RecordType_Ternary_VAL, - RecordType_Ternary_TV_TV_VAL, - RecordType_Ternary_TV_VAL_TV, - RecordType_Ternary_VAL_TV_TV, - RecordType_Ternary_VAL_VAL_TV, - RecordType_Ternary_TV_VAL_VAL, - RecordType_Ternary_VAL_TV_VAL, - RecordType_Ternary_Alpha_TV, - RecordType_Ternary_Alpha_VAL, - RecordType_Ternary_Alpha_TV_TV_VAL, - RecordType_Ternary_Alpha_TV_VAL_TV, - RecordType_Ternary_Alpha_VAL_TV_TV, - RecordType_Ternary_Alpha_VAL_VAL_TV, - RecordType_Ternary_Alpha_TV_VAL_VAL, - RecordType_Ternary_Alpha_VAL_TV_VAL, - RecordType_OutputTv, - RecordType_OutputVal, - RecordType_PadOp, - RecordType_PermuteOp, - RecordType_RandomOp, - RecordType_ReductionMax, - RecordType_ReductionMin, - RecordType_ReductionProd, - RecordType_ReductionSum, - RecordType_ReshapeOp, - RecordType_Scalar, - RecordType_SliceOp, - RecordType_SqueezeOp, - RecordType_Start, - RecordType_Tensor, - RecordType_TensorSizes, - RecordType_VarianceOp, - RecordType_VarianceMeanOp, - RecordType_Vector - }; + RecordType_Base, + RecordType_BatchNormOp, + RecordType_BroadcastOp, + RecordType_BroadcastInDim, + RecordType_BroadcastInDimSymbolic, + RecordType_CastTv, + RecordType_CastVal, + RecordType_CatOp, + RecordType_End, + RecordType_FullOp, + RecordType_IotaOp, + RecordType_IndexSelectOp, + RecordType_TorchGatherOp, + RecordType_TakeAlongAxisOp, + RecordType_Unary_TV, + RecordType_Unary_VAL, + RecordType_Binary_TV, + RecordType_Binary_VAL, + RecordType_Binary_TV_VAL, + RecordType_Binary_VAL_TV, + RecordType_Ternary_TV, + RecordType_Ternary_VAL, + RecordType_Ternary_TV_TV_VAL, + RecordType_Ternary_TV_VAL_TV, + RecordType_Ternary_VAL_TV_TV, + RecordType_Ternary_VAL_VAL_TV, + RecordType_Ternary_TV_VAL_VAL, + RecordType_Ternary_VAL_TV_VAL, + RecordType_Ternary_Alpha_TV, + RecordType_Ternary_Alpha_VAL, + RecordType_Ternary_Alpha_TV_TV_VAL, + RecordType_Ternary_Alpha_TV_VAL_TV, + RecordType_Ternary_Alpha_VAL_TV_TV, + RecordType_Ternary_Alpha_VAL_VAL_TV, + RecordType_Ternary_Alpha_TV_VAL_VAL, + RecordType_Ternary_Alpha_VAL_TV_VAL, + RecordType_OutputTv, + RecordType_OutputVal, + RecordType_PadOp, + RecordType_PermuteOp, + RecordType_RandomOp, + RecordType_ReductionMax, + RecordType_ReductionMin, + RecordType_ReductionProd, + RecordType_ReductionSum, + RecordType_ReshapeOp, + RecordType_Scalar, + RecordType_SliceOp, + RecordType_SqueezeOp, + RecordType_Start, + RecordType_Tensor, + RecordType_TensorSizes, + RecordType_VarianceOp, + RecordType_VarianceMeanOp, + RecordType_Vector}; return values; } -inline const char * const *EnumNamesRecordType() { - static const char * const names[56] = { - "Base", - "BatchNormOp", - "BroadcastOp", - "BroadcastInDim", - "BroadcastInDimSymbolic", - "CastTv", - "CastVal", - "CatOp", - "End", - "FullOp", - "IotaOp", - "IndexSelectOp", - "TorchGatherOp", - "TakeAlongAxisOp", - "Unary_TV", - "Unary_VAL", - "Binary_TV", - "Binary_VAL", - "Binary_TV_VAL", - "Binary_VAL_TV", - "Ternary_TV", - "Ternary_VAL", - "Ternary_TV_TV_VAL", - "Ternary_TV_VAL_TV", - "Ternary_VAL_TV_TV", - "Ternary_VAL_VAL_TV", - "Ternary_TV_VAL_VAL", - "Ternary_VAL_TV_VAL", - "Ternary_Alpha_TV", - "Ternary_Alpha_VAL", - "Ternary_Alpha_TV_TV_VAL", - "Ternary_Alpha_TV_VAL_TV", - "Ternary_Alpha_VAL_TV_TV", - "Ternary_Alpha_VAL_VAL_TV", - "Ternary_Alpha_TV_VAL_VAL", - "Ternary_Alpha_VAL_TV_VAL", - "OutputTv", - "OutputVal", - "PadOp", - "PermuteOp", - "RandomOp", - "ReductionMax", - "ReductionMin", - "ReductionProd", - "ReductionSum", - "ReshapeOp", - "Scalar", - "SliceOp", - "SqueezeOp", - "Start", - "Tensor", - "TensorSizes", - "VarianceOp", - "VarianceMeanOp", - "Vector", - nullptr - }; +inline const char* const* EnumNamesRecordType() { + static const char* const names[56] = { + "Base", + "BatchNormOp", + "BroadcastOp", + "BroadcastInDim", + "BroadcastInDimSymbolic", + "CastTv", + "CastVal", + "CatOp", + "End", + "FullOp", + "IotaOp", + "IndexSelectOp", + "TorchGatherOp", + "TakeAlongAxisOp", + "Unary_TV", + "Unary_VAL", + "Binary_TV", + "Binary_VAL", + "Binary_TV_VAL", + "Binary_VAL_TV", + "Ternary_TV", + "Ternary_VAL", + "Ternary_TV_TV_VAL", + "Ternary_TV_VAL_TV", + "Ternary_VAL_TV_TV", + "Ternary_VAL_VAL_TV", + "Ternary_TV_VAL_VAL", + "Ternary_VAL_TV_VAL", + "Ternary_Alpha_TV", + "Ternary_Alpha_VAL", + "Ternary_Alpha_TV_TV_VAL", + "Ternary_Alpha_TV_VAL_TV", + "Ternary_Alpha_VAL_TV_TV", + "Ternary_Alpha_VAL_VAL_TV", + "Ternary_Alpha_TV_VAL_VAL", + "Ternary_Alpha_VAL_TV_VAL", + "OutputTv", + "OutputVal", + "PadOp", + "PermuteOp", + "RandomOp", + "ReductionMax", + "ReductionMin", + "ReductionProd", + "ReductionSum", + "ReshapeOp", + "Scalar", + "SliceOp", + "SqueezeOp", + "Start", + "Tensor", + "TensorSizes", + "VarianceOp", + "VarianceMeanOp", + "Vector", + nullptr}; return names; } -inline const char *EnumNameRecordType(RecordType e) { - if (::flatbuffers::IsOutRange(e, RecordType_Base, RecordType_Vector)) return ""; +inline const char* EnumNameRecordType(RecordType e) { + if (::flatbuffers::IsOutRange(e, RecordType_Base, RecordType_Vector)) + return ""; const size_t index = static_cast(e); return EnumNamesRecordType()[index]; } @@ -470,145 +453,170 @@ enum RecordData : uint8_t { inline const RecordData (&EnumValuesRecordData())[20] { static const RecordData values[] = { - RecordData_NONE, - RecordData_BatchNorm, - RecordData_Broadcast, - RecordData_BroadcastInDim, - RecordData_BroadcastInDimSymbolic, - RecordData_Dimension, - RecordData_Dtype, - RecordData_Norm, - RecordData_Output, - RecordData_Pad, - RecordData_Permute, - RecordData_Slice, - RecordData_Squeeze, - RecordData_Reduction, - RecordData_Reshape, - RecordData_Scalar, - RecordData_Tensor, - RecordData_TensorCreation, - RecordData_TensorCreationSymbolic, - RecordData_Vector - }; + RecordData_NONE, + RecordData_BatchNorm, + RecordData_Broadcast, + RecordData_BroadcastInDim, + RecordData_BroadcastInDimSymbolic, + RecordData_Dimension, + RecordData_Dtype, + RecordData_Norm, + RecordData_Output, + RecordData_Pad, + RecordData_Permute, + RecordData_Slice, + RecordData_Squeeze, + RecordData_Reduction, + RecordData_Reshape, + RecordData_Scalar, + RecordData_Tensor, + RecordData_TensorCreation, + RecordData_TensorCreationSymbolic, + RecordData_Vector}; return values; } -inline const char * const *EnumNamesRecordData() { - static const char * const names[21] = { - "NONE", - "BatchNorm", - "Broadcast", - "BroadcastInDim", - "BroadcastInDimSymbolic", - "Dimension", - "Dtype", - "Norm", - "Output", - "Pad", - "Permute", - "Slice", - "Squeeze", - "Reduction", - "Reshape", - "Scalar", - "Tensor", - "TensorCreation", - "TensorCreationSymbolic", - "Vector", - nullptr - }; +inline const char* const* EnumNamesRecordData() { + static const char* const names[21] = { + "NONE", + "BatchNorm", + "Broadcast", + "BroadcastInDim", + "BroadcastInDimSymbolic", + "Dimension", + "Dtype", + "Norm", + "Output", + "Pad", + "Permute", + "Slice", + "Squeeze", + "Reduction", + "Reshape", + "Scalar", + "Tensor", + "TensorCreation", + "TensorCreationSymbolic", + "Vector", + nullptr}; return names; } -inline const char *EnumNameRecordData(RecordData e) { - if (::flatbuffers::IsOutRange(e, RecordData_NONE, RecordData_Vector)) return ""; +inline const char* EnumNameRecordData(RecordData e) { + if (::flatbuffers::IsOutRange(e, RecordData_NONE, RecordData_Vector)) + return ""; const size_t index = static_cast(e); return EnumNamesRecordData()[index]; } -template struct RecordDataTraits { +template +struct RecordDataTraits { static const RecordData enum_value = RecordData_NONE; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_BatchNorm; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Broadcast; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_BroadcastInDim; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_BroadcastInDimSymbolic; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Dimension; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Dtype; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Norm; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Output; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Pad; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Permute; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Slice; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Squeeze; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Reduction; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Reshape; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Scalar; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Tensor; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_TensorCreation; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_TensorCreationSymbolic; }; -template<> struct RecordDataTraits { +template <> +struct RecordDataTraits { static const RecordData enum_value = RecordData_Vector; }; -bool VerifyRecordData(::flatbuffers::Verifier &verifier, const void *obj, RecordData type); -bool VerifyRecordDataVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); +bool VerifyRecordData( + ::flatbuffers::Verifier& verifier, + const void* obj, + RecordData type); +bool VerifyRecordDataVector( + ::flatbuffers::Verifier& verifier, + const ::flatbuffers::Vector<::flatbuffers::Offset>* values, + const ::flatbuffers::Vector* types); enum ArgAbstractData : uint8_t { ArgAbstractData_NONE = 0, @@ -622,55 +630,61 @@ enum ArgAbstractData : uint8_t { inline const ArgAbstractData (&EnumValuesArgAbstractData())[5] { static const ArgAbstractData values[] = { - ArgAbstractData_NONE, - ArgAbstractData_Scalar, - ArgAbstractData_PhiloxCudaState, - ArgAbstractData_ScalarCpu, - ArgAbstractData_TensorArg - }; + ArgAbstractData_NONE, + ArgAbstractData_Scalar, + ArgAbstractData_PhiloxCudaState, + ArgAbstractData_ScalarCpu, + ArgAbstractData_TensorArg}; return values; } -inline const char * const *EnumNamesArgAbstractData() { - static const char * const names[6] = { - "NONE", - "Scalar", - "PhiloxCudaState", - "ScalarCpu", - "TensorArg", - nullptr - }; +inline const char* const* EnumNamesArgAbstractData() { + static const char* const names[6] = { + "NONE", "Scalar", "PhiloxCudaState", "ScalarCpu", "TensorArg", nullptr}; return names; } -inline const char *EnumNameArgAbstractData(ArgAbstractData e) { - if (::flatbuffers::IsOutRange(e, ArgAbstractData_NONE, ArgAbstractData_TensorArg)) return ""; +inline const char* EnumNameArgAbstractData(ArgAbstractData e) { + if (::flatbuffers::IsOutRange( + e, ArgAbstractData_NONE, ArgAbstractData_TensorArg)) + return ""; const size_t index = static_cast(e); return EnumNamesArgAbstractData()[index]; } -template struct ArgAbstractDataTraits { +template +struct ArgAbstractDataTraits { static const ArgAbstractData enum_value = ArgAbstractData_NONE; }; -template<> struct ArgAbstractDataTraits { +template <> +struct ArgAbstractDataTraits { static const ArgAbstractData enum_value = ArgAbstractData_Scalar; }; -template<> struct ArgAbstractDataTraits { +template <> +struct ArgAbstractDataTraits { static const ArgAbstractData enum_value = ArgAbstractData_PhiloxCudaState; }; -template<> struct ArgAbstractDataTraits { +template <> +struct ArgAbstractDataTraits { static const ArgAbstractData enum_value = ArgAbstractData_ScalarCpu; }; -template<> struct ArgAbstractDataTraits { +template <> +struct ArgAbstractDataTraits { static const ArgAbstractData enum_value = ArgAbstractData_TensorArg; }; -bool VerifyArgAbstractData(::flatbuffers::Verifier &verifier, const void *obj, ArgAbstractData type); -bool VerifyArgAbstractDataVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); +bool VerifyArgAbstractData( + ::flatbuffers::Verifier& verifier, + const void* obj, + ArgAbstractData type); +bool VerifyArgAbstractDataVector( + ::flatbuffers::Verifier& verifier, + const ::flatbuffers::Vector<::flatbuffers::Offset>* values, + const ::flatbuffers::Vector* types); FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) State FLATBUFFERS_FINAL_CLASS { private: @@ -678,19 +692,16 @@ FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) State FLATBUFFERS_FINAL_CLASS { int32_t type_; public: - State() - : index_(0), - type_(0) { - } + State() : index_(0), type_(0) {} State(int32_t _index, nvfuser::serde::StateType _type) : index_(::flatbuffers::EndianScalar(_index)), - type_(::flatbuffers::EndianScalar(static_cast(_type))) { - } + type_(::flatbuffers::EndianScalar(static_cast(_type))) {} int32_t index() const { return ::flatbuffers::EndianScalar(index_); } nvfuser::serde::StateType type() const { - return static_cast(::flatbuffers::EndianScalar(type_)); + return static_cast( + ::flatbuffers::EndianScalar(type_)); } }; FLATBUFFERS_STRUCT_END(State, 8); @@ -701,14 +712,10 @@ FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) EncodingEntry FLATBUFFERS_FINAL_CLASS { uint64_t lru_iter_; public: - EncodingEntry() - : id_(0), - lru_iter_(0) { - } + EncodingEntry() : id_(0), lru_iter_(0) {} EncodingEntry(uint64_t _id, uint64_t _lru_iter) : id_(::flatbuffers::EndianScalar(_id)), - lru_iter_(::flatbuffers::EndianScalar(_lru_iter)) { - } + lru_iter_(::flatbuffers::EndianScalar(_lru_iter)) {} uint64_t id() const { return ::flatbuffers::EndianScalar(id_); } @@ -731,13 +738,15 @@ struct Scalar FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_IMAG_VALUE = 18 }; nvfuser::serde::DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + return static_cast( + GetField(VT_DTYPE, 0)); } bool has_value() const { return GetField(VT_HAS_VALUE, 0) != 0; } nvfuser::serde::DataType value_type() const { - return static_cast(GetField(VT_VALUE_TYPE, 0)); + return static_cast( + GetField(VT_VALUE_TYPE, 0)); } bool bool_value() const { return GetField(VT_BOOL_VALUE, 0) != 0; @@ -754,35 +763,37 @@ struct Scalar FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { double imag_value() const { return GetField(VT_IMAG_VALUE, 0.0); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_DTYPE, 4) && - VerifyField(verifier, VT_HAS_VALUE, 1) && - VerifyField(verifier, VT_VALUE_TYPE, 4) && - VerifyField(verifier, VT_BOOL_VALUE, 1) && - VerifyField(verifier, VT_LONG_VALUE, 8) && - VerifyField(verifier, VT_DOUBLE_VALUE, 8) && - VerifyField(verifier, VT_REAL_VALUE, 8) && - VerifyField(verifier, VT_IMAG_VALUE, 8) && - verifier.EndTable(); + VerifyField(verifier, VT_DTYPE, 4) && + VerifyField(verifier, VT_HAS_VALUE, 1) && + VerifyField(verifier, VT_VALUE_TYPE, 4) && + VerifyField(verifier, VT_BOOL_VALUE, 1) && + VerifyField(verifier, VT_LONG_VALUE, 8) && + VerifyField(verifier, VT_DOUBLE_VALUE, 8) && + VerifyField(verifier, VT_REAL_VALUE, 8) && + VerifyField(verifier, VT_IMAG_VALUE, 8) && verifier.EndTable(); } }; struct ScalarBuilder { typedef Scalar Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_dtype(nvfuser::serde::DataType dtype) { fbb_.AddElement(Scalar::VT_DTYPE, static_cast(dtype), 0); } void add_has_value(bool has_value) { - fbb_.AddElement(Scalar::VT_HAS_VALUE, static_cast(has_value), 0); + fbb_.AddElement( + Scalar::VT_HAS_VALUE, static_cast(has_value), 0); } void add_value_type(nvfuser::serde::DataType value_type) { - fbb_.AddElement(Scalar::VT_VALUE_TYPE, static_cast(value_type), 0); + fbb_.AddElement( + Scalar::VT_VALUE_TYPE, static_cast(value_type), 0); } void add_bool_value(bool bool_value) { - fbb_.AddElement(Scalar::VT_BOOL_VALUE, static_cast(bool_value), 0); + fbb_.AddElement( + Scalar::VT_BOOL_VALUE, static_cast(bool_value), 0); } void add_long_value(int64_t long_value) { fbb_.AddElement(Scalar::VT_LONG_VALUE, long_value, 0); @@ -796,8 +807,7 @@ struct ScalarBuilder { void add_imag_value(double imag_value) { fbb_.AddElement(Scalar::VT_IMAG_VALUE, imag_value, 0.0); } - explicit ScalarBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit ScalarBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -808,7 +818,7 @@ struct ScalarBuilder { }; inline ::flatbuffers::Offset CreateScalar( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double, bool has_value = false, nvfuser::serde::DataType value_type = nvfuser::serde::DataType_Double, @@ -834,26 +844,24 @@ struct TensorShape FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_SHAPE = 4 }; - const ::flatbuffers::Vector *shape() const { - return GetPointer *>(VT_SHAPE); + const ::flatbuffers::Vector* shape() const { + return GetPointer*>(VT_SHAPE); } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_SHAPE) && - verifier.VerifyVector(shape()) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && verifier.EndTable(); } }; struct TensorShapeBuilder { typedef TensorShape Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_shape(::flatbuffers::Offset<::flatbuffers::Vector> shape) { fbb_.AddOffset(TensorShape::VT_SHAPE, shape); } - explicit TensorShapeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit TensorShapeBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -864,7 +872,7 @@ struct TensorShapeBuilder { }; inline ::flatbuffers::Offset CreateTensorShape( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> shape = 0) { TensorShapeBuilder builder_(_fbb); builder_.add_shape(shape); @@ -872,12 +880,10 @@ inline ::flatbuffers::Offset CreateTensorShape( } inline ::flatbuffers::Offset CreateTensorShapeDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *shape = nullptr) { + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* shape = nullptr) { auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; - return nvfuser::serde::CreateTensorShape( - _fbb, - shape__); + return nvfuser::serde::CreateTensorShape(_fbb, shape__); } struct ScalarInput FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -886,24 +892,25 @@ struct ScalarInput FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_DTYPE = 4 }; nvfuser::serde::DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + return static_cast( + GetField(VT_DTYPE, 0)); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_DTYPE, 4) && - verifier.EndTable(); + VerifyField(verifier, VT_DTYPE, 4) && verifier.EndTable(); } }; struct ScalarInputBuilder { typedef ScalarInput Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_dtype(nvfuser::serde::DataType dtype) { - fbb_.AddElement(ScalarInput::VT_DTYPE, static_cast(dtype), 0); + fbb_.AddElement( + ScalarInput::VT_DTYPE, static_cast(dtype), 0); } - explicit ScalarInputBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit ScalarInputBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -914,7 +921,7 @@ struct ScalarInputBuilder { }; inline ::flatbuffers::Offset CreateScalarInput( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double) { ScalarInputBuilder builder_(_fbb); builder_.add_dtype(dtype); @@ -933,17 +940,16 @@ struct PhiloxCudaState FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { uint64_t offset() const { return GetField(VT_OFFSET, 0); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_SEED, 8) && - VerifyField(verifier, VT_OFFSET, 8) && - verifier.EndTable(); + VerifyField(verifier, VT_SEED, 8) && + VerifyField(verifier, VT_OFFSET, 8) && verifier.EndTable(); } }; struct PhiloxCudaStateBuilder { typedef PhiloxCudaState Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_seed(uint64_t seed) { fbb_.AddElement(PhiloxCudaState::VT_SEED, seed, 0); @@ -951,8 +957,8 @@ struct PhiloxCudaStateBuilder { void add_offset(uint64_t offset) { fbb_.AddElement(PhiloxCudaState::VT_OFFSET, offset, 0); } - explicit PhiloxCudaStateBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit PhiloxCudaStateBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -963,7 +969,7 @@ struct PhiloxCudaStateBuilder { }; inline ::flatbuffers::Offset CreatePhiloxCudaState( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, uint64_t seed = 0, uint64_t offset = 0) { PhiloxCudaStateBuilder builder_(_fbb); @@ -978,33 +984,32 @@ struct ScalarCpu FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_INSTANCE = 4, VT_SIZE = 6 }; - const ::flatbuffers::Vector *instance() const { - return GetPointer *>(VT_INSTANCE); + const ::flatbuffers::Vector* instance() const { + return GetPointer*>(VT_INSTANCE); } uint64_t size() const { return GetField(VT_SIZE, 0); } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_INSTANCE) && - verifier.VerifyVector(instance()) && - VerifyField(verifier, VT_SIZE, 8) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_INSTANCE) && + verifier.VerifyVector(instance()) && + VerifyField(verifier, VT_SIZE, 8) && verifier.EndTable(); } }; struct ScalarCpuBuilder { typedef ScalarCpu Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_instance(::flatbuffers::Offset<::flatbuffers::Vector> instance) { + void add_instance( + ::flatbuffers::Offset<::flatbuffers::Vector> instance) { fbb_.AddOffset(ScalarCpu::VT_INSTANCE, instance); } void add_size(uint64_t size) { fbb_.AddElement(ScalarCpu::VT_SIZE, size, 0); } - explicit ScalarCpuBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit ScalarCpuBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1015,7 +1020,7 @@ struct ScalarCpuBuilder { }; inline ::flatbuffers::Offset CreateScalarCpu( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> instance = 0, uint64_t size = 0) { ScalarCpuBuilder builder_(_fbb); @@ -1025,14 +1030,11 @@ inline ::flatbuffers::Offset CreateScalarCpu( } inline ::flatbuffers::Offset CreateScalarCpuDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *instance = nullptr, + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* instance = nullptr, uint64_t size = 0) { auto instance__ = instance ? _fbb.CreateVector(*instance) : 0; - return nvfuser::serde::CreateScalarCpu( - _fbb, - instance__, - size); + return nvfuser::serde::CreateScalarCpu(_fbb, instance__, size); } struct TensorArg FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -1048,14 +1050,15 @@ struct TensorArg FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { uint64_t ptr() const { return GetField(VT_PTR, 0); } - const ::flatbuffers::Vector *sizes() const { - return GetPointer *>(VT_SIZES); + const ::flatbuffers::Vector* sizes() const { + return GetPointer*>(VT_SIZES); } - const ::flatbuffers::Vector *strides() const { - return GetPointer *>(VT_STRIDES); + const ::flatbuffers::Vector* strides() const { + return GetPointer*>(VT_STRIDES); } nvfuser::serde::DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + return static_cast( + GetField(VT_DTYPE, 0)); } bool is_int_index_mode() const { return GetField(VT_IS_INT_INDEX_MODE, 0) != 0; @@ -1063,23 +1066,22 @@ struct TensorArg FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool index_type_resolved() const { return GetField(VT_INDEX_TYPE_RESOLVED, 0) != 0; } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_PTR, 8) && - VerifyOffset(verifier, VT_SIZES) && - verifier.VerifyVector(sizes()) && - VerifyOffset(verifier, VT_STRIDES) && - verifier.VerifyVector(strides()) && - VerifyField(verifier, VT_DTYPE, 4) && - VerifyField(verifier, VT_IS_INT_INDEX_MODE, 1) && - VerifyField(verifier, VT_INDEX_TYPE_RESOLVED, 1) && - verifier.EndTable(); + VerifyField(verifier, VT_PTR, 8) && + VerifyOffset(verifier, VT_SIZES) && verifier.VerifyVector(sizes()) && + VerifyOffset(verifier, VT_STRIDES) && + verifier.VerifyVector(strides()) && + VerifyField(verifier, VT_DTYPE, 4) && + VerifyField(verifier, VT_IS_INT_INDEX_MODE, 1) && + VerifyField(verifier, VT_INDEX_TYPE_RESOLVED, 1) && + verifier.EndTable(); } }; struct TensorArgBuilder { typedef TensorArg Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_ptr(uint64_t ptr) { fbb_.AddElement(TensorArg::VT_PTR, ptr, 0); @@ -1087,20 +1089,28 @@ struct TensorArgBuilder { void add_sizes(::flatbuffers::Offset<::flatbuffers::Vector> sizes) { fbb_.AddOffset(TensorArg::VT_SIZES, sizes); } - void add_strides(::flatbuffers::Offset<::flatbuffers::Vector> strides) { + void add_strides( + ::flatbuffers::Offset<::flatbuffers::Vector> strides) { fbb_.AddOffset(TensorArg::VT_STRIDES, strides); } void add_dtype(nvfuser::serde::DataType dtype) { - fbb_.AddElement(TensorArg::VT_DTYPE, static_cast(dtype), 0); + fbb_.AddElement( + TensorArg::VT_DTYPE, static_cast(dtype), 0); } void add_is_int_index_mode(bool is_int_index_mode) { - fbb_.AddElement(TensorArg::VT_IS_INT_INDEX_MODE, static_cast(is_int_index_mode), 0); + fbb_.AddElement( + TensorArg::VT_IS_INT_INDEX_MODE, + static_cast(is_int_index_mode), + 0); } void add_index_type_resolved(bool index_type_resolved) { - fbb_.AddElement(TensorArg::VT_INDEX_TYPE_RESOLVED, static_cast(index_type_resolved), 0); + fbb_.AddElement( + TensorArg::VT_INDEX_TYPE_RESOLVED, + static_cast(index_type_resolved), + 0); } - explicit TensorArgBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit TensorArgBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1111,7 +1121,7 @@ struct TensorArgBuilder { }; inline ::flatbuffers::Offset CreateTensorArg( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, uint64_t ptr = 0, ::flatbuffers::Offset<::flatbuffers::Vector> sizes = 0, ::flatbuffers::Offset<::flatbuffers::Vector> strides = 0, @@ -1129,10 +1139,10 @@ inline ::flatbuffers::Offset CreateTensorArg( } inline ::flatbuffers::Offset CreateTensorArgDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, uint64_t ptr = 0, - const std::vector *sizes = nullptr, - const std::vector *strides = nullptr, + const std::vector* sizes = nullptr, + const std::vector* strides = nullptr, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double, bool is_int_index_mode = false, bool index_type_resolved = false) { @@ -1155,61 +1165,80 @@ struct ArgAbstract FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_DATA = 6 }; nvfuser::serde::ArgAbstractData data_type() const { - return static_cast(GetField(VT_DATA_TYPE, 0)); - } - const void *data() const { - return GetPointer(VT_DATA); - } - template const T *data_as() const; - const nvfuser::serde::Scalar *data_as_Scalar() const { - return data_type() == nvfuser::serde::ArgAbstractData_Scalar ? static_cast(data()) : nullptr; - } - const nvfuser::serde::PhiloxCudaState *data_as_PhiloxCudaState() const { - return data_type() == nvfuser::serde::ArgAbstractData_PhiloxCudaState ? static_cast(data()) : nullptr; - } - const nvfuser::serde::ScalarCpu *data_as_ScalarCpu() const { - return data_type() == nvfuser::serde::ArgAbstractData_ScalarCpu ? static_cast(data()) : nullptr; - } - const nvfuser::serde::TensorArg *data_as_TensorArg() const { - return data_type() == nvfuser::serde::ArgAbstractData_TensorArg ? static_cast(data()) : nullptr; - } - bool Verify(::flatbuffers::Verifier &verifier) const { + return static_cast( + GetField(VT_DATA_TYPE, 0)); + } + const void* data() const { + return GetPointer(VT_DATA); + } + template + const T* data_as() const; + const nvfuser::serde::Scalar* data_as_Scalar() const { + return data_type() == nvfuser::serde::ArgAbstractData_Scalar + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::PhiloxCudaState* data_as_PhiloxCudaState() const { + return data_type() == nvfuser::serde::ArgAbstractData_PhiloxCudaState + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::ScalarCpu* data_as_ScalarCpu() const { + return data_type() == nvfuser::serde::ArgAbstractData_ScalarCpu + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::TensorArg* data_as_TensorArg() const { + return data_type() == nvfuser::serde::ArgAbstractData_TensorArg + ? static_cast(data()) + : nullptr; + } + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_DATA_TYPE, 1) && - VerifyOffset(verifier, VT_DATA) && - VerifyArgAbstractData(verifier, data(), data_type()) && - verifier.EndTable(); + VerifyField(verifier, VT_DATA_TYPE, 1) && + VerifyOffset(verifier, VT_DATA) && + VerifyArgAbstractData(verifier, data(), data_type()) && + verifier.EndTable(); } }; -template<> inline const nvfuser::serde::Scalar *ArgAbstract::data_as() const { +template <> +inline const nvfuser::serde::Scalar* ArgAbstract::data_as< + nvfuser::serde::Scalar>() const { return data_as_Scalar(); } -template<> inline const nvfuser::serde::PhiloxCudaState *ArgAbstract::data_as() const { +template <> +inline const nvfuser::serde::PhiloxCudaState* ArgAbstract::data_as< + nvfuser::serde::PhiloxCudaState>() const { return data_as_PhiloxCudaState(); } -template<> inline const nvfuser::serde::ScalarCpu *ArgAbstract::data_as() const { +template <> +inline const nvfuser::serde::ScalarCpu* ArgAbstract::data_as< + nvfuser::serde::ScalarCpu>() const { return data_as_ScalarCpu(); } -template<> inline const nvfuser::serde::TensorArg *ArgAbstract::data_as() const { +template <> +inline const nvfuser::serde::TensorArg* ArgAbstract::data_as< + nvfuser::serde::TensorArg>() const { return data_as_TensorArg(); } struct ArgAbstractBuilder { typedef ArgAbstract Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_data_type(nvfuser::serde::ArgAbstractData data_type) { - fbb_.AddElement(ArgAbstract::VT_DATA_TYPE, static_cast(data_type), 0); + fbb_.AddElement( + ArgAbstract::VT_DATA_TYPE, static_cast(data_type), 0); } void add_data(::flatbuffers::Offset data) { fbb_.AddOffset(ArgAbstract::VT_DATA, data); } - explicit ArgAbstractBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit ArgAbstractBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1220,8 +1249,9 @@ struct ArgAbstractBuilder { }; inline ::flatbuffers::Offset CreateArgAbstract( - ::flatbuffers::FlatBufferBuilder &_fbb, - nvfuser::serde::ArgAbstractData data_type = nvfuser::serde::ArgAbstractData_NONE, + ::flatbuffers::FlatBufferBuilder& _fbb, + nvfuser::serde::ArgAbstractData data_type = + nvfuser::serde::ArgAbstractData_NONE, ::flatbuffers::Offset data = 0) { ArgAbstractBuilder builder_(_fbb); builder_.add_data(data); @@ -1229,15 +1259,19 @@ inline ::flatbuffers::Offset CreateArgAbstract( return builder_.Finish(); } -struct KernelArgumentHolder FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { +struct KernelArgumentHolder FLATBUFFERS_FINAL_CLASS + : private ::flatbuffers::Table { typedef KernelArgumentHolderBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_ARGUMENTS = 4, VT_DEVICE_INDEX = 6, VT_CACHE_ID = 8 }; - const ::flatbuffers::Vector<::flatbuffers::Offset> *arguments() const { - return GetPointer> *>(VT_ARGUMENTS); + const ::flatbuffers::Vector< + ::flatbuffers::Offset>* + arguments() const { + return GetPointer>*>(VT_ARGUMENTS); } int8_t device_index() const { return GetField(VT_DEVICE_INDEX, 0); @@ -1245,32 +1279,33 @@ struct KernelArgumentHolder FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Tab uint64_t cache_id() const { return GetField(VT_CACHE_ID, 0); } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_ARGUMENTS) && - verifier.VerifyVector(arguments()) && - verifier.VerifyVectorOfTables(arguments()) && - VerifyField(verifier, VT_DEVICE_INDEX, 1) && - VerifyField(verifier, VT_CACHE_ID, 8) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ARGUMENTS) && + verifier.VerifyVector(arguments()) && + verifier.VerifyVectorOfTables(arguments()) && + VerifyField(verifier, VT_DEVICE_INDEX, 1) && + VerifyField(verifier, VT_CACHE_ID, 8) && verifier.EndTable(); } }; struct KernelArgumentHolderBuilder { typedef KernelArgumentHolder Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_arguments(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> arguments) { + void add_arguments( + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> arguments) { fbb_.AddOffset(KernelArgumentHolder::VT_ARGUMENTS, arguments); } void add_device_index(int8_t device_index) { - fbb_.AddElement(KernelArgumentHolder::VT_DEVICE_INDEX, device_index, 0); + fbb_.AddElement( + KernelArgumentHolder::VT_DEVICE_INDEX, device_index, 0); } void add_cache_id(uint64_t cache_id) { fbb_.AddElement(KernelArgumentHolder::VT_CACHE_ID, cache_id, 0); } - explicit KernelArgumentHolderBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit KernelArgumentHolderBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1281,8 +1316,9 @@ struct KernelArgumentHolderBuilder { }; inline ::flatbuffers::Offset CreateKernelArgumentHolder( - ::flatbuffers::FlatBufferBuilder &_fbb, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> arguments = 0, + ::flatbuffers::FlatBufferBuilder& _fbb, + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> arguments = 0, int8_t device_index = 0, uint64_t cache_id = 0) { KernelArgumentHolderBuilder builder_(_fbb); @@ -1292,17 +1328,19 @@ inline ::flatbuffers::Offset CreateKernelArgumentHolder( return builder_.Finish(); } -inline ::flatbuffers::Offset CreateKernelArgumentHolderDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector<::flatbuffers::Offset> *arguments = nullptr, +inline ::flatbuffers::Offset +CreateKernelArgumentHolderDirect( + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector<::flatbuffers::Offset>* + arguments = nullptr, int8_t device_index = 0, uint64_t cache_id = 0) { - auto arguments__ = arguments ? _fbb.CreateVector<::flatbuffers::Offset>(*arguments) : 0; + auto arguments__ = arguments + ? _fbb.CreateVector<::flatbuffers::Offset>( + *arguments) + : 0; return nvfuser::serde::CreateKernelArgumentHolder( - _fbb, - arguments__, - device_index, - cache_id); + _fbb, arguments__, device_index, cache_id); } struct LaunchParams FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -1338,28 +1376,30 @@ struct LaunchParams FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int64_t smem() const { return GetField(VT_SMEM, 0); } - const ::flatbuffers::Vector<::flatbuffers::Offset> *output_sizes() const { - return GetPointer> *>(VT_OUTPUT_SIZES); + const ::flatbuffers::Vector< + ::flatbuffers::Offset>* + output_sizes() const { + return GetPointer>*>(VT_OUTPUT_SIZES); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_GDIMX, 8) && - VerifyField(verifier, VT_GDIMY, 8) && - VerifyField(verifier, VT_GDIMZ, 8) && - VerifyField(verifier, VT_BDIMX, 8) && - VerifyField(verifier, VT_BDIMY, 8) && - VerifyField(verifier, VT_BDIMZ, 8) && - VerifyField(verifier, VT_SMEM, 8) && - VerifyOffset(verifier, VT_OUTPUT_SIZES) && - verifier.VerifyVector(output_sizes()) && - verifier.VerifyVectorOfTables(output_sizes()) && - verifier.EndTable(); + VerifyField(verifier, VT_GDIMX, 8) && + VerifyField(verifier, VT_GDIMY, 8) && + VerifyField(verifier, VT_GDIMZ, 8) && + VerifyField(verifier, VT_BDIMX, 8) && + VerifyField(verifier, VT_BDIMY, 8) && + VerifyField(verifier, VT_BDIMZ, 8) && + VerifyField(verifier, VT_SMEM, 8) && + VerifyOffset(verifier, VT_OUTPUT_SIZES) && + verifier.VerifyVector(output_sizes()) && + verifier.VerifyVectorOfTables(output_sizes()) && verifier.EndTable(); } }; struct LaunchParamsBuilder { typedef LaunchParams Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_gdimx(int64_t gdimx) { fbb_.AddElement(LaunchParams::VT_GDIMX, gdimx, 0); @@ -1382,11 +1422,13 @@ struct LaunchParamsBuilder { void add_smem(int64_t smem) { fbb_.AddElement(LaunchParams::VT_SMEM, smem, 0); } - void add_output_sizes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> output_sizes) { + void add_output_sizes( + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> output_sizes) { fbb_.AddOffset(LaunchParams::VT_OUTPUT_SIZES, output_sizes); } - explicit LaunchParamsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit LaunchParamsBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1397,7 +1439,7 @@ struct LaunchParamsBuilder { }; inline ::flatbuffers::Offset CreateLaunchParams( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, int64_t gdimx = 0, int64_t gdimy = 0, int64_t gdimz = 0, @@ -1405,7 +1447,8 @@ inline ::flatbuffers::Offset CreateLaunchParams( int64_t bdimy = 0, int64_t bdimz = 0, int64_t smem = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> output_sizes = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> output_sizes = 0) { LaunchParamsBuilder builder_(_fbb); builder_.add_smem(smem); builder_.add_bdimz(bdimz); @@ -1419,7 +1462,7 @@ inline ::flatbuffers::Offset CreateLaunchParams( } inline ::flatbuffers::Offset CreateLaunchParamsDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, int64_t gdimx = 0, int64_t gdimy = 0, int64_t gdimz = 0, @@ -1427,18 +1470,14 @@ inline ::flatbuffers::Offset CreateLaunchParamsDirect( int64_t bdimy = 0, int64_t bdimz = 0, int64_t smem = 0, - const std::vector<::flatbuffers::Offset> *output_sizes = nullptr) { - auto output_sizes__ = output_sizes ? _fbb.CreateVector<::flatbuffers::Offset>(*output_sizes) : 0; + const std::vector<::flatbuffers::Offset>* + output_sizes = nullptr) { + auto output_sizes__ = output_sizes + ? _fbb.CreateVector<::flatbuffers::Offset>( + *output_sizes) + : 0; return nvfuser::serde::CreateLaunchParams( - _fbb, - gdimx, - gdimy, - gdimz, - bdimx, - bdimy, - bdimz, - smem, - output_sizes__); + _fbb, gdimx, gdimy, gdimz, bdimx, bdimy, bdimz, smem, output_sizes__); } struct GlobalBufferInfo FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -1455,14 +1494,15 @@ struct GlobalBufferInfo FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int64_t tv() const { return GetField(VT_TV, -1LL); } - const ::flatbuffers::Vector *sizes() const { - return GetPointer *>(VT_SIZES); + const ::flatbuffers::Vector* sizes() const { + return GetPointer*>(VT_SIZES); } - const ::flatbuffers::Vector *strides() const { - return GetPointer *>(VT_STRIDES); + const ::flatbuffers::Vector* strides() const { + return GetPointer*>(VT_STRIDES); } nvfuser::serde::DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + return static_cast( + GetField(VT_DTYPE, 0)); } bool zero_init() const { return GetField(VT_ZERO_INIT, 0) != 0; @@ -1473,24 +1513,23 @@ struct GlobalBufferInfo FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool is_fusion_output() const { return GetField(VT_IS_FUSION_OUTPUT, 0) != 0; } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_TV, 8) && - VerifyOffset(verifier, VT_SIZES) && - verifier.VerifyVector(sizes()) && - VerifyOffset(verifier, VT_STRIDES) && - verifier.VerifyVector(strides()) && - VerifyField(verifier, VT_DTYPE, 4) && - VerifyField(verifier, VT_ZERO_INIT, 1) && - VerifyField(verifier, VT_IS_PROFILE_BUFFER, 1) && - VerifyField(verifier, VT_IS_FUSION_OUTPUT, 1) && - verifier.EndTable(); + VerifyField(verifier, VT_TV, 8) && + VerifyOffset(verifier, VT_SIZES) && verifier.VerifyVector(sizes()) && + VerifyOffset(verifier, VT_STRIDES) && + verifier.VerifyVector(strides()) && + VerifyField(verifier, VT_DTYPE, 4) && + VerifyField(verifier, VT_ZERO_INIT, 1) && + VerifyField(verifier, VT_IS_PROFILE_BUFFER, 1) && + VerifyField(verifier, VT_IS_FUSION_OUTPUT, 1) && + verifier.EndTable(); } }; struct GlobalBufferInfoBuilder { typedef GlobalBufferInfo Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_tv(int64_t tv) { fbb_.AddElement(GlobalBufferInfo::VT_TV, tv, -1LL); @@ -1498,23 +1537,32 @@ struct GlobalBufferInfoBuilder { void add_sizes(::flatbuffers::Offset<::flatbuffers::Vector> sizes) { fbb_.AddOffset(GlobalBufferInfo::VT_SIZES, sizes); } - void add_strides(::flatbuffers::Offset<::flatbuffers::Vector> strides) { + void add_strides( + ::flatbuffers::Offset<::flatbuffers::Vector> strides) { fbb_.AddOffset(GlobalBufferInfo::VT_STRIDES, strides); } void add_dtype(nvfuser::serde::DataType dtype) { - fbb_.AddElement(GlobalBufferInfo::VT_DTYPE, static_cast(dtype), 0); + fbb_.AddElement( + GlobalBufferInfo::VT_DTYPE, static_cast(dtype), 0); } void add_zero_init(bool zero_init) { - fbb_.AddElement(GlobalBufferInfo::VT_ZERO_INIT, static_cast(zero_init), 0); + fbb_.AddElement( + GlobalBufferInfo::VT_ZERO_INIT, static_cast(zero_init), 0); } void add_is_profile_buffer(bool is_profile_buffer) { - fbb_.AddElement(GlobalBufferInfo::VT_IS_PROFILE_BUFFER, static_cast(is_profile_buffer), 0); + fbb_.AddElement( + GlobalBufferInfo::VT_IS_PROFILE_BUFFER, + static_cast(is_profile_buffer), + 0); } void add_is_fusion_output(bool is_fusion_output) { - fbb_.AddElement(GlobalBufferInfo::VT_IS_FUSION_OUTPUT, static_cast(is_fusion_output), 0); + fbb_.AddElement( + GlobalBufferInfo::VT_IS_FUSION_OUTPUT, + static_cast(is_fusion_output), + 0); } - explicit GlobalBufferInfoBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit GlobalBufferInfoBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1525,7 +1573,7 @@ struct GlobalBufferInfoBuilder { }; inline ::flatbuffers::Offset CreateGlobalBufferInfo( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, int64_t tv = -1LL, ::flatbuffers::Offset<::flatbuffers::Vector> sizes = 0, ::flatbuffers::Offset<::flatbuffers::Vector> strides = 0, @@ -1545,10 +1593,10 @@ inline ::flatbuffers::Offset CreateGlobalBufferInfo( } inline ::flatbuffers::Offset CreateGlobalBufferInfoDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, int64_t tv = -1LL, - const std::vector *sizes = nullptr, - const std::vector *strides = nullptr, + const std::vector* sizes = nullptr, + const std::vector* strides = nullptr, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double, bool zero_init = false, bool is_profile_buffer = false, @@ -1580,71 +1628,87 @@ struct ExecutorEntry FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool init() const { return GetField(VT_INIT, 0) != 0; } - const nvfuser::serde::LaunchParams *launch_params() const { - return GetPointer(VT_LAUNCH_PARAMS); + const nvfuser::serde::LaunchParams* launch_params() const { + return GetPointer(VT_LAUNCH_PARAMS); } - const ::flatbuffers::Vector *output_aliases() const { - return GetPointer *>(VT_OUTPUT_ALIASES); + const ::flatbuffers::Vector* output_aliases() const { + return GetPointer*>(VT_OUTPUT_ALIASES); } - const ::flatbuffers::Vector *input_aliases() const { - return GetPointer *>(VT_INPUT_ALIASES); + const ::flatbuffers::Vector* input_aliases() const { + return GetPointer*>(VT_INPUT_ALIASES); } - const ::flatbuffers::Vector<::flatbuffers::Offset> *outputs() const { - return GetPointer> *>(VT_OUTPUTS); + const ::flatbuffers::Vector< + ::flatbuffers::Offset>* + outputs() const { + return GetPointer>*>(VT_OUTPUTS); } - const ::flatbuffers::Vector<::flatbuffers::Offset> *intermediates() const { - return GetPointer> *>(VT_INTERMEDIATES); + const ::flatbuffers::Vector< + ::flatbuffers::Offset>* + intermediates() const { + return GetPointer>*>( + VT_INTERMEDIATES); } uint64_t rand_offset() const { return GetField(VT_RAND_OFFSET, 0); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_INIT, 1) && - VerifyOffset(verifier, VT_LAUNCH_PARAMS) && - verifier.VerifyTable(launch_params()) && - VerifyOffset(verifier, VT_OUTPUT_ALIASES) && - verifier.VerifyVector(output_aliases()) && - VerifyOffset(verifier, VT_INPUT_ALIASES) && - verifier.VerifyVector(input_aliases()) && - VerifyOffset(verifier, VT_OUTPUTS) && - verifier.VerifyVector(outputs()) && - verifier.VerifyVectorOfTables(outputs()) && - VerifyOffset(verifier, VT_INTERMEDIATES) && - verifier.VerifyVector(intermediates()) && - verifier.VerifyVectorOfTables(intermediates()) && - VerifyField(verifier, VT_RAND_OFFSET, 8) && - verifier.EndTable(); + VerifyField(verifier, VT_INIT, 1) && + VerifyOffset(verifier, VT_LAUNCH_PARAMS) && + verifier.VerifyTable(launch_params()) && + VerifyOffset(verifier, VT_OUTPUT_ALIASES) && + verifier.VerifyVector(output_aliases()) && + VerifyOffset(verifier, VT_INPUT_ALIASES) && + verifier.VerifyVector(input_aliases()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfTables(outputs()) && + VerifyOffset(verifier, VT_INTERMEDIATES) && + verifier.VerifyVector(intermediates()) && + verifier.VerifyVectorOfTables(intermediates()) && + VerifyField(verifier, VT_RAND_OFFSET, 8) && + verifier.EndTable(); } }; struct ExecutorEntryBuilder { typedef ExecutorEntry Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_init(bool init) { - fbb_.AddElement(ExecutorEntry::VT_INIT, static_cast(init), 0); + fbb_.AddElement( + ExecutorEntry::VT_INIT, static_cast(init), 0); } - void add_launch_params(::flatbuffers::Offset launch_params) { + void add_launch_params( + ::flatbuffers::Offset launch_params) { fbb_.AddOffset(ExecutorEntry::VT_LAUNCH_PARAMS, launch_params); } - void add_output_aliases(::flatbuffers::Offset<::flatbuffers::Vector> output_aliases) { + void add_output_aliases( + ::flatbuffers::Offset<::flatbuffers::Vector> output_aliases) { fbb_.AddOffset(ExecutorEntry::VT_OUTPUT_ALIASES, output_aliases); } - void add_input_aliases(::flatbuffers::Offset<::flatbuffers::Vector> input_aliases) { + void add_input_aliases( + ::flatbuffers::Offset<::flatbuffers::Vector> input_aliases) { fbb_.AddOffset(ExecutorEntry::VT_INPUT_ALIASES, input_aliases); } - void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> outputs) { + void add_outputs( + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> outputs) { fbb_.AddOffset(ExecutorEntry::VT_OUTPUTS, outputs); } - void add_intermediates(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> intermediates) { + void add_intermediates( + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> + intermediates) { fbb_.AddOffset(ExecutorEntry::VT_INTERMEDIATES, intermediates); } void add_rand_offset(uint64_t rand_offset) { fbb_.AddElement(ExecutorEntry::VT_RAND_OFFSET, rand_offset, 0); } - explicit ExecutorEntryBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit ExecutorEntryBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1655,13 +1719,15 @@ struct ExecutorEntryBuilder { }; inline ::flatbuffers::Offset CreateExecutorEntry( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, bool init = false, ::flatbuffers::Offset launch_params = 0, ::flatbuffers::Offset<::flatbuffers::Vector> output_aliases = 0, ::flatbuffers::Offset<::flatbuffers::Vector> input_aliases = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> outputs = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> intermediates = 0, + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> outputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset< + nvfuser::serde::GlobalBufferInfo>>> intermediates = 0, uint64_t rand_offset = 0) { ExecutorEntryBuilder builder_(_fbb); builder_.add_rand_offset(rand_offset); @@ -1675,18 +1741,29 @@ inline ::flatbuffers::Offset CreateExecutorEntry( } inline ::flatbuffers::Offset CreateExecutorEntryDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, bool init = false, ::flatbuffers::Offset launch_params = 0, - const std::vector *output_aliases = nullptr, - const std::vector *input_aliases = nullptr, - const std::vector<::flatbuffers::Offset> *outputs = nullptr, - const std::vector<::flatbuffers::Offset> *intermediates = nullptr, + const std::vector* output_aliases = nullptr, + const std::vector* input_aliases = nullptr, + const std::vector<::flatbuffers::Offset>* + outputs = nullptr, + const std::vector<::flatbuffers::Offset>* + intermediates = nullptr, uint64_t rand_offset = 0) { - auto output_aliases__ = output_aliases ? _fbb.CreateVector(*output_aliases) : 0; - auto input_aliases__ = input_aliases ? _fbb.CreateVector(*input_aliases) : 0; - auto outputs__ = outputs ? _fbb.CreateVector<::flatbuffers::Offset>(*outputs) : 0; - auto intermediates__ = intermediates ? _fbb.CreateVector<::flatbuffers::Offset>(*intermediates) : 0; + auto output_aliases__ = + output_aliases ? _fbb.CreateVector(*output_aliases) : 0; + auto input_aliases__ = + input_aliases ? _fbb.CreateVector(*input_aliases) : 0; + auto outputs__ = outputs + ? _fbb.CreateVector< + ::flatbuffers::Offset>(*outputs) + : 0; + auto intermediates__ = intermediates + ? _fbb.CreateVector< + ::flatbuffers::Offset>( + *intermediates) + : 0; return nvfuser::serde::CreateExecutorEntry( _fbb, init, @@ -1710,26 +1787,28 @@ struct BatchNorm FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool channels_last() const { return GetField(VT_CHANNELS_LAST, 0) != 0; } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_TRAINING, 1) && - VerifyField(verifier, VT_CHANNELS_LAST, 1) && - verifier.EndTable(); + VerifyField(verifier, VT_TRAINING, 1) && + VerifyField(verifier, VT_CHANNELS_LAST, 1) && + verifier.EndTable(); } }; struct BatchNormBuilder { typedef BatchNorm Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_training(bool training) { - fbb_.AddElement(BatchNorm::VT_TRAINING, static_cast(training), 0); + fbb_.AddElement( + BatchNorm::VT_TRAINING, static_cast(training), 0); } void add_channels_last(bool channels_last) { - fbb_.AddElement(BatchNorm::VT_CHANNELS_LAST, static_cast(channels_last), 0); + fbb_.AddElement( + BatchNorm::VT_CHANNELS_LAST, static_cast(channels_last), 0); } - explicit BatchNormBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit BatchNormBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1740,7 +1819,7 @@ struct BatchNormBuilder { }; inline ::flatbuffers::Offset CreateBatchNorm( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, bool training = false, bool channels_last = false) { BatchNormBuilder builder_(_fbb); @@ -1754,26 +1833,26 @@ struct Broadcast FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_BROADCAST_DIMS = 4 }; - const ::flatbuffers::Vector *broadcast_dims() const { - return GetPointer *>(VT_BROADCAST_DIMS); + const ::flatbuffers::Vector* broadcast_dims() const { + return GetPointer*>(VT_BROADCAST_DIMS); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_BROADCAST_DIMS) && - verifier.VerifyVector(broadcast_dims()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_BROADCAST_DIMS) && + verifier.VerifyVector(broadcast_dims()) && verifier.EndTable(); } }; struct BroadcastBuilder { typedef Broadcast Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_broadcast_dims(::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dims) { + void add_broadcast_dims( + ::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dims) { fbb_.AddOffset(Broadcast::VT_BROADCAST_DIMS, broadcast_dims); } - explicit BroadcastBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit BroadcastBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1784,7 +1863,7 @@ struct BroadcastBuilder { }; inline ::flatbuffers::Offset CreateBroadcast( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dims = 0) { BroadcastBuilder builder_(_fbb); builder_.add_broadcast_dims(broadcast_dims); @@ -1792,12 +1871,11 @@ inline ::flatbuffers::Offset CreateBroadcast( } inline ::flatbuffers::Offset CreateBroadcastDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *broadcast_dims = nullptr) { - auto broadcast_dims__ = broadcast_dims ? _fbb.CreateVector(*broadcast_dims) : 0; - return nvfuser::serde::CreateBroadcast( - _fbb, - broadcast_dims__); + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* broadcast_dims = nullptr) { + auto broadcast_dims__ = + broadcast_dims ? _fbb.CreateVector(*broadcast_dims) : 0; + return nvfuser::serde::CreateBroadcast(_fbb, broadcast_dims__); } struct BroadcastInDim FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -1806,34 +1884,35 @@ struct BroadcastInDim FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_OUTPUT_SHAPE = 4, VT_BROADCAST_DIMS = 6 }; - const ::flatbuffers::Vector *output_shape() const { - return GetPointer *>(VT_OUTPUT_SHAPE); + const ::flatbuffers::Vector* output_shape() const { + return GetPointer*>(VT_OUTPUT_SHAPE); } - const ::flatbuffers::Vector *broadcast_dims() const { - return GetPointer *>(VT_BROADCAST_DIMS); + const ::flatbuffers::Vector* broadcast_dims() const { + return GetPointer*>(VT_BROADCAST_DIMS); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_OUTPUT_SHAPE) && - verifier.VerifyVector(output_shape()) && - VerifyOffset(verifier, VT_BROADCAST_DIMS) && - verifier.VerifyVector(broadcast_dims()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_OUTPUT_SHAPE) && + verifier.VerifyVector(output_shape()) && + VerifyOffset(verifier, VT_BROADCAST_DIMS) && + verifier.VerifyVector(broadcast_dims()) && verifier.EndTable(); } }; struct BroadcastInDimBuilder { typedef BroadcastInDim Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_output_shape(::flatbuffers::Offset<::flatbuffers::Vector> output_shape) { + void add_output_shape( + ::flatbuffers::Offset<::flatbuffers::Vector> output_shape) { fbb_.AddOffset(BroadcastInDim::VT_OUTPUT_SHAPE, output_shape); } - void add_broadcast_dims(::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dims) { + void add_broadcast_dims( + ::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dims) { fbb_.AddOffset(BroadcastInDim::VT_BROADCAST_DIMS, broadcast_dims); } - explicit BroadcastInDimBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit BroadcastInDimBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1844,7 +1923,7 @@ struct BroadcastInDimBuilder { }; inline ::flatbuffers::Offset CreateBroadcastInDim( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> output_shape = 0, ::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dims = 0) { BroadcastInDimBuilder builder_(_fbb); @@ -1854,51 +1933,57 @@ inline ::flatbuffers::Offset CreateBroadcastInDim( } inline ::flatbuffers::Offset CreateBroadcastInDimDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *output_shape = nullptr, - const std::vector *broadcast_dims = nullptr) { - auto output_shape__ = output_shape ? _fbb.CreateVector(*output_shape) : 0; - auto broadcast_dims__ = broadcast_dims ? _fbb.CreateVector(*broadcast_dims) : 0; + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* output_shape = nullptr, + const std::vector* broadcast_dims = nullptr) { + auto output_shape__ = + output_shape ? _fbb.CreateVector(*output_shape) : 0; + auto broadcast_dims__ = + broadcast_dims ? _fbb.CreateVector(*broadcast_dims) : 0; return nvfuser::serde::CreateBroadcastInDim( - _fbb, - output_shape__, - broadcast_dims__); + _fbb, output_shape__, broadcast_dims__); } -struct BroadcastInDimSymbolic FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { +struct BroadcastInDimSymbolic FLATBUFFERS_FINAL_CLASS + : private ::flatbuffers::Table { typedef BroadcastInDimSymbolicBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_OUTPUT_SHAPE = 4, VT_BROADCAST_DIMS = 6 }; - const ::flatbuffers::Vector *output_shape() const { - return GetPointer *>(VT_OUTPUT_SHAPE); + const ::flatbuffers::Vector* output_shape() + const { + return GetPointer< + const ::flatbuffers::Vector*>( + VT_OUTPUT_SHAPE); } - const ::flatbuffers::Vector *broadcast_dims() const { - return GetPointer *>(VT_BROADCAST_DIMS); + const ::flatbuffers::Vector* broadcast_dims() const { + return GetPointer*>(VT_BROADCAST_DIMS); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_OUTPUT_SHAPE) && - verifier.VerifyVector(output_shape()) && - VerifyOffset(verifier, VT_BROADCAST_DIMS) && - verifier.VerifyVector(broadcast_dims()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_OUTPUT_SHAPE) && + verifier.VerifyVector(output_shape()) && + VerifyOffset(verifier, VT_BROADCAST_DIMS) && + verifier.VerifyVector(broadcast_dims()) && verifier.EndTable(); } }; struct BroadcastInDimSymbolicBuilder { typedef BroadcastInDimSymbolic Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_output_shape(::flatbuffers::Offset<::flatbuffers::Vector> output_shape) { + void add_output_shape( + ::flatbuffers::Offset<::flatbuffers::Vector> + output_shape) { fbb_.AddOffset(BroadcastInDimSymbolic::VT_OUTPUT_SHAPE, output_shape); } - void add_broadcast_dims(::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dims) { + void add_broadcast_dims( + ::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dims) { fbb_.AddOffset(BroadcastInDimSymbolic::VT_BROADCAST_DIMS, broadcast_dims); } - explicit BroadcastInDimSymbolicBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit BroadcastInDimSymbolicBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1908,9 +1993,11 @@ struct BroadcastInDimSymbolicBuilder { } }; -inline ::flatbuffers::Offset CreateBroadcastInDimSymbolic( - ::flatbuffers::FlatBufferBuilder &_fbb, - ::flatbuffers::Offset<::flatbuffers::Vector> output_shape = 0, +inline ::flatbuffers::Offset +CreateBroadcastInDimSymbolic( + ::flatbuffers::FlatBufferBuilder& _fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> + output_shape = 0, ::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dims = 0) { BroadcastInDimSymbolicBuilder builder_(_fbb); builder_.add_broadcast_dims(broadcast_dims); @@ -1918,16 +2005,18 @@ inline ::flatbuffers::Offset CreateBroadcastInDimSymboli return builder_.Finish(); } -inline ::flatbuffers::Offset CreateBroadcastInDimSymbolicDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *output_shape = nullptr, - const std::vector *broadcast_dims = nullptr) { - auto output_shape__ = output_shape ? _fbb.CreateVectorOfStructs(*output_shape) : 0; - auto broadcast_dims__ = broadcast_dims ? _fbb.CreateVector(*broadcast_dims) : 0; +inline ::flatbuffers::Offset +CreateBroadcastInDimSymbolicDirect( + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* output_shape = nullptr, + const std::vector* broadcast_dims = nullptr) { + auto output_shape__ = output_shape + ? _fbb.CreateVectorOfStructs(*output_shape) + : 0; + auto broadcast_dims__ = + broadcast_dims ? _fbb.CreateVector(*broadcast_dims) : 0; return nvfuser::serde::CreateBroadcastInDimSymbolic( - _fbb, - output_shape__, - broadcast_dims__); + _fbb, output_shape__, broadcast_dims__); } struct Dtype FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -1936,24 +2025,23 @@ struct Dtype FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_DTYPE = 4 }; nvfuser::serde::DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + return static_cast( + GetField(VT_DTYPE, 0)); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_DTYPE, 4) && - verifier.EndTable(); + VerifyField(verifier, VT_DTYPE, 4) && verifier.EndTable(); } }; struct DtypeBuilder { typedef Dtype Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_dtype(nvfuser::serde::DataType dtype) { fbb_.AddElement(Dtype::VT_DTYPE, static_cast(dtype), 0); } - explicit DtypeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit DtypeBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -1964,7 +2052,7 @@ struct DtypeBuilder { }; inline ::flatbuffers::Offset CreateDtype( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double) { DtypeBuilder builder_(_fbb); builder_.add_dtype(dtype); @@ -1979,22 +2067,21 @@ struct Dimension FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int64_t dim() const { return GetField(VT_DIM, 0); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_DIM, 8) && - verifier.EndTable(); + VerifyField(verifier, VT_DIM, 8) && verifier.EndTable(); } }; struct DimensionBuilder { typedef Dimension Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_dim(int64_t dim) { fbb_.AddElement(Dimension::VT_DIM, dim, 0); } - explicit DimensionBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit DimensionBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2005,7 +2092,7 @@ struct DimensionBuilder { }; inline ::flatbuffers::Offset CreateDimension( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, int64_t dim = 0) { DimensionBuilder builder_(_fbb); builder_.add_dim(dim); @@ -2019,8 +2106,8 @@ struct Norm FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_CORRECTION = 6, VT_KEEP_DIM = 8 }; - const ::flatbuffers::Vector *axes() const { - return GetPointer *>(VT_AXES); + const ::flatbuffers::Vector* axes() const { + return GetPointer*>(VT_AXES); } int64_t correction() const { return GetField(VT_CORRECTION, 0); @@ -2028,19 +2115,17 @@ struct Norm FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool keep_dim() const { return GetField(VT_KEEP_DIM, 0) != 0; } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_AXES) && - verifier.VerifyVector(axes()) && - VerifyField(verifier, VT_CORRECTION, 8) && - VerifyField(verifier, VT_KEEP_DIM, 1) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_AXES) && + verifier.VerifyVector(axes()) && + VerifyField(verifier, VT_CORRECTION, 8) && + VerifyField(verifier, VT_KEEP_DIM, 1) && verifier.EndTable(); } }; struct NormBuilder { typedef Norm Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_axes(::flatbuffers::Offset<::flatbuffers::Vector> axes) { fbb_.AddOffset(Norm::VT_AXES, axes); @@ -2049,10 +2134,10 @@ struct NormBuilder { fbb_.AddElement(Norm::VT_CORRECTION, correction, 0); } void add_keep_dim(bool keep_dim) { - fbb_.AddElement(Norm::VT_KEEP_DIM, static_cast(keep_dim), 0); + fbb_.AddElement( + Norm::VT_KEEP_DIM, static_cast(keep_dim), 0); } - explicit NormBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit NormBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2063,7 +2148,7 @@ struct NormBuilder { }; inline ::flatbuffers::Offset CreateNorm( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> axes = 0, int64_t correction = 0, bool keep_dim = false) { @@ -2075,16 +2160,12 @@ inline ::flatbuffers::Offset CreateNorm( } inline ::flatbuffers::Offset CreateNormDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *axes = nullptr, + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* axes = nullptr, int64_t correction = 0, bool keep_dim = false) { auto axes__ = axes ? _fbb.CreateVector(*axes) : 0; - return nvfuser::serde::CreateNorm( - _fbb, - axes__, - correction, - keep_dim); + return nvfuser::serde::CreateNorm(_fbb, axes__, correction, keep_dim); } struct Output FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2092,26 +2173,25 @@ struct Output FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_STRIDE_ORDER = 4 }; - const ::flatbuffers::Vector *stride_order() const { - return GetPointer *>(VT_STRIDE_ORDER); + const ::flatbuffers::Vector* stride_order() const { + return GetPointer*>(VT_STRIDE_ORDER); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_STRIDE_ORDER) && - verifier.VerifyVector(stride_order()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_STRIDE_ORDER) && + verifier.VerifyVector(stride_order()) && verifier.EndTable(); } }; struct OutputBuilder { typedef Output Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_stride_order(::flatbuffers::Offset<::flatbuffers::Vector> stride_order) { + void add_stride_order( + ::flatbuffers::Offset<::flatbuffers::Vector> stride_order) { fbb_.AddOffset(Output::VT_STRIDE_ORDER, stride_order); } - explicit OutputBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit OutputBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2122,7 +2202,7 @@ struct OutputBuilder { }; inline ::flatbuffers::Offset CreateOutput( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> stride_order = 0) { OutputBuilder builder_(_fbb); builder_.add_stride_order(stride_order); @@ -2130,12 +2210,11 @@ inline ::flatbuffers::Offset CreateOutput( } inline ::flatbuffers::Offset CreateOutputDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *stride_order = nullptr) { - auto stride_order__ = stride_order ? _fbb.CreateVector(*stride_order) : 0; - return nvfuser::serde::CreateOutput( - _fbb, - stride_order__); + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* stride_order = nullptr) { + auto stride_order__ = + stride_order ? _fbb.CreateVector(*stride_order) : 0; + return nvfuser::serde::CreateOutput(_fbb, stride_order__); } struct Pad FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2143,26 +2222,25 @@ struct Pad FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_PAD_WIDTHS = 4 }; - const ::flatbuffers::Vector *pad_widths() const { - return GetPointer *>(VT_PAD_WIDTHS); + const ::flatbuffers::Vector* pad_widths() const { + return GetPointer*>(VT_PAD_WIDTHS); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_PAD_WIDTHS) && - verifier.VerifyVector(pad_widths()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_PAD_WIDTHS) && + verifier.VerifyVector(pad_widths()) && verifier.EndTable(); } }; struct PadBuilder { typedef Pad Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_pad_widths(::flatbuffers::Offset<::flatbuffers::Vector> pad_widths) { + void add_pad_widths( + ::flatbuffers::Offset<::flatbuffers::Vector> pad_widths) { fbb_.AddOffset(Pad::VT_PAD_WIDTHS, pad_widths); } - explicit PadBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit PadBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2173,7 +2251,7 @@ struct PadBuilder { }; inline ::flatbuffers::Offset CreatePad( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> pad_widths = 0) { PadBuilder builder_(_fbb); builder_.add_pad_widths(pad_widths); @@ -2181,12 +2259,10 @@ inline ::flatbuffers::Offset CreatePad( } inline ::flatbuffers::Offset CreatePadDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *pad_widths = nullptr) { + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* pad_widths = nullptr) { auto pad_widths__ = pad_widths ? _fbb.CreateVector(*pad_widths) : 0; - return nvfuser::serde::CreatePad( - _fbb, - pad_widths__); + return nvfuser::serde::CreatePad(_fbb, pad_widths__); } struct Permute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2194,26 +2270,23 @@ struct Permute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_DIMS = 4 }; - const ::flatbuffers::Vector *dims() const { - return GetPointer *>(VT_DIMS); + const ::flatbuffers::Vector* dims() const { + return GetPointer*>(VT_DIMS); } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_DIMS) && - verifier.VerifyVector(dims()) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DIMS) && + verifier.VerifyVector(dims()) && verifier.EndTable(); } }; struct PermuteBuilder { typedef Permute Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_dims(::flatbuffers::Offset<::flatbuffers::Vector> dims) { fbb_.AddOffset(Permute::VT_DIMS, dims); } - explicit PermuteBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit PermuteBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2224,7 +2297,7 @@ struct PermuteBuilder { }; inline ::flatbuffers::Offset CreatePermute( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> dims = 0) { PermuteBuilder builder_(_fbb); builder_.add_dims(dims); @@ -2232,12 +2305,10 @@ inline ::flatbuffers::Offset CreatePermute( } inline ::flatbuffers::Offset CreatePermuteDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *dims = nullptr) { + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* dims = nullptr) { auto dims__ = dims ? _fbb.CreateVector(*dims) : 0; - return nvfuser::serde::CreatePermute( - _fbb, - dims__); + return nvfuser::serde::CreatePermute(_fbb, dims__); } struct Reduction FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2247,40 +2318,41 @@ struct Reduction FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_KEEP_DIM = 6, VT_DTYPE = 8 }; - const ::flatbuffers::Vector *axes() const { - return GetPointer *>(VT_AXES); + const ::flatbuffers::Vector* axes() const { + return GetPointer*>(VT_AXES); } bool keep_dim() const { return GetField(VT_KEEP_DIM, 0) != 0; } nvfuser::serde::DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + return static_cast( + GetField(VT_DTYPE, 0)); } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_AXES) && - verifier.VerifyVector(axes()) && - VerifyField(verifier, VT_KEEP_DIM, 1) && - VerifyField(verifier, VT_DTYPE, 4) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_AXES) && + verifier.VerifyVector(axes()) && + VerifyField(verifier, VT_KEEP_DIM, 1) && + VerifyField(verifier, VT_DTYPE, 4) && verifier.EndTable(); } }; struct ReductionBuilder { typedef Reduction Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_axes(::flatbuffers::Offset<::flatbuffers::Vector> axes) { fbb_.AddOffset(Reduction::VT_AXES, axes); } void add_keep_dim(bool keep_dim) { - fbb_.AddElement(Reduction::VT_KEEP_DIM, static_cast(keep_dim), 0); + fbb_.AddElement( + Reduction::VT_KEEP_DIM, static_cast(keep_dim), 0); } void add_dtype(nvfuser::serde::DataType dtype) { - fbb_.AddElement(Reduction::VT_DTYPE, static_cast(dtype), 0); + fbb_.AddElement( + Reduction::VT_DTYPE, static_cast(dtype), 0); } - explicit ReductionBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit ReductionBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2291,7 +2363,7 @@ struct ReductionBuilder { }; inline ::flatbuffers::Offset CreateReduction( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> axes = 0, bool keep_dim = false, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double) { @@ -2303,16 +2375,12 @@ inline ::flatbuffers::Offset CreateReduction( } inline ::flatbuffers::Offset CreateReductionDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *axes = nullptr, + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* axes = nullptr, bool keep_dim = false, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double) { auto axes__ = axes ? _fbb.CreateVector(*axes) : 0; - return nvfuser::serde::CreateReduction( - _fbb, - axes__, - keep_dim, - dtype); + return nvfuser::serde::CreateReduction(_fbb, axes__, keep_dim, dtype); } struct Reshape FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2321,34 +2389,34 @@ struct Reshape FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_ORIGINAL_SHAPE = 4, VT_NEW_SHAPE = 6 }; - const ::flatbuffers::Vector *original_shape() const { - return GetPointer *>(VT_ORIGINAL_SHAPE); + const ::flatbuffers::Vector* original_shape() const { + return GetPointer*>(VT_ORIGINAL_SHAPE); } - const ::flatbuffers::Vector *new_shape() const { - return GetPointer *>(VT_NEW_SHAPE); + const ::flatbuffers::Vector* new_shape() const { + return GetPointer*>(VT_NEW_SHAPE); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_ORIGINAL_SHAPE) && - verifier.VerifyVector(original_shape()) && - VerifyOffset(verifier, VT_NEW_SHAPE) && - verifier.VerifyVector(new_shape()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_ORIGINAL_SHAPE) && + verifier.VerifyVector(original_shape()) && + VerifyOffset(verifier, VT_NEW_SHAPE) && + verifier.VerifyVector(new_shape()) && verifier.EndTable(); } }; struct ReshapeBuilder { typedef Reshape Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_original_shape(::flatbuffers::Offset<::flatbuffers::Vector> original_shape) { + void add_original_shape( + ::flatbuffers::Offset<::flatbuffers::Vector> original_shape) { fbb_.AddOffset(Reshape::VT_ORIGINAL_SHAPE, original_shape); } - void add_new_shape(::flatbuffers::Offset<::flatbuffers::Vector> new_shape) { + void add_new_shape( + ::flatbuffers::Offset<::flatbuffers::Vector> new_shape) { fbb_.AddOffset(Reshape::VT_NEW_SHAPE, new_shape); } - explicit ReshapeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit ReshapeBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2359,7 +2427,7 @@ struct ReshapeBuilder { }; inline ::flatbuffers::Offset CreateReshape( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> original_shape = 0, ::flatbuffers::Offset<::flatbuffers::Vector> new_shape = 0) { ReshapeBuilder builder_(_fbb); @@ -2369,15 +2437,13 @@ inline ::flatbuffers::Offset CreateReshape( } inline ::flatbuffers::Offset CreateReshapeDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *original_shape = nullptr, - const std::vector *new_shape = nullptr) { - auto original_shape__ = original_shape ? _fbb.CreateVector(*original_shape) : 0; + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* original_shape = nullptr, + const std::vector* new_shape = nullptr) { + auto original_shape__ = + original_shape ? _fbb.CreateVector(*original_shape) : 0; auto new_shape__ = new_shape ? _fbb.CreateVector(*new_shape) : 0; - return nvfuser::serde::CreateReshape( - _fbb, - original_shape__, - new_shape__); + return nvfuser::serde::CreateReshape(_fbb, original_shape__, new_shape__); } struct Slice FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2387,42 +2453,43 @@ struct Slice FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_END_INDICES = 6, VT_STRIDES = 8 }; - const ::flatbuffers::Vector *start_indices() const { - return GetPointer *>(VT_START_INDICES); + const ::flatbuffers::Vector* start_indices() const { + return GetPointer*>(VT_START_INDICES); } - const ::flatbuffers::Vector *end_indices() const { - return GetPointer *>(VT_END_INDICES); + const ::flatbuffers::Vector* end_indices() const { + return GetPointer*>(VT_END_INDICES); } - const ::flatbuffers::Vector *strides() const { - return GetPointer *>(VT_STRIDES); + const ::flatbuffers::Vector* strides() const { + return GetPointer*>(VT_STRIDES); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_START_INDICES) && - verifier.VerifyVector(start_indices()) && - VerifyOffset(verifier, VT_END_INDICES) && - verifier.VerifyVector(end_indices()) && - VerifyOffset(verifier, VT_STRIDES) && - verifier.VerifyVector(strides()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_START_INDICES) && + verifier.VerifyVector(start_indices()) && + VerifyOffset(verifier, VT_END_INDICES) && + verifier.VerifyVector(end_indices()) && + VerifyOffset(verifier, VT_STRIDES) && + verifier.VerifyVector(strides()) && verifier.EndTable(); } }; struct SliceBuilder { typedef Slice Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_start_indices(::flatbuffers::Offset<::flatbuffers::Vector> start_indices) { + void add_start_indices( + ::flatbuffers::Offset<::flatbuffers::Vector> start_indices) { fbb_.AddOffset(Slice::VT_START_INDICES, start_indices); } - void add_end_indices(::flatbuffers::Offset<::flatbuffers::Vector> end_indices) { + void add_end_indices( + ::flatbuffers::Offset<::flatbuffers::Vector> end_indices) { fbb_.AddOffset(Slice::VT_END_INDICES, end_indices); } - void add_strides(::flatbuffers::Offset<::flatbuffers::Vector> strides) { + void add_strides( + ::flatbuffers::Offset<::flatbuffers::Vector> strides) { fbb_.AddOffset(Slice::VT_STRIDES, strides); } - explicit SliceBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit SliceBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2433,7 +2500,7 @@ struct SliceBuilder { }; inline ::flatbuffers::Offset CreateSlice( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> start_indices = 0, ::flatbuffers::Offset<::flatbuffers::Vector> end_indices = 0, ::flatbuffers::Offset<::flatbuffers::Vector> strides = 0) { @@ -2445,18 +2512,17 @@ inline ::flatbuffers::Offset CreateSlice( } inline ::flatbuffers::Offset CreateSliceDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *start_indices = nullptr, - const std::vector *end_indices = nullptr, - const std::vector *strides = nullptr) { - auto start_indices__ = start_indices ? _fbb.CreateVector(*start_indices) : 0; - auto end_indices__ = end_indices ? _fbb.CreateVector(*end_indices) : 0; + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* start_indices = nullptr, + const std::vector* end_indices = nullptr, + const std::vector* strides = nullptr) { + auto start_indices__ = + start_indices ? _fbb.CreateVector(*start_indices) : 0; + auto end_indices__ = + end_indices ? _fbb.CreateVector(*end_indices) : 0; auto strides__ = strides ? _fbb.CreateVector(*strides) : 0; return nvfuser::serde::CreateSlice( - _fbb, - start_indices__, - end_indices__, - strides__); + _fbb, start_indices__, end_indices__, strides__); } struct Squeeze FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2465,34 +2531,34 @@ struct Squeeze FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_ORIGINAL_SHAPE = 4, VT_SQUEEZE_DIMS = 6 }; - const ::flatbuffers::Vector *original_shape() const { - return GetPointer *>(VT_ORIGINAL_SHAPE); + const ::flatbuffers::Vector* original_shape() const { + return GetPointer*>(VT_ORIGINAL_SHAPE); } - const ::flatbuffers::Vector *squeeze_dims() const { - return GetPointer *>(VT_SQUEEZE_DIMS); + const ::flatbuffers::Vector* squeeze_dims() const { + return GetPointer*>(VT_SQUEEZE_DIMS); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_ORIGINAL_SHAPE) && - verifier.VerifyVector(original_shape()) && - VerifyOffset(verifier, VT_SQUEEZE_DIMS) && - verifier.VerifyVector(squeeze_dims()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_ORIGINAL_SHAPE) && + verifier.VerifyVector(original_shape()) && + VerifyOffset(verifier, VT_SQUEEZE_DIMS) && + verifier.VerifyVector(squeeze_dims()) && verifier.EndTable(); } }; struct SqueezeBuilder { typedef Squeeze Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_original_shape(::flatbuffers::Offset<::flatbuffers::Vector> original_shape) { + void add_original_shape( + ::flatbuffers::Offset<::flatbuffers::Vector> original_shape) { fbb_.AddOffset(Squeeze::VT_ORIGINAL_SHAPE, original_shape); } - void add_squeeze_dims(::flatbuffers::Offset<::flatbuffers::Vector> squeeze_dims) { + void add_squeeze_dims( + ::flatbuffers::Offset<::flatbuffers::Vector> squeeze_dims) { fbb_.AddOffset(Squeeze::VT_SQUEEZE_DIMS, squeeze_dims); } - explicit SqueezeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit SqueezeBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2503,7 +2569,7 @@ struct SqueezeBuilder { }; inline ::flatbuffers::Offset CreateSqueeze( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> original_shape = 0, ::flatbuffers::Offset<::flatbuffers::Vector> squeeze_dims = 0) { SqueezeBuilder builder_(_fbb); @@ -2513,15 +2579,14 @@ inline ::flatbuffers::Offset CreateSqueeze( } inline ::flatbuffers::Offset CreateSqueezeDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *original_shape = nullptr, - const std::vector *squeeze_dims = nullptr) { - auto original_shape__ = original_shape ? _fbb.CreateVector(*original_shape) : 0; - auto squeeze_dims__ = squeeze_dims ? _fbb.CreateVector(*squeeze_dims) : 0; - return nvfuser::serde::CreateSqueeze( - _fbb, - original_shape__, - squeeze_dims__); + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* original_shape = nullptr, + const std::vector* squeeze_dims = nullptr) { + auto original_shape__ = + original_shape ? _fbb.CreateVector(*original_shape) : 0; + auto squeeze_dims__ = + squeeze_dims ? _fbb.CreateVector(*squeeze_dims) : 0; + return nvfuser::serde::CreateSqueeze(_fbb, original_shape__, squeeze_dims__); } struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2532,48 +2597,48 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_DTYPE = 8, VT_IS_CPU = 10 }; - const ::flatbuffers::Vector *sizes() const { - return GetPointer *>(VT_SIZES); + const ::flatbuffers::Vector* sizes() const { + return GetPointer*>(VT_SIZES); } - const ::flatbuffers::Vector *contiguity() const { - return GetPointer *>(VT_CONTIGUITY); + const ::flatbuffers::Vector* contiguity() const { + return GetPointer*>(VT_CONTIGUITY); } nvfuser::serde::DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + return static_cast( + GetField(VT_DTYPE, 0)); } bool is_cpu() const { return GetField(VT_IS_CPU, 0) != 0; } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_SIZES) && - verifier.VerifyVector(sizes()) && - VerifyOffset(verifier, VT_CONTIGUITY) && - verifier.VerifyVector(contiguity()) && - VerifyField(verifier, VT_DTYPE, 4) && - VerifyField(verifier, VT_IS_CPU, 1) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SIZES) && + verifier.VerifyVector(sizes()) && + VerifyOffset(verifier, VT_CONTIGUITY) && + verifier.VerifyVector(contiguity()) && + VerifyField(verifier, VT_DTYPE, 4) && + VerifyField(verifier, VT_IS_CPU, 1) && verifier.EndTable(); } }; struct TensorBuilder { typedef Tensor Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_sizes(::flatbuffers::Offset<::flatbuffers::Vector> sizes) { fbb_.AddOffset(Tensor::VT_SIZES, sizes); } - void add_contiguity(::flatbuffers::Offset<::flatbuffers::Vector> contiguity) { + void add_contiguity( + ::flatbuffers::Offset<::flatbuffers::Vector> contiguity) { fbb_.AddOffset(Tensor::VT_CONTIGUITY, contiguity); } void add_dtype(nvfuser::serde::DataType dtype) { fbb_.AddElement(Tensor::VT_DTYPE, static_cast(dtype), 0); } void add_is_cpu(bool is_cpu) { - fbb_.AddElement(Tensor::VT_IS_CPU, static_cast(is_cpu), 0); + fbb_.AddElement( + Tensor::VT_IS_CPU, static_cast(is_cpu), 0); } - explicit TensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit TensorBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2584,7 +2649,7 @@ struct TensorBuilder { }; inline ::flatbuffers::Offset CreateTensor( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> sizes = 0, ::flatbuffers::Offset<::flatbuffers::Vector> contiguity = 0, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double, @@ -2598,19 +2663,15 @@ inline ::flatbuffers::Offset CreateTensor( } inline ::flatbuffers::Offset CreateTensorDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *sizes = nullptr, - const std::vector *contiguity = nullptr, + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* sizes = nullptr, + const std::vector* contiguity = nullptr, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double, bool is_cpu = false) { auto sizes__ = sizes ? _fbb.CreateVector(*sizes) : 0; auto contiguity__ = contiguity ? _fbb.CreateVector(*contiguity) : 0; return nvfuser::serde::CreateTensor( - _fbb, - sizes__, - contiguity__, - dtype, - is_cpu); + _fbb, sizes__, contiguity__, dtype, is_cpu); } struct TensorCreation FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2619,33 +2680,33 @@ struct TensorCreation FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_SHAPE = 4, VT_DTYPE = 6 }; - const ::flatbuffers::Vector *shape() const { - return GetPointer *>(VT_SHAPE); + const ::flatbuffers::Vector* shape() const { + return GetPointer*>(VT_SHAPE); } nvfuser::serde::DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + return static_cast( + GetField(VT_DTYPE, 0)); } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_SHAPE) && - verifier.VerifyVector(shape()) && - VerifyField(verifier, VT_DTYPE, 4) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && + VerifyField(verifier, VT_DTYPE, 4) && verifier.EndTable(); } }; struct TensorCreationBuilder { typedef TensorCreation Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_shape(::flatbuffers::Offset<::flatbuffers::Vector> shape) { fbb_.AddOffset(TensorCreation::VT_SHAPE, shape); } void add_dtype(nvfuser::serde::DataType dtype) { - fbb_.AddElement(TensorCreation::VT_DTYPE, static_cast(dtype), 0); + fbb_.AddElement( + TensorCreation::VT_DTYPE, static_cast(dtype), 0); } - explicit TensorCreationBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit TensorCreationBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2656,7 +2717,7 @@ struct TensorCreationBuilder { }; inline ::flatbuffers::Offset CreateTensorCreation( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset<::flatbuffers::Vector> shape = 0, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double) { TensorCreationBuilder builder_(_fbb); @@ -2666,49 +2727,50 @@ inline ::flatbuffers::Offset CreateTensorCreation( } inline ::flatbuffers::Offset CreateTensorCreationDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *shape = nullptr, + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* shape = nullptr, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double) { auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; - return nvfuser::serde::CreateTensorCreation( - _fbb, - shape__, - dtype); + return nvfuser::serde::CreateTensorCreation(_fbb, shape__, dtype); } -struct TensorCreationSymbolic FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { +struct TensorCreationSymbolic FLATBUFFERS_FINAL_CLASS + : private ::flatbuffers::Table { typedef TensorCreationSymbolicBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_SHAPE = 4, VT_DTYPE = 6 }; - const ::flatbuffers::Vector *shape() const { - return GetPointer *>(VT_SHAPE); + const ::flatbuffers::Vector* shape() const { + return GetPointer< + const ::flatbuffers::Vector*>(VT_SHAPE); } nvfuser::serde::DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + return static_cast( + GetField(VT_DTYPE, 0)); } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_SHAPE) && - verifier.VerifyVector(shape()) && - VerifyField(verifier, VT_DTYPE, 4) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && + VerifyField(verifier, VT_DTYPE, 4) && verifier.EndTable(); } }; struct TensorCreationSymbolicBuilder { typedef TensorCreationSymbolic Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_shape(::flatbuffers::Offset<::flatbuffers::Vector> shape) { + void add_shape( + ::flatbuffers::Offset<::flatbuffers::Vector> + shape) { fbb_.AddOffset(TensorCreationSymbolic::VT_SHAPE, shape); } void add_dtype(nvfuser::serde::DataType dtype) { - fbb_.AddElement(TensorCreationSymbolic::VT_DTYPE, static_cast(dtype), 0); + fbb_.AddElement( + TensorCreationSymbolic::VT_DTYPE, static_cast(dtype), 0); } - explicit TensorCreationSymbolicBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit TensorCreationSymbolicBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2718,9 +2780,11 @@ struct TensorCreationSymbolicBuilder { } }; -inline ::flatbuffers::Offset CreateTensorCreationSymbolic( - ::flatbuffers::FlatBufferBuilder &_fbb, - ::flatbuffers::Offset<::flatbuffers::Vector> shape = 0, +inline ::flatbuffers::Offset +CreateTensorCreationSymbolic( + ::flatbuffers::FlatBufferBuilder& _fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> + shape = 0, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double) { TensorCreationSymbolicBuilder builder_(_fbb); builder_.add_dtype(dtype); @@ -2728,15 +2792,14 @@ inline ::flatbuffers::Offset CreateTensorCreationSymboli return builder_.Finish(); } -inline ::flatbuffers::Offset CreateTensorCreationSymbolicDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *shape = nullptr, +inline ::flatbuffers::Offset +CreateTensorCreationSymbolicDirect( + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* shape = nullptr, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double) { - auto shape__ = shape ? _fbb.CreateVectorOfStructs(*shape) : 0; - return nvfuser::serde::CreateTensorCreationSymbolic( - _fbb, - shape__, - dtype); + auto shape__ = + shape ? _fbb.CreateVectorOfStructs(*shape) : 0; + return nvfuser::serde::CreateTensorCreationSymbolic(_fbb, shape__, dtype); } struct Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2745,24 +2808,23 @@ struct Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_DTYPE = 4 }; nvfuser::serde::DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + return static_cast( + GetField(VT_DTYPE, 0)); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_DTYPE, 4) && - verifier.EndTable(); + VerifyField(verifier, VT_DTYPE, 4) && verifier.EndTable(); } }; struct VectorBuilder { typedef Vector Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_dtype(nvfuser::serde::DataType dtype) { fbb_.AddElement(Vector::VT_DTYPE, static_cast(dtype), 0); } - explicit VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit VectorBuilder(::flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2773,7 +2835,7 @@ struct VectorBuilder { }; inline ::flatbuffers::Offset CreateVector( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, nvfuser::serde::DataType dtype = nvfuser::serde::DataType_Double) { VectorBuilder builder_(_fbb); builder_.add_dtype(dtype); @@ -2812,50 +2874,62 @@ struct FusionExecutor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int64_t fusion_id_counter() const { return GetField(VT_FUSION_ID_COUNTER, 0); } - const ::flatbuffers::String *kernel_code() const { - return GetPointer(VT_KERNEL_CODE); + const ::flatbuffers::String* kernel_code() const { + return GetPointer(VT_KERNEL_CODE); } - const ::flatbuffers::Vector *executor_entry_lookup_keys() const { - return GetPointer *>(VT_EXECUTOR_ENTRY_LOOKUP_KEYS); + const ::flatbuffers::Vector* executor_entry_lookup_keys() const { + return GetPointer*>( + VT_EXECUTOR_ENTRY_LOOKUP_KEYS); } - const ::flatbuffers::Vector<::flatbuffers::Offset> *executor_entry_lookup_values() const { - return GetPointer> *>(VT_EXECUTOR_ENTRY_LOOKUP_VALUES); + const ::flatbuffers::Vector< + ::flatbuffers::Offset>* + executor_entry_lookup_values() const { + return GetPointer>*>( + VT_EXECUTOR_ENTRY_LOOKUP_VALUES); } nvfuser::serde::DataType index_type() const { - return static_cast(GetField(VT_INDEX_TYPE, 0)); + return static_cast( + GetField(VT_INDEX_TYPE, 0)); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_DEVICE_SMEM_LIMIT, 8) && - VerifyField(verifier, VT_BLOCK_SIZE_HIGH_WATER_MARK, 8) && - VerifyField(verifier, VT_MAXRREGCOUNT_HIGH_WATER_MARK, 8) && - VerifyField(verifier, VT_WARP_SIZE, 8) && - VerifyField(verifier, VT_FUSION_ID, 8) && - VerifyField(verifier, VT_FUSION_ID_COUNTER, 8) && - VerifyOffset(verifier, VT_KERNEL_CODE) && - verifier.VerifyString(kernel_code()) && - VerifyOffset(verifier, VT_EXECUTOR_ENTRY_LOOKUP_KEYS) && - verifier.VerifyVector(executor_entry_lookup_keys()) && - VerifyOffset(verifier, VT_EXECUTOR_ENTRY_LOOKUP_VALUES) && - verifier.VerifyVector(executor_entry_lookup_values()) && - verifier.VerifyVectorOfTables(executor_entry_lookup_values()) && - VerifyField(verifier, VT_INDEX_TYPE, 4) && - verifier.EndTable(); + VerifyField(verifier, VT_DEVICE_SMEM_LIMIT, 8) && + VerifyField(verifier, VT_BLOCK_SIZE_HIGH_WATER_MARK, 8) && + VerifyField(verifier, VT_MAXRREGCOUNT_HIGH_WATER_MARK, 8) && + VerifyField(verifier, VT_WARP_SIZE, 8) && + VerifyField(verifier, VT_FUSION_ID, 8) && + VerifyField(verifier, VT_FUSION_ID_COUNTER, 8) && + VerifyOffset(verifier, VT_KERNEL_CODE) && + verifier.VerifyString(kernel_code()) && + VerifyOffset(verifier, VT_EXECUTOR_ENTRY_LOOKUP_KEYS) && + verifier.VerifyVector(executor_entry_lookup_keys()) && + VerifyOffset(verifier, VT_EXECUTOR_ENTRY_LOOKUP_VALUES) && + verifier.VerifyVector(executor_entry_lookup_values()) && + verifier.VerifyVectorOfTables(executor_entry_lookup_values()) && + VerifyField(verifier, VT_INDEX_TYPE, 4) && verifier.EndTable(); } }; struct FusionExecutorBuilder { typedef FusionExecutor Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_device_smem_limit(int64_t device_smem_limit) { - fbb_.AddElement(FusionExecutor::VT_DEVICE_SMEM_LIMIT, device_smem_limit, 0); + fbb_.AddElement( + FusionExecutor::VT_DEVICE_SMEM_LIMIT, device_smem_limit, 0); } void add_block_size_high_water_mark(int64_t block_size_high_water_mark) { - fbb_.AddElement(FusionExecutor::VT_BLOCK_SIZE_HIGH_WATER_MARK, block_size_high_water_mark, 0); + fbb_.AddElement( + FusionExecutor::VT_BLOCK_SIZE_HIGH_WATER_MARK, + block_size_high_water_mark, + 0); } void add_maxrregcount_high_water_mark(int64_t maxrregcount_high_water_mark) { - fbb_.AddElement(FusionExecutor::VT_MAXRREGCOUNT_HIGH_WATER_MARK, maxrregcount_high_water_mark, 0); + fbb_.AddElement( + FusionExecutor::VT_MAXRREGCOUNT_HIGH_WATER_MARK, + maxrregcount_high_water_mark, + 0); } void add_warp_size(int64_t warp_size) { fbb_.AddElement(FusionExecutor::VT_WARP_SIZE, warp_size, 0); @@ -2864,22 +2938,34 @@ struct FusionExecutorBuilder { fbb_.AddElement(FusionExecutor::VT_FUSION_ID, fusion_id, 0); } void add_fusion_id_counter(int64_t fusion_id_counter) { - fbb_.AddElement(FusionExecutor::VT_FUSION_ID_COUNTER, fusion_id_counter, 0); + fbb_.AddElement( + FusionExecutor::VT_FUSION_ID_COUNTER, fusion_id_counter, 0); } - void add_kernel_code(::flatbuffers::Offset<::flatbuffers::String> kernel_code) { + void add_kernel_code( + ::flatbuffers::Offset<::flatbuffers::String> kernel_code) { fbb_.AddOffset(FusionExecutor::VT_KERNEL_CODE, kernel_code); } - void add_executor_entry_lookup_keys(::flatbuffers::Offset<::flatbuffers::Vector> executor_entry_lookup_keys) { - fbb_.AddOffset(FusionExecutor::VT_EXECUTOR_ENTRY_LOOKUP_KEYS, executor_entry_lookup_keys); - } - void add_executor_entry_lookup_values(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> executor_entry_lookup_values) { - fbb_.AddOffset(FusionExecutor::VT_EXECUTOR_ENTRY_LOOKUP_VALUES, executor_entry_lookup_values); + void add_executor_entry_lookup_keys( + ::flatbuffers::Offset<::flatbuffers::Vector> + executor_entry_lookup_keys) { + fbb_.AddOffset( + FusionExecutor::VT_EXECUTOR_ENTRY_LOOKUP_KEYS, + executor_entry_lookup_keys); + } + void add_executor_entry_lookup_values( + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> + executor_entry_lookup_values) { + fbb_.AddOffset( + FusionExecutor::VT_EXECUTOR_ENTRY_LOOKUP_VALUES, + executor_entry_lookup_values); } void add_index_type(nvfuser::serde::DataType index_type) { - fbb_.AddElement(FusionExecutor::VT_INDEX_TYPE, static_cast(index_type), 0); + fbb_.AddElement( + FusionExecutor::VT_INDEX_TYPE, static_cast(index_type), 0); } - explicit FusionExecutorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit FusionExecutorBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2890,7 +2976,7 @@ struct FusionExecutorBuilder { }; inline ::flatbuffers::Offset CreateFusionExecutor( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, int64_t device_smem_limit = 0, int64_t block_size_high_water_mark = 0, int64_t maxrregcount_high_water_mark = 0, @@ -2898,8 +2984,10 @@ inline ::flatbuffers::Offset CreateFusionExecutor( int64_t fusion_id = 0, int64_t fusion_id_counter = 0, ::flatbuffers::Offset<::flatbuffers::String> kernel_code = 0, - ::flatbuffers::Offset<::flatbuffers::Vector> executor_entry_lookup_keys = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> executor_entry_lookup_values = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> + executor_entry_lookup_keys = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset< + nvfuser::serde::ExecutorEntry>>> executor_entry_lookup_values = 0, nvfuser::serde::DataType index_type = nvfuser::serde::DataType_Double) { FusionExecutorBuilder builder_(_fbb); builder_.add_fusion_id_counter(fusion_id_counter); @@ -2916,20 +3004,26 @@ inline ::flatbuffers::Offset CreateFusionExecutor( } inline ::flatbuffers::Offset CreateFusionExecutorDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, int64_t device_smem_limit = 0, int64_t block_size_high_water_mark = 0, int64_t maxrregcount_high_water_mark = 0, int64_t warp_size = 0, int64_t fusion_id = 0, int64_t fusion_id_counter = 0, - const char *kernel_code = nullptr, - const std::vector *executor_entry_lookup_keys = nullptr, - const std::vector<::flatbuffers::Offset> *executor_entry_lookup_values = nullptr, + const char* kernel_code = nullptr, + const std::vector* executor_entry_lookup_keys = nullptr, + const std::vector<::flatbuffers::Offset>* + executor_entry_lookup_values = nullptr, nvfuser::serde::DataType index_type = nvfuser::serde::DataType_Double) { auto kernel_code__ = kernel_code ? _fbb.CreateString(kernel_code) : 0; - auto executor_entry_lookup_keys__ = executor_entry_lookup_keys ? _fbb.CreateVector(*executor_entry_lookup_keys) : 0; - auto executor_entry_lookup_values__ = executor_entry_lookup_values ? _fbb.CreateVector<::flatbuffers::Offset>(*executor_entry_lookup_values) : 0; + auto executor_entry_lookup_keys__ = executor_entry_lookup_keys + ? _fbb.CreateVector(*executor_entry_lookup_keys) + : 0; + auto executor_entry_lookup_values__ = executor_entry_lookup_values + ? _fbb.CreateVector<::flatbuffers::Offset>( + *executor_entry_lookup_values) + : 0; return nvfuser::serde::CreateFusionExecutor( _fbb, device_smem_limit, @@ -2944,41 +3038,45 @@ inline ::flatbuffers::Offset CreateFusionExecutorDirect( index_type); } -struct FusionKernelRuntime FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { +struct FusionKernelRuntime FLATBUFFERS_FINAL_CLASS + : private ::flatbuffers::Table { typedef FusionKernelRuntimeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_ARGS = 4, VT_EXECUTORS = 6 }; - const nvfuser::serde::KernelArgumentHolder *args() const { - return GetPointer(VT_ARGS); + const nvfuser::serde::KernelArgumentHolder* args() const { + return GetPointer(VT_ARGS); } - const ::flatbuffers::Vector<::flatbuffers::Offset> *executors() const { - return GetPointer> *>(VT_EXECUTORS); + const ::flatbuffers::Vector< + ::flatbuffers::Offset>* + executors() const { + return GetPointer>*>(VT_EXECUTORS); } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_ARGS) && - verifier.VerifyTable(args()) && - VerifyOffset(verifier, VT_EXECUTORS) && - verifier.VerifyVector(executors()) && - verifier.VerifyVectorOfTables(executors()) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ARGS) && + verifier.VerifyTable(args()) && VerifyOffset(verifier, VT_EXECUTORS) && + verifier.VerifyVector(executors()) && + verifier.VerifyVectorOfTables(executors()) && verifier.EndTable(); } }; struct FusionKernelRuntimeBuilder { typedef FusionKernelRuntime Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_args(::flatbuffers::Offset args) { + void add_args( + ::flatbuffers::Offset args) { fbb_.AddOffset(FusionKernelRuntime::VT_ARGS, args); } - void add_executors(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> executors) { + void add_executors( + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> executors) { fbb_.AddOffset(FusionKernelRuntime::VT_EXECUTORS, executors); } - explicit FusionKernelRuntimeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit FusionKernelRuntimeBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -2989,24 +3087,27 @@ struct FusionKernelRuntimeBuilder { }; inline ::flatbuffers::Offset CreateFusionKernelRuntime( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset args = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> executors = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> executors = 0) { FusionKernelRuntimeBuilder builder_(_fbb); builder_.add_executors(executors); builder_.add_args(args); return builder_.Finish(); } -inline ::flatbuffers::Offset CreateFusionKernelRuntimeDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, +inline ::flatbuffers::Offset +CreateFusionKernelRuntimeDirect( + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset args = 0, - const std::vector<::flatbuffers::Offset> *executors = nullptr) { - auto executors__ = executors ? _fbb.CreateVector<::flatbuffers::Offset>(*executors) : 0; - return nvfuser::serde::CreateFusionKernelRuntime( - _fbb, - args, - executors__); + const std::vector<::flatbuffers::Offset>* + executors = nullptr) { + auto executors__ = executors + ? _fbb.CreateVector< + ::flatbuffers::Offset>(*executors) + : 0; + return nvfuser::serde::CreateFusionKernelRuntime(_fbb, args, executors__); } struct InputsIdLookup FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -3024,52 +3125,71 @@ struct InputsIdLookup FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { uint64_t current_id() const { return GetField(VT_CURRENT_ID, 0); } - const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *lru_cache() const { - return GetPointer> *>(VT_LRU_CACHE); + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* + lru_cache() const { + return GetPointer>*>(VT_LRU_CACHE); } - const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *encoding_lookup_keys() const { - return GetPointer> *>(VT_ENCODING_LOOKUP_KEYS); + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* + encoding_lookup_keys() const { + return GetPointer>*>( + VT_ENCODING_LOOKUP_KEYS); } - const ::flatbuffers::Vector *encoding_lookup_values() const { - return GetPointer *>(VT_ENCODING_LOOKUP_VALUES); + const ::flatbuffers::Vector* + encoding_lookup_values() const { + return GetPointer< + const ::flatbuffers::Vector*>( + VT_ENCODING_LOOKUP_VALUES); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_MAX_CACHE_SIZE, 8) && - VerifyField(verifier, VT_CURRENT_ID, 8) && - VerifyOffset(verifier, VT_LRU_CACHE) && - verifier.VerifyVector(lru_cache()) && - verifier.VerifyVectorOfStrings(lru_cache()) && - VerifyOffset(verifier, VT_ENCODING_LOOKUP_KEYS) && - verifier.VerifyVector(encoding_lookup_keys()) && - verifier.VerifyVectorOfStrings(encoding_lookup_keys()) && - VerifyOffset(verifier, VT_ENCODING_LOOKUP_VALUES) && - verifier.VerifyVector(encoding_lookup_values()) && - verifier.EndTable(); + VerifyField(verifier, VT_MAX_CACHE_SIZE, 8) && + VerifyField(verifier, VT_CURRENT_ID, 8) && + VerifyOffset(verifier, VT_LRU_CACHE) && + verifier.VerifyVector(lru_cache()) && + verifier.VerifyVectorOfStrings(lru_cache()) && + VerifyOffset(verifier, VT_ENCODING_LOOKUP_KEYS) && + verifier.VerifyVector(encoding_lookup_keys()) && + verifier.VerifyVectorOfStrings(encoding_lookup_keys()) && + VerifyOffset(verifier, VT_ENCODING_LOOKUP_VALUES) && + verifier.VerifyVector(encoding_lookup_values()) && verifier.EndTable(); } }; struct InputsIdLookupBuilder { typedef InputsIdLookup Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_max_cache_size(uint64_t max_cache_size) { - fbb_.AddElement(InputsIdLookup::VT_MAX_CACHE_SIZE, max_cache_size, 0); + fbb_.AddElement( + InputsIdLookup::VT_MAX_CACHE_SIZE, max_cache_size, 0); } void add_current_id(uint64_t current_id) { fbb_.AddElement(InputsIdLookup::VT_CURRENT_ID, current_id, 0); } - void add_lru_cache(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> lru_cache) { + void add_lru_cache( + ::flatbuffers::Offset< + ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> + lru_cache) { fbb_.AddOffset(InputsIdLookup::VT_LRU_CACHE, lru_cache); } - void add_encoding_lookup_keys(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> encoding_lookup_keys) { - fbb_.AddOffset(InputsIdLookup::VT_ENCODING_LOOKUP_KEYS, encoding_lookup_keys); - } - void add_encoding_lookup_values(::flatbuffers::Offset<::flatbuffers::Vector> encoding_lookup_values) { - fbb_.AddOffset(InputsIdLookup::VT_ENCODING_LOOKUP_VALUES, encoding_lookup_values); - } - explicit InputsIdLookupBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + void add_encoding_lookup_keys( + ::flatbuffers::Offset< + ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> + encoding_lookup_keys) { + fbb_.AddOffset( + InputsIdLookup::VT_ENCODING_LOOKUP_KEYS, encoding_lookup_keys); + } + void add_encoding_lookup_values( + ::flatbuffers::Offset< + ::flatbuffers::Vector> + encoding_lookup_values) { + fbb_.AddOffset( + InputsIdLookup::VT_ENCODING_LOOKUP_VALUES, encoding_lookup_values); + } + explicit InputsIdLookupBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -3080,12 +3200,15 @@ struct InputsIdLookupBuilder { }; inline ::flatbuffers::Offset CreateInputsIdLookup( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, uint64_t max_cache_size = 0, uint64_t current_id = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> lru_cache = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> encoding_lookup_keys = 0, - ::flatbuffers::Offset<::flatbuffers::Vector> encoding_lookup_values = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset<::flatbuffers::String>>> lru_cache = 0, + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset<::flatbuffers::String>>> encoding_lookup_keys = 0, + ::flatbuffers::Offset<::flatbuffers::Vector< + const nvfuser::serde::EncodingEntry*>> encoding_lookup_values = 0) { InputsIdLookupBuilder builder_(_fbb); builder_.add_current_id(current_id); builder_.add_max_cache_size(max_cache_size); @@ -3096,15 +3219,27 @@ inline ::flatbuffers::Offset CreateInputsIdLookup( } inline ::flatbuffers::Offset CreateInputsIdLookupDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, uint64_t max_cache_size = 0, uint64_t current_id = 0, - const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *lru_cache = nullptr, - const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *encoding_lookup_keys = nullptr, - const std::vector *encoding_lookup_values = nullptr) { - auto lru_cache__ = lru_cache ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*lru_cache) : 0; - auto encoding_lookup_keys__ = encoding_lookup_keys ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*encoding_lookup_keys) : 0; - auto encoding_lookup_values__ = encoding_lookup_values ? _fbb.CreateVectorOfStructs(*encoding_lookup_values) : 0; + const std::vector<::flatbuffers::Offset<::flatbuffers::String>>* lru_cache = + nullptr, + const std::vector<::flatbuffers::Offset<::flatbuffers::String>>* + encoding_lookup_keys = nullptr, + const std::vector* encoding_lookup_values = + nullptr) { + auto lru_cache__ = lru_cache + ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>( + *lru_cache) + : 0; + auto encoding_lookup_keys__ = encoding_lookup_keys + ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>( + *encoding_lookup_keys) + : 0; + auto encoding_lookup_values__ = encoding_lookup_values + ? _fbb.CreateVectorOfStructs( + *encoding_lookup_values) + : 0; return nvfuser::serde::CreateInputsIdLookup( _fbb, max_cache_size, @@ -3127,35 +3262,44 @@ struct KernelRuntimes FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool has_dynamic_transform_info() const { return GetField(VT_HAS_DYNAMIC_TRANSFORM_INFO, 0) != 0; } - const ::flatbuffers::Vector<::flatbuffers::Offset> *runtimes() const { - return GetPointer> *>(VT_RUNTIMES); + const ::flatbuffers::Vector< + ::flatbuffers::Offset>* + runtimes() const { + return GetPointer>*>( + VT_RUNTIMES); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_DEVICE_ID, 8) && - VerifyField(verifier, VT_HAS_DYNAMIC_TRANSFORM_INFO, 1) && - VerifyOffset(verifier, VT_RUNTIMES) && - verifier.VerifyVector(runtimes()) && - verifier.VerifyVectorOfTables(runtimes()) && - verifier.EndTable(); + VerifyField(verifier, VT_DEVICE_ID, 8) && + VerifyField(verifier, VT_HAS_DYNAMIC_TRANSFORM_INFO, 1) && + VerifyOffset(verifier, VT_RUNTIMES) && + verifier.VerifyVector(runtimes()) && + verifier.VerifyVectorOfTables(runtimes()) && verifier.EndTable(); } }; struct KernelRuntimesBuilder { typedef KernelRuntimes Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_device_id(uint64_t device_id) { fbb_.AddElement(KernelRuntimes::VT_DEVICE_ID, device_id, 0); } void add_has_dynamic_transform_info(bool has_dynamic_transform_info) { - fbb_.AddElement(KernelRuntimes::VT_HAS_DYNAMIC_TRANSFORM_INFO, static_cast(has_dynamic_transform_info), 0); - } - void add_runtimes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> runtimes) { + fbb_.AddElement( + KernelRuntimes::VT_HAS_DYNAMIC_TRANSFORM_INFO, + static_cast(has_dynamic_transform_info), + 0); + } + void add_runtimes( + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> + runtimes) { fbb_.AddOffset(KernelRuntimes::VT_RUNTIMES, runtimes); } - explicit KernelRuntimesBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit KernelRuntimesBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -3166,10 +3310,12 @@ struct KernelRuntimesBuilder { }; inline ::flatbuffers::Offset CreateKernelRuntimes( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, uint64_t device_id = 0, bool has_dynamic_transform_info = false, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> runtimes = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> runtimes = + 0) { KernelRuntimesBuilder builder_(_fbb); builder_.add_device_id(device_id); builder_.add_runtimes(runtimes); @@ -3178,19 +3324,23 @@ inline ::flatbuffers::Offset CreateKernelRuntimes( } inline ::flatbuffers::Offset CreateKernelRuntimesDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, uint64_t device_id = 0, bool has_dynamic_transform_info = false, - const std::vector<::flatbuffers::Offset> *runtimes = nullptr) { - auto runtimes__ = runtimes ? _fbb.CreateVector<::flatbuffers::Offset>(*runtimes) : 0; + const std::vector< + ::flatbuffers::Offset>* runtimes = + nullptr) { + auto runtimes__ = runtimes + ? _fbb.CreateVector< + ::flatbuffers::Offset>( + *runtimes) + : 0; return nvfuser::serde::CreateKernelRuntimes( - _fbb, - device_id, - has_dynamic_transform_info, - runtimes__); + _fbb, device_id, has_dynamic_transform_info, runtimes__); } -struct FusionExecutorCache FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { +struct FusionExecutorCache FLATBUFFERS_FINAL_CLASS + : private ::flatbuffers::Table { typedef FusionExecutorCacheBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_INPUTS_CACHE = 4, @@ -3198,51 +3348,66 @@ struct FusionExecutorCache FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Tabl VT_KERNEL_CACHE_KEYS = 8, VT_KERNEL_CACHE_VALUES = 10 }; - const nvfuser::serde::InputsIdLookup *inputs_cache() const { - return GetPointer(VT_INPUTS_CACHE); + const nvfuser::serde::InputsIdLookup* inputs_cache() const { + return GetPointer(VT_INPUTS_CACHE); } - const ::flatbuffers::Vector<::flatbuffers::Offset> *kernel_runtimes() const { - return GetPointer> *>(VT_KERNEL_RUNTIMES); + const ::flatbuffers::Vector< + ::flatbuffers::Offset>* + kernel_runtimes() const { + return GetPointer>*>( + VT_KERNEL_RUNTIMES); } - const ::flatbuffers::Vector *kernel_cache_keys() const { - return GetPointer *>(VT_KERNEL_CACHE_KEYS); + const ::flatbuffers::Vector* kernel_cache_keys() const { + return GetPointer*>( + VT_KERNEL_CACHE_KEYS); } - const ::flatbuffers::Vector *kernel_cache_values() const { - return GetPointer *>(VT_KERNEL_CACHE_VALUES); + const ::flatbuffers::Vector* kernel_cache_values() const { + return GetPointer*>( + VT_KERNEL_CACHE_VALUES); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_INPUTS_CACHE) && - verifier.VerifyTable(inputs_cache()) && - VerifyOffset(verifier, VT_KERNEL_RUNTIMES) && - verifier.VerifyVector(kernel_runtimes()) && - verifier.VerifyVectorOfTables(kernel_runtimes()) && - VerifyOffset(verifier, VT_KERNEL_CACHE_KEYS) && - verifier.VerifyVector(kernel_cache_keys()) && - VerifyOffset(verifier, VT_KERNEL_CACHE_VALUES) && - verifier.VerifyVector(kernel_cache_values()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_INPUTS_CACHE) && + verifier.VerifyTable(inputs_cache()) && + VerifyOffset(verifier, VT_KERNEL_RUNTIMES) && + verifier.VerifyVector(kernel_runtimes()) && + verifier.VerifyVectorOfTables(kernel_runtimes()) && + VerifyOffset(verifier, VT_KERNEL_CACHE_KEYS) && + verifier.VerifyVector(kernel_cache_keys()) && + VerifyOffset(verifier, VT_KERNEL_CACHE_VALUES) && + verifier.VerifyVector(kernel_cache_values()) && verifier.EndTable(); } }; struct FusionExecutorCacheBuilder { typedef FusionExecutorCache Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_inputs_cache(::flatbuffers::Offset inputs_cache) { + void add_inputs_cache( + ::flatbuffers::Offset inputs_cache) { fbb_.AddOffset(FusionExecutorCache::VT_INPUTS_CACHE, inputs_cache); } - void add_kernel_runtimes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> kernel_runtimes) { + void add_kernel_runtimes( + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> + kernel_runtimes) { fbb_.AddOffset(FusionExecutorCache::VT_KERNEL_RUNTIMES, kernel_runtimes); } - void add_kernel_cache_keys(::flatbuffers::Offset<::flatbuffers::Vector> kernel_cache_keys) { - fbb_.AddOffset(FusionExecutorCache::VT_KERNEL_CACHE_KEYS, kernel_cache_keys); + void add_kernel_cache_keys( + ::flatbuffers::Offset<::flatbuffers::Vector> + kernel_cache_keys) { + fbb_.AddOffset( + FusionExecutorCache::VT_KERNEL_CACHE_KEYS, kernel_cache_keys); } - void add_kernel_cache_values(::flatbuffers::Offset<::flatbuffers::Vector> kernel_cache_values) { - fbb_.AddOffset(FusionExecutorCache::VT_KERNEL_CACHE_VALUES, kernel_cache_values); + void add_kernel_cache_values( + ::flatbuffers::Offset<::flatbuffers::Vector> + kernel_cache_values) { + fbb_.AddOffset( + FusionExecutorCache::VT_KERNEL_CACHE_VALUES, kernel_cache_values); } - explicit FusionExecutorCacheBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit FusionExecutorCacheBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -3253,11 +3418,14 @@ struct FusionExecutorCacheBuilder { }; inline ::flatbuffers::Offset CreateFusionExecutorCache( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset inputs_cache = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> kernel_runtimes = 0, - ::flatbuffers::Offset<::flatbuffers::Vector> kernel_cache_keys = 0, - ::flatbuffers::Offset<::flatbuffers::Vector> kernel_cache_values = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset< + nvfuser::serde::KernelRuntimes>>> kernel_runtimes = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> kernel_cache_keys = + 0, + ::flatbuffers::Offset<::flatbuffers::Vector> kernel_cache_values = + 0) { FusionExecutorCacheBuilder builder_(_fbb); builder_.add_kernel_cache_values(kernel_cache_values); builder_.add_kernel_cache_keys(kernel_cache_keys); @@ -3266,15 +3434,24 @@ inline ::flatbuffers::Offset CreateFusionExecutorCache( return builder_.Finish(); } -inline ::flatbuffers::Offset CreateFusionExecutorCacheDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, +inline ::flatbuffers::Offset +CreateFusionExecutorCacheDirect( + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset inputs_cache = 0, - const std::vector<::flatbuffers::Offset> *kernel_runtimes = nullptr, - const std::vector *kernel_cache_keys = nullptr, - const std::vector *kernel_cache_values = nullptr) { - auto kernel_runtimes__ = kernel_runtimes ? _fbb.CreateVector<::flatbuffers::Offset>(*kernel_runtimes) : 0; - auto kernel_cache_keys__ = kernel_cache_keys ? _fbb.CreateVector(*kernel_cache_keys) : 0; - auto kernel_cache_values__ = kernel_cache_values ? _fbb.CreateVector(*kernel_cache_values) : 0; + const std::vector<::flatbuffers::Offset>* + kernel_runtimes = nullptr, + const std::vector* kernel_cache_keys = nullptr, + const std::vector* kernel_cache_values = nullptr) { + auto kernel_runtimes__ = kernel_runtimes + ? _fbb.CreateVector< + ::flatbuffers::Offset>( + *kernel_runtimes) + : 0; + auto kernel_cache_keys__ = + kernel_cache_keys ? _fbb.CreateVector(*kernel_cache_keys) : 0; + auto kernel_cache_values__ = kernel_cache_values + ? _fbb.CreateVector(*kernel_cache_values) + : 0; return nvfuser::serde::CreateFusionExecutorCache( _fbb, inputs_cache, @@ -3293,198 +3470,283 @@ struct RecordFunctor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_DATA_TYPE = 12, VT_DATA = 14 }; - const ::flatbuffers::Vector *args() const { - return GetPointer *>(VT_ARGS); + const ::flatbuffers::Vector* args() const { + return GetPointer< + const ::flatbuffers::Vector*>(VT_ARGS); } - const ::flatbuffers::Vector *outputs() const { - return GetPointer *>(VT_OUTPUTS); + const ::flatbuffers::Vector* outputs() const { + return GetPointer< + const ::flatbuffers::Vector*>(VT_OUTPUTS); } - const ::flatbuffers::String *name() const { - return GetPointer(VT_NAME); + const ::flatbuffers::String* name() const { + return GetPointer(VT_NAME); } nvfuser::serde::RecordType type() const { - return static_cast(GetField(VT_TYPE, 0)); + return static_cast( + GetField(VT_TYPE, 0)); } nvfuser::serde::RecordData data_type() const { - return static_cast(GetField(VT_DATA_TYPE, 0)); - } - const void *data() const { - return GetPointer(VT_DATA); - } - template const T *data_as() const; - const nvfuser::serde::BatchNorm *data_as_BatchNorm() const { - return data_type() == nvfuser::serde::RecordData_BatchNorm ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Broadcast *data_as_Broadcast() const { - return data_type() == nvfuser::serde::RecordData_Broadcast ? static_cast(data()) : nullptr; - } - const nvfuser::serde::BroadcastInDim *data_as_BroadcastInDim() const { - return data_type() == nvfuser::serde::RecordData_BroadcastInDim ? static_cast(data()) : nullptr; - } - const nvfuser::serde::BroadcastInDimSymbolic *data_as_BroadcastInDimSymbolic() const { - return data_type() == nvfuser::serde::RecordData_BroadcastInDimSymbolic ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Dimension *data_as_Dimension() const { - return data_type() == nvfuser::serde::RecordData_Dimension ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Dtype *data_as_Dtype() const { - return data_type() == nvfuser::serde::RecordData_Dtype ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Norm *data_as_Norm() const { - return data_type() == nvfuser::serde::RecordData_Norm ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Output *data_as_Output() const { - return data_type() == nvfuser::serde::RecordData_Output ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Pad *data_as_Pad() const { - return data_type() == nvfuser::serde::RecordData_Pad ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Permute *data_as_Permute() const { - return data_type() == nvfuser::serde::RecordData_Permute ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Slice *data_as_Slice() const { - return data_type() == nvfuser::serde::RecordData_Slice ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Squeeze *data_as_Squeeze() const { - return data_type() == nvfuser::serde::RecordData_Squeeze ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Reduction *data_as_Reduction() const { - return data_type() == nvfuser::serde::RecordData_Reduction ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Reshape *data_as_Reshape() const { - return data_type() == nvfuser::serde::RecordData_Reshape ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Scalar *data_as_Scalar() const { - return data_type() == nvfuser::serde::RecordData_Scalar ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Tensor *data_as_Tensor() const { - return data_type() == nvfuser::serde::RecordData_Tensor ? static_cast(data()) : nullptr; - } - const nvfuser::serde::TensorCreation *data_as_TensorCreation() const { - return data_type() == nvfuser::serde::RecordData_TensorCreation ? static_cast(data()) : nullptr; - } - const nvfuser::serde::TensorCreationSymbolic *data_as_TensorCreationSymbolic() const { - return data_type() == nvfuser::serde::RecordData_TensorCreationSymbolic ? static_cast(data()) : nullptr; - } - const nvfuser::serde::Vector *data_as_Vector() const { - return data_type() == nvfuser::serde::RecordData_Vector ? static_cast(data()) : nullptr; - } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_ARGS) && - verifier.VerifyVector(args()) && - VerifyOffset(verifier, VT_OUTPUTS) && - verifier.VerifyVector(outputs()) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && - VerifyField(verifier, VT_TYPE, 4) && - VerifyField(verifier, VT_DATA_TYPE, 1) && - VerifyOffset(verifier, VT_DATA) && - VerifyRecordData(verifier, data(), data_type()) && - verifier.EndTable(); + return static_cast( + GetField(VT_DATA_TYPE, 0)); + } + const void* data() const { + return GetPointer(VT_DATA); + } + template + const T* data_as() const; + const nvfuser::serde::BatchNorm* data_as_BatchNorm() const { + return data_type() == nvfuser::serde::RecordData_BatchNorm + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Broadcast* data_as_Broadcast() const { + return data_type() == nvfuser::serde::RecordData_Broadcast + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::BroadcastInDim* data_as_BroadcastInDim() const { + return data_type() == nvfuser::serde::RecordData_BroadcastInDim + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::BroadcastInDimSymbolic* data_as_BroadcastInDimSymbolic() + const { + return data_type() == nvfuser::serde::RecordData_BroadcastInDimSymbolic + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Dimension* data_as_Dimension() const { + return data_type() == nvfuser::serde::RecordData_Dimension + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Dtype* data_as_Dtype() const { + return data_type() == nvfuser::serde::RecordData_Dtype + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Norm* data_as_Norm() const { + return data_type() == nvfuser::serde::RecordData_Norm + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Output* data_as_Output() const { + return data_type() == nvfuser::serde::RecordData_Output + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Pad* data_as_Pad() const { + return data_type() == nvfuser::serde::RecordData_Pad + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Permute* data_as_Permute() const { + return data_type() == nvfuser::serde::RecordData_Permute + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Slice* data_as_Slice() const { + return data_type() == nvfuser::serde::RecordData_Slice + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Squeeze* data_as_Squeeze() const { + return data_type() == nvfuser::serde::RecordData_Squeeze + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Reduction* data_as_Reduction() const { + return data_type() == nvfuser::serde::RecordData_Reduction + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Reshape* data_as_Reshape() const { + return data_type() == nvfuser::serde::RecordData_Reshape + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Scalar* data_as_Scalar() const { + return data_type() == nvfuser::serde::RecordData_Scalar + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Tensor* data_as_Tensor() const { + return data_type() == nvfuser::serde::RecordData_Tensor + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::TensorCreation* data_as_TensorCreation() const { + return data_type() == nvfuser::serde::RecordData_TensorCreation + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::TensorCreationSymbolic* data_as_TensorCreationSymbolic() + const { + return data_type() == nvfuser::serde::RecordData_TensorCreationSymbolic + ? static_cast(data()) + : nullptr; + } + const nvfuser::serde::Vector* data_as_Vector() const { + return data_type() == nvfuser::serde::RecordData_Vector + ? static_cast(data()) + : nullptr; + } + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ARGS) && + verifier.VerifyVector(args()) && VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyField(verifier, VT_TYPE, 4) && + VerifyField(verifier, VT_DATA_TYPE, 1) && + VerifyOffset(verifier, VT_DATA) && + VerifyRecordData(verifier, data(), data_type()) && verifier.EndTable(); } }; -template<> inline const nvfuser::serde::BatchNorm *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::BatchNorm* RecordFunctor::data_as< + nvfuser::serde::BatchNorm>() const { return data_as_BatchNorm(); } -template<> inline const nvfuser::serde::Broadcast *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Broadcast* RecordFunctor::data_as< + nvfuser::serde::Broadcast>() const { return data_as_Broadcast(); } -template<> inline const nvfuser::serde::BroadcastInDim *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::BroadcastInDim* RecordFunctor::data_as< + nvfuser::serde::BroadcastInDim>() const { return data_as_BroadcastInDim(); } -template<> inline const nvfuser::serde::BroadcastInDimSymbolic *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::BroadcastInDimSymbolic* RecordFunctor::data_as< + nvfuser::serde::BroadcastInDimSymbolic>() const { return data_as_BroadcastInDimSymbolic(); } -template<> inline const nvfuser::serde::Dimension *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Dimension* RecordFunctor::data_as< + nvfuser::serde::Dimension>() const { return data_as_Dimension(); } -template<> inline const nvfuser::serde::Dtype *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Dtype* RecordFunctor::data_as< + nvfuser::serde::Dtype>() const { return data_as_Dtype(); } -template<> inline const nvfuser::serde::Norm *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Norm* RecordFunctor::data_as< + nvfuser::serde::Norm>() const { return data_as_Norm(); } -template<> inline const nvfuser::serde::Output *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Output* RecordFunctor::data_as< + nvfuser::serde::Output>() const { return data_as_Output(); } -template<> inline const nvfuser::serde::Pad *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Pad* RecordFunctor::data_as() + const { return data_as_Pad(); } -template<> inline const nvfuser::serde::Permute *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Permute* RecordFunctor::data_as< + nvfuser::serde::Permute>() const { return data_as_Permute(); } -template<> inline const nvfuser::serde::Slice *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Slice* RecordFunctor::data_as< + nvfuser::serde::Slice>() const { return data_as_Slice(); } -template<> inline const nvfuser::serde::Squeeze *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Squeeze* RecordFunctor::data_as< + nvfuser::serde::Squeeze>() const { return data_as_Squeeze(); } -template<> inline const nvfuser::serde::Reduction *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Reduction* RecordFunctor::data_as< + nvfuser::serde::Reduction>() const { return data_as_Reduction(); } -template<> inline const nvfuser::serde::Reshape *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Reshape* RecordFunctor::data_as< + nvfuser::serde::Reshape>() const { return data_as_Reshape(); } -template<> inline const nvfuser::serde::Scalar *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Scalar* RecordFunctor::data_as< + nvfuser::serde::Scalar>() const { return data_as_Scalar(); } -template<> inline const nvfuser::serde::Tensor *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Tensor* RecordFunctor::data_as< + nvfuser::serde::Tensor>() const { return data_as_Tensor(); } -template<> inline const nvfuser::serde::TensorCreation *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::TensorCreation* RecordFunctor::data_as< + nvfuser::serde::TensorCreation>() const { return data_as_TensorCreation(); } -template<> inline const nvfuser::serde::TensorCreationSymbolic *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::TensorCreationSymbolic* RecordFunctor::data_as< + nvfuser::serde::TensorCreationSymbolic>() const { return data_as_TensorCreationSymbolic(); } -template<> inline const nvfuser::serde::Vector *RecordFunctor::data_as() const { +template <> +inline const nvfuser::serde::Vector* RecordFunctor::data_as< + nvfuser::serde::Vector>() const { return data_as_Vector(); } struct RecordFunctorBuilder { typedef RecordFunctor Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; - void add_args(::flatbuffers::Offset<::flatbuffers::Vector> args) { + void add_args( + ::flatbuffers::Offset<::flatbuffers::Vector> + args) { fbb_.AddOffset(RecordFunctor::VT_ARGS, args); } - void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector> outputs) { + void add_outputs( + ::flatbuffers::Offset<::flatbuffers::Vector> + outputs) { fbb_.AddOffset(RecordFunctor::VT_OUTPUTS, outputs); } void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { fbb_.AddOffset(RecordFunctor::VT_NAME, name); } void add_type(nvfuser::serde::RecordType type) { - fbb_.AddElement(RecordFunctor::VT_TYPE, static_cast(type), 0); + fbb_.AddElement( + RecordFunctor::VT_TYPE, static_cast(type), 0); } void add_data_type(nvfuser::serde::RecordData data_type) { - fbb_.AddElement(RecordFunctor::VT_DATA_TYPE, static_cast(data_type), 0); + fbb_.AddElement( + RecordFunctor::VT_DATA_TYPE, static_cast(data_type), 0); } void add_data(::flatbuffers::Offset data) { fbb_.AddOffset(RecordFunctor::VT_DATA, data); } - explicit RecordFunctorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit RecordFunctorBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -3495,9 +3757,11 @@ struct RecordFunctorBuilder { }; inline ::flatbuffers::Offset CreateRecordFunctor( - ::flatbuffers::FlatBufferBuilder &_fbb, - ::flatbuffers::Offset<::flatbuffers::Vector> args = 0, - ::flatbuffers::Offset<::flatbuffers::Vector> outputs = 0, + ::flatbuffers::FlatBufferBuilder& _fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> + args = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> + outputs = 0, ::flatbuffers::Offset<::flatbuffers::String> name = 0, nvfuser::serde::RecordType type = nvfuser::serde::RecordType_Base, nvfuser::serde::RecordData data_type = nvfuser::serde::RecordData_NONE, @@ -3513,24 +3777,20 @@ inline ::flatbuffers::Offset CreateRecordFunctor( } inline ::flatbuffers::Offset CreateRecordFunctorDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *args = nullptr, - const std::vector *outputs = nullptr, - const char *name = nullptr, + ::flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* args = nullptr, + const std::vector* outputs = nullptr, + const char* name = nullptr, nvfuser::serde::RecordType type = nvfuser::serde::RecordType_Base, nvfuser::serde::RecordData data_type = nvfuser::serde::RecordData_NONE, ::flatbuffers::Offset data = 0) { - auto args__ = args ? _fbb.CreateVectorOfStructs(*args) : 0; - auto outputs__ = outputs ? _fbb.CreateVectorOfStructs(*outputs) : 0; + auto args__ = + args ? _fbb.CreateVectorOfStructs(*args) : 0; + auto outputs__ = + outputs ? _fbb.CreateVectorOfStructs(*outputs) : 0; auto name__ = name ? _fbb.CreateString(name) : 0; return nvfuser::serde::CreateRecordFunctor( - _fbb, - args__, - outputs__, - name__, - type, - data_type, - data); + _fbb, args__, outputs__, name__, type, data_type, data); } struct TrieNode FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -3542,11 +3802,11 @@ struct TrieNode FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_VISITS = 10, VT_IS_TERMINAL = 12 }; - const nvfuser::serde::RecordFunctor *record() const { - return GetPointer(VT_RECORD); + const nvfuser::serde::RecordFunctor* record() const { + return GetPointer(VT_RECORD); } - const ::flatbuffers::Vector *children() const { - return GetPointer *>(VT_CHILDREN); + const ::flatbuffers::Vector* children() const { + return GetPointer*>(VT_CHILDREN); } uint64_t fusion_id() const { return GetField(VT_FUSION_ID, 0); @@ -3557,27 +3817,26 @@ struct TrieNode FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool is_terminal() const { return GetField(VT_IS_TERMINAL, 0) != 0; } - bool Verify(::flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_RECORD) && - verifier.VerifyTable(record()) && - VerifyOffset(verifier, VT_CHILDREN) && - verifier.VerifyVector(children()) && - VerifyField(verifier, VT_FUSION_ID, 8) && - VerifyField(verifier, VT_VISITS, 8) && - VerifyField(verifier, VT_IS_TERMINAL, 1) && - verifier.EndTable(); + bool Verify(::flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_RECORD) && + verifier.VerifyTable(record()) && VerifyOffset(verifier, VT_CHILDREN) && + verifier.VerifyVector(children()) && + VerifyField(verifier, VT_FUSION_ID, 8) && + VerifyField(verifier, VT_VISITS, 8) && + VerifyField(verifier, VT_IS_TERMINAL, 1) && + verifier.EndTable(); } }; struct TrieNodeBuilder { typedef TrieNode Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_record(::flatbuffers::Offset record) { fbb_.AddOffset(TrieNode::VT_RECORD, record); } - void add_children(::flatbuffers::Offset<::flatbuffers::Vector> children) { + void add_children( + ::flatbuffers::Offset<::flatbuffers::Vector> children) { fbb_.AddOffset(TrieNode::VT_CHILDREN, children); } void add_fusion_id(uint64_t fusion_id) { @@ -3587,10 +3846,11 @@ struct TrieNodeBuilder { fbb_.AddElement(TrieNode::VT_VISITS, visits, 0); } void add_is_terminal(bool is_terminal) { - fbb_.AddElement(TrieNode::VT_IS_TERMINAL, static_cast(is_terminal), 0); + fbb_.AddElement( + TrieNode::VT_IS_TERMINAL, static_cast(is_terminal), 0); } - explicit TrieNodeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit TrieNodeBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -3601,7 +3861,7 @@ struct TrieNodeBuilder { }; inline ::flatbuffers::Offset CreateTrieNode( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset record = 0, ::flatbuffers::Offset<::flatbuffers::Vector> children = 0, uint64_t fusion_id = 0, @@ -3617,20 +3877,15 @@ inline ::flatbuffers::Offset CreateTrieNode( } inline ::flatbuffers::Offset CreateTrieNodeDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, ::flatbuffers::Offset record = 0, - const std::vector *children = nullptr, + const std::vector* children = nullptr, uint64_t fusion_id = 0, uint64_t visits = 0, bool is_terminal = false) { auto children__ = children ? _fbb.CreateVector(*children) : 0; return nvfuser::serde::CreateTrieNode( - _fbb, - record, - children__, - fusion_id, - visits, - is_terminal); + _fbb, record, children__, fusion_id, visits, is_terminal); } struct FusionCache FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -3644,48 +3899,61 @@ struct FusionCache FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { uint64_t max_fusions() const { return GetField(VT_MAX_FUSIONS, 0); } - const ::flatbuffers::Vector<::flatbuffers::Offset> *structure() const { - return GetPointer> *>(VT_STRUCTURE); + const ::flatbuffers::Vector<::flatbuffers::Offset>* + structure() const { + return GetPointer>*>(VT_STRUCTURE); } - const ::flatbuffers::Vector *terminal_nodes() const { - return GetPointer *>(VT_TERMINAL_NODES); + const ::flatbuffers::Vector* terminal_nodes() const { + return GetPointer*>( + VT_TERMINAL_NODES); } - const ::flatbuffers::Vector<::flatbuffers::Offset> *auto_gen_schedules() const { - return GetPointer> *>(VT_AUTO_GEN_SCHEDULES); + const ::flatbuffers::Vector< + ::flatbuffers::Offset>* + auto_gen_schedules() const { + return GetPointer>*>( + VT_AUTO_GEN_SCHEDULES); } - bool Verify(::flatbuffers::Verifier &verifier) const { + bool Verify(::flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_MAX_FUSIONS, 8) && - VerifyOffset(verifier, VT_STRUCTURE) && - verifier.VerifyVector(structure()) && - verifier.VerifyVectorOfTables(structure()) && - VerifyOffset(verifier, VT_TERMINAL_NODES) && - verifier.VerifyVector(terminal_nodes()) && - VerifyOffset(verifier, VT_AUTO_GEN_SCHEDULES) && - verifier.VerifyVector(auto_gen_schedules()) && - verifier.VerifyVectorOfTables(auto_gen_schedules()) && - verifier.EndTable(); + VerifyField(verifier, VT_MAX_FUSIONS, 8) && + VerifyOffset(verifier, VT_STRUCTURE) && + verifier.VerifyVector(structure()) && + verifier.VerifyVectorOfTables(structure()) && + VerifyOffset(verifier, VT_TERMINAL_NODES) && + verifier.VerifyVector(terminal_nodes()) && + VerifyOffset(verifier, VT_AUTO_GEN_SCHEDULES) && + verifier.VerifyVector(auto_gen_schedules()) && + verifier.VerifyVectorOfTables(auto_gen_schedules()) && + verifier.EndTable(); } }; struct FusionCacheBuilder { typedef FusionCache Table; - ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::FlatBufferBuilder& fbb_; ::flatbuffers::uoffset_t start_; void add_max_fusions(uint64_t max_fusions) { fbb_.AddElement(FusionCache::VT_MAX_FUSIONS, max_fusions, 0); } - void add_structure(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> structure) { + void add_structure( + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> structure) { fbb_.AddOffset(FusionCache::VT_STRUCTURE, structure); } - void add_terminal_nodes(::flatbuffers::Offset<::flatbuffers::Vector> terminal_nodes) { + void add_terminal_nodes( + ::flatbuffers::Offset<::flatbuffers::Vector> terminal_nodes) { fbb_.AddOffset(FusionCache::VT_TERMINAL_NODES, terminal_nodes); } - void add_auto_gen_schedules(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> auto_gen_schedules) { + void add_auto_gen_schedules( + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> + auto_gen_schedules) { fbb_.AddOffset(FusionCache::VT_AUTO_GEN_SCHEDULES, auto_gen_schedules); } - explicit FusionCacheBuilder(::flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit FusionCacheBuilder(::flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ::flatbuffers::Offset Finish() { @@ -3696,11 +3964,13 @@ struct FusionCacheBuilder { }; inline ::flatbuffers::Offset CreateFusionCache( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, uint64_t max_fusions = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> structure = 0, + ::flatbuffers::Offset<::flatbuffers::Vector< + ::flatbuffers::Offset>> structure = 0, ::flatbuffers::Offset<::flatbuffers::Vector> terminal_nodes = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> auto_gen_schedules = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset< + nvfuser::serde::FusionExecutorCache>>> auto_gen_schedules = 0) { FusionCacheBuilder builder_(_fbb); builder_.add_max_fusions(max_fusions); builder_.add_auto_gen_schedules(auto_gen_schedules); @@ -3710,201 +3980,227 @@ inline ::flatbuffers::Offset CreateFusionCache( } inline ::flatbuffers::Offset CreateFusionCacheDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::FlatBufferBuilder& _fbb, uint64_t max_fusions = 0, - const std::vector<::flatbuffers::Offset> *structure = nullptr, - const std::vector *terminal_nodes = nullptr, - const std::vector<::flatbuffers::Offset> *auto_gen_schedules = nullptr) { - auto structure__ = structure ? _fbb.CreateVector<::flatbuffers::Offset>(*structure) : 0; - auto terminal_nodes__ = terminal_nodes ? _fbb.CreateVector(*terminal_nodes) : 0; - auto auto_gen_schedules__ = auto_gen_schedules ? _fbb.CreateVector<::flatbuffers::Offset>(*auto_gen_schedules) : 0; + const std::vector<::flatbuffers::Offset>* + structure = nullptr, + const std::vector* terminal_nodes = nullptr, + const std::vector<::flatbuffers::Offset< + nvfuser::serde::FusionExecutorCache>>* auto_gen_schedules = nullptr) { + auto structure__ = structure + ? _fbb.CreateVector<::flatbuffers::Offset>( + *structure) + : 0; + auto terminal_nodes__ = + terminal_nodes ? _fbb.CreateVector(*terminal_nodes) : 0; + auto auto_gen_schedules__ = auto_gen_schedules + ? _fbb.CreateVector< + ::flatbuffers::Offset>( + *auto_gen_schedules) + : 0; return nvfuser::serde::CreateFusionCache( - _fbb, - max_fusions, - structure__, - terminal_nodes__, - auto_gen_schedules__); + _fbb, max_fusions, structure__, terminal_nodes__, auto_gen_schedules__); } -inline bool VerifyRecordData(::flatbuffers::Verifier &verifier, const void *obj, RecordData type) { +inline bool VerifyRecordData( + ::flatbuffers::Verifier& verifier, + const void* obj, + RecordData type) { switch (type) { case RecordData_NONE: { return true; } case RecordData_BatchNorm: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Broadcast: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_BroadcastInDim: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_BroadcastInDimSymbolic: { - auto ptr = reinterpret_cast(obj); + auto ptr = + reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Dimension: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Dtype: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Norm: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Output: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Pad: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Permute: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Slice: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Squeeze: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Reduction: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Reshape: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Scalar: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Tensor: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_TensorCreation: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_TensorCreationSymbolic: { - auto ptr = reinterpret_cast(obj); + auto ptr = + reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case RecordData_Vector: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - default: return true; + default: + return true; } } -inline bool VerifyRecordDataVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types) { - if (!values || !types) return !values && !types; - if (values->size() != types->size()) return false; +inline bool VerifyRecordDataVector( + ::flatbuffers::Verifier& verifier, + const ::flatbuffers::Vector<::flatbuffers::Offset>* values, + const ::flatbuffers::Vector* types) { + if (!values || !types) + return !values && !types; + if (values->size() != types->size()) + return false; for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { if (!VerifyRecordData( - verifier, values->Get(i), types->GetEnum(i))) { + verifier, values->Get(i), types->GetEnum(i))) { return false; } } return true; } -inline bool VerifyArgAbstractData(::flatbuffers::Verifier &verifier, const void *obj, ArgAbstractData type) { +inline bool VerifyArgAbstractData( + ::flatbuffers::Verifier& verifier, + const void* obj, + ArgAbstractData type) { switch (type) { case ArgAbstractData_NONE: { return true; } case ArgAbstractData_Scalar: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case ArgAbstractData_PhiloxCudaState: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case ArgAbstractData_ScalarCpu: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case ArgAbstractData_TensorArg: { - auto ptr = reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - default: return true; + default: + return true; } } -inline bool VerifyArgAbstractDataVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types) { - if (!values || !types) return !values && !types; - if (values->size() != types->size()) return false; +inline bool VerifyArgAbstractDataVector( + ::flatbuffers::Verifier& verifier, + const ::flatbuffers::Vector<::flatbuffers::Offset>* values, + const ::flatbuffers::Vector* types) { + if (!values || !types) + return !values && !types; + if (values->size() != types->size()) + return false; for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { if (!VerifyArgAbstractData( - verifier, values->Get(i), types->GetEnum(i))) { + verifier, values->Get(i), types->GetEnum(i))) { return false; } } return true; } -inline const nvfuser::serde::FusionCache *GetFusionCache(const void *buf) { +inline const nvfuser::serde::FusionCache* GetFusionCache(const void* buf) { return ::flatbuffers::GetRoot(buf); } -inline const nvfuser::serde::FusionCache *GetSizePrefixedFusionCache(const void *buf) { +inline const nvfuser::serde::FusionCache* GetSizePrefixedFusionCache( + const void* buf) { return ::flatbuffers::GetSizePrefixedRoot(buf); } -inline const char *FusionCacheIdentifier() { +inline const char* FusionCacheIdentifier() { return "NV00"; } -inline bool FusionCacheBufferHasIdentifier(const void *buf) { - return ::flatbuffers::BufferHasIdentifier( - buf, FusionCacheIdentifier()); +inline bool FusionCacheBufferHasIdentifier(const void* buf) { + return ::flatbuffers::BufferHasIdentifier(buf, FusionCacheIdentifier()); } -inline bool SizePrefixedFusionCacheBufferHasIdentifier(const void *buf) { - return ::flatbuffers::BufferHasIdentifier( - buf, FusionCacheIdentifier(), true); +inline bool SizePrefixedFusionCacheBufferHasIdentifier(const void* buf) { + return ::flatbuffers::BufferHasIdentifier(buf, FusionCacheIdentifier(), true); } -inline bool VerifyFusionCacheBuffer( - ::flatbuffers::Verifier &verifier) { - return verifier.VerifyBuffer(FusionCacheIdentifier()); +inline bool VerifyFusionCacheBuffer(::flatbuffers::Verifier& verifier) { + return verifier.VerifyBuffer( + FusionCacheIdentifier()); } inline bool VerifySizePrefixedFusionCacheBuffer( - ::flatbuffers::Verifier &verifier) { - return verifier.VerifySizePrefixedBuffer(FusionCacheIdentifier()); + ::flatbuffers::Verifier& verifier) { + return verifier.VerifySizePrefixedBuffer( + FusionCacheIdentifier()); } inline void FinishFusionCacheBuffer( - ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::FlatBufferBuilder& fbb, ::flatbuffers::Offset root) { fbb.Finish(root, FusionCacheIdentifier()); } inline void FinishSizePrefixedFusionCacheBuffer( - ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::FlatBufferBuilder& fbb, ::flatbuffers::Offset root) { fbb.FinishSizePrefixed(root, FusionCacheIdentifier()); } -} // namespace serde -} // namespace nvfuser +} // namespace serde +} // namespace nvfuser -#endif // FLATBUFFERS_GENERATED_FUSIONCACHE_NVFUSER_SERDE_H_ +#endif // FLATBUFFERS_GENERATED_FUSIONCACHE_NVFUSER_SERDE_H_ diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 11797381f13..3720397e0d4 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -723,7 +723,7 @@ ForwardingInfo::ForwardingInfo( // We have root axes in active_tv that don't exist in the inactive tensor, // now forward those to include all id's in active_tv comprised of only axes // not in the inactive tensor. - std::vector active_tv_history = StmtSort::getExprs( + auto active_tv_history = StmtSort::getExprsTo( FusionGuard::getCurFusion(), std::vector( active_tv->domain()->leaf().begin(), diff --git a/csrc/type_traits.h b/csrc/type_traits.h index 8f285111fba..ad02e64a21b 100644 --- a/csrc/type_traits.h +++ b/csrc/type_traits.h @@ -291,7 +291,7 @@ DEFINE_BINARY_OP(>>=); // it manually. template constexpr auto operator,(OperatorChecker, OperatorChecker) - -> decltype((std::declval(), std::declval()), true) { + -> decltype((std::declval(), std::declval()), true) { return true; } @@ -588,8 +588,9 @@ constexpr bool any_check(Fun f, Tuples... tuples) { } // For example: -static_assert( - any_check([](auto x) constexpr { return x > 0; }, std::make_tuple(1, -1))); +static_assert(any_check( + [](auto x) constexpr { return x > 0; }, + std::make_tuple(1, -1))); static_assert(!any_check( [](auto x) constexpr { return x > 0; }, std::make_tuple(-2, -1))); diff --git a/csrc/utils.h b/csrc/utils.h index e541e951d06..037e02d24cc 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -192,8 +192,9 @@ std::vector getSortedKeys( // Based on https://stackoverflow.com/a/9154394 template -static auto hasToStringHelper(int) - -> decltype(std::declval::type>().toString(), std::true_type{}); +static auto hasToStringHelper(int) -> decltype( + std::declval::type>().toString(), + std::true_type{}); template static auto hasToStringHelper(long) -> std::false_type; diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index b5fa1af93e1..7192b4daeda 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -1091,17 +1091,11 @@ TEST_F(NVFuserTest, FusionIndexSplitMerge_CUDA) { tv3->merge(1); tv3->split(1, 5); - MaxRootDomainInfoSpanningTree tree(tv5); - TransformPropagator tp(tv5); + MaxRootDomainInfoSpanningTree tree(tv3); + TransformPropagator tp(tv3); tree.traverse(&tp); - inlineAllAt(tv4, 1, true); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({5}, options); - at::Tensor t1 = at::randn({5, 3}, options); - std::vector inputs = {t0, t1}; - + inlineAllAt(tv3, 1, true); FusionExecutor fe; int x = 4, y = 7; From cd593fd027fcb39352cf618edd0b446a869ab0ae Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 15 Aug 2023 10:20:57 -0700 Subject: [PATCH 041/178] Fix build with clang --- csrc/id_model/id_graph.cpp | 2 ++ csrc/id_model/transform_replay.cpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index 91699cf068c..241078aeaef 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -259,6 +259,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) return true; }; +#if 0 auto allIdUsesVisisted = [&](IdGroup id) { auto uses_pair = iterDomainGroupUses(id); if (!uses_pair.second) { @@ -274,6 +275,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) } return true; }; +#endif // Returns all expression groups in required_ind_exprs_ids of outputs auto requiredExprsOutputs = [&](ExprGroup expr) { diff --git a/csrc/id_model/transform_replay.cpp b/csrc/id_model/transform_replay.cpp index 562b4687987..93d9d37ef90 100644 --- a/csrc/id_model/transform_replay.cpp +++ b/csrc/id_model/transform_replay.cpp @@ -166,4 +166,4 @@ void ReplacementTransformCloner::handle(const Resize* resize) { new_expr_ = IrBuilder::create( resize_out, resize_in, resize->leftExpand(), resize->rightExpand()); } -} // namespace nvfuser \ No newline at end of file +} // namespace nvfuser From f8c1812d816afb2f9e5f0f96993580c7571a6505 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 15 Aug 2023 10:29:54 -0700 Subject: [PATCH 042/178] clang-format --- csrc/python_frontend/fusion_record.h | 31 ++- csrc/python_frontend/python_bindings.cpp | 185 +++++++++--------- .../test/test_nvfuser_fusion_record.cpp | 21 +- csrc/scheduler/mma_utils.cpp | 8 +- csrc/type_traits.h | 7 +- csrc/utils.h | 5 +- 6 files changed, 136 insertions(+), 121 deletions(-) diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index e4c5e4cbdcf..7a27486cc95 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -1674,21 +1674,32 @@ struct ReductionOpRecord : RecordFunctor { result = result && (*fusion_op_.template target< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() == + TensorView* (*)(TensorView*, + const std::vector&, + bool, + DataType)>() == *child_ptr->fusion_op_.template target< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>()); + TensorView* (*)(TensorView*, + const std::vector&, + bool, + DataType)>()); if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() - << " Target Ptr [self: 0x" << std::hex - << (size_t)*fusion_op_.template target< + debug() << " Target Ptr [self: 0x" << std::hex + << (size_t)*fusion_op_.template target< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() - << "] [other: 0x" << std::hex - << (size_t)*child_ptr->fusion_op_.template target< + TensorView* (*)(TensorView*, + const std::vector&, + bool, + DataType)>() + << "] [other: 0x" << std::hex + << (size_t)*child_ptr->fusion_op_.template target< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() - << "]\n"; + TensorView* (*)(TensorView*, + const std::vector&, + bool, + DataType)>() + << "]\n"; } result = result && (keep_dim_ == child_ptr->keep_dim_); result = result && (dtype_ == child_ptr->dtype_); diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 6fa10a9934f..d09be4142ab 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1564,97 +1564,100 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP("addcmul", addcmul) #undef NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP -#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name, record_type) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - TORCH_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = 0; \ - std::vector axes(arg.dims); \ - std::iota(axes.begin(), axes.end(), 0); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast< \ - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ - op_name), \ - axes, \ - false, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - int axis, \ - bool keepdim, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - TORCH_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - 1); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast< \ - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ - op_name), \ - {axis}, \ - keepdim, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("axis"), \ - py::arg("keepdim") = false, \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - const std::vector& axes, \ - bool keepdim, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - TORCH_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast< \ - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ - op_name), \ - axes, \ - keepdim, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("axes"), \ - py::arg("keepdim") = false, \ - py::arg("dtype") = DataType::Null, \ +#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name, record_type) \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + TORCH_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = 0; \ + std::vector axes(arg.dims); \ + std::iota(axes.begin(), axes.end(), 0); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast&, \ + bool, \ + DataType)>(op_name), \ + axes, \ + false, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("dtype") = DataType::Null, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + int axis, \ + bool keepdim, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + TORCH_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = keepdim ? arg.dims : (arg.dims - 1); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast&, \ + bool, \ + DataType)>(op_name), \ + {axis}, \ + keepdim, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("axis"), \ + py::arg("keepdim") = false, \ + py::arg("dtype") = DataType::Null, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + const std::vector& axes, \ + bool keepdim, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + TORCH_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast&, \ + bool, \ + DataType)>(op_name), \ + axes, \ + keepdim, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("axes"), \ + py::arg("keepdim") = false, \ + py::arg("dtype") = DataType::Null, \ py::return_value_policy::reference); NVFUSER_PYTHON_BINDING_REDUCTION_OP( diff --git a/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp b/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp index 170531cfaad..e0eabf5122d 100644 --- a/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp +++ b/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp @@ -98,9 +98,10 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( - sum), + static_cast&, + bool, + DataType)>(sum), {0}, false, DataType::Float)); @@ -109,9 +110,10 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( - sum), + static_cast&, + bool, + DataType)>(sum), {0}, false, DataType::Float)); @@ -120,9 +122,10 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( - sum), + static_cast&, + bool, + DataType)>(sum), {0}, false, DataType::Float)); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 7264ac35547..546e600908f 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -40,11 +40,11 @@ bool generateSharedMemoryEpilogueHeuristics( properties->warpSize * vector_word; const int mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k; const int nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k; - const size_t smem_a = - (size_t)(ceilDiv(mk, round_to_factor) * round_to_factor * smem_double_buffer_stage) * + const size_t smem_a = (size_t)(ceilDiv(mk, round_to_factor) * + round_to_factor * smem_double_buffer_stage) * dataTypeSize(data_types[0]); - const size_t smem_b = - (size_t)(ceilDiv(nk, round_to_factor) * round_to_factor * smem_double_buffer_stage) * + const size_t smem_b = (size_t)(ceilDiv(nk, round_to_factor) * + round_to_factor * smem_double_buffer_stage) * dataTypeSize(data_types[1]); const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * dataTypeSize(data_types[2]); diff --git a/csrc/type_traits.h b/csrc/type_traits.h index ad02e64a21b..8f285111fba 100644 --- a/csrc/type_traits.h +++ b/csrc/type_traits.h @@ -291,7 +291,7 @@ DEFINE_BINARY_OP(>>=); // it manually. template constexpr auto operator,(OperatorChecker, OperatorChecker) - -> decltype((std::declval(), std::declval()), true) { + -> decltype((std::declval(), std::declval()), true) { return true; } @@ -588,9 +588,8 @@ constexpr bool any_check(Fun f, Tuples... tuples) { } // For example: -static_assert(any_check( - [](auto x) constexpr { return x > 0; }, - std::make_tuple(1, -1))); +static_assert( + any_check([](auto x) constexpr { return x > 0; }, std::make_tuple(1, -1))); static_assert(!any_check( [](auto x) constexpr { return x > 0; }, std::make_tuple(-2, -1))); diff --git a/csrc/utils.h b/csrc/utils.h index 34ea1b5070d..55feefc76fd 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -192,9 +192,8 @@ std::vector getSortedKeys( // Based on https://stackoverflow.com/a/9154394 template -static auto hasToStringHelper(int) -> decltype( - std::declval::type>().toString(), - std::true_type{}); +static auto hasToStringHelper(int) + -> decltype(std::declval::type>().toString(), std::true_type{}); template static auto hasToStringHelper(long) -> std::false_type; From b75e197bdce337bfba67c844d2e391aca65fd20f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 30 Aug 2023 12:00:06 -0700 Subject: [PATCH 043/178] Cosmetic cleanup Replaced auto with actual type just to make it easy to follow the code --- csrc/id_model/id_graph.cpp | 146 ++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 73 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index 241078aeaef..667add5705a 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -23,7 +23,7 @@ IdGraph::IdGraph(const IdGraph& other) auto new_id_group = toGroup(orig_id_group->front()); ExprGroups new_expr_groups; - for (auto orig_expr_group : orig_expr_groups) { + for (const ExprGroup& orig_expr_group : orig_expr_groups) { new_expr_groups.pushBack(toGroup(orig_expr_group->front())); } @@ -36,7 +36,7 @@ IdGraph::IdGraph(const IdGraph& other) auto new_id_group = toGroup(orig_id_group->front()); ExprGroups new_expr_groups; - for (auto orig_expr_group : orig_expr_groups) { + for (const ExprGroup& orig_expr_group : orig_expr_groups) { new_expr_groups.pushBack(toGroup(orig_expr_group->front())); } @@ -199,12 +199,12 @@ ExprGroups IdGraph::allDefinitionsOf(const IdGroups& of) const { ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) const { - auto all_uses_of_from = allUsesOf(from); - auto all_definitions_of_to = allDefinitionsOf(to); + ExprGroups all_uses_of_from = allUsesOf(from); + ExprGroups all_definitions_of_to = allDefinitionsOf(to); // All of the expressions between from and to. Not all will be used as we // just want to define each iter domain group once. - auto all_exprs = all_uses_of_from.intersect(all_definitions_of_to); + ExprGroups all_exprs = all_uses_of_from.intersect(all_definitions_of_to); // There could be IterDomains in from or to that are between other from and // to nodes. Make sure to clear those out. @@ -215,10 +215,10 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) IdGroups not_outputs; IdGroups all_id_groups; - for (auto expr_group : all_exprs) { - auto inp_groups = inputGroups(expr_group); - auto out_groups = outputGroups(expr_group); - if (IdGroups(inp_groups).intersect(IdGroups(out_groups)).size() > 0) { + for (const ExprGroup& expr_group : all_exprs) { + std::vector inp_groups = inputGroups(expr_group); + std::vector out_groups = outputGroups(expr_group); + if (!IdGroups(inp_groups).intersect(IdGroups(out_groups)).empty()) { // Expression is just a loop to its current group, ignore continue; } @@ -278,47 +278,47 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) #endif // Returns all expression groups in required_ind_exprs_ids of outputs - auto requiredExprsOutputs = [&](ExprGroup expr) { + auto requiredExprsOutputs = [&](ExprGroup expr_group) -> ExprGroups { ExprGroups all_output_required_exprs; - for (auto id_group : outputGroups(expr)) { - auto id_group_exprs_it = required_ind_exprs_ids.find(id_group); + for (const IdGroup& output_id_group : outputGroups(expr_group)) { + auto id_group_exprs_it = required_ind_exprs_ids.find(output_id_group); TORCH_INTERNAL_ASSERT( id_group_exprs_it != required_ind_exprs_ids.end(), "Failure in Iter Domain Graph index resolution, count expected for group: ", - id_group->toString()); + output_id_group->toString()); all_output_required_exprs.pushBack(id_group_exprs_it->second); } return all_output_required_exprs; }; - auto processExpr = [&](ExprGroup expr) { - if (!outputsVisited(expr)) { + auto processExprGroup = [&](ExprGroup expr_group) -> bool { + if (!outputsVisited(expr_group)) { return false; } // Accumulate expressions from all outputs add this expression and set it // as current expressions required indexing expressions. - required_ind_exprs_exprs[expr] = requiredExprsOutputs(expr); + required_ind_exprs_exprs[expr_group] = requiredExprsOutputs(expr_group); return true; }; - auto processId = [&](IdGroup id) { + auto processIdGroup = [&](IdGroup id_group) -> bool { // Track if we've grabed any of the uses required indexing expressions. bool initialized = false; // Expression group of all indexing expressions required for this iter // domain coming back from any of its uses. ExprGroups min_groups; - auto uses_pair = iterDomainGroupUses(id); + auto uses_pair = iterDomainGroupUses(id_group); if (!uses_pair.second) { // No expressions required for this iter domain, it must be a // terminating output. - required_ind_exprs_ids[id] = min_groups; + required_ind_exprs_ids[id_group] = min_groups; return true; } // Only worry about expressions between inputs and outputs we're // looking at. - for (auto use_group : uses_pair.first.intersect(all_exprs)) { + for (const ExprGroup& use_group : uses_pair.first.intersect(all_exprs)) { auto use_required_ind_exprs_it = required_ind_exprs_exprs.find(use_group); if (use_required_ind_exprs_it == required_ind_exprs_exprs.end()) { // If there isn't an entry for the use expression it wasn't @@ -338,14 +338,15 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) use_required_ind_exprs_it->second.computeUnion({use_group}); } } - required_ind_exprs_ids[id] = min_groups; + required_ind_exprs_ids[id_group] = min_groups; return true; }; + // Backward traversal from the terminating outputs IdGroups to_visit_ids = terminating_outputs; ExprGroups to_visit_exprs; - while (to_visit_ids.size() > 0 || to_visit_exprs.size() > 0) { + while (!to_visit_ids.empty() || !to_visit_exprs.empty()) { // Process expressions first as all uses of iter domains have to be // processed before we can process that iter domain. @@ -353,38 +354,39 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) // infinite loop bool something_was_processed = false; ExprGroups still_to_visit_exprs; - while (to_visit_exprs.size() > 0) { - auto currently_visiting = to_visit_exprs.popFront(); - if (required_ind_exprs_exprs.find(currently_visiting) != + while (!to_visit_exprs.empty()) { + ExprGroup currently_visiting_exprs = to_visit_exprs.popFront(); + if (required_ind_exprs_exprs.find(currently_visiting_exprs) != required_ind_exprs_exprs.end()) { continue; } - if (processExpr(currently_visiting)) { + if (processExprGroup(currently_visiting_exprs)) { something_was_processed = true; - auto inp_groups = inputGroups(currently_visiting); - for (auto inp_group : inp_groups) { + std::vector inp_groups = inputGroups(currently_visiting_exprs); + for (const IdGroup& inp_group : inp_groups) { to_visit_ids.pushBack(inp_group); } } else { - still_to_visit_exprs.pushBack(currently_visiting); + still_to_visit_exprs.pushBack(currently_visiting_exprs); } } std::swap(to_visit_exprs, still_to_visit_exprs); IdGroups still_to_visit_ids; - while (to_visit_ids.size() > 0) { - auto currently_visiting = to_visit_ids.popFront(); - if (required_ind_exprs_ids.find(currently_visiting) != + while (!to_visit_ids.empty()) { + auto currently_visiting_ids = to_visit_ids.popFront(); + if (required_ind_exprs_ids.find(currently_visiting_ids) != required_ind_exprs_ids.end()) { continue; } - if (processId(currently_visiting)) { + if (processIdGroup(currently_visiting_ids)) { something_was_processed = true; - auto definitions_pair = iterDomainGroupDefinitions(currently_visiting); + auto definitions_pair = + iterDomainGroupDefinitions(currently_visiting_ids); if (definitions_pair.second) { - for (auto def : definitions_pair.first) { + for (const ExprGroup& def : definitions_pair.first) { if (!all_exprs.has(def)) { continue; } @@ -395,7 +397,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) } } } else { - still_to_visit_ids.pushBack(currently_visiting); + still_to_visit_ids.pushBack(currently_visiting_ids); } } @@ -422,28 +424,28 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) // Topologically sort the uses_path. ExprGroups sorted_exprs; - ExprGroups to_visit; + ExprGroups to_visit_expr_groups; - for (auto inp : terminating_inputs) { + for (const IdGroup& inp : terminating_inputs) { auto use_it = uses_path.find(inp); if (use_it == uses_path.end()) { // This can happen for a trivial traversal where inputs and outputs are // exactly the same. continue; } - auto uses = use_it->second; - for (auto use : uses) { - to_visit.pushBack(use); + const ExprGroups& uses = use_it->second; + for (const ExprGroup& use : uses) { + to_visit_expr_groups.pushBack(use); } } IdGroups visited = terminating_inputs; - while (to_visit.size() > 0) { + while (!to_visit_expr_groups.empty()) { bool something_processed = false; ExprGroups still_to_visit; - while (to_visit.size() > 0) { - auto currently_visiting = to_visit.popFront(); + while (!to_visit_expr_groups.empty()) { + auto currently_visiting = to_visit_expr_groups.popFront(); auto inputs = inputGroups(currently_visiting); if (std::all_of(inputs.begin(), inputs.end(), [&](IdGroup inp_id) { return visited.has(inp_id); @@ -463,7 +465,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) still_to_visit.pushBack(currently_visiting); } } - std::swap(to_visit, still_to_visit); + std::swap(to_visit_expr_groups, still_to_visit); TORCH_INTERNAL_ASSERT(something_processed, "Infinite loop entered."); } @@ -604,7 +606,7 @@ bool IdGraph::transformAtributesMatch(Expr* first, Expr* second) { TORCH_INTERNAL_ASSERT( first->isA() || first->isA() || first->isA() || first->isA(), - "Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n", + "Unsupported rfactor expressions in compute at map:\n", first->toString()); if (typeid(*first) != typeid(*second)) { @@ -642,14 +644,16 @@ void IdGraph::initializeId( ExprGroups def_groups; for (auto def : definitions) { - auto expr_set = disjointExprSets().initializeSet(def).first->second; + const ExprGroup& expr_set = + disjointExprSets().initializeSet(def).first->second; def_groups.pushBack(expr_set); } unique_definitions_[id_disjoint_set] = def_groups; ExprGroups use_groups; for (auto use : uses) { - auto expr_set = disjointExprSets().initializeSet(use).first->second; + const ExprGroup& expr_set = + disjointExprSets().initializeSet(use).first->second; use_groups.pushBack(expr_set); } unique_uses_[id_disjoint_set] = use_groups; @@ -795,35 +799,31 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { unique_uses_[new_id_group] = orig_uses0.computeUnion(orig_uses1); // Propagate on uses - if (orig_uses0.size() > 0 || orig_uses1.size() > 0) { - if (orig_uses0.size() > 0 && orig_uses1.size() > 0) { - for (auto use_group_1 : orig_uses1) { - if (orig_uses0.has(use_group_1)) { - continue; - } + if (!orig_uses0.empty() && !orig_uses1.empty()) { + for (const ExprGroup& use_group_1 : orig_uses1) { + if (orig_uses0.has(use_group_1)) { + continue; + } - for (auto use_group_0 : orig_uses0) { - auto use0 = use_group_0->front(); - auto use1 = use_group_1->front(); - maybeMapThroughExprs(use0, use1, true); - } + for (const ExprGroup& use_group_0 : orig_uses0) { + Expr* use0 = use_group_0->front(); + Expr* use1 = use_group_1->front(); + maybeMapThroughExprs(use0, use1, true); } } } // Propagate on definitions - if (orig_defs0.size() > 0 || orig_defs1.size() > 0) { - if (orig_defs0.size() > 0 && orig_defs1.size() > 0) { - for (auto def_group_1 : orig_defs1) { - if (orig_defs0.has(def_group_1)) { - continue; - } + if (!orig_defs0.empty() && !orig_defs1.empty()) { + for (const ExprGroup& def_group_1 : orig_defs1) { + if (orig_defs0.has(def_group_1)) { + continue; + } - for (auto def_group_0 : orig_defs0) { - auto def0 = def_group_0->front(); - auto def1 = def_group_1->front(); - maybeMapThroughExprs(def0, def1, false); - } + for (const ExprGroup& def_group_0 : orig_defs0) { + auto def0 = def_group_0->front(); + auto def1 = def_group_1->front(); + maybeMapThroughExprs(def0, def1, false); } } } @@ -922,7 +922,7 @@ bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { void IdGraph::mapThroughLoopSwizzles() { std::vector all_swizzles; - for (auto expr_set : disjointExprSets().disjointSets()) { + for (const auto& expr_set : disjointExprSets().disjointSets()) { auto swizzles_in_expr_set = ir_utils::filterByType( expr_set->vector().begin(), expr_set->vector().end()); all_swizzles.insert( @@ -943,7 +943,7 @@ void IdGraph::mapThroughTrivialExprs() { // Grab all expressions std::vector exprs; - for (auto expr_group : disjointExprSets().disjointSets()) { + for (const auto& expr_group : disjointExprSets().disjointSets()) { for (auto expr : *expr_group) { exprs.push_back(expr); } @@ -968,7 +968,7 @@ void IdGraph::mapThroughTrivialExprs() { void IdGraph::removeTrivialExprs() { ExprGroups trivial_expr_groups; // This seems like it shouls just be a copy if. - for (auto expr_group : disjointExprSets().disjointSets()) { + for (const ExprGroup& expr_group : disjointExprSets().disjointSets()) { if (isTrivialExprGroup(expr_group)) { trivial_expr_groups.pushBack(expr_group); } From 8cf25afd3f65ec923a1c7a91b025cf3a814b38d4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 30 Aug 2023 18:03:31 -0700 Subject: [PATCH 044/178] Cleanup --- csrc/id_model/id_graph.cpp | 78 ++++++++++---------------- csrc/id_model/id_graph.h | 108 +++++++++++++++++++----------------- csrc/id_model/id_graphs.cpp | 18 +++--- csrc/id_model/to_string.cpp | 4 +- csrc/id_model/visitor.cpp | 6 +- 5 files changed, 99 insertions(+), 115 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index 667add5705a..f864551403a 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -14,7 +14,6 @@ namespace nvfuser { IdGraph::IdGraph(const IdGraph& other) : disjoint_ids_(other.disjoint_ids_), disjoint_exprs_(other.disjoint_exprs_), - view_rfactor_ids_(other.view_rfactor_ids_), unique_definitions_(), unique_uses_() { for (auto orig_unique_def_pair : other.unique_definitions_) { @@ -49,28 +48,11 @@ IdGraph& IdGraph::operator=(const IdGraph& other) { disjoint_exprs_.clear(); unique_definitions_.clear(); unique_uses_.clear(); - view_rfactor_ids_.clear(); IdGraph copy(other); std::swap(*this, copy); return *this; } -const DisjointSets& IdGraph::disjointIdSets() const { - return disjoint_ids_; -} - -DisjointSets& IdGraph::disjointIdSets() { - return disjoint_ids_; -} - -const DisjointSets& IdGraph::disjointExprSets() const { - return disjoint_exprs_; -} - -DisjointSets& IdGraph::disjointExprSets() { - return disjoint_exprs_; -} - // Return if there's a group entry in the graph for this expr bool IdGraph::hasGroup(Expr* expr) const { return disjoint_exprs_.mappingExists(expr); @@ -81,7 +63,7 @@ bool IdGraph::hasGroup(IterDomain* id) const { return disjoint_ids_.mappingExists(id); } -ExprGroup IdGraph::toGroup(Expr* expr) const { +const ExprGroup& IdGraph::toGroup(Expr* expr) const { auto disjoint_set_it = disjoint_exprs_.disjointSetMap().find(expr); TORCH_INTERNAL_ASSERT( disjoint_set_it != disjoint_exprs_.disjointSetMap().end(), @@ -90,7 +72,7 @@ ExprGroup IdGraph::toGroup(Expr* expr) const { return disjoint_set_it->second; } -IdGroup IdGraph::toGroup(IterDomain* id) const { +const IdGroup& IdGraph::toGroup(IterDomain* id) const { auto disjoint_set_it = disjoint_ids_.disjointSetMap().find(id); TORCH_INTERNAL_ASSERT( disjoint_set_it != disjoint_ids_.disjointSetMap().end(), @@ -117,7 +99,7 @@ IdGroups IdGraph::toGroups( return id_groups; } -std::vector IdGraph::outputGroups(ExprGroup expr) const { +std::vector IdGraph::outputGroups(const ExprGroup& expr) const { std::vector output_groups; for (auto id_output : ir_utils::filterByType(expr->front()->outputs())) { @@ -126,7 +108,7 @@ std::vector IdGraph::outputGroups(ExprGroup expr) const { return output_groups; } -std::vector IdGraph::inputGroups(ExprGroup expr) const { +std::vector IdGraph::inputGroups(const ExprGroup& expr) const { std::vector input_groups; for (auto id_input : ir_utils::filterByType(expr->front()->inputs())) { @@ -138,7 +120,7 @@ std::vector IdGraph::inputGroups(ExprGroup expr) const { ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { ExprGroups to_visit; for (auto of_id_group : of) { - auto group_uses_pair = iterDomainGroupUses(of_id_group); + auto group_uses_pair = getUses(of_id_group); if (group_uses_pair.second) { to_visit.pushBack(group_uses_pair.first); } @@ -150,7 +132,7 @@ ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { visited.pushBack(current_expr); auto output_ids = outputGroups(current_expr); for (auto output_id : output_ids) { - auto group_uses_pair = iterDomainGroupUses(output_id); + auto group_uses_pair = getUses(output_id); if (!group_uses_pair.second) { continue; } @@ -169,7 +151,7 @@ ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { ExprGroups IdGraph::allDefinitionsOf(const IdGroups& of) const { ExprGroups to_visit; for (auto of_id_group : of) { - auto group_defs_pair = iterDomainGroupDefinitions(of_id_group); + auto group_defs_pair = getDefinitions(of_id_group); if (group_defs_pair.second) { to_visit.pushBack(group_defs_pair.first); } @@ -181,7 +163,7 @@ ExprGroups IdGraph::allDefinitionsOf(const IdGroups& of) const { visited.pushBack(current_expr); auto input_ids = inputGroups(current_expr); for (auto input_id : input_ids) { - auto group_defs_pair = iterDomainGroupDefinitions(input_id); + auto group_defs_pair = getDefinitions(input_id); if (!group_defs_pair.second) { continue; } @@ -308,7 +290,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) // domain coming back from any of its uses. ExprGroups min_groups; - auto uses_pair = iterDomainGroupUses(id_group); + auto uses_pair = getUses(id_group); if (!uses_pair.second) { // No expressions required for this iter domain, it must be a // terminating output. @@ -383,8 +365,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) if (processIdGroup(currently_visiting_ids)) { something_was_processed = true; - auto definitions_pair = - iterDomainGroupDefinitions(currently_visiting_ids); + auto definitions_pair = getDefinitions(currently_visiting_ids); if (definitions_pair.second) { for (const ExprGroup& def : definitions_pair.first) { if (!all_exprs.has(def)) { @@ -413,7 +394,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) for (auto entry : required_ind_exprs_ids) { auto id = entry.first; auto traverse_exprs = entry.second; - auto all_uses = iterDomainGroupUses(id); + auto all_uses = getUses(id); if (all_uses.second) { uses_path[id] = traverse_exprs.intersect(all_uses.first); } else { @@ -455,7 +436,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) auto outputs = outputGroups(currently_visiting); for (auto out_id : outputs) { visited.pushBack(out_id); - auto use_pair = iterDomainGroupUses(out_id); + auto use_pair = getUses(out_id); if (!use_pair.second) { continue; } @@ -528,8 +509,8 @@ std::unordered_map> IdGraph:: return buildMapBetween(from.vector(), to.vector()); } -std::pair IdGraph::iterDomainGroupDefinitions( - IdGroup id_group) const { +std::pair IdGraph::getDefinitions( + const IdGroup& id_group) const { auto null_return = std::make_pair(ExprGroups(), false); if (id_group == nullptr) { @@ -544,8 +525,7 @@ std::pair IdGraph::iterDomainGroupDefinitions( return std::make_pair(definitions_it->second, true); } -std::pair IdGraph::iterDomainGroupUses( - IdGroup id_group) const { +std::pair IdGraph::getUses(const IdGroup& id_group) const { auto null_return = std::make_pair(ExprGroups(), false); if (id_group == nullptr) { @@ -748,7 +728,7 @@ bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { return true; } -ExprGroups IdGraph::uniqueDefinitions(IdGroup group) const { +const ExprGroups& IdGraph::getUniqueDefinitions(const IdGroup& group) const { auto unique_defs_it = unique_definitions_.find(group); TORCH_INTERNAL_ASSERT( unique_defs_it != unique_definitions_.end(), @@ -757,7 +737,7 @@ ExprGroups IdGraph::uniqueDefinitions(IdGroup group) const { return unique_defs_it->second; } -ExprGroups IdGraph::uniqueUses(IdGroup group) const { +const ExprGroups& IdGraph::getUniqueUses(const IdGroup& group) const { auto unique_uses_it = unique_uses_.find(group); TORCH_INTERNAL_ASSERT( unique_uses_it != unique_uses_.end(), @@ -779,10 +759,10 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { // processing. auto orig_id_group0 = toGroup(id0); auto orig_id_group1 = toGroup(id1); - ExprGroups orig_defs0 = uniqueDefinitions(orig_id_group0); - ExprGroups orig_defs1 = uniqueDefinitions(orig_id_group1); - ExprGroups orig_uses0 = uniqueUses(orig_id_group0); - ExprGroups orig_uses1 = uniqueUses(orig_id_group1); + ExprGroups orig_defs0 = getUniqueDefinitions(orig_id_group0); + ExprGroups orig_defs1 = getUniqueDefinitions(orig_id_group1); + ExprGroups orig_uses0 = getUniqueUses(orig_id_group0); + ExprGroups orig_uses1 = getUniqueUses(orig_id_group1); // Map the iter domains together before we traverse across definitions and // uses. Traversing definitions and uses could use the new property of id0 and @@ -867,9 +847,9 @@ void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { } for (auto producer_group : producers) { - uniqueUses().at(producer_group).erase(expr0_orig_group); - uniqueUses().at(producer_group).erase(expr1_orig_group); - uniqueUses().at(producer_group).pushBack(expr_new_group); + unique_uses_.at(producer_group).erase(expr0_orig_group); + unique_uses_.at(producer_group).erase(expr1_orig_group); + unique_uses_.at(producer_group).pushBack(expr_new_group); } // Update unique definitinos of consumers @@ -881,9 +861,9 @@ void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { } for (auto consumer_group : consumers) { - uniqueDefinitions().at(consumer_group).erase(expr0_orig_group); - uniqueDefinitions().at(consumer_group).erase(expr1_orig_group); - uniqueDefinitions().at(consumer_group).pushBack(expr_new_group); + unique_definitions_.at(consumer_group).erase(expr0_orig_group); + unique_definitions_.at(consumer_group).erase(expr1_orig_group); + unique_definitions_.at(consumer_group).pushBack(expr_new_group); } } @@ -987,7 +967,7 @@ void IdGraph::removeTrivialExprs() { // Complexity here is not great. We might want a better complexity version when // erasing multiple expr_groups. -void IdGraph::eraseExprGroup(ExprGroup expr_group) { +void IdGraph::eraseExprGroup(const ExprGroup& expr_group) { // Erase entries that exist in unique_definitions_ and unique_uses_ for (auto id_group : disjointIdSets().disjointSets()) { // Make sure the entries exists @@ -1009,7 +989,7 @@ void IdGraph::eraseExprGroup(ExprGroup expr_group) { } } -bool IdGraph::isTrivialExprGroup(ExprGroup expr_group) const { +bool IdGraph::isTrivialExprGroup(const ExprGroup& expr_group) const { return !IdGroups(inputGroups(expr_group)) .intersect(IdGroups(outputGroups(expr_group))) .empty(); diff --git a/csrc/id_model/id_graph.h b/csrc/id_model/id_graph.h index 4581350f25a..35b2c025ffb 100644 --- a/csrc/id_model/id_graph.h +++ b/csrc/id_model/id_graph.h @@ -32,14 +32,22 @@ class TORCH_CUDA_CU_API IdGraph { IdGraph& operator=(IdGraph&& other) = default; // Returns the disjoint IterDomain set. - const DisjointSets& disjointIdSets() const; + const DisjointSets& disjointIdSets() const { + return disjoint_ids_; + } - DisjointSets& disjointIdSets(); + DisjointSets& disjointIdSets() { + return disjoint_ids_; + } // Returns the disjoint Expr set. - const DisjointSets& disjointExprSets() const; + const DisjointSets& disjointExprSets() const { + return disjoint_exprs_; + } - DisjointSets& disjointExprSets(); + DisjointSets& disjointExprSets() { + return disjoint_exprs_; + } // Return if there's a group entry in the graph for this expr bool hasGroup(Expr* expr) const; @@ -48,10 +56,10 @@ class TORCH_CUDA_CU_API IdGraph { bool hasGroup(IterDomain* id) const; // Convert expr to its exprGroup, assert that it exists. - ExprGroup toGroup(Expr* expr) const; + const ExprGroup& toGroup(Expr* expr) const; // Convert iter domain to its IdGroup, assert that it exists. - IdGroup toGroup(IterDomain* id) const; + const IdGroup& toGroup(IterDomain* id) const; // Convert unique vector of expressions to unique vector of its groups ExprGroups toGroups(const VectorOfUniqueEntries& exprs) const; @@ -60,15 +68,17 @@ class TORCH_CUDA_CU_API IdGraph { IdGroups toGroups(const VectorOfUniqueEntries& ids) const; // Return output/input iter domain groups of provided expr - std::vector outputGroups(ExprGroup expr) const; - std::vector inputGroups(ExprGroup expr) const; + // Note that the same IdGroup can show up multiple times, so the + // output type cannot be VectorOfUniqueEntries + std::vector outputGroups(const ExprGroup& expr) const; + std::vector inputGroups(const ExprGroup& expr) const; - // Traverses uses of the IdGroups in 'of' and returns all ExprGroups - // that have a use in their definition of provided of IdGroups. + // Recursively traverses uses of the IdGroups in 'of' and returns all + // ExprGroups that have a use in their definition of provided of IdGroups. ExprGroups allUsesOf(const IdGroups& of) const; - // Traverses definitions of the IdGroups in 'of' and returns all ExprGroups - // used in this history of defining the 'of' IdGroups. + // Recursively traverses definitions of the IdGroups in 'of' and returns all + // ExprGroups used in this history of defining the 'of' IdGroups. ExprGroups allDefinitionsOf(const IdGroups& of) const; // Return sorted expressions to go from the provided IterDomains in from to @@ -101,13 +111,15 @@ class TORCH_CUDA_CU_API IdGraph { //! outer vector are expression groups that are not equivalent based on the //! provided mode, but produce one of the IterDomains within the same disjoint //! Iter Domain set based on the provided mode. - //! TODO: Change name to start with get - std::pair iterDomainGroupDefinitions( - IdGroup id_group) const; + //! + //! TODO-NM: ExprGroups is a real container. Consider returning a reference + std::pair getDefinitions(const IdGroup& id_group) const; - //! Same as iterDomainGroupDefinitions but for uses instead of definitions - //! TODO: Change name to start with get - std::pair iterDomainGroupUses(IdGroup id_group) const; + //! Same as iterDomainGroupDefinitions but for uses instead of + //! definitions + //! + //! TODO-NM: ExprGroups is a real container. Consider returning a reference + std::pair getUses(const IdGroup& id_group) const; std::string toString() const; @@ -137,18 +149,19 @@ class TORCH_CUDA_CU_API IdGraph { // Returns entry in unique_definitions_ for provided group in provided mode, // otherwise errors if no entry is found. - ExprGroups uniqueDefinitions(IdGroup group) const; + const ExprGroups& getUniqueDefinitions(const IdGroup& group) const; // Returns entry in unique_uses_ for provided group in provided mode, // otherwise errors if no entry is found. - ExprGroups uniqueUses(IdGroup group) const; + const ExprGroups& getUniqueUses(const IdGroup& group) const; - std::unordered_map& uniqueUses() { - return unique_uses_; + public: + void addUniqueUses(const IdGroup& id_group, const ExprGroup& uses) { + unique_uses_.at(id_group).pushBack(uses); } - std::unordered_map& uniqueDefinitions() { - return unique_definitions_; + void addUniqueDefinitions(const IdGroup& id_group, const ExprGroup& defs) { + unique_definitions_.at(id_group).pushBack(defs); } // Set id0 and id1 to mapped in disjointIdsSet[mode], attempt to propagate @@ -160,23 +173,6 @@ class TORCH_CUDA_CU_API IdGraph { // be the only call in IdGraph to mapThroughExpr void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); - // Map expr0 and expr1 with eachother, update unique_definitions_ unique_uses_ - // TODO: Make this variant hidden? - void mapExprs(Expr* expr0, Expr* expr1); - - // Checks if expr's are considered "the same" where sameness inputs and - // outputs in the same position across expressions map with provided - // MappingMode. If the expressions are determined the same then - // if forward - // will map outputs - // else - // will map inputs - // in the provided mode. - // Returns if expressions were mapped through. - // - // TODO: Make this private - bool mapThroughExpr(Expr* first, Expr* second, bool forward); - // Map through loop swizzles, as input/output IterDomains are exact, only the // order they're traversed differs. void mapThroughLoopSwizzles(); @@ -202,12 +198,29 @@ class TORCH_CUDA_CU_API IdGraph { // Removes the provided expression group from unique_definitions_ and // unique_uses_ breaking traversal through them. - void eraseExprGroup(ExprGroup expr_group); + void eraseExprGroup(const ExprGroup& expr_group); // Returns if the expression group has an input id group that matches an // output id group. This means traversing on this expression doesn't actually // do anything. - bool isTrivialExprGroup(ExprGroup expr_group) const; + bool isTrivialExprGroup(const ExprGroup& expr_group) const; + + private: + // Map expr0 and expr1 with eachother, update unique_definitions_ unique_uses_ + // TODO: Make this variant hidden? + void mapExprs(Expr* expr0, Expr* expr1); + + // Checks if expr's are considered "the same" where sameness inputs and + // outputs in the same position across expressions map with provided + // MappingMode. If the expressions are determined the same then + // if forward + // will map outputs + // else + // will map inputs + // in the provided mode. + // Returns if expressions were mapped through. + // + bool mapThroughExpr(Expr* first, Expr* second, bool forward); private: // If propagate_exprs_ = false, then mapThroughExpr will not be called as a @@ -217,8 +230,6 @@ class TORCH_CUDA_CU_API IdGraph { // Note: For the second sentence of above... mapThroughExpr can call mapIds // which could in return call mapThoughExpr again, but propagate_exprs_ as // mentioned above prevents that from happening. - // - // TODO: Should propagate_exprs_ be a const member? bool propagate_exprs_ = true; // Keeps a disjoint set entry for all IterDomain for all mapping mode types. @@ -231,13 +242,6 @@ class TORCH_CUDA_CU_API IdGraph { // Keeps a disjoint set entry for all Expressions for all mapping mode types. DisjointSets disjoint_exprs_; - // Hold a set of IterDomains that are considered view rfactor ids. This - // identification is particularly important to understand if split operations - // are divisible or not. - // - // TODO: This should just be in IterDomainGraphs, not here. - std::unordered_set view_rfactor_ids_; - std::unordered_map unique_definitions_; std::unordered_map unique_uses_; diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 7d2a352aa00..288b5992116 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -400,7 +400,7 @@ Expr* IterDomainGraphs::addReplayAs( // Update uses of the inputs in the graphs for (auto inp_id : ir_utils::filterByType(replay->inputs())) { auto inp_group = idGraph(mode).toGroup(inp_id); - idGraph(mode).uniqueUses().at(inp_group).pushBack(replay_group); + idGraph(mode).addUniqueUses(inp_group, replay_group); } // Propagate through all the uses of the iter domain groups of the inputs @@ -409,7 +409,7 @@ Expr* IterDomainGraphs::addReplayAs( // Gather all use expressions from inputs VectorOfUniqueEntries representative_uses; for (auto inp : new_inputs) { - auto uses_pair = graph.iterDomainGroupUses(graph.toGroup(inp)); + auto uses_pair = graph.getUses(graph.toGroup(inp)); if (uses_pair.second) { for (auto use_group : uses_pair.first) { representative_uses.pushBack(use_group->front()); @@ -534,7 +534,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( } else { // Update unique uses of existing input ids auto inp_group = graph.toGroup(inp_id); - graph.uniqueUses()[inp_group].pushBack(replay_group); + graph.addUniqueUses(inp_group, replay_group); } } @@ -547,7 +547,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( // out_id is already initialized, add the replay as a unique definition // of its group auto out_group = graph.toGroup(out_id); - graph.uniqueDefinitions()[out_group].pushBack(replay_group); + graph.addUniqueDefinitions(out_group, replay_group); } } @@ -558,7 +558,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( // Forward VectorOfUniqueEntries representative_uses; for (auto in : ir_utils::filterByType(replay->inputs())) { - auto uses_pair = graph.iterDomainGroupUses(graph.toGroup(in)); + auto uses_pair = graph.getUses(graph.toGroup(in)); if (uses_pair.second) { for (auto use_group : uses_pair.first) { if (use_group == replay_group) { @@ -576,7 +576,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( // Backwards VectorOfUniqueEntries representative_defs; for (auto out : ir_utils::filterByType(replay->outputs())) { - auto defs_pair = graph.iterDomainGroupDefinitions(graph.toGroup(out)); + auto defs_pair = graph.getDefinitions(graph.toGroup(out)); if (defs_pair.second) { for (auto def_group : defs_pair.first) { if (def_group == replay_group) { @@ -1280,7 +1280,7 @@ std::unordered_map IterDomainGraphs:: ExprGroups non_promoted_input_uses; for (auto iel_group : promoted_input_groups.intersect(input_groups)) { non_promoted_input_uses.pushBack( - intersection_exact_loop_graph.uniqueUses(iel_group)); + intersection_exact_loop_graph.getUniqueUses(iel_group)); } for (auto iel_use_group : non_promoted_input_uses) { @@ -1363,7 +1363,7 @@ std::unordered_map computeCoveredGroups( for (auto id_group : graph.disjointIdSets().disjointSets()) { // Initialize inputs - if (graph.uniqueDefinitions(id_group).empty()) { + if (graph.getUniqueDefinitions(id_group).empty()) { covered_ids[id_group] = {id_group}; } @@ -1695,7 +1695,7 @@ std::unordered_map IterDomainGraphs:: auto inp_exact_group = idGraph(IdMappingMode::EXACT).toGroup(inp_id); promoted_input_groups.push_back(inp_exact_group); promoted_input_uses.pushBack( - idGraph(IdMappingMode::EXACT).uniqueUses(inp_exact_group)); + idGraph(IdMappingMode::EXACT).getUniqueUses(inp_exact_group)); } // Check every use to see if it matches diff --git a/csrc/id_model/to_string.cpp b/csrc/id_model/to_string.cpp index 0c828f814a5..972729f6950 100644 --- a/csrc/id_model/to_string.cpp +++ b/csrc/id_model/to_string.cpp @@ -299,7 +299,7 @@ std::string definitionsString( bool with_ptr) { ExprGroups defs; for (auto id_group : id_graph.disjointIdSets().disjointSets()) { - auto definition_pair = id_graph.iterDomainGroupDefinitions(id_group); + auto definition_pair = id_graph.getDefinitions(id_group); if (definition_pair.second) { for (auto expr_group : definition_pair.first) { defs.pushBack(expr_group); @@ -315,7 +315,7 @@ std::string usesString( bool with_ptr) { ExprGroups uses; for (auto id_group : id_graph.disjointIdSets().disjointSets()) { - auto definition_pair = id_graph.iterDomainGroupUses(id_group); + auto definition_pair = id_graph.getUses(id_group); if (definition_pair.second) { for (auto expr_group : definition_pair.first) { uses.pushBack(expr_group); diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp index 910ba31de67..f0f5bcd67ae 100644 --- a/csrc/id_model/visitor.cpp +++ b/csrc/id_model/visitor.cpp @@ -31,7 +31,7 @@ void IdGraphVisitor::traverse() { graph().disjointExprSets().disjointSets().end()); } else { for (auto id_group : all_ids) { - for (auto def : graph().uniqueDefinitions(id_group)) { + for (auto def : graph().getUniqueDefinitions(id_group)) { if (all_exprs.has(def)) { continue; } @@ -88,7 +88,7 @@ void IdGraphVisitor::traverse() { }; auto is_id_ready = [&](IdGroup id_group) { - auto unique_defs = graph().uniqueDefinitions(id_group); + auto unique_defs = graph().getUniqueDefinitions(id_group); return std::all_of( unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { return expr_group->empty() || visited_exprs.has(expr_group) || @@ -142,7 +142,7 @@ void IdGraphVisitor::traverse() { visited_ids.pushBack(current_id_group); if (!terminating_outputs.has(current_id_group)) { - auto uses_pair = graph().iterDomainGroupUses(current_id_group); + auto uses_pair = graph().getUses(current_id_group); if (uses_pair.second) { to_visit_exprs.pushBack(uses_pair.first); } From a98883d90013b6fa34c95bedb45f206b3be7af69 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 30 Aug 2023 20:41:31 -0700 Subject: [PATCH 045/178] cleanup --- csrc/disjoint_set.h | 2 +- csrc/id_model/id_graph.cpp | 90 ++++++++++++++++++++------------------ 2 files changed, 49 insertions(+), 43 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 55514948f16..e0a2c082809 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -109,7 +109,7 @@ class VectorOfUniqueEntries { // Returns a new VectorOfUniqueEntries with entries that are in both this and // other, order is preserved as this. VectorOfUniqueEntries intersect( - const VectorOfUniqueEntries& other) { + const VectorOfUniqueEntries& other) const { VectorOfUniqueEntries intersection; for (auto entry : vector()) { if (other.has(entry)) { diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index f864551403a..938ebfae861 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -16,9 +16,8 @@ IdGraph::IdGraph(const IdGraph& other) disjoint_exprs_(other.disjoint_exprs_), unique_definitions_(), unique_uses_() { - for (auto orig_unique_def_pair : other.unique_definitions_) { - auto orig_id_group = orig_unique_def_pair.first; - auto orig_expr_groups = orig_unique_def_pair.second; + for (const auto& [orig_id_group, orig_expr_groups] : + other.unique_definitions_) { auto new_id_group = toGroup(orig_id_group->front()); ExprGroups new_expr_groups; @@ -29,9 +28,7 @@ IdGraph::IdGraph(const IdGraph& other) unique_definitions_[new_id_group] = new_expr_groups; } - for (auto orig_unique_use_pair : other.unique_uses_) { - auto orig_id_group = orig_unique_use_pair.first; - auto orig_expr_groups = orig_unique_use_pair.second; + for (const auto& [orig_id_group, orig_expr_groups] : other.unique_uses_) { auto new_id_group = toGroup(orig_id_group->front()); ExprGroups new_expr_groups; @@ -290,7 +287,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) // domain coming back from any of its uses. ExprGroups min_groups; - auto uses_pair = getUses(id_group); + std::pair uses_pair = getUses(id_group); if (!uses_pair.second) { // No expressions required for this iter domain, it must be a // terminating output. @@ -384,18 +381,17 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) TORCH_INTERNAL_ASSERT( something_was_processed || - (to_visit_ids.size() == 0 && to_visit_exprs.size() == 0), + (to_visit_ids.empty() && to_visit_exprs.empty()), "Infinite loop entered."); } // We want to traverse the expressions registered in required_ind_exprs_ids, // let's create a strict "uses path" std::unordered_map uses_path; - for (auto entry : required_ind_exprs_ids) { - auto id = entry.first; - auto traverse_exprs = entry.second; - auto all_uses = getUses(id); - if (all_uses.second) { + for (const auto& entry : required_ind_exprs_ids) { + const IdGroup& id = entry.first; + const ExprGroups& traverse_exprs = entry.second; + if (auto all_uses = getUses(id); all_uses.second) { uses_path[id] = traverse_exprs.intersect(all_uses.first); } else { uses_path[id] = {}; @@ -511,33 +507,29 @@ std::unordered_map> IdGraph:: std::pair IdGraph::getDefinitions( const IdGroup& id_group) const { - auto null_return = std::make_pair(ExprGroups(), false); - - if (id_group == nullptr) { - return null_return; + if (!id_group) { + return {{}, false}; } - auto definitions_it = unique_definitions_.find(id_group); - if (definitions_it == unique_definitions_.end()) { - return null_return; + if (auto definitions_it = unique_definitions_.find(id_group); + definitions_it != unique_definitions_.end()) { + return std::make_pair(definitions_it->second, true); + } else { + return {{}, false}; } - - return std::make_pair(definitions_it->second, true); } std::pair IdGraph::getUses(const IdGroup& id_group) const { - auto null_return = std::make_pair(ExprGroups(), false); - - if (id_group == nullptr) { - return null_return; + if (!id_group) { + return {{}, false}; } - auto uses_it = unique_uses_.find(id_group); - if (uses_it == unique_uses_.end()) { - return null_return; + if (auto uses_it = unique_uses_.find(id_group); + uses_it != unique_uses_.end()) { + return std::make_pair(uses_it->second, true); + } else { + return {{}, false}; } - - return std::make_pair(uses_it->second, true); } std::string IdGraph::toString() const { @@ -613,6 +605,8 @@ bool IdGraph::transformAtributesMatch(Expr* first, Expr* second) { } } + // TODO: Resize properties + return true; } @@ -628,7 +622,10 @@ void IdGraph::initializeId( disjointExprSets().initializeSet(def).first->second; def_groups.pushBack(expr_set); } - unique_definitions_[id_disjoint_set] = def_groups; + // TODO-NM: def_groups can be empty. Should it be still mapped? + // TODO-NM: Can this be overwritten? + TORCH_INTERNAL_ASSERT( + unique_definitions_.emplace(id_disjoint_set, def_groups).second); ExprGroups use_groups; for (auto use : uses) { @@ -636,7 +633,10 @@ void IdGraph::initializeId( disjointExprSets().initializeSet(use).first->second; use_groups.pushBack(expr_set); } - unique_uses_[id_disjoint_set] = use_groups; + // TODO-NM: use_groups can be empty. Should it be still mapped? + // TODO-NM: Can this be overwritten? + TORCH_INTERNAL_ASSERT( + unique_uses_.emplace(id_disjoint_set, use_groups).second); } bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { @@ -713,6 +713,8 @@ bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { // TODO: For now we're using same as, however we could know what val's are // exactly the same given the exact map. We might want to pipe that // information through to here. + + // TODO-NM: Should this be transformAtributesMatch? if (first->isA()) { if (!first->as()->leftExpand()->sameAs( second->as()->leftExpand())) { @@ -810,15 +812,19 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { } void IdGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { - if (exprsMap(expr0, expr1, forward)) { - if (propagate_exprs_) { - mapExprs(expr0, expr1); - mapThroughExpr(expr0, expr1, forward); - } else if ( - inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)) && - outputGroups(toGroup(expr0)) == outputGroups(toGroup(expr1))) { - mapExprs(expr0, expr1); - } + if (!exprsMap(expr0, expr1, forward)) { + return; + } + + // Expr inputs are mapped. If propagate_exprs_ is true, map the + // exprs and outputs + if (propagate_exprs_) { + mapExprs(expr0, expr1); + mapThroughExpr(expr0, expr1, forward); + } else if ( + inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)) && + outputGroups(toGroup(expr0)) == outputGroups(toGroup(expr1))) { + mapExprs(expr0, expr1); } } From 51243b5b450823db311f0a85ce18c266380558cb Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 30 Aug 2023 21:25:13 -0700 Subject: [PATCH 046/178] cleanup --- csrc/id_model/id_graph.cpp | 125 ++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 71 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index 938ebfae861..fc68cccf1e8 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -11,6 +11,11 @@ namespace nvfuser { +namespace { +using UnorderedSetOfExprGroup = std::unordered_set; +using DequeOfExprGroup = std::deque; +} // namespace + IdGraph::IdGraph(const IdGraph& other) : disjoint_ids_(other.disjoint_ids_), disjoint_exprs_(other.disjoint_exprs_), @@ -115,29 +120,26 @@ std::vector IdGraph::inputGroups(const ExprGroup& expr) const { } ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { - ExprGroups to_visit; - for (auto of_id_group : of) { - auto group_uses_pair = getUses(of_id_group); - if (group_uses_pair.second) { - to_visit.pushBack(group_uses_pair.first); - } - } - - ExprGroups visited; - while (to_visit.size() > 0) { - auto current_expr = to_visit.popFront(); - visited.pushBack(current_expr); - auto output_ids = outputGroups(current_expr); - for (auto output_id : output_ids) { - auto group_uses_pair = getUses(output_id); - if (!group_uses_pair.second) { - continue; - } - for (auto group_use : group_uses_pair.first) { - if (visited.has(group_use)) { - continue; + DequeOfExprGroup to_visit; + for (const IdGroup& of_id_group : of) { + if (const auto& [group_uses, found] = getUses(of_id_group); found) { + to_visit.insert(to_visit.end(), group_uses.begin(), group_uses.end()); + } + } + + UnorderedSetOfExprGroup visited; + while (!to_visit.empty()) { + ExprGroup current_expr = to_visit.front(); + to_visit.pop_front(); + visited.emplace(current_expr); + for (const IdGroup& output_id : outputGroups(current_expr)) { + if (const auto& [group_uses, found] = getUses(output_id); found) { + for (const ExprGroup& group_use : group_uses) { + if (visited.count(group_use)) { + continue; + } + to_visit.push_back(group_use); } - to_visit.pushBack(group_use); } } } @@ -146,29 +148,26 @@ ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { } ExprGroups IdGraph::allDefinitionsOf(const IdGroups& of) const { - ExprGroups to_visit; - for (auto of_id_group : of) { - auto group_defs_pair = getDefinitions(of_id_group); - if (group_defs_pair.second) { - to_visit.pushBack(group_defs_pair.first); - } - } - - ExprGroups visited; - while (to_visit.size() > 0) { - auto current_expr = to_visit.popFront(); - visited.pushBack(current_expr); - auto input_ids = inputGroups(current_expr); - for (auto input_id : input_ids) { - auto group_defs_pair = getDefinitions(input_id); - if (!group_defs_pair.second) { - continue; - } - for (auto group_def : group_defs_pair.first) { - if (visited.has(group_def)) { - continue; + DequeOfExprGroup to_visit; + for (const IdGroup& of_id_group : of) { + if (const auto& [group_defs, found] = getDefinitions(of_id_group); found) { + to_visit.insert(to_visit.end(), group_defs.begin(), group_defs.end()); + } + } + + UnorderedSetOfExprGroup visited; + while (!to_visit.empty()) { + ExprGroup current_expr = to_visit.front(); + to_visit.pop_front(); + visited.emplace(current_expr); + for (const IdGroup& input_id : inputGroups(current_expr)) { + if (const auto& [group_defs, found] = getDefinitions(input_id); found) { + for (const ExprGroup& group_def : group_defs) { + if (visited.count(group_def)) { + continue; + } + to_visit.push_back(group_def); } - to_visit.pushBack(group_def); } } } @@ -228,33 +227,16 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) // Return if all output IterDomain groups of an expression group have // already been visited - auto outputsVisited = [&](ExprGroup expr) { - for (auto id_group : outputGroups(expr)) { - if (required_ind_exprs_ids.find(id_group) == - required_ind_exprs_ids.end()) { - return false; - } - } - return true; - }; - -#if 0 - auto allIdUsesVisisted = [&](IdGroup id) { - auto uses_pair = iterDomainGroupUses(id); - if (!uses_pair.second) { - return true; - } - for (auto use_group : uses_pair.first) { - if (all_exprs.has(use_group)) { - if (required_ind_exprs_exprs.find(use_group) == - required_ind_exprs_exprs.end()) { - return false; - } - } - } - return true; + auto outputsVisited = [&](ExprGroup expr_group) { + auto output_groups = outputGroups(expr_group); + return std::all_of( + output_groups.begin(), + output_groups.end(), + [&](const IdGroup& output_group) { + return required_ind_exprs_ids.find(output_group) != + required_ind_exprs_ids.end(); + }); }; -#endif // Returns all expression groups in required_ind_exprs_ids of outputs auto requiredExprsOutputs = [&](ExprGroup expr_group) -> ExprGroups { @@ -614,7 +596,8 @@ void IdGraph::initializeId( IterDomain* id, const VectorOfUniqueEntries& definitions, const VectorOfUniqueEntries& uses) { - auto id_disjoint_set = disjointIdSets().initializeSet(id).first->second; + const IdGroup& id_disjoint_set = + disjointIdSets().initializeSet(id).first->second; ExprGroups def_groups; for (auto def : definitions) { From a0d8c43f5230721567776a7e02e3b33a022a26e5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 30 Aug 2023 21:41:16 -0700 Subject: [PATCH 047/178] comment --- csrc/id_model/id_graph.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/csrc/id_model/id_graph.h b/csrc/id_model/id_graph.h index 35b2c025ffb..bca5d27c229 100644 --- a/csrc/id_model/id_graph.h +++ b/csrc/id_model/id_graph.h @@ -242,6 +242,13 @@ class TORCH_CUDA_CU_API IdGraph { // Keeps a disjoint set entry for all Expressions for all mapping mode types. DisjointSets disjoint_exprs_; + // Definitions of IdGroup. There can be multiple definitions due to + // replays. + // TODO-NM: IdGroup by a new definition ExprGroup would not be used + // by existing uses. Does it make sense to represent uses and defs + // this way? In other words, there is a traversal path from a + // definition ExprGroup to an IdGroup and its use ExprGroup, but + // that does't guarantee the path actually exist std::unordered_map unique_definitions_; std::unordered_map unique_uses_; From 397edcc4778b579b38de25a682a907045a832822 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 30 Aug 2023 21:55:13 -0700 Subject: [PATCH 048/178] cleanup mapIds --- csrc/id_model/id_graph.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index fc68cccf1e8..dae4e075a6b 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -742,12 +742,12 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { // Definitions and uses are based on the groups of id0 and id1, don't merge // them into a single group until we grab all definitions and uses for later // processing. - auto orig_id_group0 = toGroup(id0); - auto orig_id_group1 = toGroup(id1); - ExprGroups orig_defs0 = getUniqueDefinitions(orig_id_group0); - ExprGroups orig_defs1 = getUniqueDefinitions(orig_id_group1); - ExprGroups orig_uses0 = getUniqueUses(orig_id_group0); - ExprGroups orig_uses1 = getUniqueUses(orig_id_group1); + IdGroup orig_id_group0 = toGroup(id0); + IdGroup orig_id_group1 = toGroup(id1); + const ExprGroups& orig_defs0 = getUniqueDefinitions(orig_id_group0); + const ExprGroups& orig_defs1 = getUniqueDefinitions(orig_id_group1); + const ExprGroups& orig_uses0 = getUniqueUses(orig_id_group0); + const ExprGroups& orig_uses1 = getUniqueUses(orig_id_group1); // Map the iter domains together before we traverse across definitions and // uses. Traversing definitions and uses could use the new property of id0 and @@ -755,11 +755,6 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { disjointIdSets().mapEntries(id0, id1); auto new_id_group = toGroup(id0); - unique_definitions_.erase(orig_id_group0); - unique_definitions_.erase(orig_id_group1); - unique_uses_.erase(orig_id_group0); - unique_uses_.erase(orig_id_group1); - unique_definitions_[new_id_group] = orig_defs0.computeUnion(orig_defs1); unique_uses_[new_id_group] = orig_uses0.computeUnion(orig_uses1); @@ -792,6 +787,11 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { } } } + + unique_definitions_.erase(orig_id_group0); + unique_definitions_.erase(orig_id_group1); + unique_uses_.erase(orig_id_group0); + unique_uses_.erase(orig_id_group1); } void IdGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { From fd9752511a506a3002cbf0c1a461fed9e5817fb2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 30 Aug 2023 22:18:46 -0700 Subject: [PATCH 049/178] cleanup --- csrc/id_model/id_graph.cpp | 4 ++-- csrc/id_model/id_graph.h | 17 ++++------------- csrc/id_model/id_graphs.cpp | 12 ++++-------- csrc/id_model/id_graphs.h | 5 +++-- csrc/id_model/to_string.cpp | 1 + 5 files changed, 14 insertions(+), 25 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index dae4e075a6b..8e0300f0a11 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -801,7 +801,7 @@ void IdGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { // Expr inputs are mapped. If propagate_exprs_ is true, map the // exprs and outputs - if (propagate_exprs_) { + if (propagate_through_exprs_) { mapExprs(expr0, expr1); mapThroughExpr(expr0, expr1, forward); } else if ( @@ -866,7 +866,7 @@ bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { } TORCH_INTERNAL_ASSERT( - propagate_exprs_, + propagate_through_exprs_, "Asked to propagate expression mappings on a graph that has propagate_exprs_ disabled."); auto first_ids = ir_utils::filterByType( diff --git a/csrc/id_model/id_graph.h b/csrc/id_model/id_graph.h index bca5d27c229..4f0d174e4e2 100644 --- a/csrc/id_model/id_graph.h +++ b/csrc/id_model/id_graph.h @@ -31,6 +31,8 @@ class TORCH_CUDA_CU_API IdGraph { IdGraph& operator=(const IdGraph& other); IdGraph& operator=(IdGraph&& other) = default; + IdGraph(bool propagate_through_exprs) : propagate_through_exprs_(propagate_through_exprs) {} + // Returns the disjoint IterDomain set. const DisjointSets& disjointIdSets() const { return disjoint_ids_; @@ -185,17 +187,6 @@ class TORCH_CUDA_CU_API IdGraph { // mappings from IdGraph::isTrivialExpr void removeTrivialExprs(); - // See comment on propagate_expr_ member bool for description - // Once disabled this can't be reenabled on a graph. If it's reenabled it's - // hard to predict how mappings will propagate, which will be triggered on the - // next mapping. To support changing this flag, we should likely run through - // all expressions currently registered and propagate through all of them on - // switch. Then once enabled it couldn't be redisabled because we don't record - // the history of mapId calls. - void disableExprPropagation() { - propagate_exprs_ = false; - } - // Removes the provided expression group from unique_definitions_ and // unique_uses_ breaking traversal through them. void eraseExprGroup(const ExprGroup& expr_group); @@ -223,14 +214,14 @@ class TORCH_CUDA_CU_API IdGraph { bool mapThroughExpr(Expr* first, Expr* second, bool forward); private: - // If propagate_exprs_ = false, then mapThroughExpr will not be called as a + // If propagate_through_exprs_ = false, then mapThroughExpr will not be called as a // consequence of calling mapIds. As well as mapThroughExpr will not be called // (again) as a result of calling mapThroughExpr. // // Note: For the second sentence of above... mapThroughExpr can call mapIds // which could in return call mapThoughExpr again, but propagate_exprs_ as // mentioned above prevents that from happening. - bool propagate_exprs_ = true; + bool propagate_through_exprs_ = true; // Keeps a disjoint set entry for all IterDomain for all mapping mode types. // diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 288b5992116..662fe52c8f5 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -627,8 +627,8 @@ IterDomain* IterDomainGraphs::cloneIterDomain(IterDomain* id) { return id_copy; } -IdGraph IterDomainGraphs::initializeIdGraph() { - IdGraph id_graph; +IdGraph IterDomainGraphs::initializeIdGraph(bool propagate_through_exprs) { + IdGraph id_graph(propagate_through_exprs); for (auto definition_entry : id_definitions_) { auto id = definition_entry.first; @@ -1065,10 +1065,7 @@ IdGraph IterDomainGraphs::buildIntersection( const IdGraph& graph0, const IdGraph& graph1, bool propagate_exprs) { - auto intersection = initializeIdGraph(); - if (!propagate_exprs) { - intersection.disableExprPropagation(); - } + auto intersection = initializeIdGraph(propagate_exprs); for (auto group0 : graph0.disjointIdSets().disjointSets()) { auto set_size = group0->size(); for (auto id0_i : c10::irange(set_size)) { @@ -1087,10 +1084,9 @@ IdGraph IterDomainGraphs::buildIntersection( } void IterDomainGraphs::initializeLoopMap(StatefulLoweringInfo& info) { - idGraph(IdMappingMode::LOOP) = initializeIdGraph(); // See Indexing20 example for why we shouldn't propagate when generating loop // groups - idGraph(IdMappingMode::LOOP).disableExprPropagation(); + idGraph(IdMappingMode::LOOP) = initializeIdGraph(false); // Make sure this is called in a deterministic order. Build all inlined // relationships in loop graph. diff --git a/csrc/id_model/id_graphs.h b/csrc/id_model/id_graphs.h index 0d144c3ee52..fca0bf1c842 100644 --- a/csrc/id_model/id_graphs.h +++ b/csrc/id_model/id_graphs.h @@ -185,7 +185,7 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { // Iterates over all IterDomains in id_definitions_ and calls initializeID on // a new IdGraph and returns it. - IdGraph initializeIdGraph(); + IdGraph initializeIdGraph(bool propagate_through_exprs = true); // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs // and first output of expr @@ -280,7 +280,8 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { std::unordered_map> id_uses_; // Make sure we don't blindly use definitions as we don't want to grab - // transformations before a tensor view's root domain. + // transformations before a tensor view's root domain. There can be + // multiple definitions due to replays. std::unordered_map> id_definitions_; // Debug information to hold if a self mapping in a TensorView is found. diff --git a/csrc/id_model/to_string.cpp b/csrc/id_model/to_string.cpp index 972729f6950..7bf3b619ea8 100644 --- a/csrc/id_model/to_string.cpp +++ b/csrc/id_model/to_string.cpp @@ -161,6 +161,7 @@ std::string toInlineString(const std::vector& id_groups) { auto pos = group_name_info[i].second; ss << toString(id_groups[pos]); } + ss << "}"; return ss.str(); } From 177d40c7c232e305c4c9d134446d5c08aa8539d6 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 31 Aug 2023 07:25:00 -0700 Subject: [PATCH 050/178] cleanup --- csrc/id_model/id_graphs.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 662fe52c8f5..d0b208a3c08 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -942,7 +942,7 @@ void IterDomainGraphs::build( }); auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs); - if (additional_tvs.size() > 0) { + if (!additional_tvs.empty()) { std::unordered_set all_added_tvs( all_tvs.begin(), all_tvs.end()); for (auto additional_tv : additional_tvs) { From 535ce1dad40e5e0124011d83db9f6a811532e742 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 31 Aug 2023 07:31:20 -0700 Subject: [PATCH 051/178] Enable tests --- test/test_gpu_indexing.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index 7192b4daeda..31b806f2fc2 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -791,7 +791,6 @@ TEST_F(NVFuserTest, FusionIndexing17_CUDA) { &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -#if 0 // TODO: Finish and enable test TEST_F(NVFuserTest, FusionIndexing18_CUDA) { Fusion fusion; @@ -828,10 +827,8 @@ TEST_F(NVFuserTest, FusionIndexing18_CUDA) { // ComputeAtMap ca_map(&fusion); // std::cout << ca_map.idGraph().loopNodes().toString() << std::endl; } -#endif // TODO: Finish and enable test -#if 0 // // Create a case where we're missing a valid concrete id so the compute at map // processing will fail. We need to be able to create the concrete ID not just @@ -882,9 +879,7 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { fusion.print(); fusion.printKernel(); } -#endif -#if 0 // TODO: Finish and enable test // // Progressive loop promotion. producer gets promoted in consumer, consumer is @@ -935,7 +930,6 @@ TEST_F(NVFuserTest, FusionIndexing20_CUDA) { fusion.printKernel(); } -#endif // Repro for issue #1873 TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) { @@ -1020,7 +1014,6 @@ TEST_F(NVFuserTest, FusionMultiPromotion_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -#if 0 // TODO: Finish and enable test. // Broadcast and concretize same domain in two different ways and try to merge // their loops remains unsupported. @@ -1064,7 +1057,6 @@ TEST_F(NVFuserTest, FusionMultiPromotion2_CUDA) { ASSERT_ANY_THROW(fusion.printKernel()); } -#endif // TODO: All the above tests are merges followed by splits, we should make some // more complex examples even though merging then spliting is the most likely From b80e871107fb9f26583736ba3bd25dc24bcc3b47 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 1 Sep 2023 07:41:16 -0700 Subject: [PATCH 052/178] cleanup --- csrc/id_model/id_graph.cpp | 43 ++++++++++++-------------------------- csrc/id_model/id_graph.h | 9 ++++---- 2 files changed, 18 insertions(+), 34 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index 8e0300f0a11..fc8bc950ba2 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -194,24 +194,19 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) IdGroups all_id_groups; for (const ExprGroup& expr_group : all_exprs) { - std::vector inp_groups = inputGroups(expr_group); - std::vector out_groups = outputGroups(expr_group); - if (!IdGroups(inp_groups).intersect(IdGroups(out_groups)).empty()) { + if (isTrivialExprGroup(expr_group)) { // Expression is just a loop to its current group, ignore continue; } - all_id_groups.pushBack(inp_groups); + std::vector inp_groups = inputGroups(expr_group); + std::vector out_groups = outputGroups(expr_group); - if (!inp_groups.empty()) { - not_outputs.pushBack(inp_groups); - } + all_id_groups.pushBack(inp_groups); + not_outputs.pushBack(inp_groups); all_id_groups.pushBack(out_groups); - - if (!out_groups.empty()) { - not_inputs.pushBack(out_groups); - } + not_inputs.pushBack(out_groups); } terminating_inputs = all_id_groups.subtract(not_inputs); terminating_outputs = all_id_groups.subtract(not_outputs); @@ -643,26 +638,14 @@ bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { first->toString(), second->toString()); + // TODO-MN: Is this equivalent as + // inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)) ? { - std::vector> zipped_ids; - - std::transform( - first_ids.begin(), - first_ids.end(), - second_ids.begin(), - std::back_inserter(zipped_ids), - [](IterDomain* first, IterDomain* second) { - return std::make_pair(first, second); - }); - - if (std::any_of( - zipped_ids.begin(), - zipped_ids.end(), - [&](std::pair id_pair) { - return !disjointIdSets().permissiveAreMapped( - id_pair.first, id_pair.second); - })) { - return false; + for (const auto i : c10::irange(first_ids.size())) { + if (!disjointIdSets().permissiveAreMapped( + first_ids.at(i), second_ids.at(i))) { + return false; + } } } diff --git a/csrc/id_model/id_graph.h b/csrc/id_model/id_graph.h index 4f0d174e4e2..0e822d7451b 100644 --- a/csrc/id_model/id_graph.h +++ b/csrc/id_model/id_graph.h @@ -31,7 +31,8 @@ class TORCH_CUDA_CU_API IdGraph { IdGraph& operator=(const IdGraph& other); IdGraph& operator=(IdGraph&& other) = default; - IdGraph(bool propagate_through_exprs) : propagate_through_exprs_(propagate_through_exprs) {} + IdGraph(bool propagate_through_exprs) + : propagate_through_exprs_(propagate_through_exprs) {} // Returns the disjoint IterDomain set. const DisjointSets& disjointIdSets() const { @@ -214,9 +215,9 @@ class TORCH_CUDA_CU_API IdGraph { bool mapThroughExpr(Expr* first, Expr* second, bool forward); private: - // If propagate_through_exprs_ = false, then mapThroughExpr will not be called as a - // consequence of calling mapIds. As well as mapThroughExpr will not be called - // (again) as a result of calling mapThroughExpr. + // If propagate_through_exprs_ = false, then mapThroughExpr will not be called + // as a consequence of calling mapIds. As well as mapThroughExpr will not be + // called (again) as a result of calling mapThroughExpr. // // Note: For the second sentence of above... mapThroughExpr can call mapIds // which could in return call mapThoughExpr again, but propagate_exprs_ as From bb4968bb0e57bce65b3320aa376e2b8fb1662cb0 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 1 Sep 2023 07:49:39 -0700 Subject: [PATCH 053/178] comment --- csrc/id_model/id_graph.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index fc8bc950ba2..934593dca13 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -314,6 +314,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) ExprGroup currently_visiting_exprs = to_visit_exprs.popFront(); if (required_ind_exprs_exprs.find(currently_visiting_exprs) != required_ind_exprs_exprs.end()) { + // currently_visiting_exprs is already visited continue; } if (processExprGroup(currently_visiting_exprs)) { From d6504f88b17d89cc5a8d67e968d349bb88a4d19a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 22 Sep 2023 17:09:33 -0700 Subject: [PATCH 054/178] debug output --- csrc/id_model/id_graphs.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index d0b208a3c08..aa91c164064 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1322,6 +1322,16 @@ std::unordered_map IterDomainGraphs:: } } } + + std::cerr << "Inline promotion done\n"; + + std::stringstream ss; + ss << "Inline promotion map\n"; + for (const auto& [iel_group, promoted_id] : iel_promotion_map) { + ss << "\t" << nvfuser::toString(iel_group) << " -> " << promoted_id->name() << std::endl; + } + std::cerr << ss.str(); + return iel_promotion_map; } @@ -1754,6 +1764,11 @@ std::unordered_map IterDomainGraphs:: } } + std::cerr << "Loop promotion map:\n"; + for (const auto& [iel_group, id]: iel_promotion_map) { + std::cerr << nvfuser::toString(iel_group) << " -> " << id->name() << std::endl; + } + return iel_promotion_map; } From 60407b238e19a02885dd5f189cc49c1d46825104 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 4 Oct 2023 14:32:15 -0700 Subject: [PATCH 055/178] Remove --no-allow-shlib-undefined as our dependencies may not have correct dependencies, e.g., c10 --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ee5b7ee9daf..09e7699e129 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,6 @@ cmake_minimum_required(VERSION 3.18 FATAL_ERROR) project(nvfuser) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--no-allow-shlib-undefined") # compensating the lack of PROJECT_IS_TOP_LEVEL for older cmake version if(CMAKE_VERSION VERSION_LESS 3.21) From 24dc7589eebf15132e7041a08fb9ccef5d938cf7 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 5 Oct 2023 16:55:12 -0700 Subject: [PATCH 056/178] Fill the final loop promotion map (#1028) This PR adds structural validations of loop graphs and promotions of one of the most important tests, `Indexing19`. In order to do so, it introduces a minimal set of changes as follows: - Enable building loop graphs and promotion mappings even for `Fusion` - Populate a map from loop ID groups to their promotion domains, which is temporarily done at the very end of `buildLoopPromotionMap`. Some cleanups are needed, but I will iteratively work on cleanups as well as functional changes. This PR severs as the baseline for forthcoming PRs as well as making sure my understanding is correct. --- csrc/id_model/id_graphs.cpp | 76 ++++++++++++++++++++++++++++--- csrc/id_model/id_graphs.h | 4 ++ csrc/id_model/utils.h | 55 +++++++++++++++++++++++ test/test_gpu_indexing.cpp | 90 ++++++++++++++++++++++++++++++++++++- 4 files changed, 217 insertions(+), 8 deletions(-) create mode 100644 csrc/id_model/utils.h diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 54c89966986..3c95724bee9 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -971,7 +972,8 @@ void IterDomainGraphs::build( idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); // Only build loop map during lowering - if (FusionGuard::getCurFusion()->isA()) { + // TODO: make this configurable + if (true || FusionGuard::getCurFusion()->isA()) { validatePTypes(all_tvs); StatefulLoweringInfo info = buildInfo( @@ -1320,15 +1322,13 @@ std::unordered_map IterDomainGraphs:: } } - std::cerr << "Inline promotion done\n"; - std::stringstream ss; ss << "Inline promotion map\n"; for (const auto& [iel_group, promoted_id] : iel_promotion_map) { ss << "\t" << nvfuser::toString(iel_group) << " -> " << promoted_id->name() << std::endl; } - std::cerr << ss.str(); + VERBOSE() << ss.str(); return iel_promotion_map; } @@ -1762,9 +1762,73 @@ std::unordered_map IterDomainGraphs:: } } - std::cerr << "Loop promotion map:\n"; + // TODO: cleanup + // Set loop_promotion_map_[loop_group] = promotion. + // Make sure the existing mapping, if exists, matches with the given + // promotion. + auto setLoopPromotion = + [this](const IdGroup& loop_group, IterDomain* promotion) -> void { + if (auto it = loop_promotion_map_.find(loop_group); + it != loop_promotion_map_.end()) { + auto existing_promotion = it->second; + NVF_ERROR( + idGraph(IdMappingMode::EXACT).toGroup(promotion) == + idGraph(IdMappingMode::EXACT).toGroup(existing_promotion), + "Different promotions found for ", + nvfuser::toString(loop_group), + ". ", + promotion->toString(), + ", ", + existing_promotion->toString()); + } else { + loop_promotion_map_.emplace(loop_group, promotion); + } + }; + + // Set up the loop promotion map of loops groups to promotion IDs + for (const IdGroup& loop_group : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + bool promoted = false; + for (IterDomain* id : loop_group->vector()) { + const auto& iel_group = intersection_exact_loop_graph.toGroup(id); + if (auto iel_promotion_map_it = iel_promotion_map.find(iel_group); + iel_promotion_map_it != iel_promotion_map.end()) { + IterDomain* iel_promotion_id = iel_promotion_map_it->second; + setLoopPromotion(loop_group, iel_promotion_id); + promoted = true; + } + } + + if (promoted) { + continue; + } + + VERBOSE() << "No mapping in the IEL promotion map: " + << nvfuser::toString(loop_group) << std::endl; + + // No mapping in the IEL promotion map. If the loop group is still + // mapped in the loop group promotion map, that should be the + // correct promotion for this group + if (auto loop_graph_copy_promotion_map_it = + loop_graph_copy_promotion_map.find( + loop_graph_copy.toGroup(loop_group->vector().at(0))); + loop_graph_copy_promotion_map_it != + loop_graph_copy_promotion_map.end()) { + VERBOSE() << "Found in loop promotion: " << nvfuser::toString(loop_group) + << std::endl; + setLoopPromotion(loop_group, loop_graph_copy_promotion_map_it->second); + promoted = true; + } + + NVF_ERROR( + promoted, + "Loop promotion not found for ", + nvfuser::toString(loop_group)); + } + + VERBOSE() << "Loop promotion map:" << std::endl; for (const auto& [iel_group, id] : iel_promotion_map) { - std::cerr << nvfuser::toString(iel_group) << " -> " << id->name() + VERBOSE() << nvfuser::toString(iel_group) << " -> " << id->name() << std::endl; } diff --git a/csrc/id_model/id_graphs.h b/csrc/id_model/id_graphs.h index fca0bf1c842..8728f96147f 100644 --- a/csrc/id_model/id_graphs.h +++ b/csrc/id_model/id_graphs.h @@ -167,6 +167,10 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { // not have any registered uses or definitions. IterDomain* cloneIterDomain(IterDomain* id); + const std::unordered_map loopPromotionMap() const { + return loop_promotion_map_; + } + // TODO: Should this not be private? protected: // Sometimes fusion inputs or outputs are disconnected from expressions, in diff --git a/csrc/id_model/utils.h b/csrc/id_model/utils.h new file mode 100644 index 00000000000..563c5279774 --- /dev/null +++ b/csrc/id_model/utils.h @@ -0,0 +1,55 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include "utils.h" + +#include +#include +#include + +#define VERBOSE() verbose(__LINE__) +#define WARN() warn(__LINE__) + +namespace nvfuser { + +// Temporary logging utility +class DebugStream { + public: + DebugStream() + : enabled_(getNvFuserEnv("ID_MODEL_VERBOSE")), out_(std::cerr) {} + + template + DebugStream& operator<<(const T& v) { + if (enabled_) { + out_ << v; + } + return *this; + } + + DebugStream& operator<<(std::ostream& (*endl)(std::ostream&)) { + if (enabled_) { + out_ << endl; + } + return *this; + } + + private: + bool enabled_ = false; + std::ostream& out_; +}; + +inline DebugStream verbose(int line) { + return DebugStream() << "[DEBUG@" << line << "] "; +} + +inline DebugStream warn(int line) { + return DebugStream() << "[WARN@" << line << "] "; +} + +} // namespace nvfuser diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index 88ff2d092fa..ca81b0da900 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -11,6 +11,8 @@ #include #include +#include +#include #include #include #include @@ -877,8 +879,92 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { tensor->inlineAt(1); } - fusion.print(); - fusion.printKernel(); + IterDomainGraphs id_model(&fusion); + + // All of the IDs that are generated with merge operations from the + // root domains should be mapped to the single group. + const IdGroup& merge_loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(tv1->getRootDomain().at(0)); + for (auto tv : {tv1, tv2, tv4, tv5, tv6, tv8, tv9}) { + for (auto id : ir_utils::allIDsOf(tv)) { + if (dynamic_cast(id->definition()) == nullptr) { + const IdGroup& loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(id); + ASSERT_EQ(loop_group, merge_loop_group) + << "Unexpected loop group: " << nvfuser::toString(loop_group); + } + } + } + + const auto& promotion_map = id_model.loopPromotionMap(); + + // The merge loop group should be promoted to the output of the + // final merge in tv10 + auto ref_merge_out = tv10->axis(0) + ->definition() + ->input(0) + ->definition() + ->input(0) + ->as(); + + auto promotion_map_it = promotion_map.find(merge_loop_group); + ASSERT_TRUE(promotion_map_it != promotion_map.end()) + << "Loop promotion not found for merge loop group"; + ASSERT_EQ( + id_model.idGraph(IdMappingMode::EXACT).toGroup(promotion_map_it->second), + id_model.idGraph(IdMappingMode::EXACT).toGroup(ref_merge_out)) + << "Merge loop group should be promoted to " << ref_merge_out->toString(); + + // Get the corresponding reference ID in tv10 + auto getRefId = [&](TensorView* tv, IterDomain* id) -> IterDomain* { + if (dynamic_cast(id->definition()) != nullptr) { + if (id->uses().empty()) { + auto it = std::find( + tv->getLeafDomain().begin(), tv->getLeafDomain().end(), id); + NVF_ERROR(it != tv->getLeafDomain().end()); + int leaf_pos = + static_cast(std::distance(tv->getLeafDomain().begin(), it)); + return tv10->axis(leaf_pos); + } else { + return tv10->axis(0)->definition()->input(0)->as(); + } + } else { + return ref_merge_out; + } + }; + + // At this point, all of the IDs from the root until split are + // validated. Validating the remaining IDs + for (auto tv : {tv1, tv2, tv4, tv5, tv6, tv8, tv9}) { + for (auto id : ir_utils::allIDsOf(tv)) { + const auto& loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(id); + if (loop_group == merge_loop_group) { + // already validated + continue; + } + + auto promotion_map_it = promotion_map.find(loop_group); + ASSERT_TRUE(promotion_map_it != promotion_map.end()) + << "Loop promotion not found for " << id->toString() << " of " + << tv->toString() + << ". Loop group: " << nvfuser::toString(loop_group); + + auto promotion_exact_group = id_model.idGraph(IdMappingMode::EXACT) + .toGroup(promotion_map_it->second); + + auto ref_id = getRefId(tv, id); + auto ref_exact_group = + id_model.idGraph(IdMappingMode::EXACT).toGroup(ref_id); + + ASSERT_EQ(promotion_exact_group, ref_exact_group) + << "Invalid promotion: " << id->toString() << " of " << tv->toString() + << ". Promotion group: " << nvfuser::toString(promotion_exact_group); + } + } + + // The current ComputeAtMap fails with this fusion + // fusion.printKernel(); } // TODO: Finish and enable test From f9045f9b57d89c15ffdf49cbfb0af1d7bae4799a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 10 Oct 2023 20:55:29 -0700 Subject: [PATCH 057/178] Just cleanup. There should be no functional change (#1061) --- csrc/disjoint_set.h | 19 ++- csrc/id_model/id_graph.h | 6 +- csrc/id_model/id_graphs.cpp | 277 +++++++++++++++++++------------ csrc/id_model/id_graphs.h | 5 +- csrc/id_model/transform_replay.h | 6 +- csrc/id_model/visitor.cpp | 42 ++--- csrc/id_model/visitor.h | 43 ++--- 7 files changed, 238 insertions(+), 160 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 431e5478e8c..660fdbc60d2 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -99,8 +99,13 @@ class VectorOfUniqueEntries { // Returns if any node was added bool pushBack(const VectorOfUniqueEntries& other) { + return pushBack(other.vector()); + } + + // Returns if any node was added + bool pushBack(const std::vector& other) { bool any_added = false; - for (auto entry : other) { + for (const auto& entry : other) { auto added = pushBack(entry); any_added = any_added || added; } @@ -112,7 +117,7 @@ class VectorOfUniqueEntries { VectorOfUniqueEntries intersect( const VectorOfUniqueEntries& other) const { VectorOfUniqueEntries intersection; - for (auto entry : vector()) { + for (const auto& entry : vector()) { if (other.has(entry)) { intersection.pushBack(entry); } @@ -125,7 +130,7 @@ class VectorOfUniqueEntries { VectorOfUniqueEntries subtract( const VectorOfUniqueEntries& other) const { VectorOfUniqueEntries subtraction; - for (auto entry : vector()) { + for (const auto& entry : vector()) { if (!other.has(entry)) { subtraction.pushBack(entry); } @@ -139,7 +144,7 @@ class VectorOfUniqueEntries { const VectorOfUniqueEntries& other) const { const VectorOfUniqueEntries& this_ref = *this; VectorOfUniqueEntries union_(this_ref); - for (auto entry : other.vector()) { + for (const auto& entry : other.vector()) { union_.pushBack(entry); } return union_; @@ -253,7 +258,7 @@ class VectorOfUniqueEntries { std::string toString() const { std::stringstream ss; ss << "{ "; - for (auto entry : vector()) { + for (const auto& entry : vector()) { ss << abstractToString(entry); if (entry != vector().back()) { ss << "; "; @@ -476,6 +481,10 @@ class DisjointSets { return ss.str(); } + auto size() const { + return disjoint_sets_.size(); + } + private: // Disjoint sets std::unordered_map>, Hash> diff --git a/csrc/id_model/id_graph.h b/csrc/id_model/id_graph.h index 0e822d7451b..77333f1a3ed 100644 --- a/csrc/id_model/id_graph.h +++ b/csrc/id_model/id_graph.h @@ -21,7 +21,7 @@ using IdGroups = VectorOfUniqueEntries; using ExprGroup = std::shared_ptr>; using ExprGroups = VectorOfUniqueEntries; -class TORCH_CUDA_CU_API IdGraph { +class IdGraph { public: IdGraph() = default; @@ -121,7 +121,9 @@ class TORCH_CUDA_CU_API IdGraph { //! Same as iterDomainGroupDefinitions but for uses instead of //! definitions //! - //! TODO-NM: ExprGroups is a real container. Consider returning a reference + //! TODO-NM: ExprGroups is a real container. Consider returning a + //! reference + //! TODO-NM: Rename to getMaybeUses. See getUses std::pair getUses(const IdGroup& id_group) const; std::string toString() const; diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 3c95724bee9..b685ce249d3 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -326,6 +326,8 @@ Expr* IterDomainGraphs::addReplayAs( std::vector orig_input_ids( orig_inputs.begin(), orig_inputs.end()); + // Replace the provided inputs with IterType::Iteration domains as + // reduction domains cannot be merged with non-reduction domains. if (std::any_of( new_inputs.begin(), new_inputs.end(), @@ -359,8 +361,7 @@ Expr* IterDomainGraphs::addReplayAs( VectorOfUniqueEntries all_inputs{ orig_input_ids.begin(), orig_input_ids.end()}; - all_inputs.pushBack(VectorOfUniqueEntries{ - new_inputs.begin(), new_inputs.end()}); + all_inputs.pushBack(new_inputs); for (auto mode : initialized_modes) { for (auto inp : all_inputs) { @@ -409,10 +410,10 @@ Expr* IterDomainGraphs::addReplayAs( auto& graph = idGraph(mode); // Gather all use expressions from inputs VectorOfUniqueEntries representative_uses; - for (auto inp : new_inputs) { + for (IterDomain* inp : new_inputs) { auto uses_pair = graph.getUses(graph.toGroup(inp)); if (uses_pair.second) { - for (auto use_group : uses_pair.first) { + for (const ExprGroup& use_group : uses_pair.first) { representative_uses.pushBack(use_group->front()); } } @@ -717,7 +718,6 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { } // TODO: Should this just get rolled up in the forwarding map now? - // TODO: Why should IDs be mapped to their compliments? Is this right? for (auto entry : permissive_forwarding.producer_compliment_map) { for (auto entry_2 : entry.second) { idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); @@ -831,16 +831,19 @@ std::unordered_map resolvedRootBroadcasts( PairwiseRootDomainMap(producer, consumer).mapProducerToConsumer(); std::unordered_map resolved_bcast_map; - for (const auto& kv : p2c_map) { - auto p_id = kv.first; - // Ignore non-broadcast dims - if (!p_id->isBroadcast()) { + for (const auto& [p_id, c_id] : p2c_map) { + // Look for a broadcast producer and non-broadcast consumer + + // Ignore non-broadcast producer and broadcast consumer dims + if (!p_id->isBroadcast() || c_id->isBroadcast()) { continue; } - auto c_id = kv.second; - // If the consumer ID is a reduction (i.e., a trivial - // reduction), do not consider it's concretized. - if (c_id->isBroadcast() || c_id->isReduction()) { + + if (c_id->isReduction()) { + // This should only happen with expanded broadcast + // domains. Otherwise, squeeze should be used + NVF_ERROR( + p_id->hasExpandedExtent(), "Unexpected domain: ", c_id->toString()); continue; } @@ -875,10 +878,10 @@ StatefulLoweringInfo buildInfo( } info.ordered_p_ca_ids.pushBack(all_producer_ca_deps); - for (auto consumer : ir_utils::filterByType(expr->outputs())) { auto resolved_bcast_map = resolvedRootBroadcasts(producer, consumer); + for (auto entry : resolved_bcast_map) { info.p2c_root_broadcast_resolution_map[entry.first].pushBack( entry.second); @@ -897,8 +900,8 @@ StatefulLoweringInfo buildInfo( auto p2c_permissive_map = permissive_graph.buildMapBetween( all_producer_ids, all_consumer_ids); - for (auto entry : p2c_permissive_map) { - if (entry.second.size() == 0) { + for (const auto& entry : p2c_permissive_map) { + if (entry.second.empty()) { continue; } if (all_producer_ca_deps.has(entry.first)) { @@ -907,8 +910,8 @@ StatefulLoweringInfo buildInfo( info.p2c_permissive_maps[entry.first].pushBack(entry.second); } - for (auto entry : p2c_permissive_map) { - if (entry.second.size() == 0) { + for (const auto& entry : p2c_permissive_map) { + if (entry.second.empty()) { continue; } info.p2c_permissive_maps[entry.first].pushBack(entry.second); @@ -1021,7 +1024,7 @@ void IterDomainGraphs::build( VectorOfUniqueEntries IterDomainGraphs::computeTerminalLoopIds( const StatefulLoweringInfo info) { VectorOfUniqueEntries terminal_loop_ids; - for (auto group : + for (const IdGroup& group : idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { if (group->size() == 1) { terminal_loop_ids.pushBack(group->front()); @@ -1043,7 +1046,7 @@ VectorOfUniqueEntries IterDomainGraphs::computeTerminalLoopIds( // If there's an output group that is not in the same group, then it's id // consumer terminal. Also if there's no output groups it's id consumer // terminal. - bool all_outs_in_loop_group = uses_it->second.size() == 0 ? false : true; + bool all_outs_in_loop_group = uses_it->second.empty() ? false : true; for (auto use : uses_it->second) { for (auto out_id : ir_utils::filterByType(use->outputs())) { if (group != idGraph(IdMappingMode::LOOP).toGroup(out_id)) { @@ -1065,7 +1068,7 @@ IdGraph IterDomainGraphs::buildIntersection( const IdGraph& graph1, bool propagate_exprs) { auto intersection = initializeIdGraph(propagate_exprs); - for (auto group0 : graph0.disjointIdSets().disjointSets()) { + for (const auto& group0 : graph0.disjointIdSets().disjointSets()) { auto set_size = group0->size(); for (auto id0_i : c10::irange(set_size)) { auto id0 = group0->vector()[id0_i]; @@ -1089,11 +1092,11 @@ void IterDomainGraphs::initializeLoopMap(StatefulLoweringInfo& info) { // Make sure this is called in a deterministic order. Build all inlined // relationships in loop graph. - for (auto p_id : info.ordered_p_ca_ids) { + for (IterDomain* p_id : info.ordered_p_ca_ids) { auto entry_it = info.p2c_ca_permissive_maps.find(p_id); if (entry_it != info.p2c_ca_permissive_maps.end()) { - auto c_ids = entry_it->second; - for (auto c_id : c_ids) { + const VectorOfUniqueEntries& c_ids = entry_it->second; + for (IterDomain* c_id : c_ids) { idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); } } @@ -1103,7 +1106,7 @@ void IterDomainGraphs::initializeLoopMap(StatefulLoweringInfo& info) { std::unordered_map IterDomainGraphs:: buildInlinePromotions(StatefulLoweringInfo& info) { // Make an intersection of the exact and loop map. This will group together - // entries in each loop group that are exact with eachother. This provides a + // entries in each loop group that are exact with each other. This provides a // better graph to do promotion and replays. // It's tempting to use the intersection of the almost exact and loop, but we @@ -1126,7 +1129,7 @@ std::unordered_map IterDomainGraphs:: // smaller groups and this algorithm scales with the number of groups * // (number of entries in groups ^ 2) - auto intersection_exact_loop_graph = buildIntersection( + IdGraph intersection_exact_loop_graph = buildIntersection( idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); // Promotion logic is going to be on the intersection of the exact and loop @@ -1141,26 +1144,46 @@ std::unordered_map IterDomainGraphs:: // able to modify a broadcast domain between root and rfactor which would be // required to resolve a non input broadcast domain. But for now leaving it as // traversal on all broadcast groups. - for (auto iel_group : + // + // TODO-NM: The ordering appears to be non-deterministic + + // We first visit all broadcast root domains. If a broadcast is + // resovled, see if it's promoted. Note that a domain be resolved to + // a domain that may not be loop mapped, yet it can still be + // promoted. In other words, there can be a domain that is exactly + // mapped with the resolving domain *and* is mapped with the + // broadcast domain by the loop map. The algorihm here is: + // + // 1. For a broadcast domain, find the domain that the broadcast is + // resolved to. + // 2. If the resolving domain is also loop-mapped with the + // broadcast, that is the promotion domain, but the resolving + // domain may not be loop mapped as mentioned above. Instead, + // find all loop-mapped domains with the broadcast domain and + // pick one that is exactly mapped with the resolving domain + // + // Note again this process is only done for root domains. Once we + // find promotion relationships for root domains, we propagate the + // mappings to derived domains + for (const IdGroup& iel_group : intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + NVF_ERROR(!iel_group->empty()); + if (!iel_group->front()->isBroadcast()) { continue; } // Collect all the exact groups of the resolutions of the broadcast id's IdGroups resolved_exact_groups; - for (auto bcast_id : *iel_group) { - auto p2c_root_broadcast_resolution_map_it = - info.p2c_root_broadcast_resolution_map.find(bcast_id); - - if (p2c_root_broadcast_resolution_map_it == + for (IterDomain* bcast_id : *iel_group) { + if (auto p2c_root_broadcast_resolution_map_it = + info.p2c_root_broadcast_resolution_map.find(bcast_id); + p2c_root_broadcast_resolution_map_it != info.p2c_root_broadcast_resolution_map.end()) { - continue; + resolved_exact_groups.pushBack( + idGraph(IdMappingMode::EXACT) + .toGroups(p2c_root_broadcast_resolution_map_it->second)); } - - resolved_exact_groups.pushBack( - idGraph(IdMappingMode::EXACT) - .toGroups(p2c_root_broadcast_resolution_map_it->second)); } // Collect all the exact groups in the loop set containing this iel_group @@ -1171,7 +1194,7 @@ std::unordered_map IterDomainGraphs:: // The intersection of the exact groups that the broadcast domains can be // broadcasted to, and those that exist within the same loop groop are is // the promotion needed for this iel_group. - auto loop_exact_resolved_intersection = + IdGroups loop_exact_resolved_intersection = resolved_exact_groups.intersect(loop_covered_exact_groups); if (loop_exact_resolved_intersection.empty()) { @@ -1186,21 +1209,21 @@ std::unordered_map IterDomainGraphs:: << "Invalid multiple broadcast resolution within shared loops detected, group:\n " << iel_group->toString() << "\nIs being broadcasted to:"; - for (auto entry : loop_exact_resolved_intersection) { + for (const IdGroup& entry : loop_exact_resolved_intersection) { err_msg << "\n " << entry->toString(); } NVF_ERROR(false, err_msg.str()); } // loop_exact_resolved_intersection.size() must be 1 at this point - auto exact_resolution_group = loop_exact_resolved_intersection.front(); + IdGroup exact_resolution_group = loop_exact_resolved_intersection.front(); VectorOfUniqueEntries resolved_ids = exact_resolution_group->intersect(*loop_group); auto promoted_iel_groups = intersection_exact_loop_graph.toGroups(resolved_ids); - if (promoted_iel_groups.size() == 0) { + if (promoted_iel_groups.empty()) { continue; } @@ -1211,7 +1234,7 @@ std::unordered_map IterDomainGraphs:: << "Invalid multiple broadcast resolution within shared loops detected, group:\n " << iel_group->toString() << "\nIs being broadcasted to:"; - for (auto entry : promoted_iel_groups) { + for (const IdGroup& entry : promoted_iel_groups) { err_msg << "\n " << entry->toString(); } NVF_ERROR(false, err_msg.str()); @@ -1220,25 +1243,29 @@ std::unordered_map IterDomainGraphs:: iel_promotion_map[iel_group] = promoted_iel_groups.front()->front(); } - for (auto iel_group : - intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - auto entry_it = iel_promotion_map.find(iel_group); - if (entry_it == iel_promotion_map.end()) { - continue; - } - } + // Propagate promotion mappings from root domains to derived domains + // by traversing IEL exprs. For each expr, if an input is promoted, + // the output needs to be promoted too. If there's already a domain + // that the output domain should be promoted to, create a mapping to it from + // the promoted output domain. If not, a new domain is created by + // replaying the expr with the promoted inputs. + // In order to make + // this traversal work, the traversal order must be toplogically + // sorted. IdGraphStmtSort iel_stmt_sort(intersection_exact_loop_graph); - for (auto iel_expr : iel_stmt_sort.exprs()) { - auto input_groups = intersection_exact_loop_graph.inputGroups(iel_expr); + for (const ExprGroup& iel_expr : iel_stmt_sort.exprs()) { + NVF_ERROR(!iel_expr->empty()); + std::vector input_groups = + intersection_exact_loop_graph.inputGroups(iel_expr); // Check if any inputs need promotion indicating this expr group needs to // be replayed with promoted inputs std::vector promoted_inputs; bool an_input_was_promoted = false; - for (auto inp : input_groups) { + for (const IdGroup& inp : input_groups) { auto inp_promo_it = iel_promotion_map.find(inp); if (inp_promo_it == iel_promotion_map.end()) { promoted_inputs.push_back(inp->front()); @@ -1253,8 +1280,6 @@ std::unordered_map IterDomainGraphs:: continue; } - Expr* replay = nullptr; - IdGroups promoted_input_groups; for (auto inp_id : promoted_inputs) { if (intersection_exact_loop_graph.hasGroup(inp_id)) { @@ -1271,14 +1296,32 @@ std::unordered_map IterDomainGraphs:: // we're not adding the replayed expression to the iel graph since we're // traversing the iel graph. // - // TODO: Can we reduce the number of new expressions generated here? + // TODO: Can we reduce the number of new expressions generated + // here? + // + // TODO-NM: This won't work for any single-input expr, e.g., + // split, as there's no other non-promoted input. Can't we just + // look at the use expr of the promoted IDGroup? + // + // TODO-NM: Why can't we just also use the promoted IDs and their + // uses? E.g., test Indexing5, t3 has a merge of iS11 and bS7, + // both of them are promoted to iS17 and iS45, respectively. Since + // there's no promoted input, there would be no reuse, but it + // seems perfectly fine to reuse the merge of iS17 and iS45. + ExprGroups non_promoted_input_uses; - for (auto iel_group : promoted_input_groups.intersect(input_groups)) { + for (const IdGroup& iel_group : + promoted_input_groups.intersect(input_groups)) { non_promoted_input_uses.pushBack( intersection_exact_loop_graph.getUniqueUses(iel_group)); } - for (auto iel_use_group : non_promoted_input_uses) { + Expr* replay = nullptr; + + // Look for exprs that have inputs that are mapped in the IEL + // graph with the (promoted) inputs of iel_expr. If found, no need + // to create a new expr to produce promoted outputs + for (const ExprGroup& iel_use_group : non_promoted_input_uses) { if (IdGraph::transformAtributesMatch( iel_expr->front(), iel_use_group->front())) { auto use_inps = @@ -1302,7 +1345,8 @@ std::unordered_map IterDomainGraphs:: replay = addReplayAs(promoted_inputs, iel_expr->front()); } - auto out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); + std::vector out_groups = + intersection_exact_loop_graph.outputGroups(iel_expr); // Mark outputs as having a promoted iter domain auto replay_out_ids = @@ -1336,21 +1380,21 @@ std::unordered_map IterDomainGraphs:: namespace { std::unordered_map updateMap( - const std::unordered_map stale_map, + const std::unordered_map& stale_map, IdGraph& new_graph) { std::unordered_map new_map; - for (auto stale_entry : stale_map) { - auto stale_id_group = stale_entry.first; - auto new_groups = new_graph.toGroups(*stale_id_group); + + for (const auto& [stale_key, mapped_id] : stale_map) { + const IdGroups& new_groups = new_graph.toGroups(*stale_key); NVF_ERROR( new_groups.size() == 1, "\nUpdate map assumes that new graph is equivalent to old graph plus extra mappings.\n", "i.e. all mappings in new_graph should exist in the graph stale_map was produced on.\n", "old:", - nvfuser::toString(stale_id_group), + nvfuser::toString(stale_key), "new: ", nvfuser::toString(new_groups)); - new_map[new_groups.front()] = stale_entry.second; + new_map[new_groups.front()] = mapped_id; } return new_map; } @@ -1359,15 +1403,15 @@ std::unordered_map updateMap( // traversing on definitions. Ignoring broadcast IdGroups and resetting inputs // at RFactor IdGroups. std::unordered_map computeCoveredGroups( - const IdGraph& graph, - std::unordered_set view_rfactor_ids) { + const IdGraph& exact_graph, + const std::unordered_set& view_rfactor_ids) { // Map from an exact iter domain group, to all the exact iter domain groups it // covers std::unordered_map covered_ids; - for (auto id_group : graph.disjointIdSets().disjointSets()) { + for (const IdGroup& id_group : exact_graph.disjointIdSets().disjointSets()) { // Initialize inputs - if (graph.getUniqueDefinitions(id_group).empty()) { + if (exact_graph.getUniqueDefinitions(id_group).empty()) { covered_ids[id_group] = {id_group}; } @@ -1378,7 +1422,8 @@ std::unordered_map computeCoveredGroups( covered_ids[id_group] = {id_group}; } - // Initialize broadcast groups to empty + // Initialize broadcast groups to empty since broadcast domains + // don't matter for indexing if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { return id->isBroadcast(); })) { @@ -1386,17 +1431,17 @@ std::unordered_map computeCoveredGroups( } } - IdGraphStmtSort exact_stmt_sort(graph); + IdGraphStmtSort exact_stmt_sort(exact_graph); - for (auto exact_expr : exact_stmt_sort.exprs()) { - auto input_groups = graph.inputGroups(exact_expr); + for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) { + std::vector input_groups = exact_graph.inputGroups(exact_expr); IdGroups covered; - for (auto inp_group : input_groups) { + for (const IdGroup& inp_group : input_groups) { covered.pushBack(covered_ids.at(inp_group)); } - for (auto output_group : graph.outputGroups(exact_expr)) { + for (const IdGroup& output_group : exact_graph.outputGroups(exact_expr)) { // Don't overwrite initialized cases due to rfactor markings. if (covered_ids.find(output_group) == covered_ids.end()) { covered_ids[output_group] = covered; @@ -1412,11 +1457,16 @@ std::unordered_map IterDomainGraphs:: buildLoopPromotionMap( const std::vector& exprs, StatefulLoweringInfo& info, - std::unordered_map stale_promotion_map) { + const std::unordered_map& stale_promotion_map) { + // Non-ca domains may also need to be promoted if parent domains are + // promoted. + // Opportunistically add non-inlined loop relationships where they don't // interfere with the loop groups. This should be on all p_ids that are not // p_ca_ids. for (auto p_id : info.ordered_c_ids.subtract(info.ordered_p_ca_ids)) { + // p2c_permissive_maps include those that are not mapped with the + // loop map auto entry_it = info.p2c_permissive_maps.find(p_id); if (entry_it == info.p2c_permissive_maps.end()) { continue; @@ -1432,12 +1482,15 @@ std::unordered_map IterDomainGraphs:: // Grab all iter domains already in the loop groups for both iter // domains. - auto loop_groups = + IdGroups loop_groups = idGraph(IdMappingMode::LOOP) .toGroups(VectorOfUniqueEntries{p_id, c_id}); VectorOfUniqueEntries all_ids_in_groups; + // p_id and c_id are not loop mapped, so there must be two ID groups + NVF_ERROR(loop_groups.size() == 2); + ParallelType common_ptype = loop_groups.front()->front()->getParallelType(); if (std::any_of( @@ -1450,21 +1503,22 @@ std::unordered_map IterDomainGraphs:: continue; } - for (auto loop_group : loop_groups) { + for (const IdGroup& loop_group : loop_groups) { all_ids_in_groups.pushBack(*loop_group); } // Ignore new loop mappings from replays, we can still opportunistically // merge leaves if they already have a promoted id from replay associated - // with them. + // with them. Since they are not included in ordered_c_ids, + // taking intersection filters them out all_ids_in_groups = all_ids_in_groups.intersect(info.ordered_c_ids); // Grab the almost exact map of all iter domains in those loop groups - auto ae_groups = + const IdGroups& ae_groups = idGraph(IdMappingMode::ALMOSTEXACT).toGroups(all_ids_in_groups); // If there's no broadcast promotion within the loop group then all the - // iter domains will be almost exact mapped with eachother. + // iter domains will be almost exact mapped with each other. if (ae_groups.size() == 1) { idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); } @@ -1504,7 +1558,8 @@ std::unordered_map IterDomainGraphs:: // TODO: I'm uncertain if we can simply use the iel_promotion_map. Once this // system is in use we should test not recomputing the "concrete ids". - for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { + for (const IdGroup& loop_group : + loop_graph_copy.disjointIdSets().disjointSets()) { if (loop_group->size() == 1) { loop_graph_copy_promotion_map[loop_group] = loop_group->front(); continue; @@ -1520,7 +1575,7 @@ std::unordered_map IterDomainGraphs:: } // Grab the iel entry - auto iel_group = intersection_exact_loop_graph.toGroup(loop_id); + const IdGroup& iel_group = intersection_exact_loop_graph.toGroup(loop_id); auto iel_promo_it = iel_promotion_map.find(iel_group); if (iel_promo_it == iel_promotion_map.end()) { @@ -1530,7 +1585,7 @@ std::unordered_map IterDomainGraphs:: } else { // If this terminal ID doesn't have a promotion associated with it, save // the terminal ID. - exact_promoted_terminal_ids.push_back(std::make_pair( + exact_promoted_terminal_ids.emplace_back(std::make_pair( idGraph(IdMappingMode::EXACT).toGroup(iel_promo_it->second), iel_promo_it->second)); } @@ -1541,7 +1596,7 @@ std::unordered_map IterDomainGraphs:: // All exact groups covered by all iter domains in this loop group IdGroups loop_group_covered_ids; - for (auto exact_group : exact_groups) { + for (const IdGroup& exact_group : exact_groups) { auto covered_it = exact_covered_ids.find(exact_group); NVF_ERROR(covered_it != exact_covered_ids.end()); loop_group_covered_ids.pushBack(covered_it->second); @@ -1552,12 +1607,12 @@ std::unordered_map IterDomainGraphs:: // Check if any of the candidate Iter Domains we collected cover all the // exact groups of loop_group_covered_ids. If so, that's the correct // promoted iter domain of this group. - for (auto entry : exact_promoted_terminal_ids) { - auto terminal_id_group = entry.first; - auto terminal_id = entry.second; + for (const auto& entry : exact_promoted_terminal_ids) { + const IdGroup& terminal_id_group = entry.first; + IterDomain* terminal_id = entry.second; auto covered_it = exact_covered_ids.find(terminal_id_group); NVF_ERROR(covered_it != exact_covered_ids.end()); - if (loop_group_covered_ids.subtract(covered_it->second).size() == 0) { + if (loop_group_covered_ids.subtract(covered_it->second).empty()) { loop_promotion_id = terminal_id; break; } @@ -1569,15 +1624,16 @@ std::unordered_map IterDomainGraphs:: << "\n ERROR Loop promotion map build. Could not find promotion for loop group:\n "; err_msg << nvfuser::toString(loop_group, 0, true); err_msg << "\nnone of the terminal iter domains of this group:\n "; - for (auto entry : exact_promoted_terminal_ids) { - auto terminal_id_group = entry.first; - auto covered_id_groups = exact_covered_ids.at(terminal_id_group); + for (const auto& entry : exact_promoted_terminal_ids) { + const IdGroup& terminal_id_group = entry.first; + const IdGroups& covered_id_groups = + exact_covered_ids.at(terminal_id_group); err_msg << " " << nvfuser::toString(terminal_id_group, 0, true) << " -(covers)-> " << nvfuser::toString(covered_id_groups) << std::endl; } err_msg << "iter domains in this group cover all id groups:\n"; - for (auto covered_group : loop_group_covered_ids) { + for (const IdGroup& covered_group : loop_group_covered_ids) { err_msg << " " << nvfuser::toString(covered_group, 0, true); } // NVF_ERROR(false, err_msg.str()); @@ -1586,11 +1642,11 @@ std::unordered_map IterDomainGraphs:: } } - for (auto loop_group : loop_graph_copy.disjointIdSets().disjointSets()) { - if (loop_graph_copy_promotion_map.find(loop_group) != - loop_graph_copy_promotion_map.end()) { - } - } + // At this point, most of loop groups should have correct promoted + // IDs. However, non-inlined loop groups may miss promotion that + // should be propagated from parent ID groups, e.g., iS50 of T2 in + // Indexing19. Its parent ID loop group is promoted, but the loop + // group of iS50 is not found yet. // Reset the promotion map for the second pass. // TODO: Unclear if we could simply update the iel_promotion_map from @@ -1600,10 +1656,15 @@ std::unordered_map IterDomainGraphs:: // Need to run a replay for the loop groups that are dependent on inlined loop // groups, but themselves are not inlined loop groups. - for (auto iel_expr : IdGraphStmtSort(intersection_exact_loop_graph).exprs()) { - auto iel_inp_groups = intersection_exact_loop_graph.inputGroups(iel_expr); + for (const ExprGroup& iel_expr : + IdGraphStmtSort(intersection_exact_loop_graph).exprs()) { + NVF_ERROR(!iel_expr->empty()); + + std::vector iel_inp_groups = + intersection_exact_loop_graph.inputGroups(iel_expr); - auto iel_out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); + std::vector iel_out_groups = + intersection_exact_loop_graph.outputGroups(iel_expr); // When replaying the transformations we can't blindly apply loop promotion // to all iter domains within a loop group as it would replay the @@ -1629,12 +1690,12 @@ std::unordered_map IterDomainGraphs:: // the same loop group. IdGroups inp_loop_groups; - for (auto iel_inp_group : iel_inp_groups) { + for (const IdGroup& iel_inp_group : iel_inp_groups) { inp_loop_groups.pushBack(loop_graph_copy.toGroup(iel_inp_group->front())); } IdGroups out_loop_groups; - for (auto iel_out_group : iel_out_groups) { + for (const IdGroup& iel_out_group : iel_out_groups) { out_loop_groups.pushBack(loop_graph_copy.toGroup(iel_out_group->front())); } @@ -1647,14 +1708,15 @@ std::unordered_map IterDomainGraphs:: bool an_input_was_promoted = false; // Promote inputs for replay - for (auto iel_inp_group : iel_inp_groups) { + for (const IdGroup& iel_inp_group : iel_inp_groups) { // Promote loops based on the loop promotion map. If the loop promotion // map should be used and has an entry we should use that promotion. This // happen when an iel expression is across a loop group boundary. // Signifying and capturing instances when we traverse across an inlined // loop group to a non-inlined loop group boundary (think of the iel graph // projected onto the loop graph). - auto loop_copy_group = loop_graph_copy.toGroup(iel_inp_group->front()); + const IdGroup& loop_copy_group = + loop_graph_copy.toGroup(iel_inp_group->front()); auto inp_loop_promo_it = loop_graph_copy_promotion_map.find(loop_copy_group); if (loop_promote_inputs && @@ -1696,14 +1758,15 @@ std::unordered_map IterDomainGraphs:: ExprGroups promoted_input_uses; for (auto inp_id : promoted_inputs) { - auto inp_exact_group = idGraph(IdMappingMode::EXACT).toGroup(inp_id); + const auto& inp_exact_group = + idGraph(IdMappingMode::EXACT).toGroup(inp_id); promoted_input_groups.push_back(inp_exact_group); promoted_input_uses.pushBack( idGraph(IdMappingMode::EXACT).getUniqueUses(inp_exact_group)); } // Check every use to see if it matches - for (auto exact_use_group : promoted_input_uses) { + for (const ExprGroup& exact_use_group : promoted_input_uses) { // Check if all the attributes (including type) of the transform match if (!IdGraph::transformAtributesMatch( iel_expr->front(), exact_use_group->front())) { @@ -1755,7 +1818,7 @@ std::unordered_map IterDomainGraphs:: } } - for (auto group : + for (const IdGroup& group : intersection_exact_loop_graph.disjointIdSets().disjointSets()) { if (iel_promotion_map.find(group) == iel_promotion_map.end()) { continue; diff --git a/csrc/id_model/id_graphs.h b/csrc/id_model/id_graphs.h index 8728f96147f..440963288a6 100644 --- a/csrc/id_model/id_graphs.h +++ b/csrc/id_model/id_graphs.h @@ -82,7 +82,7 @@ struct StatefulLoweringInfo; // PERMISSIVE) // Forward through split one axes, i.e. id{ceilDiv(i0, 1)}, id{i0} are mapped // -class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { +class IterDomainGraphs : public PolymorphicBase { public: IterDomainGraphs( const std::vector& exprs, @@ -246,7 +246,7 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { std::unordered_map buildLoopPromotionMap( const std::vector& exprs, StatefulLoweringInfo& info, - std::unordered_map stale_promotion_map); + const std::unordered_map& stale_promotion_map); // Builds idGraph(IdMappingMode::INDEX) and returns the iter domain promotion // map to go from leaf domains of each (consumer only?) tensor to their @@ -292,6 +292,7 @@ class TORCH_CUDA_CU_API IterDomainGraphs : public PolymorphicBase { c10::optional> self_mapping_info_ = c10::nullopt; + // Promotion domain for each loop group std::unordered_map loop_promotion_map_; std::unordered_set view_rfactor_ids_; diff --git a/csrc/id_model/transform_replay.h b/csrc/id_model/transform_replay.h index 290f82f33f5..ff549db65e5 100644 --- a/csrc/id_model/transform_replay.h +++ b/csrc/id_model/transform_replay.h @@ -26,8 +26,6 @@ class ReplayTransform : OptInConstDispatch { const Expr* expression_to_match); private: - ReplayTransform() = delete; - ReplayTransform( const std::vector& ordered_inputs, const Expr* expression_to_match); @@ -61,15 +59,13 @@ class ReplacementTransformCloner : OptInConstDispatch { // validation is done on provided inputs/outputs. // // In other words a split i0{I0}->i1{I0//2}, i2{2} with a map: - // i2{2} -> i3{48} wouldn't throw an error, but would not bevalid. + // i2{2} -> i3{48} wouldn't throw an error, but would not be valid. static Expr* clone( const std::unordered_map& provided_expr_val_2_replacement_val, const Expr* expression_to_match); private: - ReplacementTransformCloner() = delete; - ReplacementTransformCloner( const std::unordered_map& expr_to_match_2_replacement, diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp index 271227f8f42..0f33135ad5f 100644 --- a/csrc/id_model/visitor.cpp +++ b/csrc/id_model/visitor.cpp @@ -13,6 +13,8 @@ void IdGraphVisitor::traverse() { IdGroups all_ids; ExprGroups all_exprs; { + // Initialize IDs to traverse. If sub_selection is provided, only + // traverse IDs that are included in the set are traversed. if (sub_selection_.empty()) { all_ids = IdGroups( graph().disjointIdSets().disjointSets().begin(), @@ -25,13 +27,17 @@ void IdGraphVisitor::traverse() { } } + // Initialize exprs to traverse. If sub_selection is provided, + // only traverse exprs that are strictly contained within the provided + // sub_selection. Exprs are excluded if any of inputs or outputs + // is not in sub_selection. if (sub_selection_.empty()) { all_exprs = ExprGroups( graph().disjointExprSets().disjointSets().begin(), graph().disjointExprSets().disjointSets().end()); } else { - for (auto id_group : all_ids) { - for (auto def : graph().getUniqueDefinitions(id_group)) { + for (const IdGroup& id_group : all_ids) { + for (const ExprGroup& def : graph().getUniqueDefinitions(id_group)) { if (all_exprs.has(def)) { continue; } @@ -53,17 +59,14 @@ void IdGraphVisitor::traverse() { { IdGroups not_inputs; IdGroups not_outputs; - for (auto expr_group : all_exprs) { - auto inp_groups = IdGroups(graph().inputGroups(expr_group)); - auto out_groups = IdGroups(graph().outputGroups(expr_group)); - - if (inp_groups.intersect(out_groups).size() > 0) { + for (const ExprGroup& expr_group : all_exprs) { + if (graph().isTrivialExprGroup(expr_group)) { // Expression is just a loop to its current group, ignore continue; } - not_inputs.pushBack(out_groups); - not_outputs.pushBack(inp_groups); + not_inputs.pushBack(graph().outputGroups(expr_group)); + not_outputs.pushBack(graph().inputGroups(expr_group)); } terminating_inputs = @@ -79,7 +82,7 @@ void IdGraphVisitor::traverse() { ExprGroups to_visit_exprs; ExprGroups visited_exprs; - auto is_expr_ready = [&](ExprGroup expr_group) { + auto is_expr_ready = [&](const ExprGroup& expr_group) { auto inp_groups = graph().inputGroups(expr_group); return std::all_of( inp_groups.begin(), inp_groups.end(), [&](IdGroup id_group) { @@ -87,7 +90,7 @@ void IdGraphVisitor::traverse() { }); }; - auto is_id_ready = [&](IdGroup id_group) { + auto is_id_ready = [&](const IdGroup& id_group) { auto unique_defs = graph().getUniqueDefinitions(id_group); return std::all_of( unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { @@ -96,7 +99,7 @@ void IdGraphVisitor::traverse() { }); }; - while (to_visit_ids.size() > 0 || to_visit_exprs.size() > 0) { + while (!to_visit_ids.empty() || !to_visit_exprs.empty()) { // Process expressions first as all definitions of iter domains have to be // processed before we can process that iter domain. @@ -105,8 +108,9 @@ void IdGraphVisitor::traverse() { bool something_was_processed = false; ExprGroups still_to_visit_exprs; - while (to_visit_exprs.size() > 0) { - auto current_expr_group = to_visit_exprs.popFront(); + while (!to_visit_exprs.empty()) { + ExprGroup current_expr_group = to_visit_exprs.popFront(); + NVF_ERROR(!current_expr_group->empty()); if (visited_exprs.has(current_expr_group)) { continue; } @@ -117,10 +121,7 @@ void IdGraphVisitor::traverse() { something_was_processed = true; visited_exprs.pushBack(current_expr_group); - auto out_groups = graph().outputGroups(current_expr_group); - for (auto out_group : out_groups) { - to_visit_ids.pushBack(out_group); - } + to_visit_ids.pushBack(graph().outputGroups(current_expr_group)); } else { still_to_visit_exprs.pushBack(current_expr_group); } @@ -129,8 +130,9 @@ void IdGraphVisitor::traverse() { std::swap(to_visit_exprs, still_to_visit_exprs); IdGroups still_to_visit_ids; - while (to_visit_ids.size() > 0) { + while (!to_visit_ids.empty()) { auto current_id_group = to_visit_ids.popFront(); + NVF_ERROR(!current_id_group->empty()); if (visited_ids.has(current_id_group)) { continue; } @@ -155,7 +157,7 @@ void IdGraphVisitor::traverse() { NVF_ERROR( something_was_processed || - (to_visit_ids.size() == 0 && to_visit_exprs.size() == 0), + (to_visit_ids.empty() && to_visit_exprs.empty()), "Infinite loop entered."); } } diff --git a/csrc/id_model/visitor.h b/csrc/id_model/visitor.h index 4029fdced0e..0c5122979f9 100644 --- a/csrc/id_model/visitor.h +++ b/csrc/id_model/visitor.h @@ -22,7 +22,16 @@ namespace nvfuser { // Warning: This is not a great iterator if there's a desire to minimize paths // traveled to simply visit all IdGroups in order. See ExprsBetween to see how // we might minimize paths. -class TORCH_CUDA_CU_API IdGraphVisitor { +class IdGraphVisitor { + public: + IdGraphVisitor() = delete; + + IdGraphVisitor& operator=(const IdGraphVisitor& other) = delete; + + IdGraphVisitor& operator=(IdGraphVisitor&& other) = delete; + + virtual ~IdGraphVisitor() = default; + protected: // If sub_selection is assumed to be a set of iter domains by which form a // sub-regrion of the IdGraph provided. Only that sub-region will be visited. @@ -31,6 +40,10 @@ class TORCH_CUDA_CU_API IdGraphVisitor { const VectorOfUniqueEntries sub_selection = {}) : id_graph_(id_graph), sub_selection_(sub_selection) {} + IdGraphVisitor(const IdGraphVisitor& other) = default; + + IdGraphVisitor(IdGraphVisitor&& other) = default; + virtual void handle(IdGroup id_group) = 0; virtual void handle(ExprGroup expr_group) = 0; @@ -40,16 +53,6 @@ class TORCH_CUDA_CU_API IdGraphVisitor { return id_graph_; }; - IdGraphVisitor() = delete; - - IdGraphVisitor(const IdGraphVisitor& other) = default; - IdGraphVisitor& operator=(const IdGraphVisitor& other) = delete; - - IdGraphVisitor(IdGraphVisitor&& other) = default; - IdGraphVisitor& operator=(IdGraphVisitor&& other) = delete; - - virtual ~IdGraphVisitor() = default; - private: const IdGraph& id_graph_; const VectorOfUniqueEntries sub_selection_; @@ -65,12 +68,14 @@ class IdGraphStmtSort : public IdGraphVisitor { IdGraphVisitor::traverse(); } - ExprGroups exprs() { - return sorted_exprs; + // Return non-reference so that code like below can work + // for (auto expr_group: IdGraphStmtSort(graph).exprs()) + ExprGroups exprs() const { + return sorted_exprs_; } - IdGroups ids() { - return sorted_ids; + IdGroups ids() const { + return sorted_ids_; } ~IdGraphStmtSort() override = default; @@ -78,15 +83,15 @@ class IdGraphStmtSort : public IdGraphVisitor { protected: using IdGraphVisitor::handle; void handle(IdGroup id_group) override { - sorted_ids.pushBack(id_group); + sorted_ids_.pushBack(id_group); } void handle(ExprGroup expr_group) override { - sorted_exprs.pushBack(expr_group); + sorted_exprs_.pushBack(expr_group); } - ExprGroups sorted_exprs; - IdGroups sorted_ids; + ExprGroups sorted_exprs_; + IdGroups sorted_ids_; }; } // namespace nvfuser From c51379edaa4cb078b2a8b619e8842feb3b2d4c3d Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 10 Oct 2023 22:43:16 -0700 Subject: [PATCH 058/178] IdGraph cleanup (#1062) --- csrc/disjoint_set.h | 6 ++---- csrc/id_model/id_graph.cpp | 10 +++++----- csrc/id_model/id_graphs.cpp | 24 +++++++++++------------- csrc/id_model/to_string.cpp | 32 +++++++++++++++++--------------- csrc/id_model/utils.h | 2 +- 5 files changed, 36 insertions(+), 38 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 660fdbc60d2..ff714b22844 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -50,10 +50,8 @@ class VectorOfUniqueEntries { } } - VectorOfUniqueEntries(const VectorOfUniqueEntries& other) { - vector_ = other.vector(); - set_ = other.set(); - } + VectorOfUniqueEntries(const VectorOfUniqueEntries& other) + : vector_(other.vector()), set_(other.set()) {} VectorOfUniqueEntries& operator=(const VectorOfUniqueEntries& other) { if (this != &other) { diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index 41c01e676e9..a9a2bf0a37b 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -408,7 +408,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) something_processed = true; sorted_exprs.pushBack(currently_visiting); auto outputs = outputGroups(currently_visiting); - for (auto out_id : outputs) { + for (const IdGroup& out_id : outputs) { visited.pushBack(out_id); auto use_pair = getUses(out_id); if (!use_pair.second) { @@ -817,7 +817,7 @@ void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { } } - for (auto producer_group : producers) { + for (const IdGroup& producer_group : producers) { unique_uses_.at(producer_group).erase(expr0_orig_group); unique_uses_.at(producer_group).erase(expr1_orig_group); unique_uses_.at(producer_group).pushBack(expr_new_group); @@ -831,7 +831,7 @@ void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { } } - for (auto consumer_group : consumers) { + for (const IdGroup& consumer_group : consumers) { unique_definitions_.at(consumer_group).erase(expr0_orig_group); unique_definitions_.at(consumer_group).erase(expr1_orig_group); unique_definitions_.at(consumer_group).pushBack(expr_new_group); @@ -929,7 +929,7 @@ void IdGraph::removeTrivialExprs() { // from definitions and uses. They shouldn't be important in traversal, and // will break the terminal input/terminal output logic of traversal. Similar // to what's drafted in buildIndexGraph - for (auto trivial_expr_group : trivial_expr_groups) { + for (const ExprGroup& trivial_expr_group : trivial_expr_groups) { // Complexity of erase not good as both disjoint set and vector of unique // entries require a vector find to erase an entry. eraseExprGroup(trivial_expr_group); @@ -940,7 +940,7 @@ void IdGraph::removeTrivialExprs() { // erasing multiple expr_groups. void IdGraph::eraseExprGroup(const ExprGroup& expr_group) { // Erase entries that exist in unique_definitions_ and unique_uses_ - for (auto id_group : disjointIdSets().disjointSets()) { + for (const IdGroup& id_group : disjointIdSets().disjointSets()) { // Make sure the entries exists NVF_ERROR( unique_definitions_.find(id_group) != unique_definitions_.end(), diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index b685ce249d3..b426abbbf2e 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -562,7 +562,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( for (auto in : ir_utils::filterByType(replay->inputs())) { auto uses_pair = graph.getUses(graph.toGroup(in)); if (uses_pair.second) { - for (auto use_group : uses_pair.first) { + for (const ExprGroup& use_group : uses_pair.first) { if (use_group == replay_group) { continue; } @@ -580,7 +580,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( for (auto out : ir_utils::filterByType(replay->outputs())) { auto defs_pair = graph.getDefinitions(graph.toGroup(out)); if (defs_pair.second) { - for (auto def_group : defs_pair.first) { + for (const ExprGroup& def_group : defs_pair.first) { if (def_group == replay_group) { continue; } @@ -632,9 +632,7 @@ IterDomain* IterDomainGraphs::cloneIterDomain(IterDomain* id) { IdGraph IterDomainGraphs::initializeIdGraph(bool propagate_through_exprs) { IdGraph id_graph(propagate_through_exprs); - for (auto definition_entry : id_definitions_) { - auto id = definition_entry.first; - auto defs = definition_entry.second; + for (const auto& [id, defs] : id_definitions_) { auto uses_it = id_uses_.find(id); NVF_ERROR( uses_it != id_uses_.end(), @@ -718,7 +716,7 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { } // TODO: Should this just get rolled up in the forwarding map now? - for (auto entry : permissive_forwarding.producer_compliment_map) { + for (const auto& entry : permissive_forwarding.producer_compliment_map) { for (auto entry_2 : entry.second) { idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); } @@ -730,7 +728,7 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { // TODO: Should this just get rolled up in the forwarding map now? // TODO: Why should IDs be mapped to their compliments? Is this right? - for (auto entry : permissive_forwarding.consumer_compliment_map) { + for (const auto& entry : permissive_forwarding.consumer_compliment_map) { for (auto entry_2 : entry.second) { idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); } @@ -1579,15 +1577,15 @@ std::unordered_map IterDomainGraphs:: auto iel_promo_it = iel_promotion_map.find(iel_group); if (iel_promo_it == iel_promotion_map.end()) { - // If this terminal ID has a promotion, grab the promoted ID. - exact_promoted_terminal_ids.push_back(std::make_pair( - idGraph(IdMappingMode::EXACT).toGroup(loop_id), loop_id)); - } else { // If this terminal ID doesn't have a promotion associated with it, save // the terminal ID. - exact_promoted_terminal_ids.emplace_back(std::make_pair( + exact_promoted_terminal_ids.emplace_back( + idGraph(IdMappingMode::EXACT).toGroup(loop_id), loop_id); + } else { + // If this terminal ID has a promotion, grab the promoted ID. + exact_promoted_terminal_ids.emplace_back( idGraph(IdMappingMode::EXACT).toGroup(iel_promo_it->second), - iel_promo_it->second)); + iel_promo_it->second); } } diff --git a/csrc/id_model/to_string.cpp b/csrc/id_model/to_string.cpp index 7bf3b619ea8..885e4171e57 100644 --- a/csrc/id_model/to_string.cpp +++ b/csrc/id_model/to_string.cpp @@ -41,6 +41,7 @@ std::string toString( const std::vector& id_group, int indent_size) { std::vector names; + names.reserve(id_group.size()); for (auto id : id_group) { names.push_back(id->name()); } @@ -70,14 +71,14 @@ std::string toString( unsigned int pos = 0; - for (auto id_group : id_groups) { + for (const IdGroup& id_group : id_groups) { unsigned int min_id_name = std::numeric_limits::max(); for (auto id : *id_group) { if (id->name() < min_id_name) { min_id_name = id->name(); } } - group_name_info.push_back(std::make_pair(min_id_name, pos++)); + group_name_info.emplace_back(min_id_name, pos++); } ss << indent(indent_size) << "(idgs){\n"; @@ -105,14 +106,14 @@ std::string toString( unsigned int pos = 0; - for (auto id_group : id_groups) { + for (const IdGroup& id_group : id_groups) { unsigned int min_id_name = std::numeric_limits::max(); for (auto id : *id_group) { if (id->name() < min_id_name) { min_id_name = id->name(); } } - group_name_info.push_back(std::make_pair(min_id_name, pos++)); + group_name_info.emplace_back(min_id_name, pos++); } ss << indent(indent_size) << "(idgs){\n"; @@ -135,14 +136,14 @@ std::string toInlineString(const std::vector& id_groups) { unsigned int pos = 0; - for (auto id_group : id_groups) { + for (const IdGroup& id_group : id_groups) { unsigned int min_id_name = std::numeric_limits::max(); for (auto id : *id_group) { if (id->name() < min_id_name) { min_id_name = id->name(); } } - group_name_info.push_back(std::make_pair(min_id_name, pos++)); + group_name_info.emplace_back(min_id_name, pos++); } // Sort based on minimum id in the group @@ -168,6 +169,7 @@ std::string toInlineString(const std::vector& id_groups) { std::string toString(const std::vector& expr_group, int indent_size) { std::vector names; + names.reserve(expr_group.size()); for (auto expr : expr_group) { names.push_back(expr->name()); } @@ -201,14 +203,14 @@ std::string toString( unsigned int pos = 0; - for (auto expr_group : expr_groups) { + for (const ExprGroup& expr_group : expr_groups) { unsigned int min_expr_name = std::numeric_limits::max(); for (auto expr : *expr_group) { if (expr->name() < min_expr_name) { min_expr_name = expr->name(); } } - group_name_info.push_back(std::make_pair(min_expr_name, pos++)); + group_name_info.emplace_back(min_expr_name, pos++); } ss << indent(indent_size) << "(exprgs){\n"; @@ -218,7 +220,7 @@ std::string toString( for (auto i : c10::irange(group_name_info.size())) { auto pos = group_name_info[i].second; - auto expr_group = expr_groups[pos]; + const ExprGroup& expr_group = expr_groups[pos]; auto inputs = IdGroups(id_graph.inputGroups(expr_group)); auto outputs = IdGroups(id_graph.outputGroups(expr_group)); @@ -244,14 +246,14 @@ std::string toString( unsigned int pos = 0; - for (auto expr_group : expr_groups) { + for (const ExprGroup& expr_group : expr_groups) { unsigned int min_id_name = std::numeric_limits::max(); for (auto id : *expr_group) { if (id->name() < min_id_name) { min_id_name = id->name(); } } - group_name_info.push_back(std::make_pair(min_id_name, pos++)); + group_name_info.emplace_back(min_id_name, pos++); } ss << indent(indent_size) << "(exprgs){\n"; @@ -299,10 +301,10 @@ std::string definitionsString( int indent_size, bool with_ptr) { ExprGroups defs; - for (auto id_group : id_graph.disjointIdSets().disjointSets()) { + for (const IdGroup& id_group : id_graph.disjointIdSets().disjointSets()) { auto definition_pair = id_graph.getDefinitions(id_group); if (definition_pair.second) { - for (auto expr_group : definition_pair.first) { + for (const ExprGroup& expr_group : definition_pair.first) { defs.pushBack(expr_group); } } @@ -315,10 +317,10 @@ std::string usesString( int indent_size, bool with_ptr) { ExprGroups uses; - for (auto id_group : id_graph.disjointIdSets().disjointSets()) { + for (const IdGroup& id_group : id_graph.disjointIdSets().disjointSets()) { auto definition_pair = id_graph.getUses(id_group); if (definition_pair.second) { - for (auto expr_group : definition_pair.first) { + for (const ExprGroup& expr_group : definition_pair.first) { uses.pushBack(expr_group); } } diff --git a/csrc/id_model/utils.h b/csrc/id_model/utils.h index 563c5279774..2d6327bf586 100644 --- a/csrc/id_model/utils.h +++ b/csrc/id_model/utils.h @@ -7,7 +7,7 @@ // clang-format on #pragma once -#include "utils.h" +#include #include #include From d77436cb9f50ed7aef827b09cc2d8e581d988981 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 10 Oct 2023 23:12:36 -0700 Subject: [PATCH 059/178] remove TORCH_CUDA_CU_API --- csrc/ir/utils.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index cf137543f70..fe328f2266e 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -312,8 +312,7 @@ std::vector outputTvsOf(std::vector tvs); std::vector allTvs(Fusion* fusion); // returns all tensor views used in the provided expressions -TORCH_CUDA_CU_API std::vector allTvsOfExprs( - const std::vector& exprs); +std::vector allTvsOfExprs(const std::vector& exprs); // returns all tensor views in fusion that are used between outputs and inputs // except the specified set. From 7f34b17b44e972cffd7e2b5be21f4c61aada5225 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 10 Oct 2023 23:42:29 -0700 Subject: [PATCH 060/178] clang-tidy --- csrc/ir/utils.cpp | 3 ++- csrc/transform_iter.cpp | 16 ++++++++-------- csrc/transform_iter.h | 3 ++- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 4adbe896856..ec29d369536 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -134,8 +134,9 @@ std::vector normalizeOld2New( // All available new positions std::set all_positions; - for (decltype(ndims) i{0}; i < ndims; i++) + for (decltype(ndims) i{0}; i < ndims; i++) { all_positions.insert((int)i); + } // Check what positions haven't been specified. std::set positions_left; diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 3e24a15c824..16d08a9dcdf 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -289,8 +289,9 @@ void ReplayTransformations::runReplay() { // Populate leaf_vec_ in a deterministic manner. This is deterministic // because size_t in leaf_ids is filled based on operation order. std::set, id_int_lt> ordered_set; - for (auto entry : leaf_ids_) + for (auto entry : leaf_ids_) { ordered_set.emplace(entry); + } leaf_vec_.clear(); leaf_vec_.resize(ordered_set.size()); @@ -781,19 +782,16 @@ namespace { IterDomain* getSwizzleFinalOutput( IterDomain* id, const std::unordered_map& id2expr) { - bool is_swizzle_input = true; - // Note: currently not supporting swizzling consumer of another // swizzle id, so this should terminate in 1 iter, but eventually // will try to support stacked swizzles so keeping this pass // generic. - while (is_swizzle_input) { + while (true) { auto expr_it = id2expr.find(id); // This means id is a leaf that doesn't // have any consumers. Stop iteration in this case. if (expr_it == id2expr.end()) { - is_swizzle_input = false; break; } @@ -813,7 +811,7 @@ IterDomain* getSwizzleFinalOutput( } else { // Probably unreachable but if the expression // is unknown type assume it is not a swizzle op. - is_swizzle_input = false; + break; } } @@ -903,8 +901,9 @@ BestEffortReplay BestEffortReplay::replayCasP( bool skip_consumer_swizzle, bool skip_producer_swizzle, bool skip_resize) { - if (producer_compute_at_axis < 0) + if (producer_compute_at_axis < 0) { producer_compute_at_axis += (int)producer->nDims() + 1; + } NVF_ERROR( producer_compute_at_axis >= 0 && @@ -969,8 +968,9 @@ BestEffortReplay BestEffortReplay::replayPasC( bool skip_producer_swizzle, bool skip_consumer_swizzle, bool skip_resize) { - if (consumer_compute_at_axis < 0) + if (consumer_compute_at_axis < 0) { consumer_compute_at_axis += (int)consumer->nDims() + 1; + } NVF_ERROR( consumer_compute_at_axis >= 0 && (unsigned int)consumer_compute_at_axis <= consumer->nDims(), diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index 943c25f6d85..086c47c19ac 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -441,8 +441,9 @@ class BestEffortReplay { // Returned ordered set of IDs in getUnorderedLeafIDs std::vector getLeafIDs() { std::set, id_int_lt> ordered_set; - for (auto entry : leaf_ids_) + for (auto entry : leaf_ids_) { ordered_set.emplace(entry); + } std::vector leaf_vec_; leaf_vec_.resize(ordered_set.size()); From 6603e0a7da59c9a3eb1d8263d12d8de13b22437b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 12 Oct 2023 09:53:34 -0700 Subject: [PATCH 061/178] Do not map broadcast and non-broadcast domains in EXACT (#1065) `PairwiseRootDomainMap` by default allows mapping of broadcast and non-broadcast domains. We should probably disable that too. The validation in the Indexing19 test passes. Some other tests are failing but that is the case without this PR too. Fixes #1052 --- csrc/id_model/id_graphs.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index b426abbbf2e..4f2deb06d8a 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -679,8 +679,9 @@ void IterDomainGraphs::buildExactMap(const std::vector& exprs) { // For exact mapings do not map any broadcast dimensions to // non-broadcast dimensions. Prevent any broadcasted axes being mapped // to non-broadcasted axes. - auto exact_c2p_root_map = - PairwiseRootDomainMap(p_tv, c_tv).mapConsumerToProducer(); + auto exact_c2p_root_map = PairwiseRootDomainMap(p_tv, c_tv) + .mapBroadcast(false) + .mapConsumerToProducer(); for (auto c_id : getSortedKeys(exact_c2p_root_map, Statement::lessThan)) { auto p_id = exact_c2p_root_map.at(c_id); From aacc5298e27c8266ce7fce80bc900d1e28b3ab51 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 12 Oct 2023 12:57:35 -0700 Subject: [PATCH 062/178] Fix mapping propagations through uses (#1072) This doesn't seem necessary at this moment but required to stop using ALMOST_EXACT for permissive mappings. Suppose domains A and B have uses as: A: ExprGroup X and ExprGroup Y B: ExprGroup X and ExprGroup Y i.e., both have the same set of expr groups. Before this propagation, there's nothing to allow map X and Y. When A and B are mapped, we should map the uses of A and B, which means X and Y should be mapped. The check removed in this PR prevents the propagation. --- csrc/id_model/id_graph.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index a9a2bf0a37b..e31aa611c64 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -743,11 +743,10 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { // Propagate on uses if (!orig_uses0.empty() && !orig_uses1.empty()) { for (const ExprGroup& use_group_1 : orig_uses1) { - if (orig_uses0.has(use_group_1)) { - continue; - } - for (const ExprGroup& use_group_0 : orig_uses0) { + if (use_group_0 == use_group_1) { + continue; + } Expr* use0 = use_group_0->front(); Expr* use1 = use_group_1->front(); maybeMapThroughExprs(use0, use1, true); From a7614091c2aa1e23f49f83e8f40cf5a8bee58c2b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 12 Oct 2023 13:05:07 -0700 Subject: [PATCH 063/178] Use the EXACT map as the starting map of the PERMISSIVE map (#1073) ALMOST_EXACT should not be necessary for PERMISSIVE and LOOP Related: #1053 --- csrc/id_model/id_graphs.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 4f2deb06d8a..0b86fbb7b60 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -694,7 +694,10 @@ void IterDomainGraphs::buildExactMap(const std::vector& exprs) { } void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { - idGraph(IdMappingMode::PERMISSIVE) = idGraph(IdMappingMode::ALMOSTEXACT); + // Use the exact map as the starting map rather than the + // almost-exact map. Almost exact is useful for index hoisting but + // not necessary for permissive and loop maps + idGraph(IdMappingMode::PERMISSIVE) = idGraph(IdMappingMode::EXACT); for (auto expr : exprs) { // Multiple outputs are already mapped, we can ignore all but the first From d3eb4c17748a7a1f0a92acb328b812efc2b40a26 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 13 Oct 2023 09:23:07 -0700 Subject: [PATCH 064/178] Same as #1072 but for definitions (#1081) --- csrc/id_model/id_graph.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index e31aa611c64..e71e89b46fe 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -757,11 +757,10 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { // Propagate on definitions if (!orig_defs0.empty() && !orig_defs1.empty()) { for (const ExprGroup& def_group_1 : orig_defs1) { - if (orig_defs0.has(def_group_1)) { - continue; - } - for (const ExprGroup& def_group_0 : orig_defs0) { + if (def_group_0 == def_group_1) { + continue; + } auto def0 = def_group_0->front(); auto def1 = def_group_1->front(); maybeMapThroughExprs(def0, def1, false); From 3bb96926c421c95647651e4ce6067e094f9cc809 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 13 Oct 2023 10:06:50 -0700 Subject: [PATCH 065/178] minor cleanup --- csrc/id_model/id_graph.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index e71e89b46fe..c7bb547baa6 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -660,17 +660,17 @@ bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { auto extent_1o = merge1->outer()->extent(); auto extent_1i = merge1->inner()->extent(); - auto extent_0_match = extent_0o->sameAs(extent_1o) || + auto extent_o_match = extent_0o->sameAs(extent_1o) || (extent_0o->isConstInt() && extent_1o->isConstInt() && extent_0o->evaluateInt() == extent_1o->evaluateInt()) || disjointIdSets().permissiveAreMapped(merge0->outer(), merge1->outer()); - auto extent_1_match = extent_0i->sameAs(extent_1i) || + auto extent_i_match = extent_0i->sameAs(extent_1i) || (extent_0i->isConstInt() && extent_1i->isConstInt() && extent_0i->evaluateInt() == extent_1i->evaluateInt()) || disjointIdSets().permissiveAreMapped(merge0->inner(), merge1->inner()); - if (!(extent_0_match || extent_1_match)) { + if (!(extent_o_match || extent_i_match)) { return false; } } From b778788edf119def5e5c0a6e2ef22340cccb044c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 13 Oct 2023 13:00:00 -0700 Subject: [PATCH 066/178] Disable mappings of non-ca domains in LOOP. (#1082) Makes sense in INDEX but keep LOOP same as the current LOOP --- csrc/id_model/id_graphs.cpp | 64 ------ csrc/python_frontend/fusion_record.h | 31 +-- csrc/python_frontend/python_bindings.cpp | 185 +++++++++--------- .../test/test_nvfuser_fusion_record.cpp | 21 +- csrc/scheduler/mma_utils.cpp | 8 +- csrc/utils.h | 5 +- 6 files changed, 117 insertions(+), 197 deletions(-) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 0b86fbb7b60..08af43f3ff6 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1463,70 +1463,6 @@ std::unordered_map IterDomainGraphs:: // Non-ca domains may also need to be promoted if parent domains are // promoted. - // Opportunistically add non-inlined loop relationships where they don't - // interfere with the loop groups. This should be on all p_ids that are not - // p_ca_ids. - for (auto p_id : info.ordered_c_ids.subtract(info.ordered_p_ca_ids)) { - // p2c_permissive_maps include those that are not mapped with the - // loop map - auto entry_it = info.p2c_permissive_maps.find(p_id); - if (entry_it == info.p2c_permissive_maps.end()) { - continue; - } - auto c_ids = entry_it->second; - for (auto c_id : c_ids) { - if (idGraph(IdMappingMode::LOOP) - .disjointIdSets() - .permissiveAreMapped(p_id, c_id)) { - // Already mapped - continue; - } - - // Grab all iter domains already in the loop groups for both iter - // domains. - IdGroups loop_groups = - idGraph(IdMappingMode::LOOP) - .toGroups(VectorOfUniqueEntries{p_id, c_id}); - - VectorOfUniqueEntries all_ids_in_groups; - - // p_id and c_id are not loop mapped, so there must be two ID groups - NVF_ERROR(loop_groups.size() == 2); - - ParallelType common_ptype = - loop_groups.front()->front()->getParallelType(); - if (std::any_of( - loop_groups.begin() + 1, - loop_groups.end(), - [common_ptype](IdGroup id_group) { - return id_group->front()->getParallelType() != common_ptype; - })) { - // Parallel types don't match, cannot merge non-inlined loop groups. - continue; - } - - for (const IdGroup& loop_group : loop_groups) { - all_ids_in_groups.pushBack(*loop_group); - } - - // Ignore new loop mappings from replays, we can still opportunistically - // merge leaves if they already have a promoted id from replay associated - // with them. Since they are not included in ordered_c_ids, - // taking intersection filters them out - all_ids_in_groups = all_ids_in_groups.intersect(info.ordered_c_ids); - - // Grab the almost exact map of all iter domains in those loop groups - const IdGroups& ae_groups = - idGraph(IdMappingMode::ALMOSTEXACT).toGroups(all_ids_in_groups); - - // If there's no broadcast promotion within the loop group then all the - // iter domains will be almost exact mapped with each other. - if (ae_groups.size() == 1) { - idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); - } - } - } - // Need to use the intersection of exact and loop map again, it needs to be // recomputed. auto intersection_exact_loop_graph = buildIntersection( diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 2312f05458f..00a6842cc05 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -1530,32 +1530,21 @@ struct ReductionOpRecord : RecordFunctor { result = result && (*fusion_op_.template target< - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>() == + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() == *child_ptr->fusion_op_.template target< - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>()); + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>()); if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() << " Target Ptr [self: 0x" << std::hex - << (size_t)*fusion_op_.template target< + debug() + << " Target Ptr [self: 0x" << std::hex + << (size_t)*fusion_op_.template target< - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>() - << "] [other: 0x" << std::hex - << (size_t)*child_ptr->fusion_op_.template target< + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() + << "] [other: 0x" << std::hex + << (size_t)*child_ptr->fusion_op_.template target< - TensorView* (*)(TensorView*, - const std::vector&, - bool, - DataType)>() - << "]\n"; + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() + << "]\n"; } result = result && (keep_dim_ == child_ptr->keep_dim_); result = result && (dtype_ == child_ptr->dtype_); diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 75214901c7e..a0078314cd2 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1751,100 +1751,97 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP("addcmul", addcmul) #undef NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP -#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name, record_type) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = 0; \ - std::vector axes(arg.dims); \ - std::iota(axes.begin(), axes.end(), 0); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast&, \ - bool, \ - DataType)>(op_name), \ - axes, \ - false, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - int axis, \ - bool keepdim, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - 1); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast&, \ - bool, \ - DataType)>(op_name), \ - {axis}, \ - keepdim, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("axis"), \ - py::arg("keepdim") = false, \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - const std::vector& axes, \ - bool keepdim, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast&, \ - bool, \ - DataType)>(op_name), \ - axes, \ - keepdim, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("axes"), \ - py::arg("keepdim") = false, \ - py::arg("dtype") = DataType::Null, \ +#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name, record_type) \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + NVF_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = 0; \ + std::vector axes(arg.dims); \ + std::iota(axes.begin(), axes.end(), 0); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast< \ + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ + op_name), \ + axes, \ + false, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("dtype") = DataType::Null, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + int axis, \ + bool keepdim, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + NVF_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = keepdim ? arg.dims : (arg.dims - 1); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast< \ + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ + op_name), \ + {axis}, \ + keepdim, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("axis"), \ + py::arg("keepdim") = false, \ + py::arg("dtype") = DataType::Null, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + const std::vector& axes, \ + bool keepdim, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + NVF_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast< \ + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ + op_name), \ + axes, \ + keepdim, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("axes"), \ + py::arg("keepdim") = false, \ + py::arg("dtype") = DataType::Null, \ py::return_value_policy::reference); NVFUSER_PYTHON_BINDING_REDUCTION_OP( diff --git a/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp b/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp index e0eabf5122d..170531cfaad 100644 --- a/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp +++ b/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp @@ -98,10 +98,9 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast&, - bool, - DataType)>(sum), + static_cast< + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( + sum), {0}, false, DataType::Float)); @@ -110,10 +109,9 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast&, - bool, - DataType)>(sum), + static_cast< + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( + sum), {0}, false, DataType::Float)); @@ -122,10 +120,9 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast&, - bool, - DataType)>(sum), + static_cast< + TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( + sum), {0}, false, DataType::Float)); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 253c6dc46dc..451849e18a2 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -59,11 +59,11 @@ std::pair generateSharedMemoryEpilogueHeuristics( properties->warpSize * vector_word; const int mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k; const int nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k; - const size_t smem_a = (size_t)(ceilDiv(mk, round_to_factor) * - round_to_factor * smem_double_buffer_stage) * + const size_t smem_a = + (size_t)(ceilDiv(mk, round_to_factor) * round_to_factor * smem_double_buffer_stage) * dataTypeSize(data_types[0]); - const size_t smem_b = (size_t)(ceilDiv(nk, round_to_factor) * - round_to_factor * smem_double_buffer_stage) * + const size_t smem_b = + (size_t)(ceilDiv(nk, round_to_factor) * round_to_factor * smem_double_buffer_stage) * dataTypeSize(data_types[1]); const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * dataTypeSize(data_types[2]); diff --git a/csrc/utils.h b/csrc/utils.h index 7358f1c55b9..e55f191a542 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -196,8 +196,9 @@ std::vector getSortedKeys( // Based on https://stackoverflow.com/a/9154394 template -static auto hasToStringHelper(int) - -> decltype(std::declval::type>().toString(), std::true_type{}); +static auto hasToStringHelper(int) -> decltype( + std::declval::type>().toString(), + std::true_type{}); template static auto hasToStringHelper(long) -> std::false_type; From 4c50dcf5672b24c0635f0049c297e1eba7f66096 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 16 Oct 2023 11:29:27 -0700 Subject: [PATCH 067/178] Remove unnecessary code (#1087) No idea why this is necessary. I don't think it should hit, so leaving an assertion --- csrc/id_model/id_graphs.cpp | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 08af43f3ff6..1a05bdc1fc9 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -884,13 +884,24 @@ StatefulLoweringInfo buildInfo( ir_utils::filterByType(expr->outputs())) { auto resolved_bcast_map = resolvedRootBroadcasts(producer, consumer); - for (auto entry : resolved_bcast_map) { - info.p2c_root_broadcast_resolution_map[entry.first].pushBack( - entry.second); - for (auto other_exact_bcast : *exact_graph.toGroup(entry.first)) { + for (const auto& [p_id, c_id] : resolved_bcast_map) { + info.p2c_root_broadcast_resolution_map[p_id].pushBack(c_id); + for (auto other_exact_bcast : *(exact_graph.toGroup(p_id))) { + if (p_id == other_exact_bcast) { + continue; + } if (all_producer_ca_deps.has(other_exact_bcast)) { + // TODO-NM: Why is this here? Can be removed? + NVF_ERROR( + false, + "Can this happen? Adding other exact: ", + other_exact_bcast->name(), + " in addition to ", + p_id->name(), + " of ", + producer->toString()); info.p2c_root_broadcast_resolution_map[other_exact_bcast] - .pushBack(entry.second); + .pushBack(c_id); } } } From e20a76bab050c297b42e39d564f408d09a87f3b4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 16 Oct 2023 11:54:54 -0700 Subject: [PATCH 068/178] clang-format --- csrc/python_frontend/fusion_record.h | 31 ++- csrc/python_frontend/python_bindings.cpp | 185 +++++++++--------- .../test/test_nvfuser_fusion_record.cpp | 21 +- csrc/scheduler/mma_utils.cpp | 8 +- csrc/utils.h | 5 +- 5 files changed, 133 insertions(+), 117 deletions(-) diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 844b566bdf6..afc4ef0d7d3 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -1454,21 +1454,32 @@ struct ReductionOpRecord : RecordFunctor { result = result && (*fusion_op_.template target< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() == + TensorView* (*)(TensorView*, + const std::vector&, + bool, + DataType)>() == *child_ptr->fusion_op_.template target< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>()); + TensorView* (*)(TensorView*, + const std::vector&, + bool, + DataType)>()); if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { - debug() - << " Target Ptr [self: 0x" << std::hex - << (size_t)*fusion_op_.template target< + debug() << " Target Ptr [self: 0x" << std::hex + << (size_t)*fusion_op_.template target< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() - << "] [other: 0x" << std::hex - << (size_t)*child_ptr->fusion_op_.template target< + TensorView* (*)(TensorView*, + const std::vector&, + bool, + DataType)>() + << "] [other: 0x" << std::hex + << (size_t)*child_ptr->fusion_op_.template target< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>() - << "]\n"; + TensorView* (*)(TensorView*, + const std::vector&, + bool, + DataType)>() + << "]\n"; } result = result && (keep_dim_ == child_ptr->keep_dim_); result = result && (dtype_ == child_ptr->dtype_); diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 18cfb12f38b..0472bc43772 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1774,97 +1774,100 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP("addcmul", addcmul) #undef NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP -#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name, record_type) \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = 0; \ - std::vector axes(arg.dims); \ - std::iota(axes.begin(), axes.end(), 0); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast< \ - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ - op_name), \ - axes, \ - false, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - int axis, \ - bool keepdim, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - 1); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast< \ - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ - op_name), \ - {axis}, \ - keepdim, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("axis"), \ - py::arg("keepdim") = false, \ - py::arg("dtype") = DataType::Null, \ - py::return_value_policy::reference); \ - nvf_ops.def( \ - op_str, \ - [](FusionDefinition::Operators& self, \ - Tensor arg, \ - const std::vector& axes, \ - bool keepdim, \ - PrimDataType dtype) -> Tensor { \ - FUSER_PERF_SCOPE("Operators." op_str); \ - NVF_CHECK( \ - self.validUse(), "Attempting to add to a completed definition!"); \ - FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); \ - Tensor output = fd->defineTensor(ndims); \ - fd->defineRecord(new ReductionOpRecord( \ - {fd->recordingState(arg())}, \ - {fd->recordingState(output())}, \ - ("ops." op_str), \ - record_type, \ - static_cast< \ - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( \ - op_name), \ - axes, \ - keepdim, \ - dtype)); \ - return output; \ - }, \ - py::arg("arg"), \ - py::arg("axes"), \ - py::arg("keepdim") = false, \ - py::arg("dtype") = DataType::Null, \ +#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name, record_type) \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + NVF_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = 0; \ + std::vector axes(arg.dims); \ + std::iota(axes.begin(), axes.end(), 0); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast&, \ + bool, \ + DataType)>(op_name), \ + axes, \ + false, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("dtype") = DataType::Null, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + int axis, \ + bool keepdim, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + NVF_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = keepdim ? arg.dims : (arg.dims - 1); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast&, \ + bool, \ + DataType)>(op_name), \ + {axis}, \ + keepdim, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("axis"), \ + py::arg("keepdim") = false, \ + py::arg("dtype") = DataType::Null, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](FusionDefinition::Operators& self, \ + Tensor arg, \ + const std::vector& axes, \ + bool keepdim, \ + PrimDataType dtype) -> Tensor { \ + FUSER_PERF_SCOPE("Operators." op_str); \ + NVF_CHECK( \ + self.validUse(), "Attempting to add to a completed definition!"); \ + FusionDefinition* fd = self.fusion_definition; \ + size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); \ + Tensor output = fd->defineTensor(ndims); \ + fd->defineRecord(new ReductionOpRecord( \ + {fd->recordingState(arg())}, \ + {fd->recordingState(output())}, \ + ("ops." op_str), \ + record_type, \ + static_cast&, \ + bool, \ + DataType)>(op_name), \ + axes, \ + keepdim, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("axes"), \ + py::arg("keepdim") = false, \ + py::arg("dtype") = DataType::Null, \ py::return_value_policy::reference); NVFUSER_PYTHON_BINDING_REDUCTION_OP( diff --git a/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp b/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp index 170531cfaad..e0eabf5122d 100644 --- a/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp +++ b/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp @@ -98,9 +98,10 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( - sum), + static_cast&, + bool, + DataType)>(sum), {0}, false, DataType::Float)); @@ -109,9 +110,10 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( - sum), + static_cast&, + bool, + DataType)>(sum), {0}, false, DataType::Float)); @@ -120,9 +122,10 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {out}, "ops.sum", serde::RecordType_ReductionSum, - static_cast< - TensorView* (*)(TensorView*, const std::vector&, bool, DataType)>( - sum), + static_cast&, + bool, + DataType)>(sum), {0}, false, DataType::Float)); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 451849e18a2..253c6dc46dc 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -59,11 +59,11 @@ std::pair generateSharedMemoryEpilogueHeuristics( properties->warpSize * vector_word; const int mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k; const int nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k; - const size_t smem_a = - (size_t)(ceilDiv(mk, round_to_factor) * round_to_factor * smem_double_buffer_stage) * + const size_t smem_a = (size_t)(ceilDiv(mk, round_to_factor) * + round_to_factor * smem_double_buffer_stage) * dataTypeSize(data_types[0]); - const size_t smem_b = - (size_t)(ceilDiv(nk, round_to_factor) * round_to_factor * smem_double_buffer_stage) * + const size_t smem_b = (size_t)(ceilDiv(nk, round_to_factor) * + round_to_factor * smem_double_buffer_stage) * dataTypeSize(data_types[1]); const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * dataTypeSize(data_types[2]); diff --git a/csrc/utils.h b/csrc/utils.h index e55f191a542..7358f1c55b9 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -196,9 +196,8 @@ std::vector getSortedKeys( // Based on https://stackoverflow.com/a/9154394 template -static auto hasToStringHelper(int) -> decltype( - std::declval::type>().toString(), - std::true_type{}); +static auto hasToStringHelper(int) + -> decltype(std::declval::type>().toString(), std::true_type{}); template static auto hasToStringHelper(long) -> std::false_type; From 27c619a34eba10e418d45eab4ad52c1c19692838 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 16 Oct 2023 12:16:38 -0700 Subject: [PATCH 069/178] cleanup --- csrc/id_model/id_graphs.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 1a05bdc1fc9..a2d2e71652a 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1335,6 +1335,10 @@ std::unordered_map IterDomainGraphs:: // graph with the (promoted) inputs of iel_expr. If found, no need // to create a new expr to produce promoted outputs for (const ExprGroup& iel_use_group : non_promoted_input_uses) { + // No need to check itself + if (iel_expr == iel_use_group) { + continue; + } if (IdGraph::transformAtributesMatch( iel_expr->front(), iel_use_group->front())) { auto use_inps = From 4523cb27ee2098fc47b59c8963b488458e2f6c4f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 16 Oct 2023 14:10:38 -0700 Subject: [PATCH 070/178] WIP: Loop promotion with IEL (#1090) When finding reusable promotion domains, make sure to reuse loop-mapped domains since promotion domains should be part of mapped loop groups. This should fix #1054. Additional checks are also added to the Indexing19 test. --- csrc/id_model/id_graphs.cpp | 55 +++++++++++++++++++++++++++++-------- test/test_gpu_indexing.cpp | 27 ++++++++++++++++-- 2 files changed, 68 insertions(+), 14 deletions(-) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index a2d2e71652a..8f2fbbbcf41 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1700,37 +1700,64 @@ std::unordered_map IterDomainGraphs:: // Before replaying, check if there's already an expression like this, if so // use that for promotion. We're still only looking for representative iter // domains, so if there's already an expression that would produce something - // representative (matching in the exact graph) of what the new inputs would + // representative (matching in the IEL graph) of what the new inputs would // generate, just promote to that expressions outputs, don't bother // generating a new one. // - // Check all uses of the exact map the inputs are in, and look for one that - // would match. Grab all uses of the promoted inputs' groups in the exact - // map. + // Check all uses of the IEL map the inputs are in, and look for one that + // would match. Grab all uses of the promoted inputs' groups in the IEL + // map. Note that promotion should be to loop-mapped domains, so + // the IEL graph is used rather than the exact graph std::vector promoted_input_groups; ExprGroups promoted_input_uses; for (auto inp_id : promoted_inputs) { const auto& inp_exact_group = - idGraph(IdMappingMode::EXACT).toGroup(inp_id); + intersection_exact_loop_graph.toGroup(inp_id); promoted_input_groups.push_back(inp_exact_group); promoted_input_uses.pushBack( - idGraph(IdMappingMode::EXACT).getUniqueUses(inp_exact_group)); + intersection_exact_loop_graph.getUniqueUses(inp_exact_group)); } // Check every use to see if it matches - for (const ExprGroup& exact_use_group : promoted_input_uses) { + for (const ExprGroup& iel_use_group : promoted_input_uses) { + NVF_ERROR(!iel_use_group->empty()); // Check if all the attributes (including type) of the transform match if (!IdGraph::transformAtributesMatch( - iel_expr->front(), exact_use_group->front())) { + iel_expr->front(), iel_use_group->front())) { continue; } // Check if inputs all match if (promoted_input_groups != - idGraph(IdMappingMode::EXACT).inputGroups(exact_use_group)) { + intersection_exact_loop_graph.inputGroups(iel_use_group)) { + continue; + } + // Input mapping doesn't always mean expr and output + // mappings. Make sure the exprs are mapped, which automatically + // means the outputs are mapped in the case of the LOOP map + if (!idGraph(IdMappingMode::LOOP) + .disjointExprSets() + .permissiveAreMapped( + iel_expr->front(), iel_use_group->front())) { continue; } - replay = exact_use_group->front(); + // This is just an extra sanity check. Make sure all exprs in + // the use group are mapped + NVF_ERROR( + std::all_of( + iel_use_group->vector().begin(), + iel_use_group->vector().end(), + [&](Expr* iel_use) { + return idGraph(IdMappingMode::LOOP) + .disjointExprSets() + .permissiveAreMapped(iel_expr->front(), iel_use); + }), + "Not all mapped: ", + nvfuser::toString(iel_expr), + "\n", + nvfuser::toString(iel_use_group)); + + replay = iel_use_group->front(); break; } @@ -1801,11 +1828,17 @@ std::unordered_map IterDomainGraphs:: } }; - // Set up the loop promotion map of loops groups to promotion IDs + // Set up the loop promotion map of loop groups to promotion IDs for (const IdGroup& loop_group : idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { bool promoted = false; for (IterDomain* id : loop_group->vector()) { + // Additional domains are added to the LOOP graph after the IEL + // graph was built. Those auxiliary domains should be fine to + // ignore. + if (!intersection_exact_loop_graph.hasGroup(id)) { + continue; + } const auto& iel_group = intersection_exact_loop_graph.toGroup(id); if (auto iel_promotion_map_it = iel_promotion_map.find(iel_group); iel_promotion_map_it != iel_promotion_map.end()) { diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index ca81b0da900..e62205579c3 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -910,10 +910,16 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { auto promotion_map_it = promotion_map.find(merge_loop_group); ASSERT_TRUE(promotion_map_it != promotion_map.end()) << "Loop promotion not found for merge loop group"; + auto merge_out_promotion_id = promotion_map_it->second; ASSERT_EQ( - id_model.idGraph(IdMappingMode::EXACT).toGroup(promotion_map_it->second), + id_model.idGraph(IdMappingMode::EXACT).toGroup(merge_out_promotion_id), id_model.idGraph(IdMappingMode::EXACT).toGroup(ref_merge_out)) << "Merge loop group should be promoted to " << ref_merge_out->toString(); + ASSERT_NE( + id_model.idGraph(IdMappingMode::LOOP).toGroup(merge_out_promotion_id), + id_model.idGraph(IdMappingMode::LOOP).toGroup(ref_merge_out)) + << "Should not be loop-mapped with ref: " + << merge_out_promotion_id->toString(); // Get the corresponding reference ID in tv10 auto getRefId = [&](TensorView* tv, IterDomain* id) -> IterDomain* { @@ -950,8 +956,16 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { << tv->toString() << ". Loop group: " << nvfuser::toString(loop_group); - auto promotion_exact_group = id_model.idGraph(IdMappingMode::EXACT) - .toGroup(promotion_map_it->second); + auto promotion_id = promotion_map_it->second; + + // Promotion ID should be loop-mapped + ASSERT_TRUE(loop_group->has(promotion_id)) + << "Loop promotion for " << id->toString() << " of " << tv->toString() + << " is promoted to an ID that isn't loop mapped: " + << promotion_id->toString() << std::endl; + + auto promotion_exact_group = + id_model.idGraph(IdMappingMode::EXACT).toGroup(promotion_id); auto ref_id = getRefId(tv, id); auto ref_exact_group = @@ -960,6 +974,13 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { ASSERT_EQ(promotion_exact_group, ref_exact_group) << "Invalid promotion: " << id->toString() << " of " << tv->toString() << ". Promotion group: " << nvfuser::toString(promotion_exact_group); + + auto ref_loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(ref_id); + ASSERT_NE(loop_group, ref_loop_group) + << "Invalid promotion: " << id->toString() << " of " << tv->toString() + << ". Should not be loop-mapped with ref: " + << nvfuser::toString(loop_group); } } From 86c574ab9263b0a85eab99c89e75649fb46124fc Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 16 Oct 2023 18:49:14 -0700 Subject: [PATCH 071/178] Verify loop mappings of leaf domains See #1055 --- test/test_gpu_indexing.cpp | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index e62205579c3..d215fe30f36 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -939,6 +939,16 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { } }; + // Check if id is a leaf of a consumer tensor of tv + auto isIdOfConsumerTensor = [&](IterDomain* id, TensorView* tv) -> bool { + auto consumer_tvs = ir_utils::consumerTvsOf(tv); + return std::any_of( + consumer_tvs.begin(), consumer_tvs.end(), [&](auto consumer_tv) { + auto all_ids = ir_utils::allIDsOf(consumer_tv); + return std::find(all_ids.begin(), all_ids.end(), id) != all_ids.end(); + }); + }; + // At this point, all of the IDs from the root until split are // validated. Validating the remaining IDs for (auto tv : {tv1, tv2, tv4, tv5, tv6, tv8, tv9}) { @@ -981,6 +991,23 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { << "Invalid promotion: " << id->toString() << " of " << tv->toString() << ". Should not be loop-mapped with ref: " << nvfuser::toString(loop_group); + + // If id is a leaf, make sure it isn't mapped with + auto leaf_id_it = + std::find(tv->getLeafDomain().begin(), tv->getLeafDomain().end(), id); + if (leaf_id_it != tv->getLeafDomain().end() && + std::distance(tv->getLeafDomain().begin(), leaf_id_it) >= + tv->getComputeAtPosition()) { + for (auto loop_mapped_id : *loop_group) { + if (loop_mapped_id == id) { + continue; + } + ASSERT_FALSE(isIdOfConsumerTensor(loop_mapped_id, tv)) + << "Invalid promotion: " << id->toString() << " of " + << tv->toString() << ". Found to mapped a consumer tensor: " + << loop_mapped_id->name(); + } + } } } From bff13b8a1d0248343602bbb6cbc28d70461e988e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 16 Oct 2023 19:03:17 -0700 Subject: [PATCH 072/178] minor change --- csrc/id_model/id_graph.cpp | 1 + csrc/id_model/id_graph.h | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index c7bb547baa6..2fcdfc31a6e 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include namespace nvfuser { diff --git a/csrc/id_model/id_graph.h b/csrc/id_model/id_graph.h index 77333f1a3ed..c35b55c84d0 100644 --- a/csrc/id_model/id_graph.h +++ b/csrc/id_model/id_graph.h @@ -199,6 +199,10 @@ class IdGraph { // do anything. bool isTrivialExprGroup(const ExprGroup& expr_group) const; + void setPropagateThroughExprs(bool b) { + propagate_through_exprs_ = b; + } + private: // Map expr0 and expr1 with eachother, update unique_definitions_ unique_uses_ // TODO: Make this variant hidden? From aadcb9db5be31086ff86ab7bc3e4f6d2932cddec Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 25 Oct 2023 11:55:26 -0700 Subject: [PATCH 073/178] Fix use of replayed input IDs (#1144) Don't remember but some of the C++ tests failed. --- csrc/id_model/id_graphs.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 8f2fbbbcf41..99141a9f221 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1712,6 +1712,12 @@ std::unordered_map IterDomainGraphs:: ExprGroups promoted_input_uses; for (auto inp_id : promoted_inputs) { + // inp_id may have been just replayed, in which case it should + // not exist in the IEL graph. It should be just ignored as it + // should not have any use yet. + if (!intersection_exact_loop_graph.hasGroup(inp_id)) { + continue; + } const auto& inp_exact_group = intersection_exact_loop_graph.toGroup(inp_id); promoted_input_groups.push_back(inp_exact_group); From 5f0e5c4055d2a9de87ac60ac0044193a1e1c8f5f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 25 Oct 2023 16:38:54 -0700 Subject: [PATCH 074/178] cleanup --- csrc/id_model/id_graphs.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 99141a9f221..ff81e929ed5 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1804,13 +1804,6 @@ std::unordered_map IterDomainGraphs:: } } - for (const IdGroup& group : - intersection_exact_loop_graph.disjointIdSets().disjointSets()) { - if (iel_promotion_map.find(group) == iel_promotion_map.end()) { - continue; - } - } - // TODO: cleanup // Set loop_promotion_map_[loop_group] = promotion. // Make sure the existing mapping, if exists, matches with the given From 811a4adbfcbdd397b4e823e005b1fc3481448955 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 26 Oct 2023 14:29:47 -0700 Subject: [PATCH 075/178] Finalize loop promotion map (#1149) Fixes #1146 --- csrc/id_model/id_graph.cpp | 5 ++ csrc/id_model/id_graph.h | 2 + csrc/id_model/id_graphs.cpp | 126 ++++++++++++++++++++++-------------- 3 files changed, 86 insertions(+), 47 deletions(-) diff --git a/csrc/id_model/id_graph.cpp b/csrc/id_model/id_graph.cpp index 2fcdfc31a6e..eb9f69ddced 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/id_model/id_graph.cpp @@ -511,6 +511,11 @@ std::pair IdGraph::getUses(const IdGroup& id_group) const { } } +bool IdGraph::hasUses(const IdGroup& id_group) const { + NVF_ERROR(id_group); + return unique_uses_.find(id_group) != unique_uses_.end(); +} + std::string IdGraph::toString() const { std::stringstream ss; ss << "IdGraph { \n"; diff --git a/csrc/id_model/id_graph.h b/csrc/id_model/id_graph.h index c35b55c84d0..b2660f7537d 100644 --- a/csrc/id_model/id_graph.h +++ b/csrc/id_model/id_graph.h @@ -126,6 +126,8 @@ class IdGraph { //! TODO-NM: Rename to getMaybeUses. See getUses std::pair getUses(const IdGroup& id_group) const; + bool hasUses(const IdGroup& id_group) const; + std::string toString() const; // Checks if the expression is a trivial operation where an input is simply an diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index ff81e929ed5..46a78c5b30e 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -1804,59 +1804,54 @@ std::unordered_map IterDomainGraphs:: } } - // TODO: cleanup - // Set loop_promotion_map_[loop_group] = promotion. - // Make sure the existing mapping, if exists, matches with the given - // promotion. - auto setLoopPromotion = - [this](const IdGroup& loop_group, IterDomain* promotion) -> void { - if (auto it = loop_promotion_map_.find(loop_group); - it != loop_promotion_map_.end()) { - auto existing_promotion = it->second; + // Update the coverage map + exact_covered_ids = + computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); + + // Set up the loop promotion map of loop groups to promotion + // IDs. Note that the IEL promotion map is still incomplete in the + // sense that: + // + // - Not all loop graphs have promotions set at this point. + // - Multiple domains that are loop-mapped may have different + // promotions, one of which should cover the rest. + // + // Fill the gap, here we traverse the loop graph and for each loop + // group we examine each IEL group. If an IEL group has a promotion, + // we consider it as a candidate of the promotion of this loop + // group. If not, we include a domain of the IEL group as a + // candidate too. We also look at the inline promotion map since + // that may also contain the promotion the loop should be associated + // with. Once all candidates are obtained, we pick one that covers + // all the exact domains (cf. concrete domains in ComputeAtMap) + for (const IdGroup& loop_group : + loop_graph_copy.disjointIdSets().disjointSets()) { + IdGroups iel_groups = intersection_exact_loop_graph.toGroups(*loop_group); + // All exact groups covered by all iter domains in this loop group + IdGroups loop_group_covered_ids; + for (const IdGroup& iel_group : iel_groups) { + auto exact_group = + idGraph(IdMappingMode::EXACT).toGroup(iel_group->front()); + auto covered_it = exact_covered_ids.find(exact_group); NVF_ERROR( - idGraph(IdMappingMode::EXACT).toGroup(promotion) == - idGraph(IdMappingMode::EXACT).toGroup(existing_promotion), - "Different promotions found for ", - nvfuser::toString(loop_group), - ". ", - promotion->toString(), - ", ", - existing_promotion->toString()); - } else { - loop_promotion_map_.emplace(loop_group, promotion); + covered_it != exact_covered_ids.end(), + "Exact covered id not found for ", + nvfuser::toString(exact_group)); + loop_group_covered_ids.pushBack(covered_it->second); } - }; - // Set up the loop promotion map of loop groups to promotion IDs - for (const IdGroup& loop_group : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { - bool promoted = false; - for (IterDomain* id : loop_group->vector()) { - // Additional domains are added to the LOOP graph after the IEL - // graph was built. Those auxiliary domains should be fine to - // ignore. - if (!intersection_exact_loop_graph.hasGroup(id)) { - continue; - } - const auto& iel_group = intersection_exact_loop_graph.toGroup(id); + VectorOfUniqueEntries representative_id_candidates; + + for (const IdGroup& iel_group : iel_groups) { if (auto iel_promotion_map_it = iel_promotion_map.find(iel_group); iel_promotion_map_it != iel_promotion_map.end()) { IterDomain* iel_promotion_id = iel_promotion_map_it->second; - setLoopPromotion(loop_group, iel_promotion_id); - promoted = true; + representative_id_candidates.pushBack(iel_promotion_id); + } else { + representative_id_candidates.pushBack(iel_group->front()); } } - if (promoted) { - continue; - } - - VERBOSE() << "No mapping in the IEL promotion map: " - << nvfuser::toString(loop_group) << std::endl; - - // No mapping in the IEL promotion map. If the loop group is still - // mapped in the loop group promotion map, that should be the - // correct promotion for this group if (auto loop_graph_copy_promotion_map_it = loop_graph_copy_promotion_map.find( loop_graph_copy.toGroup(loop_group->vector().at(0))); @@ -1864,14 +1859,51 @@ std::unordered_map IterDomainGraphs:: loop_graph_copy_promotion_map.end()) { VERBOSE() << "Found in loop promotion: " << nvfuser::toString(loop_group) << std::endl; - setLoopPromotion(loop_group, loop_graph_copy_promotion_map_it->second); - promoted = true; + representative_id_candidates.pushBack( + loop_graph_copy_promotion_map_it->second); } + VERBOSE() << "Loop promotion candidates: " << std::endl; + + // All candidates gathered + for (IterDomain* candidate_id : representative_id_candidates) { + auto covered_it = exact_covered_ids.find( + idGraph(IdMappingMode::EXACT).toGroup(candidate_id)); + NVF_ERROR(covered_it != exact_covered_ids.end()); + if (loop_group_covered_ids.subtract(covered_it->second).empty()) { + // Found + VERBOSE() << "Representative found: " << candidate_id->toString() + << std::endl; + const IdGroup& current_loop_group = + idGraph(IdMappingMode::LOOP).toGroup(loop_group->front()); + loop_promotion_map_.emplace(current_loop_group, candidate_id); + break; + } + } + } + + // Sanity check of the loop promotion map + for (const IdGroup& loop_group : + idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + // Non-leaf loop groups are not guaranteed to have valid + // promotions. See for example FusionRepro1713, where root domains + // are all grouped together but there's no valid promotion. + if (idGraph(IdMappingMode::LOOP).hasUses(loop_group)) { + continue; + } + auto promotion_it = loop_promotion_map_.find(loop_group); NVF_ERROR( - promoted, + promotion_it != loop_promotion_map_.end(), "Loop promotion not found for ", nvfuser::toString(loop_group)); + IterDomain* promotion = promotion_it->second; + // Make sure the promotion domain is also loop-mapped + NVF_ERROR( + loop_group->has(promotion), + "Loop promotion not loop-mapped. Loop group: ", + nvfuser::toString(loop_group), + ". Promotion domain: ", + promotion->name()); } VERBOSE() << "Loop promotion map:" << std::endl; From f3bbd8f1f8f91d2f9e0693939ffb42a1fdb4a035 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 26 Oct 2023 16:21:57 -0700 Subject: [PATCH 076/178] minor --- csrc/disjoint_set.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index ff714b22844..90467f2faf3 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -253,6 +253,22 @@ class VectorOfUniqueEntries { return vector_.end(); } + auto rbegin() const { + return vector().rbegin(); + } + + auto rend() const { + return vector().rend(); + } + + auto rbegin() { + return vector_.begin(); + } + + auto rend() { + return vector_.end(); + } + std::string toString() const { std::stringstream ss; ss << "{ "; From f2448c24c75d1443de8994700ca0a9d53af6593b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 26 Oct 2023 16:58:07 -0700 Subject: [PATCH 077/178] Renaming IterDomainGraphs to IdModel (#1161) I prefer to have `IdGraph` and `IdModel`, rather than `IdGraph` and `IdGraphs` --- csrc/device_lower/lower2device.cpp | 2 +- csrc/id_model/id_graphs.cpp | 72 ++++++++++++++---------------- csrc/id_model/id_graphs.h | 10 ++--- test/test_gpu_indexing.cpp | 2 +- 4 files changed, 39 insertions(+), 47 deletions(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index ebd827d8440..bb708d31958 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -311,7 +311,7 @@ void GpuLower::lower(Fusion* fusion) { replaceSymbolicSizes(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "replaceSymbolicSizes"); - IterDomainGraphs test(fusion_); + IdModel test(fusion_); // Build what's refered to as the compute at map. This map contains the // mappings of all iteration domains across the fusion. There are three types diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_graphs.cpp index 46a78c5b30e..de3154d943d 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_graphs.cpp @@ -24,7 +24,7 @@ namespace nvfuser { -void IterDomainGraphs::assertNoSelfMapping() { +void IdModel::assertNoSelfMapping() { if (hasSelfMapping()) { NVF_ERROR( !hasSelfMapping(), @@ -40,7 +40,7 @@ void IterDomainGraphs::assertNoSelfMapping() { } } -IterDomainGraphs::IterDomainGraphs( +IdModel::IdModel( const std::vector& exprs, const std::vector& additional_tvs, bool allow_self_mapping) { @@ -51,12 +51,10 @@ IterDomainGraphs::IterDomainGraphs( } } -IterDomainGraphs::IterDomainGraphs( - const std::vector& exprs, - bool allow_self_mapping) - : IterDomainGraphs(exprs, {}, allow_self_mapping) {} +IdModel::IdModel(const std::vector& exprs, bool allow_self_mapping) + : IdModel(exprs, {}, allow_self_mapping) {} -IterDomainGraphs::IterDomainGraphs(Fusion* fusion, bool allow_self_mapping) { +IdModel::IdModel(Fusion* fusion, bool allow_self_mapping) { std::vector inputs_and_outputs; { auto inp_tvs = ir_utils::filterByType(fusion->inputs()); @@ -76,19 +74,19 @@ IterDomainGraphs::IterDomainGraphs(Fusion* fusion, bool allow_self_mapping) { } } -const IdGraph& IterDomainGraphs::idGraph(IdMappingMode mode) const { +const IdGraph& IdModel::idGraph(IdMappingMode mode) const { auto graph_it = id_graphs_.find(mode); NVF_ERROR(graph_it != id_graphs_.end()); return graph_it->second; } -IdGraph& IterDomainGraphs::idGraph(IdMappingMode mode) { +IdGraph& IdModel::idGraph(IdMappingMode mode) { auto graph_it = id_graphs_.find(mode); NVF_ERROR(graph_it != id_graphs_.end()); return graph_it->second; } -Expr* IterDomainGraphs::idUse(IterDomain* id) const { +Expr* IdModel::idUse(IterDomain* id) const { auto use_it = id_uses_.find(id); if (use_it == id_uses_.end()) { return nullptr; @@ -96,7 +94,7 @@ Expr* IterDomainGraphs::idUse(IterDomain* id) const { return use_it->second.front(); } -Expr* IterDomainGraphs::idDef(IterDomain* id) const { +Expr* IdModel::idDef(IterDomain* id) const { auto def_it = id_definitions_.find(id); if (def_it == id_definitions_.end()) { return nullptr; @@ -139,7 +137,7 @@ namespace { // we pull multiple values of tv0 to compute tv3. c10::optional> detectMappablePair( const std::vector& ids, - const IterDomainGraphs& id_graph, + const IdModel& id_graph, IdMappingMode mode) { for (auto id1 : ids) { for (auto id2 : ids) { @@ -163,7 +161,7 @@ c10::optional> detectMappablePair( c10::optional> findFirstSelfMapping( const std::vector& all_tvs, - const IterDomainGraphs& id_graph) { + const IdModel& id_graph) { for (auto tv : all_tvs) { // For each tensor, make sure root, rfactor and leaf domains // should not include domains that are mapped with another domain @@ -213,7 +211,7 @@ findFirstSelfMapping( } // namespace -void IterDomainGraphs::buildIterDomainDefinitionsAndUses( +void IdModel::buildIterDomainDefinitionsAndUses( const std::vector& all_tvs) { for (auto tv : all_tvs) { VectorOfUniqueEntries root_domain_ids{ @@ -267,7 +265,7 @@ void IterDomainGraphs::buildIterDomainDefinitionsAndUses( } } -std::string IterDomainGraphs::toString() const { +std::string IdModel::toString() const { // Figure out which graphs are already initialized to make sure we add the new // expression to them. std::vector initialized_modes; @@ -302,9 +300,7 @@ std::string IterDomainGraphs::toString() const { } // Replay Expr but with the inputs provided. -Expr* IterDomainGraphs::addReplayAs( - std::vector new_inputs, - Expr* expr) { +Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { // Figure out which graphs are already initialized to make sure we add the new // expression to them. std::vector initialized_modes; @@ -429,7 +425,7 @@ Expr* IterDomainGraphs::addReplayAs( // Generate a new expr with the IterDomain inputs/outputs replaced based on map. // Replaced inputs/outputs should almost exact match with provided expr. -Expr* IterDomainGraphs::addExprWithReplacement( +Expr* IdModel::addExprWithReplacement( const std::unordered_map& old_2_new_ids, Expr* old_expr) { // Figure out which graphs are already initialized to make sure we add the new @@ -598,7 +594,7 @@ Expr* IterDomainGraphs::addExprWithReplacement( // Clone provided iter domain and return the new copy. Map that copy in relevant // maps. -IterDomain* IterDomainGraphs::cloneIterDomain(IterDomain* id) { +IterDomain* IdModel::cloneIterDomain(IterDomain* id) { // Figure out which graphs are already initialized to make sure we add the new // expression to them. std::vector initialized_modes; @@ -629,7 +625,7 @@ IterDomain* IterDomainGraphs::cloneIterDomain(IterDomain* id) { return id_copy; } -IdGraph IterDomainGraphs::initializeIdGraph(bool propagate_through_exprs) { +IdGraph IdModel::initializeIdGraph(bool propagate_through_exprs) { IdGraph id_graph(propagate_through_exprs); for (const auto& [id, defs] : id_definitions_) { @@ -645,7 +641,7 @@ IdGraph IterDomainGraphs::initializeIdGraph(bool propagate_through_exprs) { return id_graph; } -void IterDomainGraphs::buildExactMap(const std::vector& exprs) { +void IdModel::buildExactMap(const std::vector& exprs) { for (auto expr : exprs) { TensorView* c_tv = ir_utils::getTvOutput(expr); @@ -693,7 +689,7 @@ void IterDomainGraphs::buildExactMap(const std::vector& exprs) { } } -void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { +void IdModel::buildPermissiveMap(const std::vector& exprs) { // Use the exact map as the starting map rather than the // almost-exact map. Almost exact is useful for index hoisting but // not necessary for permissive and loop maps @@ -748,7 +744,7 @@ void IterDomainGraphs::buildPermissiveMap(const std::vector& exprs) { idGraph(IdMappingMode::PERMISSIVE).mapThroughLoopSwizzles(); } -void IterDomainGraphs::buildAlmostExactMap() { +void IdModel::buildAlmostExactMap() { // Build almost exact map by forwarding through broadcast axes idGraph(IdMappingMode::ALMOSTEXACT) = idGraph(IdMappingMode::EXACT); idGraph(IdMappingMode::ALMOSTEXACT).mapThroughTrivialExprs(); @@ -756,8 +752,7 @@ void IterDomainGraphs::buildAlmostExactMap() { // TODO: Reenable after reenabling parallel propagation. // propagateLoopPTypes -void IterDomainGraphs::validatePTypes( - const std::vector& all_tvs) const { +void IdModel::validatePTypes(const std::vector& all_tvs) const { // VectorOfUniqueEntries leaf_ids; // for (auto tv : all_tvs) { // leaf_ids.pushBack(tv->domain()->leaf()); @@ -776,7 +771,7 @@ void IterDomainGraphs::validatePTypes( // } } -void IterDomainGraphs::propagateLoopPTypes() const { +void IdModel::propagateLoopPTypes() const { for (const auto& loop_disjoint_set : idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { ParallelType common_ptype = ParallelType::Serial; @@ -937,7 +932,7 @@ StatefulLoweringInfo buildInfo( } // namespace -void IterDomainGraphs::build( +void IdModel::build( const std::vector& exprs, const std::vector& additional_tvs) { // Initialize the required sets as if a permissive relationship is never @@ -1034,7 +1029,7 @@ void IterDomainGraphs::build( self_mapping_info_ = findFirstSelfMapping(all_tvs, *this); } -VectorOfUniqueEntries IterDomainGraphs::computeTerminalLoopIds( +VectorOfUniqueEntries IdModel::computeTerminalLoopIds( const StatefulLoweringInfo info) { VectorOfUniqueEntries terminal_loop_ids; for (const IdGroup& group : @@ -1076,7 +1071,7 @@ VectorOfUniqueEntries IterDomainGraphs::computeTerminalLoopIds( return terminal_loop_ids; } -IdGraph IterDomainGraphs::buildIntersection( +IdGraph IdModel::buildIntersection( const IdGraph& graph0, const IdGraph& graph1, bool propagate_exprs) { @@ -1098,7 +1093,7 @@ IdGraph IterDomainGraphs::buildIntersection( return intersection; } -void IterDomainGraphs::initializeLoopMap(StatefulLoweringInfo& info) { +void IdModel::initializeLoopMap(StatefulLoweringInfo& info) { // See Indexing20 example for why we shouldn't propagate when generating loop // groups idGraph(IdMappingMode::LOOP) = initializeIdGraph(false); @@ -1116,8 +1111,8 @@ void IterDomainGraphs::initializeLoopMap(StatefulLoweringInfo& info) { } } -std::unordered_map IterDomainGraphs:: - buildInlinePromotions(StatefulLoweringInfo& info) { +std::unordered_map IdModel::buildInlinePromotions( + StatefulLoweringInfo& info) { // Make an intersection of the exact and loop map. This will group together // entries in each loop group that are exact with each other. This provides a // better graph to do promotion and replays. @@ -1470,11 +1465,10 @@ std::unordered_map computeCoveredGroups( } }; // namespace -std::unordered_map IterDomainGraphs:: - buildLoopPromotionMap( - const std::vector& exprs, - StatefulLoweringInfo& info, - const std::unordered_map& stale_promotion_map) { +std::unordered_map IdModel::buildLoopPromotionMap( + const std::vector& exprs, + StatefulLoweringInfo& info, + const std::unordered_map& stale_promotion_map) { // Non-ca domains may also need to be promoted if parent domains are // promoted. @@ -1915,7 +1909,7 @@ std::unordered_map IterDomainGraphs:: return iel_promotion_map; } -std::unordered_map IterDomainGraphs::buildIndexGraph( +std::unordered_map IdModel::buildIndexGraph( const std::vector& exprs, const std::vector& all_tvs, StatefulLoweringInfo& info, diff --git a/csrc/id_model/id_graphs.h b/csrc/id_model/id_graphs.h index 440963288a6..6c69296e81f 100644 --- a/csrc/id_model/id_graphs.h +++ b/csrc/id_model/id_graphs.h @@ -82,21 +82,19 @@ struct StatefulLoweringInfo; // PERMISSIVE) // Forward through split one axes, i.e. id{ceilDiv(i0, 1)}, id{i0} are mapped // -class IterDomainGraphs : public PolymorphicBase { +class IdModel : public PolymorphicBase { public: - IterDomainGraphs( + IdModel( const std::vector& exprs, const std::vector& additional_tvs, bool allow_self_mapping = false); - IterDomainGraphs( - const std::vector& exprs, - bool allow_self_mapping = false); + IdModel(const std::vector& exprs, bool allow_self_mapping = false); // Same as the above constructor with fusion->exprs() excpet fusion may have // some dangling inputs/outputs that are expected to have IterDomain entries // even though there's no possible connections from them. - IterDomainGraphs(Fusion* fusion, bool allow_self_mapping = false); + IdModel(Fusion* fusion, bool allow_self_mapping = false); // Returns iter domain graph of provided mode. const IdGraph& idGraph(IdMappingMode mode) const; diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index d215fe30f36..bda3916d225 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -879,7 +879,7 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { tensor->inlineAt(1); } - IterDomainGraphs id_model(&fusion); + IdModel id_model(&fusion); // All of the IDs that are generated with merge operations from the // root domains should be mapped to the single group. From 57469b3e4851580bf5e4715f008bd13d6d90ecd8 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 26 Oct 2023 17:35:12 -0700 Subject: [PATCH 078/178] Follow-up to #1161 (#1165) CC: @csarofeen --- csrc/id_model/{id_graphs.cpp => id_model.cpp} | 4 ++-- csrc/id_model/{id_graphs.h => id_model.h} | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) rename csrc/id_model/{id_graphs.cpp => id_model.cpp} (99%) rename csrc/id_model/{id_graphs.h => id_model.h} (99%) diff --git a/csrc/id_model/id_graphs.cpp b/csrc/id_model/id_model.cpp similarity index 99% rename from csrc/id_model/id_graphs.cpp rename to csrc/id_model/id_model.cpp index de3154d943d..422d7e4d029 100644 --- a/csrc/id_model/id_graphs.cpp +++ b/csrc/id_model/id_model.cpp @@ -505,13 +505,13 @@ Expr* IdModel::addExprWithReplacement( // Create the new expression with provided outputs auto replay = ReplacementTransformCloner::clone(replacement_map, old_expr); - // Add new output iter domains to id_definitions_/id_uses_ of IdGraphs + // Add new output iter domains to id_definitions_/id_uses_ of IdModel for (auto out_id : ir_utils::filterByType(replay->outputs())) { id_definitions_[out_id].pushBack(replay); id_uses_[out_id]; } - // Add new input iter domains to id_definitions_/id_uses_ of IdGraphs + // Add new input iter domains to id_definitions_/id_uses_ of IdModel for (auto inp_id : ir_utils::filterByType(replay->inputs())) { id_definitions_[inp_id]; id_uses_[inp_id].pushBack(replay); diff --git a/csrc/id_model/id_graphs.h b/csrc/id_model/id_model.h similarity index 99% rename from csrc/id_model/id_graphs.h rename to csrc/id_model/id_model.h index 6c69296e81f..bb00160162c 100644 --- a/csrc/id_model/id_graphs.h +++ b/csrc/id_model/id_model.h @@ -161,7 +161,7 @@ class IdModel : public PolymorphicBase { Expr* expr); // Make an exact copy of provided IterDomain (without rfactor set), and map - // the copy to the original in all registered IdGraphs. IterDomain copy will + // the copy to the original in all registered IdModel. IterDomain copy will // not have any registered uses or definitions. IterDomain* cloneIterDomain(IterDomain* id); From 5dde9f110dbf67e61153f485d40e9c9e1416167c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 26 Oct 2023 17:57:32 -0700 Subject: [PATCH 079/178] build fix (#1166) --- CMakeLists.txt | 2 +- csrc/device_lower/lower2device.cpp | 2 +- csrc/id_model/id_model.cpp | 2 +- test/test_gpu_indexing.cpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2e110d38c6f..fa16f56370a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,7 +77,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/fusion.cpp ${NVFUSER_SRCS_DIR}/grouped_reduction.cpp ${NVFUSER_SRCS_DIR}/id_model/id_graph.cpp - ${NVFUSER_SRCS_DIR}/id_model/id_graphs.cpp + ${NVFUSER_SRCS_DIR}/id_model/id_model.cpp ${NVFUSER_SRCS_DIR}/id_model/to_string.cpp ${NVFUSER_SRCS_DIR}/id_model/transform_replay.cpp ${NVFUSER_SRCS_DIR}/id_model/visitor.cpp diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index bb708d31958..d34439aa8e4 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -32,7 +32,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 422d7e4d029..8810b9c342f 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include +#include #include #include #include diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index bda3916d225..f853d34a47f 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -11,7 +11,7 @@ #include #include -#include +#include #include #include #include From e01e2e679d06b1194b9a6b1bc6e56dd4711e1afa Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 13 Nov 2023 18:03:59 -0800 Subject: [PATCH 080/178] Mechanical changes extracted from #1168 (#1293) --- CMakeLists.txt | 6 +- csrc/disjoint_set.h | 13 +- csrc/id_model/id_model.cpp | 273 +++++++++--------- csrc/id_model/id_model.h | 40 +-- csrc/id_model/to_string.cpp | 57 ++-- csrc/id_model/to_string.h | 30 +- csrc/id_model/visitor.cpp | 36 +-- csrc/id_model/visitor.h | 22 +- csrc/{id_model/id_graph.cpp => val_graph.cpp} | 231 ++++++++------- csrc/{id_model/id_graph.h => val_graph.h} | 109 +++---- test/test_gpu_indexing.cpp | 7 +- 11 files changed, 433 insertions(+), 391 deletions(-) rename csrc/{id_model/id_graph.cpp => val_graph.cpp} (80%) rename csrc/{id_model/id_graph.h => val_graph.h} (72%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 838cad3a067..8defa7dcf2c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,8 +77,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/executor_utils.cpp ${NVFUSER_SRCS_DIR}/fusion.cpp ${NVFUSER_SRCS_DIR}/grouped_reduction.cpp - ${NVFUSER_SRCS_DIR}/id_model/id_graph.cpp - ${NVFUSER_SRCS_DIR}/id_model/id_model.cpp + ${NVFUSER_SRCS_DIR}/id_model/id_model.cpp ${NVFUSER_SRCS_DIR}/id_model/to_string.cpp ${NVFUSER_SRCS_DIR}/id_model/transform_replay.cpp ${NVFUSER_SRCS_DIR}/id_model/visitor.cpp @@ -195,7 +194,8 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/optimization/mark_alias.cpp ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp -) + ${NVFUSER_SRCS_DIR}/val_graph.cpp + ) # We don't link CUPTI for MSVC if(NOT MSVC) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 90467f2faf3..952c27ce220 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -100,8 +100,19 @@ class VectorOfUniqueEntries { return pushBack(other.vector()); } + // Returns true if any node was added + template < + typename VectorOfUniqueEntriesType, + typename VectorOfUniqueEntriesHash> + bool pushBack(const VectorOfUniqueEntries< + VectorOfUniqueEntriesType, + VectorOfUniqueEntriesHash>& other) { + return pushBack(other.vector()); + } + // Returns if any node was added - bool pushBack(const std::vector& other) { + template + bool pushBack(const std::vector& other) { bool any_added = false; for (const auto& entry : other) { auto added = pushBack(entry); diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 8810b9c342f..8acbb4a5386 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -74,13 +74,13 @@ IdModel::IdModel(Fusion* fusion, bool allow_self_mapping) { } } -const IdGraph& IdModel::idGraph(IdMappingMode mode) const { +const ValGraph& IdModel::idGraph(IdMappingMode mode) const { auto graph_it = id_graphs_.find(mode); NVF_ERROR(graph_it != id_graphs_.end()); return graph_it->second; } -IdGraph& IdModel::idGraph(IdMappingMode mode) { +ValGraph& IdModel::idGraph(IdMappingMode mode) { auto graph_it = id_graphs_.find(mode); NVF_ERROR(graph_it != id_graphs_.end()); return graph_it->second; @@ -144,7 +144,7 @@ c10::optional> detectMappablePair( if (id1 == id2) { continue; } - if (id_graph.idGraph(mode).disjointIdSets().permissiveAreMapped( + if (id_graph.idGraph(mode).disjointValSets().permissiveAreMapped( id1, id2)) { return std::make_pair(id1, id2); } @@ -276,7 +276,7 @@ std::string IdModel::toString() const { } auto& graph = graph_it->second; - if (graph.disjointIdSets().disjointSetMap().empty()) { + if (graph.disjointValSets().disjointSetMap().empty()) { continue; } @@ -311,7 +311,7 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { } auto& graph = graph_it->second; - if (graph.disjointIdSets().disjointSetMap().empty()) { + if (graph.disjointValSets().disjointSetMap().empty()) { continue; } @@ -340,7 +340,7 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { id_definitions_[new_inputs.back()]; id_uses_[new_inputs.back()]; for (auto mode : initialized_modes) { - idGraph(mode).initializeId(new_inputs.back(), {}, {}); + idGraph(mode).initializeVal(new_inputs.back(), {}, {}); idGraph(mode).mapIds(new_inputs.back(), tmp_input); } } @@ -392,7 +392,7 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { // Initialize output ids in map for (auto out_id : ir_utils::filterByType(replay->outputs())) { - idGraph(mode).initializeId(out_id, {replay}, {}); + idGraph(mode).initializeVal(out_id, {replay}, {}); } // Update uses of the inputs in the graphs @@ -438,7 +438,7 @@ Expr* IdModel::addExprWithReplacement( } auto& graph = graph_it->second; - if (graph.disjointIdSets().disjointSetMap().empty()) { + if (graph.disjointValSets().disjointSetMap().empty()) { continue; } @@ -526,9 +526,9 @@ Expr* IdModel::addExprWithReplacement( // Initialize any non-existant input ids, update existing ones for (auto inp_id : ir_utils::filterByType(replay->inputs())) { - if (!graph.disjointIdSets().mappingExists(inp_id)) { + if (!graph.disjointValSets().mappingExists(inp_id)) { // inp_id is not initialized in the map, initialize it - graph.initializeId(inp_id, {}, {replay}); + graph.initializeVal(inp_id, {}, {replay}); } else { // Update unique uses of existing input ids auto inp_group = graph.toGroup(inp_id); @@ -538,9 +538,9 @@ Expr* IdModel::addExprWithReplacement( // Initialize any non-existant output ids, update existing ones for (auto out_id : ir_utils::filterByType(replay->outputs())) { - if (!graph.disjointIdSets().mappingExists(out_id)) { + if (!graph.disjointValSets().mappingExists(out_id)) { // out_id is not initialized in the map, initialize it - graph.initializeId(out_id, {replay}, {}); + graph.initializeVal(out_id, {replay}, {}); } else { // out_id is already initialized, add the replay as a unique definition // of its group @@ -605,7 +605,7 @@ IterDomain* IdModel::cloneIterDomain(IterDomain* id) { } auto& graph = graph_it->second; - if (graph.disjointIdSets().disjointSetMap().empty()) { + if (graph.disjointValSets().disjointSetMap().empty()) { continue; } @@ -618,15 +618,15 @@ IterDomain* IdModel::cloneIterDomain(IterDomain* id) { id_definitions_[id_copy] = {}; for (auto mode : initialized_modes) { - idGraph(mode).initializeId(id_copy, {}, {}); + idGraph(mode).initializeVal(id_copy, {}, {}); idGraph(mode).mapIds(id, id_copy); } return id_copy; } -IdGraph IdModel::initializeIdGraph(bool propagate_through_exprs) { - IdGraph id_graph(propagate_through_exprs); +ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) { + ValGraph id_graph(propagate_through_exprs); for (const auto& [id, defs] : id_definitions_) { auto uses_it = id_uses_.find(id); @@ -635,7 +635,7 @@ IdGraph IdModel::initializeIdGraph(bool propagate_through_exprs) { "Failed to initialize id: ", id->toString(), " as it's missing a definition entry."); - id_graph.initializeId(id, defs, uses_it->second); + id_graph.initializeVal(id, defs, uses_it->second); } return id_graph; @@ -759,7 +759,7 @@ void IdModel::validatePTypes(const std::vector& all_tvs) const { // } // for (const auto& disjoint_set : - // idGraph(IdMappingMode::EXACT).disjointIdSets().disjointSets()) { + // idGraph(IdMappingMode::EXACT).disjointValSets().disjointSets()) { // for (auto id : disjoint_set->vector()) { // auto id_ptype = id->getParallelType(); @@ -773,10 +773,10 @@ void IdModel::validatePTypes(const std::vector& all_tvs) const { void IdModel::propagateLoopPTypes() const { for (const auto& loop_disjoint_set : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { ParallelType common_ptype = ParallelType::Serial; for (auto id : loop_disjoint_set->vector()) { - auto id_ptype = id->getParallelType(); + auto id_ptype = id->as()->getParallelType(); NVF_ERROR( id_ptype == common_ptype || id_ptype == ParallelType::Serial || @@ -791,7 +791,7 @@ void IdModel::propagateLoopPTypes() const { } for (auto id : loop_disjoint_set->vector()) { - id->parallelize(common_ptype); + id->as()->parallelize(common_ptype); } } } @@ -800,7 +800,7 @@ namespace { struct StatefulLoweringInfo { // Tracks all p2c mappings in permissive maps even those not inlined between // producer and consumer - std::unordered_map> + std::unordered_map> p2c_permissive_maps; // All consumer ids in a deterministic order (ignores fusion->inputs()) @@ -808,7 +808,7 @@ struct StatefulLoweringInfo { // p2c mappings through the fusion within (including dependencies of) inlined // leaf domains. - std::unordered_map> + std::unordered_map> p2c_ca_permissive_maps; // All producer ids within (including dependencies of) inlined leaf domains, @@ -851,8 +851,8 @@ std::unordered_map resolvedRootBroadcasts( StatefulLoweringInfo buildInfo( const std::vector& exprs, - const IdGraph& exact_graph, - const IdGraph& permissive_graph) { + const ValGraph& exact_graph, + const ValGraph& permissive_graph) { StatefulLoweringInfo info; // Grab inlining relationships for (auto expr : exprs) { @@ -885,7 +885,7 @@ StatefulLoweringInfo buildInfo( if (p_id == other_exact_bcast) { continue; } - if (all_producer_ca_deps.has(other_exact_bcast)) { + if (all_producer_ca_deps.has(other_exact_bcast->as())) { // TODO-NM: Why is this here? Can be removed? NVF_ERROR( false, @@ -895,7 +895,8 @@ StatefulLoweringInfo buildInfo( p_id->name(), " of ", producer->toString()); - info.p2c_root_broadcast_resolution_map[other_exact_bcast] + info.p2c_root_broadcast_resolution_map[other_exact_bcast + ->as()] .pushBack(c_id); } } @@ -912,17 +913,20 @@ StatefulLoweringInfo buildInfo( if (entry.second.empty()) { continue; } - if (all_producer_ca_deps.has(entry.first)) { - info.p2c_ca_permissive_maps[entry.first].pushBack(entry.second); + if (all_producer_ca_deps.has(entry.first->as())) { + info.p2c_ca_permissive_maps[entry.first->as()].pushBack( + entry.second); } - info.p2c_permissive_maps[entry.first].pushBack(entry.second); + info.p2c_permissive_maps[entry.first->as()].pushBack( + entry.second); } for (const auto& entry : p2c_permissive_map) { if (entry.second.empty()) { continue; } - info.p2c_permissive_maps[entry.first].pushBack(entry.second); + info.p2c_permissive_maps[entry.first->as()].pushBack( + entry.second); } } } @@ -939,7 +943,7 @@ void IdModel::build( // found, then querying an empty permissive map will fail later. // Initialize disjoint sets for (auto mode : kIdMappingModes) { - id_graphs_[mode] = IdGraph(); + id_graphs_[mode] = ValGraph(); } std::vector tv_exprs; @@ -1032,22 +1036,22 @@ void IdModel::build( VectorOfUniqueEntries IdModel::computeTerminalLoopIds( const StatefulLoweringInfo info) { VectorOfUniqueEntries terminal_loop_ids; - for (const IdGroup& group : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + for (const ValGroup& group : + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { if (group->size() == 1) { - terminal_loop_ids.pushBack(group->front()); + terminal_loop_ids.pushBack(group->front()->as()); } // Don't select producer iter domains for (auto loop_id : *group) { - if (info.p2c_ca_permissive_maps.find(loop_id) != + if (info.p2c_ca_permissive_maps.find(loop_id->as()) != info.p2c_ca_permissive_maps.end()) { continue; } - auto uses_it = id_uses_.find(loop_id); + auto uses_it = id_uses_.find(loop_id->as()); if (uses_it == id_uses_.end()) { - terminal_loop_ids.pushBack(loop_id); + terminal_loop_ids.pushBack(loop_id->as()); continue; } @@ -1064,19 +1068,19 @@ VectorOfUniqueEntries IdModel::computeTerminalLoopIds( } if (!all_outs_in_loop_group) { - terminal_loop_ids.pushBack(loop_id); + terminal_loop_ids.pushBack(loop_id->as()); } } } return terminal_loop_ids; } -IdGraph IdModel::buildIntersection( - const IdGraph& graph0, - const IdGraph& graph1, +ValGraph IdModel::buildIntersection( + const ValGraph& graph0, + const ValGraph& graph1, bool propagate_exprs) { auto intersection = initializeIdGraph(propagate_exprs); - for (const auto& group0 : graph0.disjointIdSets().disjointSets()) { + for (const auto& group0 : graph0.disjointValSets().disjointSets()) { auto set_size = group0->size(); for (auto id0_i : c10::irange(set_size)) { auto id0 = group0->vector()[id0_i]; @@ -1084,7 +1088,7 @@ IdGraph IdModel::buildIntersection( auto id1 = group0->vector()[id1_i]; // id0 and id1 map in group0. If they also map in the group1, // add the mapping to the inersection. - if (graph1.disjointIdSets().strictAreMapped(id0, id1)) { + if (graph1.disjointValSets().strictAreMapped(id0, id1)) { intersection.mapIds(id0, id1); } } @@ -1103,15 +1107,15 @@ void IdModel::initializeLoopMap(StatefulLoweringInfo& info) { for (IterDomain* p_id : info.ordered_p_ca_ids) { auto entry_it = info.p2c_ca_permissive_maps.find(p_id); if (entry_it != info.p2c_ca_permissive_maps.end()) { - const VectorOfUniqueEntries& c_ids = entry_it->second; - for (IterDomain* c_id : c_ids) { + const VectorOfUniqueEntries& c_ids = entry_it->second; + for (Val* c_id : c_ids) { idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); } } } } -std::unordered_map IdModel::buildInlinePromotions( +std::unordered_map IdModel::buildInlinePromotions( StatefulLoweringInfo& info) { // Make an intersection of the exact and loop map. This will group together // entries in each loop group that are exact with each other. This provides a @@ -1137,7 +1141,7 @@ std::unordered_map IdModel::buildInlinePromotions( // smaller groups and this algorithm scales with the number of groups * // (number of entries in groups ^ 2) - IdGraph intersection_exact_loop_graph = buildIntersection( + ValGraph intersection_exact_loop_graph = buildIntersection( idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); // Promotion logic is going to be on the intersection of the exact and loop @@ -1146,7 +1150,7 @@ std::unordered_map IdModel::buildInlinePromotions( // the map. // // iel stands for Intersection of the Exact and Loop graphs. - std::unordered_map iel_promotion_map; + std::unordered_map iel_promotion_map; // This should probably work just on terminating inputs, as we shouldn't be // able to modify a broadcast domain between root and rfactor which would be @@ -1173,19 +1177,20 @@ std::unordered_map IdModel::buildInlinePromotions( // Note again this process is only done for root domains. Once we // find promotion relationships for root domains, we propagate the // mappings to derived domains - for (const IdGroup& iel_group : - intersection_exact_loop_graph.disjointIdSets().disjointSets()) { + for (const ValGroup& iel_group : + intersection_exact_loop_graph.disjointValSets().disjointSets()) { NVF_ERROR(!iel_group->empty()); - if (!iel_group->front()->isBroadcast()) { + if (!iel_group->front()->as()->isBroadcast()) { continue; } // Collect all the exact groups of the resolutions of the broadcast id's - IdGroups resolved_exact_groups; - for (IterDomain* bcast_id : *iel_group) { + ValGroups resolved_exact_groups; + for (Val* bcast_id : *iel_group) { if (auto p2c_root_broadcast_resolution_map_it = - info.p2c_root_broadcast_resolution_map.find(bcast_id); + info.p2c_root_broadcast_resolution_map.find( + bcast_id->as()); p2c_root_broadcast_resolution_map_it != info.p2c_root_broadcast_resolution_map.end()) { resolved_exact_groups.pushBack( @@ -1202,7 +1207,7 @@ std::unordered_map IdModel::buildInlinePromotions( // The intersection of the exact groups that the broadcast domains can be // broadcasted to, and those that exist within the same loop groop are is // the promotion needed for this iel_group. - IdGroups loop_exact_resolved_intersection = + ValGroups loop_exact_resolved_intersection = resolved_exact_groups.intersect(loop_covered_exact_groups); if (loop_exact_resolved_intersection.empty()) { @@ -1217,16 +1222,16 @@ std::unordered_map IdModel::buildInlinePromotions( << "Invalid multiple broadcast resolution within shared loops detected, group:\n " << iel_group->toString() << "\nIs being broadcasted to:"; - for (const IdGroup& entry : loop_exact_resolved_intersection) { + for (const ValGroup& entry : loop_exact_resolved_intersection) { err_msg << "\n " << entry->toString(); } NVF_ERROR(false, err_msg.str()); } // loop_exact_resolved_intersection.size() must be 1 at this point - IdGroup exact_resolution_group = loop_exact_resolved_intersection.front(); + ValGroup exact_resolution_group = loop_exact_resolved_intersection.front(); - VectorOfUniqueEntries resolved_ids = + VectorOfUniqueEntries resolved_ids = exact_resolution_group->intersect(*loop_group); auto promoted_iel_groups = intersection_exact_loop_graph.toGroups(resolved_ids); @@ -1242,13 +1247,14 @@ std::unordered_map IdModel::buildInlinePromotions( << "Invalid multiple broadcast resolution within shared loops detected, group:\n " << iel_group->toString() << "\nIs being broadcasted to:"; - for (const IdGroup& entry : promoted_iel_groups) { + for (const ValGroup& entry : promoted_iel_groups) { err_msg << "\n " << entry->toString(); } NVF_ERROR(false, err_msg.str()); } - iel_promotion_map[iel_group] = promoted_iel_groups.front()->front(); + iel_promotion_map[iel_group] = + promoted_iel_groups.front()->front()->as(); } // Propagate promotion mappings from root domains to derived domains @@ -1265,7 +1271,7 @@ std::unordered_map IdModel::buildInlinePromotions( for (const ExprGroup& iel_expr : iel_stmt_sort.exprs()) { NVF_ERROR(!iel_expr->empty()); - std::vector input_groups = + std::vector input_groups = intersection_exact_loop_graph.inputGroups(iel_expr); // Check if any inputs need promotion indicating this expr group needs to @@ -1273,10 +1279,10 @@ std::unordered_map IdModel::buildInlinePromotions( std::vector promoted_inputs; bool an_input_was_promoted = false; - for (const IdGroup& inp : input_groups) { + for (const ValGroup& inp : input_groups) { auto inp_promo_it = iel_promotion_map.find(inp); if (inp_promo_it == iel_promotion_map.end()) { - promoted_inputs.push_back(inp->front()); + promoted_inputs.push_back(inp->front()->as()); } else { promoted_inputs.push_back(inp_promo_it->second); an_input_was_promoted = true; @@ -1288,7 +1294,7 @@ std::unordered_map IdModel::buildInlinePromotions( continue; } - IdGroups promoted_input_groups; + ValGroups promoted_input_groups; for (auto inp_id : promoted_inputs) { if (intersection_exact_loop_graph.hasGroup(inp_id)) { promoted_input_groups.pushBack( @@ -1318,7 +1324,7 @@ std::unordered_map IdModel::buildInlinePromotions( // seems perfectly fine to reuse the merge of iS17 and iS45. ExprGroups non_promoted_input_uses; - for (const IdGroup& iel_group : + for (const ValGroup& iel_group : promoted_input_groups.intersect(input_groups)) { non_promoted_input_uses.pushBack( intersection_exact_loop_graph.getUniqueUses(iel_group)); @@ -1334,7 +1340,7 @@ std::unordered_map IdModel::buildInlinePromotions( if (iel_expr == iel_use_group) { continue; } - if (IdGraph::transformAtributesMatch( + if (ValGraph::transformAtributesMatch( iel_expr->front(), iel_use_group->front())) { auto use_inps = ir_utils::filterByType(iel_use_group->front()->inputs()) @@ -1342,7 +1348,7 @@ std::unordered_map IdModel::buildInlinePromotions( bool inps_match = true; for (auto inp_i : c10::irange(use_inps.size())) { inps_match = inps_match && - intersection_exact_loop_graph.disjointIdSets().strictAreMapped( + intersection_exact_loop_graph.disjointValSets().strictAreMapped( use_inps[inp_i], promoted_inputs[inp_i]); } if (inps_match) { @@ -1357,7 +1363,7 @@ std::unordered_map IdModel::buildInlinePromotions( replay = addReplayAs(promoted_inputs, iel_expr->front()); } - std::vector out_groups = + std::vector out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); // Mark outputs as having a promoted iter domain @@ -1391,13 +1397,13 @@ std::unordered_map IdModel::buildInlinePromotions( namespace { -std::unordered_map updateMap( - const std::unordered_map& stale_map, - IdGraph& new_graph) { - std::unordered_map new_map; +std::unordered_map updateMap( + const std::unordered_map& stale_map, + ValGraph& new_graph) { + std::unordered_map new_map; for (const auto& [stale_key, mapped_id] : stale_map) { - const IdGroups& new_groups = new_graph.toGroups(*stale_key); + const ValGroups& new_groups = new_graph.toGroups(*stale_key); NVF_ERROR( new_groups.size() == 1, "\nUpdate map assumes that new graph is equivalent to old graph plus extra mappings.\n", @@ -1411,33 +1417,35 @@ std::unordered_map updateMap( return new_map; } -// Returns for each IdGroup in provided IdGraph what the input IdGroups are -// traversing on definitions. Ignoring broadcast IdGroups and resetting inputs -// at RFactor IdGroups. -std::unordered_map computeCoveredGroups( - const IdGraph& exact_graph, +// Returns for each ValGroup in provided IdGraph what the input ValGroups are +// traversing on definitions. Ignoring broadcast ValGroups and resetting inputs +// at RFactor ValGroups. +std::unordered_map computeCoveredGroups( + const ValGraph& exact_graph, const std::unordered_set& view_rfactor_ids) { // Map from an exact iter domain group, to all the exact iter domain groups it // covers - std::unordered_map covered_ids; + std::unordered_map covered_ids; - for (const IdGroup& id_group : exact_graph.disjointIdSets().disjointSets()) { + for (const ValGroup& id_group : + exact_graph.disjointValSets().disjointSets()) { // Initialize inputs if (exact_graph.getUniqueDefinitions(id_group).empty()) { covered_ids[id_group] = {id_group}; } // Initialize rfactor groups - if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { - return view_rfactor_ids.find(id) != view_rfactor_ids.end(); + if (std::any_of(id_group->begin(), id_group->end(), [&](Val* id) { + return view_rfactor_ids.find(id->as()) != + view_rfactor_ids.end(); })) { covered_ids[id_group] = {id_group}; } // Initialize broadcast groups to empty since broadcast domains // don't matter for indexing - if (std::any_of(id_group->begin(), id_group->end(), [&](IterDomain* id) { - return id->isBroadcast(); + if (std::any_of(id_group->begin(), id_group->end(), [&](Val* id) { + return id->as()->isBroadcast(); })) { covered_ids[id_group] = {}; } @@ -1446,14 +1454,14 @@ std::unordered_map computeCoveredGroups( IdGraphStmtSort exact_stmt_sort(exact_graph); for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) { - std::vector input_groups = exact_graph.inputGroups(exact_expr); + std::vector input_groups = exact_graph.inputGroups(exact_expr); - IdGroups covered; - for (const IdGroup& inp_group : input_groups) { + ValGroups covered; + for (const ValGroup& inp_group : input_groups) { covered.pushBack(covered_ids.at(inp_group)); } - for (const IdGroup& output_group : exact_graph.outputGroups(exact_expr)) { + for (const ValGroup& output_group : exact_graph.outputGroups(exact_expr)) { // Don't overwrite initialized cases due to rfactor markings. if (covered_ids.find(output_group) == covered_ids.end()) { covered_ids[output_group] = covered; @@ -1465,10 +1473,10 @@ std::unordered_map computeCoveredGroups( } }; // namespace -std::unordered_map IdModel::buildLoopPromotionMap( +std::unordered_map IdModel::buildLoopPromotionMap( const std::vector& exprs, StatefulLoweringInfo& info, - const std::unordered_map& stale_promotion_map) { + const std::unordered_map& stale_promotion_map) { // Non-ca domains may also need to be promoted if parent domains are // promoted. @@ -1483,7 +1491,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // Map from an exact iter domain group, to all the exact iter domain groups it // covers; needs to be recomputed. - std::unordered_map exact_covered_ids = + std::unordered_map exact_covered_ids = computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); // Grab terminal iter domain in the loop groups. @@ -1500,36 +1508,39 @@ std::unordered_map IdModel::buildLoopPromotionMap( // have to be in the loop group) that covers all the exact groups // representative of the resolved transformations within the loop group. Only // the inlined loop groups will be covered here. - std::unordered_map loop_graph_copy_promotion_map; + std::unordered_map loop_graph_copy_promotion_map; // TODO: I'm uncertain if we can simply use the iel_promotion_map. Once this // system is in use we should test not recomputing the "concrete ids". - for (const IdGroup& loop_group : - loop_graph_copy.disjointIdSets().disjointSets()) { + for (const ValGroup& loop_group : + loop_graph_copy.disjointValSets().disjointSets()) { if (loop_group->size() == 1) { - loop_graph_copy_promotion_map[loop_group] = loop_group->front(); + loop_graph_copy_promotion_map[loop_group] = + loop_group->front()->as(); continue; } // Grab all the (potentially promoted) terminal iter domains in this group. // Save the exact group and the iter domain in this vector. - std::vector> exact_promoted_terminal_ids; + std::vector> exact_promoted_terminal_ids; for (auto loop_id : *loop_group) { // If not a terminal id in the group skip - if (!terminal_loop_ids.has(loop_id)) { + if (!terminal_loop_ids.has(loop_id->as())) { continue; } // Grab the iel entry - const IdGroup& iel_group = intersection_exact_loop_graph.toGroup(loop_id); + const ValGroup& iel_group = + intersection_exact_loop_graph.toGroup(loop_id); auto iel_promo_it = iel_promotion_map.find(iel_group); if (iel_promo_it == iel_promotion_map.end()) { // If this terminal ID doesn't have a promotion associated with it, save // the terminal ID. exact_promoted_terminal_ids.emplace_back( - idGraph(IdMappingMode::EXACT).toGroup(loop_id), loop_id); + idGraph(IdMappingMode::EXACT).toGroup(loop_id), + loop_id->as()); } else { // If this terminal ID has a promotion, grab the promoted ID. exact_promoted_terminal_ids.emplace_back( @@ -1539,11 +1550,12 @@ std::unordered_map IdModel::buildLoopPromotionMap( } // All the exact groups of the iter domains in the loop group - IdGroups exact_groups = idGraph(IdMappingMode::EXACT).toGroups(*loop_group); + ValGroups exact_groups = + idGraph(IdMappingMode::EXACT).toGroups(*loop_group); // All exact groups covered by all iter domains in this loop group - IdGroups loop_group_covered_ids; - for (const IdGroup& exact_group : exact_groups) { + ValGroups loop_group_covered_ids; + for (const ValGroup& exact_group : exact_groups) { auto covered_it = exact_covered_ids.find(exact_group); NVF_ERROR(covered_it != exact_covered_ids.end()); loop_group_covered_ids.pushBack(covered_it->second); @@ -1555,7 +1567,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // exact groups of loop_group_covered_ids. If so, that's the correct // promoted iter domain of this group. for (const auto& entry : exact_promoted_terminal_ids) { - const IdGroup& terminal_id_group = entry.first; + const ValGroup& terminal_id_group = entry.first; IterDomain* terminal_id = entry.second; auto covered_it = exact_covered_ids.find(terminal_id_group); NVF_ERROR(covered_it != exact_covered_ids.end()); @@ -1572,15 +1584,15 @@ std::unordered_map IdModel::buildLoopPromotionMap( err_msg << nvfuser::toString(loop_group, 0, true); err_msg << "\nnone of the terminal iter domains of this group:\n "; for (const auto& entry : exact_promoted_terminal_ids) { - const IdGroup& terminal_id_group = entry.first; - const IdGroups& covered_id_groups = + const ValGroup& terminal_id_group = entry.first; + const ValGroups& covered_id_groups = exact_covered_ids.at(terminal_id_group); err_msg << " " << nvfuser::toString(terminal_id_group, 0, true) << " -(covers)-> " << nvfuser::toString(covered_id_groups) << std::endl; } err_msg << "iter domains in this group cover all id groups:\n"; - for (const IdGroup& covered_group : loop_group_covered_ids) { + for (const ValGroup& covered_group : loop_group_covered_ids) { err_msg << " " << nvfuser::toString(covered_group, 0, true); } // NVF_ERROR(false, err_msg.str()); @@ -1607,10 +1619,10 @@ std::unordered_map IdModel::buildLoopPromotionMap( IdGraphStmtSort(intersection_exact_loop_graph).exprs()) { NVF_ERROR(!iel_expr->empty()); - std::vector iel_inp_groups = + std::vector iel_inp_groups = intersection_exact_loop_graph.inputGroups(iel_expr); - std::vector iel_out_groups = + std::vector iel_out_groups = intersection_exact_loop_graph.outputGroups(iel_expr); // When replaying the transformations we can't blindly apply loop promotion @@ -1636,13 +1648,13 @@ std::unordered_map IdModel::buildLoopPromotionMap( // So if we have an iel_expr make sure it's inputs and outputs are not in // the same loop group. - IdGroups inp_loop_groups; - for (const IdGroup& iel_inp_group : iel_inp_groups) { + ValGroups inp_loop_groups; + for (const ValGroup& iel_inp_group : iel_inp_groups) { inp_loop_groups.pushBack(loop_graph_copy.toGroup(iel_inp_group->front())); } - IdGroups out_loop_groups; - for (const IdGroup& iel_out_group : iel_out_groups) { + ValGroups out_loop_groups; + for (const ValGroup& iel_out_group : iel_out_groups) { out_loop_groups.pushBack(loop_graph_copy.toGroup(iel_out_group->front())); } @@ -1655,14 +1667,14 @@ std::unordered_map IdModel::buildLoopPromotionMap( bool an_input_was_promoted = false; // Promote inputs for replay - for (const IdGroup& iel_inp_group : iel_inp_groups) { + for (const ValGroup& iel_inp_group : iel_inp_groups) { // Promote loops based on the loop promotion map. If the loop promotion // map should be used and has an entry we should use that promotion. This // happen when an iel expression is across a loop group boundary. // Signifying and capturing instances when we traverse across an inlined // loop group to a non-inlined loop group boundary (think of the iel graph // projected onto the loop graph). - const IdGroup& loop_copy_group = + const ValGroup& loop_copy_group = loop_graph_copy.toGroup(iel_inp_group->front()); auto inp_loop_promo_it = loop_graph_copy_promotion_map.find(loop_copy_group); @@ -1677,7 +1689,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // of the iel graph. auto inp_promo_it = iel_promotion_map.find(iel_inp_group); if (inp_promo_it == iel_promotion_map.end()) { - promoted_inputs.push_back(iel_inp_group->front()); + promoted_inputs.push_back(iel_inp_group->front()->as()); } else { promoted_inputs.push_back(inp_promo_it->second); an_input_was_promoted = true; @@ -1702,7 +1714,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // would match. Grab all uses of the promoted inputs' groups in the IEL // map. Note that promotion should be to loop-mapped domains, so // the IEL graph is used rather than the exact graph - std::vector promoted_input_groups; + std::vector promoted_input_groups; ExprGroups promoted_input_uses; for (auto inp_id : promoted_inputs) { @@ -1723,7 +1735,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( for (const ExprGroup& iel_use_group : promoted_input_uses) { NVF_ERROR(!iel_use_group->empty()); // Check if all the attributes (including type) of the transform match - if (!IdGraph::transformAtributesMatch( + if (!ValGraph::transformAtributesMatch( iel_expr->front(), iel_use_group->front())) { continue; } @@ -1779,7 +1791,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( for (auto i : c10::irange(replay_out_ids.size())) { if (!idGraph(IdMappingMode::EXACT) - .disjointIdSets() + .disjointValSets() .strictAreMapped(replay_out_ids[i], output_groups[i]->front())) { // Promote if necessary, if the output is already in the same exact map // it doesn't need a promotion. @@ -1818,12 +1830,12 @@ std::unordered_map IdModel::buildLoopPromotionMap( // that may also contain the promotion the loop should be associated // with. Once all candidates are obtained, we pick one that covers // all the exact domains (cf. concrete domains in ComputeAtMap) - for (const IdGroup& loop_group : - loop_graph_copy.disjointIdSets().disjointSets()) { - IdGroups iel_groups = intersection_exact_loop_graph.toGroups(*loop_group); + for (const ValGroup& loop_group : + loop_graph_copy.disjointValSets().disjointSets()) { + ValGroups iel_groups = intersection_exact_loop_graph.toGroups(*loop_group); // All exact groups covered by all iter domains in this loop group - IdGroups loop_group_covered_ids; - for (const IdGroup& iel_group : iel_groups) { + ValGroups loop_group_covered_ids; + for (const ValGroup& iel_group : iel_groups) { auto exact_group = idGraph(IdMappingMode::EXACT).toGroup(iel_group->front()); auto covered_it = exact_covered_ids.find(exact_group); @@ -1836,13 +1848,14 @@ std::unordered_map IdModel::buildLoopPromotionMap( VectorOfUniqueEntries representative_id_candidates; - for (const IdGroup& iel_group : iel_groups) { + for (const ValGroup& iel_group : iel_groups) { if (auto iel_promotion_map_it = iel_promotion_map.find(iel_group); iel_promotion_map_it != iel_promotion_map.end()) { IterDomain* iel_promotion_id = iel_promotion_map_it->second; representative_id_candidates.pushBack(iel_promotion_id); } else { - representative_id_candidates.pushBack(iel_group->front()); + representative_id_candidates.pushBack( + iel_group->front()->as()); } } @@ -1868,7 +1881,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // Found VERBOSE() << "Representative found: " << candidate_id->toString() << std::endl; - const IdGroup& current_loop_group = + const ValGroup& current_loop_group = idGraph(IdMappingMode::LOOP).toGroup(loop_group->front()); loop_promotion_map_.emplace(current_loop_group, candidate_id); break; @@ -1877,8 +1890,8 @@ std::unordered_map IdModel::buildLoopPromotionMap( } // Sanity check of the loop promotion map - for (const IdGroup& loop_group : - idGraph(IdMappingMode::LOOP).disjointIdSets().disjointSets()) { + for (const ValGroup& loop_group : + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { // Non-leaf loop groups are not guaranteed to have valid // promotions. See for example FusionRepro1713, where root domains // are all grouped together but there's no valid promotion. @@ -1913,7 +1926,7 @@ std::unordered_map IdModel::buildIndexGraph( const std::vector& exprs, const std::vector& all_tvs, StatefulLoweringInfo& info, - std::unordered_map stale_promotion_map) { + std::unordered_map stale_promotion_map) { NVF_ERROR(false, "Not implemented yet."); } diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index bb00160162c..1315c4fe125 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -9,8 +9,8 @@ #include #include -#include #include +#include #include #include @@ -19,7 +19,7 @@ namespace nvfuser { -class IdGraph; +class ValGraph; namespace { // Convenience to store some intermediate data across a few lowering build @@ -97,8 +97,8 @@ class IdModel : public PolymorphicBase { IdModel(Fusion* fusion, bool allow_self_mapping = false); // Returns iter domain graph of provided mode. - const IdGraph& idGraph(IdMappingMode mode) const; - IdGraph& idGraph(IdMappingMode mode); + const ValGraph& idGraph(IdMappingMode mode) const; + ValGraph& idGraph(IdMappingMode mode); // IterDomains from the original fusion are only allowed to be used once in // the IterDomain graph, id->uses() are not directly used as there's no bounds @@ -165,7 +165,7 @@ class IdModel : public PolymorphicBase { // not have any registered uses or definitions. IterDomain* cloneIterDomain(IterDomain* id); - const std::unordered_map loopPromotionMap() const { + const std::unordered_map loopPromotionMap() const { return loop_promotion_map_; } @@ -187,7 +187,7 @@ class IdModel : public PolymorphicBase { // Iterates over all IterDomains in id_definitions_ and calls initializeID on // a new IdGraph and returns it. - IdGraph initializeIdGraph(bool propagate_through_exprs = true); + ValGraph initializeIdGraph(bool propagate_through_exprs = true); // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs // and first output of expr @@ -222,9 +222,9 @@ class IdModel : public PolymorphicBase { // Returns an IdGraph with all Id's mapped that are mapped both in graph0 and // graph1. - IdGraph buildIntersection( - const IdGraph& graph0, - const IdGraph& graph1, + ValGraph buildIntersection( + const ValGraph& graph0, + const ValGraph& graph1, bool propagate_exprs = true); // !! END Helper functions to build loop promotion and index map!! @@ -232,19 +232,19 @@ class IdModel : public PolymorphicBase { // Start loop map by grouping inlined iter domains void initializeLoopMap(StatefulLoweringInfo& info); - // Returns map of IdGroups in the loop map to a representative IterDomain that - // contains all resolved transformations that the terminal IterDomains should - // be promoted to. The returned promotions are valid only for inlined iter - // domains. - std::unordered_map buildInlinePromotions( + // Returns map of ValGroups in the loop map to a representative IterDomain + // that contains all resolved transformations that the terminal IterDomains + // should be promoted to. The returned promotions are valid only for inlined + // iter domains. + std::unordered_map buildInlinePromotions( StatefulLoweringInfo& info); // Returns a similar thing to buildInlinePromotions but also includes iter // domains that are not inlined. - std::unordered_map buildLoopPromotionMap( + std::unordered_map buildLoopPromotionMap( const std::vector& exprs, StatefulLoweringInfo& info, - const std::unordered_map& stale_promotion_map); + const std::unordered_map& stale_promotion_map); // Builds idGraph(IdMappingMode::INDEX) and returns the iter domain promotion // map to go from leaf domains of each (consumer only?) tensor to their @@ -253,14 +253,14 @@ class IdModel : public PolymorphicBase { const std::vector& exprs, const std::vector& all_tvs, StatefulLoweringInfo& info, - std::unordered_map stale_promotion_map); + std::unordered_map stale_promotion_map); // Returns the terminal rfactor or input iter domains each group in the almost // exact map covers (in the almost exact map). This effectively returns all // the input almost exact iter domain groups for each almost exact iter domain // group. RFactor axes are considered an "input" as all broadcast dimensions // have to be resolved by or before the rfactor iter domain. - std::unordered_map buildCoveredAlmostExact(); + std::unordered_map buildCoveredAlmostExact(); // ======= END Iteration domain build process in order called ======= @@ -272,7 +272,7 @@ class IdModel : public PolymorphicBase { // Using an array here might be nice, but it seems hard to use an enum as an // array key // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum - std::unordered_map id_graphs_; + std::unordered_map id_graphs_; // If multiple transformations occur IterDomains could have multiple uses, // however only one should be active in the given Fusion. When we resolve loop @@ -291,7 +291,7 @@ class IdModel : public PolymorphicBase { self_mapping_info_ = c10::nullopt; // Promotion domain for each loop group - std::unordered_map loop_promotion_map_; + std::unordered_map loop_promotion_map_; std::unordered_set view_rfactor_ids_; }; diff --git a/csrc/id_model/to_string.cpp b/csrc/id_model/to_string.cpp index 885e4171e57..87057559f6c 100644 --- a/csrc/id_model/to_string.cpp +++ b/csrc/id_model/to_string.cpp @@ -37,6 +37,19 @@ std::string indent(int size = 0) { } } // namespace +std::string toString(const std::vector& id_group, int indent_size) { + std::vector names; + names.reserve(id_group.size()); + for (auto id : id_group) { + names.push_back(id->name()); + } + std::sort(names.begin(), names.end()); + + std::stringstream ss; + ss << indent(indent_size) << "{" << names << "}"; + return ss.str(); +} + std::string toString( const std::vector& id_group, int indent_size) { @@ -52,7 +65,7 @@ std::string toString( return ss.str(); } -std::string toString(const IdGroup& id_group, int indent_size, bool with_ptr) { +std::string toString(const ValGroup& id_group, int indent_size, bool with_ptr) { std::stringstream ss; ss << indent(indent_size) << "idg" << (with_ptr ? "(" : "") << toString(id_group.get(), with_ptr) << (with_ptr ? ")" : "") @@ -61,7 +74,7 @@ std::string toString(const IdGroup& id_group, int indent_size, bool with_ptr) { } std::string toString( - const std::vector& id_groups, + const std::vector& id_groups, int indent_size, bool with_ptr) { std::stringstream ss; @@ -71,7 +84,7 @@ std::string toString( unsigned int pos = 0; - for (const IdGroup& id_group : id_groups) { + for (const ValGroup& id_group : id_groups) { unsigned int min_id_name = std::numeric_limits::max(); for (auto id : *id_group) { if (id->name() < min_id_name) { @@ -96,7 +109,7 @@ std::string toString( } std::string toString( - const IdGroups& id_groups, + const ValGroups& id_groups, int indent_size, bool with_ptr) { std::stringstream ss; @@ -106,7 +119,7 @@ std::string toString( unsigned int pos = 0; - for (const IdGroup& id_group : id_groups) { + for (const ValGroup& id_group : id_groups) { unsigned int min_id_name = std::numeric_limits::max(); for (auto id : *id_group) { if (id->name() < min_id_name) { @@ -130,13 +143,13 @@ std::string toString( return ss.str(); } -std::string toInlineString(const std::vector& id_groups) { +std::string toInlineString(const std::vector& id_groups) { // Track position in id_groups and its min iter domain name in the set std::vector> group_name_info; unsigned int pos = 0; - for (const IdGroup& id_group : id_groups) { + for (const ValGroup& id_group : id_groups) { unsigned int min_id_name = std::numeric_limits::max(); for (auto id : *id_group) { if (id->name() < min_id_name) { @@ -192,7 +205,7 @@ std::string toString( } std::string toString( - const IdGraph& id_graph, + const ValGraph& id_graph, const std::vector& expr_groups, int indent_size, bool with_ptr) { @@ -222,8 +235,8 @@ std::string toString( auto pos = group_name_info[i].second; const ExprGroup& expr_group = expr_groups[pos]; - auto inputs = IdGroups(id_graph.inputGroups(expr_group)); - auto outputs = IdGroups(id_graph.outputGroups(expr_group)); + auto inputs = ValGroups(id_graph.inputGroups(expr_group)); + auto outputs = ValGroups(id_graph.outputGroups(expr_group)); ss << indent(indent_size + 1) << toInlineString(inputs.vector()) << " --" << toString(expr_group, 0, with_ptr) << "--> " @@ -235,7 +248,7 @@ std::string toString( } std::string toString( - const IdGraph& id_graph, + const ValGraph& id_graph, const ExprGroups& expr_groups, int indent_size, bool with_ptr) { @@ -265,8 +278,8 @@ std::string toString( auto pos = group_name_info[i].second; auto expr_group = expr_groups.vector()[pos]; - auto inputs = IdGroups(id_graph.inputGroups(expr_group)); - auto outputs = IdGroups(id_graph.outputGroups(expr_group)); + auto inputs = ValGroups(id_graph.inputGroups(expr_group)); + auto outputs = ValGroups(id_graph.outputGroups(expr_group)); ss << indent(indent_size + 1) << toInlineString(inputs.vector()) << " --" << toString(expr_group, 0, with_ptr) << "--> " @@ -278,16 +291,16 @@ std::string toString( } std::string idGroupsString( - const IdGraph& id_graph, + const ValGraph& id_graph, int indent_size, bool with_ptr) { - IdGroups id_groups( - id_graph.disjointIdSets().disjointSets().begin(), - id_graph.disjointIdSets().disjointSets().end()); + ValGroups id_groups( + id_graph.disjointValSets().disjointSets().begin(), + id_graph.disjointValSets().disjointSets().end()); return toString(id_groups, indent_size, with_ptr); } std::string exprGroupsString( - const IdGraph& id_graph, + const ValGraph& id_graph, int indent_size, bool with_ptr) { ExprGroups expr_groups( @@ -297,11 +310,11 @@ std::string exprGroupsString( } std::string definitionsString( - const IdGraph& id_graph, + const ValGraph& id_graph, int indent_size, bool with_ptr) { ExprGroups defs; - for (const IdGroup& id_group : id_graph.disjointIdSets().disjointSets()) { + for (const ValGroup& id_group : id_graph.disjointValSets().disjointSets()) { auto definition_pair = id_graph.getDefinitions(id_group); if (definition_pair.second) { for (const ExprGroup& expr_group : definition_pair.first) { @@ -313,11 +326,11 @@ std::string definitionsString( } std::string usesString( - const IdGraph& id_graph, + const ValGraph& id_graph, int indent_size, bool with_ptr) { ExprGroups uses; - for (const IdGroup& id_group : id_graph.disjointIdSets().disjointSets()) { + for (const ValGroup& id_group : id_graph.disjointValSets().disjointSets()) { auto definition_pair = id_graph.getUses(id_group); if (definition_pair.second) { for (const ExprGroup& expr_group : definition_pair.first) { diff --git a/csrc/id_model/to_string.h b/csrc/id_model/to_string.h index f58cf00a2a2..d64d92a8fed 100644 --- a/csrc/id_model/to_string.h +++ b/csrc/id_model/to_string.h @@ -7,34 +7,36 @@ // clang-format on #pragma once -#include #include +#include #include #include namespace nvfuser { +std::string toString(const std::vector& id_group, int indent_size = 0); + std::string toString( const std::vector& id_group, int indent_size = 0); std::string toString( - const IdGroup& id_group, + const ValGroup& id_group, int indent_size = 0, bool with_ptr = false); std::string toString( - const std::vector& id_groups, + const std::vector& id_groups, int indent_size = 0, bool with_ptr = false); std::string toString( - const IdGroups& id_groups, + const ValGroups& id_groups, int indent_size = 0, bool with_ptr = false); -std::string toInlineString(const std::vector& id_groups); -std::string toInlineString(const IdGroups& id_groups); +std::string toInlineString(const std::vector& id_groups); +std::string toInlineString(const ValGroups& id_groups); std::string toString(const std::vector& expr_group, int indent_size = 0); std::string toString( @@ -43,41 +45,41 @@ std::string toString( bool with_ptr = false); std::string toString( - const IdGraph& id_graph, + const ValGraph& id_graph, const std::vector& expr_group, int indent_size = 0, bool with_ptr = false); std::string toString( - const IdGraph& id_graph, + const ValGraph& id_graph, const ExprGroup& expr_groups, int indent_size = 0, bool with_ptr = false); std::string toString( - const IdGraph& id_graph, + const ValGraph& id_graph, const std::vector& expr_groups, int indent_size = 0, bool with_ptr = false); std::string toString( - const IdGraph& id_graph, + const ValGraph& id_graph, const ExprGroups& expr_groups, int indent_size = 0, bool with_ptr = false); std::string idGroupsString( - const IdGraph& id_graph, + const ValGraph& id_graph, int indent_size = 0, bool with_ptr = false); std::string exprGroupsString( - const IdGraph& id_graph, + const ValGraph& id_graph, int indent_size = 0, bool with_ptr = false); std::string definitionsString( - const IdGraph& id_graph, + const ValGraph& id_graph, int indent_size = 0, bool with_ptr = false); std::string usesString( - const IdGraph& id_graph, + const ValGraph& id_graph, int indent_size = 0, bool with_ptr = false); diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp index 0f33135ad5f..24a210ed810 100644 --- a/csrc/id_model/visitor.cpp +++ b/csrc/id_model/visitor.cpp @@ -10,15 +10,15 @@ namespace nvfuser { void IdGraphVisitor::traverse() { - IdGroups all_ids; + ValGroups all_ids; ExprGroups all_exprs; { // Initialize IDs to traverse. If sub_selection is provided, only // traverse IDs that are included in the set are traversed. if (sub_selection_.empty()) { - all_ids = IdGroups( - graph().disjointIdSets().disjointSets().begin(), - graph().disjointIdSets().disjointSets().end()); + all_ids = ValGroups( + graph().disjointValSets().disjointSets().begin(), + graph().disjointValSets().disjointSets().end()); } else { for (auto id : sub_selection_) { if (graph().hasGroup(id)) { @@ -36,13 +36,13 @@ void IdGraphVisitor::traverse() { graph().disjointExprSets().disjointSets().begin(), graph().disjointExprSets().disjointSets().end()); } else { - for (const IdGroup& id_group : all_ids) { + for (const ValGroup& id_group : all_ids) { for (const ExprGroup& def : graph().getUniqueDefinitions(id_group)) { if (all_exprs.has(def)) { continue; } - auto inp_groups = IdGroups(graph().inputGroups(def)); - auto out_groups = IdGroups(graph().outputGroups(def)); + auto inp_groups = ValGroups(graph().inputGroups(def)); + auto out_groups = ValGroups(graph().outputGroups(def)); if (inp_groups.subtract(all_ids).empty() && out_groups.subtract(all_ids).empty()) { all_exprs.pushBack(def); @@ -53,12 +53,12 @@ void IdGraphVisitor::traverse() { } // There could be IterDomains in from or to that are between other from and // to nodes. Make sure to clear those out. - IdGroups terminating_inputs; - IdGroups terminating_outputs; + ValGroups terminating_inputs; + ValGroups terminating_outputs; { - IdGroups not_inputs; - IdGroups not_outputs; + ValGroups not_inputs; + ValGroups not_outputs; for (const ExprGroup& expr_group : all_exprs) { if (graph().isTrivialExprGroup(expr_group)) { // Expression is just a loop to its current group, ignore @@ -70,14 +70,14 @@ void IdGraphVisitor::traverse() { } terminating_inputs = - IdGroups(all_ids.begin(), all_ids.end()).subtract(not_inputs); + ValGroups(all_ids.begin(), all_ids.end()).subtract(not_inputs); terminating_outputs = - IdGroups(all_ids.begin(), all_ids.end()).subtract(not_outputs); + ValGroups(all_ids.begin(), all_ids.end()).subtract(not_outputs); } - IdGroups to_visit_ids = terminating_inputs; - IdGroups visited_ids; + ValGroups to_visit_ids = terminating_inputs; + ValGroups visited_ids; ExprGroups to_visit_exprs; ExprGroups visited_exprs; @@ -85,12 +85,12 @@ void IdGraphVisitor::traverse() { auto is_expr_ready = [&](const ExprGroup& expr_group) { auto inp_groups = graph().inputGroups(expr_group); return std::all_of( - inp_groups.begin(), inp_groups.end(), [&](IdGroup id_group) { + inp_groups.begin(), inp_groups.end(), [&](ValGroup id_group) { return visited_ids.has(id_group) || id_group->empty(); }); }; - auto is_id_ready = [&](const IdGroup& id_group) { + auto is_id_ready = [&](const ValGroup& id_group) { auto unique_defs = graph().getUniqueDefinitions(id_group); return std::all_of( unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { @@ -129,7 +129,7 @@ void IdGraphVisitor::traverse() { std::swap(to_visit_exprs, still_to_visit_exprs); - IdGroups still_to_visit_ids; + ValGroups still_to_visit_ids; while (!to_visit_ids.empty()) { auto current_id_group = to_visit_ids.popFront(); NVF_ERROR(!current_id_group->empty()); diff --git a/csrc/id_model/visitor.h b/csrc/id_model/visitor.h index 0c5122979f9..2c13b9efae0 100644 --- a/csrc/id_model/visitor.h +++ b/csrc/id_model/visitor.h @@ -8,19 +8,19 @@ #pragma once #include -#include #include +#include namespace nvfuser { // Iterates through an IterDomain Graph in topological order, calling handle on // all Id and all Expr groups in a forward topological order. // -// Warning: Expr groups that have an input and output in the same IdGroup are +// Warning: Expr groups that have an input and output in the same ValGroup are // ignored. // // Warning: This is not a great iterator if there's a desire to minimize paths -// traveled to simply visit all IdGroups in order. See ExprsBetween to see how +// traveled to simply visit all ValGroups in order. See ExprsBetween to see how // we might minimize paths. class IdGraphVisitor { public: @@ -36,7 +36,7 @@ class IdGraphVisitor { // If sub_selection is assumed to be a set of iter domains by which form a // sub-regrion of the IdGraph provided. Only that sub-region will be visited. IdGraphVisitor( - const IdGraph& id_graph, + const ValGraph& id_graph, const VectorOfUniqueEntries sub_selection = {}) : id_graph_(id_graph), sub_selection_(sub_selection) {} @@ -44,17 +44,17 @@ class IdGraphVisitor { IdGraphVisitor(IdGraphVisitor&& other) = default; - virtual void handle(IdGroup id_group) = 0; + virtual void handle(ValGroup id_group) = 0; virtual void handle(ExprGroup expr_group) = 0; void traverse(); - const IdGraph& graph() { + const ValGraph& graph() { return id_graph_; }; private: - const IdGraph& id_graph_; + const ValGraph& id_graph_; const VectorOfUniqueEntries sub_selection_; }; @@ -62,7 +62,7 @@ class IdGraphVisitor { class IdGraphStmtSort : public IdGraphVisitor { public: IdGraphStmtSort( - const IdGraph& id_graph, + const ValGraph& id_graph, const VectorOfUniqueEntries sub_selection = {}) : IdGraphVisitor(id_graph, sub_selection) { IdGraphVisitor::traverse(); @@ -74,7 +74,7 @@ class IdGraphStmtSort : public IdGraphVisitor { return sorted_exprs_; } - IdGroups ids() const { + ValGroups ids() const { return sorted_ids_; } @@ -82,7 +82,7 @@ class IdGraphStmtSort : public IdGraphVisitor { protected: using IdGraphVisitor::handle; - void handle(IdGroup id_group) override { + void handle(ValGroup id_group) override { sorted_ids_.pushBack(id_group); } @@ -91,7 +91,7 @@ class IdGraphStmtSort : public IdGraphVisitor { } ExprGroups sorted_exprs_; - IdGroups sorted_ids_; + ValGroups sorted_ids_; }; } // namespace nvfuser diff --git a/csrc/id_model/id_graph.cpp b/csrc/val_graph.cpp similarity index 80% rename from csrc/id_model/id_graph.cpp rename to csrc/val_graph.cpp index b45cbfdf3d8..06e0ebdfc36 100644 --- a/csrc/id_model/id_graph.cpp +++ b/csrc/val_graph.cpp @@ -5,10 +5,10 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include #include #include #include +#include namespace nvfuser { @@ -17,8 +17,8 @@ using UnorderedSetOfExprGroup = std::unordered_set; using DequeOfExprGroup = std::deque; } // namespace -IdGraph::IdGraph(const IdGraph& other) - : disjoint_ids_(other.disjoint_ids_), +ValGraph::ValGraph(const ValGraph& other) + : disjoint_vals_(other.disjoint_vals_), disjoint_exprs_(other.disjoint_exprs_), unique_definitions_(), unique_uses_() { @@ -46,27 +46,27 @@ IdGraph::IdGraph(const IdGraph& other) } } -IdGraph& IdGraph::operator=(const IdGraph& other) { - disjoint_ids_.clear(); +ValGraph& ValGraph::operator=(const ValGraph& other) { + disjoint_vals_.clear(); disjoint_exprs_.clear(); unique_definitions_.clear(); unique_uses_.clear(); - IdGraph copy(other); + ValGraph copy(other); std::swap(*this, copy); return *this; } // Return if there's a group entry in the graph for this expr -bool IdGraph::hasGroup(Expr* expr) const { +bool ValGraph::hasGroup(Expr* expr) const { return disjoint_exprs_.mappingExists(expr); } // Return if there's a group entry in the graph for this id -bool IdGraph::hasGroup(IterDomain* id) const { - return disjoint_ids_.mappingExists(id); +bool ValGraph::hasGroup(Val* id) const { + return disjoint_vals_.mappingExists(id); } -const ExprGroup& IdGraph::toGroup(Expr* expr) const { +const ExprGroup& ValGraph::toGroup(Expr* expr) const { auto disjoint_set_it = disjoint_exprs_.disjointSetMap().find(expr); NVF_ERROR( disjoint_set_it != disjoint_exprs_.disjointSetMap().end(), @@ -75,17 +75,17 @@ const ExprGroup& IdGraph::toGroup(Expr* expr) const { return disjoint_set_it->second; } -const IdGroup& IdGraph::toGroup(IterDomain* id) const { - auto disjoint_set_it = disjoint_ids_.disjointSetMap().find(id); +const ValGroup& ValGraph::toGroup(Val* id) const { + auto disjoint_set_it = disjoint_vals_.disjointSetMap().find(id); NVF_ERROR( - disjoint_set_it != disjoint_ids_.disjointSetMap().end(), + disjoint_set_it != disjoint_vals_.disjointSetMap().end(), "\nId group could not be found in graph associated with: ", id->toString(), "\n"); return disjoint_set_it->second; } -ExprGroups IdGraph::toGroups(const VectorOfUniqueEntries& exprs) const { +ExprGroups ValGraph::toGroups(const VectorOfUniqueEntries& exprs) const { ExprGroups expr_groups; for (auto expr : exprs) { expr_groups.pushBack(toGroup(expr)); @@ -93,36 +93,33 @@ ExprGroups IdGraph::toGroups(const VectorOfUniqueEntries& exprs) const { return expr_groups; } -IdGroups IdGraph::toGroups( - const VectorOfUniqueEntries& ids) const { - IdGroups id_groups; +ValGroups ValGraph::toGroups(const VectorOfUniqueEntries& ids) const { + ValGroups id_groups; for (auto id : ids) { id_groups.pushBack(toGroup(id)); } return id_groups; } -std::vector IdGraph::outputGroups(const ExprGroup& expr) const { - std::vector output_groups; - for (auto id_output : - ir_utils::filterByType(expr->front()->outputs())) { +std::vector ValGraph::outputGroups(const ExprGroup& expr) const { + std::vector output_groups; + for (auto id_output : ir_utils::filterByType(expr->front()->outputs())) { output_groups.push_back(toGroup(id_output)); } return output_groups; } -std::vector IdGraph::inputGroups(const ExprGroup& expr) const { - std::vector input_groups; - for (auto id_input : - ir_utils::filterByType(expr->front()->inputs())) { +std::vector ValGraph::inputGroups(const ExprGroup& expr) const { + std::vector input_groups; + for (auto id_input : ir_utils::filterByType(expr->front()->inputs())) { input_groups.push_back(toGroup(id_input)); } return input_groups; } -ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { +ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { DequeOfExprGroup to_visit; - for (const IdGroup& of_id_group : of) { + for (const ValGroup& of_id_group : of) { if (const auto& [group_uses, found] = getUses(of_id_group); found) { to_visit.insert(to_visit.end(), group_uses.begin(), group_uses.end()); } @@ -133,7 +130,7 @@ ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { ExprGroup current_expr = to_visit.front(); to_visit.pop_front(); visited.emplace(current_expr); - for (const IdGroup& output_id : outputGroups(current_expr)) { + for (const ValGroup& output_id : outputGroups(current_expr)) { if (const auto& [group_uses, found] = getUses(output_id); found) { for (const ExprGroup& group_use : group_uses) { if (visited.count(group_use)) { @@ -148,9 +145,9 @@ ExprGroups IdGraph::allUsesOf(const IdGroups& of) const { return visited; } -ExprGroups IdGraph::allDefinitionsOf(const IdGroups& of) const { +ExprGroups ValGraph::allDefinitionsOf(const ValGroups& of) const { DequeOfExprGroup to_visit; - for (const IdGroup& of_id_group : of) { + for (const ValGroup& of_id_group : of) { if (const auto& [group_defs, found] = getDefinitions(of_id_group); found) { to_visit.insert(to_visit.end(), group_defs.begin(), group_defs.end()); } @@ -161,7 +158,7 @@ ExprGroups IdGraph::allDefinitionsOf(const IdGroups& of) const { ExprGroup current_expr = to_visit.front(); to_visit.pop_front(); visited.emplace(current_expr); - for (const IdGroup& input_id : inputGroups(current_expr)) { + for (const ValGroup& input_id : inputGroups(current_expr)) { if (const auto& [group_defs, found] = getDefinitions(input_id); found) { for (const ExprGroup& group_def : group_defs) { if (visited.count(group_def)) { @@ -176,7 +173,7 @@ ExprGroups IdGraph::allDefinitionsOf(const IdGroups& of) const { return visited; } -ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) +ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) const { ExprGroups all_uses_of_from = allUsesOf(from); ExprGroups all_definitions_of_to = allDefinitionsOf(to); @@ -187,12 +184,12 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) // There could be IterDomains in from or to that are between other from and // to nodes. Make sure to clear those out. - IdGroups terminating_inputs; - IdGroups terminating_outputs; + ValGroups terminating_inputs; + ValGroups terminating_outputs; { - IdGroups not_inputs; - IdGroups not_outputs; - IdGroups all_id_groups; + ValGroups not_inputs; + ValGroups not_outputs; + ValGroups all_id_groups; for (const ExprGroup& expr_group : all_exprs) { if (isTrivialExprGroup(expr_group)) { @@ -200,8 +197,8 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) continue; } - std::vector inp_groups = inputGroups(expr_group); - std::vector out_groups = outputGroups(expr_group); + std::vector inp_groups = inputGroups(expr_group); + std::vector out_groups = outputGroups(expr_group); all_id_groups.pushBack(inp_groups); not_outputs.pushBack(inp_groups); @@ -218,7 +215,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) // index is assigned to each leaf of a domain and as we traverse backwards // we're effectively accumulating indexing math. We'll only keep the fewest // expression lists to get to the iter domain. - std::unordered_map required_ind_exprs_ids; + std::unordered_map required_ind_exprs_ids; std::unordered_map required_ind_exprs_exprs; // Return if all output IterDomain groups of an expression group have @@ -228,7 +225,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) return std::all_of( output_groups.begin(), output_groups.end(), - [&](const IdGroup& output_group) { + [&](const ValGroup& output_group) { return required_ind_exprs_ids.find(output_group) != required_ind_exprs_ids.end(); }); @@ -237,7 +234,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) // Returns all expression groups in required_ind_exprs_ids of outputs auto requiredExprsOutputs = [&](ExprGroup expr_group) -> ExprGroups { ExprGroups all_output_required_exprs; - for (const IdGroup& output_id_group : outputGroups(expr_group)) { + for (const ValGroup& output_id_group : outputGroups(expr_group)) { auto id_group_exprs_it = required_ind_exprs_ids.find(output_id_group); NVF_ERROR( id_group_exprs_it != required_ind_exprs_ids.end(), @@ -258,7 +255,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) return true; }; - auto processIdGroup = [&](IdGroup id_group) -> bool { + auto processValGroup = [&](ValGroup id_group) -> bool { // Track if we've grabed any of the uses required indexing expressions. bool initialized = false; // Expression group of all indexing expressions required for this iter @@ -300,7 +297,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) }; // Backward traversal from the terminating outputs - IdGroups to_visit_ids = terminating_outputs; + ValGroups to_visit_ids = terminating_outputs; ExprGroups to_visit_exprs; while (!to_visit_ids.empty() || !to_visit_exprs.empty()) { @@ -320,8 +317,9 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) } if (processExprGroup(currently_visiting_exprs)) { something_was_processed = true; - std::vector inp_groups = inputGroups(currently_visiting_exprs); - for (const IdGroup& inp_group : inp_groups) { + std::vector inp_groups = + inputGroups(currently_visiting_exprs); + for (const ValGroup& inp_group : inp_groups) { to_visit_ids.pushBack(inp_group); } } else { @@ -331,7 +329,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) std::swap(to_visit_exprs, still_to_visit_exprs); - IdGroups still_to_visit_ids; + ValGroups still_to_visit_ids; while (!to_visit_ids.empty()) { auto currently_visiting_ids = to_visit_ids.popFront(); if (required_ind_exprs_ids.find(currently_visiting_ids) != @@ -339,7 +337,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) continue; } - if (processIdGroup(currently_visiting_ids)) { + if (processValGroup(currently_visiting_ids)) { something_was_processed = true; auto definitions_pair = getDefinitions(currently_visiting_ids); if (definitions_pair.second) { @@ -366,9 +364,9 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) // We want to traverse the expressions registered in required_ind_exprs_ids, // let's create a strict "uses path" - std::unordered_map uses_path; + std::unordered_map uses_path; for (const auto& entry : required_ind_exprs_ids) { - const IdGroup& id = entry.first; + const ValGroup& id = entry.first; const ExprGroups& traverse_exprs = entry.second; if (auto all_uses = getUses(id); all_uses.second) { uses_path[id] = traverse_exprs.intersect(all_uses.first); @@ -382,7 +380,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) ExprGroups sorted_exprs; ExprGroups to_visit_expr_groups; - for (const IdGroup& inp : terminating_inputs) { + for (const ValGroup& inp : terminating_inputs) { auto use_it = uses_path.find(inp); if (use_it == uses_path.end()) { // This can happen for a trivial traversal where inputs and outputs are @@ -395,7 +393,7 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) } } - IdGroups visited = terminating_inputs; + ValGroups visited = terminating_inputs; while (!to_visit_expr_groups.empty()) { bool something_processed = false; @@ -403,13 +401,13 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) while (!to_visit_expr_groups.empty()) { auto currently_visiting = to_visit_expr_groups.popFront(); auto inputs = inputGroups(currently_visiting); - if (std::all_of(inputs.begin(), inputs.end(), [&](IdGroup inp_id) { + if (std::all_of(inputs.begin(), inputs.end(), [&](ValGroup inp_id) { return visited.has(inp_id); })) { something_processed = true; sorted_exprs.pushBack(currently_visiting); auto outputs = outputGroups(currently_visiting); - for (const IdGroup& out_id : outputs) { + for (const ValGroup& out_id : outputs) { visited.pushBack(out_id); auto use_pair = getUses(out_id); if (!use_pair.second) { @@ -428,11 +426,10 @@ ExprGroups IdGraph::getExprsBetween(const IdGroups& from, const IdGroups& to) return sorted_exprs; } -std::unordered_map> IdGraph:: - buildMapBetween( - const std::vector& from, - const std::vector& to) const { - std::unordered_map from_ids2set; +std::unordered_map> ValGraph::buildMapBetween( + const std::vector& from, + const std::vector& to) const { + std::unordered_map from_ids2set; for (auto from_id : from) { if (!hasGroup(from_id)) { @@ -443,7 +440,7 @@ std::unordered_map> IdGraph:: // Map from the sets associated with the IterDomains in to, to those iter // domains - std::unordered_map> set2to_ids; + std::unordered_map> set2to_ids; for (auto to_id : to) { if (!hasGroup(to_id)) { @@ -459,10 +456,9 @@ std::unordered_map> IdGraph:: } } - std::unordered_map> - from_ids2to_ids; + std::unordered_map> from_ids2to_ids; for (auto from_id : from) { - from_ids2to_ids[from_id] = VectorOfUniqueEntries(); + from_ids2to_ids[from_id] = VectorOfUniqueEntries(); auto from_it = from_ids2set.find(from_id); NVF_ERROR(from_it != from_ids2set.end()); @@ -477,15 +473,14 @@ std::unordered_map> IdGraph:: return from_ids2to_ids; } -std::unordered_map> IdGraph:: - buildMapBetween( - const VectorOfUniqueEntries& from, - const VectorOfUniqueEntries& to) const { +std::unordered_map> ValGraph::buildMapBetween( + const VectorOfUniqueEntries& from, + const VectorOfUniqueEntries& to) const { return buildMapBetween(from.vector(), to.vector()); } -std::pair IdGraph::getDefinitions( - const IdGroup& id_group) const { +std::pair ValGraph::getDefinitions( + const ValGroup& id_group) const { if (!id_group) { return {{}, false}; } @@ -498,7 +493,7 @@ std::pair IdGraph::getDefinitions( } } -std::pair IdGraph::getUses(const IdGroup& id_group) const { +std::pair ValGraph::getUses(const ValGroup& id_group) const { if (!id_group) { return {{}, false}; } @@ -511,12 +506,12 @@ std::pair IdGraph::getUses(const IdGroup& id_group) const { } } -bool IdGraph::hasUses(const IdGroup& id_group) const { +bool ValGraph::hasUses(const ValGroup& id_group) const { NVF_ERROR(id_group); return unique_uses_.find(id_group) != unique_uses_.end(); } -std::string IdGraph::toString() const { +std::string ValGraph::toString() const { std::stringstream ss; ss << "IdGraph { \n"; ss << "Disjoint Ids:\n" @@ -526,8 +521,8 @@ std::string IdGraph::toString() const { return ss.str(); } -std::vector> IdGraph::isTrivialExpr(Expr* expr) { - std::vector> mapped_ids; +std::vector> ValGraph::isTrivialExpr(Expr* expr) { + std::vector> mapped_ids; if (auto merge = dynamic_cast(expr)) { if (merge->inner()->extent()->isOneInt()) { mapped_ids.push_back({merge->outer(), merge->out()}); @@ -554,7 +549,7 @@ std::vector> IdGraph::isTrivialExpr(Expr* expr) { return mapped_ids; } -bool IdGraph::transformAtributesMatch(Expr* first, Expr* second) { +bool ValGraph::transformAtributesMatch(Expr* first, Expr* second) { if (first == nullptr || second == nullptr) { return false; } @@ -594,12 +589,12 @@ bool IdGraph::transformAtributesMatch(Expr* first, Expr* second) { return true; } -void IdGraph::initializeId( - IterDomain* id, +void ValGraph::initializeVal( + Val* val, const VectorOfUniqueEntries& definitions, const VectorOfUniqueEntries& uses) { - const IdGroup& id_disjoint_set = - disjointIdSets().initializeSet(id).first->second; + const ValGroup& id_disjoint_set = + disjointValSets().initializeSet(val).first->second; ExprGroups def_groups; for (auto def : definitions) { @@ -622,16 +617,16 @@ void IdGraph::initializeId( NVF_ERROR(unique_uses_.emplace(id_disjoint_set, use_groups).second); } -bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { +bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { if (!transformAtributesMatch(first, second)) { return false; } - auto first_ids = ir_utils::filterByType( - forward ? first->inputs() : first->outputs()) - .vector(); + auto first_ids = + ir_utils::filterByType(forward ? first->inputs() : first->outputs()) + .vector(); - auto second_ids = ir_utils::filterByType( + auto second_ids = ir_utils::filterByType( forward ? second->inputs() : second->outputs()) .vector(); @@ -647,7 +642,7 @@ bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { // inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)) ? { for (const auto i : c10::irange(first_ids.size())) { - if (!disjointIdSets().permissiveAreMapped( + if (!disjointValSets().permissiveAreMapped( first_ids.at(i), second_ids.at(i))) { return false; } @@ -669,12 +664,12 @@ bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { auto extent_o_match = extent_0o->sameAs(extent_1o) || (extent_0o->isConstInt() && extent_1o->isConstInt() && extent_0o->evaluate() == extent_1o->evaluate()) || - disjointIdSets().permissiveAreMapped(merge0->outer(), merge1->outer()); + disjointValSets().permissiveAreMapped(merge0->outer(), merge1->outer()); auto extent_i_match = extent_0i->sameAs(extent_1i) || (extent_0i->isConstInt() && extent_1i->isConstInt() && extent_0i->evaluate() == extent_1i->evaluate()) || - disjointIdSets().permissiveAreMapped(merge0->inner(), merge1->inner()); + disjointValSets().permissiveAreMapped(merge0->inner(), merge1->inner()); if (!(extent_o_match || extent_i_match)) { return false; @@ -701,37 +696,37 @@ bool IdGraph::exprsMap(Expr* first, Expr* second, bool forward) const { return true; } -const ExprGroups& IdGraph::getUniqueDefinitions(const IdGroup& group) const { +const ExprGroups& ValGraph::getUniqueDefinitions(const ValGroup& group) const { auto unique_defs_it = unique_definitions_.find(group); NVF_ERROR( unique_defs_it != unique_definitions_.end(), - "Definition not found for IdGroup: ", + "Definition not found for ValGroup: ", group->toString()); return unique_defs_it->second; } -const ExprGroups& IdGraph::getUniqueUses(const IdGroup& group) const { +const ExprGroups& ValGraph::getUniqueUses(const ValGroup& group) const { auto unique_uses_it = unique_uses_.find(group); NVF_ERROR( unique_uses_it != unique_uses_.end(), - "Uses not found for IdGroup: ", + "Uses not found for ValGroup: ", group->toString()); return unique_uses_it->second; } -void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { +void ValGraph::mapIds(Val* id0, Val* id1) { if (id0 == id1) { return; } - if (disjointIdSets().strictAreMapped(id0, id1)) { + if (disjointValSets().strictAreMapped(id0, id1)) { return; } // Definitions and uses are based on the groups of id0 and id1, don't merge // them into a single group until we grab all definitions and uses for later // processing. - IdGroup orig_id_group0 = toGroup(id0); - IdGroup orig_id_group1 = toGroup(id1); + ValGroup orig_id_group0 = toGroup(id0); + ValGroup orig_id_group1 = toGroup(id1); const ExprGroups& orig_defs0 = getUniqueDefinitions(orig_id_group0); const ExprGroups& orig_defs1 = getUniqueDefinitions(orig_id_group1); const ExprGroups& orig_uses0 = getUniqueUses(orig_id_group0); @@ -740,7 +735,7 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { // Map the iter domains together before we traverse across definitions and // uses. Traversing definitions and uses could use the new property of id0 and // id1 being mapped. - disjointIdSets().mapEntries(id0, id1); + disjointValSets().mapEntries(id0, id1); auto new_id_group = toGroup(id0); unique_definitions_[new_id_group] = orig_defs0.computeUnion(orig_defs1); @@ -780,7 +775,7 @@ void IdGraph::mapIds(IterDomain* id0, IterDomain* id1) { unique_uses_.erase(orig_id_group1); } -void IdGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { +void ValGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { if (!exprsMap(expr0, expr1, forward)) { return; } @@ -797,7 +792,7 @@ void IdGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { } } -void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { +void ValGraph::mapExprs(Expr* expr0, Expr* expr1) { if (expr0 == expr1) { return; } @@ -814,35 +809,35 @@ void IdGraph::mapExprs(Expr* expr0, Expr* expr1) { auto expr_new_group = toGroup(expr0); // Update unique uses of producers - IdGroups producers; + ValGroups producers; for (auto expr : std::vector{expr0, expr1}) { - for (auto input_id : ir_utils::filterByType(expr->inputs())) { + for (auto input_id : ir_utils::filterByType(expr->inputs())) { producers.pushBack(toGroup(input_id)); } } - for (const IdGroup& producer_group : producers) { + for (const ValGroup& producer_group : producers) { unique_uses_.at(producer_group).erase(expr0_orig_group); unique_uses_.at(producer_group).erase(expr1_orig_group); unique_uses_.at(producer_group).pushBack(expr_new_group); } // Update unique definitinos of consumers - IdGroups consumers; + ValGroups consumers; for (auto expr : std::vector{expr0, expr1}) { - for (auto output_id : ir_utils::filterByType(expr->outputs())) { + for (auto output_id : ir_utils::filterByType(expr->outputs())) { consumers.pushBack(toGroup(output_id)); } } - for (const IdGroup& consumer_group : consumers) { + for (const ValGroup& consumer_group : consumers) { unique_definitions_.at(consumer_group).erase(expr0_orig_group); unique_definitions_.at(consumer_group).erase(expr1_orig_group); unique_definitions_.at(consumer_group).pushBack(expr_new_group); } } -bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { +bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { if (first == nullptr || second == nullptr) { return false; } @@ -855,10 +850,10 @@ bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { propagate_through_exprs_, "Asked to propagate expression mappings on a graph that has propagate_exprs_ disabled."); - auto first_ids = ir_utils::filterByType( - forward ? first->outputs() : first->inputs()) - .vector(); - auto second_ids = ir_utils::filterByType( + auto first_ids = + ir_utils::filterByType(forward ? first->outputs() : first->inputs()) + .vector(); + auto second_ids = ir_utils::filterByType( forward ? second->outputs() : second->inputs()) .vector(); NVF_ERROR( @@ -874,7 +869,7 @@ bool IdGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return true; } -void IdGraph::mapThroughLoopSwizzles() { +void ValGraph::mapThroughLoopSwizzles() { std::vector all_swizzles; for (const auto& expr_set : disjointExprSets().disjointSets()) { @@ -894,7 +889,7 @@ void IdGraph::mapThroughLoopSwizzles() { } } -void IdGraph::mapThroughTrivialExprs() { +void ValGraph::mapThroughTrivialExprs() { // Grab all expressions std::vector exprs; @@ -906,7 +901,7 @@ void IdGraph::mapThroughTrivialExprs() { for (auto expr : exprs) { // If not trivial continue - auto mapped_ids = IdGraph::isTrivialExpr(expr); + auto mapped_ids = ValGraph::isTrivialExpr(expr); if (mapped_ids.empty()) { continue; } @@ -920,7 +915,7 @@ void IdGraph::mapThroughTrivialExprs() { } } -void IdGraph::removeTrivialExprs() { +void ValGraph::removeTrivialExprs() { ExprGroups trivial_expr_groups; // This seems like it shouls just be a copy if. for (const ExprGroup& expr_group : disjointExprSets().disjointSets()) { @@ -942,9 +937,9 @@ void IdGraph::removeTrivialExprs() { // Complexity here is not great. We might want a better complexity version when // erasing multiple expr_groups. -void IdGraph::eraseExprGroup(const ExprGroup& expr_group) { +void ValGraph::eraseExprGroup(const ExprGroup& expr_group) { // Erase entries that exist in unique_definitions_ and unique_uses_ - for (const IdGroup& id_group : disjointIdSets().disjointSets()) { + for (const ValGroup& id_group : disjointValSets().disjointSets()) { // Make sure the entries exists NVF_ERROR( unique_definitions_.find(id_group) != unique_definitions_.end(), @@ -964,9 +959,9 @@ void IdGraph::eraseExprGroup(const ExprGroup& expr_group) { } } -bool IdGraph::isTrivialExprGroup(const ExprGroup& expr_group) const { - return !IdGroups(inputGroups(expr_group)) - .intersect(IdGroups(outputGroups(expr_group))) +bool ValGraph::isTrivialExprGroup(const ExprGroup& expr_group) const { + return !ValGroups(inputGroups(expr_group)) + .intersect(ValGroups(outputGroups(expr_group))) .empty(); } diff --git a/csrc/id_model/id_graph.h b/csrc/val_graph.h similarity index 72% rename from csrc/id_model/id_graph.h rename to csrc/val_graph.h index b2660f7537d..4740ed54f07 100644 --- a/csrc/id_model/id_graph.h +++ b/csrc/val_graph.h @@ -16,31 +16,31 @@ namespace nvfuser { -using IdGroup = std::shared_ptr>; -using IdGroups = VectorOfUniqueEntries; +using ValGroup = std::shared_ptr>; +using ValGroups = VectorOfUniqueEntries; using ExprGroup = std::shared_ptr>; using ExprGroups = VectorOfUniqueEntries; -class IdGraph { +class ValGraph { public: - IdGraph() = default; + ValGraph() = default; - IdGraph(const IdGraph& other); - IdGraph(IdGraph&& other) = default; + ValGraph(const ValGraph& other); + ValGraph(ValGraph&& other) = default; - IdGraph& operator=(const IdGraph& other); - IdGraph& operator=(IdGraph&& other) = default; + ValGraph& operator=(const ValGraph& other); + ValGraph& operator=(ValGraph&& other) = default; - IdGraph(bool propagate_through_exprs) + ValGraph(bool propagate_through_exprs) : propagate_through_exprs_(propagate_through_exprs) {} // Returns the disjoint IterDomain set. - const DisjointSets& disjointIdSets() const { - return disjoint_ids_; + const DisjointSets& disjointValSets() const { + return disjoint_vals_; } - DisjointSets& disjointIdSets() { - return disjoint_ids_; + DisjointSets& disjointValSets() { + return disjoint_vals_; } // Returns the disjoint Expr set. @@ -56,53 +56,60 @@ class IdGraph { bool hasGroup(Expr* expr) const; // Return if there's a group entry in the graph for this id - bool hasGroup(IterDomain* id) const; + bool hasGroup(Val* id) const; // Convert expr to its exprGroup, assert that it exists. const ExprGroup& toGroup(Expr* expr) const; - // Convert iter domain to its IdGroup, assert that it exists. - const IdGroup& toGroup(IterDomain* id) const; + // Convert iter domain to its ValGroup, assert that it exists. + const ValGroup& toGroup(Val* id) const; // Convert unique vector of expressions to unique vector of its groups ExprGroups toGroups(const VectorOfUniqueEntries& exprs) const; // Convert unique vector of IterDomain to unique vector of its groups - IdGroups toGroups(const VectorOfUniqueEntries& ids) const; + ValGroups toGroups(const VectorOfUniqueEntries& ids) const; + + template + ValGroups toGroups(const VectorOfUniqueEntries& vals) const { + ValGroups val_groups; + for (auto val : vals) { + val_groups.pushBack(toGroup(val)); + } + return val_groups; + } // Return output/input iter domain groups of provided expr // Note that the same IdGroup can show up multiple times, so the // output type cannot be VectorOfUniqueEntries - std::vector outputGroups(const ExprGroup& expr) const; - std::vector inputGroups(const ExprGroup& expr) const; + std::vector outputGroups(const ExprGroup& expr) const; + std::vector inputGroups(const ExprGroup& expr) const; - // Recursively traverses uses of the IdGroups in 'of' and returns all - // ExprGroups that have a use in their definition of provided of IdGroups. - ExprGroups allUsesOf(const IdGroups& of) const; + // Recursively traverses uses of the ValGroups in 'of' and returns all + // ExprGroups that have a use in their definition of provided of ValGroups. + ExprGroups allUsesOf(const ValGroups& of) const; - // Recursively traverses definitions of the IdGroups in 'of' and returns all - // ExprGroups used in this history of defining the 'of' IdGroups. - ExprGroups allDefinitionsOf(const IdGroups& of) const; + // Recursively traverses definitions of the ValGroups in 'of' and returns all + // ExprGroups used in this history of defining the 'of' ValGroups. + ExprGroups allDefinitionsOf(const ValGroups& of) const; // Return sorted expressions to go from the provided IterDomains in from to // the provided IterDomains in to with provided mode. Minimal expressions to // get from 'from' to 'to' returned. - ExprGroups getExprsBetween(const IdGroups& from, const IdGroups& to) const; + ExprGroups getExprsBetween(const ValGroups& from, const ValGroups& to) const; // Supports one to many mappings, uses the disjoint sets of the provided mode // to produce mappings between from and to. If multiple IterDomains in to map // to a single iter domain in from, the order of the IterDomains in value of // the map is preserved to be the order provided in to. - std::unordered_map> - buildMapBetween( - const std::vector& from, - const std::vector& to) const; + std::unordered_map> buildMapBetween( + const std::vector& from, + const std::vector& to) const; // Alias of the above on unique vector entries - std::unordered_map> - buildMapBetween( - const VectorOfUniqueEntries& from, - const VectorOfUniqueEntries& to) const; + std::unordered_map> buildMapBetween( + const VectorOfUniqueEntries& from, + const VectorOfUniqueEntries& to) const; //! Returns //! (1) The expressions associated with the definitions of the provided @@ -116,7 +123,7 @@ class IdGraph { //! Iter Domain set based on the provided mode. //! //! TODO-NM: ExprGroups is a real container. Consider returning a reference - std::pair getDefinitions(const IdGroup& id_group) const; + std::pair getDefinitions(const ValGroup& id_group) const; //! Same as iterDomainGroupDefinitions but for uses instead of //! definitions @@ -124,22 +131,22 @@ class IdGraph { //! TODO-NM: ExprGroups is a real container. Consider returning a //! reference //! TODO-NM: Rename to getMaybeUses. See getUses - std::pair getUses(const IdGroup& id_group) const; + std::pair getUses(const ValGroup& id_group) const; - bool hasUses(const IdGroup& id_group) const; + bool hasUses(const ValGroup& id_group) const; std::string toString() const; // Checks if the expression is a trivial operation where an input is simply an // output of the transformation. Returns the mapped iter domains if found. - static std::vector> isTrivialExpr(Expr* expr); + static std::vector> isTrivialExpr(Expr* expr); // Returns if all atributes of the ID transforms first and second are the same static bool transformAtributesMatch(Expr* first, Expr* second); // Initializes entries for the provided IterDomain in the IterDomainGraphs - void initializeId( - IterDomain* id, + void initializeVal( + Val* val, const VectorOfUniqueEntries& definitions, const VectorOfUniqueEntries& uses); @@ -156,24 +163,24 @@ class IdGraph { // Returns entry in unique_definitions_ for provided group in provided mode, // otherwise errors if no entry is found. - const ExprGroups& getUniqueDefinitions(const IdGroup& group) const; + const ExprGroups& getUniqueDefinitions(const ValGroup& group) const; // Returns entry in unique_uses_ for provided group in provided mode, // otherwise errors if no entry is found. - const ExprGroups& getUniqueUses(const IdGroup& group) const; + const ExprGroups& getUniqueUses(const ValGroup& group) const; public: - void addUniqueUses(const IdGroup& id_group, const ExprGroup& uses) { + void addUniqueUses(const ValGroup& id_group, const ExprGroup& uses) { unique_uses_.at(id_group).pushBack(uses); } - void addUniqueDefinitions(const IdGroup& id_group, const ExprGroup& defs) { + void addUniqueDefinitions(const ValGroup& id_group, const ExprGroup& defs) { unique_definitions_.at(id_group).pushBack(defs); } // Set id0 and id1 to mapped in disjointIdsSet[mode], attempt to propagate // new mapping through id0/id1 definitions/uses. - void mapIds(IterDomain* id0, IterDomain* id1); + void mapIds(Val* id0, Val* id1); // Checks if expr0 and expr1 should map together, maps them together, and if // expression propagation is on, propagates mapping through them. This should @@ -237,21 +244,21 @@ class IdGraph { // Using an array here might be nice, but it seems hard to use an enum as an // array key // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum - DisjointSets disjoint_ids_; + DisjointSets disjoint_vals_; // Keeps a disjoint set entry for all Expressions for all mapping mode types. DisjointSets disjoint_exprs_; - // Definitions of IdGroup. There can be multiple definitions due to + // Definitions of ValGroup. There can be multiple definitions due to // replays. - // TODO-NM: IdGroup by a new definition ExprGroup would not be used + // TODO-NM: ValGroup by a new definition ExprGroup would not be used // by existing uses. Does it make sense to represent uses and defs // this way? In other words, there is a traversal path from a - // definition ExprGroup to an IdGroup and its use ExprGroup, but + // definition ExprGroup to an ValGroup and its use ExprGroup, but // that does't guarantee the path actually exist - std::unordered_map unique_definitions_; + std::unordered_map unique_definitions_; - std::unordered_map unique_uses_; + std::unordered_map unique_uses_; }; } // namespace nvfuser diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index f853d34a47f..a55e68dadec 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -883,12 +883,12 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { // All of the IDs that are generated with merge operations from the // root domains should be mapped to the single group. - const IdGroup& merge_loop_group = + const ValGroup& merge_loop_group = id_model.idGraph(IdMappingMode::LOOP).toGroup(tv1->getRootDomain().at(0)); for (auto tv : {tv1, tv2, tv4, tv5, tv6, tv8, tv9}) { for (auto id : ir_utils::allIDsOf(tv)) { if (dynamic_cast(id->definition()) == nullptr) { - const IdGroup& loop_group = + const ValGroup& loop_group = id_model.idGraph(IdMappingMode::LOOP).toGroup(id); ASSERT_EQ(loop_group, merge_loop_group) << "Unexpected loop group: " << nvfuser::toString(loop_group); @@ -1002,7 +1002,8 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { if (loop_mapped_id == id) { continue; } - ASSERT_FALSE(isIdOfConsumerTensor(loop_mapped_id, tv)) + ASSERT_FALSE( + isIdOfConsumerTensor(loop_mapped_id->as(), tv)) << "Invalid promotion: " << id->toString() << " of " << tv->toString() << ". Found to mapped a consumer tensor: " << loop_mapped_id->name(); From f7f4d842140e5c4dab3a60c87d73969d99641ce6 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 29 Nov 2023 15:52:39 -0800 Subject: [PATCH 081/178] Updating IterDomainGraphs branch with main (#1412) Merging the current main, which includes #1168 --- CMakeLists.txt | 5 +- csrc/compute_at_map.h | 8 ++ csrc/device_lower/lower2device.cpp | 12 +++ csrc/disjoint_set.h | 93 ++++++++------------- csrc/expr_evaluator.cpp | 3 +- csrc/id_model/id_model.cpp | 20 ++--- csrc/id_model/validation_utils.cpp | 128 +++++++++++++++++++++++++++++ csrc/id_model/validation_utils.h | 40 +++++++++ csrc/id_model/visitor.cpp | 8 +- csrc/ir/utils.cpp | 17 ++-- csrc/ir/utils.h | 3 +- csrc/options.cpp | 1 + csrc/options.h | 1 + csrc/val_graph.cpp | 16 ++-- 14 files changed, 261 insertions(+), 94 deletions(-) create mode 100644 csrc/id_model/validation_utils.cpp create mode 100644 csrc/id_model/validation_utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ba58ae9875..dd4b140f5ca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,9 +77,10 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/executor_utils.cpp ${NVFUSER_SRCS_DIR}/fusion.cpp ${NVFUSER_SRCS_DIR}/grouped_reduction.cpp - ${NVFUSER_SRCS_DIR}/id_model/id_model.cpp + ${NVFUSER_SRCS_DIR}/id_model/id_model.cpp ${NVFUSER_SRCS_DIR}/id_model/to_string.cpp ${NVFUSER_SRCS_DIR}/id_model/transform_replay.cpp + ${NVFUSER_SRCS_DIR}/id_model/validation_utils.cpp ${NVFUSER_SRCS_DIR}/id_model/visitor.cpp ${NVFUSER_SRCS_DIR}/index_compute.cpp ${NVFUSER_SRCS_DIR}/instrumentation.cpp @@ -195,7 +196,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp ${NVFUSER_SRCS_DIR}/val_graph.cpp - ) +) # We don't link CUPTI for MSVC if(NOT MSVC) diff --git a/csrc/compute_at_map.h b/csrc/compute_at_map.h index 7aede5ea379..4079a0da035 100644 --- a/csrc/compute_at_map.h +++ b/csrc/compute_at_map.h @@ -18,6 +18,8 @@ namespace nvfuser { +class IdModelValidator; + // There's four modes of these iter domain mappings all uniquely important in // the lowering process. // @@ -169,6 +171,9 @@ class IterDomainGraph { std::optional> self_mapping_info_ = std::nullopt; + + // Temporary interface exposure for validating IdModel + friend class IdModelValidator; }; using DoubleBufferIndices = std::unordered_map; @@ -377,6 +382,9 @@ class ComputeAtMap { // Shortcut to access the fusion this computeAt map was // built from. Fusion* fusion_; + + // Temporary interface exposure for validating IdModel + friend class IdModelValidator; }; } // namespace nvfuser diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index b2f92bdea41..bdcf5fb7ac1 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -384,6 +385,17 @@ void GpuLower::analysis(Fusion* fusion) { // information. compute_at_map_ = std::make_shared(fusion_); + // Transitory testing of IdModel if enabled. No existing + // functionality should be affected. New IterDomains may be created, + // so it is expected that generated code may use diffrent variable + // names + if (isOptionEnabled(EnableOption::IdModel)) { + IdModel id_model(fusion_); + // Only the exact graph is genereated at this moment + IdModelValidator::checkExactGraphEquivalence( + id_model.idGraph(IdMappingMode::EXACT)); + } + resolveComputeWith(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "resolveComputeWith", true); diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 952c27ce220..528b594d866 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -44,22 +44,13 @@ class VectorOfUniqueEntries { public: VectorOfUniqueEntries() = default; - VectorOfUniqueEntries(const std::initializer_list& initializer) { - for (auto entry : initializer) { - pushBack(entry); - } - } + VectorOfUniqueEntries(const std::initializer_list& initializer) + : VectorOfUniqueEntries(initializer.begin(), initializer.end()) {} - VectorOfUniqueEntries(const VectorOfUniqueEntries& other) - : vector_(other.vector()), set_(other.set()) {} + VectorOfUniqueEntries(const VectorOfUniqueEntries& other) = default; - VectorOfUniqueEntries& operator=(const VectorOfUniqueEntries& other) { - if (this != &other) { - vector_ = other.vector(); - set_ = other.set(); - } - return *this; - } + VectorOfUniqueEntries& operator=(const VectorOfUniqueEntries& other) = + default; template VectorOfUniqueEntries(InputIt first, InputIt last) { @@ -123,7 +114,7 @@ class VectorOfUniqueEntries { // Returns a new VectorOfUniqueEntries with entries that are in both this and // other, order is preserved as this. - VectorOfUniqueEntries intersect( + VectorOfUniqueEntries computeIntersect( const VectorOfUniqueEntries& other) const { VectorOfUniqueEntries intersection; for (const auto& entry : vector()) { @@ -136,7 +127,7 @@ class VectorOfUniqueEntries { // Returns a new VectorOfUniqueEntries with entries that are in this but not // in other. - VectorOfUniqueEntries subtract( + VectorOfUniqueEntries computeSubtract( const VectorOfUniqueEntries& other) const { VectorOfUniqueEntries subtraction; for (const auto& entry : vector()) { @@ -151,8 +142,7 @@ class VectorOfUniqueEntries { // other. VectorOfUniqueEntries computeUnion( const VectorOfUniqueEntries& other) const { - const VectorOfUniqueEntries& this_ref = *this; - VectorOfUniqueEntries union_(this_ref); + VectorOfUniqueEntries union_(*this); for (const auto& entry : other.vector()) { union_.pushBack(entry); } @@ -307,6 +297,9 @@ class VectorOfUniqueEntries { template > class DisjointSets { public: + using DisjointSetMap = std:: + unordered_map>, Hash>; + DisjointSets() = default; DisjointSets(const DisjointSets& other); @@ -325,9 +318,7 @@ class DisjointSets { // Warning: returned values should never be modified. This accessor isn't // strictly safe as VectorOfUniqueEntries is not returned as a const. - const std:: - unordered_map>, Hash>& - disjointSetMap() const { + const DisjointSetMap& disjointSetMap() const { return disjoint_set_maps_; } @@ -349,13 +340,7 @@ class DisjointSets { } // Initializes a new set for provided entry - std::pair< - typename std::unordered_map< - T, - std::shared_ptr>, - Hash>::iterator, - bool> - initializeSet(T entry) { + std::pair initializeSet(T entry) { auto disjoint_set_maps_it = disjoint_set_maps_.find(entry); if (disjoint_set_maps_it != disjoint_set_maps_.end()) { return std::make_pair(disjoint_set_maps_it, false); @@ -388,40 +373,32 @@ class DisjointSets { std::make_shared>()); auto new_set = disjoint_sets_.back(); - if (set_0_found) { - auto set_0 = set_it_0->second; - for (auto set_0_entry : *set_0) { - NVF_ERROR(set_0_entry != entry1); - new_set->pushBack(set_0_entry); - disjoint_set_maps_[set_0_entry] = new_set; + // Add an entry to new_set along with the other entries previously + // grouped together with the entry. The existing set is erased. + auto mergeSets = [this](const T& entry, auto& new_set) { + if (auto it = disjoint_set_maps_.find(entry); + it != disjoint_set_maps_.end()) { + auto existing_set = it->second; + for (const auto& existing_entry : *existing_set) { + new_set->pushBack(existing_entry); + disjoint_set_maps_[existing_entry] = new_set; + } + disjoint_sets_.erase(std::find( + disjoint_sets_.begin(), disjoint_sets_.end(), existing_set)); + } else { + new_set->pushBack(entry); + disjoint_set_maps_[entry] = new_set; } - disjoint_sets_.erase( - std::find(disjoint_sets_.begin(), disjoint_sets_.end(), set_0)); - // Erase invalidates iterators, regrab. - set_it_1 = disjoint_set_maps_.find(entry1); - set_1_found = set_it_1 != disjoint_set_maps_.end(); - } else { - new_set->pushBack(entry0); - disjoint_set_maps_[entry0] = new_set; - } + }; + + mergeSets(entry0, new_set); // This should be after we enter a new set in case it doesn't exist. if (entry0 == entry1) { return; } - if (set_1_found) { - auto set_1 = set_it_1->second; - for (auto set_1_entry : *set_1) { - new_set->pushBack(set_1_entry); - disjoint_set_maps_[set_1_entry] = new_set; - } - disjoint_sets_.erase( - std::find(disjoint_sets_.begin(), disjoint_sets_.end(), set_1)); - } else { - new_set->pushBack(entry1); - disjoint_set_maps_[entry1] = new_set; - } + mergeSets(entry1, new_set); } // Will assert if provided entry0 is not in any disjoint set, otherwise @@ -451,7 +428,8 @@ class DisjointSets { return disjoint_set_maps_.find(entry) != disjoint_set_maps_.end(); } - // Erases element if it exists in the disjoint set, returns if element found. + // Erases element if it exists in the disjoint set. Returns true if element + // found. bool erase(T entry) { auto entry_it = disjoint_set_maps_.find(entry); if (entry_it == disjoint_set_maps_.end()) { @@ -512,8 +490,7 @@ class DisjointSets { private: // Disjoint sets - std::unordered_map>, Hash> - disjoint_set_maps_; + DisjointSetMap disjoint_set_maps_; // Keep a list of disjoint_sets that's deterministic to iterate over // diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index e6dce264fc1..e66a4a2a9ae 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -224,8 +224,9 @@ const PolymorphicValue& ExpressionEvaluator::getValue( } auto it = known_values_.find(value); - if (it != known_values_.end()) + if (it != known_values_.end()) { return it->second; + } if (&additional_known_values != &known_values_) { it = additional_known_values.find(value); diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 8acbb4a5386..c99f7325ec2 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -960,7 +960,7 @@ void IdModel::build( all_tvs.begin(), all_tvs.end()); for (auto additional_tv : additional_tvs) { if (all_added_tvs.find(additional_tv) == all_added_tvs.end()) { - all_tvs.push_back(additional_tv); + all_tvs.pushBack(additional_tv); } } } @@ -971,7 +971,7 @@ void IdModel::build( FusionGuard fg(all_tvs.front()->fusion()); // Add uses and definitions to all iter domains. - buildIterDomainDefinitionsAndUses(all_tvs); + buildIterDomainDefinitionsAndUses(all_tvs.vector()); // Initialize the maps with all the IterDomains used in the provded // expressions. @@ -989,7 +989,7 @@ void IdModel::build( // Only build loop map during lowering // TODO: make this configurable if (true || FusionGuard::getCurFusion()->isA()) { - validatePTypes(all_tvs); + validatePTypes(all_tvs.vector()); StatefulLoweringInfo info = buildInfo( tv_exprs, @@ -1030,7 +1030,7 @@ void IdModel::build( // Debug, make sure there's no self mapping in TensorView's during lowering // that would invalidate lowering assumptions. - self_mapping_info_ = findFirstSelfMapping(all_tvs, *this); + self_mapping_info_ = findFirstSelfMapping(all_tvs.vector(), *this); } VectorOfUniqueEntries IdModel::computeTerminalLoopIds( @@ -1208,7 +1208,7 @@ std::unordered_map IdModel::buildInlinePromotions( // broadcasted to, and those that exist within the same loop groop are is // the promotion needed for this iel_group. ValGroups loop_exact_resolved_intersection = - resolved_exact_groups.intersect(loop_covered_exact_groups); + resolved_exact_groups.computeIntersect(loop_covered_exact_groups); if (loop_exact_resolved_intersection.empty()) { // No resolution @@ -1232,7 +1232,7 @@ std::unordered_map IdModel::buildInlinePromotions( ValGroup exact_resolution_group = loop_exact_resolved_intersection.front(); VectorOfUniqueEntries resolved_ids = - exact_resolution_group->intersect(*loop_group); + exact_resolution_group->computeIntersect(*loop_group); auto promoted_iel_groups = intersection_exact_loop_graph.toGroups(resolved_ids); @@ -1325,7 +1325,7 @@ std::unordered_map IdModel::buildInlinePromotions( ExprGroups non_promoted_input_uses; for (const ValGroup& iel_group : - promoted_input_groups.intersect(input_groups)) { + promoted_input_groups.computeIntersect(input_groups)) { non_promoted_input_uses.pushBack( intersection_exact_loop_graph.getUniqueUses(iel_group)); } @@ -1571,7 +1571,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( IterDomain* terminal_id = entry.second; auto covered_it = exact_covered_ids.find(terminal_id_group); NVF_ERROR(covered_it != exact_covered_ids.end()); - if (loop_group_covered_ids.subtract(covered_it->second).empty()) { + if (loop_group_covered_ids.computeSubtract(covered_it->second).empty()) { loop_promotion_id = terminal_id; break; } @@ -1660,7 +1660,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // The inputs should be promoted based on the loop promotion map. bool loop_promote_inputs = - !inp_loop_groups.subtract(out_loop_groups).empty(); + !inp_loop_groups.computeSubtract(out_loop_groups).empty(); std::vector promoted_inputs; @@ -1877,7 +1877,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( auto covered_it = exact_covered_ids.find( idGraph(IdMappingMode::EXACT).toGroup(candidate_id)); NVF_ERROR(covered_it != exact_covered_ids.end()); - if (loop_group_covered_ids.subtract(covered_it->second).empty()) { + if (loop_group_covered_ids.computeSubtract(covered_it->second).empty()) { // Found VERBOSE() << "Representative found: " << candidate_id->toString() << std::endl; diff --git a/csrc/id_model/validation_utils.cpp b/csrc/id_model/validation_utils.cpp new file mode 100644 index 00000000000..f996180591c --- /dev/null +++ b/csrc/id_model/validation_utils.cpp @@ -0,0 +1,128 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include + +#include + +namespace nvfuser { + +void IdModelValidator::checkExactGraphEquivalence(const ValGraph& exact_graph) { + // Empty graph + if (exact_graph.disjointValSets().disjointSets().empty()) { + return; + } + + auto all_exprs = exact_graph.disjointExprSets().getAllElements(); + if (std::find_if(all_exprs.begin(), all_exprs.end(), [](Expr* expr) { + return expr->isA(); + }) != all_exprs.end()) { + // Ignoring a fusion with swizzle + return; + } + + Fusion* fusion = exact_graph.disjointValSets() + .disjointSets() + .at(0) + ->vector() + .at(0) + ->fusion(); + ComputeAtMap ca_map(fusion); + + DisjointSets& ca_map_exact_sets = ca_map.id_graph_.exact_nodes_; + + // Propgate mappings through expressions in ComputeAtMap. Since we + // want to traverse and update ca_map_exact_sets, once updated, the + // traversal of the ID groups cannot continue and needs to be + // restarted. The algorithm seems terriblly inefficient, but + // shuldn't matter as this is just for transitory validations + bool updated = true; + while (updated) { + updated = false; + for (const auto& set : ca_map_exact_sets.disjointSets()) { + auto uses = ca_map.uniqueExactUses(set->vector().front()); + auto use_count = uses.size(); + // Note that it should be fine to continue updating the map with + // the loop below as it should only modify output domain groups + for (size_t i = 0; i < use_count; ++i) { + auto use_i = uses.at(i); + for (size_t j = i + 1; j < use_count; ++j) { + auto use_j = uses.at(j); + if (!IterDomainGraph::exprsMap( + use_i, use_j, true, ca_map_exact_sets)) { + continue; + } + auto num_outputs = use_i->outputs().size(); + NVF_ERROR(use_j->outputs().size() == num_outputs); + for (size_t output_i = 0; output_i < num_outputs; ++output_i) { + auto out_i = use_i->output(output_i)->as(); + auto out_j = use_j->output(output_i)->as(); + if (!ca_map_exact_sets.strictAreMapped(out_i, out_j)) { + ca_map_exact_sets.mapEntries(out_i, out_j); + updated = true; + } + } + } + } + // If updated, the previous sets returned by + // ca_map_exact_sets.disjointSets() may contain stale sets + if (updated) { + ca_map.build(fusion); + break; + } + } + } + + const DisjointSets& id_model_exact_sets = exact_graph.disjointValSets(); + + if (id_model_exact_sets.size() != ca_map_exact_sets.size()) { + std::stringstream ss; + ss << "Mismatched number of groups: " << id_model_exact_sets.size() << ", " + << ca_map_exact_sets.size() << "\n"; + + ss << "IdModel exact sets:\n"; + for (const auto& id_set : id_model_exact_sets.disjointSets()) { + ss << "\t" << nvfuser::toString(id_set->vector()) << "\n"; + } + + ss << "ComputeAtMap exact sets:\n"; + for (const auto& id_set : ca_map_exact_sets.disjointSets()) { + ss << "\t" << nvfuser::toString(id_set->vector()) << "\n"; + } + + NVF_ERROR(false, ss.str()); + } + + for (const auto& id_model_id_set : id_model_exact_sets.disjointSets()) { + NVF_ERROR(!id_model_id_set->empty()); + NVF_ERROR( + ca_map_exact_sets.mappingExists( + id_model_id_set->front()->as()), + "Not found in ComputeAtMap: ", + id_model_id_set->front()->toString()); + + const auto& ca_map_id_set = ca_map_exact_sets.getDisjointSetOf( + id_model_id_set->front()->as()); + + std::unordered_set ca_map_id_set_cast; + std::copy( + ca_map_id_set.begin(), + ca_map_id_set.end(), + std::inserter(ca_map_id_set_cast, ca_map_id_set_cast.end())); + + NVF_ERROR( + id_model_id_set->set() == ca_map_id_set_cast, + "Mismatched ID set: ", + nvfuser::toString(id_model_id_set->vector()), + ", ", + nvfuser::toString(ca_map_id_set.vector())); + } +} + +} // namespace nvfuser diff --git a/csrc/id_model/validation_utils.h b/csrc/id_model/validation_utils.h new file mode 100644 index 00000000000..647fe1ba925 --- /dev/null +++ b/csrc/id_model/validation_utils.h @@ -0,0 +1,40 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include + +namespace nvfuser { + +// Note that this class is a friend of ComputeAtMap as it needs to +// have private access +class IdModelValidator { + public: + // Validate a given exact graph of IdModel by comparing it with + // ComputeAtMap. Their maps should + // be almost the same but there are some differences. + // - In ComputeAtMap, swizzles are just skipped no matter what swizzle + // type is used, so only swizzle outputs are mapped. In IdModel, + // only swizzle inputs are mapped, except for Loop swizzles where + // their inputs and outputs are mapped. + // - In ComputeAtMap, mappings are local. For example, if domain x0 is + // split to x1 and x2, and also domain y0 is split to y1 and + // y2. Suppose x0 and y0 are exactly mapped and the two splits are + // also considered exactly the same, IdModel maps x1 and y1, and x2 + // and y2, respectively, whereas that doesn't happen with ComputeAtMap + // + // Accounting for the first difference doesn't seem trivial, so when + // swizzle is used we give up validating the exact graph. The second + // difference is whether mappings are propagated, which can be + // accounted for by updating the ComputeAtMap as is done in IdModel. + static void checkExactGraphEquivalence(const ValGraph& exact_graph); +}; + +} // namespace nvfuser diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp index 24a210ed810..9feb9e0fcff 100644 --- a/csrc/id_model/visitor.cpp +++ b/csrc/id_model/visitor.cpp @@ -43,8 +43,8 @@ void IdGraphVisitor::traverse() { } auto inp_groups = ValGroups(graph().inputGroups(def)); auto out_groups = ValGroups(graph().outputGroups(def)); - if (inp_groups.subtract(all_ids).empty() && - out_groups.subtract(all_ids).empty()) { + if (inp_groups.computeSubtract(all_ids).empty() && + out_groups.computeSubtract(all_ids).empty()) { all_exprs.pushBack(def); } } @@ -70,10 +70,10 @@ void IdGraphVisitor::traverse() { } terminating_inputs = - ValGroups(all_ids.begin(), all_ids.end()).subtract(not_inputs); + ValGroups(all_ids.begin(), all_ids.end()).computeSubtract(not_inputs); terminating_outputs = - ValGroups(all_ids.begin(), all_ids.end()).subtract(not_outputs); + ValGroups(all_ids.begin(), all_ids.end()).computeSubtract(not_outputs); } ValGroups to_visit_ids = terminating_inputs; diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 79ecd64d08f..ccbe0679b03 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -134,7 +134,7 @@ std::vector normalizeOld2New( // All available new positions std::set all_positions; - for (decltype(ndims) i{0}; i < ndims; i++) { + for (auto i : c10::irange(ndims)) { all_positions.insert((int)i); } @@ -412,19 +412,14 @@ std::vector allTvs(Fusion* fusion) { return uniqueEntries(all_tvs); } -std::vector allTvsOfExprs(const std::vector& exprs) { - std::vector all_tvs; - std::unordered_set added; +VectorOfUniqueEntries allTvsOfExprs( + const std::vector& exprs) { + VectorOfUniqueEntries all_tvs; for (auto expr : exprs) { auto input_tvs = ir_utils::filterByType(expr->inputs()); auto output_tvs = ir_utils::filterByType(expr->outputs()); - for (bool input : {true, false}) { - auto& tvs = input ? input_tvs : output_tvs; - for (auto tv : tvs) { - if (added.emplace(tv).second) { - all_tvs.push_back(tv); - } - } + for (const auto& tvs : {input_tvs, output_tvs}) { + all_tvs.pushBack(tvs.begin(), tvs.end()); } } return all_tvs; diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 86afdaa557f..15b185a3852 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -314,7 +314,8 @@ std::vector outputTvsOf(std::vector tvs); std::vector allTvs(Fusion* fusion); // returns all tensor views used in the provided expressions -std::vector allTvsOfExprs(const std::vector& exprs); +VectorOfUniqueEntries allTvsOfExprs( + const std::vector& exprs); // returns all tensor views in fusion that are used between outputs and inputs // except the specified set. diff --git a/csrc/options.cpp b/csrc/options.cpp index 03ef8512b59..1fd5708ebe3 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -151,6 +151,7 @@ template <> std::unordered_map> Options< EnableOption>::getOptionsFromEnv() { const std::unordered_map available_options = { + {"id_model", EnableOption::IdModel}, {"kernel_db", EnableOption::KernelDb}, {"kernel_profile", EnableOption::KernelProfile}, {"memory_promotion", EnableOption::MemoryPromotion}, diff --git a/csrc/options.h b/csrc/options.h index 3ece7935f81..a9913817f90 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -76,6 +76,7 @@ enum class DebugDumpOption { //! These can be set through the `NVFUSER_ENABLE` environment variable //! enum class EnableOption { + IdModel, //! Enable IdModel KernelDb, //! Enable Kernel Database KernelProfile, //! Enable intra-kernel performance profiling MemoryPromotion, //! Enable promotion of memory types for non-pointwise ops diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 06e0ebdfc36..1d7bf1c5431 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -180,7 +180,8 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) // All of the expressions between from and to. Not all will be used as we // just want to define each iter domain group once. - ExprGroups all_exprs = all_uses_of_from.intersect(all_definitions_of_to); + ExprGroups all_exprs = + all_uses_of_from.computeIntersect(all_definitions_of_to); // There could be IterDomains in from or to that are between other from and // to nodes. Make sure to clear those out. @@ -206,8 +207,8 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) all_id_groups.pushBack(out_groups); not_inputs.pushBack(out_groups); } - terminating_inputs = all_id_groups.subtract(not_inputs); - terminating_outputs = all_id_groups.subtract(not_outputs); + terminating_inputs = all_id_groups.computeSubtract(not_inputs); + terminating_outputs = all_id_groups.computeSubtract(not_outputs); } // Track all expressions to get from outputs to this IterDomain. We @@ -272,7 +273,8 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) // Only worry about expressions between inputs and outputs we're // looking at. - for (const ExprGroup& use_group : uses_pair.first.intersect(all_exprs)) { + for (const ExprGroup& use_group : + uses_pair.first.computeIntersect(all_exprs)) { auto use_required_ind_exprs_it = required_ind_exprs_exprs.find(use_group); if (use_required_ind_exprs_it == required_ind_exprs_exprs.end()) { // If there isn't an entry for the use expression it wasn't @@ -369,7 +371,7 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) const ValGroup& id = entry.first; const ExprGroups& traverse_exprs = entry.second; if (auto all_uses = getUses(id); all_uses.second) { - uses_path[id] = traverse_exprs.intersect(all_uses.first); + uses_path[id] = traverse_exprs.computeIntersect(all_uses.first); } else { uses_path[id] = {}; continue; @@ -413,7 +415,7 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) if (!use_pair.second) { continue; } - still_to_visit.pushBack(use_pair.first.intersect(all_exprs)); + still_to_visit.pushBack(use_pair.first.computeIntersect(all_exprs)); } } else { still_to_visit.pushBack(currently_visiting); @@ -961,7 +963,7 @@ void ValGraph::eraseExprGroup(const ExprGroup& expr_group) { bool ValGraph::isTrivialExprGroup(const ExprGroup& expr_group) const { return !ValGroups(inputGroups(expr_group)) - .intersect(ValGroups(outputGroups(expr_group))) + .computeIntersect(ValGroups(outputGroups(expr_group))) .empty(); } From f073d049175f9c4f8e28f58d3f5220093968404f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 29 Nov 2023 17:01:49 -0800 Subject: [PATCH 082/178] Merging current main to IterDomainGraphs (#1417) --- .clang-tidy | 1 + .github/workflows/build.yml | 1 + .github/workflows/lint.yml | 3 + .github/workflows/nvfuser-ci-trigger.yml | 2 +- CMakeLists.txt | 108 +- README.md | 40 + benchmark/matmul.cpp | 75 +- cmake/Dependencies.cmake | 36 +- cmake/FlatBuffers.cmake | 1 - csrc/codegen.cpp | 283 +- csrc/compute_at_map.cpp | 36 +- .../analysis/sync_information.cpp | 3 +- csrc/device_lower/lower2device.cpp | 6 +- csrc/device_lower/pass/fusion_simplifier.cpp | 14 +- csrc/device_lower/pass/fusion_simplifier.h | 4 +- csrc/device_lower/pass/index.cpp | 24 +- csrc/device_lower/pass/index.h | 3 +- csrc/device_lower/pass/inline_ptx.cpp | 147 +- csrc/device_lower/pass/predicate.cpp | 45 +- csrc/device_lower/validation.cpp | 145 +- csrc/disjoint_set.h | 2 +- csrc/dynamic_transform.cpp | 23 +- csrc/executor.cpp | 64 +- csrc/executor_params.cpp | 14 + csrc/executor_params.h | 2 + csrc/executor_utils.cpp | 6 +- csrc/fusion.cpp | 8 +- csrc/fusion.h | 15 +- csrc/fusion_profiler.cpp | 84 +- csrc/fusion_segmenter.cpp | 40 +- csrc/index_compute.cpp | 52 +- csrc/index_compute.h | 3 +- csrc/ir/builder.cpp | 10 + csrc/ir/builder.h | 1 + csrc/ir/interface_nodes.h | 4 +- csrc/ir/internal_base_nodes.h | 19 +- csrc/ir/internal_nodes.h | 48 +- csrc/ir/iostream.cpp | 2 +- csrc/ir/nodes.cpp | 176 +- csrc/ir/printer.h | 2 +- csrc/iter_visitor.cpp | 30 +- csrc/iter_visitor.h | 30 + csrc/kernel.cpp | 10 + csrc/kernel.h | 11 +- csrc/kernel_cache.cpp | 12 +- csrc/kernel_ir.cpp | 76 +- csrc/kernel_ir.h | 11 + csrc/linked_hash_map.h | 37 +- csrc/macros.h | 6 + csrc/mma_type.cpp | 200 +- csrc/mma_type.h | 290 +- csrc/multidevice/communication.cpp | 106 +- csrc/multidevice/communication.h | 74 +- csrc/multidevice/communicator.cpp | 52 +- csrc/multidevice/communicator.h | 45 +- csrc/multidevice/device_mesh.h | 9 - csrc/multidevice/lower_communication.cpp | 56 +- csrc/multidevice/utils.cpp | 31 + csrc/multidevice/utils.h | 19 + csrc/ops/arith.cpp | 3 +- csrc/optimization/alias_analysis.cpp | 337 +- csrc/optimization/alias_analysis.h | 23 +- csrc/optimization/mark_alias.cpp | 66 +- csrc/polymorphic_value.h | 12 + csrc/predicate_compute.cpp | 8 +- csrc/python_frontend/fusion_cache.cpp | 266 +- csrc/python_frontend/fusion_cache.h | 25 +- csrc/python_frontend/fusion_definition.cpp | 6 +- csrc/python_frontend/fusion_record.h | 312 +- csrc/python_frontend/fusion_state.cpp | 8 +- csrc/python_frontend/python_bindings.cpp | 449 ++- .../test/test_nvfuser_fusion_cache.cpp | 10 +- .../test/test_nvfuser_fusion_definition.cpp | 10 +- .../test/test_nvfuser_fusion_record.cpp | 32 +- csrc/root_domain_map.cpp | 4 +- csrc/scheduler/cache_policy_refiner.cpp | 11 +- csrc/scheduler/matmul.cpp | 966 +++-- csrc/scheduler/matmul_heuristic.h | 4 +- csrc/scheduler/matmul_utils.cpp | 64 +- csrc/scheduler/mma_utils.cpp | 874 ++-- csrc/scheduler/mma_utils.h | 73 +- csrc/scheduler/normalization_inner.cpp | 22 +- csrc/scheduler/normalization_inner_outer.cpp | 27 +- csrc/scheduler/normalization_outer.cpp | 22 +- csrc/scheduler/normalization_utils.cpp | 3 +- csrc/scheduler/normalization_utils.h | 1 + csrc/scheduler/reduction_heuristic.h | 3 +- csrc/scheduler/registry_utils.cpp | 7 +- csrc/serde/factory.h | 21 +- csrc/serde/fusion_cache.fbs | 18 +- csrc/serde/fusion_record_serde.cpp | 485 ++- csrc/serde/fusion_record_serde.h | 5 +- csrc/serde/polymorphic_value_serde.cpp | 37 +- csrc/serde/polymorphic_value_serde.h | 13 +- csrc/tensor_view.cpp | 37 +- csrc/transform_replay.cpp | 81 +- csrc/transform_view.cpp | 32 +- csrc/type.cpp | 5 + csrc/type.h | 9 +- lib/dynamic_type/CMakeLists.txt | 6 +- python_tests/pytest_input_generators.py | 20 +- python_tests/pytest_opinfos.py | 36 +- python_tests/test_python_frontend.py | 272 +- runtime/memory.cu | 128 +- runtime/tensorcore.cu | 270 -- test/multidevice.cpp | 112 +- test/multidevice.h | 8 +- test/test_alias.cpp | 401 +- test/test_allocation_domain.cpp | 96 +- test/test_dynamic_transform.cpp | 90 +- test/test_external_src.cpp | 2 +- test/test_fusion_profiler.cpp | 16 +- test/test_gather.cpp | 55 +- test/test_gpu1.cpp | 79 +- test/test_gpu2.cpp | 57 +- test/test_gpu3.cpp | 155 - test/test_gpu_compute_with.cpp | 8 +- test/test_gpu_indexing.cpp | 85 +- test/test_gpu_tensorcore.cpp | 3498 ++++------------- test/test_gpu_utils.cpp | 2 +- test/test_gpu_view.cpp | 247 +- test/test_linked_hash_map.cpp | 61 + test/test_loop_rotation.cpp | 84 +- test/test_matmul_sass.cpp | 32 +- test/test_matmul_scheduler.cpp | 1318 +++++-- test/test_memory.cpp | 119 +- test/test_mma.cpp | 173 + test/test_multidevice_communications.cpp | 113 +- test/test_multidevice_pipeline.cpp | 7 + test/test_no_op.cpp | 27 +- test/test_optimization_pass.cpp | 40 +- test/test_pointwise.cpp | 179 + test/test_resize.cpp | 169 +- test/test_scalar_hoisting.cpp | 21 +- test/test_swizzle.cpp | 17 +- test/test_tensor_factories.cpp | 96 +- test/utils.cpp | 128 +- test/utils.h | 31 +- test/validator.h | 23 +- tools/codediff/diff_report.py | 35 +- tools/codediff/run_command.sh | 2 + .../codediff/templates/command_env_info.html | 16 +- version.txt | 2 +- 143 files changed, 7486 insertions(+), 7898 deletions(-) create mode 100644 README.md create mode 100644 csrc/multidevice/utils.cpp create mode 100644 csrc/multidevice/utils.h delete mode 100644 runtime/tensorcore.cu create mode 100644 test/test_mma.cpp create mode 100644 test/test_pointwise.cpp diff --git a/.clang-tidy b/.clang-tidy index e641c2a41de..591a39a4062 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -40,6 +40,7 @@ modernize-*, -modernize-use-using, -modernize-use-trailing-return-type, -modernize-use-nodiscard, +-modernize-loop-convert, performance-*, -performance-noexcept-move-constructor, -performance-unnecessary-value-param, diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 18cd198d203..8a05c3aaa26 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,6 +30,7 @@ jobs: tools/apt-install-things.sh & tools/pip-install-things.sh & source tools/setup-env.sh + sudo rm -rf /usr/lib/gcc/x86_64-linux-gnu/13 export CC=clang export CXX=clang++ wait diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 79548759c42..3e64b3d7cea 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -30,6 +30,9 @@ jobs: tools/pip-install-things.sh & source tools/setup-env.sh + # clang-tidy does not work well with gcc-13 headers, remove them + sudo rm -rf /usr/lib/gcc/x86_64-linux-gnu/13 + # Install lintrunner pip install lintrunner diff --git a/.github/workflows/nvfuser-ci-trigger.yml b/.github/workflows/nvfuser-ci-trigger.yml index 5c8bd273abf..8d95a00f5c5 100644 --- a/.github/workflows/nvfuser-ci-trigger.yml +++ b/.github/workflows/nvfuser-ci-trigger.yml @@ -16,7 +16,7 @@ jobs: # This job only runs for pull request comments if: | - contains(',xwang233,jjsjann123,chang-l,csarofeen,drzejan2,IvanYashchuk,jacobhinkle,kevinstephano,liqiangxl,mmigdal-nv,naoyam,ptrblck,rdspring1,samnordmann,zasdfgbnm,crcrpar,nWEIdia,Priya2698,wujingyue,tfogal,protonu,', format(',{0},', github.actor)) && + contains(',xwang233,jjsjann123,chang-l,csarofeen,drzejan2,IvanYashchuk,jacobhinkle,kevinstephano,liqiangxl,mmigdal-nv,naoyam,ptrblck,rdspring1,samnordmann,zasdfgbnm,crcrpar,nWEIdia,Priya2698,wujingyue,tfogal,protonu,cowanmeg,', format(',{0},', github.actor)) && startsWith(github.event.comment.body, '!build') steps: - name: Check if comment is issued by authorized person diff --git a/CMakeLists.txt b/CMakeLists.txt index dd4b140f5ca..1d6dfdc73c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,13 @@ option(NVFUSER_STANDALONE_BUILD_WITH_UCC "" OFF) option(NVFUSER_BUILD_WITH_ASAN "Build nvFuser with asan" OFF) if(NOT NVFUSER_CPP_STANDARD) - set(NVFUSER_CPP_STANDARD 17) + set(NVFUSER_CPP_STANDARD 20) +endif() + +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 11.4) + message(FATAL_ERROR "GCC < 11.4 has compiler bugs and can not compile nvFuser.") + endif() endif() string(APPEND CMAKE_CXX_FLAGS " -Wno-psabi") @@ -31,12 +37,10 @@ set(ATEN_CUDA_ROOT "${TORCH_INSTALL_PREFIX}/include/ATen") string(APPEND CMAKE_CXX_FLAGS " ${TORCH_CXX_FLAGS}") include(cmake/FlatBuffers.cmake) -if(BUILD_NVFUSER_BENCHMARK) - include(cmake/Dependencies.cmake) -endif() +include(cmake/Dependencies.cmake) # set CUDA_ARCH for cu tests. -if(EXISTS ${TORCH_CUDA_ARCH_LIST}) +if(TORCH_CUDA_ARCH_LIST) set(ARCH_FLAGS) cuda_select_nvcc_arch_flags(ARCH_FLAGS ${TORCH_CUDA_ARCH_LIST}) list(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS}) @@ -139,6 +143,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/multidevice/pipeline.cpp ${NVFUSER_SRCS_DIR}/multidevice/pipeline_ir.cpp ${NVFUSER_SRCS_DIR}/multidevice/runtime.cpp + ${NVFUSER_SRCS_DIR}/multidevice/utils.cpp ${NVFUSER_SRCS_DIR}/mutator.cpp ${NVFUSER_SRCS_DIR}/non_divisible_split.cpp ${NVFUSER_SRCS_DIR}/ops/alias.cpp @@ -216,44 +221,45 @@ endif() set(NVFUSER_CODEGEN ${PROJECT_NAME}_codegen) add_library(${NVFUSER_CODEGEN} SHARED ${NVFUSER_SRCS}) -target_compile_options(${NVFUSER_CODEGEN} PRIVATE -Wall -Wno-unused-function) -target_compile_options(${NVFUSER_CODEGEN} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") if(NOT MSVC) - target_compile_options(${NVFUSER_CODEGEN} PRIVATE -Werror) + target_compile_options(${NVFUSER_CODEGEN} PRIVATE + -Wall -Wno-unused-function + -Werror + ) endif() -# NB: This must be target_compile_definitions, not target_compile_options, -# as the latter is not respected by nvcc target_compile_definitions(${NVFUSER_CODEGEN} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") - -# Link flatbuffers for serialization support -target_link_libraries(${NVFUSER_CODEGEN} PRIVATE flatbuffers) - -# For kernel_db, linking STL Filesystem Library for backward compatability with C++14 -target_link_libraries(${NVFUSER_CODEGEN} PRIVATE stdc++fs) -target_link_libraries(${NVFUSER_CODEGEN} PRIVATE dynamic_type) -target_link_libraries(${NVFUSER_CODEGEN} PRIVATE ${CUDA_NVRTC_LIB} ${LIBNVTOOLSEXT} ${LIBCUPTI}) -list(APPEND CUDA_INCLUDE_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/extras/CUPTI/include/) -target_include_directories(${NVFUSER_CODEGEN} PRIVATE ${CUDA_INCLUDE_DIRS}) - -# TODO: we should guard the include of gloo -target_include_directories(${NVFUSER_CODEGEN} PRIVATE ${PROJECT_SOURCE_DIR}/third_party/gloo) +target_include_directories(${NVFUSER_CODEGEN} SYSTEM PRIVATE + ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include + ${CMAKE_SOURCE_DIR}/third_party/gloo # TODO: guard this on usage + ${CUDA_TOOLKIT_ROOT_DIR}/extras/CUPTI/include + ${CUDA_INCLUDE_DIRS} +) target_include_directories(${NVFUSER_CODEGEN} PUBLIC "$" "$" ) set_property(TARGET ${NVFUSER_CODEGEN} PROPERTY CXX_STANDARD ${NVFUSER_CPP_STANDARD}) +# Ensure we don't link against libcuda; we'll dlopen it ourselves. +list(FILTER TORCH_LIBRARIES EXCLUDE REGEX "libcuda\.so") +target_link_libraries(${NVFUSER_CODEGEN} PRIVATE + flatbuffers + stdc++fs # for compatibility with C++14 + dynamic_type + ${CUDA_NVRTC_LIB} + ${LIBNVTOOLSEXT} + ${LIBCUPTI} + ${TORCH_LIBRARIES} + dl +) + if(NVFUSER_BUILD_WITH_ASAN) add_compile_options(-fsanitize=address) add_link_options(-fsanitize=address) endif() -list(FILTER TORCH_LIBRARIES EXCLUDE REGEX "libcuda\.so") - -target_link_libraries(${NVFUSER_CODEGEN} PRIVATE torch ${TORCH_LIBRARIES} dl) - # this is to find pip installed nvrtc/nvtx .so set_target_properties(${NVFUSER_CODEGEN} PROPERTIES INSTALL_RPATH "$ORIGIN/../../nvidia/cuda_runtime/lib:$ORIGIN/../../nvidia/cuda_nvrtc/lib:$ORIGIN/../../nvidia/nvtx/lib:$ORIGIN/../../nvidia/cuda_cupti/lib:$ORIGIN/../../torch/lib") @@ -266,7 +272,7 @@ add_custom_command( DEPENDS ${NVFUSER_ROOT}/csrc/serde/fusion_cache.fbs DEPENDS flatc - COMMAND ${CMAKE_CURRENT_BINARY_DIR}/third_party/flatbuffers/flatc --gen-object-api -o ${NVFUSER_ROOT}/csrc/serde/ -c -b ${NVFUSER_ROOT}/csrc/serde/fusion_cache.fbs + COMMAND ${CMAKE_CURRENT_BINARY_DIR}/third_party/flatbuffers/flatc --scoped-enums -o ${NVFUSER_ROOT}/csrc/serde/ -c -b ${NVFUSER_ROOT}/csrc/serde/fusion_cache.fbs COMMENT "Generating fusion_cache_generated header from fusion_cache.fbs" VERBATIM ) @@ -319,6 +325,9 @@ if(BUILD_PYTHON) set(NVFUSER "${PROJECT_NAME}") add_library(${NVFUSER} MODULE ${NVFUSER_PYTHON_SRCS}) set_property(TARGET ${NVFUSER} PROPERTY CXX_STANDARD ${NVFUSER_CPP_STANDARD}) + target_include_directories(${NVFUSER} SYSTEM PRIVATE + ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include + ) # setup python API version add_custom_command( @@ -344,6 +353,7 @@ if(BUILD_PYTHON) target_compile_definitions(${NVFUSER} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") if(NOT MSVC) + target_compile_options(${NVFUSER} PRIVATE -Wall -Wno-unused-function) target_compile_options(${NVFUSER} PRIVATE -Werror) set_target_properties(${NVFUSER} PROPERTIES SUFFIX ".so") else() @@ -355,7 +365,6 @@ if(BUILD_PYTHON) target_compile_definitions(${NVFUSER} PRIVATE EXTENSION_NAME=_C) - target_compile_options(${NVFUSER} PRIVATE -Wall -Wno-unused-function) target_link_libraries(${NVFUSER} PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(${NVFUSER} PRIVATE "${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so") target_link_libraries(${NVFUSER} PRIVATE dynamic_type) @@ -403,6 +412,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/test/test_matmul_scheduler.cpp ${NVFUSER_ROOT}/test/test_mbarrier.cpp ${NVFUSER_ROOT}/test/test_memory.cpp + ${NVFUSER_ROOT}/test/test_mma.cpp ${NVFUSER_ROOT}/test/test_gpu_view.cpp ${NVFUSER_ROOT}/test/test_gpu_transpose.cpp ${NVFUSER_ROOT}/test/test_gpu_utils.cpp @@ -420,6 +430,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/test/test_scalar_hoisting.cpp ${NVFUSER_ROOT}/test/test_no_op.cpp ${NVFUSER_ROOT}/test/test_linked_hash_map.cpp + ${NVFUSER_ROOT}/test/test_pointwise.cpp ) # We don't link CUPTI for MSVC @@ -448,14 +459,17 @@ if(BUILD_TEST) add_executable(${NVFUSER_TESTS} ${JIT_TEST_SRCS}) set_property(TARGET ${NVFUSER_TESTS} PROPERTY CXX_STANDARD ${NVFUSER_CPP_STANDARD}) target_compile_definitions(${NVFUSER_TESTS} PRIVATE USE_GTEST) - target_link_libraries(${NVFUSER_TESTS} PRIVATE ${NVFUSER_CODEGEN} ${NVFUSER_TESTS_KERNELS} dynamic_type gtest_main gmock_main flatbuffers) target_include_directories(${NVFUSER_TESTS} PRIVATE "${NVFUSER_ROOT}") - - target_compile_options(${NVFUSER_TESTS} PRIVATE -Wall -Wno-unused-function) - target_link_libraries(${NVFUSER_TESTS} PRIVATE ${TORCH_LIBRARIES}) + target_include_directories(${NVFUSER_TESTS} SYSTEM PRIVATE + ${CMAKE_SOURCE_DIR}/third_party/googletest/googletest/include + ${CMAKE_SOURCE_DIR}/third_party/googletest/googlemock/include + ) + target_link_libraries(${NVFUSER_TESTS} PRIVATE ${NVFUSER_CODEGEN} ${NVFUSER_TESTS_KERNELS} dynamic_type GTest::gtest_main GTest::gmock_main flatbuffers ${TORCH_LIBRARIES}) if(NOT MSVC) - set_property(SOURCE ${JIT_TEST_SRCS} APPEND PROPERTY COMPILE_OPTIONS "-Werror") + target_compile_options(${NVFUSER_TESTS} PRIVATE + -Wall -Wno-unused-function -Werror + ) endif() endif() @@ -501,17 +515,26 @@ if(BUILD_NVFUSER_BENCHMARK) add_executable(${NVFUSER_BENCHMARK} ${BENCHMARK_SRCS}) set_property(TARGET ${NVFUSER_BENCHMARK} PROPERTY CXX_STANDARD ${NVFUSER_CPP_STANDARD}) - target_compile_options(${NVFUSER_BENCHMARK} PRIVATE -Wall -Wno-unused-function) - target_link_libraries(${NVFUSER_BENCHMARK} PRIVATE dynamic_type) - target_link_libraries(${NVFUSER_BENCHMARK} PRIVATE ${TORCH_LIBRARIES}) - target_link_libraries(${NVFUSER_BENCHMARK} PRIVATE benchmark::benchmark) + target_include_directories(${NVFUSER_BENCHMARK} SYSTEM PRIVATE + ${CMAKE_SOURCE_DIR}/third_party/benchmark/include + ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include + ${CMAKE_SOURCE_DIR}/third_party/googletest/googletest/include + ) + target_include_directories(${NVFUSER_BENCHMARK} PUBLIC ${NVFUSER_ROOT}) + target_link_libraries(${NVFUSER_BENCHMARK} PRIVATE + dynamic_type + ${TORCH_LIBRARIES} + benchmark::benchmark + ${NVFUSER_CODEGEN} + ) + add_dependencies(${NVFUSER_BENCHMARK} flatc build_flatbuffer_config) if(NOT MSVC) - target_compile_options(${NVFUSER_BENCHMARK} PRIVATE -Werror -Wno-deprecated-copy) + target_compile_options(${NVFUSER_BENCHMARK} PRIVATE + -Wall -Wno-unused-function + -Werror -Wno-deprecated-copy + ) endif() - - target_link_libraries(${NVFUSER_BENCHMARK} PRIVATE ${NVFUSER_CODEGEN}) - target_include_directories(${NVFUSER_BENCHMARK} PRIVATE ${NVFUSER_ROOT}) endif() # --- generate runtime files @@ -542,7 +565,6 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/memory.cu ${NVFUSER_ROOT}/runtime/random_numbers.cu ${NVFUSER_ROOT}/runtime/tensor.cu - ${NVFUSER_ROOT}/runtime/tensorcore.cu ${NVFUSER_ROOT}/runtime/tuple.cu ${NVFUSER_ROOT}/runtime/type_traits.cu ${NVFUSER_ROOT}/runtime/warp.cu diff --git a/README.md b/README.md new file mode 100644 index 00000000000..2d1b313f1f0 --- /dev/null +++ b/README.md @@ -0,0 +1,40 @@ +# Fuser + +A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser") + +## Installation + +We publish nightly wheel packages on https://pypi.nvidia.com + +built-env | cuda 11.8 | cuda 12.1 +:---: | :---: | :---: +torch 2.1 | nvfuser-cu118-torch21 | nvfuser-cu121-torch21 +torch nightly wheel | nvfuser-cu118 | nvfuser-cu121 + +Note that nvfuser built against torch-2.1 isn't compatible with nightly pytorch wheel, so ensure you pick the right version suiting your environment. + +You can instll a given nvfuser version with `pip install --pre nvfuser-cu121 --extra-index-url https://pypi.nvidia.com` + +As we build against nightly torch wheel and there's no compatibility promised on nightly wheels, we have explicitly marked the nightly torch wheel as an optinoal dependency. You can choose to install the torch wheel along with nvfuser package. e.g. +`pip install --pre "nvfuser-cu121[torch]" --extra-index-url https://pypi.nvidia.com`. +Note that this may uninstall your local pytorch installation and install the compatible nightly pytorch. + +Versioned nvfuser will be published on pypi.org [WIP] + +PyPI: [https://pypi.org/project/nvfuser/](https://pypi.org/search/?q=nvfuser) + + +## Developer + +Getting started: https://github.com/NVIDIA/Fuser/wiki/Getting-started +Build: https://github.com/NVIDIA/Fuser/wiki/Building-fuser-project + +Supported compilers: +- gcc 11.4+ +- clang14+ + +Supported C++ standard: +- C++17 +- C++20 + +We are actively considering dropping C++17 support diff --git a/benchmark/matmul.cpp b/benchmark/matmul.cpp index 2605eb29eb1..7a7a5f9746e 100644 --- a/benchmark/matmul.cpp +++ b/benchmark/matmul.cpp @@ -40,14 +40,11 @@ bool hasRequiredSmemSize(size_t required_size) { return; \ } -// util to track support matmul operand layout. -using MatmulLayout = MmaOptions::MmaLayout; - // TODO: separate compute and schedule definition once the can schedule // logic and pattern matching is ready. void setupMatmul( Fusion* fusion, - MatmulLayout layout, + MmaLayout layout, MatmulParams params, bool turing_or_later // TODO: This is a temporary solution. Remove this! ) { @@ -122,7 +119,7 @@ void checkMatch(at::Tensor expect, at::Tensor result, int64_t k) { static void SingleMatmulBase( benchmark::State& benchmark_state, - MatmulLayout layout, + MmaLayout layout, MatmulParams params) { std::vector input_mnk{ benchmark_state.range(0), @@ -184,7 +181,7 @@ static void SingleMatmulBase( static void Baseline_Matmul( benchmark::State& benchmark_state, - MatmulLayout layout) { + MmaLayout layout) { std::vector input_mnk{ benchmark_state.range(0), benchmark_state.range(1), @@ -221,7 +218,7 @@ size_t getSmemSize(GemmTile cta_tile, int stage_number) { MatmulParams getMatmulParams( GemmTile cta_tile, int stage_number, - MatmulLayout layout, + MmaLayout layout, int splitk_factor = 1) { MatMulTileOptions gemm_tile; gemm_tile.cta_tile = cta_tile; @@ -230,7 +227,7 @@ MatmulParams getMatmulParams( gemm_tile.instruction_tile = GemmTile(16, 16, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_16_16; + params.mma_macro = MmaMacro::Ampere_16_16_16; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; @@ -274,7 +271,7 @@ int computeAutoSplitKFactor( // for comparing against the first kernel in Cutlass's two-kernel split-K. static void SingleMatmulPartitionedK( benchmark::State& benchmark_state, - MatmulLayout layout, + MmaLayout layout, MatmulParams params, int64_t splitk_factor) { int64_t M = benchmark_state.range(0); @@ -353,7 +350,7 @@ static void SingleMatmulPartitionedK( static void NvFuserScheduler_Matmul( benchmark::State& benchmark_state, - MatmulLayout layout, + MmaLayout layout, int splitk_factor = 1, bool partitionedk = false) { int num_warps = benchmark_state.range(3); @@ -512,7 +509,7 @@ static std::vector splitKNs(long int tileN = 128) { { 65536 } #define Layouts \ - { MatmulLayout::TT, MatmulLayout::TN, MatmulLayout::NT, MatmulLayout::NN } + { MmaLayout::TT, MmaLayout::TN, MmaLayout::NT, MmaLayout::NN } #define NumWarps \ { 4, 8 } #define NumStages \ @@ -592,43 +589,43 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) { } } -#define EagerModeBenchmark(layout) \ - BENCHMARK_CAPTURE( \ - Baseline_Matmul, eagermode_legacyshapes_##layout, MatmulLayout::layout) \ - ->Unit(benchmark::kMicrosecond) \ - ->UseManualTime() \ - ->Apply([](benchmark::internal::Benchmark* b) { \ - return MatmulShape( \ - b, sizeProduct(LegacyMs, LegacyNs, LegacyKs)); \ - }); \ - BENCHMARK_CAPTURE( \ - Baseline_Matmul, eagermode_timmshapes_##layout, MatmulLayout::layout) \ - ->Unit(benchmark::kMicrosecond) \ - ->UseManualTime() \ - ->Apply([](benchmark::internal::Benchmark* b) { \ - return MatmulShape(b, TIMMShapes); \ - }); \ - BENCHMARK_CAPTURE( \ - Baseline_Matmul, eagermode_splitkshapes_##layout, MatmulLayout::layout) \ - ->Unit(benchmark::kMicrosecond) \ - ->UseManualTime() \ - ->Apply([](benchmark::internal::Benchmark* b) { \ - return MatmulShape( \ - b, sizeProduct(SplitKMs, splitKNs(), SplitKKs)); \ +#define EagerModeBenchmark(layout) \ + BENCHMARK_CAPTURE( \ + Baseline_Matmul, eagermode_legacyshapes_##layout, MmaLayout::layout) \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime() \ + ->Apply([](benchmark::internal::Benchmark* b) { \ + return MatmulShape( \ + b, sizeProduct(LegacyMs, LegacyNs, LegacyKs)); \ + }); \ + BENCHMARK_CAPTURE( \ + Baseline_Matmul, eagermode_timmshapes_##layout, MmaLayout::layout) \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime() \ + ->Apply([](benchmark::internal::Benchmark* b) { \ + return MatmulShape(b, TIMMShapes); \ + }); \ + BENCHMARK_CAPTURE( \ + Baseline_Matmul, eagermode_splitkshapes_##layout, MmaLayout::layout) \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime() \ + ->Apply([](benchmark::internal::Benchmark* b) { \ + return MatmulShape( \ + b, sizeProduct(SplitKMs, splitKNs(), SplitKKs)); \ }); #define NvfuserMatmulBenchmark(layout) \ BENCHMARK_CAPTURE( \ NvFuserScheduler_Matmul, \ nvfuser_nosplitk_legacyshapes_##layout, \ - MatmulLayout::layout) \ + MmaLayout::layout) \ ->Unit(benchmark::kMicrosecond) \ ->UseManualTime() \ ->Apply(MatmulShapeWarpStageAutoSplitK); \ BENCHMARK_CAPTURE( \ NvFuserScheduler_Matmul, \ nvfuser_nosplitk_timmshapes_##layout, \ - MatmulLayout::layout) \ + MmaLayout::layout) \ ->Unit(benchmark::kMicrosecond) \ ->UseManualTime() \ ->Apply([](benchmark::internal::Benchmark* b) { \ @@ -637,7 +634,7 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) { BENCHMARK_CAPTURE( \ NvFuserScheduler_Matmul, \ nvfuser_nosplitk_splitkshapes_##layout, \ - MatmulLayout::layout) \ + MmaLayout::layout) \ ->Unit(benchmark::kMicrosecond) \ ->UseManualTime() \ ->Apply([](benchmark::internal::Benchmark* b) { \ @@ -655,7 +652,7 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) { BENCHMARK_CAPTURE( \ NvFuserScheduler_Matmul, \ nvfuser_auto_splitk_##layout, \ - MatmulLayout::layout, \ + MmaLayout::layout, \ -1) \ ->Unit(benchmark::kMicrosecond) \ ->UseManualTime() \ @@ -665,7 +662,7 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) { BENCHMARK_CAPTURE( \ NvFuserScheduler_Matmul, \ nvfuser_auto_partitionedk_##layout, \ - MatmulLayout::layout, \ + MmaLayout::layout, \ -1, \ true) \ ->Unit(benchmark::kMicrosecond) \ diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 343c52387b8..28475b16503 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -12,18 +12,7 @@ set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE) set(INSTALL_GTEST OFF CACHE BOOL "Install gtest." FORCE) set(BUILD_GMOCK ON CACHE BOOL "Build gmock." FORCE) -# Add googletest subdirectory but make sure our INCLUDE_DIRECTORIES do not bleed into it. -# This is because libraries installed into the root conda env (e.g. MKL) add a global /opt/conda/include directory, -# and if there is gtest installed in conda, the third_party/googletest/**.cc source files would try to include headers -# from /opt/conda/include/gtest/**.h instead of its own. Once we have proper target-based include directories, -# this shouldn't be necessary anymore. -get_property(INC_DIR_temp DIRECTORY PROPERTY INCLUDE_DIRECTORIES) -set_property(DIRECTORY PROPERTY INCLUDE_DIRECTORIES "") -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest) -set_property(DIRECTORY PROPERTY INCLUDE_DIRECTORIES ${INC_DIR_temp}) - -include_directories(BEFORE SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest/googletest/include) -include_directories(BEFORE SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest/googlemock/include) +add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/googletest) # We will not need to test benchmark lib itself. set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark testing as we don't need it.") @@ -31,17 +20,16 @@ set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark testing as we don set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable benchmark install to avoid overwriting vendor install.") if(NOT USE_SYSTEM_BENCHMARK) -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/benchmark) + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/benchmark) else() -add_library(benchmark SHARED IMPORTED) -find_library(BENCHMARK_LIBRARY benchmark) -if(NOT BENCHMARK_LIBRARY) - message(FATAL_ERROR "Cannot find google benchmark library") + add_library(benchmark SHARED IMPORTED) + find_library(BENCHMARK_LIBRARY benchmark) + if(NOT BENCHMARK_LIBRARY) + message(FATAL_ERROR "Cannot find google benchmark library") + endif() + message("-- Found benchmark: ${BENCHMARK_LIBRARY}") + set_property(TARGET benchmark PROPERTY IMPORTED_LOCATION ${BENCHMARK_LIBRARY}) endif() -message("-- Found benchmark: ${BENCHMARK_LIBRARY}") -set_property(TARGET benchmark PROPERTY IMPORTED_LOCATION ${BENCHMARK_LIBRARY}) -endif() -include_directories(${CMAKE_CURRENT_LIST_DIR}/../third_party/benchmark/include) # Recover build options. set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE) @@ -50,7 +38,7 @@ set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" F # Without this is cross compiling we end up having to blow build directory # and rebuild from scratch. if(CMAKE_CROSSCOMPILING) -if(COMPILE_HAVE_STD_REGEX) - set(RUN_HAVE_STD_REGEX 0 CACHE INTERNAL "Cache RUN_HAVE_STD_REGEX output for cross-compile.") -endif() + if(COMPILE_HAVE_STD_REGEX) + set(RUN_HAVE_STD_REGEX 0 CACHE INTERNAL "Cache RUN_HAVE_STD_REGEX output for cross-compile.") + endif() endif() diff --git a/cmake/FlatBuffers.cmake b/cmake/FlatBuffers.cmake index 83c97204197..e17728dec8d 100644 --- a/cmake/FlatBuffers.cmake +++ b/cmake/FlatBuffers.cmake @@ -8,4 +8,3 @@ option(FLATBUFFERS_BUILD_FLATHASH "Enable the build of flathash" OFF) # Add FlatBuffers directly to our build. This defines the `flatbuffers` target. add_subdirectory(${FlatBuffers_Src_Dir}) -include_directories(BEFORE SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../third_party/flatbuffers/include) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 90d8de1a678..6c1595ec75b 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -534,17 +534,34 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { void handle(const kir::TensorIndex* ti) final { bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global && kernel_->summary().sync_map->needsRawSync(ti->view()).hasBID(); + bool is_pointer = isPointerType(ti->index()->dtype()); + if (is_pointer) { + bool is_u32_ptr = ti->index()->dtype() == DataType::SMemAddress; + if (is_u32_ptr) { + // DataType::SMemAddress is implemented as uint32_t in C++. The problem + // for this implementation is, the type promotion rule in C++ for + // uint32_t mismatch with the type promotion rule for + // DataType::SMemAddress in nvFuser. As a workaround, we always cast to + // the correct type in the generated code. + code_ << "(uint32_t)("; + } + code_ << genInline(ti->index()); + if (is_u32_ptr) { + code_ << ")"; + } + return; + } bool different_dtype = ti->view()->dtype() != ti->dtype(); if (is_volatile) { code_ << "*(volatile " << ti->getDataType().value() << "*)&"; } if (different_dtype) { - code_ << "*reinterpret_cast<" << ti->getDataType().value() << "*>(&"; + code_ << "(*reinterpret_cast<" << ti->getDataType().value() << "*>(&"; } code_ << genVariableName(ti->view()) << "[" << genInline(ti->index()) << "]"; if (different_dtype) { - code_ << ")"; + code_ << "))"; } } @@ -560,28 +577,6 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { code_ << genVariableName(tv); } - // Utility function to emit a cp.async intrinsic - void genCpAsync(const LoadStoreOp* ldst, size_t vec_size) { - auto dtype = ldst->in()->getDataType().value(); - - bool is_cg = ldst->opType() == LoadStoreOpType::CpAsync && - ldst->cacheOp() == CacheOp::Global; - std::string name = (is_cg ? "Ampere::cpAsyncCg" : "Ampere::cpAsyncCa"); - - ArgumentBuilder template_args; - template_args.arg(dtype); - template_args.arg(vec_size); - - ArgumentBuilder func_args; - func_args.arg(genInline(ldst->out()->as()->index())); - func_args.arg(genInline(ldst->in()->as()->index())); - if (ldst->predicate() != nullptr) { - func_args.arg(genInline(ldst->predicate())); - } - - indent() << genCall(name, template_args, func_args) << ";\n"; - } - void genCpAsyncBulkTensorTile(const LoadStoreOp* ldst) { auto in = ldst->in()->as(); auto out = ldst->out()->as(); @@ -619,20 +614,6 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << genCall(func_name, func_args) << ";\n"; } - void genLdMatrix(const LoadStoreOp* ldst) { - auto dtype = ldst->in()->getDataType().value(); - - bool is_transpose = (ldst->opType() == LoadStoreOpType::LdMatrixTranspose); - std::string name = - (is_transpose ? "Turing::ldMatrixT" : "Turing::ldMatrix"); - - ArgumentBuilder func_args; - func_args.arg(gen(ldst->out())); - func_args.arg(genInline(ldst->in()->as()->index())); - - indent() << genCall(name, func_args) << ";\n"; - } - void handle(const GetMetaData* gop) final { if (print_inline_) { code_ << gen(gop->in()); @@ -714,6 +695,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { code_ << cast_str.value(); } else if (op_type == UnaryOpType::BitCast) { code_ << "std::bit_cast<" << uop->out()->dtype() << ">"; + } else if (op_type == UnaryOpType::RefCast) { + code_ << "(*reinterpret_cast<" << uop->out()->dtype() << "*>(&"; } else { code_ << op_type; if (needFloatSuffix(op_type) && @@ -723,6 +706,9 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } code_ << "(" << gen(uop->in()) << ")"; + if (op_type == UnaryOpType::RefCast) { + code_ << "))"; + } } if (!print_inline_) { @@ -1052,109 +1038,6 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } } - std::string genArchString(MmaOptions::MacroType macro) { - std::stringstream ss; - if (isVolta(macro)) { - ss << "Volta"; - } else if (isTuring(macro)) { - ss << "Turing"; - } else if (isAmpere(macro)) { - ss << "Ampere"; - } else { - NVF_ERROR(false, "mma macro unknown arch"); - } - return ss.str(); - } - - std::string genMmaOp(const MmaOp* mma, bool init = false) { - std::stringstream ss; - auto macro = mma->macro(); - ss << genArchString(macro) << "::"; - if (init) { - ss << "init"; - } - ss << toString(macro); - - // clang-tidy: bugprone-unchecked-optional-access - // clang-tidy assumes that function result is unstable, so we need a copy. - auto mma_layout_opt = mma->layout(); - NVF_ERROR(mma_layout_opt.has_value(), "mma unknown input layout"); - if (isTuring(macro) || isAmpere(macro)) { - NVF_ERROR( - mma_layout_opt == MmaOptions::MmaLayout::TN, - "MMAs in Turing and Ampere are TN only, transpose is handled either " - "via ldmatrix.trans for fp16 or explicitly for other types."); - } - if (!init) { - ss << toString(mma_layout_opt.value()); - } - if (!init && isAmpere(macro)) { - if (mma->inA()->getDataType().value() == DataType::Half) { - ss << "F16"; - } else { - ss << "BF16"; - } - } - return ss.str(); - } - - static int getInputARegisterSize(MmaOptions::MacroType macro) { - switch (macro) { - case MmaOptions::MacroType::Volta_16_16_4: - return 2; - case MmaOptions::MacroType::Turing_16_8_16: - case MmaOptions::MacroType::Turing_16_16_16: - case MmaOptions::MacroType::Ampere_16_8_16: - case MmaOptions::MacroType::Ampere_16_16_16: - return 4; - default: - NVF_ERROR(false, "unknown macro"); - break; - } - return -1; - } - - static int getInputBRegisterSize(MmaOptions::MacroType macro) { - switch (macro) { - case MmaOptions::MacroType::Volta_16_16_4: - case MmaOptions::MacroType::Turing_16_8_16: - case MmaOptions::MacroType::Ampere_16_8_16: - return 2; - case MmaOptions::MacroType::Turing_16_16_16: - case MmaOptions::MacroType::Ampere_16_16_16: - return 4; - default: - NVF_ERROR(false, "unknown macro"); - break; - } - return -1; - } - - void genMmaOperands(const MmaOp* mma) { - std::stringstream ss; - auto macro = mma->macro(); - auto in_a = mma->inA()->as()->view(); - auto dtype = in_a->getDataType().value(); - indent() << kTab << "(reinterpret_cast*>(&" - << genVariableName(mma->inA()->as()->view()) - << ")[" << genInline(mma->inA()->as()->index()) - << "])" - << ",\n"; - indent() << kTab << "(reinterpret_cast*>(&" - << genVariableName(mma->inB()->as()->view()) - << ")[" << genInline(mma->inB()->as()->index()) - << "])"; - } - - void handle(const MmaOp* mma) final { - indent() << genMmaOp(mma) << "(\n"; - indent() << kTab << gen(mma->out()) << ",\n"; - genMmaOperands(mma); - code_ << ");\n"; - } - std::string genReductionOp(BinaryOpType op_type, DataType data_type) { std::stringstream lambda; lambda << "[](" << data_type << " &a, " << data_type << " b) " @@ -1312,6 +1195,12 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { void handle(const LoadStoreOp* ldst) final { auto optype = ldst->opType(); + NVF_ERROR( + optype != LoadStoreOpType::LdMatrix && + optype != LoadStoreOpType::LdMatrixTranspose && + optype != LoadStoreOpType::CpAsync, + "ldmatrix and cp.async should be lowered as kir::Asm"); + if (ldst->out()->isA()) { auto out_ti = ldst->out()->as(); auto out_tv = out_ti->view(); @@ -1324,7 +1213,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { auto mma = dynamic_cast(out_tv->definition()); NVF_ERROR(mma != nullptr, "CodeGen: mma op not in mma loop"); NVF_ERROR(optype == LoadStoreOpType::Set); - indent() << genMmaOp(mma, true) << "(" << gen(ldst->out()) << ");\n"; + indent() << "(" << gen(ldst->out()) << ").set(0);\n"; return; } @@ -1339,25 +1228,6 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { "Vectorized store/load requires input and output datatypes match."); } - // dispatch ldmatrix - if (optype == LoadStoreOpType::LdMatrix || - optype == LoadStoreOpType::LdMatrixTranspose) { - NVF_ERROR(is_vector_op, "LdMatrix: Vectorization required: ", ldst); - genLdMatrix(ldst); - return; - } - - // dispatch cp.async - if (optype == LoadStoreOpType::CpAsync) { - if (ldst->cacheOp() == CacheOp::Global) { - NVF_ERROR( - is_vector_op && vector_word_size == 8, - "cp.async.cg only support vectorize 8"); - } - genCpAsync(ldst, vector_word_size); - return; - } - // dispatch cp.async.bulk.tensor.tile if (optype == LoadStoreOpType::CpAsyncBulkTensorTile) { genCpAsyncBulkTensorTile(ldst); @@ -2875,17 +2745,62 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { if (asm_->volatile_()) { code_ << " volatile"; } - bool multiline = + bool multiline = asm_->hasBooleanInput() || (asm_->code().size() + (asm_->inputs().size() + asm_->outputs().size()) * 5 > 80); + if (!multiline) { + // If any of the operand is an array type, force using multiline + for (const auto& l : + std::array>, 2>{ + asm_->inputs(), asm_->outputs()}) { + for (const auto& v : l.get()) { + if (std::holds_alternative(v->dtype().type)) { + multiline = true; + break; + } + } + } + } code_ << "("; if (multiline) { code_ << "\n"; block_nest_level_++; indent(); } - code_ << "\"" << asm_->code() << "\\n\""; + + if (asm_->hasBooleanInput()) { + code_ << "\"{\\n\"\n"; + int64_t boolean_counter = 0; + int64_t counter = 0; + for (auto input : asm_->inputs()) { + if (input->dtype() == DataType::Bool) { + indent() << "\" .reg .pred p" << boolean_counter << "; \\n\"\n"; + indent() << "\" setp.ne.b32 p" << boolean_counter << ", %" << counter + << ", 0;\\n\"\n"; + boolean_counter++; + } + if (std::holds_alternative(input->dtype().type)) { + counter += (int64_t)std::get(input->dtype().type).size; + } else { + counter++; + } + } + indent() << "\" " << asm_->code(); + } else { + code_ << "\"" << asm_->code(); + } + + auto parameters = asm_->parameters(); + if (!parameters.empty()) { + code_ << " " << parameters; + } + code_ << R"(;\n")"; + + if (asm_->hasBooleanInput()) { + code_ << "\n"; + indent() << R"("}\n")"; + } auto next_section = [&]() { if (multiline) { @@ -2899,15 +2814,40 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { [&](const auto& constraints_and_registers) { bool first = true; for (auto [constraint, register_] : constraints_and_registers) { - if (!first) { - code_ << ", "; + auto next_line = [&]() { + code_ << ","; if (multiline) { code_ << "\n"; - indent(); + indent() << " "; + } else { + code_ << " "; } + }; + if (!first) { + next_line(); } first = false; - code_ << "\"" << constraint << "\"(" << gen(register_) << ")"; + if (std::holds_alternative(register_->dtype().type)) { + for (auto i : c10::irange( + std::get(register_->dtype().type).size)) { + if (i > 0) { + next_line(); + } + code_ << "\"" << constraint << "\"(" << gen(register_) << "[" + << i << "]" + << ")"; + } + } else { + code_ << "\"" << constraint << "\"("; + if (register_->dtype() == DataType::Bool) { + code_ << "(uint32_t)("; + } + code_ << gen(register_); + if (register_->dtype() == DataType::Bool) { + code_ << ")"; + } + code_ << ")"; + } } }; @@ -2945,15 +2885,6 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } } - void handle(const kir::CpAsyncBulkS2GWait* cpasync_wait) final { - indent() << "Hopper::cpAsyncBulkS2GPartialReadBarrier<" - << cpasync_wait->keepStages() << ">();\n"; - } - - void handle(const kir::CpAsyncBulkS2GCommit* cpasync_wait) final { - indent() << "Hopper::cpAsyncBulkS2GCommit();\n"; - } - void handle(const kir::GridSync* sync) final { // Use a custom synchronization method if enabled bool bidx = sync->syncDims().get(ParallelType::BIDx); diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index cf479dfed83..6955309c072 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -806,22 +806,34 @@ void ComputeAtMap::allocateIndexVariables() { // and we only need one index variable for each set. for (const auto& loop_disjoint_set : id_graph_.loopNodes().disjointSets()) { ParallelType ptype = ParallelType::Serial; + + // We don't allocate any index variable for domains which + // are parallelized accross devices + if (auto result = std::find_if( + loop_disjoint_set->vector().begin(), + loop_disjoint_set->vector().end(), + [](IterDomain* id) { return id->isDeviceDim(); }); + result != loop_disjoint_set->vector().end()) { + loop_index_variable_map_[loop_disjoint_set.get()] = fusion_->zeroVal(); + continue; + } + // first allocate thread and grid parallel indices: // The validation pass will check that the parallel bindings within the // loop nodes are consistent so all the loops within this disjoint set // will be realized implicitly using parallel index variables. - auto result = std::find_if( - loop_disjoint_set->vector().begin(), - loop_disjoint_set->vector().end(), - [](IterDomain* id) { - // Halo extended parallel loops currently are handled - // differently and an index variable would still - // be allocated in this case. - return id->isThread() && - (GpuLower::current()->haloInfo()->getExtent(id) == nullptr); - }); - if (result != loop_disjoint_set->vector().end()) { + if (auto result = std::find_if( + loop_disjoint_set->vector().begin(), + loop_disjoint_set->vector().end(), + [](IterDomain* id) { + // Halo extended parallel loops currently are handled + // differently and an index variable would still + // be allocated in this case. + return id->isThread() && + (GpuLower::current()->haloInfo()->getExtent(id) == nullptr); + }); + result != loop_disjoint_set->vector().end()) { ptype = (*result)->getParallelType(); loop_index_variable_map_[loop_disjoint_set.get()] = NamedScalar::getParallelIndex(ptype); @@ -1425,7 +1437,7 @@ std::string ComputeAtMap::toString() const { << idGraphNodesToString(*this, IdMappingMode::PERMISSIVE); ss << "Permissive-Resize map:\n" << idGraphNodesToString(*this, IdMappingMode::PERMISSIVE_RESIZE); - ss << "Permissive-Relaxed-Resize map:\n" + ss << "Innermost map:\n" << idGraphNodesToString(*this, IdMappingMode::INNERMOST); ss << "Consumer maps:\n"; for (auto key : getSortedKeys(id_graph_.consumers(), Statement::lessThan)) { diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 2f6912b119f..f7c922ad5e8 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -777,7 +777,8 @@ SyncMap::SyncMap(Fusion* fusion) { raw_dims.toString()); } else if (raw_dims.hasTID()) { NVF_ERROR( - producer->getMemoryType() == MemoryType::Global || + ir_utils::isLdMatrixOp(producer->definition()) || + producer->getMemoryType() == MemoryType::Global || producer->getMemoryType() == MemoryType::Shared, "Inconsistent parallelization found between TV", producer->name(), diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index bdcf5fb7ac1..dac936eb583 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -268,7 +268,7 @@ GpuLower::GpuLower(Fusion* fusion, const CompileParams& cparams) // printed in verbose mode of lowering. The function must take a // const std::vector& and return a std::vector. {{"LoopNestGenerator", LoopNestGenerator::loweredExprs}, - {"unarySetOpInserter", unarySetOpInserter}, + {"loadStoreOpInserter", loadStoreOpInserter}, {"insertAllocations", insertAllocations}, {"insertRawThreadSynchronization", insertRawThreadSynchronization}, {"reuseMemoryAllocations", reuseMemoryAllocations}, @@ -494,9 +494,7 @@ void GpuLower::analysis(Fusion* fusion) { dumpExprsIfEnabled(fusion_->exprs(), "build doubleBufferInfo", true); compute_at_map_->allocateIndexVariables(); - dumpExprsIfEnabled(fusion_->exprs(), "allocateIndexVariables", true); - // Run our passes keeping the lowered expressions and forwarding - // them + dumpExprsIfEnabled(fusion_->exprs(), "allocateIndexVariables"); } kir::Kernel* GpuLower::kernel() const { diff --git a/csrc/device_lower/pass/fusion_simplifier.cpp b/csrc/device_lower/pass/fusion_simplifier.cpp index e5eaf21f182..3ef8f973faf 100644 --- a/csrc/device_lower/pass/fusion_simplifier.cpp +++ b/csrc/device_lower/pass/fusion_simplifier.cpp @@ -16,18 +16,18 @@ namespace nvfuser { namespace { -// Replaces Transpose, Shift, Gather, and View Ops with Unary Ops. -class UnaryOpInserter : private kir::ExprMutator { +// Replaces Transpose, Shift, Gather, and View Ops with LoadStoreOps. +class LoadStoreOpInserter : private kir::ExprMutator { public: static std::vector insert(const std::vector& exprs) { - UnaryOpInserter inserter(exprs); + LoadStoreOpInserter inserter(exprs); return inserter.exprs_; } private: using kir::ExprMutator::handle; - UnaryOpInserter(const std::vector& exprs) { + LoadStoreOpInserter(const std::vector& exprs) { kir::ExprMutator::traverseAndInsert(exprs); } @@ -89,9 +89,9 @@ class UnaryOpInserter : private kir::ExprMutator { } // namespace -// Transpose, Shift, Gather, and View Ops with Unary Set Ops -std::vector unarySetOpInserter(const std::vector& exprs) { - return UnaryOpInserter::insert(exprs); +// Transpose, Shift, Gather, and View Ops with LoadStoreOps. +std::vector loadStoreOpInserter(const std::vector& exprs) { + return LoadStoreOpInserter::insert(exprs); } } // namespace nvfuser diff --git a/csrc/device_lower/pass/fusion_simplifier.h b/csrc/device_lower/pass/fusion_simplifier.h index 2612d215ec1..4ccd331089c 100644 --- a/csrc/device_lower/pass/fusion_simplifier.h +++ b/csrc/device_lower/pass/fusion_simplifier.h @@ -17,7 +17,7 @@ namespace nvfuser { -// Transpose, Shift, Gather, and View Ops with Unary Set Ops -std::vector unarySetOpInserter(const std::vector& exprs); +// Transpose, Shift, Gather, and View Ops with LoadStoreOps +std::vector loadStoreOpInserter(const std::vector& exprs); } // namespace nvfuser diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index cfc08db2d24..822dc1d5c72 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -24,7 +24,8 @@ Val* IndexLowering::lowerSrcIndex( Val* src, Val* dst, const std::unordered_map& override_index, - bool generate_pointer) const { + bool generate_pointer, + DataType as_type) const { if (auto tv = dynamic_cast(src)) { NVF_ERROR(dst->isA()); return Index::getProducerIndex( @@ -33,7 +34,8 @@ Val* IndexLowering::lowerSrcIndex( for_loops_, getRotatedLoop(), override_index, - generate_pointer); + generate_pointer, + as_type); } else { return src; } @@ -1352,6 +1354,18 @@ void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) { pushBack(IrBuilder::create(0)); } +static DataType getMmaInputAType(MmaMacro macro) { + int size = getM(macro) * getK(macro) / 32 /* threads per warp */ / + 2 /* halves per 32bit register */; + return ArrayType{std::make_shared(DataType::UInt32), (size_t)size}; +} + +static DataType getMmaInputBType(MmaMacro macro) { + int size = getN(macro) * getK(macro) / 32 /* threads per warp */ / + 2 /* halves per 32bit register */; + return ArrayType{std::make_shared(DataType::UInt32), (size_t)size}; +} + static inline DataType getMmaOutType(TensorView* mma_out) { int64_t size = 1; for (auto id : mma_out->getLeafDomain()) { @@ -1398,8 +1412,10 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { } void IndexLowering::handle(const MmaOp* mma) { - const auto a = lowerSrcIndex(mma->inA(), mma->out()); - const auto b = lowerSrcIndex(mma->inB(), mma->out()); + const auto a = lowerSrcIndex( + mma->inA(), mma->out(), {}, false, getMmaInputAType(mma->macro())); + const auto b = lowerSrcIndex( + mma->inB(), mma->out(), {}, false, getMmaInputBType(mma->macro())); const auto out = lowerDstIndex( mma->out(), {}, false, getMmaOutType(mma->out()->as())); auto mma_indexed = IrBuilder::create( diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index fbeb9776dd3..3983776a777 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -108,7 +108,8 @@ class IndexLowering : private OptOutConstDispatch { Val* val, Val* dst, const std::unordered_map& override_index = {}, - bool generate_pointer = false) const; + bool generate_pointer = false, + DataType as_type = DataType::Null) const; Val* lowerDstIndex( Val* dst, diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 6b4bcd63418..134d172e6c4 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -7,9 +7,13 @@ // clang-format on #include +#include #include +#include #include +#include + namespace nvfuser { class LowerToInlinePtx : public kir::ExprMutator { @@ -20,7 +24,7 @@ class LowerToInlinePtx : public kir::ExprMutator { registerReplace( commit, IrBuilder::create( - "cp.async.commit_group;", + "cp.async.commit_group", std::vector{}, std::vector{}, kir::Asm::Options{true})); @@ -31,13 +35,13 @@ class LowerToInlinePtx : public kir::ExprMutator { Expr* replace = nullptr; if (stages > 0) { replace = IrBuilder::create( - "cp.async.wait_group %0;", + "cp.async.wait_group", std::vector{}, std::vector{IrBuilder::create(stages)}, kir::Asm::Options{true}); } else { replace = IrBuilder::create( - "cp.async.wait_all;", + "cp.async.wait_all", std::vector{}, std::vector{}, kir::Asm::Options{true}); @@ -45,6 +49,143 @@ class LowerToInlinePtx : public kir::ExprMutator { registerReplace(wait, replace); } + + void handle(kir::CpAsyncBulkS2GCommit* commit) override { + registerReplace( + commit, + IrBuilder::create( + "cp.async.bulk.commit_group", + std::vector{}, + std::vector{}, + kir::Asm::Options{true})); + } + + void handle(kir::CpAsyncBulkS2GWait* wait) override { + auto stages = wait->keepStages(); + registerReplace( + wait, + IrBuilder::create( + "cp.async.bulk.wait_group.read", + std::vector{}, + std::vector{IrBuilder::create(stages)}, + kir::Asm::Options{true, true})); + } + + void handle(LoadStoreOp* ldst) override { + if (ir_utils::isLdMatrixOp(ldst)) { + auto op = ldst->opType(); + std::stringstream ss; + ss << "ldmatrix.sync.aligned.x" + << std::get(ldst->out()->dtype().type).size; + if (op == LoadStoreOpType::LdMatrixTranspose) { + ss << ".trans"; + } + ss << ".m8n8.shared.b16"; + registerReplace( + ldst, + IrBuilder::create( + ss.str(), + std::vector{ldst->out()}, + std::vector{ldst->in()}, + kir::Asm::Options{true})); + return; + } else if (ir_utils::isCpAsyncOp(ldst)) { + auto out_tv = ldst->out()->as()->view(); + auto vec_size = + ir_utils::getVectorizeSize(out_tv) * dataTypeSize(out_tv->dtype()); + std::stringstream ss; + ss << "cp.async."; + if (ldst->cacheOp() == CacheOp::AllLevels) { + ss << "ca"; + } else { + ss << "cg"; + NVF_ERROR( + vec_size == 16, "cp.async.cg only support vectorize 16 bytes"); + } + ss << ".shared.global"; + registerReplace( + ldst, + IrBuilder::create( + ss.str(), + std::vector{}, + std::vector{ + ldst->out(), + ldst->in(), + IrBuilder::create(vec_size), + ldst->predicate()}, + kir::Asm::Options{true})); + } + } + + void handle(MmaOp* mma) override { + // Constants definitions based on MMA PTX instruction documentation: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#multiply-and-accumulate-instruction-mma + const int m = 16; + const int n = 8; + const int k = mma->isAmpere() ? 16 : 8; + + std::string op; + { + std::stringstream op_ss; + op_ss << "mma.sync.aligned.m" << m << "n" << n << "k" << k + << ".row.col.f32"; + if (mma->inA()->as()->view()->getDataType().value() == + DataType::BFloat16) { + op_ss << ".bf16.bf16"; + } else { + op_ss << ".f16.f16"; + } + op_ss << ".f32"; + op = op_ss.str(); + } + + int64_t split_n = mma->n() / n; + int64_t split_k = mma->k() / k; + + // If factor == 1, then do nothing, otherwise, view array as + // array, factor> + auto maybe_outer_split = [](DataType dtype, int64_t factor) -> DataType { + if (factor == 1) { + return dtype; + } + const auto& array = std::get(dtype.type); + return ArrayType{ + std::make_shared( + ArrayType{array.type, array.size / (size_t)factor}), + (size_t)factor}; + }; + + DataType accumulator_type = maybe_outer_split(mma->out()->dtype(), split_n); + DataType a_type = maybe_outer_split(mma->inA()->dtype(), split_k); + DataType b_type = maybe_outer_split(mma->inB()->dtype(), split_n); + if (split_n > 1) { + // array, split_k>, split_n> + auto& item_type = *std::get(b_type.type).type; + item_type = maybe_outer_split(item_type, split_k); + } else { + // array, split_k> + b_type = maybe_outer_split(b_type, split_k); + } + + auto accumulator = + IrBuilder::maybeRefCastExpr(accumulator_type, mma->out()); + auto a = IrBuilder::maybeRefCastExpr(a_type, mma->inA()); + auto b = IrBuilder::maybeRefCastExpr(b_type, mma->inB()); + + for (auto in : c10::irange(split_n)) { + auto acc = + split_n == 1 ? accumulator : IrBuilder::getItemExpr(accumulator, in); + auto bb = split_n == 1 ? b : IrBuilder::getItemExpr(b, in); + for (auto ik : c10::irange(split_k)) { + auto aa = split_k == 1 ? a : IrBuilder::getItemExpr(a, ik); + auto bbb = split_k == 1 ? bb : IrBuilder::getItemExpr(bb, ik); + auto mma_asm = IrBuilder::create( + op, std::vector{acc}, std::vector{aa, bbb, acc}); + registerInsertBefore(mma, mma_asm); + } + } + registerRemove(mma); + } }; std::vector lowerToInlinePtx(const std::vector& exprs) { diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index 555600a5aa6..939e30cbf60 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -89,37 +89,20 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { setWritePredicate(expr); } - // Note: [Predicate Inversion for CpAsync] - // Today for vectorized support the pattern is: - // Initialize buffer -> predicated load - // For memcpy async: - // If we initialized and then loaded (without sync) it would be undefined - // behavior. - // Initialize only the "virtual out of boundary" accesses. - // Memory allocated, but outside the virtual tensor space. - // Virtual tensor space today is effectively what would be allocated in - // global memory. Then only copy the "within bound" accesses. - // This is a WAR today based on how our system is set up. - // We would want to have a separate concept of SMEM space from Virtual or - // GMEM space, so that we know we're only working with the allocated - // SMEM. - // If we hit outside the allocated SMEM bad things happen. - // Today asserting in predicate removal making sure that the virtual and - // SMEM boundaries line up based on the IterDomains. - // - // TODO: in a follow up we need to extend the predicate - // infrastructure to generate predicate for both gmem - // and smem, and the predicate removal will need to - // be extended as well for the perf critical regions. - if (isPredicatedInitForCpAsync(expr)) { - invertPredicateForGmemToSharedMemInitialize(expr); + // According to: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async + // cp.async has a built-in mechanism `ignore-src` to ignore the source and + // fill zero. We can just invert the predicate and use it as `ignore-src`. + if (ir_utils::isCpAsyncOp(expr)) { + invertPredicate(expr); } kir::ExprMutator::dispatch(expr); } // Invert the predicate of given expr. - void invertPredicateForGmemToSharedMemInitialize(Expr* expr) { + void invertPredicate(Expr* expr) { + NVF_ERROR(expr != nullptr); auto pred = expr->predicate()->value(); Val* invert = SimplifyingIrBuilder::logicalNotExpr(pred); invert = @@ -127,18 +110,6 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { expr->predicate()->setValue(invert); } - // Detect if this expr is an initialization for vectorized - // cp asyc with predicates. - bool isPredicatedInitForCpAsync(Expr* expr) { - // Match the pattern: - // If(pred) - // TV = 0; - // where TV is the output of cp async. - auto maybe_init = ir_utils::getMaybePredicatedSingleton(expr); - return maybe_init.has_value() && - ir_utils::isCpAsyncInit(maybe_init.value()); - } - void setWritePredicate(Expr* expr) { if (expr->writePredicate() != nullptr) { auto write_cond = generateConditional(expr->writePredicate()); diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 5a9dc1b5bd3..9e2bc68e413 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -876,10 +876,14 @@ namespace { //! specialization of tidx as lane id. void validateMmaTensors(MmaOp* mma) { bool tidx_validated = false; - std::vector to_validate = { - mma->inA()->as(), - mma->inB()->as(), - mma->out()->as()}; + std::vector to_validate = {mma->out()->as()}; + + if (ir_utils::isLdMatrixOp(mma->inA()->definition())) { + to_validate.push_back(mma->inA()->as()); + } + if (ir_utils::isLdMatrixOp(mma->inB()->definition())) { + to_validate.push_back(mma->inB()->as()); + } for (auto tv : to_validate) { for (auto id : tv->getLeafDomain()) { @@ -896,10 +900,18 @@ void validateMmaTensors(MmaOp* mma) { GpuLower::current()->parallelDimensionMap(); NVF_ERROR( lower_utils::isExtentEqualToMaxParallelTypeExtent(id) && - paralel_dim_map.get(ptype)->isConstInt() && - paralel_dim_map.get(ptype)->evaluate() == - at::cuda::warp_size(), - "TIDx is reserved for lane id in mma kernels, and it needs to be exactly a warp"); + paralel_dim_map.get(ptype)->isConstInt(), + "TIDx is reserved for lane id in mma kernels"); + if (mma->isHopper()) { + NVF_ERROR( + paralel_dim_map.get(ptype)->evaluate() == + at::cuda::warp_size() * 4, + "TIDx must be exactly a warp group for Hopper"); + } else { + NVF_ERROR( + paralel_dim_map.get(ptype)->evaluate() == at::cuda::warp_size(), + "TIDx must be exactly a warp for Turing/Ampere"); + } tidx_validated = true; } } @@ -907,10 +919,23 @@ void validateMmaTensors(MmaOp* mma) { } // Note: this check will be relaxed in a follow up. - auto validate_operand = [](const TensorView* tv) { - NVF_ERROR( - tv->getMemoryType() == MemoryType::Local, - "Only supporting register input for mma ops, up to sm80 all mma ops have to take register inputs."); + auto validate_operand = [mma](const TensorView* tv, MmaOperand operand) { + if (mma->isHopper()) { + if (operand == MmaOperand::B) { + NVF_ERROR( + tv->getMemoryType() == MemoryType::Shared, + "Only supporting smem input for Hopper mma input B"); + } else { + NVF_ERROR( + tv->getMemoryType() == MemoryType::Local || + tv->getMemoryType() == MemoryType::Shared, + "Only supporting register or shared memory input for Hopper mma input A"); + } + } else { + NVF_ERROR( + tv->getMemoryType() == MemoryType::Local, + "Only supporting register input for mma input on Ampere/Turing"); + } NVF_ERROR( std::all_of( @@ -932,8 +957,8 @@ void validateMmaTensors(MmaOp* mma) { tv); }; - validate_operand(mma->inA()->as()); - validate_operand(mma->inB()->as()); + validate_operand(mma->inA()->as(), MmaOperand::A); + validate_operand(mma->inB()->as(), MmaOperand::B); // Additionally validate that mma is not directly taking a double buffered // register input as the double buffer indexing is currently not compatible @@ -946,68 +971,6 @@ void validateMmaTensors(MmaOp* mma) { "MMA op cannot directly take double buffered register input, put a set stage before."); } -//! Note and TODO: -//! Currently relying on ldmatrix to -//! obtain the correct data layout for turing/ampere -//! mma's. -//! This restriction will eventually not -//! be necessary once the scatter swizzle is ready. -void validateTuringMmaInput(TensorView* tv) { - // Pattern matching here to make sure LDMatrix is the right format. - // Format is done through swizzling in the scheduling and - // we check that swizzling to make sure it's correctly setup for LDMatrix. - // We could in theory support patterns LDMatrix doesn't support, - // but that would also mean the MMA isn't supported and - // so we would have to lower to something completely different. - - // MemCpy async is a more generic utility that we can use. - // Currently only allowed input paths are: - // ldmatrix -> mma or - // ldmatrix -> broadcast -> mma - // We actually wouldn't want too much flexibility here since - // this path is very perf critical. But the check itself - // can be made cleaner once we have the correct swizzle - // labeling. - // The most generic support would involve build out to - // support any pointwise ops that does not change the - // datalayout. - auto tv_def = tv->definition(); - NVF_ERROR(tv_def); - if (tv_def->isA()) { - tv_def = tv_def->input(0)->definition(); - } - NVF_ERROR(tv_def); - NVF_ERROR(ir_utils::isLdMatrixOp(tv_def)); -} - -// Output of ldmatrix is swizzled with the mma format, so it -// currently should not be fused with any pointwise ops. This -// check is to protect against these cases. -// This would also not be needed once scatter swizzle ready, should -// just become a swizzle format check if we wanted to fuse ldmatrix -// with any op other than mma. -void validateLdMatrixOutput(TensorView* tv) { - const auto& out_uses = tv->fusion()->unordered_uses(tv); - if (out_uses.empty()) { - return; - } - // TODO: restricting to single use pipelines for now which - // is true to matmul mainloop. This Could be relaxed to - // support more complex mma usage. - NVF_ERROR(out_uses.size() == 1); - auto out_use = *(out_uses.begin()); - - if (out_use->isA()) { - validateLdMatrixOutput(out_use->output(0)->as()); - return; - } - - NVF_ERROR( - out_use->isA(), - "validateLdMatrixOutput: currently only supports single mma use for ldmatrix", - out_use); -} - void validateSizeMemoryOp(LoadStoreOp* ldst) { if (!ldst->out()->isA()) { return; @@ -1042,18 +1005,6 @@ void validateSizeMemoryOp(LoadStoreOp* ldst) { } } -// Checks that the memory ops are supported on the targeted GPU -void validateArchMemoryOp(LoadStoreOp* ldst) { - switch (ldst->opType()) { - case LoadStoreOpType::LdMatrix: - case LoadStoreOpType::LdMatrixTranspose: - validateLdMatrixOutput(ldst->out()->as()); - return; - default: - return; - } -} - } // namespace //! Validate data format and GPU arch compatibility of scheduled @@ -1064,26 +1015,8 @@ void validateMma(Fusion* fusion) { for (auto expr : exprs) { if (auto mma = dynamic_cast(expr)) { validateMmaTensors(mma); - - switch (mma->macro()) { - case MmaOptions::MacroType::Volta_16_16_4: - break; - case MmaOptions::MacroType::Turing_16_8_16: - case MmaOptions::MacroType::Turing_16_16_16: - case MmaOptions::MacroType::Ampere_16_8_16: - case MmaOptions::MacroType::Ampere_16_16_16: - // Check that operands come from ldmatrix, can be - // relaxed once swizzles can be labeled on iterdomains. - validateTuringMmaInput(mma->inA()->as()); - validateTuringMmaInput(mma->inB()->as()); - break; - default: - NVF_ERROR(false, "validate mma: unsupported macro"); - break; - } } if (auto ldst = dynamic_cast(expr)) { - validateArchMemoryOp(ldst); validateSizeMemoryOp(ldst); } } diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 528b594d866..eff54c62276 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -86,7 +86,7 @@ class VectorOfUniqueEntries { return false; } - // Returns if any node was added + // Returns true if any node was added bool pushBack(const VectorOfUniqueEntries& other) { return pushBack(other.vector()); } diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 56500e936fc..fb0c091632e 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -750,9 +750,12 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { std::vector root_dom = updateIdVec(td->root()); std::vector rfactor_dom = td->hasRFactor() - ? updateIdVec(td->maybeRFactor()) + ? updateIdVec(td->rfactor()) + : std::vector(); + std::vector leaf_domain = updateIdVec(td->leaf()); + std::vector alloc_dom = td->hasAllocation() + ? updateIdVec(td->allocation()) : std::vector(); - std::vector domain = updateIdVec(td->leaf()); if (!mutated) { return; @@ -761,8 +764,16 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { // Update the contiguity vector. Drop the contig val if mutated to broadcast auto contig = td->contiguity(); - for (const auto i : c10::irange(td->maybeRFactor().size())) { - auto original_id = td->maybeRFactor().at(i); + const auto& new_maybe_alloc = td->hasAllocation() ? alloc_dom + : td->hasRFactor() ? rfactor_dom + : root_dom; + const auto& original_alloc = td->maybeAllocation(); + NVF_ERROR( + new_maybe_alloc.size() == original_alloc.size(), + "rank of allocation domain shouldn't change in concretization"); + + for (const auto i : c10::irange(original_alloc.size())) { + auto original_id = original_alloc.at(i); if (original_id->getIterType() != IterType::Symbolic) { continue; } @@ -772,7 +783,7 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { "Unexpected to have a non-contig symbolic domain: ", original_id->toString()); - auto updated_id = td->hasRFactor() ? rfactor_dom.at(i) : root_dom.at(i); + auto updated_id = new_maybe_alloc.at(i); // If the concretized ID is a broadcast domain, drop the contig val if (updated_id->isBroadcast()) { @@ -781,7 +792,7 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { } Val* mutated_val = IrBuilder::create( - td->container(), root_dom, rfactor_dom, domain, contig); + td->container(), root_dom, rfactor_dom, alloc_dom, leaf_domain, contig); registerConcretization(td, mutated_val); } diff --git a/csrc/executor.cpp b/csrc/executor.cpp index 3a54ff88170..6232dc5b03a 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -171,7 +171,7 @@ std::string FusionExecutor::getStructuredCode( if (isDebugDumpEnabled(DebugDumpOption::CudaToFile) || isDebugDumpEnabled(DebugDumpOption::DebugInfo)) { std::stringstream file_name; - file_name << "__tmp_kernel" << getGlobalFusionCount() << ".cu"; + file_name << "__tmp_kernel_" << kernel_id_ << ".cu"; debug() << "PRINTING: " << file_name.str() << std::endl; std::ofstream out(file_name.str()); out << code << std::endl; @@ -369,7 +369,11 @@ void FusionExecutor::compileFusion( // If the loaded external source code is empty, revert to the default codegen. // The external_structured_code is moved to structured_code and explicitly // cleared to avoid use-after-move scenarios. - auto structured_code = getStructuredCodeFromExternalFiles(fusion_id_); + // Note: we index these with getGlobalFusionCount() instead of fusion_id_ in + // order to match the numbering of files output with + // NVFUSER_DUMP=cuda_to_file + auto structured_code = + getStructuredCodeFromExternalFiles(getGlobalFusionCount()); if (structured_code.empty()) { structured_code = getStructuredCode(); } @@ -918,46 +922,33 @@ int64_t IndexOfFusionInput(const Val* in, const Fusion* fusion) { } // Returns the at::Tensor allocated for `out_info`. -// -// TODO: clean up the API so we explicitly pass in the input alias. This way, we -// can remove `args` and `kernel`, which unnecessary expose information of -// unrelated arguments. at::Tensor allocateOutput( const FusionExecutor::GlobalBufferInfo& out_info, - const KernelArgumentHolder& args, + Val* aliased_in, + const AliasInfo* alias_info, + const at::Tensor& aliased_in_tensor, const c10::Device& device, - const kir::Kernel* kernel, ExpressionEvaluator& ee) { TensorView* out_tv = out_info.tv; - auto alias_it = kernel->ioAlias().find(out_tv); // Note: aliased output is not returned as output. But we still need it // for kernel execution, so would need to push them to args - if (alias_it != kernel->ioAlias().end()) { - const auto aliased_in_index = - IndexOfFusionInput(alias_it->second.first, kernel); - const PolymorphicValue& in_val = *args[aliased_in_index]; - NVF_ERROR( - in_val.is(), - "Alias io only supports tensor. Found ", - PolymorphicValue_functions::toString(in_val)); - at::Tensor in_tensor = in_val.as(); - - switch (alias_it->second.second.type) { + if (aliased_in != nullptr) { + switch (alias_info->type) { case AliasType::InplaceUpdate: - // Unlike for `AliasType::PointerCast`, don't use + // Unlike for `AliasType::PointerArithmetic`, don't use // ExpressionEvaluator to compute the output tensor. This is because // the output tensor may hold different data from the input, e.g., an // updated running mean. `ExpressionEvaluator::evaluate(out_tv)` // would trigger non-trivial host computation. - return in_tensor; + return aliased_in_tensor; - case AliasType::PointerCast: - auto* in_tv = kernel->inputs()[aliased_in_index]->as(); - ee.bind(in_tv, in_tensor); + case AliasType::PointerArithmetic: + auto* in_tv = aliased_in->as(); + ee.bind(in_tv, aliased_in_tensor); at::Tensor out_tensor = ee.evaluate(out_tv).as(); NVF_ERROR( - in_tensor.data_ptr() == out_tensor.data_ptr(), + out_tensor.is_alias_of(aliased_in_tensor), "ExpressionEvaluator failed to evaluate ", out_tv->toString(), " as an alias of ", @@ -997,8 +988,25 @@ std::vector allocateOutputs( std::vector outputs; outputs.reserve(output_info.size()); for (const auto output_idx : c10::irange(output_info.size())) { - outputs.push_back( - allocateOutput(output_info[output_idx], inputs, device, kernel, ee)); + Val* out = kernel->outputs()[output_idx]; + auto [aliased_in, alias_info] = kernel->getOutputAlias(out); + at::Tensor aliased_in_tensor; + if (aliased_in != nullptr) { + const PolymorphicValue& aliased_in_val = + *inputs[IndexOfFusionInput(aliased_in, kernel)]; + NVF_ERROR( + aliased_in_val.is(), + "Alias io only supports tensor. Found ", + PolymorphicValue_functions::toString(aliased_in_val)); + aliased_in_tensor = aliased_in_val.as(); + } + outputs.push_back(allocateOutput( + output_info[output_idx], + aliased_in, + alias_info, + aliased_in_tensor, + device, + ee)); } return outputs; } diff --git a/csrc/executor_params.cpp b/csrc/executor_params.cpp index ea40a09aefd..cee80f8b1f4 100644 --- a/csrc/executor_params.cpp +++ b/csrc/executor_params.cpp @@ -12,6 +12,20 @@ namespace nvfuser { +std::string CompileParams::toString() const { + std::stringstream ss; + ss << "Compile Parameters: index_type = "; + if (index_type.has_value()) { + ss << index_type.value() << ", "; + } else { + ss << "NotSet, "; + } + ss << "maxrregcount = " << maxrregcount << ", " + << "enable_magic_zero = " << enable_magic_zero << ", " + << "enable_ptxas_verbose = " << enable_ptxas_verbose << "\n"; + return ss.str(); +} + void LaunchParams::assertValid() { NVF_ERROR( bdimx() * bdimy() * bdimz() > 0 && diff --git a/csrc/executor_params.h b/csrc/executor_params.h index 981b403929a..0fb0607a3d1 100644 --- a/csrc/executor_params.h +++ b/csrc/executor_params.h @@ -37,6 +37,8 @@ struct CompileParams { bool operator!=(const CompileParams& other) const { return !(*this == other); } + + std::string toString() const; }; class LaunchParams { diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index 14ace582dd5..549b647d11d 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -50,7 +50,6 @@ #include #include #include -#include #include #include #include @@ -99,7 +98,6 @@ std::string kernelPreamble() { ss << nvfuser_resources::broadcast_cu; ss << nvfuser_resources::welford_cu; ss << nvfuser_resources::warp_cu; - ss << nvfuser_resources::tensorcore_cu; ss << nvfuser_resources::memory_cu; ss << nvfuser_resources::fused_welford_helper_cu; ss << nvfuser_resources::fused_reduction_cu; @@ -970,6 +968,10 @@ void fillCompileOptions( std::optional opt_block_size) { nvrtc_compile_driver.setOption("--std=c++17"); + // Suppress warnings for functions that are defined but unused, since we have + // many unused functions in the preamble. + nvrtc_compile_driver.setOption("--diag-suppress=177"); + // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_) // which gives better backwards compatibility to work on older driver, // (since older driver doesn't necessarily recognize PTX emitted by new diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index c90e6a34e3e..ee623065499 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -334,14 +334,14 @@ std::vector Fusion::exprs() { namespace { -bool allOutputsArePointerCasts(Fusion* fusion) { +bool allOutputsArePointerArithmetics(Fusion* fusion) { for (Val* out : fusion->outputs()) { const auto& [in, info] = fusion->getOutputAlias(out); if (in == nullptr) { return false; } NVF_ERROR(info != nullptr); - if (info->type != AliasType::PointerCast) { + if (info->type != AliasType::PointerArithmetic) { return false; } } @@ -355,7 +355,7 @@ bool Fusion::isNoOp() { return true; } - if (allOutputsArePointerCasts(this)) { + if (allOutputsArePointerArithmetics(this)) { return true; } @@ -816,7 +816,7 @@ void Fusion::aliasOutputToInput(Val* output, Val* input, const AliasType type) { } } -std::pair Fusion::getOutputAlias(Val* output) { +std::pair Fusion::getOutputAlias(Val* output) const { if (auto search = io_alias_.find(output); search != io_alias_.end()) { const std::pair& in_val_and_info = search->second; return {in_val_and_info.first, &in_val_and_info.second}; diff --git a/csrc/fusion.h b/csrc/fusion.h index b006bd77ffa..37b68928619 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -90,10 +90,10 @@ enum class AliasType : int { // For example, the tensor storing BatchNorm's running mean. The output EMA is // updated in place. InplaceUpdate, - // For example, the output of a ViewOp is merely a pointer cast of the input. - // In this case, we use `ExpressionEvaluator` (instead of a kernel) to compute - // the output tensor. - PointerCast, + // For example, the output of a ViewOp is merely a pointer arithmetic of the + // input. In this case, we use `ExpressionEvaluator` (instead of a kernel) to + // cheaply compute the output tensor. + PointerArithmetic, }; struct AliasInfo { @@ -244,7 +244,7 @@ class Fusion : public IrContainer { //! Returns the aliased input of a given output along with an `AliasInfo` //! describing how they alias. Returns when `output` is not //! aliased. - std::pair getOutputAlias(Val* output); + std::pair getOutputAlias(Val* output) const; // mark input at index to be permuted by permutation void setPermutationOnInput(int index, std::vector permutation) { @@ -276,11 +276,6 @@ class Fusion : public IrContainer { return is_during_update_uses_; } - // TODO: Have getOutputAlias expose AliasInfo and then remove this method. - const std::unordered_map>& ioAlias() const { - return io_alias_; - } - // NOTE: [Fusion managed data] // // Fusion-managed data is a mechanism to communicate data that survives fusion diff --git a/csrc/fusion_profiler.cpp b/csrc/fusion_profiler.cpp index 8cdeaa07cf9..3b43f763472 100644 --- a/csrc/fusion_profiler.cpp +++ b/csrc/fusion_profiler.cpp @@ -44,16 +44,11 @@ void record_cupti_activity(CUpti_Activity* pRecord, FILE* pFileHandle) { KernelProfile prof; prof.name.assign(demangle(pKARecord->name)); - size_t kernel_start = prof.name.find("kernel"); - size_t nvfuser_start = prof.name.find("nvfuser"); - NVF_ERROR( - kernel_start != std::string::npos || - nvfuser_start != std::string::npos, - "Failed to find kernel name start position.") - - size_t start = std::min(kernel_start, nvfuser_start); - size_t end = prof.name.find('('); - prof.name = prof.name.substr(start, end - start); + size_t start = prof.name.find("nvfuser"); + if (start != std::string::npos) { + size_t end = prof.name.find('(', start); + prof.name = prof.name.substr(start, end - start); + } prof.device = (int)pKARecord->deviceId; prof.stream = pKARecord->streamId; prof.correlation_id = pKARecord->correlationId; @@ -402,31 +397,13 @@ void FusionProfile::reset() { } std::array column_strs{ - "Fus#", - "NSegs", - "CuEvtTm(ms)", - "HstTm(ms)", - "CmpTm(ms)", - "KerTm(ms)", - "EffBw(GB/s)", - "%PeakBw", - "S-Seg#", - "S-KerName", - "S-KerTm(ms)", - "S-CmpTm(ms)", - "S-EffBw(GB/s)", - "S-%PeakBw", - "S-In(MB)", - "S-Out(MB)", - "S-Smem[Dyn,Stat]", - "S-Regs", - "S-Grid", - "S-Block", - "S-Cluster", - "S-Dev", - "S-Stm", - "S-PeakBw(GB/s)", - "S-DeviceName"}; + "Fus#", "NSegs", "CuEvtTm(ms)", "HstTm(ms)", + "CmpTm(ms)", "KerTm(ms)", "EffBw(GB/s)", "%PeakBw", + "S-Seg#", "S-KerTm(ms)", "S-CmpTm(ms)", "S-EffBw(GB/s)", + "S-%PeakBw", "S-In(MB)", "S-Out(MB)", "S-Smem[Dyn,Stat]", + "S-Regs", "S-Grid", "S-Block", "S-Cluster", + "S-Dev", "S-Stm", "S-PkBw(GB/s)", "S-DeviceName", + "S-KerName"}; std::ostream& operator<<(std::ostream& os, const FusionProfile& fp) { if (fp.fusion_id == 0) { @@ -442,29 +419,30 @@ std::ostream& operator<<(std::ostream& os, const FusionProfile& fp) { << std::get<7>(column_strs); os << " " << std::setw(6) << std::get<8>(column_strs) << " " - << std::setw(10) << std::get<9>(column_strs) << " " << std::setw(11) - << std::get<10>(column_strs); + << std::setw(9) << std::get<9>(column_strs); if (fp.verbose) { - os << " " << std::setw(11) << std::get<11>(column_strs); + os << " " << std::setw(11) << std::get<10>(column_strs); } - os << " " << std::setw(13) << std::get<12>(column_strs) << " " - << std::setw(9) << std::get<13>(column_strs) << " " << std::setw(9) - << std::get<14>(column_strs) << " " << std::setw(9) - << std::get<15>(column_strs) << " " << std::setw(16) - << std::get<16>(column_strs) << " " << std::setw(6) + os << " " << std::setw(13) << std::get<11>(column_strs) << " " + << std::setw(9) << std::get<12>(column_strs) << " " << std::setw(9) + << std::get<13>(column_strs) << " " << std::setw(9) + << std::get<14>(column_strs) << " " << std::setw(16) + << std::get<15>(column_strs) << " " << std::setw(6) + << std::get<16>(column_strs) << " " << std::setw(16) << std::get<17>(column_strs) << " " << std::setw(16) - << std::get<18>(column_strs) << " " << std::setw(16) - << std::get<19>(column_strs); + << std::get<18>(column_strs); if (fp.verbose) { - os << " " << std::setw(16) << std::get<20>(column_strs) << " " - << std::setw(5) << std::get<21>(column_strs) << " " << std::setw(5) - << std::get<22>(column_strs) << " " << std::setw(14) - << std::get<23>(column_strs) << " " << std::setw(20) - << std::get<24>(column_strs); + os << " " << std::setw(16) << std::get<19>(column_strs) << " " + << std::setw(5) << std::get<20>(column_strs) << " " << std::setw(5) + << std::get<21>(column_strs) << " " << std::setw(12) + << std::get<22>(column_strs) << " " << std::setw(20) + << std::get<23>(column_strs); } + + os << " " << std::setw(20) << std::get<24>(column_strs); } os << std::endl; @@ -517,8 +495,7 @@ std::ostream& operator<<(std::ostream& os, const FusionProfile& fp) { smem << "[" << kp.dynamic_shared_mem << ", " << kp.static_shared_mem << "]"; os << std::setfill(' ') << std::right << std::fixed << " " << std::setw(6) - << idx << " " << std::setw(10) << kp.name << " " << std::setw(11) - << std::setprecision(3) << kp.time_ms; + << idx << " " << std::setw(11) << std::setprecision(3) << kp.time_ms; if (fp.verbose) { os << " " << std::setw(11) << std::setprecision(3) @@ -537,9 +514,10 @@ std::ostream& operator<<(std::ostream& os, const FusionProfile& fp) { if (fp.verbose) { os << " " << std::setw(16) << cluster.str() << " " << std::setw(5) << kp.device << " " << std::setw(5) << kp.stream << " " - << std::setw(14) << std::setprecision(2) << kp.peak_bandwidth_gbs + << std::setw(12) << std::setprecision(2) << kp.peak_bandwidth_gbs << " " << std::setw(20) << kp.device_name; } + os << " " << std::setw(20) << kp.name; os << std::endl; ++idx; } diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 093a1b2b4f8..06d57c28de6 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -883,7 +883,7 @@ std::vector SegmentedFusion::castInputOutputToLowerPrecision( // // To avoid this discrepancy, when this is done with virtual merged // groups, bundle all edges to the merged groups and process them - // together. This way, only one instane of the cast-back expr should + // together. This way, only one instance of the cast-back expr should // be inserted. // // Note that this analysis and replacement would be much simpler if we @@ -1554,7 +1554,9 @@ std::unique_ptr SegmentedFusion::makeFusion(SegmentedGroup* sg) { } } - for (auto out : getAllOutputs(sg)) { + // note, we would want to keep output consistent and not artificially drop + // duplicates. + for (auto out : sg->output_vals) { fusion_segment->addOutput(complete_to_segment_map.clone(out)); } @@ -3480,36 +3482,40 @@ void SegmentCandidateFinder::forwardInputs() { // treated as complete fusion inputs. VectorOfUniqueEntries forwarded_inputs; { - std::deque to_visit; + std::deque to_visit; for (auto inp : completeFusion()->inputs()) { + // Add all uses of input if all of those uses are UnaryOps + // If any of these ops are not UnaryOps then we if (std::all_of(inp->uses().begin(), inp->uses().end(), [](Expr* expr) { return expr->isA(); })) { - to_visit.insert(to_visit.end(), inp->uses().begin(), inp->uses().end()); + for (auto use : inp->uses()) { + to_visit.push_back(use->as()); + } } } while (!to_visit.empty()) { - auto expr = to_visit.front(); + UnaryOp* uop = to_visit.front(); to_visit.pop_front(); - if (!expr->isA() || expr->output(0)->isFusionOutput()) { + if (uop->out()->isFusionOutput()) { continue; } - // expr is a unary op so there is a single output. Here we look at that + // uop is a UnaryOp so there is a single output. Here we look at that // output's further uses - const auto& output_uses = expr->output(0)->uses(); + const auto& output_uses = uop->out()->uses(); - if (output_uses.size() == 1) { - // If there is a single use, visit it to try and extend the chain of - // unaryOps - to_visit.emplace_back(output_uses.at(0)); + if (output_uses.size() == 1 && output_uses[0]->isA()) { + // If there is a single use which is also a UnaryOp, visit it to try + // and extend the chain of unaryOps + to_visit.emplace_back(output_uses[0]->as()); } else { - // If there are either no more uses, or more than one use, we cannot - // extend the chain of unary Ops. In either case, finalize this chain by - // saving the expr and its output. - excluded_inp_unary_exprs_.pushBack(expr); - forwarded_inputs.pushBack(expr->output(0)); + // If there are either no more uses, more than one use, or one use that + // is not a UnaryOp, then we cannot extend the chain of unary Ops. In + // these cases we finalize this chain by saving the uop and its output. + excluded_inp_unary_exprs_.pushBack(uop); + forwarded_inputs.pushBack(uop->out()); } } } diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 1c13990537a..314ee914e5e 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -7,8 +7,10 @@ // clang-format on #include +#include #include #include + #include #include #include @@ -1323,7 +1325,8 @@ bool isParallelLoopIndexSubstitutedAsZero( // to consumer but they should still be detected as same // parallel type. In a follow up may want to extend // find_matching_parallel_domain to cover this case. - if (within_mma_loops && loop_id->getParallelType() == ParallelType::TIDx) { + if ((within_mma_loops || ir_utils::isLdMatrixOp(tv->definition())) && + loop_id->getParallelType() == ParallelType::TIDx) { return true; } @@ -1662,6 +1665,7 @@ std::vector Index::getNonGlobalProducerStridedIndices( const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index) { + bool is_mma_input = consumer_tv->definition()->isA(); const auto gpu_lower = GpuLower::current(); // Replay producer to look like consumer so we can index on producer since our // loop nests look like consumer @@ -1753,7 +1757,9 @@ std::vector Index::getNonGlobalProducerStridedIndices( for (auto alloc_id : alloc_dom) { // Already taken care of because we can detect no indexing required if (alloc_id->isBroadcast() || alloc_id->isReduction() || - alloc_id->isStride()) { + alloc_id->isStride() || + (alloc_id->isThread() && + producer_tv->getMemoryType() == MemoryType::Local)) { skip_indexing.insert(alloc_id); continue; } @@ -1766,6 +1772,23 @@ std::vector Index::getNonGlobalProducerStridedIndices( std::vector strided_inds( alloc_dom.size(), GpuLower::current()->kernel()->zeroVal()); + + // MMA operation op is a special operation that our automatic "zero domain" + // analysis of our current indexing approach does not work. So we need to + // manually specify which dimensions are used for MMA allocation. + std::function is_mma_allocation; + if (is_mma_input) { + int size = (int)alloc_dom.size(); + const IterDomain* allocation0 = alloc_dom.at(size - 3); + const IterDomain* allocation1 = alloc_dom.at(size - 2); + const IterDomain* allocation2 = alloc_dom.at(size - 1); + is_mma_allocation = [=](const IterDomain* id) { + return id == allocation0 || id == allocation1 || id == allocation2; + }; + } else { + is_mma_allocation = [](const IterDomain* id) { return false; }; + } + for (const auto i : c10::irange(alloc_dom.size())) { if (skip_indexing.count(alloc_dom[i])) { continue; @@ -1810,13 +1833,15 @@ std::vector Index::getNonGlobalProducerStridedIndices( continue; } - auto alloc_ext_j = extent_map.find(alloc_dom[j]) == extent_map.end() + auto alloc_ext_j = (extent_map.find(alloc_dom[j]) == extent_map.end() || + is_mma_allocation(alloc_dom[j])) ? alloc_dom[j]->extent() : extent_map.at(alloc_dom[j]); alloc_ext_j = getHaloExtentOfRootAxis(alloc_dom[j], alloc_ext_j); - if (zero_domain_map.count(alloc_dom[j]) == 0) { + if (zero_domain_map.count(alloc_dom[j]) == 0 || + is_mma_allocation(alloc_dom[j])) { if (stride == nullptr) { stride = alloc_ext_j; } else { @@ -2191,7 +2216,9 @@ std::vector Index::getNonGlobalConsumerStridedIndices( alloc_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(alloc_dom.size())) { if (alloc_dom[i]->isReduction() || alloc_dom[i]->isBroadcast() || - alloc_dom[i]->isStride()) { + alloc_dom[i]->isStride() || + (alloc_dom[i]->isThread() && + consumer_tv->getMemoryType() == MemoryType::Local)) { continue; } @@ -2363,7 +2390,8 @@ kir::TensorIndex* Index::getProducerIndex( const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index, - bool generate_pointer) { + bool generate_pointer, + DataType as_type) { auto index = getProducerStridedIndices( producer, consumer, @@ -2372,7 +2400,17 @@ kir::TensorIndex* Index::getProducerIndex( override_index, generate_pointer); index = GpuLower::current()->commonScalarMap().hoistScalar(index, loops); - return SimplifyingIrBuilder::create(producer, index); + if (ir_utils::isLdMatrixOp(consumer->definition())) { + if (at::cuda::getCurrentDeviceProperties()->major < 8) { + // For Turing, unused indices for ldmatrix needs to be aligned, although + // they are not used. + auto orig_index = index; + index = IrBuilder::create(index->dtype()); + IrBuilder::create( + UnaryOpType::AdjustPartialLdMatrixAddrInTuring, index, orig_index); + } + } + return IrBuilder::create(producer, index, as_type); } Val* Index::getConsumerStridedIndices( diff --git a/csrc/index_compute.h b/csrc/index_compute.h index 0ecdcd9ce65..a01b5638bf7 100644 --- a/csrc/index_compute.h +++ b/csrc/index_compute.h @@ -478,7 +478,8 @@ class Index { const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index = {}, - bool generate_pointer = false); + bool generate_pointer = false, + DataType as_type = DataType::Null); // Consumer index dispatch static kir::TensorIndex* getConsumerIndex( diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index de5ac19e28b..d8e9b777b37 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -122,6 +122,16 @@ Val* IrBuilder::maybeCastExpr(DataType dtype, Val* val) { return result; } +Val* IrBuilder::maybeRefCastExpr(DataType dtype, Val* val) { + NVF_CHECK(val != nullptr, "val is a nullptr in bitCastExpr."); + if (val->dtype() == dtype) { + return val; + } + auto result = create(dtype); + IrBuilder::create(UnaryOpType::RefCast, result, val); + return result; +} + Val* IrBuilder::addressExpr(Val* val) { NVF_CHECK(val != nullptr, "val is a nullptr in addressExpr."); auto result = create( diff --git a/csrc/ir/builder.h b/csrc/ir/builder.h index 24eee7b6e39..21df7433e65 100644 --- a/csrc/ir/builder.h +++ b/csrc/ir/builder.h @@ -62,6 +62,7 @@ class IrBuilder { static Val* absExpr(Val* val); static Val* setExpr(Val* val); static Val* maybeCastExpr(DataType dtype, Val* val); + static Val* maybeRefCastExpr(DataType dtype, Val* val); static Val* addressExpr(Val* val); static NamedScalar* setExprNamedScalar(const std::string& name, Val* val); static NamedScalar* addressExprNamedScalar(const std::string& name, Val* val); diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 43160899165..b881b6b81bb 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -120,6 +120,8 @@ class TensorView : public Val { std::string toInlineString(int indent_size = 0) const override; + void printTransforms() const; + TensorDomain* domain() const { return domain_; } @@ -414,7 +416,7 @@ class TensorView : public Val { //! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to //! have a matching thread swizzle with the mma operand/result. //! More detail on usage see [WarpMmaSwizzler] in scheduler/mma_utils.h . - void applyMmaSwizzle(MmaOptions options); + void applyMmaSwizzle(MmaOperand operand); //! Returns if this tensor view has swizzle operator on its tensor domain. //! This is the temporary flag for indicating that the new swizzle diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 5a575dba518..017cebe073c 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -120,16 +120,24 @@ class IterDomain : public Val { static std::vector clone( const std::vector& domains); - static IterDomain* merge(IterDomain* outer, IterDomain* inner); + //! When `rfactor_domain` is true, also set the `is_rfactor_domain_` flag of + //! the result IterDomain. + static IterDomain* merge( + IterDomain* outer, + IterDomain* inner, + bool rfactor_domain = false); //! start_offset and stop_offset defines partial split. Only root //! domains are allowed to have non-zero start and stop offsets. + //! When `rfactor_domain` is true, also set the `is_rfactor_domain_` flag of + //! both result IterDomains. static std::pair split( IterDomain* in, Val* factor, bool inner_split, Val* start_offset = nullptr, - Val* stop_offset = nullptr); + Val* stop_offset = nullptr, + bool rfactor_domain = false); //! trim_out_of_bounds controls how the values outside start and stop //! positions are treated. The option is only valid with root @@ -141,7 +149,8 @@ class IterDomain : public Val { IterDomain* in, Val* factor, bool inner_split, - bool trim_out_of_bounds); + bool trim_out_of_bounds, + bool rfactor_domain = false); //! Resize an IterDomain by expanding both the left and right sides //! by given widths. The resulting IterDomain has an extent of @@ -229,6 +238,10 @@ class IterDomain : public Val { return (isBlockDim() || isThreadDim()); } + bool isDeviceDim() const { + return isParallelTypeDeviceDim(getParallelType()); + } + void parallelize(ParallelType t); ParallelType getParallelType() const { diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index b72810b9656..089b2bf41f9 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -52,6 +52,9 @@ class FullOp : public Expr { std::string toString(int indent_size = 0) const override; std::string toInlineString(int indent_size = 0) const override; + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; Val* getFillValue() const { return inputs().back(); @@ -237,6 +240,9 @@ class IotaOp : public Expr { std::string toString(int indent_size = 0) const override; std::string toInlineString(int indent_size = 0) const override; + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; DataType dtype() const { return *start()->getDataType(); @@ -287,6 +293,9 @@ class EyeOp : public Expr { std::string toString(int indent_size = 0) const override; std::string toInlineString(int indent_size = 0) const override; + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; DataType dtype() const { return attribute(0); @@ -1325,7 +1334,7 @@ class GroupedWelfordOp : public Expr { class MmaOp : public Expr { public: using AxesData = std::vector; - using MmaLayoutOpt = std::optional; + using MmaLayoutOpt = std::optional; using Expr::Expr; MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init); @@ -1336,7 +1345,7 @@ class MmaOp : public Expr { Val* in_a, Val* in_b, Val* init, - const MmaOptions::MacroType& options, + const MmaMacro& options, const MmaLayoutOpt& input_layout); NVFUSER_DECLARE_CLONE_AND_CREATE @@ -1365,10 +1374,34 @@ class MmaOp : public Expr { } const auto& macro() const { - return attribute(ATTR_POS_MACRO); + return attribute(ATTR_POS_MACRO); + } + + int m() const { + return getM(macro()); + } + + int n() const { + return getN(macro()); + } + + int k() const { + return getK(macro()); + } + + bool isTuring() const { + return nvfuser::isTuring(macro()); + } + + bool isAmpere() const { + return nvfuser::isAmpere(macro()); + } + + bool isHopper() const { + return nvfuser::isHopper(macro()); } - void configureOptions(MmaOptions options); + void setMacro(MmaMacro options); auto layout() const { return attribute(ATTR_POS_INPUT_LAYOUT); @@ -1434,6 +1467,10 @@ class ExpandOp : public Expr { std::vector expanded_extents() const { return {inputs().begin() + 1, inputs().end()}; } + + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; }; //! Shift @@ -1550,6 +1587,9 @@ class ViewAsScalar : public Expr { std::string toString(int indent_size = 0) const override; std::string toInlineString(int indent_size = 0) const override; + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; Val* out() const { return output(0); diff --git a/csrc/ir/iostream.cpp b/csrc/ir/iostream.cpp index 181b394c8e8..824e6cb16f1 100644 --- a/csrc/ir/iostream.cpp +++ b/csrc/ir/iostream.cpp @@ -88,7 +88,7 @@ void IrTransformPrinter::handle(Fusion* f) { } } -void IrTransformPrinter::printTransforms(TensorView* tv) { +void IrTransformPrinter::printTransforms(const TensorView* tv) { const auto& root_domain = tv->getRootDomain(); os() << " root domain : (" << toDelimitedString(root_domain) << ")\n"; diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index ada6748056c..1e781b6b899 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -67,6 +67,20 @@ std::string FullOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector FullOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + std::vector shape; + for (auto i : c10::irange(inputs.size() - 1)) { + shape.push_back((int)inputs.at(i)); + } + DataType dtype = getFillValue()->getDataType().value(); + const auto options = + at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype)); + using namespace PolymorphicValue_functions; + return {at::full(shape, toScalar(inputs.back()), options)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(FullOp) SelectOp::SelectOp( @@ -210,7 +224,11 @@ std::vector TorchGatherOp::evaluate( const auto& input = inputs.at(0).as(); const auto& index = inputs.at(1).as(); auto dimension = dim(); - return {at::gather(input, dimension, index)}; + if (exactSizes()) { + return {at::take_along_dim(input, index, dimension)}; + } else { + return {at::gather(input, dimension, index)}; + } } NVFUSER_DEFINE_CLONE_AND_CREATE(TorchGatherOp) @@ -294,6 +312,31 @@ std::string IotaOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector IotaOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto options = + at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); + int64_t length = (int64_t)inputs.at(0); + + if (isIntegralType(dtype())) { + int64_t start = (int64_t)inputs.at(1); + int64_t step = (int64_t)inputs.at(2); + int64_t end = start + step * length; + return {at::arange(start, end, step, options)}; + } else if (isFloatingPointType(dtype())) { + double start = (double)inputs.at(1); + double step = (double)inputs.at(2); + // Due to rounding error, it can be hard to guarantee the size of + // the output of arange to be exactly length, so we generate a + // larger tensor and truncate it to length. + double end = start + step * ((double)length + 1); + return {at::arange(start, end, step, options).narrow(0, 0, length)}; + } else { + NVF_ERROR(false, "Unsupported dtype in IotaOp evaluator: ", dtype()); + } +} + NVFUSER_DEFINE_CLONE_AND_CREATE(IotaOp) EyeOp::EyeOp(IrBuilderPasskey passkey, Val* out, DataType dtype) @@ -321,6 +364,19 @@ std::string EyeOp::toString(int indent_size) const { std::string EyeOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector EyeOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto options = + at::TensorOptions().device(at::kCUDA).dtype(data_type_to_aten(dtype())); + int64_t nrows = (int64_t)inputs.at(0); + if (inputs.size() > 1) { + int64_t ncols = (int64_t)inputs.at(1); + return {at::eye(nrows, ncols, options)}; + } else { + return {at::eye(nrows, options)}; + } +} NVFUSER_DEFINE_CLONE_AND_CREATE(EyeOp) @@ -373,6 +429,9 @@ std::vector UnaryOp::evaluate( case UnaryOpType::ToUnsignedSmemAddr: return {(int64_t)(unsigned)in}; break; + case UnaryOpType::AdjustPartialLdMatrixAddrInTuring: + return {in}; + break; case UnaryOpType::Dereference: if (*out()->getDataType() == DataType::Float) { return {PolymorphicValue((double)*(float*)in)}; @@ -380,6 +439,52 @@ std::vector UnaryOp::evaluate( NVF_ERROR( false, "dtype not supported in evaluator: ", *out()->getDataType()); } + break; + case UnaryOpType::Sigmoid: + return {in.as().sigmoid()}; + break; + case UnaryOpType::Tanh: + return {in.as().tanh()}; + break; + case UnaryOpType::Relu: + return {at::relu(in.as())}; + break; + case UnaryOpType::Gelu: + return {at::gelu(in.as())}; + break; + case UnaryOpType::Exp: + return {at::exp(in.as())}; + break; + case UnaryOpType::Sin: + return {in.as().sin()}; + break; + case UnaryOpType::Cos: + return {in.as().cos()}; + break; + case UnaryOpType::BitCast: + NVF_CHECK( + dataTypeSize(input(0)->dtype()) == dataTypeSize(out()->dtype()), + "BitCast only works for types of the same size"); + if (isComplexType(input(0)->dtype()) && + std::holds_alternative(out()->dtype().type)) { + // view_as_real case. + auto vec_type = std::get(out()->dtype().type); + auto inp_scalar_type = getTypeFromComplexType(input(0)->dtype()); + NVF_CHECK( + *vec_type.type == inp_scalar_type, + "Output type must be the same as the scalar type of the complex input."); + NVF_CHECK( + vec_type.size == 2, + "Expected output to be array of size 2, found array of size ", + vec_type.size); + return {in.as()}; + } else { + return {in.as().view(data_type_to_aten(out()->dtype()))}; + } + break; + case UnaryOpType::Rsqrt: + return {in.as().rsqrt()}; + break; default: NVF_CHECK( false, @@ -1791,7 +1896,7 @@ struct MmaOpDetails { // and output AxesData batch_axes; // A placeholder for mma input layout - std::optional input_layout = std::nullopt; + std::optional input_layout = std::nullopt; }; // A helper structure with pieces of information about TensorView @@ -1823,7 +1928,7 @@ TensorViewDetails getDetailsFor(const std::vector& dims) { return details; } -MmaOptions::MmaLayout getInputLayout( +MmaLayout getInputLayout( const TensorViewDetails& in_a, const TensorViewDetails& in_b, const MmaOp::AxesData& m_axes, @@ -1837,7 +1942,7 @@ MmaOptions::MmaLayout getInputLayout( (k_axes.front() < in_a.bcasts.front()) && (in_b.bcasts.front() < k_axes.front()) && (in_b.bcasts.front() < n_axes.front())) { - return MmaOptions::MmaLayout::TT; + return MmaLayout::TT; } // TN layout (b - broadcast, r - reduction): // A = [M, b, K] @@ -1847,7 +1952,7 @@ MmaOptions::MmaLayout getInputLayout( (in_a.bcasts.front() < k_axes.front()) && (in_b.bcasts.front() < n_axes.front()) && (in_b.bcasts.front() < k_axes.front())) { - return MmaOptions::MmaLayout::TN; + return MmaLayout::TN; } // NT layout (b - broadcast, r - reduction): // A = [K, M, b] @@ -1857,7 +1962,7 @@ MmaOptions::MmaLayout getInputLayout( (m_axes.front() < in_a.bcasts.front()) && (k_axes.front() < in_b.bcasts.front()) && (in_b.bcasts.front() < n_axes.front())) { - return MmaOptions::MmaLayout::NT; + return MmaLayout::NT; } // NN layout (b - broadcast, r - reduction): // A = [b, K, M] @@ -1866,7 +1971,7 @@ MmaOptions::MmaLayout getInputLayout( if ((in_a.bcasts.front() < k_axes.front()) && (k_axes.front() < m_axes.front()) && (n_axes.front() < k_axes.front()) && (k_axes.front() < in_b.bcasts.front())) { - return MmaOptions::MmaLayout::NN; + return MmaLayout::NN; } NVF_ERROR(false, "Unsupported input layout"); @@ -2044,7 +2149,7 @@ MmaOp::MmaOp( // ATTR_POS_INIT addAttribute(init); // ATTR_POS_MACRO - addDataAttribute(MmaOptions::MacroType::NoMMA); + addDataAttribute(MmaMacro::NoMMA); // ATTR_POS_M_AXES addDataAttribute(AxesData{}); // ATTR_POS_N_AXES @@ -2078,10 +2183,10 @@ MmaOp::MmaOp( Val* in_a, Val* in_b, Val* init, - const MmaOptions::MacroType& macro, + const MmaMacro& macro, const MmaLayoutOpt& input_layout) : MmaOp(passkey, out, in_a, in_b, init) { - attribute(ATTR_POS_MACRO) = macro; + attribute(ATTR_POS_MACRO) = macro; const auto input_layout_ = attribute(ATTR_POS_INPUT_LAYOUT); if (input_layout_.has_value()) { @@ -2110,13 +2215,9 @@ std::string MmaOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -void MmaOp::configureOptions(MmaOptions options) { - MmaOptions::MacroType& macro = - attribute(ATTR_POS_MACRO); - NVF_ERROR( - options.macro != MmaOptions::MacroType::NoMMA, - "Un-configured mma type from options."); - macro = options.macro; +void MmaOp::setMacro(MmaMacro macro) { + NVF_ERROR(macro != MmaMacro::NoMMA, "Unspecified mma type"); + attribute(ATTR_POS_MACRO) = macro; } NVFUSER_DEFINE_CLONE_AND_CREATE(MmaOp) @@ -2151,6 +2252,17 @@ std::string ExpandOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector ExpandOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto& in = inputs.at(0).as(); + std::vector expanded_size; + for (auto i : c10::irange(1, inputs.size())) { + expanded_size.push_back((int64_t)inputs.at(i)); + } + return {at::expand_copy(in, expanded_size)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(ExpandOp) ShiftOp::ShiftOp( @@ -2300,6 +2412,13 @@ std::string ViewAsScalar::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } +std::vector ViewAsScalar::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const at::Tensor& in = inputs.at(0).as(); + return {at::view_as_real(in)}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(ViewAsScalar) ViewOp::ViewOp(IrBuilderPasskey passkey, Val* out, Val* in) : Expr(passkey) { @@ -2725,7 +2844,10 @@ std::vector IterDomain::clone( // domains is enforced by predicates. Note that since only root // domains have valid start and stop, it's not possible to contiguous // predication. -IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { +IterDomain* IterDomain::merge( + IterDomain* outer, + IterDomain* inner, + bool rfactor_domain) { NVF_CHECK( outer->isReduction() == inner->isReduction(), "Merging IterDomains requires that their iteration types match. ", @@ -2786,6 +2908,7 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { .parallel_type(outer->getParallelType()) .expanded_extent(expanded_extent) .iter_type(itype) + .is_rfactor_domain(rfactor_domain) .build(); IrBuilder::create(outer->container(), merged_id, outer, inner); @@ -2801,7 +2924,8 @@ std::pair IterDomain::split( Val* factor, bool inner_split, Val* start_offset, - Val* stop_offset) { + Val* stop_offset, + bool rfactor_domain) { NVF_CHECK( factor->isIntegralScalar(), "Cannot split by non-integer value ", factor); @@ -2829,6 +2953,7 @@ std::pair IterDomain::split( : nullptr) .parallel_type(in->getParallelType()) .iter_type(in->getIterType()) + .is_rfactor_domain(rfactor_domain) .build(); // inner loop IterDomain @@ -2840,6 +2965,7 @@ std::pair IterDomain::split( : nullptr) .parallel_type(in->getParallelType()) .iter_type(in->getIterType()) + .is_rfactor_domain(rfactor_domain) .build(); IrBuilder::create( @@ -2858,10 +2984,12 @@ std::pair IterDomain::split( IterDomain* in, Val* factor, bool inner_split, - bool trim_out_of_bounds) { + bool trim_out_of_bounds, + bool rfactor_domain) { auto start_offset = trim_out_of_bounds ? in->start() : nullptr; auto stop_offset = trim_out_of_bounds ? in->stopOffset() : nullptr; - return IterDomain::split(in, factor, inner_split, start_offset, stop_offset); + return IterDomain::split( + in, factor, inner_split, start_offset, stop_offset, rfactor_domain); } std::pair IterDomain::stridedSplit(int64_t factor) { @@ -2996,7 +3124,11 @@ IterDomain* IterDomain::resize( } auto resized_id = - IterDomainBuilder(in->container()->zeroVal(), resized_id_size) + IterDomainBuilder( + in->container()->zeroVal(), + // Set immediate constant size of 1 if resize produces broadcast + iter_type == IterType::Broadcast ? in->fusion()->oneVal() + : resized_id_size) .is_rfactor_domain(mark_as_rfactor) .iter_type(iter_type) .build(); diff --git a/csrc/ir/printer.h b/csrc/ir/printer.h index edb0a5d3c35..4e858c99aa0 100644 --- a/csrc/ir/printer.h +++ b/csrc/ir/printer.h @@ -53,7 +53,7 @@ class IrTransformPrinter : public IrPrinter { void handle(Fusion* f) override; - void printTransforms(TensorView* tv); + void printTransforms(const TensorView* tv); }; } // namespace nvfuser diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 2fafa1fe25e..0c3be5a2c7e 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -469,8 +469,9 @@ void BackwardVisitor::traverseTo( { size_t pos = 0; - for (auto expr : exprs) + for (auto expr : exprs) { traversal_exprs_[expr] = pos++; + } } // All stmts we've called handle on @@ -679,8 +680,9 @@ class DependentVals : public IterVisitor { std::unordered_set boundary_; std::vector next(Val* v) override { - if (boundary_.find(v) != boundary_.end()) + if (boundary_.find(v) != boundary_.end()) { return std::vector(); + } return IterVisitor::next(v); } @@ -1046,6 +1048,9 @@ void DeadCodeRemover::handle(TensorView* tv) { } bool DeadCodeRemover::registerReplacement(Val* old_val, Val* new_val) { + // Mark new val live + markLiveRecursive(new_val); + vals_to_replace_.emplace_back(old_val, new_val); if (old_val->isFusionInput()) { @@ -1102,9 +1107,9 @@ void DeadCodeRemover::markLiveRecursive(Statement* stmt) { return; } markLive(stmt); - if (stmt->isVal() && stmt->asVal()->definition()) { - markLiveRecursive(stmt); - } else { + if (stmt->isVal() && stmt->asVal()->definition() != nullptr) { + markLiveRecursive(stmt->asVal()->definition()); + } else if (stmt->isExpr()) { auto expr = stmt->asExpr(); for (const auto inp : expr->outputs()) { markLive(inp); @@ -1116,7 +1121,20 @@ void DeadCodeRemover::markLiveRecursive(Statement* stmt) { } bool DeadCodeRemover::markDead(Statement* stmt) { - return (bool)live_statements_.erase(stmt); + if (auto e = dynamic_cast(stmt)) { + // If this is an expression, ensure it is not marked as a future live use + // of any of its inputs + for (Val* inp : e->inputs()) { + if (std::find(inp->uses().begin(), inp->uses().end(), e) == + inp->uses().end()) { + auto fu_it = future_uses_.find(inp); + if (fu_it != future_uses_.end()) { + fu_it->second.erase(e); + } + } + } + } + return live_statements_.erase(stmt); } bool DeadCodeRemover::modifyFusion() const { diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index 910ccca3cff..2a0e5b92188 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -462,6 +462,14 @@ class DeadCodeRemover : BackwardVisitor { //! Check whether all uses have been marked dead inline bool allUsesDead(Val* val) const { + auto fu_it = future_uses_.find(val); + if (fu_it != future_uses_.end() && !fu_it->second.empty()) { + // Regardless of whether current uses are marked dead, this appears in a + // replacement expression, so it has a future live use and we should keep + // it. + return false; + } + return std::all_of(val->uses().begin(), val->uses().end(), [&](Expr* use) { return isDead(use); }); @@ -480,6 +488,21 @@ class DeadCodeRemover : BackwardVisitor { //! Mark a single Statement as being alive. inline void markLive(Statement* stmt) { live_statements_.insert(stmt); + if (auto e = dynamic_cast(stmt)) { + // Check if this expression is already in uses() for each of its inputs + // and if not, record it in future_uses_ + for (Val* inp : e->inputs()) { + if (std::find(inp->uses().begin(), inp->uses().end(), e) == + inp->uses().end()) { + auto fu_it = future_uses_.find(inp); + if (fu_it == future_uses_.end()) { + future_uses_.emplace(inp, std::unordered_set({e})); + } else { + fu_it->second.insert(e); + } + } + } + } } //! Ensure that a Statement and its upstream Statements are alive. If it is an @@ -529,6 +552,13 @@ class DeadCodeRemover : BackwardVisitor { //! them separately here. std::vector vals_to_remove_; std::vector exprs_to_remove_; + + //! This holds additional _future_ uses of each val. val->uses() only returns + //! currently live uses, so until we have finalized all replacements, new uses + //! will not appear there. The mapping below gets populated whenever we mark + //! an expression as live, if that expression is not already in inp->uses() + //! for any of its inputs. + std::unordered_map> future_uses_; }; } // namespace nvfuser diff --git a/csrc/kernel.cpp b/csrc/kernel.cpp index 89b7ded5066..99214dc6f6b 100644 --- a/csrc/kernel.cpp +++ b/csrc/kernel.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -295,6 +296,15 @@ class ValidateAllocation : private OptOutConstDispatch { } // namespace +Kernel::Kernel(Fusion* fusion, PrimDataType index_type) + : Fusion(*fusion), index_type_(index_type) { + // Index type must be resolved to either int32 or int64 + NVF_ERROR( + index_type_ == PrimDataType::Int || index_type_ == PrimDataType::Int32 || + "Invalid index type: ", + index_type_); +} + // TODO(kir): Kernel IR validation void Kernel::finalize(std::vector top_level_exprs) { NVF_ERROR(top_level_exprs_.empty()); diff --git a/csrc/kernel.h b/csrc/kernel.h index f93cf69c600..0a41601a1c1 100644 --- a/csrc/kernel.h +++ b/csrc/kernel.h @@ -175,16 +175,7 @@ class Kernel final : public Fusion { // we do something like generate an initialization statement for a reduction // TV, we may want to continue to do fusion like analysis on the original // expression. - // TODO: Assert index type is int or int32 - Kernel(Fusion* fusion, PrimDataType index_type = PrimDataType::Int) - : Fusion(*fusion), index_type_(index_type) { - // Index type must be resolved to either int32 or int64 - NVF_ERROR( - index_type_ == PrimDataType::Int || - index_type_ == PrimDataType::Int32 || "Invalid index type: ", - index_type_); - } - + Kernel(Fusion* fusion, PrimDataType index_type = PrimDataType::Int); Kernel() = delete; // No move or copy semantics diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index c1f17ad7862..97f640ce32f 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -518,18 +518,12 @@ std::vector FusionExecutorCache::runFusionWithInputs( // Removing aliased outputs, since those are updated by the Fusion. It is not // semantically correct to actually return them as outputs from // fusion. - const auto& io_alias = fusion->ioAlias(); - auto should_remove = [&io_alias](Val* out_val) -> bool { - if (auto alias_it = io_alias.find(out_val); alias_it != io_alias.end()) { - return alias_it->second.second.hide_output; - } - return false; - }; - NVF_ERROR(fusion->outputs().size() == outputs.size()); size_t new_size = 0; for (size_t out_index = 0; out_index < outputs.size(); out_index++) { - if (!should_remove(fusion->outputs()[out_index])) { + const AliasInfo* alias_info = + fusion->getOutputAlias(fusion->outputs()[out_index]).second; + if (alias_info == nullptr || !alias_info->hide_output) { outputs[new_size] = outputs[out_index]; new_size++; } diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index 71a5b932cee..4081f3c181f 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -242,7 +242,33 @@ Asm::Asm( namespace { -const char* dataTypeToPTXConstraints(DataType dt) { +// If value is a kir::TensorIndex, and its index is a pointer type, then +// return the pointer type. Otherwise return the value's dtype. +DataType getTypeOrIndexType(Val* value) { + if (auto ti = dynamic_cast(value)) { + if (isPointerType(ti->index()->dtype())) { + return ti->index()->dtype(); + } + } + return value->dtype(); +} + +const char* getPTXConstraints(Val* value) { + DataType dt = getTypeOrIndexType(value); + if (dt == DataType::Bool) { + return "r"; + } + if (auto ti = dynamic_cast(value)) { + // If the index type is a pointer type, then we directly uses the pointer in + // the generated code, instead of generating something like T0[i]. For this + // case we should use the pointer type as the constraint. + if (isPointerType(ti->index()->dtype())) { + dt = ti->index()->dtype(); + } + } + if (std::holds_alternative(dt.type)) { + dt = *std::get(dt.type).type; + } auto size = dataTypeSize(dt); switch (size) { case 2: @@ -272,7 +298,7 @@ std::vector> Asm::constraintsAndOutputs() const { std::string prefix = "="; for (auto out : outputs()) { NVF_ERROR(!out->isConst()); - result.emplace_back(prefix + dataTypeToPTXConstraints(out->dtype()), out); + result.emplace_back(prefix + getPTXConstraints(out), out); } return result; } @@ -283,13 +309,51 @@ std::vector> Asm::constraintsAndInputs() const { if (in->isConst()) { constraint = "n"; } else { - constraint = dataTypeToPTXConstraints(in->dtype()); + constraint = getPTXConstraints(in); } result.emplace_back(constraint, in); } return result; } +std::string Asm::parameters() const { + int64_t counter = 0; + int64_t bool_counter = 0; + std::stringstream ss; + auto gen = [&counter, &bool_counter, &ss](Val* v) { + DataType dtype = getTypeOrIndexType(v); + if (counter > 0) { + ss << ", "; + } + if (isPointerType(dtype)) { + ss << "[%" << counter++ << "]"; + } else if (dtype == DataType::Bool) { + ss << "p" << bool_counter++; + } else if (std::holds_alternative(dtype.type)) { + ss << "%" << counter++; + } else if (std::holds_alternative(dtype.type)) { + auto type = std::get(dtype.type); + ss << "{"; + for (auto i : c10::irange(type.size)) { + if (i > 0) { + ss << ", "; + } + ss << "%" << counter++; + } + ss << "}"; + } else { + NVF_ERROR(false, "Unsupported data type ", dtype); + } + }; + for (auto out : outputs()) { + gen(out); + } + for (auto in : inputs()) { + gen(in); + } + return ss.str(); +} + std::string Asm::toString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << "asm"; @@ -785,8 +849,8 @@ bool ForLoop::isUnrollable() const { // dimension, cannot be bound to a parallel dimension, must not be // vectorized. return start()->isConstScalar() && stop()->isConstScalar() && - !iter_domain()->isThread() && !iter_domain()->isBroadcast() && - !vectorize(); + !iter_domain()->isThread() && !iter_domain()->isDeviceDim() && + !iter_domain()->isBroadcast() && !vectorize(); } bool ForLoop::isUnrolled() const { @@ -860,7 +924,7 @@ bool ForLoop::isTrivial() const { // These loops are not materialized if (vectorize() || iter_domain()->isBroadcast() || iter_domain()->isStride() || iter_domain()->isMma() || - iter_domain()->isBulk()) { + iter_domain()->isBulk() || iter_domain()->isDeviceDim()) { return true; } diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index e9224e93e76..f3728b73ace 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -229,8 +229,19 @@ class Asm final : public Expr { return options().memory; } + bool hasBooleanInput() const { + for (auto input : inputs()) { + if (input->dtype() == DataType::Bool) { + return true; + } + } + return false; + } + std::vector> constraintsAndOutputs() const; std::vector> constraintsAndInputs() const; + + std::string parameters() const; }; //! Allocate is a lower level Node that describes a buffer of memory that diff --git a/csrc/linked_hash_map.h b/csrc/linked_hash_map.h index 093dfafc7e1..890b7063e85 100644 --- a/csrc/linked_hash_map.h +++ b/csrc/linked_hash_map.h @@ -25,6 +25,7 @@ class LinkedHashMap { public: using value_type = std::pair; using const_iterator = typename std::list::const_iterator; + using iterator = typename std::list::iterator; LinkedHashMap() = default; LinkedHashMap(const LinkedHashMap&) = delete; @@ -32,11 +33,15 @@ class LinkedHashMap { LinkedHashMap(LinkedHashMap&&) = default; LinkedHashMap& operator=(LinkedHashMap&&) = default; - std::pair erase(const K& key); + // Returns the value associated with `key` and the list iterator following the + // removed element. + std::pair erase(const K& key); void insert(const_iterator i, const K& key, const V& value); + void insert(const_iterator i, const K& key, V&& value); void pushBack(const K& key, const V& value); + void pushBack(const K& key, V&& value); const_iterator begin() const { return order_.begin(); @@ -47,15 +52,16 @@ class LinkedHashMap { private: std::list order_; - std::unordered_map key_to_index_; + std::unordered_map key_to_index_; }; template -std::pair::const_iterator> LinkedHashMap:: - erase(const K& key) { - const_iterator index = key_to_index_.at(key); +std::pair::iterator> LinkedHashMap::erase( + const K& key) { + iterator index = key_to_index_.at(key); key_to_index_.erase(key); - return {index->second, order_.erase(index)}; + V value = std::move(index->second); + return {std::move(value), order_.erase(index)}; } template @@ -63,8 +69,18 @@ void LinkedHashMap::insert( LinkedHashMap::const_iterator i, const K& key, const V& value) { - bool inserted = - key_to_index_.emplace(key, order_.insert(i, {key, value})).second; + auto j = order_.emplace(i, key, value); + bool inserted = key_to_index_.emplace(key, j).second; + NVF_CHECK(inserted, "Key already existed"); +} + +template +void LinkedHashMap::insert( + LinkedHashMap::const_iterator i, + const K& key, + V&& value) { + auto j = order_.emplace(i, key, std::move(value)); + bool inserted = key_to_index_.emplace(key, j).second; NVF_CHECK(inserted, "Key already existed"); } @@ -73,4 +89,9 @@ void LinkedHashMap::pushBack(const K& key, const V& value) { insert(order_.end(), key, value); } +template +void LinkedHashMap::pushBack(const K& key, V&& value) { + insert(order_.end(), key, std::move(value)); +} + } // namespace nvfuser diff --git a/csrc/macros.h b/csrc/macros.h index 375e03c8567..b845cf18323 100644 --- a/csrc/macros.h +++ b/csrc/macros.h @@ -15,3 +15,9 @@ #if defined(__GLIBCXX__) && __GLIBCXX__ >= 20230714 #define STD_UNORDERED_SET_SUPPORTS_INCOMPLETE_TYPE 1 #endif + +#if __cplusplus < 202002L +#define IS_CPP20 0 +#else +#define IS_CPP20 1 +#endif diff --git a/csrc/mma_type.cpp b/csrc/mma_type.cpp index 76861c776b0..ebc9c1bb62f 100644 --- a/csrc/mma_type.cpp +++ b/csrc/mma_type.cpp @@ -12,145 +12,23 @@ namespace nvfuser { -MmaOp* MmaOptions::mmaOp() const { - NVF_ERROR( - accumulator_tv != nullptr && accumulator_tv->definition() != nullptr, - "Invalid accumulator_tv."); - auto mma_op = dynamic_cast(accumulator_tv->definition()); - NVF_ERROR(mma_op != nullptr, "accumulator tv not an output of mma op"); - return mma_op; +GemmTile getMmaOpShape(MmaMacro macro) { + return {getM(macro), getN(macro), getK(macro)}; } -MmaBuilder::MmaBuilder( - MmaOptions::MacroType macro, - MatMulTileOptions gemm_tile) { - option_.macro = macro; -} - -MmaBuilder& MmaBuilder::layout(MmaOptions::MmaLayout layout) { - option_.layout = layout; - return *this; -} - -MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) { - option_.operand = a_or_b; - return *this; -} - -// TODO: validate op config -MmaOptions MmaBuilder::build() const { - NVF_CHECK( - option_.accumulator_tv != nullptr, - "Please configure accumulator tv before using swizzle options.") - return option_; -} - -void MmaBuilder::configureMma(MmaOp* mma) const { - NVF_CHECK(mma, "configureMma: invalid op object ", mma); - mma->configureOptions(option_); -} - -void MmaBuilder::accumulatorTv(TensorView* tv) { - NVF_CHECK( - tv->getMemoryType() == MemoryType::Local, "Mma only outputs to register"); - NVF_CHECK(tv->definition(), "Input cannot be accumulator tv"); - NVF_CHECK( - tv->definition()->isA(), - "Requires mma op output for reduction tv"); - option_.accumulator_tv = tv; -} - -namespace { - -// Utility to get ldmatrix direction a mma layout and operand -LoadStoreOpType getLdMatrixType(MmaOptions options) { - bool transpose = false; - switch (options.macro) { - case MmaOptions::MacroType::Turing_16_8_16: - case MmaOptions::MacroType::Ampere_16_8_16: - case MmaOptions::MacroType::Ampere_16_16_16: - case MmaOptions::MacroType::Turing_16_16_16: - // Turing mma assumes TN as default - transpose = (options.operand == MmaOptions::Operand::A && - !isOperandTransposed(options)) || - (options.operand == MmaOptions::Operand::B && - isOperandTransposed(options)); - break; - default: - NVF_ERROR(false, "unsupported op with ldmatrix"); - break; - } - return transpose ? LoadStoreOpType::LdMatrixTranspose - : LoadStoreOpType::LdMatrix; -} - -} // namespace - -LoadStoreOpType MmaBuilder::ldMatrix() const { - return getLdMatrixType(option_); -} - -bool isVolta(MmaOptions::MacroType macro) { - return macro == MmaOptions::MacroType::Volta_16_16_4; -} - -bool isTuring(MmaOptions::MacroType macro) { - return macro == MmaOptions::MacroType::Turing_16_8_16 || - macro == MmaOptions::MacroType::Turing_16_16_16; -} - -bool isAmpere(MmaOptions::MacroType macro) { - return macro == MmaOptions::MacroType::Ampere_16_8_8 || - macro == MmaOptions::MacroType::Ampere_16_8_16 || - macro == MmaOptions::MacroType::Ampere_16_16_16; -} - -bool isOperandTransposed(MmaOptions options) { - switch (options.operand) { - case MmaOptions::Operand::A: - return options.layout == MmaOptions::MmaLayout::TT || - options.layout == MmaOptions::MmaLayout::TN; - case MmaOptions::Operand::B: - return options.layout == MmaOptions::MmaLayout::TT || - options.layout == MmaOptions::MmaLayout::NT; - default: - NVF_CHECK(false, "isOperandTransposed: please specify operand"); - } - return false; -} - -GemmTile getMmaOpShape(MmaOptions::MacroType macro) { - switch (macro) { - case MmaOptions::MacroType::Volta_16_16_4: - return {16, 16, 4}; - case MmaOptions::MacroType::Turing_16_8_16: - case MmaOptions::MacroType::Ampere_16_8_16: - return {16, 8, 16}; - case MmaOptions::MacroType::Turing_16_16_16: - case MmaOptions::MacroType::Ampere_16_16_16: - return {16, 16, 16}; - case MmaOptions::MacroType::Ampere_16_8_8: - return {16, 8, 8}; - case MmaOptions::MacroType::NoMMA: - return {1, 1, 1}; - } - - NVF_ERROR(false, "unknown MMA macro"); -} - -std::string toString(MmaOptions::MmaLayout input_layout) { +std::string toString(MmaLayout input_layout) { std::stringstream ss; switch (input_layout) { - case MmaOptions::MmaLayout::TT: + case MmaLayout::TT: ss << "TT"; break; - case MmaOptions::MmaLayout::TN: + case MmaLayout::TN: ss << "TN"; break; - case MmaOptions::MmaLayout::NT: + case MmaLayout::NT: ss << "NT"; break; - case MmaOptions::MmaLayout::NN: + case MmaLayout::NN: ss << "NN"; break; default: @@ -159,30 +37,6 @@ std::string toString(MmaOptions::MmaLayout input_layout) { return ss.str(); } -std::string toString(MmaOptions::MacroType mt) { - std::stringstream ss; - switch (mt) { - case MmaOptions::MacroType::NoMMA: - ss << "NoOp"; - break; - case MmaOptions::MacroType::Volta_16_16_4: - ss << "M16N16K4"; - break; - case MmaOptions::MacroType::Turing_16_8_16: - case MmaOptions::MacroType::Ampere_16_8_16: - ss << "M16N8K16"; - break; - case MmaOptions::MacroType::Turing_16_16_16: - case MmaOptions::MacroType::Ampere_16_16_16: - ss << "M16N16K16"; - break; - default: - NVF_ERROR(false, "undefined mma type"); - break; - } - return ss.str(); -} - std::string toString(const GemmTile& tile) { std::stringstream ss; ss << "[" << tile.m << ", " << tile.n << ", " << tile.k << "]"; @@ -198,32 +52,34 @@ std::string toString(const MatMulTileOptions& opts) { return ss.str(); } -std::string toString(MmaOptions::MacroType mt, bool) { - switch (mt) { - case MmaOptions::MacroType::Ampere_16_8_8: - return "Ampere_16_8_8"; - case MmaOptions::MacroType::Ampere_16_8_16: - return "Ampere_16_8_16"; - case MmaOptions::MacroType::Ampere_16_16_16: - return "Ampere_16_16_16"; - case MmaOptions::MacroType::NoMMA: +std::string toString(MmaMacro macro) { + std::stringstream ss; + auto underlying = static_cast(macro); + switch (underlying.arch) { + case MmaMacroEncode::Arch::NoMma: return "NoOp"; - case MmaOptions::MacroType::Turing_16_8_16: - return "Turing_16_8_16"; - case MmaOptions::MacroType::Turing_16_16_16: - return "Turing_16_16_16"; - case MmaOptions::MacroType::Volta_16_16_4: - return "Volta_16_16_4"; + case MmaMacroEncode::Arch::Volta: + ss << "Volta"; + break; + case MmaMacroEncode::Arch::Turing: + ss << "Turing"; + break; + case MmaMacroEncode::Arch::Ampere: + ss << "Ampere"; + break; + case MmaMacroEncode::Arch::Hopper: + ss << "Hopper"; + break; } - NVF_ERROR(false, "Unsupported mma type"); - return "Unsupported"; + ss << "_" << underlying.m << "_" << underlying.n << "_" << underlying.k; + return ss.str(); } -size_t hash(MmaOptions::MacroType macro) { +size_t hash(MmaMacro macro) { return std::hash{}(static_cast(macro)); } -size_t hash(MmaOptions::MmaLayout input_layout) { +size_t hash(MmaLayout input_layout) { return std::hash{}(static_cast(input_layout)); } diff --git a/csrc/mma_type.h b/csrc/mma_type.h index 0f75297089b..bee5c54b7b2 100644 --- a/csrc/mma_type.h +++ b/csrc/mma_type.h @@ -6,10 +6,20 @@ */ // clang-format on #pragma once + +#include + #include #include #include +#include + +#if IS_CPP20 +#include +#endif +#include + namespace nvfuser { constexpr std::string_view MATMUL_LOG_PREFIX = "[MATMUL DEBUG] "; @@ -20,14 +30,16 @@ enum class MatmulDomain { M = 0, N, K }; //! Named descriptors of TensorView roles in fusion //! INPUT_A - a producer of MMA input A //! INPUT_B - a producer of MMA input B -//! OUTPUT_D - the main consumer of MMA op results //! INPUT_C - a producer of a tensor used in fusion epilogue, //! for example tensor used in beta scaling fusion +//! OUTPUT_D - the main consumer of MMA op results +//! OUTPUT_AUX - fusion outputs that are consumers of OUTPUT_D //! //! Naming convention is based on the following formula: //! D = alpha * A x B + beta * C +//! AUX = relu(D) //! Note: bias vector tensors will be assigned to INPUT_C role. -enum class MatmulRole { INPUT_A = 0, INPUT_B, OUTPUT_D, INPUT_C }; +enum class MatmulRole { INPUT_A = 0, INPUT_B, OUTPUT_D, INPUT_C, OUTPUT_AUX }; //! The expected number of occurances of core TensorView roles in fusion static constexpr size_t MATMUL_CORE_ROLES_EXPECTED_COUNT = 1; @@ -71,149 +83,175 @@ struct MatMulTileOptions { } }; -//! Information for configuring and lowering mma ops -struct MmaOptions { - //! Type of mma instrinsic macro to use - //! This will translate to which mma intrinsic from runtime string - //! to be generated to implement the mma op. The current plan - //! is to have exactly one macro for each - //! (arch, datatype, operand layout) triple, though there - //! exists multiple possibilities for some cases, e.g. for Turing and fp16 - //! one can use 16_8_8 or 16_8_16. - //! Will consider adding more choices that the scheduler can pick from - //! when our perf target becomes more fine grained, which is more likely in - //! latency bound kernels. - enum class MacroType { - NoMMA = 0, - Volta_16_16_4, - Ampere_16_8_16, - Ampere_16_16_16, - Turing_16_8_16, - Turing_16_16_16, - Ampere_16_8_8 // place holder for tf32 - }; - - //! [Operand Layout Convention] - //! Operand layout, T=transposed/row_major, N=normal/col_major - //! Ordered by position of K - //! NT : K,M x K,N -> M,N - //! TT : M,K X K,N -> M,N - //! TN : M,K X N,K -> M,N - //! NN : K,M X N,K -> M,N - //! TODO: NN is currently not supported on pre-Turing and Hopper wgmma - enum class MmaLayout { NT = 0, TT, TN, NN }; - - //! Utility to annotate which input of mma this option struct describes - enum class Operand { Accumulator = 0, A, B }; - - //! Utility to annotate which mma macro this config uses. - MacroType macro = MacroType::NoMMA; - - //! Utility to annotate transposition of operands - MmaLayout layout = MmaLayout::TT; - - //! Utility to annotate which input of mma this option struct describes - Operand operand = Operand::A; - - bool operator==(const MmaOptions& other) const { - return macro == other.macro && layout == other.layout && - operand == other.operand; +enum class MmaMacro : uint64_t; + +struct MmaMacroEncode { + enum class Arch { NoMma, Volta, Turing, Ampere, Hopper } arch : 16; + unsigned m : 16; + unsigned n : 16; + unsigned k : 16; + + constexpr operator uint64_t() { +#if IS_CPP20 && !defined(__clang__) + // std::bit_cast for bit field is not supported by clang yet + return std::bit_cast(*this); +#else + return (uint64_t)arch << 48 | (uint64_t)m << 32 | (uint64_t)n << 16 | + (uint64_t)k; +#endif } - // The accumulator tensorview register supplied by the - // scheduler interface. Each mma builder is responsible - // for the parameters of one mma op, so the options struct - // would need a pointer to keep track of which mma op it - // is describing. - // Tracking mma expressions would not be stable as the expression - // can get deleted by mutate passes. - TensorView* accumulator_tv = nullptr; - - //! Returns the mma op that this options parameter list - //! is describing. See comment on accumulator_tv. - MmaOp* mmaOp() const; + constexpr operator MmaMacro(); + + constexpr MmaMacroEncode(MmaMacro macro); + + constexpr MmaMacroEncode(Arch arch, unsigned m, unsigned n, unsigned k) + : arch(arch), m(m), n(n), k(k) {} }; -//! User interface for configuring the mma and mma related -//! operators by specifying the mma instruction tile type -//! input data layout, and the operand position of a tensor. -class MmaBuilder { - public: - //! Initialized a mma builder, for the given mma instruction type. - //! TODO: the mma implementation is generic and should not have - //! strong dependency on the actual matmul tiling shapes. The - //! MatMulTileOptions provided in here is a WAR for mma format and - //! should be removed once there is support for labeling swizzles - //! on iterdomains. - MmaBuilder(MmaOptions::MacroType macro, MatMulTileOptions gemm_tile); - - //! User configuration function: - //! Specifies the input matrix layout for the mma instruction. - //! see [Operand Layout Convention]. - MmaBuilder& layout(MmaOptions::MmaLayout layout); - - //! User configuration function: - //! Specifies which element in the mma op this builder is generating - //! parameters for, i.e. A or B. This is useful when generating - //! data swizzles for different elements of mma. - //! - Operand::Accumulator means the parameters describe accumulator in mma - //! op. - //! - This option is ignored when configuring the mma operator itself. - MmaBuilder& operand(MmaOptions::Operand a_or_b); - - //! Generates the matching ldmatrix instruction type for the - //! specified mma option. - LoadStoreOpType ldMatrix() const; - - //! Store the accumulator tv register reference in mma builder - //! to avoid automatic matching of which mma ops. - void accumulatorTv(TensorView* tv); - - //! Fill in mma options in scheduling time. - //! Each mma op in Fusion IR must be configured once before lowering. - //! Mma options are configuration parameters used in lowering to mma - //! instrinsics, mainly the type of mma macro to use and input data layout - //! etc. - //! - //! TODO: This step will very likely be removed in a follow up PR. All of - //! the options configured here could actually be inferred from fusion IR - //! once we are feature complete. - void configureMma(MmaOp* mma) const; - - //! Export all the parameters with user's configurations applied. - MmaOptions build() const; - - private: - MmaOptions option_; +static_assert(sizeof(MmaMacroEncode) == sizeof(uint64_t)); + +//! Type of mma instrinsic macro to use +//! This will translate to which mma intrinsic from runtime string +//! to be generated to implement the mma op. The current plan +//! is to have exactly one macro for each +//! (arch, datatype, operand layout) triple, though there +//! exists multiple possibilities for some cases, e.g. for Turing and fp16 +//! one can use 16_8_8 or 16_8_16. +//! Will consider adding more choices that the scheduler can pick from +//! when our perf target becomes more fine grained, which is more likely in +//! latency bound kernels. + +#define MACRO(arch, m, n, k) \ + arch##_##m##_##n##_##k = MmaMacroEncode(MmaMacroEncode::Arch::arch, m, n, k) + +enum class MmaMacro : uint64_t { + NoMMA = 0, + + MACRO(Turing, 16, 8, 8), + MACRO(Turing, 16, 8, 16), + MACRO(Turing, 16, 16, 16), + + MACRO(Ampere, 16, 8, 16), + MACRO(Ampere, 16, 16, 16), + + MACRO(Hopper, 64, 8, 16), + MACRO(Hopper, 64, 16, 16), + MACRO(Hopper, 64, 24, 16), + MACRO(Hopper, 64, 32, 16), + MACRO(Hopper, 64, 40, 16), + MACRO(Hopper, 64, 48, 16), + MACRO(Hopper, 64, 56, 16), + MACRO(Hopper, 64, 64, 16), + MACRO(Hopper, 64, 72, 16), + MACRO(Hopper, 64, 80, 16), + MACRO(Hopper, 64, 88, 16), + MACRO(Hopper, 64, 96, 16), + MACRO(Hopper, 64, 104, 16), + MACRO(Hopper, 64, 112, 16), + MACRO(Hopper, 64, 120, 16), + MACRO(Hopper, 64, 128, 16), + MACRO(Hopper, 64, 136, 16), + MACRO(Hopper, 64, 144, 16), + MACRO(Hopper, 64, 152, 16), + MACRO(Hopper, 64, 160, 16), + MACRO(Hopper, 64, 168, 16), + MACRO(Hopper, 64, 176, 16), + MACRO(Hopper, 64, 184, 16), + MACRO(Hopper, 64, 192, 16), + MACRO(Hopper, 64, 200, 16), + MACRO(Hopper, 64, 208, 16), + MACRO(Hopper, 64, 216, 16), + MACRO(Hopper, 64, 224, 16), + MACRO(Hopper, 64, 232, 16), + MACRO(Hopper, 64, 240, 16), + MACRO(Hopper, 64, 248, 16), + MACRO(Hopper, 64, 256, 16), }; +#undef MACRO + +constexpr MmaMacroEncode::operator MmaMacro() { +#if IS_CPP20 && !defined(__clang__) + // std::bit_cast for bit field is not supported by clang yet + return std::bit_cast(*this); +#else + return static_cast(static_cast(*this)); +#endif +} + +constexpr MmaMacroEncode::MmaMacroEncode(MmaMacro macro) +#if IS_CPP20 && !defined(__clang__) +{ + // std::bit_cast for bit field is not supported by clang yet + *this = std::bit_cast(macro); +} +#else + : arch((Arch)(toUnderlying(macro) >> 48)), + m((toUnderlying(macro) >> 32) & 0xFFFF), + n((toUnderlying(macro) >> 16) & 0xFFFF), + k(toUnderlying(macro) & 0xFFFF) { +} +#endif + +//! [Operand Layout Convention] +//! Operand layout, T=transposed/row_major, N=normal/col_major +//! Ordered by position of K +//! NT : K,M x K,N -> M,N +//! TT : M,K X K,N -> M,N +//! TN : M,K X N,K -> M,N +//! NN : K,M X N,K -> M,N +enum class MmaLayout { NT = 0, TT, TN, NN }; + +//! Utility to annotate which input of mma this option struct describes +enum class MmaOperand { Accumulator = 0, A, B }; + //! GPU arch check for macro type -bool isVolta(MmaOptions::MacroType macro); -bool isTuring(MmaOptions::MacroType macro); -bool isAmpere(MmaOptions::MacroType macro); +inline bool isTuring(MmaMacro macro) { + return MmaMacroEncode(macro).arch == MmaMacroEncode::Arch::Turing; +} + +inline bool isAmpere(MmaMacro macro) { + return MmaMacroEncode(macro).arch == MmaMacroEncode::Arch::Ampere; +} + +inline bool isHopper(MmaMacro macro) { + return MmaMacroEncode(macro).arch == MmaMacroEncode::Arch::Hopper; +} + +//! Get the m size from macro type +inline int getM(MmaMacro macro) { + return MmaMacroEncode(macro).m; +} + +//! Get the n size from macro type +inline int getN(MmaMacro macro) { + return MmaMacroEncode(macro).n; +} -//! Returns true if the given option describes a transposed operand -bool isOperandTransposed(MmaOptions options); +//! Get the k size from macro type +inline int getK(MmaMacro macro) { + return MmaMacroEncode(macro).k; +} // Unpacked constants from macro type: // exact numbers are defined by each individual instruction. -int getOutputRegisterSize(MmaOptions::MacroType macro); -int getInputARegisterSize(MmaOptions::MacroType macro); -int getInputBRegisterSize(MmaOptions::MacroType macro); +int getOutputRegisterSize(MmaMacro macro); +int getInputARegisterSize(MmaMacro macro); +int getInputBRegisterSize(MmaMacro macro); // Unpack MMA op shape -GemmTile getMmaOpShape(MmaOptions::MacroType macro); +GemmTile getMmaOpShape(MmaMacro macro); // MMA stringify utils -std::string toString(MmaOptions::MacroType macro); -std::string toString(MmaOptions::MmaLayout input_layout); +std::string toString(MmaLayout input_layout); std::string toString(const GemmTile& tile); std::string toString(const MatMulTileOptions& opts); -std::string toString(MmaOptions::MacroType macro, bool); +std::string toString(MmaMacro macro); // MMA hash utils -size_t hash(MmaOptions::MacroType macro); -size_t hash(MmaOptions::MmaLayout input_layout); +size_t hash(MmaMacro macro); +size_t hash(MmaLayout input_layout); size_t hash(const GemmTile& tile); size_t hash(const MatMulTileOptions& opts); } // namespace nvfuser diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index c37af5c3544..52ed4ab9fe2 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -6,6 +6,9 @@ */ // clang-format on #ifdef USE_DISTRIBUTED +#ifdef USE_C10D_NCCL +#include +#endif #include @@ -107,7 +110,9 @@ std::string Communication::toString(int indent) const { Broadcast::Broadcast(CommParams params) : Communication(params, "broadcast") {} -c10::intrusive_ptr Broadcast::post(Communicator& comm) { +c10::intrusive_ptr Broadcast::post( + Communicator& comm, + std::optional backend) { post_common(*this, comm); if (comm.deviceId() == params_.root) { @@ -126,7 +131,7 @@ c10::intrusive_ptr Broadcast::post(Communicator& comm) { return nullptr; } - return comm.getBackendForTeam(params_.team) + return comm.getBackendForTeam(params_.team, backend) ->broadcast( comm.deviceId() == params_.root ? params_.src_bufs : params_.dst_bufs, {.rootRank = root_relative_index_}); @@ -137,7 +142,9 @@ Gather::Gather(CommParams params) : Communication(params, "gather") { NVF_ERROR(params_.team.size() > 1, "the team size must be greater than 1"); } -c10::intrusive_ptr Gather::post(Communicator& comm) { +c10::intrusive_ptr Gather::post( + Communicator& comm, + std::optional backend) { post_common(*this, comm); // This is used to change the representation of the buffers to match c10d // ProcessGroup API @@ -149,7 +156,7 @@ c10::intrusive_ptr Gather::post(Communicator& comm) { assertBufferCount(params_.dst_bufs, 0); } auto work = - comm.getBackendForTeam(params_.team) + comm.getBackendForTeam(params_.team, backend) ->gather( buf_list, params_.src_bufs, {.rootRank = root_relative_index_}); if (comm.deviceId() == params_.root) { @@ -165,13 +172,15 @@ Allgather::Allgather(CommParams params) NVF_ERROR(params_.team.size() > 1, "the team size must be greater than 1"); } -c10::intrusive_ptr Allgather::post(Communicator& comm) { +c10::intrusive_ptr Allgather::post( + Communicator& comm, + std::optional backend) { post_common(*this, comm); // This is used to change the representation of the buffers to match c10d // ProcessGroup API std::vector> buf_list; buf_list = {std::move(params_.dst_bufs)}; - auto work = comm.getBackendForTeam(params_.team) + auto work = comm.getBackendForTeam(params_.team, backend) ->allgather(buf_list, params_.src_bufs, {}); params_.dst_bufs = std::move(buf_list.back()); return work; @@ -182,7 +191,9 @@ Scatter::Scatter(CommParams params) : Communication(params, "scatter") { NVF_ERROR(params_.team.size() > 1, "the team size must be greater than 1"); } -c10::intrusive_ptr Scatter::post(Communicator& comm) { +c10::intrusive_ptr Scatter::post( + Communicator& comm, + std::optional backend) { post_common(*this, comm); // This is used to change the representation of the buffers to match c10d // ProcessGroup API @@ -194,7 +205,7 @@ c10::intrusive_ptr Scatter::post(Communicator& comm) { assertBufferCount(params_.src_bufs, 0); } auto work = - comm.getBackendForTeam(params_.team) + comm.getBackendForTeam(params_.team, backend) ->scatter( params_.dst_bufs, buf_list, {.rootRank = root_relative_index_}); if (comm.deviceId() == params_.root) { @@ -203,13 +214,87 @@ c10::intrusive_ptr Scatter::post(Communicator& comm) { return work; } +Reduce::Reduce(CommParams params) : Communication(params, "reduce") { + assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); + assertBufferCount(params_.src_bufs, 1); + NVF_ERROR(params_.team.size() > 1, "the team size must be greater than 1"); +} + +c10::intrusive_ptr Reduce::post( + Communicator& comm, + std::optional backend) { + if (comm.deviceId() == params_.root) { + assertBufferCount(params_.dst_bufs, 1); + } else { + assertBufferCount(params_.dst_bufs, 0); + } + post_common(*this, comm); + auto& buf = + (comm.deviceId() == params_.root) ? params_.dst_bufs : params_.src_bufs; + c10d::ReduceOptions options = { + .reduceOp = params_.redOp, .rootRank = root_relative_index_}; + auto team_backend = comm.getBackendForTeam(params_.team, backend); +#ifdef USE_C10D_NCCL + if (backend == CommunicatorBackend::nccl) { + auto nccl_backend = + dynamic_cast(team_backend.get()); + return nccl_backend->_reduce_oop(buf, params_.src_bufs, options); + } +#endif + if (comm.deviceId() == params_.root) { + doLocalCopy(params_.dst_bufs.at(0), params_.src_bufs.at(0)); + } + return team_backend->reduce(buf, options); +} + +Allreduce::Allreduce(CommParams params) + : Communication(params, "allreduce", false) { + assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); + assertBufferCount(params_.src_bufs, 1); + assertBufferCount(params_.dst_bufs, 1); + NVF_ERROR(params_.team.size() > 1, "the team size must be greater than 1"); +} + +c10::intrusive_ptr Allreduce::post( + Communicator& comm, + std::optional backend) { + post_common(*this, comm); + doLocalCopy(params_.dst_bufs.at(0), params_.src_bufs.at(0)); + return comm.getBackendForTeam(params_.team, backend) + ->allreduce(params_.dst_bufs, {.reduceOp = params_.redOp}); +} + +ReduceScatter::ReduceScatter(CommParams params) + : Communication(params, "reduce_scatter", false) { + assertBufferCount(params_.src_bufs, params_.team.size()); + assertBufferCount(params_.dst_bufs, 1); + NVF_ERROR(params_.team.size() > 1, "the team size must be greater than 1"); +} + +c10::intrusive_ptr ReduceScatter::post( + Communicator& comm, + std::optional backend) { + post_common(*this, comm); + // This is used to change the representation of the buffers to match c10d + // ProcessGroup API + std::vector> buf_list = {std::move(params_.src_bufs)}; + auto work = comm.getBackendForTeam(params_.team, backend) + ->reduce_scatter( + params_.dst_bufs, buf_list, {.reduceOp = params_.redOp}); + params_.src_bufs = std::move(buf_list.back()); + return work; +} + SendRecv::SendRecv(CommParams params) : Communication(params, "send/recv") { + assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); NVF_ERROR( params_.team.size() == 1 || params_.team.size() == 2, "the team size should be 1 or 2"); } -c10::intrusive_ptr SendRecv::post(Communicator& comm) { +c10::intrusive_ptr SendRecv::post( + Communicator& comm, + std::optional backend) { post_common(*this, comm); if (comm.deviceId() == params_.root) { @@ -230,7 +315,8 @@ c10::intrusive_ptr SendRecv::post(Communicator& comm) { (params_.team.at(0) == params_.root) ? params_.team.at(1) : params_.team.at(0), params_.root, - params_.dst_bufs.empty() ? params_.src_bufs : params_.dst_bufs); + params_.dst_bufs.empty() ? params_.src_bufs : params_.dst_bufs, + backend); } } // namespace nvfuser diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index edcbeec3ded..505ba4dcc05 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -10,6 +10,8 @@ #include #include +#include +#include namespace nvfuser { @@ -22,6 +24,7 @@ struct CommParams { std::vector src_bufs; std::vector dst_bufs; Team team; // should not have duplicate + c10d::ReduceOp::RedOpType redOp = c10d::ReduceOp::RedOpType::UNUSED; }; /* @@ -64,7 +67,9 @@ class Communication { // Triggers the execution of the communication. This is a non-blocking call. // The communication can be posted multiple times - virtual c10::intrusive_ptr post(Communicator& comm) = 0; + virtual c10::intrusive_ptr post( + Communicator& comm, + std::optional backend = std::nullopt) = 0; protected: // argument "name" is only used for printing @@ -95,7 +100,9 @@ Copies the root's src buffer to each device's dst buffer class Broadcast : public Communication { public: Broadcast(CommParams params); - c10::intrusive_ptr post(Communicator& comm) override; + c10::intrusive_ptr post( + Communicator& comm, + std::optional backend = std::nullopt) override; }; /* @@ -112,7 +119,9 @@ root's buffers. class Gather : public Communication { public: Gather(CommParams params); - c10::intrusive_ptr post(Communicator& comm) override; + c10::intrusive_ptr post( + Communicator& comm, + std::optional backend = std::nullopt) override; }; /* @@ -127,7 +136,9 @@ buffers class Allgather : public Communication { public: Allgather(CommParams params); - c10::intrusive_ptr post(Communicator& comm) override; + c10::intrusive_ptr post( + Communicator& comm, + std::optional backend = std::nullopt) override; }; /* @@ -143,7 +154,56 @@ The order of the buffers matches the order of the receiver devices class Scatter : public Communication { public: Scatter(CommParams params); - c10::intrusive_ptr post(Communicator& comm) override; + c10::intrusive_ptr post( + Communicator& comm, + std::optional backend = std::nullopt) override; +}; + +/* +Reduce the src buffers to the root's dst buffer. + +Requirements: + - the root is set and belongs to the team + - the root has one src buffers and one dst buffer + - non-roots have one src buffer and no dst buffer + - all buffers have the same size +*/ +class Reduce : public Communication { + public: + Reduce(CommParams params); + c10::intrusive_ptr post( + Communicator& comm, + std::optional backend = std::nullopt) override; +}; + +/* +Reduce the src buffers to the dst buffer. + +Requirements: + - all devices have one src buffer and one dst buffer + - all buffers have the same size +*/ +class Allreduce : public Communication { + public: + Allreduce(CommParams params); + c10::intrusive_ptr post( + Communicator& comm, + std::optional backend = std::nullopt) override; +}; + +/* +Reduce all the src buffers and shard the result to the dst buffers. + +Requirements: + - all devices have src buffer and one dst buffer + - all buffers have the same size +*/ +class ReduceScatter : public Communication { + public: + ReduceScatter(CommParams params); + c10::intrusive_ptr post( + Communicator& comm, + std::optional backend = std::nullopt) override; }; /* @@ -164,7 +224,9 @@ buffer class SendRecv : public Communication { public: SendRecv(CommParams params); - c10::intrusive_ptr post(Communicator& comm) override; + c10::intrusive_ptr post( + Communicator& comm, + std::optional backend = std::nullopt) override; }; } // namespace nvfuser diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index 56e55d92c56..5d5ed9e619e 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -101,11 +101,13 @@ bool parseEnv( return true; } -inline std::string getTeamKey(const Team& team) { +inline std::string getTeamKey(const Team& team, CommunicatorBackend backend) { + std::string backend_str = + (backend == CommunicatorBackend::ucc) ? "ucc" : "nccl"; return std::accumulate( std::begin(team), std::end(team), - std::string{}, + std::string{backend_str}, [](const std::string& a, const RankType& b) { return a.empty() ? std::to_string(b) : a + ',' + std::to_string(b); }); @@ -114,7 +116,7 @@ inline std::string getTeamKey(const Team& team) { // creates and return a process group backend c10::intrusive_ptr createBackend( CommunicatorBackend backend, - ::c10::intrusive_ptr store, + c10::intrusive_ptr store, RankType rank, int64_t size) { #ifdef USE_C10D_NCCL @@ -135,7 +137,9 @@ c10::intrusive_ptr createBackend( #if defined(USE_C10D_UCC) && defined(NVFUSER_BUILD_WITH_UCC) if (backend == CommunicatorBackend::ucc) { - return c10::make_intrusive<::c10d::ProcessGroupUCC>(store, rank, size); + constexpr auto timeout = std::chrono::milliseconds(30 * 60 * 1000); + return c10d::ProcessGroupUCC::createProcessGroupUCC( + store, rank, size, timeout); } #endif NVF_CHECK(false, "no distributed backend available"); @@ -145,12 +149,14 @@ Communicator::Communicator( CommunicatorBackend backend, RankType server_local_rank) : is_available_(false), - backend_type_(backend), + default_backend_(backend), rank_(0), size_(0), local_rank_(0), local_size_(0), - master_port_(0) { + master_port_(0), + ucc_available_(false), + nccl_available_(false) { // retrieves rank and communicator size is_available_ = parseEnv( rank_, size_, local_rank_, local_size_, master_addr_, master_port_); @@ -173,15 +179,20 @@ Communicator::Communicator( store_opts.port = master_port_ ? master_port_ : comm_master_port_default; store_ = c10::make_intrusive(master_addr_, store_opts); - // creates the world's backend - std::vector all_ranks(size_); - std::iota(all_ranks.begin(), all_ranks.end(), 0); - world_ = getBackendForTeam(all_ranks); +#if defined(USE_C10D_UCC) && defined(NVFUSER_BUILD_WITH_UCC) + ucc_available_ = true; +#endif + +#ifdef USE_C10D_NCCL + nccl_available_ = true; +#endif } c10::intrusive_ptr Communicator::getBackendForTeam( - const Team& team) { - std::string team_key = getTeamKey(team); + const Team& team, + std::optional backend) { + CommunicatorBackend b = getBackend(backend); + std::string team_key = getTeamKey(team, b); // check if backend associated with the team is present in the cache if (backends_.find(team_key) == backends_.end()) { // create the backend and cache it @@ -195,7 +206,7 @@ c10::intrusive_ptr Communicator::getBackendForTeam( // generate a string key which is unique to the team // create the team and cache it backends_[team_key] = createBackend( - backend_type_, + b, c10::make_intrusive(team_key, store_), team_rank, static_cast(team.size())); @@ -207,15 +218,26 @@ c10::intrusive_ptr Communicator::sendRecv( DeviceIdxType receiver, DeviceIdxType sender, std::vector& tensors, + std::optional backend, int tag) { NVF_ERROR( deviceId() == sender || deviceId() == receiver, "only sender or receiver should post the sendRecv"); NVF_ERROR(sender != receiver, "cannot send to self"); + + auto world = getWorld(backend); if (deviceId() == sender) { - return world_->send(tensors, static_cast(dIdToRank(receiver)), tag); + return world->send(tensors, static_cast(dIdToRank(receiver)), tag); } - return world_->recv(tensors, static_cast(dIdToRank(sender)), tag); + return world->recv(tensors, static_cast(dIdToRank(sender)), tag); +} + +c10::intrusive_ptr Communicator::getWorld( + std::optional backend) { + std::vector all_ranks(size_); + std::iota(all_ranks.begin(), all_ranks.end(), 0); + + return getBackendForTeam(all_ranks, backend); } } // namespace nvfuser diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index 54015e98cdb..a042722cd53 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -31,10 +31,14 @@ namespace nvfuser { using RankType = DeviceIdxType; -// Supported backends. TODO: only tested with nccl for now +// Supported backends. TODO: gloo untested enum class CommunicatorBackend { nccl, ucc, gloo }; +#ifdef USE_C10D_NCCL constexpr CommunicatorBackend comm_backend_default = CommunicatorBackend::nccl; +#else +constexpr CommunicatorBackend comm_backend_default = CommunicatorBackend::ucc; +#endif constexpr int comm_server_local_rank_default = 0; constexpr int comm_master_port_default = c10d::TCPStoreOptions::kDefaultPort; // 29500 @@ -63,20 +67,28 @@ class Communicator { return local_size_; } + // sets the communicator's default backend + void setDefaultBackend(CommunicatorBackend backend) { + default_backend_ = backend; + } + // performs a send/receive p2p data transfer c10::intrusive_ptr sendRecv( DeviceIdxType receiver, DeviceIdxType sender, std::vector& tensor, + std::optional backend = std::nullopt, int tag = 0); // performs a blocking barrier in the communicator - void barrier() const { - world_->barrier()->wait(); + void barrier(std::optional backend = std::nullopt) { + getWorld(backend)->barrier()->wait(); } // returns the backend associated with a team - c10::intrusive_ptr getBackendForTeam(const Team& team); + c10::intrusive_ptr getBackendForTeam( + const Team& team, + std::optional backend); // returns the device associated with the current process auto device() const { @@ -88,6 +100,21 @@ class Communicator { return rankToDiD(rank_); } + // returns world backend for communicator backend or default backend if not + // specified. + c10::intrusive_ptr getWorld( + std::optional backend = std::nullopt); + + // returns if a backend is available for creation + bool isBackendAvailable(CommunicatorBackend backend) const { + if (backend == CommunicatorBackend::ucc) { + return ucc_available_; + } else if (backend == CommunicatorBackend::nccl) { + return nccl_available_; + } + return false; + } + private: // returns the rank corresponding to a device index RankType dIdToRank(DeviceIdxType d_id) const { @@ -99,18 +126,22 @@ class Communicator { return static_cast(rank); } + CommunicatorBackend getBackend(std::optional backend) { + return backend.value_or(default_backend_); + } + bool is_available_; - CommunicatorBackend backend_type_; + CommunicatorBackend default_backend_; RankType rank_; int64_t size_; RankType local_rank_; int64_t local_size_; std::string master_addr_; int master_port_; + bool ucc_available_; + bool nccl_available_; // stores the world's store used for the backend init c10::intrusive_ptr store_; - // stores the world's backend - c10::intrusive_ptr world_; // cache for the created backends. The keys are strings generated from Teams std::unordered_map> backends_; }; diff --git a/csrc/multidevice/device_mesh.h b/csrc/multidevice/device_mesh.h index 3b258d90c00..3f495f1073b 100644 --- a/csrc/multidevice/device_mesh.h +++ b/csrc/multidevice/device_mesh.h @@ -40,15 +40,6 @@ class DeviceMesh final { return std::find(vector_.begin(), vector_.end(), device) != vector_.end(); } - // returns the relative index of a device in the mesh - // Throws if the device is not found - DeviceIdxType findIndex(const DeviceIdxType device) const { - auto it = std::find(vector_.begin(), vector_.end(), device); - NVF_ERROR( - it != vector_.end(), "device index ", device, " is not in the mesh"); - return std::distance(vector_.begin(), it); - } - private: void setDevices(std::vector devices) { vector_ = devices; diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index 081ff40ceba..9702d2dae51 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -10,28 +10,12 @@ #include #include #include +#include namespace nvfuser { namespace { -// Returns whether a TensorView has its first axis parallelized on Didx -// Checks that the other axis are not parallelized on Didx -bool isParallelD(TensorView* tv) { - std::vector is_parallel_d; - for (IterDomain* id : tv->getLeafDomain()) { - is_parallel_d.push_back(isParallelTypeDeviceDim(id->getParallelType())); - } - // Currently, only the most external dim is allowed to be parallelized - NVF_ERROR(tv->getMaybeRFactorDomain() == tv->getLeafDomain()); - for (auto i : c10::irange(1, is_parallel_d.size())) { - NVF_ERROR( - !is_parallel_d.at(i), - "only the outmost dimension can be device-parallelized"); - } - return is_parallel_d.empty() ? false : is_parallel_d.at(0); -} - inline bool isDeviceInvolved( DeviceIdxType my_device_index, DeviceIdxType root, @@ -72,8 +56,7 @@ CommParams createParamsForGatherScatter( } if (mesh.has(my_device_index)) { - auto sliced_buf = - buf.index({static_cast(mesh.findIndex(my_device_index)), "..."}); + auto sliced_buf = buf.index({0, "..."}); ((is_scatter) ? params.dst_bufs : params.src_bufs) = {sliced_buf}; } @@ -154,8 +137,7 @@ void lowerToAllgather( params.dst_bufs.push_back( output_tensor.index({static_cast(i), "..."})); } - params.src_bufs = { - input_tensor.index({mesh.findIndex(my_device_index), "..."})}; + params.src_bufs = {input_tensor.index({0, "..."})}; comms.push_back(std::make_shared(std::move(params))); } @@ -208,17 +190,17 @@ void lowerToBroadcastOrP2P( // Adds several Broadcast or Send/Recv communications to the vector 'comms' // For now, we assume that this function is called only if -// the input and output have the same parallelization (given by -// the argument "is_parallelized"). Later we could support more general cases. +// the input and output have the same sharding. Later we could support more +// general cases. void lowerToBroadcastOrP2P( DeviceIdxType my_device_index, const DeviceMesh& sender_mesh, const DeviceMesh& receiver_mesh, at::Tensor input_tensor, at::Tensor output_tensor, - bool is_parallelized, + bool is_sharded, std::vector>& comms) { - if (is_parallelized) { + if (is_sharded) { // if the inputs and ouputs are parallelized, // we create as many Broadcast as that will be handled in parallel for (auto i : c10::irange(sender_mesh.vector().size())) { @@ -229,8 +211,8 @@ void lowerToBroadcastOrP2P( my_device_index, sender_mesh.vector().at(i), DeviceMesh({receiver_mesh.vector().at(i)}), - input_tensor.index({static_cast(i), "..."}), - output_tensor.index({static_cast(i), "..."}), + input_tensor.index({0, "..."}), + output_tensor.index({0, "..."}), comms); } } else { @@ -274,21 +256,21 @@ std::vector> lowerCommunication( c->out()->as()->getStage()->descriptor()->mesh; // Stores whether the I/O has its first axis parallelized on Didx - bool is_input_parallel_d = - isParallelD(input_tv) && sender_mesh.vector().size() > 1; - bool is_output_parallel_d = - isParallelD(output_tv) && receiver_mesh.vector().size() > 1; + const bool is_input_sharded = + isSharded(input_tv) && sender_mesh.vector().size() > 1; + const bool is_output_sharded = + isSharded(output_tv) && receiver_mesh.vector().size() > 1; NVF_ERROR( - !is_input_parallel_d || + !is_input_sharded || sender_mesh.vector().size() == static_cast(input_tensor.size(0)), - "the size of the mesh", + "the size of the mesh ", sender_mesh.vector().size(), " doesn't match the size of the tensor ", input_tensor.size(0)); NVF_ERROR( - !is_output_parallel_d || + !is_output_sharded || receiver_mesh.vector().size() == static_cast(output_tensor.size(0)), "the size of the mesh", @@ -302,7 +284,7 @@ std::vector> lowerCommunication( return {}; } - if (!is_input_parallel_d && is_output_parallel_d) { + if (!is_input_sharded && is_output_sharded) { lowerToScatter( my_device_index, sender_mesh, @@ -310,7 +292,7 @@ std::vector> lowerCommunication( input_tensor, output_tensor, comms); - } else if (is_input_parallel_d && !is_output_parallel_d) { + } else if (is_input_sharded && !is_output_sharded) { if (receiver_mesh.vector() == sender_mesh.vector()) { lowerToAllgather( my_device_index, sender_mesh, input_tensor, output_tensor, comms); @@ -330,7 +312,7 @@ std::vector> lowerCommunication( receiver_mesh, input_tensor, output_tensor, - is_input_parallel_d, + is_input_sharded, comms); } return comms; diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp new file mode 100644 index 00000000000..1164f933b8c --- /dev/null +++ b/csrc/multidevice/utils.cpp @@ -0,0 +1,31 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include + +#include + +namespace nvfuser { + +bool isSharded(TensorView* tv) { + std::vector is_sharded; + for (IterDomain* id : TensorDomain::noReductions(tv->getLeafDomain())) { + is_sharded.push_back(id->isDeviceDim()); + } + // Currently, only the most external dim is allowed to be sharded + NVF_ERROR(tv->getMaybeRFactorDomain() == tv->getLeafDomain()); + for (auto i : c10::irange(1, is_sharded.size())) { + NVF_ERROR( + !is_sharded.at(i), + "only the outmost dimension can be device-parallelized"); + } + return is_sharded.empty() ? false : is_sharded.at(0); +} + +} // namespace nvfuser diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h new file mode 100644 index 00000000000..3664ec88b1d --- /dev/null +++ b/csrc/multidevice/utils.h @@ -0,0 +1,19 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser { + +// Returns whether a TensorView has its first non-reduction axis parallelized +// on Didx +// Checks that the other non-reduction axis are not parallelized on Didx +bool isSharded(TensorView*); + +} // namespace nvfuser diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 5ea0805d94a..438ad5dff0b 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -1294,7 +1294,8 @@ TensorView* reductionOp( for (unsigned int axis : uint_axes) { auto id = tv_root[axis]; is_trivial_reduction[axis] = id->isBroadcast() && - !id->hasExpandedExtent() && id->extent()->isOneInt(); + !id->hasExpandedExtent() && id->extent()->isConstInt() && + id->extent()->evaluate().as() == 1; if (!is_trivial_reduction[axis]) { reduction_axes.push_back((int)axis + offset); } else if (!keep_dim) { diff --git a/csrc/optimization/alias_analysis.cpp b/csrc/optimization/alias_analysis.cpp index 0cfbb7a37d6..8916a0e3aae 100644 --- a/csrc/optimization/alias_analysis.cpp +++ b/csrc/optimization/alias_analysis.cpp @@ -13,75 +13,14 @@ #include #include #include +#include #include +#include namespace nvfuser::optimization { namespace { -bool isContiguous(const std::vector>& contiguity) { - for (const auto& id_contiguity : contiguity) { - // We skip std::nullopt contiguity. It represents a broadcast or reduction - // dimension, which is of size 1 and always contiguous. - if (id_contiguity.has_value() && id_contiguity.value() == false) { - return false; - } - } - return true; -} - -// Returns whether the input TensorView is allocated contiguously. -bool isContiguous(const TensorView& tv) { - return isContiguous(tv.getContiguity()); -} - -// Whether a ViewOp transforms any expanded broadcast IterDomain in the input. -// This is a corner case in which we can't always turn `out` into an alias. -// -// For example, given -// -// t0 = makeContigConcreteTensor({4, 5}); -// t1 = broadcast(t0, {false, false, true}); -// t2 = expand(t1, {4, 5, 6}); -// -// `reshape(t2, {40, 3})` and `reshape(t2, {4, 30})` because both merge the -// expanded broadcast IterDomain (6) or a subspace of it with preceding -// IterDomains. However, the output of `reshape(t2, {20, 6})` can simply be an -// alias because the expanded broadcast IterDomain is forwarded not transformed. -// -// As a future improvement, when an expanded broadcast dimension is only split, -// the output of the reshape can be an alias. However, nvFuser currently decides -// to materialize the expansion, making the output not an alias (#1126). -// -// Obviously, this function assumes `in` and `out` are the input and output -// TensorView of the same ViewOp. -bool transformsExpandedBroadcastIterDomain(TensorView* in, TensorView* out) { - const std::vector& in_rfactor = in->getMaybeRFactorDomain(); - const std::vector& out_root = out->getRootDomain(); - const std::vector& out_rfactor = out->getMaybeRFactorDomain(); - - std::unordered_set expanded_broadcast_dims; - for (size_t i = 0, size = in_rfactor.size(); i < size; i++) { - IterDomain* id = in_rfactor[i]; - if (id->isBroadcast() && id->hasExpandedExtent()) { - expanded_broadcast_dims.insert(out_root[i]); - } - } - - const std::vector transforms = DependencyCheck::getAllExprsBetween( - {out_root.begin(), out_root.end()}, - {out_rfactor.begin(), out_rfactor.end()}); - - for (const auto* transform : transforms) { - for (Val* input : transform->inputs()) { - if (expanded_broadcast_dims.count(input)) { - return true; - } - } - } - return false; -} - // Finds aliases between `expr`'s inputs and outputs and stores the findings in // `analysis`. // @@ -98,49 +37,164 @@ class AliasFinder : public OptOutConstDispatch { void handle(const ViewOp* view) override; void handle(const LoadStoreOp* ldst) override; + void handle(const SliceOp* slice) override; private: AliasAnalysisResult& analysis_; }; +// Computes `Split`'s output contiguity. Returns the outer contiguity and then +// the inner contiguity. +std::pair, std::optional> splitContiguity( + const std::optional& contiguity) { + // Credits to @jacobhinkle: + // https://github.com/NVIDIA/Fuser/pull/1124#discussion_r1368682735 + if (!contiguity.has_value()) { + return {std::nullopt, std::nullopt}; + } + if (*contiguity) { + return {true, true}; + } else { + return {true, false}; + } +} + +// Computes `Merge`'s output contiguity. Returns a pair +// ``. `mergeable` indicates whether the two IterDomains +// can be merged without materialization. For example, there's no way to merge +// `outer=f,inner=t` while keeping the output as an alias, because a dimension +// can only have one stride. `contiguity` is the contiguity of the merged output +// IterDomain. +// +// Credits to @jacobhinkle: +// https://github.com/NVIDIA/Fuser/pull/1124#discussion_r1368682735 +std::pair> mergeContiguity( + const IterDomain* outer_id, + const std::optional& outer_contiguity, + const IterDomain* inner_id, + const std::optional& inner_contiguity) { + // Statuses `b` and `e` are represented in the IR with isBroadcast() and + // hasExpandedExtent(). Status `C` means stops propagating because we know we + // can't alias at that point. + // + // o\i | t f b e + // ----+----------- + // t | t f t C + // f | C C f C + // b | t f b e + // e | C C e e + if (!outer_contiguity.has_value() && !outer_id->hasExpandedExtent()) { + return {true, inner_contiguity}; + } + if (!inner_contiguity.has_value() && !inner_id->hasExpandedExtent()) { + return {true, outer_contiguity}; + } + + // o\i | t f b e + // ----+----------- + // t | t f C + // f | C C C + // b | + // e | C C e + if (outer_id->hasExpandedExtent() && inner_id->hasExpandedExtent()) { + return {true, std::nullopt}; + } + if (outer_id->hasExpandedExtent() || inner_id->hasExpandedExtent()) { + return {false, std::nullopt}; + } + + // o\i | t f b e + // ----+----------- + // t | t f + // f | C C + // b | + // e | + if (*outer_contiguity) { + return {true, inner_contiguity}; + } + return {false, std::nullopt}; +} + void AliasFinder::handle(const ViewOp* view) { TensorView* in = view->in(); TensorView* out = view->out(); + const std::vector& in_rfactor = in->getMaybeRFactorDomain(); + const std::vector& out_root = out->getRootDomain(); + const std::vector& out_rfactor = out->getMaybeRFactorDomain(); + Layout in_layout = analysis_.preferredLayout(in); - const std::vector& out_allocation = - out->getMaybeAllocationDomain(); - if (in_layout.allocation_domain == in->getMaybeRFactorDomain() && - isContiguous(in_layout.contiguity) && - out_allocation == out->getMaybeRFactorDomain() && isContiguous(*out) && - !transformsExpandedBroadcastIterDomain(in, out)) { - // This is a sufficient but not necessary condition for `out` to alias - // `in`. Both `in` and `out` are allocated contiguously per the - // rfactor domain. Also, the ViewOp can't transform any expanded broadcast - // IterDomain. - analysis_.add( - out, - in, - {out_allocation, - TensorDomain::getContiguityFilledWith(out_allocation, true)}); + if (!ir_utils::computePermutation(in_rfactor, in_layout.allocation_domain) + .has_value()) { + // Give up when `in`'s allocation domain is not an rfactor permutation. + return; } -} -void AliasFinder::handle(const LoadStoreOp* permute) { - TensorView* out = dynamic_cast(permute->out()); - if (!out->hasRFactor()) { - // Not a permute. It's actually an easier case to propagate aliases. I'm - // too lazy. - return; + std::unordered_map in_rfactor_to_out_root = + PairwiseRootDomainMap(in, out).mapBroadcast(true).mapProducerToConsumer(); + + // Collect the allocation order of `in`'s rfactor domain and thus `out`'s root + // domain. + LinkedHashMap> allocation_to_contiguity; + for (const auto i : c10::irange(in_layout.allocation_domain.size())) { + IterDomain* in_allocation_id = in_layout.allocation_domain[i]; + if (!in_rfactor_to_out_root.count(in_allocation_id)) { + // `in_allocation_id` is a reduction product. + continue; + } + IterDomain* out_root_id = in_rfactor_to_out_root.at(in_allocation_id); + allocation_to_contiguity.pushBack(out_root_id, in_layout.contiguity[i]); } - // Another lazy move: we could check compatibility and only give up when - // the allocation domain is incompatible with what we prefer for aliasing. - if (out->hasAllocation()) { - return; + // Replay `Expr`s from `out`'s root to `out`'s rfactor on `out`'s root. + // Stop when an `Expr` requires a data copy; otherwise generate the allocation + // order of `out`'s rfactor domain and the corresponding contiguity flags. + for (Expr* transform : DependencyCheck::getAllExprsBetween( + {out_root.begin(), out_root.end()}, + {out_rfactor.begin(), out_rfactor.end()})) { + if (Split* split = dynamic_cast(transform)) { + const auto [contiguity, split_i] = + allocation_to_contiguity.erase(split->in()); + auto [outer_contiguity, inner_contiguity] = splitContiguity(contiguity); + allocation_to_contiguity.insert( + split_i, split->outer(), outer_contiguity); + allocation_to_contiguity.insert( + split_i, split->inner(), inner_contiguity); + } else if (Merge* merge = dynamic_cast(transform)) { + const auto [outer_contiguity, inner_i] = + allocation_to_contiguity.erase(merge->outer()); + if (inner_i == allocation_to_contiguity.end() || + inner_i->first != merge->inner()) { + // Outer and inner are not adjacent in allocation order. + return; + } + const auto [inner_contiguity, merge_i] = + allocation_to_contiguity.erase(merge->inner()); + const auto [mergeable, contiguity] = mergeContiguity( + merge->outer(), outer_contiguity, merge->inner(), inner_contiguity); + if (!mergeable) { + return; + } + allocation_to_contiguity.insert(merge_i, merge->out(), contiguity); + } else { + NVF_ERROR( + false, "Expect Split or Merge, but found: ", transform->toString()); + } } - TensorView* in = permute->in()->as(); + Layout out_layout; + for (const auto& [allocation_id, contiguity] : allocation_to_contiguity) { + out_layout.allocation_domain.push_back(allocation_id); + out_layout.contiguity.push_back(contiguity); + } + analysis_.add(out, in, std::move(out_layout)); +} + +void AliasFinder::handle(const LoadStoreOp* permute) { + TensorView* in = dynamic_cast(permute->in()); + if (in == nullptr) { + return; + } // Look at the preferred layout not `in`'s current layout. Layout in_layout = analysis_.preferredLayout(in); if (!ir_utils::computePermutation( @@ -150,6 +204,7 @@ void AliasFinder::handle(const LoadStoreOp* permute) { return; } + TensorView* out = permute->out()->as(); // Compute `out`'s preferred allocation domain for aliasing. // // For example, @@ -167,20 +222,96 @@ void AliasFinder::handle(const LoadStoreOp* permute) { // 1. Construct the map from `in`'s rfactor to `out`'s root: // {i0->i3,i1->i4,i2->i5}. // 2. Apply the map to `in`'s allocation and get [i5,i3,i4]. - std::unordered_map in_rfactor_to_out_root; - for (auto i : c10::irange(out->getRootDomain().size())) { - in_rfactor_to_out_root[in->getMaybeRFactorDomain()[i]] = - out->getRootDomain()[i]; - } + std::unordered_map in_rfactor_to_out_root = + PairwiseRootDomainMap(in, out).mapBroadcast(true).mapProducerToConsumer(); Layout out_layout; for (const auto i : c10::irange(in_layout.allocation_domain.size())) { - IterDomain* allocation_id = in_layout.allocation_domain[i]; + IterDomain* in_allocation_id = in_layout.allocation_domain[i]; + if (!in_rfactor_to_out_root.count(in_allocation_id)) { + // `in_allocation_id` is a reduction product. + continue; + } out_layout.allocation_domain.push_back( - in_rfactor_to_out_root.at(allocation_id)); + in_rfactor_to_out_root.at(in_allocation_id)); out_layout.contiguity.push_back(in_layout.contiguity[i]); } - analysis_.add(out, in, out_layout); + analysis_.add(out, in, std::move(out_layout)); +} + +// For future improvement, a PadOp with negative padding amount can also be +// treated as a slice. +void AliasFinder::handle(const SliceOp* slice) { + TensorView* in = slice->in(); + TensorView* out = slice->out(); + + const std::vector& in_rfactor = in->getMaybeRFactorDomain(); + const std::vector& out_root = out->getRootDomain(); + const std::vector& out_rfactor = out->getMaybeRFactorDomain(); + + std::unordered_map in_rfactor_to_out_root = + PairwiseRootDomainMap(in, out).mapBroadcast(true).mapProducerToConsumer(); + + const auto out_rank = out_rfactor.size(); + std::unordered_map out_root_to_rfactor; + out_root_to_rfactor.reserve(out_rank); + for (auto i : c10::irange(out_rank)) { + out_root_to_rfactor[out_root[i]] = out_rfactor[i]; + } + + Layout in_layout = analysis_.preferredLayout(in); + if (!ir_utils::computePermutation(in_rfactor, in_layout.allocation_domain) + .has_value()) { + // Give up when `in`'s allocation domain is not an rfactor permutation. + return; + } + + // Inherit the allocation order from the input. However, refine the + // contiguity flags. + Layout out_layout; + out_layout.allocation_domain.reserve(out_rank); + for (IterDomain* in_allocation_id : in_layout.allocation_domain) { + if (!in_rfactor_to_out_root.count(in_allocation_id)) { + // `in_allocation_id` is a reduction product. + continue; + } + IterDomain* out_root_id = in_rfactor_to_out_root.at(in_allocation_id); + out_layout.allocation_domain.push_back(out_root_to_rfactor.at(out_root_id)); + } + + // Scan through the allocation domain in minor-to-major order. If an + // IterDomain is sliced, the next non-broadcast IterDomain has to be marked + // non-contiguous. For example, + // + // in = makeContigConcreteTensor({16, 128, 3072}); + // out = slice(in, {0, 0, 0}, {16, 128, 1024}); + // + // For `out` to alias `in`, its contiguity has to be updated to [t, f, t]. + out_layout.contiguity.resize(out_rank); + bool next_non_broadcast_is_non_contiguous = false; + for (auto i = static_cast(out_rank) - 1; i >= 0; i--) { + if (out_layout.allocation_domain[i]->isBroadcast()) { + out_layout.contiguity[i] = std::nullopt; + } else if (next_non_broadcast_is_non_contiguous) { + out_layout.contiguity[i] = false; + next_non_broadcast_is_non_contiguous = false; + } else { + out_layout.contiguity[i] = in_layout.contiguity[i]; + } + + // A broadcast dimension can be a slicing product as well. + std::vector dependencies = DependencyCheck::getAllExprsBetween( + {out_root.begin(), out_root.end()}, {out_layout.allocation_domain[i]}); + if (std::find_if( + dependencies.begin(), dependencies.end(), [](const Expr* expr) { + return expr->isA(); + }) != dependencies.end()) { + // out_layout.allocation_domain[i] is sliced. + next_non_broadcast_is_non_contiguous = true; + } + } + + analysis_.add(out, in, std::move(out_layout)); } } // namespace @@ -188,17 +319,17 @@ void AliasFinder::handle(const LoadStoreOp* permute) { void AliasAnalysisResult::add( const TensorView* alias, const TensorView* source, - const Layout& layout) { - std::pair& old_source = alias_to_source_[alias]; + Layout&& layout) { + auto [i, inserted] = alias_to_source_.emplace( + alias, std::make_pair(source, std::move(layout))); NVF_ERROR( - old_source.first == nullptr, + inserted, "The current implementation of alias analysis shouldn't find two sources for an alias. However, it's trying to make ", alias->toString(), " an alias of ", source->toString(), " while it's already an alias of ", - old_source.first->toString()); - old_source = {source, layout}; + i->second.first->toString()); } const Val* AliasAnalysisResult::findRoot(const Val* alias) const { diff --git a/csrc/optimization/alias_analysis.h b/csrc/optimization/alias_analysis.h index 6bc371fb143..cdafa743c6c 100644 --- a/csrc/optimization/alias_analysis.h +++ b/csrc/optimization/alias_analysis.h @@ -24,6 +24,10 @@ struct Layout { class AliasAnalysisResult { public: AliasAnalysisResult() = default; + AliasAnalysisResult(const AliasAnalysisResult&) = delete; + AliasAnalysisResult& operator=(const AliasAnalysisResult&) = delete; + AliasAnalysisResult(AliasAnalysisResult&&) = default; + AliasAnalysisResult& operator=(AliasAnalysisResult&&) = default; // Returns itself if `alias` doesn't alias anything. const Val* findRoot(const Val* alias) const; @@ -34,15 +38,7 @@ class AliasAnalysisResult { // Marks `source` as the immediate aliasing source of `alias` and sets the // preferred layout. - void add( - const TensorView* alias, - const TensorView* source, - const Layout& layout); - - AliasAnalysisResult(const AliasAnalysisResult&) = delete; - AliasAnalysisResult& operator=(const AliasAnalysisResult&) = delete; - AliasAnalysisResult(AliasAnalysisResult&&) = default; - AliasAnalysisResult& operator=(AliasAnalysisResult&&) = default; + void add(const TensorView* alias, const TensorView* source, Layout&& layout); private: // Maps aliases (e.g. the output of a View) to their direct sources (e.g. the @@ -55,9 +51,16 @@ class AliasAnalysisResult { }; // Finds aliases of the fusion inputs. The analysis should be conservative -- -// when the analysis says B is an alias of input A, +// when the analysis says B is an alias of input A and that B's layout +// (allocation domain and contiguity) is compatible with the preferred layout, // `ExpressionEvaluator::evaluate(B)` should produce an `at::Tensor` that's an // alias of the `at::Tensor` bound to A. +// +// Currently, for implementation convenience, AliasAnalysis ignores allocation +// domains of non-fusion-input TensorViews. It produces preferred layouts for +// these TensorViews and expects the user to resolve any incompatibility. +// MarkAliasPass, its only user at this moment, marks an output as an alias only +// when its allocation domain is empty. I'm happy to revisit this contract. AliasAnalysisResult findAliases(Fusion* fusion); } // namespace nvfuser::optimization diff --git a/csrc/optimization/mark_alias.cpp b/csrc/optimization/mark_alias.cpp index 5d20ca315f3..3fe5cbf3003 100644 --- a/csrc/optimization/mark_alias.cpp +++ b/csrc/optimization/mark_alias.cpp @@ -17,37 +17,49 @@ void MarkAliasPass::runPass(Fusion* fusion) { const AliasAnalysisResult alias_analysis = findAliases(fusion); for (TensorView* out : ir_utils::filterByType(fusion->outputs())) { - if (const Val* in = alias_analysis.findRoot(out); in->isFusionInput()) { - fusion->aliasOutputToInput( - out, - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(in), - AliasType::PointerCast); + // Lazy move: we could check compatibility and only give up when + // the allocation domain is incompatible with what we prefer for + // aliasing. + if (out->hasAllocation()) { if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { - debug() << "MarkAliasPass marked " << out->toString() - << " as an alias of " << in->toString() << std::endl; + debug() << "MarkAliasPass skipped " << out->toString() + << " because it already has an allocation domain:" << std::endl + << out->domain()->toString(1, /*leaf_only=*/false) << std::endl; } + continue; + } - // A scalar `out` triggers a corner case that crashes - // `validateDomainEquivalence`. - if (!out->isZeroDim()) { - const Layout out_layout = alias_analysis.preferredLayout(out); - if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { - debug() << "MarkAliasPass changed the layout of " << out->toString() - << std::endl; - debug() << " Old TensorDomain:" << std::endl; - debug() << out->domain()->toString(4, /*leaf_only=*/false) - << std::endl; - } - out->setAllocationDomain( - out_layout.allocation_domain, out_layout.contiguity); - if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { - debug() << " New TensorDomain:" << std::endl; - debug() << out->domain()->toString(4, /*leaf_only=*/false) - << std::endl; - } - } + const Val* in = alias_analysis.findRoot(out); + if (!in->isFusionInput()) { + continue; + } + + fusion->aliasOutputToInput( + out, + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(in), + AliasType::PointerArithmetic); + if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { + debug() << "MarkAliasPass marked " << out->toString() + << " as an alias of " << in->toString() << std::endl; + } + + // When `out` is a scalar, `out->setAllocationDomain` triggers a corner case + // that crashes `validateDomainEquivalence`. + if (out->isZeroDim()) { + continue; + } + + const Layout out_layout = alias_analysis.preferredLayout(out); + if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { + debug() << "MarkAliasPass changed the layout of " << out->toString() + << std::endl; + debug() << " Old TensorDomain:" << std::endl; + debug() << out->domain()->toString(4, /*leaf_only=*/false) << std::endl; + debug() << " New layout:" << out_layout.toString() << std::endl; } + out->setAllocationDomain( + out_layout.allocation_domain, out_layout.contiguity); } } diff --git a/csrc/polymorphic_value.h b/csrc/polymorphic_value.h index e3b0e232c97..30a27cbdba4 100644 --- a/csrc/polymorphic_value.h +++ b/csrc/polymorphic_value.h @@ -319,6 +319,9 @@ inline PolymorphicValue abs(const PolymorphicValue& a) { if (a.is>()) { return std::abs(a.as>()); } + if (a.is()) { + return a.as().abs(); + } NVF_ERROR( false, "PolymorphicValue abs not implemented for ", a.type().name()); } @@ -373,6 +376,15 @@ inline PolymorphicValue toTensor( false, "PolymorphicValue toTensor not implemented for ", x.type().name()); } +// Convert PolymorphicValue to c10::Scalar. +inline c10::Scalar toScalar(const PolymorphicValue& x) { + if (x.is>()) { + return (c10::complex)x.as>(); + } else { + return (c10::Scalar)x; + } +} + } // namespace PolymorphicValue_functions } // namespace nvfuser diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 7d90c07d279..12667497298 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -455,13 +455,7 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) { const auto gpu_lower = GpuLower::current(); - // FIXME: - // Needed to keep the predicate of cp.async initialization to get the - // inverted predicate, - // see [Predicate Inversion for CpAsync]. In a follow up both this part and - // the [Predicate Inversion for CpAsync] should be cleaned up together. - if (gpu_lower->predicateElimination().canOmitPredicate(tv_expr) && - !ir_utils::isCpAsyncInit(tv_expr)) { + if (gpu_lower->predicateElimination().canOmitPredicate(tv_expr)) { return; } diff --git a/csrc/python_frontend/fusion_cache.cpp b/csrc/python_frontend/fusion_cache.cpp index 759b53de32d..578ab96edbe 100644 --- a/csrc/python_frontend/fusion_cache.cpp +++ b/csrc/python_frontend/fusion_cache.cpp @@ -5,6 +5,9 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include +#include + #include #include #include @@ -15,8 +18,167 @@ #include namespace fs = std::filesystem; +#ifdef _WIN32 +#include +#else +#include +#include +#endif + namespace nvfuser::python_frontend { +namespace { +// Generate temporary file for this FusionCacheBuffer +std::string getSerdeTmpFile() { +#ifdef _WIN32 + const unsigned int pid = GetCurrentProcessId(); +#else + const unsigned int pid = getpid(); +#endif // _WIN32 + std::stringstream ss; + ss << "nvf_serde_tmp_" << pid; + return ss.str(); +} + +std::string getSerdeFile() { + auto device_prop = at::cuda::getCurrentDeviceProperties(); + int cuda_major = 0; + int cuda_minor = 0; + NVFUSER_NVRTC_SAFE_CALL(nvrtcVersion(&cuda_major, &cuda_minor)); + + std::stringstream ss; + ss << "nvf_serde"; + ss << "_device" << device_prop->major << "_" << device_prop->minor; + ss << "_cuda" << cuda_major << "_" << cuda_minor; + return ss.str(); +} + +// Get std::filesystem::path to specified file in nvfuser kernel database +// directory. +fs::path getSerdeFilePath(const std::string& file_name) { + fs::path kernel_db_path = fs::temp_directory_path() / "nvfuser_kernel_db"; + if (!fs::is_directory(kernel_db_path)) { + try { + fs::create_directory(kernel_db_path); + } catch (const std::exception& e) { + NVF_ERROR( + "Unable to create nvFuser Kernel DB directory! ", + kernel_db_path.string(), + e.what()); + } + } + return kernel_db_path / file_name; +} + +using BinaryBuffer = std::vector; +BinaryBuffer openFusionCache(std::string filename) { + FUSER_PERF_SCOPE("Flatbuffers::openFusionCache"); + auto file_handle = std::fopen(filename.c_str(), "rb"); + NVF_CHECK(file_handle != nullptr, "Failed to open FusionCache buffer."); + + auto file_path = fs::path(filename.c_str()); + auto file_size = fs::file_size(file_path); + NVF_CHECK(file_size > 0, "FusionCache buffer is empty."); + + BinaryBuffer buffer(file_size); + size_t read_status = + std::fread(buffer.data(), sizeof(uint8_t), file_size, file_handle); + NVF_CHECK( + read_status == file_size, "Failed to read entire FusionCache buffer.\n"); + return buffer; +} + +// This check function only throws errors if strict flag is enabled. +const serde::FusionCache* verifyFusionCache( + const BinaryBuffer& buffer, + bool strict) { + FUSER_PERF_SCOPE("Flatbuffers::verifyFusionCache"); + auto fusion_cache_buffer = serde::GetFusionCache(buffer.data()); + + // Check flatbuffer integrity + flatbuffers::Verifier v(buffer.data(), buffer.size()); + if (!fusion_cache_buffer->Verify(v)) { + NVF_CHECK(!strict, "Failed to verify the integrity of FusionCache buffer."); + return nullptr; + } + + // Check schema version + if (!serde::FusionCacheBufferHasIdentifier(buffer.data())) { + NVF_CHECK( + !strict, + "Failed to verify the schema version of the FusionCache buffer"); + return nullptr; + } + + // Check device major and minor versions + auto device_prop = at::cuda::getCurrentDeviceProperties(); + if (device_prop->major != fusion_cache_buffer->device_major() || + device_prop->minor != fusion_cache_buffer->device_minor()) { + NVF_CHECK( + !strict, + "Expected cuda version ", + device_prop->major, + ".", + device_prop->minor, + " but flatbuffer has cuda version ", + fusion_cache_buffer->device_major(), + ".", + fusion_cache_buffer->device_minor()); + return nullptr; + } + + // Check cuda installation + int cuda_major = 0; + int cuda_minor = 0; + NVFUSER_NVRTC_SAFE_CALL(nvrtcVersion(&cuda_major, &cuda_minor)); + if (cuda_major != fusion_cache_buffer->cuda_major() || + cuda_minor != fusion_cache_buffer->cuda_minor()) { + NVF_CHECK( + !strict, + "Expected cuda version ", + cuda_major, + ".", + cuda_minor, + " but flatbuffer has cuda version ", + fusion_cache_buffer->cuda_major(), + ".", + fusion_cache_buffer->cuda_minor()); + return nullptr; + } + + return fusion_cache_buffer; +} + +} // namespace + +void serialize() { + auto tmp_file_path = getSerdeFilePath(getSerdeTmpFile()); + FusionCache::get()->serialize(tmp_file_path); + + // Save to a per-process temporary file to avoid multi-process contention. + // Then, rename the temporary file to the actual file. If the actual file + // already exists, then the rename may fail or replace the actual file. + // Files replaced through this process should remain extant if they are being + // read because of UNIX filesystem properties, but this behavior is + // unverified. + auto file_path = getSerdeFilePath(getSerdeFile()); + std::error_code rename_ec; + fs::rename(tmp_file_path, file_path, rename_ec); + + // Failed to replace common workspace, so remove the temporary file. + if (rename_ec) { + try { + fs::remove(tmp_file_path); + std::cout + << "Removed temporary file because we could not replace common workspace. Exception:\t" + << rename_ec.message() << std::endl; + } catch (const std::exception& e) { + std::cout << "Failed to delete temporary file. Exception:\t" << e.what() + << std::endl; + } + } +} + // FusionCache static data member definitions for singleton usage std::mutex FusionCache::singleton_lock_; FusionCache* FusionCache::singleton_ = nullptr; @@ -52,7 +214,7 @@ TrieNode::TrieNode(RecordFunctor* rec, TrieNode* _parent, size_t _fusion_id) trie_node_lock() {} bool TrieNode::isTerminal() const { - return (record.get()->recordType() == serde::RecordType_End); + return (record.get()->recordType() == serde::RecordType::End); } flatbuffers::Offset TrieNode::serialize( @@ -76,11 +238,13 @@ flatbuffers::Offset TrieNode::serialize( isTerminal()); } -FusionCache* FusionCache::get(size_t max_fusions) { +FusionCache* FusionCache::get( + size_t max_fusions, + bool load_from_default_workspace) { FUSER_PERF_SCOPE("FusionCache::get"); std::lock_guard guard(singleton_lock_); if (singleton_ == nullptr) { - singleton_ = new FusionCache(max_fusions); + singleton_ = new FusionCache(max_fusions, load_from_default_workspace); } NVF_CHECK( max_fusions >= singleton_->fusions_.size(), @@ -106,7 +270,7 @@ void FusionCache::print(std::ostream& os) const { std::vector rev_fusion_records; TrieNode* end = node->parent; while (end) { - if (end->record->recordType() != serde::RecordType_Start) { + if (end->record->recordType() != serde::RecordType::Start) { rev_fusion_records.emplace_back(end); } end = end->parent; @@ -151,16 +315,16 @@ void FusionCache::stats(std::ostream& os) const { } } -void FusionCache::reset() { +void FusionCache::reset(bool load_from_default_workspace) { std::lock_guard guard(singleton_lock_); if (singleton_ != nullptr) { auto max_fusions = singleton_->max_fusions_; delete singleton_; - singleton_ = new FusionCache(max_fusions); + singleton_ = new FusionCache(max_fusions, load_from_default_workspace); } } -FusionCache::FusionCache(size_t max_fusions) +FusionCache::FusionCache(size_t max_fusions, bool load_from_default_workspace) : max_fusions_(max_fusions), root_(nullptr), fusions_(), @@ -168,6 +332,27 @@ FusionCache::FusionCache(size_t max_fusions) user_def_input_encodings_() { RecordFunctor* start = new StartRecord(); root_ = std::make_unique(start); + + // Deserialize cache hierarchy from common workspace automatically + auto file_path = getSerdeFilePath(getSerdeFile()).native(); + if (load_from_default_workspace && fs::exists(file_path)) { + const BinaryBuffer& buffer = openFusionCache(file_path); + const serde::FusionCache* fc = + verifyFusionCache(buffer, false /* strict */); + // The saved workspace can become out-of-date between nvfuser updates. + if (fc != nullptr) { + // Only deserialize if the current binary is valid. + deserialize(buffer, fc); + } else { + try { + fs::remove(file_path); + std::cout << "Delete incompatible workspace." << std::endl; + } catch (const std::exception& e) { + std::cout << "Failed to delete workspace. Exception:\t" << e.what() + << std::endl; + } + } + } } // In order to keep queries fast, this method does not lock. @@ -188,6 +373,7 @@ std::optional FusionCache::queryChildren( return std::optional(trie_node->second.get()); } } + FusionSchedules* FusionCache::queryFusionSchedules(size_t fusion_id) const { NVF_CHECK( fusion_id < fusions_.size(), @@ -242,7 +428,7 @@ TrieNode* FusionCache::createChild(TrieNode* node, RecordFunctor* rec) { child = child_node.value(); } else { size_t fusion_id = 0; - if (rec->recordType() == serde::RecordType_End) { + if (rec->recordType() == serde::RecordType::End) { NVF_CHECK( (fusions_.size() + 1) <= max_fusions_, "The number of fusions in nvfuser has exceeded ", @@ -263,7 +449,7 @@ TrieNode* FusionCache::createChild(TrieNode* node, RecordFunctor* rec) { child = node->children[new_rec].get(); NVF_CHECK(child, "Created child of TrieNode should not be null!"); ++(child->visits); - if (rec->recordType() == serde::RecordType_End) { + if (rec->recordType() == serde::RecordType::End) { terminal_nodes_.push_back(node->children[new_rec].get()); } if (isDebugDumpEnabled(DebugDumpOption::PythonFrontendDebug)) { @@ -366,6 +552,11 @@ void FusionCache::serialize(std::string filename) const { schedule->auto_gen_schedules->serialize(builder)); } + auto device_prop = at::cuda::getCurrentDeviceProperties(); + int cuda_major = 0; + int cuda_minor = 0; + NVFUSER_NVRTC_SAFE_CALL(nvrtcVersion(&cuda_major, &cuda_minor)); + // 6. Build FusionCache flatbuffer object // See table definition for FusionCache in serde/fusion_cache.fbs auto fusion_cache = serde::CreateFusionCacheDirect( @@ -374,7 +565,11 @@ void FusionCache::serialize(std::string filename) const { &fb_nodes, &terminal_node_idx, &fb_auto_gen_schedules, - FusionExecutor::getGlobalFusionCount()); + FusionExecutor::getGlobalFusionCount(), + device_prop->major, + device_prop->minor, + cuda_major, + cuda_minor); builder.Finish(fusion_cache, "NV00" /* file_identifier */); // 6. Write flatbuffer binary to file @@ -388,41 +583,6 @@ void FusionCache::serialize(std::string filename) const { std::fclose(file_handle); } -namespace { -typedef std::vector BinaryBuffer; - -BinaryBuffer openFusionCache(std::string filename) { - FUSER_PERF_SCOPE("Flatbuffers::openFusionCache"); - auto file_handle = std::fopen(filename.c_str(), "rb"); - NVF_CHECK(file_handle != nullptr, "Failed to open FusionCache buffer."); - - auto file_path = fs::path(filename.c_str()); - auto file_size = fs::file_size(file_path); - NVF_CHECK(file_size > 0, "FusionCache buffer is empty."); - - BinaryBuffer buffer(file_size); - size_t read_status = - std::fread(buffer.data(), sizeof(uint8_t), file_size, file_handle); - NVF_CHECK( - read_status == file_size, "Failed to read entire FusionCache buffer.\n"); - return buffer; -} - -const serde::FusionCache* verifyFusionCache(const BinaryBuffer& buffer) { - FUSER_PERF_SCOPE("Flatbuffers::verifyFusionCache"); - auto fusion_cache_buffer = serde::GetFusionCache(buffer.data()); - flatbuffers::Verifier v(buffer.data(), buffer.size()); - NVF_CHECK( - fusion_cache_buffer->Verify(v), - "Failed to verify the integrity of FusionCache buffer."); - NVF_CHECK( - serde::FusionCacheBufferHasIdentifier(buffer.data()), - "Failed to verify the schema version of the FusionCache buffer"); - return fusion_cache_buffer; -} - -} // namespace - void FusionCache::deserialize(std::string filename) { // See table definition for FusionCache in serde/fusion_cache.fbs // 0. Load flatbuffer binary from file @@ -430,8 +590,18 @@ void FusionCache::deserialize(std::string filename) { NVF_CHECK( fusions_.empty(), "Deserialization is prohibited if FusionCache is already populated."); - auto buffer = openFusionCache(filename); - auto fusion_cache_buffer = verifyFusionCache(buffer); + const BinaryBuffer& buffer = openFusionCache(filename); + const serde::FusionCache* fusion_cache_buffer = + verifyFusionCache(buffer, true /* strict */); + deserialize(buffer, fusion_cache_buffer); +} + +void FusionCache::deserialize( + const BinaryBuffer& buffer, + const serde::FusionCache* fusion_cache_buffer) { + // See table definition for FusionCache in serde/fusion_cache.fbs + FUSER_PERF_SCOPE("FusionCache::deserialize"); + NVF_CHECK(fusion_cache_buffer != nullptr, "Fusion Cache buffer is invalid."); // 0. Set static fusion count in Fusion Executor FusionExecutor::setGlobalFusionCount( @@ -488,7 +658,7 @@ void FusionCache::deserialize(std::string filename) { fb_trie_node->children()->size() == 0, "This terminal node should not have any children.") NVF_CHECK( - fb_trie_node->record()->type() == serde::RecordType_End, + fb_trie_node->record()->type() == serde::RecordType::End, "This terminal node should have an EndRecord RecordFunctor") NVF_CHECK( trie_ptr->fusion_id == fb_trie_node->fusion_id(), diff --git a/csrc/python_frontend/fusion_cache.h b/csrc/python_frontend/fusion_cache.h index 579d47d4f76..cf21274dad6 100644 --- a/csrc/python_frontend/fusion_cache.h +++ b/csrc/python_frontend/fusion_cache.h @@ -118,7 +118,7 @@ struct TrieNode { class FusionCache { //! The constructor is private given the FusionCache is only constructed //! as a singleton. - FusionCache(size_t max_fusions); + FusionCache(size_t max_fusions, bool load_from_default_workspace); public: //! Copy and Assignment of the FusionCache is not supported @@ -129,7 +129,9 @@ class FusionCache { //! The next 4 public methods are the python interface methods //! Gets a pointer to the singleton and creates a new one if necessary - static FusionCache* get(size_t max_fusions = 8192); + static FusionCache* get( + size_t max_fusions = 8192, + bool load_from_default_workspace = true); //! Number of fusions cached size_t numFusions() const; //! print cache contents @@ -137,7 +139,8 @@ class FusionCache { //! print cache stats void stats(std::ostream& os) const; //! Reset Cache to an empty state - static void reset(); + static void reset(bool load_from_default_workspace = false); + //! Serialize Fusion Cache using flatbuffers void serialize(std::string filename) const; //! Deserialize Fusion Cache using flatbuffers @@ -174,6 +177,12 @@ class FusionCache { TrieNode* rootTriePtr(); private: + using BinaryBuffer = std::vector; + //! Deserialize Fusion Cache + void deserialize( + const BinaryBuffer& buffer, + const serde::FusionCache* fusion_cache_buffer); + //! The static pointer to the FusionCache static FusionCache* singleton_; //! Lock for accessing the singleton by multiple threads @@ -199,4 +208,14 @@ class FusionCache { InputsIdLookup user_def_input_encodings_; }; +//! Serialize Fusion Cache to common workspace +//! /tmp/nvfuser_kernel_db/nvf_serde_[cuda_major]_[cuda_minor]_[nvrtc_major]_[nvrtc_minor] +//! +//! '''python +//! # Use atexit to automatically call serialize on program exit +//! import atexit +//! atexit.register(nvfuser.serialize) +//! ''' +void serialize(); + } // namespace nvfuser::python_frontend diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 0b0c0a3befb..099186bb304 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -301,21 +301,21 @@ std::optional FusionDefinition::id() const { Scalar FusionDefinition::defineScalar() { FUSER_PERF_SCOPE("FusionDefinition::defineScalar"); Scalar out(recording_state_.size(), this); - recording_state_.emplace_back(out(), serde::StateType_Scalar); + recording_state_.emplace_back(out(), serde::StateType::Scalar); return out; } Tensor FusionDefinition::defineTensor(size_t dims) { FUSER_PERF_SCOPE("FusionDefinition::defineTensor"); Tensor out(recording_state_.size(), dims, this); - recording_state_.emplace_back(out(), serde::StateType_Tensor); + recording_state_.emplace_back(out(), serde::StateType::Tensor); return out; } Vector FusionDefinition::defineVector(size_t size) { FUSER_PERF_SCOPE("FusionDefinition::defineVector"); Vector out(recording_state_.size(), size, this); - recording_state_.emplace_back(out(), serde::StateType_Vector); + recording_state_.emplace_back(out(), serde::StateType::Vector); return out; } diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 21e1245ff9e..e9a8198b8e7 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -106,7 +106,7 @@ struct RecordFunctor { //! Abstraction for storing data specific to a record functor. virtual std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const { - return {serde::RecordData_NONE, flatbuffers::Offset()}; + return {serde::RecordData::NONE, flatbuffers::Offset()}; } //! The base serialize function that handles args, outputs, name and @@ -327,7 +327,7 @@ struct ReshapeOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), "ops.reshape", - serde::RecordType_ReshapeOp) { + serde::RecordType::ReshapeOp) { arg_names_[1] = "new_shape"; } ~ReshapeOpRecord() override = default; @@ -353,7 +353,7 @@ struct PadOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), "ops.pad", - serde::RecordType_PadOp), + serde::RecordType::PadOp), pad_widths_(std::move(pad_widths)) {} ~PadOpRecord() override = default; RecordFunctor* clone() final { @@ -407,7 +407,7 @@ struct PadOpRecord : RecordFunctor { } TensorView* output = nullptr; - if (args_.at(1).stype == serde::StateType_Scalar) { + if (args_.at(1).stype == serde::StateType::Scalar) { output = pad(arg, val_widths, fd.getFusionState(args_.at(1).index)); } else { // default: None output = pad(arg, val_widths); @@ -435,7 +435,7 @@ struct PadOpRecord : RecordFunctor { os << w; } os << "]"; - if (args_.at(1).stype == serde::StateType_Scalar) { + if (args_.at(1).stype == serde::StateType::Scalar) { // fill value was given os << ", " << args_.at(1); } @@ -445,7 +445,7 @@ struct PadOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Pad, + serde::RecordData::Pad, serde::CreatePadDirect(builder, &pad_widths_).Union()}; } @@ -520,12 +520,12 @@ struct DimsOpRecord : RecordFunctor { } void operator()(FusionState& fd) final { - if constexpr (op_type == serde::RecordType_PermuteOp) { + if constexpr (op_type == serde::RecordType::PermuteOp) { auto arg = fd.getFusionState(args_.at(0).index)->template as(); auto output = permute(arg, dims_); fd.setFusionState(outputs_.at(0).index, output); - } else if constexpr (op_type == serde::RecordType_StrideOrderOp) { + } else if constexpr (op_type == serde::RecordType::StrideOrderOp) { auto arg = fd.getFusionState(args_.at(0).index)->template as(); auto output = set(arg); @@ -544,9 +544,9 @@ struct DimsOpRecord : RecordFunctor { void print(std::ostream& os, bool close_function = true) const final { RecordFunctor::print(os, false); - if constexpr (op_type == serde::RecordType_PermuteOp) { + if constexpr (op_type == serde::RecordType::PermuteOp) { os << ", dims=["; - } else if constexpr (op_type == serde::RecordType_StrideOrderOp) { + } else if constexpr (op_type == serde::RecordType::StrideOrderOp) { os << ", stride_order=["; } else { NVF_ERROR(false, "op_type is not recognized by dims operator."); @@ -569,7 +569,7 @@ struct DimsOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Dims, + serde::RecordData::Dims, serde::CreateDimsDirect(builder, &dims_).Union()}; } @@ -588,7 +588,7 @@ struct SqueezeOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), "ops.squeeze", - serde::RecordType_SqueezeOp), + serde::RecordType::SqueezeOp), original_shape_(std::move(original_shape)), dims_(std::move(dims)) {} ~SqueezeOpRecord() override = default; @@ -677,7 +677,7 @@ struct SqueezeOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Squeeze, + serde::RecordData::Squeeze, serde::CreateSqueezeDirect(builder, &original_shape_, &dims_).Union()}; } @@ -703,7 +703,7 @@ struct BroadcastInDimOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), "ops.broadcast_in_dim", - serde::RecordType_BroadcastInDim), + serde::RecordType::BroadcastInDim), output_ndims_(output_ndims), broadcast_dims_(std::move(broadcast_dims)) { arg_names_[1] = "shape"; @@ -804,7 +804,7 @@ struct BroadcastInDimOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_BroadcastInDim, + serde::RecordData::BroadcastInDim, serde::CreateBroadcastInDimDirect( builder, output_ndims_, &broadcast_dims_) .Union()}; @@ -831,7 +831,7 @@ struct BroadcastOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), _name, - serde::RecordType_BroadcastOp), + serde::RecordType::BroadcastOp), is_broadcast_dim_(std::move(is_broadcast_dim)) {} ~BroadcastOpRecord() override = default; RecordFunctor* clone() final { @@ -891,7 +891,7 @@ struct BroadcastOpRecord : RecordFunctor { serde::BroadcastBuilder bcast_builder(builder); bcast_builder.add_broadcast_dims(fb_broadcast_dims); auto expr_data = bcast_builder.Finish(); - return {serde::RecordData_Broadcast, expr_data.Union()}; + return {serde::RecordData::Broadcast, expr_data.Union()}; } private: @@ -980,8 +980,8 @@ struct CastOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Dtype, - serde::CreateDtype(builder, toUnderlying(dtype_)).Union()}; + serde::RecordData::Dtype, + serde::CreateDtype(builder, nvfuser::toUnderlying(dtype_)).Union()}; } private: @@ -1000,7 +1000,7 @@ struct CatOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), "ops.cat", - serde::RecordType_CatOp), + serde::RecordType::CatOp), dim_(dim) {} ~CatOpRecord() override = default; RecordFunctor* clone() final { @@ -1070,7 +1070,7 @@ struct CatOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Dimension, + serde::RecordData::Dimension, serde::CreateDimension(builder, dim_).Union()}; } @@ -1083,7 +1083,7 @@ struct CatOpRecord : RecordFunctor { //! The accompanying Fusion Cache Entry holds a Fusion Object. struct EndRecord : RecordFunctor { - EndRecord() : RecordFunctor({}, {}, "end", serde::RecordType_End) {} + EndRecord() : RecordFunctor({}, {}, "end", serde::RecordType::End) {} ~EndRecord() override = default; RecordFunctor* clone() final { return new EndRecord(*this); @@ -1121,7 +1121,7 @@ struct TensorRecord : RecordFunctor { {}, std::move(_outputs), "define_tensor", - serde::RecordType_Tensor), + serde::RecordType::Tensor), shape_(std::move(_shape)), contiguity_(std::move(_contiguity)), stride_order_(std::move(_stride_order)), @@ -1324,16 +1324,16 @@ struct TensorRecord : RecordFunctor { flatbuffers::FlatBufferBuilder& builder) const final { auto fb_sizes = builder.CreateVector(shape_); - auto mapOptionalToEnum = [](std::optional v) -> int { + auto mapOptionalToEnum = [](std::optional v) -> serde::Contiguity { if (!v.has_value()) { - return serde::Contiguity_None; + return serde::Contiguity::None; } else if (v.value()) { - return serde::Contiguity_Contiguous; + return serde::Contiguity::Contiguous; } else { - return serde::Contiguity_Strided; + return serde::Contiguity::Strided; } }; - std::vector contiguity_enum; + std::vector contiguity_enum; std::transform( contiguity_.cbegin(), contiguity_.cend(), @@ -1349,7 +1349,7 @@ struct TensorRecord : RecordFunctor { tensor_builder.add_dtype(toUnderlying(dtype_)); tensor_builder.add_is_cpu(is_cpu_); auto expr_data = tensor_builder.Finish(); - return {serde::RecordData_Tensor, expr_data.Union()}; + return {serde::RecordData::Tensor, expr_data.Union()}; } private: @@ -1477,7 +1477,7 @@ struct OutputRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Output, + serde::RecordData::Output, serde::CreateOutputDirect(builder, &stride_order_).Union()}; } @@ -1623,7 +1623,7 @@ struct ReductionOpRecord : RecordFunctor { flatbuffers::FlatBufferBuilder& builder) const final { // TODO add dtype return { - serde::RecordData_Reduction, + serde::RecordData::Reduction, serde::CreateReductionDirect( builder, &axes_, keep_dim_, toUnderlying(dtype_)) .Union()}; @@ -1651,7 +1651,7 @@ struct IndexSelectOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), "ops.index_select", - serde::RecordType_IndexSelectOp), + serde::RecordType::IndexSelectOp), dim_(dim) {} ~IndexSelectOpRecord() override = default; RecordFunctor* clone() final { @@ -1685,7 +1685,7 @@ struct IndexSelectOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Dimension, + serde::RecordData::Dimension, serde::CreateDimension(builder, dim_).Union()}; } @@ -1703,7 +1703,7 @@ struct TorchGatherOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), "ops.gather", - serde::RecordType_TorchGatherOp), + serde::RecordType::TorchGatherOp), dim_(dim) {} ~TorchGatherOpRecord() override = default; RecordFunctor* clone() final { @@ -1737,7 +1737,7 @@ struct TorchGatherOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Dimension, + serde::RecordData::Dimension, serde::CreateDimension(builder, dim_).Union()}; } @@ -1757,7 +1757,7 @@ struct TakeAlongAxisOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), "ops.take_along_axis", - serde::RecordType_TakeAlongAxisOp), + serde::RecordType::TakeAlongAxisOp), dim_(dim) {} ~TakeAlongAxisOpRecord() override = default; RecordFunctor* clone() final { @@ -1791,7 +1791,7 @@ struct TakeAlongAxisOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Dimension, + serde::RecordData::Dimension, serde::CreateDimension(builder, dim_).Union()}; } @@ -1812,7 +1812,7 @@ struct ScalarRecord : RecordFunctor { {}, std::move(_outputs), "define_scalar", - serde::RecordType_Scalar), + serde::RecordType::Scalar), value_( dtype.has_value() ? castToDtype(std::move(value), dtype.value()) : std::move(value)), @@ -1903,7 +1903,7 @@ struct ScalarRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Scalar, + serde::RecordData::Scalar, serde::serializeScalar(builder, value_, dtype_).Union()}; } @@ -1929,7 +1929,7 @@ struct SliceOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), "ops.slice", - serde::RecordType_SliceOp), + serde::RecordType::SliceOp), start_indices_(std::move(start_indices)), end_indices_(std::move(end_indices)), strides_(std::move(strides)) {} @@ -2019,7 +2019,7 @@ struct SliceOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Slice, + serde::RecordData::Slice, serde::CreateSliceDirect( builder, &start_indices_, &end_indices_, &strides_) .Union()}; @@ -2042,7 +2042,7 @@ struct SliceOpRecord : RecordFunctor { //! Fusion Cache. struct StartRecord : RecordFunctor { - StartRecord() : RecordFunctor({}, {}, "start", serde::RecordType_Start) {} + StartRecord() : RecordFunctor({}, {}, "start", serde::RecordType::Start) {} ~StartRecord() override = default; RecordFunctor* clone() final { return new StartRecord(*this); @@ -2148,7 +2148,7 @@ struct NormOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Norm, + serde::RecordData::Norm, serde::CreateNormDirect(builder, &axes_, correction_, keep_dim_) .Union()}; } @@ -2173,7 +2173,7 @@ struct VarianceOpRecord : NormOpRecord { std::move(args), std::move(outputs), "ops.var", - serde::RecordType_VarianceOp, + serde::RecordType::VarianceOp, std::move(axes), correction, keep_dim) {} @@ -2202,7 +2202,7 @@ struct VarianceMeanOpRecord : NormOpRecord { std::move(args), std::move(outputs), "ops.var_mean", - serde::RecordType_VarianceMeanOp, + serde::RecordType::VarianceMeanOp, std::move(axes), correction, keep_dim) {} @@ -2229,7 +2229,7 @@ struct BatchNormOpRecord : RecordFunctor { std::move(args), std::move(outputs), "ops.batch_norm", - serde::RecordType_BatchNormOp), + serde::RecordType::BatchNormOp), training_(training), channels_last_(channels_last) {} ~BatchNormOpRecord() override = default; @@ -2255,16 +2255,16 @@ struct BatchNormOpRecord : RecordFunctor { void operator()(FusionState& fd) final { auto x = fd.getFusionState(args_.at(0).index)->as(); - auto weight = (args_.at(1).stype == serde::StateType_Tensor) + auto weight = (args_.at(1).stype == serde::StateType::Tensor) ? fd.getFusionState(args_.at(1).index)->as() : nullptr; - auto bias = (args_.at(2).stype == serde::StateType_Tensor) + auto bias = (args_.at(2).stype == serde::StateType::Tensor) ? fd.getFusionState(args_.at(2).index)->as() : nullptr; - auto running_mean = (args_.at(3).stype == serde::StateType_Tensor) + auto running_mean = (args_.at(3).stype == serde::StateType::Tensor) ? fd.getFusionState(args_.at(3).index)->as() : nullptr; - auto running_var = (args_.at(4).stype == serde::StateType_Tensor) + auto running_var = (args_.at(4).stype == serde::StateType::Tensor) ? fd.getFusionState(args_.at(4).index)->as() : nullptr; auto momentum = fd.getFusionState(args_.at(5).index)->as(); @@ -2296,7 +2296,7 @@ struct BatchNormOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_BatchNorm, + serde::RecordData::BatchNorm, serde::CreateBatchNorm(builder, training_, channels_last_).Union()}; } @@ -2314,7 +2314,7 @@ struct TensorSizesRecord : RecordFunctor { std::move(args), std::move(outputs), "ops.tensor_sizes", - serde::RecordType_TensorSizes) { + serde::RecordType::TensorSizes) { always_returns_tuple_ = true; } ~TensorSizesRecord() override = default; @@ -2348,7 +2348,7 @@ struct ShapeOpRecord : RecordFunctor { std::move(args), std::move(outputs), "ops.shape", - serde::RecordType_ShapeOp) {} + serde::RecordType::ShapeOp) {} ~ShapeOpRecord() override = default; RecordFunctor* clone() final { return new ShapeOpRecord(*this); @@ -2378,7 +2378,7 @@ struct SizeOpRecord : RecordFunctor { std::move(args), std::move(outputs), "ops.size", - serde::RecordType_SizeOp), + serde::RecordType::SizeOp), dim_(dim) {} ~SizeOpRecord() override = default; RecordFunctor* clone() final { @@ -2410,7 +2410,7 @@ struct SizeOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { - return {serde::RecordData_Size, serde::CreateSize(builder, dim_).Union()}; + return {serde::RecordData::Size, serde::CreateSize(builder, dim_).Union()}; } void print(std::ostream& os, bool close_function = true) const final { @@ -2434,7 +2434,7 @@ struct AtOpRecord : RecordFunctor { std::move(args), std::move(outputs), "ops.at", - serde::RecordType_AtOp), + serde::RecordType::AtOp), index_(index) {} ~AtOpRecord() override = default; RecordFunctor* clone() final { @@ -2460,7 +2460,8 @@ struct AtOpRecord : RecordFunctor { void operator()(FusionState& fd) final { NVF_CHECK( - args_.at(0).stype == serde::StateType_Vector, "Expected Vector State!"); + args_.at(0).stype == serde::StateType::Vector, + "Expected Vector State!"); const std::vector& arg = fd.getFusionStateVector(args_.at(0).index); auto result = at(arg, index_); fd.setFusionState(outputs_.at(0).index, result); @@ -2468,7 +2469,7 @@ struct AtOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { - return {serde::RecordData_At, serde::CreateAt(builder, index_).Union()}; + return {serde::RecordData::At, serde::CreateAt(builder, index_).Union()}; } void print(std::ostream& os, bool close_function = true) const final { @@ -2487,78 +2488,48 @@ struct FullOpRecord : RecordFunctor { FullOpRecord( std::vector _args, std::vector _outputs, - std::vector shape, PrimDataType dtype) : RecordFunctor( std::move(_args), std::move(_outputs), "ops.full", - serde::RecordType_FullOp), - shape_(std::move(shape)), - dtype_(dtype) {} + serde::RecordType::FullOp), + dtype_(dtype) { + setArgName(0, "shape"); + setArgName(1, "fill_value"); + } ~FullOpRecord() override = default; RecordFunctor* clone() final { return new FullOpRecord(*this); } //! Child specific hash function in lower 32 bits. - //! | 31 --- 24 | 23 -------------------------- 0 | - //! | Dtype | Shape hash code | + //! | 31 -------------------------------------- 0 | + //! | Dtype | size_t hash() const final { auto result = RecordFunctor::hash(); - size_t shape_hash = 0; - for (auto p : shape_) { - shape_hash ^= static_cast(p); - } - result |= ((static_cast(dtype_) & 0xff) << 24); - result |= (shape_hash & 0xffff); + result |= (static_cast(dtype_) & 0xffffffff); return result; } bool operator==(const RecordFunctor& other) const final { auto result = false; if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && - shape_ == child_ptr->shape_ && dtype_ == child_ptr->dtype_; + result = RecordFunctor::operator==(other) && dtype_ == child_ptr->dtype_; } return result; } void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index); + const std::vector& shape = fd.getFusionStateVector(args_.at(0).index); + auto fill_value = fd.getFusionState(args_.at(1).index); - std::vector nvf_shape(shape_.size(), nullptr); - for (const auto idx : c10::irange(shape_.size())) { - nvf_shape[idx] = IrBuilder::create(shape_.at(idx)); - } - auto output = full(nvf_shape, arg, dtype_); + auto output = full(shape, fill_value, dtype_); fd.setFusionState(outputs_.at(0).index, output); } void print(std::ostream& os, bool close_function = true) const override { - bool first_output = true; - for (auto& output : outputs_) { - if (first_output) { - first_output = false; - } else { - os << ", "; - } - os << output; - } - os << " = " - << "fd." << name_ << "("; - os << "fill_value=" << args_.at(0); - os << ", shape=["; - bool first_arg = true; - for (auto p : shape_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << p; - } - os << "]"; + RecordFunctor::print(os, false); os << ", dtype=" << dtypeToPyString(dtype_); if (close_function) { os << ")"; @@ -2568,15 +2539,12 @@ struct FullOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_TensorCreation, - serde::CreateTensorCreationDirect( - builder, &shape_, toUnderlying(dtype_)) + serde::RecordData::TensorCreationSymbolic, + serde::CreateTensorCreationSymbolic(builder, toUnderlying(dtype_)) .Union()}; } private: - //! Represents shape of new tensor - std::vector shape_; //! Type of output PrimDataType dtype_; }; @@ -2590,7 +2558,7 @@ struct IotaOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), "ops.iota", - serde::RecordType_IotaOp), + serde::RecordType::IotaOp), dtype_(dtype) {} ~IotaOpRecord() override = default; RecordFunctor* clone() final { @@ -2614,10 +2582,10 @@ struct IotaOpRecord : RecordFunctor { void operator()(FusionState& fd) final { auto length = fd.getFusionState(args_.at(0).index); - auto start = (args_.at(1).stype == serde::StateType_Scalar) + auto start = (args_.at(1).stype == serde::StateType::Scalar) ? fd.getFusionState(args_.at(1).index)->as() : nullptr; - auto step = (args_.at(2).stype == serde::StateType_Scalar) + auto step = (args_.at(2).stype == serde::StateType::Scalar) ? fd.getFusionState(args_.at(2).index)->as() : nullptr; auto output = iota(length, start, step, dtype_); @@ -2635,8 +2603,8 @@ struct IotaOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Dtype, - serde::CreateDtype(builder, toUnderlying(dtype_)).Union()}; + serde::RecordData::Dtype, + serde::CreateDtype(builder, nvfuser::toUnderlying(dtype_)).Union()}; } private: @@ -2645,55 +2613,47 @@ struct IotaOpRecord : RecordFunctor { }; //! Specialized Record Functors for random ops. -struct RandomOpRecord : RecordFunctor { - RandomOpRecord( +template +struct RandomDistOpRecord : RecordFunctor { + RandomDistOpRecord( std::vector _args, std::vector _outputs, - std::vector output_shape, - std::string _name, PrimDataType dtype) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - _name, - serde::RecordType_RandomOp), - output_shape_(std::move(output_shape)), + : RecordFunctor(std::move(_args), std::move(_outputs), "", RType), dtype_(dtype) { - if (args_.size() == 4) { - // seed and offset were provided in addition to the usual 2 arguments - setArgName(2, "rng_seed"); - setArgName(3, "rng_offset"); + if constexpr (RType == serde::RecordType::UniformDistOp) { + name_ = "ops.uniform"; + } else if constexpr (RType == serde::RecordType::NormalDistOp) { + name_ = "ops.normal"; + } else { + static_assert( + (RType == serde::RecordType::NormalDistOp) || + (RType == serde::RecordType::UniformDistOp)); + } + setArgName(2, "shape"); + if (args_.size() == 5) { + setArgName(3, "rng_seed"); + setArgName(4, "rng_offset"); } } - ~RandomOpRecord() override = default; + ~RandomDistOpRecord() override = default; RecordFunctor* clone() final { - return new RandomOpRecord(*this); + return new RandomDistOpRecord(*this); } //! Child specific hash function in lower 32 bits. - //! | 31 -------------- 16 | 15 -------------- 0 | - //! | distribution hash | output_shape hash | + //! | 31 --------------------------------------- 0 | + //! | Dtype | size_t hash() const final { auto result = RecordFunctor::hash(); - return result | (output_shape_.size() & 0xffff) | - (std::hash{}(name_.c_str()) & 0xffff << 16); + return result | (static_cast(dtype_) & 0xffffffff); } bool operator==(const RecordFunctor& other) const final { auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { + if (auto child_ptr = dynamic_cast(&other)) { result = RecordFunctor::operator==(other); - if (result) { - result = (output_shape_.size() == child_ptr->output_shape_.size()); - if (result) { - for (size_t i = 0; i < output_shape_.size(); ++i) { - if (output_shape_[i] != child_ptr->output_shape_[i]) { - result = false; - break; - } - } - } - } + result = result && (dtype_ == child_ptr->dtype_); } return result; } @@ -2701,51 +2661,37 @@ struct RandomOpRecord : RecordFunctor { void operator()(FusionState& fd) final { auto arg1 = fd.getFusionState(args_.at(0).index); auto arg2 = fd.getFusionState(args_.at(1).index); + const std::vector& output_shape = + fd.getFusionStateVector(args_.at(2).index); - std::vector output_shape(output_shape_.size(), nullptr); - std::transform( - output_shape_.begin(), - output_shape_.end(), - output_shape.begin(), - [&fd](const State& state) { - return fd.getFusionState(state.index)->template as(); - }); Val* output = nullptr; - if (name_.compare("ops.uniform") == 0) { - if (args_.size() == 2) { // stochastic uniform + if constexpr (RType == serde::RecordType::UniformDistOp) { + if (args_.size() == 3) { // stochastic uniform output = uniform(output_shape, arg1, arg2, dtype_); - } else if (args_.size() == 4) { // provided seed and offset - auto seed = fd.getFusionState(args_.at(2).index); - auto offset = fd.getFusionState(args_.at(3).index); + } else if (args_.size() == 5) { // provided seed and offset + auto seed = fd.getFusionState(args_.at(3).index); + auto offset = fd.getFusionState(args_.at(4).index); output = uniform(output_shape, arg1, arg2, dtype_, seed, offset); } - } else if (name_.compare("ops.normal") == 0) { - if (args_.size() == 2) { // stochastic normal + } else if constexpr (RType == serde::RecordType::NormalDistOp) { + if (args_.size() == 3) { // stochastic normal output = normal(output_shape, arg1, arg2, dtype_); - } else if (args_.size() == 4) { // provided seed and offset - auto seed = fd.getFusionState(args_.at(2).index); - auto offset = fd.getFusionState(args_.at(3).index); + } else if (args_.size() == 5) { // provided seed and offset + auto seed = fd.getFusionState(args_.at(3).index); + auto offset = fd.getFusionState(args_.at(4).index); output = normal(output_shape, arg1, arg2, dtype_, seed, offset); } } else { - NVF_ERROR(false, "random distribution not recognized:", name_); + static_assert( + (RType == serde::RecordType::NormalDistOp) || + (RType == serde::RecordType::UniformDistOp)); } + fd.setFusionState(outputs_.at(0).index, output); } void print(std::ostream& os, bool close_function = true) const final { RecordFunctor::print(os, false); - os << ", shape=["; - bool first_arg = true; - for (auto shape : output_shape_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << shape; - } - os << "]"; os << ", dtype=" << dtypeToPyString(dtype_); if (close_function) { os << ")"; @@ -2754,21 +2700,13 @@ struct RandomOpRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { - std::vector fb_shape; - fb_shape.reserve(output_shape_.size()); - for (auto& it : output_shape_) { - fb_shape.emplace_back(it.index, it.stype); - } return { - serde::RecordData_TensorCreationSymbolic, - serde::CreateTensorCreationSymbolicDirect( - builder, &fb_shape, toUnderlying(dtype_)) + serde::RecordData::TensorCreationSymbolic, + serde::CreateTensorCreationSymbolic(builder, toUnderlying(dtype_)) .Union()}; } private: - //! Represents the tensor dimensions of the output tensor. - std::vector output_shape_; //! DataType of output PrimDataType dtype_; }; @@ -2784,7 +2722,7 @@ struct VectorRecord : RecordFunctor { std::move(_args), std::move(_outputs), "define_vector", - serde::RecordType_Vector), + serde::RecordType::Vector), dtype_(dtype) {} ~VectorRecord() override = default; RecordFunctor* clone() final { @@ -2816,7 +2754,7 @@ struct VectorRecord : RecordFunctor { dtype_); for (size_t i = 0; i < args_.size(); ++i) { NVF_CHECK( - args_.at(i).stype == serde::StateType_Scalar, + args_.at(i).stype == serde::StateType::Scalar, "Unsupported State type!"); output.at(i) = fd.getFusionState(args_.at(i).index); } @@ -2852,8 +2790,8 @@ struct VectorRecord : RecordFunctor { std::pair> recordData( flatbuffers::FlatBufferBuilder& builder) const final { return { - serde::RecordData_Vector, - serde::CreateVector(builder, toUnderlying(dtype_)).Union()}; + serde::RecordData::Vector, + serde::CreateVector(builder, nvfuser::toUnderlying(dtype_)).Union()}; }; private: diff --git a/csrc/python_frontend/fusion_state.cpp b/csrc/python_frontend/fusion_state.cpp index 77ad8a095d8..4625d80c1c8 100644 --- a/csrc/python_frontend/fusion_state.cpp +++ b/csrc/python_frontend/fusion_state.cpp @@ -31,13 +31,13 @@ bool State::operator!=(const State& other) const { // Generalized printing of State std::ostream& operator<<(std::ostream& os, const State& state) { - if (state.stype == serde::StateType_Scalar) { + if (state.stype == serde::StateType::Scalar) { os << "S"; - } else if (state.stype == serde::StateType_Tensor) { + } else if (state.stype == serde::StateType::Tensor) { os << "T"; - } else if (state.stype == serde::StateType_Vector) { + } else if (state.stype == serde::StateType::Vector) { os << "V"; - } else if (state.stype == serde::StateType_None) { + } else if (state.stype == serde::StateType::None) { os << "None"; } else { NVF_ERROR(false, "Unsupported StateType"); diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 92627ac21bd..bd70fc9100b 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -115,8 +115,6 @@ Tensor broadcast_in_dim_fn( std::vector& broadcast_dims) { FUSER_PERF_SCOPE("Operators.broadcast_in_dim"); FusionDefinition* fd = op.fusion_definition; - NVF_CHECK(!fd->completed(), "Attempting to add to a completed definition!"); - NVF_CHECK(op.validUse(), "Attempting to add to a completed definition!"); Vector output_shape = ShapeAsVector(generic_output_shape, *fd); NVF_CHECK( @@ -132,6 +130,23 @@ Tensor broadcast_in_dim_fn( return output; } +template +Tensor full_op_fn( + FusionDefinition::Operators& self, + ShapeType generic_output_shape, + Scalar fill_value, + PrimDataType dtype) { + NVF_CHECK(self.validUse(), "Attempting to add to a completed definition!"); + FusionDefinition* fd = self.fusion_definition; + Vector output_shape = ShapeAsVector(generic_output_shape, *fd); + Tensor output = fd->defineTensor(output_shape.size); + fd->defineRecord(new FullOpRecord( + {fd->recordingState(output_shape()), fd->recordingState(fill_value())}, + {fd->recordingState(output())}, + dtype)); + return output; +} + template Tensor reshape_fn( FusionDefinition::Operators& self, @@ -149,6 +164,47 @@ Tensor reshape_fn( return output; } +template +Tensor random_dist_op_fn( + FusionDefinition::Operators& self, + Scalar arg1, + Scalar arg2, + ShapeType generic_new_shape, + std::optional rng_seed, + std::optional rng_offset, + PrimDataType dtype) { + static_assert( + (RType == serde::RecordType::NormalDistOp) || + (RType == serde::RecordType::UniformDistOp)); + NVF_CHECK(self.validUse(), "Attempting to add to a completed definition!"); + NVF_CHECK( + isFloatingPointType(dtype), + "Random distributions only create floating point types! ", + dtype); + FusionDefinition* fd = self.fusion_definition; + Vector new_shape = ShapeAsVector(generic_new_shape, *fd); + + Tensor output = fd->defineTensor(new_shape.size); + std::vector arg_states = { + fd->recordingState(arg1()), + fd->recordingState(arg2()), + fd->recordingState(new_shape()), + }; + if (rng_seed.has_value() && rng_offset.has_value()) { + arg_states.push_back(fd->recordingState(rng_seed.value()())); + arg_states.push_back(fd->recordingState(rng_offset.value()())); + } else { + NVF_CHECK( + !rng_seed.has_value() && !rng_offset.has_value(), + "rng_seed and rng_offset must be provided together!"); + } + + fd->defineRecord(new RandomDistOpRecord( + arg_states, {fd->recordingState(output())}, dtype)); + + return output; +} + struct DimInfo { int64_t index; int64_t size; @@ -215,29 +271,43 @@ std::vector> computeContiguity( // `stride order` vector corresponds to the order for each logical domain in // physical memory; For any 0 <= i < n , we know the dimension i has the // stride_order[i]-th smallest stride. +// An exception to this are implicit broadcast dimensions, i.e. dimensions +// with `stride == 0`, where we would maintain their semantical position // `contiguity` vector to whether or not indexing could be collaped // corresponding to each physical domain; // // e.g. Given size and stride as follow: -// sizes = [2, 1, 3, 1, 4, 3] -// strides = [12, 4, 4, 4, 1, 0] -// we would compute stride order as: [5, 4, 3, 2, 1, 0]. Since the original -// stride is in descending order. Note that there's more than one way to define -// a stride order when we have equal strides. In the context of index -// collapsing, how we resolve that shouldn't matter, hence we just go with -// preserving their original order. Similarly, we compute contiguity as: [True, -// None, True, None, True, None], Since the physical order is the same as the -// logical order, this one is trivial to compute. +// sizes = [2, 2, 2, 2] +// strides = [8, 4, 2, 1] +// Obviously the stride order as: [3, 2, 1, 0] for row-major order, i.e. stride +// in descending order and contiguity flag will be [True, True, True, True] +// +// e.g. Given size and stride as follow: +// sizes = [2, 1, 3, 1, 4] +// strides = [24, 4, 8, 4, 2] +// Note that there are a few explicit broadcast dimensions, dimensions with size +// == 1 and stride != 0. The stride for explicit broadcast dimensions +// participates in stride order computation. The reason is that, frameworks +// could assign meaningful stride to an explicit broadcast dimensions to hint +// memory format, which could be used to deduce the desired output memory +// format. We use stable sort to break tie when two dimension has equal stride, +// i.e. try to preserve their semantical order. Hence, we would compute stride +// order as: [4, 2, 3, 1, 0]. In the context of index, collapsing, how we +// resolve that shouldn't matter. With sorted sizes & strides: +// sorted_size = [2, 3, 1, 1, 4] +// sorted_strides = [24, 8, 4, 4, 2] +// Here, we compute contiguity as: [True, True, None, None, False] // // e.g. Given size and stride as follow: -// sizes = [2, 3, 1, 5, 4] -// strides = [28, 4, 14, 0, 1] -// stride_order would be: [4, 2, 3, 0, 1], marking the order of strides in the -// vector. Meanwhile, contiguity would be computed on the physical domain, i.e. -// on sorted sizes & strides. -// sorted_size = [2, 1, 3, 4, 5] -// sorted_strides = [28, 14, 4, 1, 0] -// contiguity would be: [False, None, True, True, None] +// sizes = [2, 2, 2, 2] +// strides = [8, 4, 0, 2] +// The stride of implicit broadcast dimensions, dimensions with stride == 0, +// does not participate in stride order computation and preserves their +// semantical position in stride order. The logic behind this is so that we +// would not unnecessarily introduce permutated alloc_domain for a naive +// unsqueeze/expanded operation when it doesn't improve indexing. For the given +// example, computed stride_order would be: [3, 2, 1, 0] and contiguity would +// be: [True, True, None, False] // // This function returns a pair of std::pair>, std::vector> @@ -246,19 +316,39 @@ computeTensorDescriptor( const std::vector& strides) { NVF_CHECK( sizes.size() == strides.size(), - "compute_tensor_descriptor: Sizes and strides must have the same number of dimensions"); - std::vector dim_info_vec; + "compute_tensor_descriptor: " + "Sizes and strides must have the same number of dimensions"); + std::vector non_broadcast_dim_info_vec; + std::vector stride_zero_dims; for (auto i : c10::irange(sizes.size())) { - // NOTE: not supporting negative stride yet. - NVF_CHECK(strides[i] >= 0, "negative stride on tensor is not supported"); - dim_info_vec.emplace_back(DimInfo{(int64_t)i, sizes[i], strides[i]}); + // NOTE: not supporting negative stride yet, but we can probably allow it on + // broadcast dims + NVF_CHECK( + strides[i] >= 0, + "negative stride on tensor is not supported: strides[", + i, + "]=", + strides[i]); + DimInfo dim_info{(int64_t)i, sizes[i], strides[i]}; + if (strides[i] != 0) { + non_broadcast_dim_info_vec.push_back(dim_info); + } else { + stride_zero_dims.push_back(dim_info); + } } - // sort by stride - std::sort( - dim_info_vec.begin(), - dim_info_vec.end(), + // sort non-broadcast dimensions by stride + std::stable_sort( + non_broadcast_dim_info_vec.begin(), + non_broadcast_dim_info_vec.end(), [](const auto& l, const auto& r) { return l.stride > r.stride; }); + // combine dimensions while preserving the semantical position of broadcast + // dimensions + for (const auto& dim_info : stride_zero_dims) { + non_broadcast_dim_info_vec.insert( + non_broadcast_dim_info_vec.begin() + dim_info.index, dim_info); + } + // Dimensions are marked contiguous by inspecting the current dimension and // one to the right towards the inner dimension while skipping over broadcast // dimensions. @@ -266,31 +356,34 @@ computeTensorDescriptor( // dimension to it's right and needs to have stride equal to 1 in order to be // marked contiguous. for (int64_t i = 0; i < (int64_t)sizes.size();) { - dim_info_vec[i].stride_order = (int64_t)sizes.size() - 1 - i; - if (!dim_info_vec[i].isBroadcast()) { + non_broadcast_dim_info_vec[i].stride_order = (int64_t)sizes.size() - 1 - i; + if (!non_broadcast_dim_info_vec[i].isBroadcast()) { auto l = i++; int64_t expected = 1; for (; i < (int64_t)sizes.size(); i++) { - dim_info_vec[i].stride_order = (int64_t)sizes.size() - 1 - i; - if (!dim_info_vec[i].isBroadcast()) { - expected = dim_info_vec[i].stride * dim_info_vec[i].size; + non_broadcast_dim_info_vec[i].stride_order = + (int64_t)sizes.size() - 1 - i; + if (!non_broadcast_dim_info_vec[i].isBroadcast()) { + expected = non_broadcast_dim_info_vec[i].stride * + non_broadcast_dim_info_vec[i].size; break; } } - dim_info_vec[l].contiguity = (dim_info_vec[l].stride == expected); + non_broadcast_dim_info_vec[l].contiguity = + (non_broadcast_dim_info_vec[l].stride == expected); } else { i++; } } std::vector stride_order_vec(sizes.size(), -1); - for (const auto& dim_info : dim_info_vec) { + for (const auto& dim_info : non_broadcast_dim_info_vec) { stride_order_vec[dim_info.index] = dim_info.stride_order; } std::vector> contiguity_vec; std::transform( - dim_info_vec.begin(), - dim_info_vec.end(), + non_broadcast_dim_info_vec.begin(), + non_broadcast_dim_info_vec.end(), std::back_inserter(contiguity_vec), [](const DimInfo& val) { return val.contiguity; }); @@ -315,6 +408,7 @@ void initNvFuserPythonBindings(PyObject* module) { nvfuser.def("compute_contiguity", computeContiguity); nvfuser.def("compute_tensor_descriptor", computeTensorDescriptor); + nvfuser.def("serialize", serialize); //! Binding the FusionCache that holds a cache of Fusions //! This is only bound to provide an interface to get the number of fusions @@ -325,10 +419,14 @@ void initNvFuserPythonBindings(PyObject* module) { "get", &FusionCache::get, py::arg("max_fusions") = int(8192), + py::arg("load_from_default_workspace") = true, py::return_value_policy::reference) .def("num_fusions", &FusionCache::numFusions) .def_static( - "reset", &FusionCache::reset, py::return_value_policy::reference) + "reset", + &FusionCache::reset, + py::arg("load_from_default_workspace") = false, + py::return_value_policy::reference) .def( "serialize", [](FusionCache& self, std::string filename) { @@ -339,7 +437,7 @@ void initNvFuserPythonBindings(PyObject* module) { .def( "deserialize", [](FusionCache& self, std::string filename) { - FUSER_PERF_SCOPE("FusionCache.serialize (string)"); + FUSER_PERF_SCOPE("FusionCache.deserialize (string)"); self.deserialize(filename); }, py::arg("filename")) @@ -553,7 +651,7 @@ void initNvFuserPythonBindings(PyObject* module) { !self.completed(), "Attempting to add to a completed definition!"); self.defineRecord(new OutputRecord( - {self.recordingState(output())}, serde::RecordType_OutputVal)); + {self.recordingState(output())}, serde::RecordType::OutputVal)); }, py::arg("output")) .def( @@ -569,10 +667,11 @@ void initNvFuserPythonBindings(PyObject* module) { self.defineRecord(new OutputRecord( {self.recordingState(output()), self.recordingState(alias_input.value()())}, - serde::RecordType_OutputTv)); + serde::RecordType::OutputTv)); } else { self.defineRecord(new OutputRecord( - {self.recordingState(output())}, serde::RecordType_OutputTv)); + {self.recordingState(output())}, + serde::RecordType::OutputTv)); } }, py::arg("output"), @@ -601,7 +700,7 @@ void initNvFuserPythonBindings(PyObject* module) { "duplicated elements in stride_order detected!"); self.defineRecord(new OutputRecord( {self.recordingState(output())}, - serde::RecordType_OutputTv, + serde::RecordType::OutputTv, stride_order)); }, py::arg("output"), @@ -836,7 +935,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(input())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Unary_TV, \ + serde::RecordType::Unary_TV, \ static_cast(op_name))); \ return output; \ }, \ @@ -853,7 +952,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(input())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Unary_VAL, \ + serde::RecordType::Unary_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -928,7 +1027,7 @@ void initNvFuserPythonBindings(PyObject* module) { "Operator stride_order expects `stride_order` argument to have the same length as input!"); FusionDefinition* fd = self.fusion_definition; Tensor output = fd->defineTensor(arg.dims); - fd->defineRecord(new DimsOpRecord( + fd->defineRecord(new DimsOpRecord( {fd->recordingState(arg())}, {fd->recordingState(output())}, std::move(stride_order), @@ -963,7 +1062,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(rng_offset())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_TV_VAL_VAL, \ + serde::RecordType::Ternary_TV_VAL_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -991,7 +1090,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(input())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Unary_TV, \ + serde::RecordType::Unary_TV, \ static_cast(op_name))); \ return output; \ }, \ @@ -1008,7 +1107,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(input())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Unary_VAL, \ + serde::RecordType::Unary_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1032,7 +1131,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Binary_TV, \ + serde::RecordType::Binary_TV, \ static_cast(op_name))); \ return output; \ }, \ @@ -1059,7 +1158,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Binary_TV, \ + serde::RecordType::Binary_TV, \ static_cast(op_name))); \ return output; \ }, \ @@ -1078,7 +1177,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Binary_TV_VAL, \ + serde::RecordType::Binary_TV_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1097,7 +1196,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Binary_VAL_TV, \ + serde::RecordType::Binary_VAL_TV, \ static_cast(op_name))); \ return output; \ }, \ @@ -1116,7 +1215,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Binary_VAL, \ + serde::RecordType::Binary_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1161,7 +1260,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Binary_TV, \ + serde::RecordType::Binary_TV, \ static_cast(op_name))); \ return output; \ }, \ @@ -1176,7 +1275,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Binary_TV_VAL, \ + serde::RecordType::Binary_TV_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1191,7 +1290,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Binary_VAL_TV, \ + serde::RecordType::Binary_VAL_TV, \ static_cast(op_name))); \ return output; \ }, \ @@ -1206,7 +1305,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Binary_VAL, \ + serde::RecordType::Binary_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1262,7 +1361,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_TV_TV_VAL, \ + serde::RecordType::Ternary_TV_TV_VAL, \ static_cast( \ op_name))); \ return output; \ @@ -1285,7 +1384,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_TV_VAL_VAL, \ + serde::RecordType::Ternary_TV_VAL_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1307,7 +1406,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_VAL_TV_VAL, \ + serde::RecordType::Ternary_VAL_TV_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1329,7 +1428,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_VAL, \ + serde::RecordType::Ternary_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1357,7 +1456,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_VAL, \ + serde::RecordType::Ternary_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1380,7 +1479,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_TV, \ + serde::RecordType::Ternary_TV, \ static_cast< \ TensorView* (*)(TensorView*, TensorView*, TensorView*)>( \ op_name))); \ @@ -1405,7 +1504,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_TV_TV_VAL, \ + serde::RecordType::Ternary_TV_TV_VAL, \ static_cast( \ op_name))); \ return output; \ @@ -1429,7 +1528,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_TV_VAL_TV, \ + serde::RecordType::Ternary_TV_VAL_TV, \ static_cast( \ op_name))); \ return output; \ @@ -1453,7 +1552,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_VAL_TV_TV, \ + serde::RecordType::Ternary_VAL_TV_TV, \ static_cast( \ op_name))); \ return output; \ @@ -1476,7 +1575,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_VAL_VAL_TV, \ + serde::RecordType::Ternary_VAL_VAL_TV, \ static_cast(op_name))); \ return output; \ }, \ @@ -1498,7 +1597,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_TV_VAL_VAL, \ + serde::RecordType::Ternary_TV_VAL_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1520,7 +1619,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_VAL_TV_VAL, \ + serde::RecordType::Ternary_VAL_TV_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1548,7 +1647,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_VAL, \ + serde::RecordType::Ternary_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1570,7 +1669,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg3())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_TV_VAL_VAL, \ + serde::RecordType::Ternary_TV_VAL_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1600,7 +1699,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg4())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_Alpha_VAL, \ + serde::RecordType::Ternary_Alpha_VAL, \ static_cast(op_name))); \ return output; \ }, \ @@ -1629,7 +1728,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg4())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_Alpha_TV, \ + serde::RecordType::Ternary_Alpha_TV, \ static_cast< \ TensorView* (*)(TensorView*, TensorView*, TensorView*, Val*)>( \ op_name))); \ @@ -1656,7 +1755,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg4())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_Alpha_TV_TV_VAL, \ + serde::RecordType::Ternary_Alpha_TV_TV_VAL, \ static_cast< \ TensorView* (*)(TensorView*, TensorView*, Val*, Val*)>( \ op_name))); \ @@ -1683,7 +1782,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg4())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_Alpha_TV_VAL_TV, \ + serde::RecordType::Ternary_Alpha_TV_VAL_TV, \ static_cast< \ TensorView* (*)(TensorView*, Val*, TensorView*, Val*)>( \ op_name))); \ @@ -1710,7 +1809,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg4())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_Alpha_VAL_TV_TV, \ + serde::RecordType::Ternary_Alpha_VAL_TV_TV, \ static_cast< \ TensorView* (*)(Val*, TensorView*, TensorView*, Val*)>( \ op_name))); \ @@ -1737,7 +1836,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg4())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_Alpha_VAL_VAL_TV, \ + serde::RecordType::Ternary_Alpha_VAL_VAL_TV, \ static_cast( \ op_name))); \ return output; \ @@ -1763,7 +1862,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg4())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_Alpha_TV_VAL_VAL, \ + serde::RecordType::Ternary_Alpha_TV_VAL_VAL, \ static_cast( \ op_name))); \ return output; \ @@ -1789,7 +1888,7 @@ void initNvFuserPythonBindings(PyObject* module) { fd->recordingState(arg4())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_Ternary_Alpha_VAL_TV_VAL, \ + serde::RecordType::Ternary_Alpha_VAL_TV_VAL, \ static_cast( \ op_name))); \ return output; \ @@ -1896,13 +1995,13 @@ void initNvFuserPythonBindings(PyObject* module) { py::return_value_policy::reference); NVFUSER_PYTHON_BINDING_REDUCTION_OP( - "max", max, serde::RecordType::RecordType_ReductionMax) + "max", max, serde::RecordType::ReductionMax) NVFUSER_PYTHON_BINDING_REDUCTION_OP( - "min", min, serde::RecordType::RecordType_ReductionMin) + "min", min, serde::RecordType::ReductionMin) NVFUSER_PYTHON_BINDING_REDUCTION_OP( - "prod", prod, serde::RecordType::RecordType_ReductionProd) + "prod", prod, serde::RecordType::ReductionProd) NVFUSER_PYTHON_BINDING_REDUCTION_OP( - "sum", sum, serde::RecordType::RecordType_ReductionSum) + "sum", sum, serde::RecordType::ReductionSum) #undef NVFUSER_PYTHON_BINDING_REDUCTION_OP #define NVFUSER_PYTHON_BINDING_CAST_OP(op_str, op_name) \ @@ -1920,7 +2019,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_CastTv, \ + serde::RecordType::CastTv, \ static_cast(op_name), \ dtype)); \ return output; \ @@ -1942,7 +2041,7 @@ void initNvFuserPythonBindings(PyObject* module) { {fd->recordingState(arg())}, \ {fd->recordingState(output())}, \ ("ops." op_str), \ - serde::RecordType_CastVal, \ + serde::RecordType::CastVal, \ static_cast(op_name), \ dtype)); \ return output; \ @@ -1954,6 +2053,48 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_CAST_OP("cast", castOp) #undef NVFUSER_PYTHON_BINDING_CAST_OP +#define NVFUSER_ALL_VECTOR_TYPES(fn, ...) \ + fn(Vector, __VA_ARGS__); \ + fn(py::list, __VA_ARGS__); \ + fn(py::tuple, __VA_ARGS__); + +#define NVFUSER_RANDOM_DIST_OP_HELPER( \ + vec_type, op_str, op_type, arg1_str, arg2_str) \ + nvf_ops.def( \ + op_str, \ + random_dist_op_fn, \ + py::arg(arg1_str), \ + py::arg(arg2_str), \ + py::arg("shape"), \ + py::kw_only(), \ + py::arg("rng_seed") = py::none(), \ + py::arg("rng_offset") = py::none(), \ + py::arg("dtype") = DataType::Float, \ + py::return_value_policy::reference); + +#define NVFUSER_PYTHON_BINDING_RANDOM_DIST_OP(...) \ + NVFUSER_ALL_VECTOR_TYPES(NVFUSER_RANDOM_DIST_OP_HELPER, __VA_ARGS__) + + NVFUSER_PYTHON_BINDING_RANDOM_DIST_OP( + "normal", serde::RecordType::NormalDistOp, "mean", "std") + NVFUSER_PYTHON_BINDING_RANDOM_DIST_OP( + "uniform", serde::RecordType::UniformDistOp, "minval", "maxval") +#undef NVFUSER_PYTHON_BINDING_RANDOM_DIST_OP +#undef NVFUSER_RANDOM_DIST_OP_HELPER + +#define NVFUSER_FULL_OP_HELPER(vec_type, ...) \ + nvf_ops.def( \ + "full", \ + full_op_fn, \ + py::arg("shape"), \ + py::arg("fill_value"), \ + py::arg("dtype"), \ + py::return_value_policy::reference); + + // NOTE: The second argument is a dummy to satisfy the macro + NVFUSER_ALL_VECTOR_TYPES(NVFUSER_FULL_OP_HELPER, false) +#undef NVFUSER_FULL_OP_HELPER + nvf_ops.def( "batch_norm", [](FusionDefinition::Operators& self, @@ -1975,16 +2116,15 @@ void initNvFuserPythonBindings(PyObject* module) { Tensor invstd = fd->defineTensor(1); auto weight_state = weight.has_value() ? fd->recordingState(weight.value()()) - : State(0, serde::StateType::StateType_None); - auto bias_state = bias.has_value() - ? fd->recordingState(bias.value()()) - : State(0, serde::StateType::StateType_None); + : State(0, serde::StateType::None); + auto bias_state = bias.has_value() ? fd->recordingState(bias.value()()) + : State(0, serde::StateType::None); auto running_mean_state = running_mean.has_value() ? fd->recordingState(running_mean.value()()) - : State(0, serde::StateType::StateType_None); + : State(0, serde::StateType::None); auto running_var_state = running_var.has_value() ? fd->recordingState(running_var.value()()) - : State(0, serde::StateType::StateType_None); + : State(0, serde::StateType::None); fd->defineRecord(new BatchNormOpRecord( {fd->recordingState(arg()), weight_state, @@ -2178,7 +2318,7 @@ void initNvFuserPythonBindings(PyObject* module) { Tensor output = fd->defineTensor(arg.dims); auto value_state = value.has_value() ? fd->recordingState(value.value()()) - : State(0, serde::StateType_None); + : State(0, serde::StateType::None); fd->defineRecord(new PadOpRecord( {fd->recordingState(arg()), value_state}, {fd->recordingState(output())}, @@ -2261,7 +2401,7 @@ void initNvFuserPythonBindings(PyObject* module) { FusionDefinition* fd = self.fusion_definition; Tensor output = fd->defineTensor(arg.dims); self.fusion_definition->defineRecord( - new DimsOpRecord( + new DimsOpRecord( {fd->recordingState(arg())}, {fd->recordingState(output())}, std::move(dims), @@ -2498,27 +2638,6 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("arg"), py::arg("new_shape"), py::return_value_policy::reference); - nvf_ops.def( - "full", - [](FusionDefinition::Operators& self, - std::vector& shape, - Scalar fill_value, - PrimDataType dtype) -> Tensor { - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(shape.size()); - fd->defineRecord(new FullOpRecord( - {fd->recordingState(fill_value())}, - {fd->recordingState(output())}, - std::move(shape), - dtype)); - return output; - }, - py::arg("shape"), - py::arg("fill_value"), - py::arg("dtype"), - py::return_value_policy::reference); nvf_ops.def( "iota", [](FusionDefinition::Operators& self, @@ -2532,9 +2651,9 @@ void initNvFuserPythonBindings(PyObject* module) { Tensor output = fd->defineTensor(1); auto start_state = start.has_value() ? fd->recordingState(start.value()()) - : State(0, serde::StateType_None); + : State(0, serde::StateType::None); auto step_state = step.has_value() ? fd->recordingState(step.value()()) - : State(0, serde::StateType_None); + : State(0, serde::StateType::None); fd->defineRecord(new IotaOpRecord( {fd->recordingState(length()), start_state, step_state}, {fd->recordingState(output())}, @@ -2599,102 +2718,6 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("correction") = 1, py::arg("keepdim") = false, py::return_value_policy::reference); - nvf_ops.def( - "uniform", - [](FusionDefinition::Operators& self, - Scalar minval, - Scalar maxval, - std::vector& shape, - PrimDataType dtype, - std::optional rng_seed, - std::optional rng_offset) -> Tensor { - FUSER_PERF_SCOPE("Operators.uniform"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(shape.size()); - std::vector output_shape_states( - shape.size(), State(0, serde::StateType_Scalar)); - std::transform( - shape.begin(), - shape.end(), - output_shape_states.begin(), - [&fd](const Scalar& s) { return fd->recordingState(s()); }); - std::vector arg_states = { - fd->recordingState(minval()), - fd->recordingState(maxval()), - }; - if (rng_seed.has_value()) { - NVF_CHECK( - rng_offset.has_value(), - "When providing rng_seed, rng_offset must also be provided"); - arg_states.push_back(fd->recordingState(rng_seed.value()())); - arg_states.push_back(fd->recordingState(rng_offset.value()())); - } - fd->defineRecord(new RandomOpRecord( - arg_states, - {fd->recordingState(output())}, - output_shape_states, - "ops.uniform", - dtype)); - return output; - }, - py::arg("minval"), - py::arg("maxval"), - py::arg("shape"), - py::arg("dtype") = DataType::Float, - py::kw_only(), - py::arg("rng_seed") = py::none(), - py::arg("rng_offset") = py::none(), - py::return_value_policy::reference); - nvf_ops.def( - "normal", - [](FusionDefinition::Operators& self, - Scalar mean, - Scalar std, - std::vector& shape, - PrimDataType dtype, - std::optional rng_seed, - std::optional rng_offset) -> Tensor { - FUSER_PERF_SCOPE("Operators.normal"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(shape.size()); - std::vector output_shape_states( - shape.size(), State(0, serde::StateType_Scalar)); - std::transform( - shape.begin(), - shape.end(), - output_shape_states.begin(), - [&fd](const Scalar& s) { return fd->recordingState(s()); }); - std::vector arg_states = { - fd->recordingState(mean()), - fd->recordingState(std()), - }; - if (rng_seed.has_value()) { - NVF_CHECK( - rng_offset.has_value(), - "When providing rng_seed, rng_offset must also be provided"); - arg_states.push_back(fd->recordingState(rng_seed.value()())); - arg_states.push_back(fd->recordingState(rng_offset.value()())); - } - fd->defineRecord(new RandomOpRecord( - arg_states, - {fd->recordingState(output())}, - output_shape_states, - "ops.normal", - dtype)); - return output; - }, - py::arg("mean"), - py::arg("std"), - py::arg("shape"), - py::arg("dtype") = DataType::Float, - py::kw_only(), - py::arg("rng_seed") = py::none(), - py::arg("rng_offset") = py::none(), - py::return_value_policy::reference); //! The ScedOperators class is a nested class of FusionDefinition to allow the //! user to query the class for the list of schedule operators. //! diff --git a/csrc/python_frontend/test/test_nvfuser_fusion_cache.cpp b/csrc/python_frontend/test/test_nvfuser_fusion_cache.cpp index f8c7f17629c..99568b10878 100644 --- a/csrc/python_frontend/test/test_nvfuser_fusion_cache.cpp +++ b/csrc/python_frontend/test/test_nvfuser_fusion_cache.cpp @@ -57,7 +57,7 @@ TEST_F(NVFuserTest, PyFusionCache_CUDA) { // record to an empty cache. { std::unique_ptr test_record(new TensorRecord( - {State(0, serde::StateType_Tensor)}, {3}, {true}, DataType::Float)); + {State(0, serde::StateType::Tensor)}, {3}, {true}, DataType::Float)); TrieNode* root = fc->rootTriePtr(); TrieNode* node = nullptr; @@ -113,9 +113,9 @@ TEST_F(NVFuserTest, PyFusionCache_CUDA) { // record to a cache with 1 fusion. { std::unique_ptr cached_record(new TensorRecord( - {State(0, serde::StateType_Tensor)}, {3}, {true}, DataType::Float)); + {State(0, serde::StateType::Tensor)}, {3}, {true}, DataType::Float)); std::unique_ptr new_record(new ScalarRecord( - {State(1, serde::StateType_Scalar)}, + {State(1, serde::StateType::Scalar)}, std::monostate{}, DataType::Float)); TrieNode* root = fc->rootTriePtr(); @@ -158,9 +158,9 @@ TEST_F(NVFuserTest, PyFusionCache_CUDA) { // This tends to flush out pointer problems in the cache. { std::unique_ptr test_record(new TensorRecord( - {State(0, serde::StateType_Tensor)}, {3}, {true}, DataType::Float)); + {State(0, serde::StateType::Tensor)}, {3}, {true}, DataType::Float)); std::unique_ptr dummy_record(new TensorRecord( - {State(0, serde::StateType_Tensor)}, {3}, {true}, DataType::Float)); + {State(0, serde::StateType::Tensor)}, {3}, {true}, DataType::Float)); TrieNode* root = fc->rootTriePtr(); TrieNode* node = nullptr; diff --git a/csrc/python_frontend/test/test_nvfuser_fusion_definition.cpp b/csrc/python_frontend/test/test_nvfuser_fusion_definition.cpp index 7a80e1d9c2b..8f30cb28aac 100644 --- a/csrc/python_frontend/test/test_nvfuser_fusion_definition.cpp +++ b/csrc/python_frontend/test/test_nvfuser_fusion_definition.cpp @@ -68,7 +68,7 @@ TEST_F(NVFuserTest, FusionDefinition_CUDA) { {fd.recordingState(t0()), fd.recordingState(s1())}, {fd.recordingState(t2())}, "ops.add", - serde::RecordType_Binary_TV_VAL, + serde::RecordType::Binary_TV_VAL, static_cast(add))); SUCCEED(); } catch (const std::exception& e) { @@ -77,7 +77,7 @@ TEST_F(NVFuserTest, FusionDefinition_CUDA) { try { fd.defineRecord(new OutputRecord( - {fd.recordingState(t2())}, serde::RecordType_OutputTv)); + {fd.recordingState(t2())}, serde::RecordType::OutputTv)); SUCCEED(); } catch (const std::exception& e) { FAIL() << "Unexpected assert during Output Record creation! " << e.what(); @@ -85,7 +85,7 @@ TEST_F(NVFuserTest, FusionDefinition_CUDA) { try { fd.defineRecord(new OutputRecord( - {fd.recordingState(s1())}, serde::RecordType_OutputVal)); + {fd.recordingState(s1())}, serde::RecordType::OutputVal)); FAIL() << "Expected an assert for too many records!"; } catch (...) { SUCCEED(); @@ -148,7 +148,7 @@ TEST_F(NVFuserTest, FusionDefinition_CUDA) { {fd.recordingState(t0()), fd.recordingState(s1())}, {fd.recordingState(t2())}, "ops.add", - serde::RecordType_Binary_TV_VAL, + serde::RecordType::Binary_TV_VAL, static_cast(add))); SUCCEED(); } catch (const std::exception& e) { @@ -157,7 +157,7 @@ TEST_F(NVFuserTest, FusionDefinition_CUDA) { try { fd.defineRecord(new OutputRecord( - {fd.recordingState(t2())}, serde::RecordType_OutputTv)); + {fd.recordingState(t2())}, serde::RecordType::OutputTv)); SUCCEED(); } catch (const std::exception& e) { FAIL() << "Unexpected assert during Output Record creation! " << e.what(); diff --git a/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp b/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp index e0eabf5122d..70cf4350cf6 100644 --- a/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp +++ b/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp @@ -25,29 +25,29 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { // OpRecord Equality Check { - auto t0 = State(0, serde::StateType_Tensor); - auto s1 = State(1, serde::StateType_Scalar); - auto out = State(2, serde::StateType_Tensor); + auto t0 = State(0, serde::StateType::Tensor); + auto s1 = State(1, serde::StateType::Scalar); + auto out = State(2, serde::StateType::Tensor); std::unique_ptr test_record1( new OpRecord( {t0, s1}, {out}, "ops.mul", - serde::RecordType_Binary_TV_VAL, + serde::RecordType::Binary_TV_VAL, static_cast(mul))); std::unique_ptr test_record2( new OpRecord( {t0, s1}, {out}, "ops.mul", - serde::RecordType_Binary_TV_VAL, + serde::RecordType::Binary_TV_VAL, static_cast(mul))); std::unique_ptr test_record3( new OpRecord( {t0, s1}, {out}, "ops.mul", - serde::RecordType_Binary_TV_VAL, + serde::RecordType::Binary_TV_VAL, static_cast(mul))); EXPECT_TRUE(*test_record1 == *test_record2); @@ -57,14 +57,14 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { // CastOpRecord Equality Check { - auto t0 = State(0, serde::StateType_Tensor); - auto out = State(1, serde::StateType_Tensor); + auto t0 = State(0, serde::StateType::Tensor); + auto out = State(1, serde::StateType::Tensor); std::unique_ptr test_record1( new CastOpRecord( {t0}, {out}, "ops.cast", - serde::RecordType_CastTv, + serde::RecordType::CastTv, static_cast(castOp), DataType::Half)); std::unique_ptr test_record2( @@ -72,7 +72,7 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {t0}, {out}, "ops.cast", - serde::RecordType_CastTv, + serde::RecordType::CastTv, static_cast(castOp), DataType::Half)); std::unique_ptr test_record3( @@ -80,7 +80,7 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {t0}, {out}, "ops.cast", - serde::RecordType_CastTv, + serde::RecordType::CastTv, static_cast(castOp), DataType::Half)); @@ -91,13 +91,13 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { // ReductionOpRecord Equality Check { - auto t0 = State(0, serde::StateType_Tensor); - auto out = State(1, serde::StateType_Tensor); + auto t0 = State(0, serde::StateType::Tensor); + auto out = State(1, serde::StateType::Tensor); std::unique_ptr test_record1(new ReductionOpRecord( {t0}, {out}, "ops.sum", - serde::RecordType_ReductionSum, + serde::RecordType::ReductionSum, static_cast&, bool, @@ -109,7 +109,7 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {t0}, {out}, "ops.sum", - serde::RecordType_ReductionSum, + serde::RecordType::ReductionSum, static_cast&, bool, @@ -121,7 +121,7 @@ TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) { {t0}, {out}, "ops.sum", - serde::RecordType_ReductionSum, + serde::RecordType::ReductionSum, static_cast&, bool, diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index eae4480d6c7..18d1269f7e0 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -170,7 +170,9 @@ std::unordered_map PairwiseRootDomainMap::map( // Condition 3: when the producer ID is a removed broadcast domain, there is // no mapping for it. if (!squeeze_flags.empty() && squeeze_flags.at(itp)) { - NVF_ERROR(producer_id->isBroadcast()); + // Dynamic IterDomains can be squeezed, in which case they must concretize + // to broadcasts + NVF_ERROR(producer_id->isBroadcast() || producer_id->isSymbolic()); itp++; continue; } diff --git a/csrc/scheduler/cache_policy_refiner.cpp b/csrc/scheduler/cache_policy_refiner.cpp index 3d4e5a82b5a..b7f87de9417 100644 --- a/csrc/scheduler/cache_policy_refiner.cpp +++ b/csrc/scheduler/cache_policy_refiner.cpp @@ -18,6 +18,11 @@ namespace nvfuser { namespace { +template +void vlog(const Args&... args) { + scheduler_debug_utils::log("[cache_policy_refiner] ", args...); +} + // Returns whether a pointwise expression `expr` expands its input operand // `in_tv`. bool pointwiseExpands(const Expr* expr, const TensorView* in_tv) { @@ -115,11 +120,11 @@ const Expr* findExpand(const LoadStoreOp* ldst) { // Returns true if the cache policy is changed. bool refineCachePolicy(LoadStoreOp* ldst) { - scheduler_debug_utils::log("Processing ", ldst->toString()); + vlog("Processing ", ldst->toString()); const Expr* expand = findExpand(ldst); if (expand == nullptr) { - scheduler_debug_utils::log( + vlog( "Skipped ", ldst->toString(), " because we cannot find the using expand."); @@ -127,7 +132,7 @@ bool refineCachePolicy(LoadStoreOp* ldst) { } auto target_cache_op = CacheOp::AllLevels; - scheduler_debug_utils::log( + vlog( "Changed the cache op of ", ldst->toString(), " from ", diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index e6dc2801cd7..2edad60902d 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -154,397 +154,386 @@ void swizzleSharedMemory( const int64_t tile_size_y = shared_mem_tv->axis(-1 - skip)->extent()->evaluate().as(); - if (isTuring(params.mma_macro) || isAmpere(params.mma_macro)) { - // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. - // half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit - // (i.e. float) - const int64_t data_type_size = - (int64_t)dataTypeSize(*shared_mem_tv->getDataType()); - NVF_ERROR(data_type_size == 2 || data_type_size == 4); - - // For main loop, ldmatrix loads a n_rows x n_cols = 8 x 8 matrix each time. - // For epilogue, threads in a warp is organized as 8 rows x 4 columns. - // Each thread vectorized write 2 items, so 8 items per row. - //--0--1--2--3 - //--4--5--6--7 - //--8--9--10-11 - //--12-13-14-15 - //--16-17-18-19 - //--20-21-22-23 - //--24-25-26-27 - //--28-29-30-31 - constexpr int64_t n_rows = 8; - constexpr int64_t n_cols = 8; - - // Column size of the tile needs to be multiples of 8 for ldmatrix to work. - NVF_ERROR( - tile_size_x >= n_rows && tile_size_x % n_rows == 0 && - tile_size_y >= n_cols && tile_size_y % n_cols == 0, - "Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle", - tile_size_x, - "x", - tile_size_y); - - /* Note [How to remove bank conflict for ldmatrix?] - * - * **This note is interleaved with code, I suggest reading this note like - * reading a jupyter notebook** - * - * Our task is to make sure different rows does not fall into the same - * bank of shared memory. - * - * Introduction to bank conflict can be found at page 54-72 of: - * https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf - * - * When we talk about bank conflict removal, we are talking about the - * following task: - * "there are 32 banks, and each bank contains one 4-byte word, we want to - * make sure different lanes in a warp does not access different word - * addresses in the same bank" - * For example, if thread 0 is accessing word address 1, and thread 1 is - * accessing word address 33, then these two threads will have a bank - * conflict because they are accessing different word addresses in the same - * bank. However, if thread 0 is accessing byte address 4 and thread 1 is - * accessing byte address 6 then there will be no bank conflict because 4 - * and 6 both belong to word 1. - */ - - constexpr int64_t smem_bytes_per_word = 4; - constexpr int64_t smem_banks = 32; - - /* but here, for our convenience, because ldmatrix always use vectorized - * access of 8 items = 16 bytes = 4 words, we further group words into - * units: we consider each 4 words as a "unit", and each 4 banks as a - * "megabank". So we can rephrase our task as: - * "there are 8 megabanks, and each megabanks contains one 4-word unit, we - * want to make sure different lanes in a warp does not access different - * unit addresses in the same megabank" - * In this terminology, matrices are in the row major format, each matrix - * has 8 rows, and each row has exactly one unit. - */ - - constexpr int64_t items_per_unit = n_cols; - const int64_t bytes_per_unit = items_per_unit * data_type_size; - const int64_t words_per_unit = bytes_per_unit / smem_bytes_per_word; - const int64_t num_megabanks = smem_banks / words_per_unit; - - /* In the following example, each CTA tile contains 2 rows and 3 colums of - * matrices, each 8x8 size: - * +----------+----------+----------+ - * | matrix 0 | matrix 1 | matrix 2 | - * +----------+----------+----------+ - * | matrix 3 | matrix 4 | matrix 5 | - * +----------+----------+----------+ - * The addresses of different rows in the same matrix are offset by 3 units. - * In this perspective, loading a matrix is a strided memory access with the - * following stride (in units): - */ - - // number of units per row - int64_t row_stride = tile_size_y / items_per_unit; - - /* So the bank conflicting problem is now converted to the following game: - * I have a clock that has one pointer and `num_megabanks` ticks. I start - * my game by making my pointer pointing to somewhere, and turn forward - * the pointer `n_rows` times, each time by `row_stride` ticks. - * This problem can be well modeled by modular arithmetic in number theory - * using the concept "integers modulo n" a.k.a. "Z/nZ"[1]. - * Take n = 6 as an example, Z/6Z only has 6 elements: 0, 1, 2, 3, 4, 5. - * Additions and multiplications are defined in a cyclic manner: - * 5 + 1 = 0, 5 + 2 = 1, 5 + 3 = 2, 5 + 4 = 3, ... - * 2 * 1 = 2, 2 * 2 = 4, 2 * 3 = 0, 2 * 4 = 2, ... - * With this definition, Z is mapped to Z/nZ naturally by i -> i % n [2] - * - * It worth mention that Z/nZ is a "commutative ring", that is, we can use - * addition and multiplication rules just like using normal integers: - * a + b = b + a, a * (b + c) = a * b + a * c, ... - * In short, we can reason about Z/nZ just like we are reasoning about - * integers, except that every number is automatically "% n". - * - * Reference: - * [1] https://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n - * [2] The % is under Euclidean definition, that is -1 % 6 is 5 instead of - * -1, see [The Mathematics of Integer Arithmetic] for more detail. But - * we are only interested in non-negative numbers here, so there is no - * need to worry about this problem - */ - - // row_stride in Z/nZ, where n is num_megabanks: - // assert(row_stride >= 0); - // assert(num_megabanks >= 0); - int64_t row_stride_znz = row_stride % num_megabanks; - /* Consider the following function in Z/nZ: - * f(i; init) = init + i * stride - * where init is the initial position of the pointer in the clock when we - * start the game, and stride is the number of ticks we move forward each - * time, and i is the number of times we move forward. For a fixed init, we - * abbrivate f(i; init) as f(i). - * - * In our problem, f(i) is the megabank of the `i`th row of the matrix, and - * `init` is the megabank of the 0th row of the matrix. - * - * One very important property of f(i) is: - * - if f(i1) == f(i2), then for every j, f(i1 + j) = f(i2 + j) - * This property is true because: - * f(i1 + j) = f(i1) + j * stride = f(i2) + j * stride = f(i2 + j) - * - * The above property tells us, as we turn the clock forward: - * - initially, we will go to a never-visited tick in each turn, but, - * - at some point, we will return back to our original position, and, - * - after we return, we start repeat the pervious pattern again and again. - * - * As an example, consider f(i) where init = 0, stride = 6, under Z/8Z: - * i 0 1 2 3 4 5 6 7 - * f(i) 0 6 4 2 0 6 4 2 - * We can see that f(i) is repeating a pattern of four unique numbers - * "0 6 4 2" twice. In our bank conflict problem, this means we are using 4 - * different megabanks, and we have a 2-way conflict. - * - * The question of interest is, does the above observation generalize? That - * is, does f(i) always repeat a pattern of p unique numbers q times? Note - * that p and q must satisfy p * q = n. - * - * The answer to the above question is: yes! Consider the following - * equation: - * f(i1 + j) == f(i1) - * We want to know what is the smallest positive number j that makes the - * above equation true. Because this tells us in how many steps we will see - * repeat. This equation can be simplified as: - * f(i1 + j) == f(i1) + j * stride == f(i1) - * ==> j * stride == 0 - * - * An important tool to study this equation is multiplicative inverse: - * https://en.wikipedia.org/wiki/Modular_multiplicative_inverse - * A number i has multiplicative inverse `minv(i)` in Z/nZ if and only if it - * coprime with n. `minv(i)` is the number that `i * minv(i) == 1`. So in - * Z/nZ, the equation `ax = b` has solution `x = minv(a)*b` if a has - * multiplicative inverse. For example, in Z/15Z, `minv(2) = 8` because - * (2 * 8) % 15 = 1 - * - * stride has an multiplicative inverse if and only if stride coprime with - * n, that is, g := gcd(stride, n) == 1. In such case, the solution to our - * equation j * stride == 0 is j = minv(stride) * 0 = 0, that is: f(i) does - * not repeat, that is: there is no bank conflict. - */ - - int64_t g = std::gcd(num_megabanks, row_stride_znz); - if (g == 1) { - return; // No need to swizzle in this case. - } + // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. + // half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit + // (i.e. float) + const int64_t data_type_size = + (int64_t)dataTypeSize(*shared_mem_tv->getDataType()); + NVF_ERROR(data_type_size == 2 || data_type_size == 4); + + // For main loop, ldmatrix loads a n_rows x n_cols = 8 x 8 matrix each time. + // For epilogue, threads in a warp is organized as 8 rows x 4 columns. + // Each thread vectorized write 2 items, so 8 items per row. + //--0--1--2--3 + //--4--5--6--7 + //--8--9--10-11 + //--12-13-14-15 + //--16-17-18-19 + //--20-21-22-23 + //--24-25-26-27 + //--28-29-30-31 + constexpr int64_t n_rows = 8; + constexpr int64_t n_cols = 8; + + // Column size of the tile needs to be multiples of 8 for ldmatrix to work. + NVF_ERROR( + tile_size_x >= n_rows && tile_size_x % n_rows == 0 && + tile_size_y >= n_cols && tile_size_y % n_cols == 0, + "Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle", + tile_size_x, + "x", + tile_size_y); + + /* Note [How to remove bank conflict for ldmatrix?] + * + * **This note is interleaved with code, I suggest reading this note like + * reading a jupyter notebook** + * + * Our task is to make sure different rows does not fall into the same + * bank of shared memory. + * + * Introduction to bank conflict can be found at page 54-72 of: + * https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf + * + * When we talk about bank conflict removal, we are talking about the + * following task: + * "there are 32 banks, and each bank contains one 4-byte word, we want to + * make sure different lanes in a warp does not access different word + * addresses in the same bank" + * For example, if thread 0 is accessing word address 1, and thread 1 is + * accessing word address 33, then these two threads will have a bank + * conflict because they are accessing different word addresses in the same + * bank. However, if thread 0 is accessing byte address 4 and thread 1 is + * accessing byte address 6 then there will be no bank conflict because 4 + * and 6 both belong to word 1. + */ + + constexpr int64_t smem_bytes_per_word = 4; + constexpr int64_t smem_banks = 32; + + /* but here, for our convenience, because ldmatrix always use vectorized + * access of 8 items = 16 bytes = 4 words, we further group words into + * units: we consider each 4 words as a "unit", and each 4 banks as a + * "megabank". So we can rephrase our task as: + * "there are 8 megabanks, and each megabanks contains one 4-word unit, we + * want to make sure different lanes in a warp does not access different + * unit addresses in the same megabank" + * In this terminology, matrices are in the row major format, each matrix + * has 8 rows, and each row has exactly one unit. + */ + + constexpr int64_t items_per_unit = n_cols; + const int64_t bytes_per_unit = items_per_unit * data_type_size; + const int64_t words_per_unit = bytes_per_unit / smem_bytes_per_word; + const int64_t num_megabanks = smem_banks / words_per_unit; + + /* In the following example, each CTA tile contains 2 rows and 3 colums of + * matrices, each 8x8 size: + * +----------+----------+----------+ + * | matrix 0 | matrix 1 | matrix 2 | + * +----------+----------+----------+ + * | matrix 3 | matrix 4 | matrix 5 | + * +----------+----------+----------+ + * The addresses of different rows in the same matrix are offset by 3 units. + * In this perspective, loading a matrix is a strided memory access with the + * following stride (in units): + */ + + // number of units per row + int64_t row_stride = tile_size_y / items_per_unit; + + /* So the bank conflicting problem is now converted to the following game: + * I have a clock that has one pointer and `num_megabanks` ticks. I start + * my game by making my pointer pointing to somewhere, and turn forward + * the pointer `n_rows` times, each time by `row_stride` ticks. + * This problem can be well modeled by modular arithmetic in number theory + * using the concept "integers modulo n" a.k.a. "Z/nZ"[1]. + * Take n = 6 as an example, Z/6Z only has 6 elements: 0, 1, 2, 3, 4, 5. + * Additions and multiplications are defined in a cyclic manner: + * 5 + 1 = 0, 5 + 2 = 1, 5 + 3 = 2, 5 + 4 = 3, ... + * 2 * 1 = 2, 2 * 2 = 4, 2 * 3 = 0, 2 * 4 = 2, ... + * With this definition, Z is mapped to Z/nZ naturally by i -> i % n [2] + * + * It worth mention that Z/nZ is a "commutative ring", that is, we can use + * addition and multiplication rules just like using normal integers: + * a + b = b + a, a * (b + c) = a * b + a * c, ... + * In short, we can reason about Z/nZ just like we are reasoning about + * integers, except that every number is automatically "% n". + * + * Reference: + * [1] https://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n + * [2] The % is under Euclidean definition, that is -1 % 6 is 5 instead of + * -1, see [The Mathematics of Integer Arithmetic] for more detail. But + * we are only interested in non-negative numbers here, so there is no + * need to worry about this problem + */ + + // row_stride in Z/nZ, where n is num_megabanks: + // assert(row_stride >= 0); + // assert(num_megabanks >= 0); + int64_t row_stride_znz = row_stride % num_megabanks; + /* Consider the following function in Z/nZ: + * f(i; init) = init + i * stride + * where init is the initial position of the pointer in the clock when we + * start the game, and stride is the number of ticks we move forward each + * time, and i is the number of times we move forward. For a fixed init, we + * abbrivate f(i; init) as f(i). + * + * In our problem, f(i) is the megabank of the `i`th row of the matrix, and + * `init` is the megabank of the 0th row of the matrix. + * + * One very important property of f(i) is: + * - if f(i1) == f(i2), then for every j, f(i1 + j) = f(i2 + j) + * This property is true because: + * f(i1 + j) = f(i1) + j * stride = f(i2) + j * stride = f(i2 + j) + * + * The above property tells us, as we turn the clock forward: + * - initially, we will go to a never-visited tick in each turn, but, + * - at some point, we will return back to our original position, and, + * - after we return, we start repeat the pervious pattern again and again. + * + * As an example, consider f(i) where init = 0, stride = 6, under Z/8Z: + * i 0 1 2 3 4 5 6 7 + * f(i) 0 6 4 2 0 6 4 2 + * We can see that f(i) is repeating a pattern of four unique numbers + * "0 6 4 2" twice. In our bank conflict problem, this means we are using 4 + * different megabanks, and we have a 2-way conflict. + * + * The question of interest is, does the above observation generalize? That + * is, does f(i) always repeat a pattern of p unique numbers q times? Note + * that p and q must satisfy p * q = n. + * + * The answer to the above question is: yes! Consider the following + * equation: + * f(i1 + j) == f(i1) + * We want to know what is the smallest positive number j that makes the + * above equation true. Because this tells us in how many steps we will see + * repeat. This equation can be simplified as: + * f(i1 + j) == f(i1) + j * stride == f(i1) + * ==> j * stride == 0 + * + * An important tool to study this equation is multiplicative inverse: + * https://en.wikipedia.org/wiki/Modular_multiplicative_inverse + * A number i has multiplicative inverse `minv(i)` in Z/nZ if and only if it + * coprime with n. `minv(i)` is the number that `i * minv(i) == 1`. So in + * Z/nZ, the equation `ax = b` has solution `x = minv(a)*b` if a has + * multiplicative inverse. For example, in Z/15Z, `minv(2) = 8` because + * (2 * 8) % 15 = 1 + * + * stride has an multiplicative inverse if and only if stride coprime with + * n, that is, g := gcd(stride, n) == 1. In such case, the solution to our + * equation j * stride == 0 is j = minv(stride) * 0 = 0, that is: f(i) does + * not repeat, that is: there is no bank conflict. + */ + + int64_t g = std::gcd(num_megabanks, row_stride_znz); + if (g == 1) { + return; // No need to swizzle in this case. + } - /* For the case where stride does not coprime with n, we note that - * j * stride == 0 in Z/nZ is equivalent to (j * stride) % n = 0 in Z. We - * can write stride and n as: - * stride = s * g, n = m * g - * According to Theorem 4.13 in [The Mathematics of Integer Arithmetic], we - * have: - * (j * stride) % n = 0 - * ==> (j * s) % m * g = 0 - * ==> (j * s) % m = 0 - * which is equivalent to j * s == 0 in Z/mZ. Because s coprime with m, we - * further get: - * j == 0 (in Z/mZ) - * That is, j is a multiple of m in Z. So the smallest positive j that make - * the equation hold is n / g. - * - * That is: f(i) always repeat a pattern of n/g unique numbers g times. - * In other word: we are using n/g megabanks, and we have a g-way bank - * conflict. - * - * Let's use the word "pattern" to refer to the set of values of `f` at - * different `i`, that is: - * pattern k = { f(i; init=k) | i in Z/nZ } - * For the example of stride = 6 under Z/8Z, we have the following patterns - * f(i): 01234567 - * pattern 0: x_x_x_x_ - * pattern 1: _x_x_x_x - * (x => occupied, _ => unoccupied) - */ - - int64_t repeated_pattern_size = num_megabanks / g; - - if (repeated_pattern_size >= n_rows) { - return; // No need to swizzle in this case. - } + /* For the case where stride does not coprime with n, we note that + * j * stride == 0 in Z/nZ is equivalent to (j * stride) % n = 0 in Z. We + * can write stride and n as: + * stride = s * g, n = m * g + * According to Theorem 4.13 in [The Mathematics of Integer Arithmetic], we + * have: + * (j * stride) % n = 0 + * ==> (j * s) % m * g = 0 + * ==> (j * s) % m = 0 + * which is equivalent to j * s == 0 in Z/mZ. Because s coprime with m, we + * further get: + * j == 0 (in Z/mZ) + * That is, j is a multiple of m in Z. So the smallest positive j that make + * the equation hold is n / g. + * + * That is: f(i) always repeat a pattern of n/g unique numbers g times. + * In other word: we are using n/g megabanks, and we have a g-way bank + * conflict. + * + * Let's use the word "pattern" to refer to the set of values of `f` at + * different `i`, that is: + * pattern k = { f(i; init=k) | i in Z/nZ } + * For the example of stride = 6 under Z/8Z, we have the following patterns + * f(i): 01234567 + * pattern 0: x_x_x_x_ + * pattern 1: _x_x_x_x + * (x => occupied, _ => unoccupied) + */ + + int64_t repeated_pattern_size = num_megabanks / g; + + if (repeated_pattern_size >= n_rows) { + return; // No need to swizzle in this case. + } - /* Now we know that we have a g-way bank conflict. How do we remove this - * bank conflict? The answer is to mix the storage of different matrices. - * We first split the matrices along the row axis into g pieces, each piece - * has n/g rows. With this split, each piece occupies exactly one pattern. - * We want to use some non-traditional storage to let different pieces of - * the same matrix to occupy different patterns. - * - * Because Z/nZ has n items, each pattern has n/g different items, so we - * have in total g different patterns. We want to find the corresponding - * `init` values of these g different patterns. - * - * Consider two different init values `init1` and `init2`. When do they - * represent the same pattern? They represent the same pattern if and only - * if `f(0; init2)` falls on the pattern of `init1`, that is, there exist an - * i such that - * f(i; init1) == f(0; init2) - * which simplifies to - * init1 + i * stride == init2 - * ==> init2 - init1 == i * stride - * What values can `i * stride` be? It can be an arbitrary multiple of g: - * i * stride in Z/nZ is (i * stride) % n in Z. Let m = n/g, according to - * Theorem 4.13 in [The Mathematics of Integer Arithmetic] - * (i * stride) % n = (i * s) % m * g - * Because s coprime with m, we know that for an arbitrary value `j` in - * Z/mZ, we can take `i = minv(s) * j` to make `i * s == j`. - * - * That said, for init values that are off by a multiple of g they - * correspond to the same pattern, otherwise they belongs to different - * patterns. So, we can use - * init = 0, 1, ..., g - 1 - * to canonically represent g patterns. Let's call the above - * `init` values "pattern id". - * - * Now we have the idea about how to remove bank conflict: We can do an - * inner split of our row dimension by `repeated_pattern_size` to get - * (repeat, pattern), then different indices of the "repeat" dimension will - * be using the same megabank, and different indices of the "pattern" - * dimension will be using different megabank. We don't need to touch the - * "pattern" dimension, but we need to play with the "repeat" dimension to - * interleave it with matrice ids so that each matrix is distributed across - * different banks. - * - * For example, if we have repeated_pattern_size = 4, we would want to do - * something like below: - * +----------+----------+ - * 0| | | - * 1| matrix 0 | matrix 1 | - * 2| | | - * 3| | | - * +----------+----------+ - * 4| | | - * 5| matrix 1 | matrix 0 | - * 6| | | - * 7| | | - * +----------+----------+ - * - * We can consider each repeated_pattern_size rows as a gigarow, and each - * repeated_pattern_size megabanks as a gigabank. Note that megabank is a - * contiguous chunk of banks, but gigabank is not contiguous. Indeed, - * nearby megabanks in a gigabank has a distance of `g` megabanks - */ + /* Now we know that we have a g-way bank conflict. How do we remove this + * bank conflict? The answer is to mix the storage of different matrices. + * We first split the matrices along the row axis into g pieces, each piece + * has n/g rows. With this split, each piece occupies exactly one pattern. + * We want to use some non-traditional storage to let different pieces of + * the same matrix to occupy different patterns. + * + * Because Z/nZ has n items, each pattern has n/g different items, so we + * have in total g different patterns. We want to find the corresponding + * `init` values of these g different patterns. + * + * Consider two different init values `init1` and `init2`. When do they + * represent the same pattern? They represent the same pattern if and only + * if `f(0; init2)` falls on the pattern of `init1`, that is, there exist an + * i such that + * f(i; init1) == f(0; init2) + * which simplifies to + * init1 + i * stride == init2 + * ==> init2 - init1 == i * stride + * What values can `i * stride` be? It can be an arbitrary multiple of g: + * i * stride in Z/nZ is (i * stride) % n in Z. Let m = n/g, according to + * Theorem 4.13 in [The Mathematics of Integer Arithmetic] + * (i * stride) % n = (i * s) % m * g + * Because s coprime with m, we know that for an arbitrary value `j` in + * Z/mZ, we can take `i = minv(s) * j` to make `i * s == j`. + * + * That said, for init values that are off by a multiple of g they + * correspond to the same pattern, otherwise they belongs to different + * patterns. So, we can use + * init = 0, 1, ..., g - 1 + * to canonically represent g patterns. Let's call the above + * `init` values "pattern id". + * + * Now we have the idea about how to remove bank conflict: We can do an + * inner split of our row dimension by `repeated_pattern_size` to get + * (repeat, pattern), then different indices of the "repeat" dimension will + * be using the same megabank, and different indices of the "pattern" + * dimension will be using different megabank. We don't need to touch the + * "pattern" dimension, but we need to play with the "repeat" dimension to + * interleave it with matrice ids so that each matrix is distributed across + * different banks. + * + * For example, if we have repeated_pattern_size = 4, we would want to do + * something like below: + * +----------+----------+ + * 0| | | + * 1| matrix 0 | matrix 1 | + * 2| | | + * 3| | | + * +----------+----------+ + * 4| | | + * 5| matrix 1 | matrix 0 | + * 6| | | + * 7| | | + * +----------+----------+ + * + * We can consider each repeated_pattern_size rows as a gigarow, and each + * repeated_pattern_size megabanks as a gigabank. Note that megabank is a + * contiguous chunk of banks, but gigabank is not contiguous. Indeed, + * nearby megabanks in a gigabank has a distance of `g` megabanks + */ - NVF_ERROR( - n_rows % repeated_pattern_size == 0, - "Can not partition matrix into megarows"); - int64_t num_gigarows = n_rows / repeated_pattern_size; - int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size - - // -2 -1 - // [row, col] - if (repeated_pattern_size > 1) { - shared_mem_tv->split(-2 - skip, repeated_pattern_size); - } - shared_mem_tv->split(-1 - skip, n_cols); - // -4 -3 -2 -1 - // [gigarow id, gigarow, matrix id, matrix] - shared_mem_tv->split(-2 - skip, num_gigabanks); - // -5 -4 -3 -2 -1 - // [gigarow id, gigarow, y outer, gigabank id, matrix] - // Note that megabanks inside a gigabank are not contiguous, so the gigabank - // id is -2 instead of -3 - - /* We want to evenly distribute gigarows across gigabanks, for example, if - * we have 7 gigarows and 3 gigabanks, then we might distribute them as: - * +---+ - * |x | - * | x | - * | x| - * |x | - * | x | - * | x| - * |x | - * +---+ - * considering all matrices, this is a swizzle function like: - * +---+ - * |012| - * |201| - * |120| - * |012| - * |201| - * |120| - * |012| - * +---+ - * which is a cyclic shift. - * - * Note that because num_gigabanks (a.k.a. g) divide num_megabanks and - * row_stride_znz (which is row_stride % num_megabanks), g should also - * divide row_stride, because according to the fundamental - * division-with-remainder property (see comment in expr_simplifier.h): - * row_stride = q * num_megabanks + row_stride_znz - * which means, we can just consider each num_gigabanks matrices as a group, - * and we always have complete groups (i.e. no group has less than - * num_gigabanks matrices). Interleaving the memory of matrices within each - * group should be enough to fully remove bank conflict. - */ - - /* To further simplify the problem, if we assume: */ - NVF_ERROR( - num_gigarows % num_gigabanks == 0, - "Requires non-square swizzle, which is not supported yet"); - /* Then we can partition gigarows into full waves, each wave has - * num_gigabanks gigarows. This partition creates square dimensions, making - * the swizzle implementation easier */ - - // -5 -4 -3 -2 -1 - // [gigarow id, gigarow, y outer, gigabank id, matrix] + NVF_ERROR( + n_rows % repeated_pattern_size == 0, + "Can not partition matrix into megarows"); + int64_t num_gigarows = n_rows / repeated_pattern_size; + int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size + + // -2 -1 + // [row, col] + if (repeated_pattern_size > 1) { + shared_mem_tv->split(-2 - skip, repeated_pattern_size); + } + shared_mem_tv->split(-1 - skip, n_cols); + // -4 -3 -2 -1 + // [gigarow id, gigarow, matrix id, matrix] + shared_mem_tv->split(-2 - skip, num_gigabanks); + // -5 -4 -3 -2 -1 + // [gigarow id, gigarow, y outer, gigabank id, matrix] + // Note that megabanks inside a gigabank are not contiguous, so the gigabank + // id is -2 instead of -3 + + /* We want to evenly distribute gigarows across gigabanks, for example, if + * we have 7 gigarows and 3 gigabanks, then we might distribute them as: + * +---+ + * |x | + * | x | + * | x| + * |x | + * | x | + * | x| + * |x | + * +---+ + * considering all matrices, this is a swizzle function like: + * +---+ + * |012| + * |201| + * |120| + * |012| + * |201| + * |120| + * |012| + * +---+ + * which is a cyclic shift. + * + * Note that because num_gigabanks (a.k.a. g) divide num_megabanks and + * row_stride_znz (which is row_stride % num_megabanks), g should also + * divide row_stride, because according to the fundamental + * division-with-remainder property (see comment in expr_simplifier.h): + * row_stride = q * num_megabanks + row_stride_znz + * which means, we can just consider each num_gigabanks matrices as a group, + * and we always have complete groups (i.e. no group has less than + * num_gigabanks matrices). Interleaving the memory of matrices within each + * group should be enough to fully remove bank conflict. + */ + + /* To further simplify the problem, if we assume: */ + NVF_ERROR( + num_gigarows % num_gigabanks == 0, + "Requires non-square swizzle, which is not supported yet"); + /* Then we can partition gigarows into full waves, each wave has + * num_gigabanks gigarows. This partition creates square dimensions, making + * the swizzle implementation easier */ + + // -5 -4 -3 -2 -1 + // [gigarow id, gigarow, y outer, gigabank id, matrix] + int axis_of_gigarow_id = repeated_pattern_size > 1 ? -5 : -4; + shared_mem_tv->split(axis_of_gigarow_id - skip, num_gigabanks); + // -6 -5 -4 -3 -2 -1 + // [wave id, wave, gigarow, y outer, gigabank id, matrix] + + // swizzle wave with gigabank id to make threads in a wave access different + // gigabank. Apply swizzle only when shared_mem_tv is stored in shared + // memory. + // TODO: This is a temporary workaround for the following issue: + // For the mma output, we have the following schedule: + // rFactor: [...., X, Y] -> mma-swizzle transformations -> leaf + // For epilogue smem tensor, the schedule is + // rFactor: [...., X, Y] -> split -> [...., X1, X2, X3, Y1, Y2, Y3] + // -> swizzle X2, Y2 -> [...., X1, X2', X3, Y1, Y2', Y3] + // -> merge back -> [...., X', Y'] + // -> mma-swizzle transformations -> leaf + // The mma-swizzle transformations for the mma output and epilogue smem + // tensor are the same. In indexing, we do require {X, X'} and {Y, Y'} to be + // mapped in CA map, however, we currently can not handle that. So we have + // to do the same split and merge to the mma output without actually + // applying the swizzle, and this check is to detect and handle this + // specific case. We should remove this special handling when we fix our CA + // mapping. + if (shared_mem_tv->getMemoryType() == MemoryType::Shared) { int axis_of_gigarow_id = repeated_pattern_size > 1 ? -5 : -4; - shared_mem_tv->split(axis_of_gigarow_id - skip, num_gigabanks); - // -6 -5 -4 -3 -2 -1 - // [wave id, wave, gigarow, y outer, gigabank id, matrix] - - // swizzle wave with gigabank id to make threads in a wave access different - // gigabank. Apply swizzle only when shared_mem_tv is stored in shared - // memory. - // TODO: This is a temporary workaround for the following issue: - // For the mma output, we have the following schedule: - // rFactor: [...., X, Y] -> mma-swizzle transformations -> leaf - // For epilogue smem tensor, the schedule is - // rFactor: [...., X, Y] -> split -> [...., X1, X2, X3, Y1, Y2, Y3] - // -> swizzle X2, Y2 -> [...., X1, X2', X3, Y1, Y2', Y3] - // -> merge back -> [...., X', Y'] - // -> mma-swizzle transformations -> leaf - // The mma-swizzle transformations for the mma output and epilogue smem - // tensor are the same. In indexing, we do require {X, X'} and {Y, Y'} to be - // mapped in CA map, however, we currently can not handle that. So we have - // to do the same split and merge to the mma output without actually - // applying the swizzle, and this check is to detect and handle this - // specific case. We should remove this special handling when we fix our CA - // mapping. - if (shared_mem_tv->getMemoryType() == MemoryType::Shared) { - int axis_of_gigarow_id = repeated_pattern_size > 1 ? -5 : -4; - if (isPowOf2(num_gigabanks)) { - shared_mem_tv->swizzle( - Swizzle2DType::XOR, axis_of_gigarow_id - skip, -2 - skip); - } else { - shared_mem_tv->swizzle( - Swizzle2DType::CyclicShift, axis_of_gigarow_id - skip, -2 - skip); - } + if (isPowOf2(num_gigabanks)) { + shared_mem_tv->swizzle( + Swizzle2DType::XOR, axis_of_gigarow_id - skip, -2 - skip); + } else { + shared_mem_tv->swizzle( + Swizzle2DType::CyclicShift, axis_of_gigarow_id - skip, -2 - skip); } + } - if (repeated_pattern_size > 1) { - shared_mem_tv->merge(-6 - skip); - } - shared_mem_tv->merge(-5 - skip); - - // merge back tile_size_y - shared_mem_tv->merge(-3 - skip); - shared_mem_tv->merge(-2 - skip); - - } else if (isVolta(params.mma_macro)) { - // TODO: Volta is slightly more complex, and a fixed recipe would - // not scale. In a follow up this would be inferred from the mma - // macro layout themselves as we already have them registered in - // the utils. - return; - } else { - NVF_ERROR(false, "Prolog swizzle: unsupported mma macro"); + if (repeated_pattern_size > 1) { + shared_mem_tv->merge(-6 - skip); } + shared_mem_tv->merge(-5 - skip); + + // merge back tile_size_y + shared_mem_tv->merge(-3 - skip); + shared_mem_tv->merge(-2 - skip); } //! Generates the prolog schedule on the shared memory buffer @@ -707,6 +696,10 @@ void scheduleFusionInputsForEpilogue( } // namespace void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { + static const bool should_unroll = true; + auto cached_and_forked_outputs = + scheduler_utils::cacheAndForkOutputs(fusion, should_unroll); + const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion); // NOTE: the contents of roles_map have been already validated during @@ -723,7 +716,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // Core roles: there can be only one... TV with assigned core role TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); - TensorView* d = roles_map.at(MatmulRole::OUTPUT_D).front(); // Collect mma swizzle info auto mma = mma_ops.front(); @@ -731,11 +723,9 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { NVF_ERROR( mma_layout_opt.has_value(), "fusion mma op has undefined input layout"); const auto mma_layout = mma_layout_opt.value(); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion); + const auto fusion_layout = mma_utils::getMmaLayout(fusion); NVF_ERROR(fusion_layout.isValid(), fusion_layout.getErrorMsg()); - auto mma_builder = - MmaBuilder(params.mma_macro, params.tile_sizes).layout(mma_layout); const auto& gemm_tile = params.tile_sizes; const bool has_epilogue = !mma->out()->isFusionOutput(); @@ -777,15 +767,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // Currently the support is for a, b, c and d as fusion inputs/outputs // aka. no prolog fusion yet. - mma_builder.configureMma(mma); - - // TODO: - // Beyond this point, mma_builder really just becomes a populated - // list of parameters to describe the mma swizzles that should - // be annotated on the tensor domain. Conceptually the mma builder - // object should be separated to 2 parts, one as scheduler utility - // and the other as matmul heuristic parameters, which we are - // starting to build out. + mma->setMacro(params.mma_macro); // Setup register and shared memory stages: // TODO: this section goes to a separate matmul util, @@ -796,10 +778,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { auto bb = mma->inB()->as(); // Setup accumulator register. - auto dc = d->cacheBefore(); - // Mma object is valid only because cacheBefore has been done on - // TV which is not output of MmaOp, as there is an epilogue - auto mma_result = has_epilogue ? mma->out()->as() : dc; + auto mma_result = mma->out()->as(); // Unswizzle mma result in shared memory auto smem_epilogue = @@ -808,9 +787,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // Clear MmaOp pointer, it's not needed from now on mma = nullptr; - // Set accumulation tv for mma op. - mma_builder.accumulatorTv(mma_result); - // Staging register for global memory load TensorView *ar = a, *br = b; @@ -829,75 +805,45 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { TensorView* acr = nullptr; TensorView* bcr = nullptr; - // Different paths because Volta swizzle needs to - // involve the broadcast dimensions that are concretized - // at mma, while Ampere ones should be done before - // the broadcast op to be able to use cp.async. - // TODO: - // Also a few additional parameters should be introduced - // to control this stage of scheduling. - if (isVolta(params.mma_macro)) { - acw_smem = ab->cacheAfter(); - bcw_smem = bb->cacheAfter(); - // Cache again to be able to vectorize. - acw_smem = acw_smem->cacheAfter(); - bcw_smem = bcw_smem->cacheAfter(); - - acr = acw_smem->cacheAfter(); - bcr = bcw_smem->cacheAfter(); - if (params.double_buffer_options.double_buffer_smem_read) { - // Provide another copy op between the double buffered - // smem load register and the actual mma ops to avoid - // complication in double buffered fragment iteration. - ab = acr->cacheAfter(); - bb = bcr->cacheAfter(); - } else { - ab = acr; - bb = bcr; - } - - } else { - // Use cp.async as requested in scheduler params. - LoadStoreOpType load_op = LoadStoreOpType::Set; - CacheOp cache_op = CacheOp::Unspecified; - if (params.async_gmem_load_operands) { - load_op = LoadStoreOpType::CpAsync; - cache_op = CacheOp::Global; - } + // Use cp.async as requested in scheduler params. + LoadStoreOpType load_op = LoadStoreOpType::Set; + CacheOp cache_op = CacheOp::Unspecified; + if (params.async_gmem_load_operands) { + load_op = LoadStoreOpType::CpAsync; + cache_op = CacheOp::Global; + } - acw_smem = ar->cacheAfter(load_op, cache_op); - bcw_smem = br->cacheAfter(load_op, cache_op); - NVF_ERROR(acw_smem->uses().size() == 1); - NVF_ERROR(bcw_smem->uses().size() == 1); - if (auto ldst = dynamic_cast(acw_smem->uses().at(0))) { - acr = ldst->out()->as(); - if (ldst->hasInnerTranspose()) { - ldst->setOpType(LoadStoreOpType::LdMatrixTranspose); - } else { - ldst->setOpType(LoadStoreOpType::LdMatrix); - } + acw_smem = ar->cacheAfter(load_op, cache_op); + bcw_smem = br->cacheAfter(load_op, cache_op); + NVF_ERROR(acw_smem->uses().size() == 1); + NVF_ERROR(bcw_smem->uses().size() == 1); + if (auto ldst = dynamic_cast(acw_smem->uses().at(0))) { + acr = ldst->out()->as(); + if (ldst->hasInnerTranspose()) { + ldst->setOpType(LoadStoreOpType::LdMatrixTranspose); } else { - acr = acw_smem->cacheAfter(LoadStoreOpType::LdMatrix); + ldst->setOpType(LoadStoreOpType::LdMatrix); } - if (auto ldst = dynamic_cast(bcw_smem->uses().at(0))) { - bcr = ldst->out()->as(); - if (ldst->hasInnerTranspose()) { - ldst->setOpType(LoadStoreOpType::LdMatrixTranspose); - } else { - ldst->setOpType(LoadStoreOpType::LdMatrix); - } + } else { + acr = acw_smem->cacheAfter(LoadStoreOpType::LdMatrix); + } + if (auto ldst = dynamic_cast(bcw_smem->uses().at(0))) { + bcr = ldst->out()->as(); + if (ldst->hasInnerTranspose()) { + ldst->setOpType(LoadStoreOpType::LdMatrixTranspose); } else { - bcr = bcw_smem->cacheAfter(LoadStoreOpType::LdMatrix); + ldst->setOpType(LoadStoreOpType::LdMatrix); } - - // For Turing and Ampere, the layout of the MmaOp is always TN - NVF_ERROR( - mma_layout == MmaOptions::MmaLayout::TN, - "MMAs in Turing and Ampere are TN only, transpose is handled either " - "via ldmatrix.trans for fp16 or explicitly for other types."); - mma_builder.layout(fusion_layout.getData()); + } else { + bcr = bcw_smem->cacheAfter(LoadStoreOpType::LdMatrix); } + // For Turing and Ampere, the layout of the MmaOp is always TN + NVF_ERROR( + mma_layout == MmaLayout::TN, + "MMAs in Turing and Ampere are TN only, transpose is handled either " + "via ldmatrix.trans for fp16 or explicitly for other types."); + // Make a CTA tile // ------------------------------------------------------------------ mma_utils::canonicalizeMmaTvOrdering(mma_result); @@ -948,10 +894,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { splitk_sum = mma_result; mma_result = splitk_sum->rFactor({-4, -1}); - // the accumulator must be the output of the MMA op, which is now the - // rfactor TV - mma_builder.accumulatorTv(mma_result); - num_splitk_dims = 1; } @@ -988,35 +930,47 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { moveInnerBroadcastLeft(ab); moveInnerBroadcastLeft(bb); } - ab->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - bb->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + ab->applyMmaSwizzle(MmaOperand::A); + bb->applyMmaSwizzle(MmaOperand::B); // Propagate mma input swizzle up the DAG // to all the tensors before mma op and after shared mem read. - scheduler_utils::BoundedDirectionalTransformPropagator::backward( - ab, - -1, - {acw_smem}, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType()); - scheduler_utils::BoundedDirectionalTransformPropagator::backward( - bb, - -1, - {bcw_smem}, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType()); + auto propagate_mma_input_schedule_to = [&](TensorView* a_boundary, + TensorView* b_boundary) { + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + ab, + -1, + {a_boundary}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType()); + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + bb, + -1, + {b_boundary}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType()); + }; + propagate_mma_input_schedule_to(acw_smem, bcw_smem); - mma_result->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + mma_result->applyMmaSwizzle(MmaOperand::Accumulator); // Set parallelization: // TODO: this section goes to a separate matmul util, // and needs more configurability. // ------------------------------------------------------------------ - // Vectorize smem stores/loads: - acr->axis(-1)->parallelize(ParallelType::Vectorize); - bcr->axis(-1)->parallelize(ParallelType::Vectorize); + acr->setAllocationDomain(acr->getLeafDomain(), true); + bcr->setAllocationDomain(bcr->getLeafDomain(), true); + mma_utils::WarpMmaSwizzler::scheduleLdMatrix(acr, MmaOperand::A); + mma_utils::WarpMmaSwizzler::scheduleLdMatrix(bcr, MmaOperand::B); + + // -5 -4 -3 -2 -1 or -5 -4 -3 -2 -1 + //[8mi, 4k, 2ko, 2mo, 2ki] [8ni, 4k, 2ko, 1no, 2ki] + for (auto tv : {ab, bb}) { + tv->merge(-5); + tv->axis(-4)->parallelize(ParallelType::TIDx); + } + propagate_mma_input_schedule_to(acr, bcr); // Parallelization strategy: // Here the top two rows indicate how we can index each axis. The third row @@ -1088,23 +1042,27 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { .propagateToBoundary()); smem_epilogue->axis(-1)->parallelize(ParallelType::Vectorize); - // Schedule output tensor differently for better global memory access - // pattern. - scheduleOutputTensor(mma_result, d, gemm_tile); - d->axis(-1)->parallelize(ParallelType::Vectorize); + for (auto [dc, d] : cached_and_forked_outputs) { + // Schedule output tensor differently for better global memory access + // pattern. + scheduleOutputTensor(mma_result, d, gemm_tile); + d->axis(-1)->parallelize(ParallelType::Vectorize); - // Propagate output tensor transformations back to smem_epilogue - scheduler_utils::BoundedDirectionalTransformPropagator::backward( - d, -1, {smem_epilogue}); + // Propagate output tensor transformations back to smem_epilogue + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + d, -1, {smem_epilogue}); + } } else { - scheduler_utils::BoundedDirectionalTransformPropagator::forward( - mma_result, - -1, - {d}, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType() - .propagateToBoundary()); - d->axis(-1)->parallelize(ParallelType::Vectorize); + for (auto [dc, d] : cached_and_forked_outputs) { + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + mma_result, + -1, + {d}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); + d->axis(-1)->parallelize(ParallelType::Vectorize); + } } // propagate output transformations to all inputs that are part of epilogue // operations, input tvs with non-core roles diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index f3baafd4ee3..c43be585cc5 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -67,7 +67,7 @@ class MatmulParams : public HeuristicParams { MatMulTileOptions tile_sizes = {}; //! Specify the type of MMA op to be used in generated kernel. - MmaOptions::MacroType mma_macro = MmaOptions::MacroType::NoMMA; + MmaMacro mma_macro = MmaMacro::NoMMA; //! Specify CTA rastrization order. TileRasterizationOrder cta_order = TileRasterizationOrder::RowMajor; @@ -105,7 +105,7 @@ class MatmulParams : public HeuristicParams { std::stringstream ss; ss << "\n===== Matmul Parameters ========\n" << (tag.empty() ? "" : "Tag: ") << tag << "\n" - << "MMA macro: " << nvfuser::toString(mma_macro, true) << "\n" + << "MMA macro: " << nvfuser::toString(mma_macro) << "\n" << double_buffer_options.toString() << "\n" << nvfuser::toString(tile_sizes) << "\n" << "Rotate ldmatrix out of main loop: " diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index c853ec2aacb..bfb84355af7 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -36,24 +36,21 @@ namespace nvfuser { namespace { -using MatmulLayout = MmaOptions::MmaLayout; //! Access to the structure should be done with labels defined in //! MmaOptions::MmaDomains. using ProblemShape = std::array; //! A helper for deciding the type of MMA op for given fusion and problem shape. -inline std::optional getMmaOp( +inline std::optional getMmaOp( const int dev_version, const ProblemShape& problem) { - using MacroType = MmaOptions::MacroType; + using MacroType = MmaMacro; // NOTE: A temp condition const ProblemShape::value_type n_extend = problem[(size_t)MatmulDomain::N]; const bool use_small_n = ((n_extend % 8) == 0) && ((n_extend % 16) != 0); switch (dev_version) { - case 70: - return MacroType::Volta_16_16_4; case 75: return (use_small_n) ? MacroType::Turing_16_8_16 : MacroType::Turing_16_16_16; @@ -71,7 +68,7 @@ inline std::optional getMmaOp( //! A wrapper for core heuristics initialization inline bool initCoreHeuristics( std::shared_ptr params, - const MmaOptions::MacroType& mma_op, + const MmaMacro& mma_op, const ProblemShape& problem_shape) { const GemmTile instruction_tile = getMmaOpShape(mma_op); GemmTile warp_tile = {-1, -1, -1}; @@ -81,25 +78,20 @@ inline bool initCoreHeuristics( // warp tile shape { - if (isAmpere(mma_op) || isTuring(mma_op)) { - // Initial target: - // - 1 MMA ops per thread in a warp (32 threads), warp tile should be - // then 32x bigger than instruction tile, - // - start with [4, 4, 2] shape, later it should depend on problem - // shape and have bigger impact on CTA tile shape - - const DimType m_ratio = 4; - const DimType n_ratio = 4; - const DimType k_ratio = 2; - - warp_tile = { - instruction_tile.m * m_ratio, - instruction_tile.n * n_ratio, - instruction_tile.k * k_ratio}; - } else { - // No support for Volta - return false; - } + // Initial target: + // - 1 MMA ops per thread in a warp (32 threads), warp tile should be + // then 32x bigger than instruction tile, + // - start with [4, 4, 2] shape, later it should depend on problem + // shape and have bigger impact on CTA tile shape + + const DimType m_ratio = 4; + const DimType n_ratio = 4; + const DimType k_ratio = 2; + + warp_tile = { + instruction_tile.m * m_ratio, + instruction_tile.n * n_ratio, + instruction_tile.k * k_ratio}; } // cta tile shape @@ -189,7 +181,6 @@ std::string isMatmulFusionDefinitionSupported( ir_utils::filterByType(fusion_outputs).vector(); constexpr size_t minimal_number_of_inputs = 2; - constexpr size_t expected_number_of_outputs = 1; // Quick checks - MmaOp { @@ -214,11 +205,6 @@ std::string isMatmulFusionDefinitionSupported( if (minimal_number_of_inputs > fusion_inputs.size()) { return "Fusion inputs contain at least one non-TensorView object"; } - - // Fusion has only TVs as outputs, and we expect only one object in the list - if ((expected_number_of_outputs != fusion_outputs_tvs.size())) { - return "Fusion has more than a single TensorView object in its outputs"; - } } // Fusion topology check @@ -255,21 +241,23 @@ std::string isMatmulFusionDefinitionSupported( entry = roles_map.find(MatmulRole::OUTPUT_D); if (entry != roles_map.end()) { - if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { - tvs_with_roles.insert(entry->second.begin(), entry->second.end()); - } else { - return "There is more than a single fusion output that can be MMA output"; - } + tvs_with_roles.insert(entry->second.begin(), entry->second.end()); } else { return "No candidate in fusion outputs MMA output"; } - // Non-core roles are optional, no requirements for their presence + // Non-core input roles are optional, no requirements for definitions entry = roles_map.find(MatmulRole::INPUT_C); if (entry != roles_map.end()) { tvs_with_roles.insert(entry->second.begin(), entry->second.end()); } + // Non-core output roles are optional, no requirements for definitions + entry = roles_map.find(MatmulRole::OUTPUT_AUX); + if (entry != roles_map.end()) { + tvs_with_roles.insert(entry->second.begin(), entry->second.end()); + } + const auto in_out_tvs_count = fusion_inputs_tvs.size() + fusion_outputs_tvs.size(); if (in_out_tvs_count != tvs_with_roles.size()) { @@ -332,7 +320,7 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // #2 { - const auto input_layout_opt = mma_utils::getMatmulLayout(fusion); + const auto input_layout_opt = mma_utils::getMmaLayout(fusion); if (!input_layout_opt.isValid()) { return input_layout_opt.getErrorMsg(); } diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 04e59fc3c65..43f902cead0 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -502,10 +502,6 @@ bool canValidateIsInnerDim( IterDomain* root, IterDomain* leaf, int inner_dim_size) { - // Accept boundary case for Volta. - if (leaf == root && leaf->isBroadcast()) { - return true; - } auto expr = leaf->definition(); if (!leaf->extent()->isConstInt()) { return false; @@ -581,59 +577,7 @@ void checkDimSize( } } -void WarpMmaSwizzler::scheduleMmaWarpOutput( - TensorView* tv, - MmaOptions options) { - auto macro = options.macro; - switch (macro) { - case MmaOptions::MacroType::Volta_16_16_4: - scheduleVoltaM16N16K4Fp32Output(tv, options); - if (tv->definition()->isA()) { - setWarpMapped(tv, 5); - } - break; - case MmaOptions::MacroType::Turing_16_8_16: - case MmaOptions::MacroType::Ampere_16_8_16: - scheduleTuringM16N8K16MmaWarpOutput(tv, options); - if (tv->definition()->isA()) { - setWarpMapped(tv, 4); - } - break; - case MmaOptions::MacroType::Turing_16_16_16: - case MmaOptions::MacroType::Ampere_16_16_16: - scheduleTuringM16N16K16MmaWarpOutput(tv, options); - if (tv->definition()->isA()) { - setWarpMapped(tv, 4); - } - break; - default: - NVF_CHECK( - false, "scheduleMmaWarp: unsupported mma option ", toString(macro)); - break; - } -} - -void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) { - // Schedules operand for inner most 3 contiguous dimensions - // Assumes M, N, K - - switch (options.macro) { - case MmaOptions::MacroType::Volta_16_16_4: - scheduleVoltaOperandRead(tv, options); - break; - case MmaOptions::MacroType::Turing_16_8_16: - case MmaOptions::MacroType::Ampere_16_8_16: - case MmaOptions::MacroType::Turing_16_16_16: - case MmaOptions::MacroType::Ampere_16_16_16: - scheduleTuringOperandRead(tv, options); - break; - default: - NVF_CHECK(false, "WarpMmaSwizzler: please specify macro"); - break; - } -} - -void WarpMmaSwizzler::setWarpMapped(TensorView* tv, int number_of_dims) { +static void setWarpMapped(TensorView* tv, int number_of_dims) { for (int id : c10::irange(number_of_dims)) { tv->axis(-id - 1)->toMmaSwizzled(); } @@ -728,600 +672,212 @@ std::unordered_set getMmaDomainSet( return {mma_domains.begin(), mma_domains.end()}; } -// [MMA dimension matching] -// Returns all the axes that correspond to the given mma dimension. This is the -// first relaxation step on the mma check. -// Mma operations concerns 3 dimensions, namely, the M, N, -// and K dimension, more details see [Operand Layout Convention] in mma_type.h. -// The current implementation, for best effort safety, supports the patterns -// where the root axes can be classified into one of the 3 dimension types. -// This is a helpful initial step into defining tensor contraction -// optimizations. -// -// A concrete example: -// T0 [I0, I1, I2, R3, I4, I5] = mma(T1[I01, B11, B21, I31, I41, B51], T2[B02, -// I12, B22, I32, I42, I52], {3}; -// In this case some example querries: -// K dimension of T0 = {R3} -// M dimension of T1 = {I01} -// N dimension of T2 = {I52} -// etc. -std::vector getMmaRootDimensions( - TensorView* tv, - MmaOp* mma, - MmaDimension dimension) { - // Build a fusion-level root domain map - // so we can use the mma swizzles on non-immediate tensor operands, for - // example loadstore staging ops. - ComputeAtRootDomainMap root_map; - root_map.build(); - - // FIXME: - // Several optimization is possible at this stage but assuming we don't have - // a lot of mma ops in a fusion this could be lower priority. - // First it'd be nice not having to build root map every time this function - // is called. That'd require some explicit boundary where we "lock" the - // compute in the fusion so the root map stays valid. - // Second it'd reduce complexity of the below matching by an order if we have - // something similar to "disjointSetOf" in idGraph, for just the root domains - // at scheduler composing time. - auto mma_root_dimensions = getMmaDomains(mma, dimension); - auto mma_accumulator_tv = mma->out()->as(); - - std::vector result; +} // namespace - // Need to use root domain for accumulator tv and maybe rfactor domain - // otherwise. See [Use Root Domain in Accumulator TV]. - auto is_mma_output = - tv->definition() != nullptr && tv->definition()->isA(); - const auto& tv_root_domain = - is_mma_output ? tv->getRootDomain() : tv->getMaybeRFactorDomain(); - - // Loop through tensorview's root domains and accumulate all the - // root domain IterDomain's that maps to any of the collected - // mma root dimension from the mma accumulator tv. - for (auto tv_id : tv_root_domain) { - if (std::any_of( - mma_root_dimensions.begin(), - mma_root_dimensions.end(), - [&](IterDomain* mma_id) { - return root_map.canMap( - tv->domain(), tv_id, mma_accumulator_tv->domain(), mma_id); - })) { - result.push_back(tv_id); - } +void WarpMmaSwizzler::scheduleLdMatrix(TensorView* tv, MmaOperand operand) { + bool transpose = tv->definition()->as()->opType() == + LoadStoreOpType::LdMatrixTranspose; + // For A, we have an extra outer dim (-6), which is the "warp group". For + // Hopper, mma instructions executes on warp group level. For Turing/Ampere, + // this dim will just have extent 1. + + // A B + // -6 -5 -4 -3 -2 -1 or -5 -4 -3 -2 -1 + //[4moo, 8mi, 4k, 2ko, 2mo, 2ki] [8ni, 4k, 2ko, 1no, 2ki] + tv->reorder({{-2, -4}, {-3, -5}}); + // A B + // -6 -5 -4 -3 -2 -1 or -5 -4 -3 -2 -1 + //[4moo, 2ko, 2mo, 8mi, 4k, 2ki] [2ko, 1no, 8ni, 4k, 2ki] + tv->merge(-2); + // A B + // -5 -4 -3 -2 -1 or -4 -3 -2 -1 + //[4moo, 2ko, 2mo, 8mi, 8k] [2ko, 1no, 8ni, 8k] + if (transpose) { + tv->reorder({{-2, -1}}); + // A B + // -5 -4 -3 -2 -1 or -4 -3 -2 -1 + //[4moo, 2ko, 2mo, 8k, 8mi] [2ko, 1no, 8k, 8ni] } - return result; -} - -//! Utility function to help check that the innermost 3 iterdomains -//! are also the corresponding innermost {m,n,k} dimensions of -//! the root id's that are participating in the mma operation. -//! This is a format check before the warp mma swizzler applies mma -//! swizzles to make sure that the swizzler is applying the right -//! swizzles to the right axes. -//! This check will be relaxed as we build out the mma usage patterns. -void validateMmaRootInnerMNK( - TensorView* tv, - MmaOptions options, - int m, - int n, - int k) { - auto mma = options.mmaOp(); - auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M); - auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N); - auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K); - - NVF_CHECK( - !m_dims.empty() && !n_dims.empty() && !k_dims.empty(), - "validateMmaRootInnerMNK: MMA Axes incomplete"); - - // Still check the innermost dims of each at the current state: - NVF_ERROR(tv->nDims() >= 3); - NVF_ERROR( - canValidateIsInnerDim(m_dims.back(), tv->axis(-3), m), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); - NVF_ERROR( - canValidateIsInnerDim(n_dims.back(), tv->axis(-2), n), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); - NVF_ERROR( - canValidateIsInnerDim(k_dims.back(), tv->axis(-1), k), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); -} - -//! Utility function to help check that the innermost 3 iterdomains -//! are also the corresponding innermost {m,n} dimensions of -//! the root id's that are participating in the mma operation. -//! This is a format check before the warp mma swizzler applies mma -//! swizzles to make sure that the swizzler is applying the right -//! swizzles to the right axes. -//! This check will be relaxed as we build out the mma usage patterns. -void validateMmaRootInnerMN(TensorView* tv, MmaOptions options, int m, int n) { - auto mma = options.mmaOp(); - auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M); - auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N); - - NVF_CHECK( - !m_dims.empty() && !n_dims.empty(), - "validateMmaRootInnerMNK: MMA Axes incomplete"); - - // Still check the innermost dims of each at the current state: - NVF_ERROR(tv->nDims() >= 2); - NVF_ERROR( - canValidateIsInnerDim(m_dims.back(), tv->axis(-2), m), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); - NVF_ERROR( - canValidateIsInnerDim(n_dims.back(), tv->axis(-1), n), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); -} - -//! Performs checks on tv given to schedule ld matrix. -//! Currently only allowed ones are either: -//! 1. direct output of an ldmatrix op or -//! 2. direct output of a broadcast op following a ldmatrix op -//! Returns true if the tv is an immediate output of ldmatrix op -//! -//! TODO: this check is a WAR with pattern matching for now. -//! The two patterns mentioned above are the only supported use -//! cases of ldmatrix currently. This restriction can be greatly -//! relaxed after the iterdomain swizzle infrastructure, which -//! will provide the capability to directly model the exact -//! data format of ldmatrix output. -bool checkLdMatrixTv(TensorView* tv) { - // First check if tv is an ldmatrix output: - auto tv_def = tv->definition(); - NVF_CHECK(tv_def != nullptr, "ldmatrix : invalid tv"); - bool is_immediate_output = true; - if (!ir_utils::isLdMatrixOp(tv_def)) { - // Only allow one broadcast in between tv and the ldmatrix op - NVF_CHECK( - tv_def->isA(), - "ldmatrix: only allow serial broadcast between ldmatrix and mma"); - tv_def = tv_def->input(0)->definition(); - NVF_CHECK(tv_def != nullptr, "ldmatrix : invalid tv"); - is_immediate_output = false; + tv->merge(-4); + tv->merge(-3); + if (operand == MmaOperand::A) { + // For A, we have an extra outer dim which is the warp group. Merge it back + // here so that TIDx represent a warp group, instead of a single warp. + tv->merge(-3); } - NVF_CHECK( - ir_utils::isLdMatrixOp(tv_def), - "ldmatrix : invalid op type: ", - tv_def->toString()); - NVF_CHECK( - tv->nDims() >= 2, - "ldmatrix: scheduled tv needs to be at least 2 dimensional"); - NVF_CHECK( - !tv->axis(-1)->isBroadcast(), "ldmatrix: unsupported scheduled axes"); - NVF_CHECK( - !tv->axis(-1)->isReduction(), "ldmatrix: unsupported scheduled axes"); - NVF_CHECK( - !tv->axis(-2)->isBroadcast(), "ldmatrix: unsupported scheduled axes"); - NVF_CHECK( - !tv->axis(-2)->isReduction(), "ldmatrix: unsupported scheduled axes"); - return is_immediate_output; -} - -void scheduleVoltaA(TensorView* tv, MmaOptions options) { - // Assumed: - // [..., 16, 16 ,4] - // [..., M, BN, K] - // Some validation: - validateMmaRootInnerMNK(tv, options, 16, 16, 4); - bool transposed = isOperandTransposed(options); - - tv->split(-3, 4); - - // Split out 16 from the bcast - tv->split(-2, 16); - tv->split(-2, 8); - - // -6 -5 -4 -3 -2 -1 - //[Mo4, Mi4, Noo, No2, Ni8, K] - - if (transposed) { - tv->reorder({{-5, -3}, {-3, -5}}); - // -6 -5 -4 -3 -2 -1 - //[Mo4, No2, Noo, Mi4, Ni8, K] - - } else { - tv->reorder({{-5, -1}, {-3, -5}, {-1, -3}}); - // -6 -5 -4 -3 -2 -1 - //[Mo4, No2, Noo, K, Ni8, Mi4] + // A B + // -2 -1 or -2 -1 + //[128, 8] [16, 8] + + // The extent of axis(-2) is the number of threads that contains useful + // addresses. We can not parallelize axis(-2) directly if the extent is less + // than 32. Instead, we should split axis(-1) and merge it to axis(-2) to + // get a complete warp of 32 threads. This makes sure that, during lowering, + // our system can correctly compute the buffer size. + int64_t num_tidx_with_addr = tv->axis(-2)->extent()->evaluate().as(); + if (num_tidx_with_addr < 32) { + int64_t factor = 32 / num_tidx_with_addr; + tv->split(-1, factor, false); + tv->reorder({{-2, -3}, {-3, -2}}); + // -3 -2 -1 + // [factor, num_tidx_with_addr, 8/factor] + // For indexing, we only care about what we get when the index of axis(-3) + // is 0. For higher values, they are garbage, and abandoned. + tv->merge(-3); } - tv->merge(-6); - tv->merge(-5); - tv->merge(-4); + // A B + // -2 -1 or -2 -1 + //[128, 8] [32, 4] - //[Warp, Ni8, K/Mi4] - tv->axis(-3)->parallelize(ParallelType::TIDx); + tv->axis(-2)->parallelize(ParallelType::TIDx); + // TODO: this is not really vectorization. Change its parallel type to Mma. + tv->axis(-1)->parallelize(ParallelType::Vectorize); + setWarpMapped(tv, 2); } -void scheduleVoltaB(TensorView* tv, MmaOptions options) { - // Assumed: - // [..., 16,16,4] - // [..., BM, N, K] - // Some validation: - validateMmaRootInnerMNK(tv, options, 16, 16, 4); - - bool transposed = isOperandTransposed(options); - tv->split(-3, 16); - tv->split(-3, 8); +void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOperand operand) { + // This function works for all mma ops, regardless of the architecture. + // Operand A and B are slightly different in the sense that operand A can be + // (>=16)x16 matrix, but operand B can only be 8x16 or 16x16. For operand A, + // the Hopper one is the most general one. For earlier architectures, we will + // have some dimensions with size 1 after split, this is fine. Memory format + // for hopper mma: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-a + NVF_ERROR(tv->nDims() >= 2); + // A B + // -2 -1 or -2 -1 + //[64m, 16k] [8n, 16k] tv->split(-2, 8); + tv->split(-1, 2); tv->split(-2, 4); - // -7 -6 -5 -4 -3 -2 -1 - //[Moo, Mo2, Mi8, No2, Nio2, Nii4, K] - tv->reorder({{-6, -4}, {-5, -6}, {-4, -3}, {-3, -5}}); + // A B + // -5 -4 -3 -2 -1 or -5 -4 -3 -2 -1 + //[8m, 8m, 2k, 4k, 2k'] [1n, 8n, 2k, 4k, 2k'] - // -7 -6 -5 -4 -3 -2 -1 - //[Moo, Mi8, Nio2, Mo2, No2, Nii4, K ] - if (transposed) { - tv->reorder({{-2, -1}, {-1, -2}}); - // -7 -6 -5 -4 -3 -2 -1 - //[Moo, Mi8, Nio2, Mo2, No2, K, Nii4] + if (operand == MmaOperand::A) { + // For A, we need to have an extra outer dim (-6) for warp group. + tv->split(-5, 2); + // On Ampere and Turing, the extent of dim -6 after the split below will be + // just 1. On Hopper, the dim -6 will be 4 because Hopper warp group + // instructions have 4x larger m extend than Ampere/Turing. } - tv->merge(-5); - tv->merge(-4); - tv->merge(-3); - - //[Moo, Mi8, Warp, K/Nii4] - tv->axis(-2)->parallelize(ParallelType::TIDx); -} - -void scheduleLdMatrix(TensorView* tv, MmaOptions options) { - // Check if tv should use ldmatrix layout and - // if tv is immediate output of ldmatrix - bool is_immediate_output = checkLdMatrixTv(tv); - - // Check mma option is supported - NVF_CHECK( - options.macro == MmaOptions::MacroType::Ampere_16_8_16 || - options.macro == MmaOptions::MacroType::Ampere_16_16_16 || - options.macro == MmaOptions::MacroType::Turing_16_8_16 || - options.macro == MmaOptions::MacroType::Turing_16_16_16, - "scheduleLdMatrix: unknown macro for ldmatrix"); - - if (options.operand == MmaOptions::Operand::A) { - NVF_ERROR(tv->nDims() >= 2); - // validation: - auto mma = options.mmaOp(); - auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M); - auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K); - bool transposed = - (options.layout == MmaOptions::MmaLayout::NN || - options.layout == MmaOptions::MmaLayout::NT); - - NVF_ERROR( - canValidateIsInnerDim(m_dims.back(), tv->axis(-2), 16), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); - NVF_ERROR( - canValidateIsInnerDim(k_dims.back(), tv->axis(-1), 16), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain", - tv->toString()); - - //[16m, 16k] - tv->split(-2, 8); - tv->split(-1, 8); - - // -4 -3 -2 -1 - //[2o, 8o, 2i, 8i] - tv->reorder({{-4, -3}, {-3, -2}, {-2, -4}}); - - // -4 -3 -2 -1 - // [2i, 2o, 8o, 8i] - - if (transposed) { - tv->reorder({{-1, -2}, {-2, -1}}); - } - - tv->merge(-4); - tv->merge(-3); - // [warp, 8i/o] - - tv->axis(-2)->parallelize(ParallelType::TIDx); - } else if (options.operand == MmaOptions::Operand::B) { - auto mma = options.mmaOp(); - auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N); - auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K); - bool transposed = - (options.layout == MmaOptions::MmaLayout::NT || - options.layout == MmaOptions::MmaLayout::TT); - - NVF_ERROR( - canValidateIsInnerDim(k_dims.back(), tv->axis(-1), 16), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); - - // Each ldmatrix 4 would be loading an effective 16x16x16 tile, which is 2x - // the - // size of regular 16x8x16 tile supported by largest mma operation. The - // swizzle also needs to be different to take this into account. - // TODO: - // Using an emulated 16x16x16 mma tile is a temporary step to enable the - // widest load possible for scheduler bring up phase. - // A unifying step would be needed in a follow up to support all these - // swizzles - // with a single affine utility. - bool use_ldmatrix4 = canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 16); - - if (use_ldmatrix4) { - // [... N16, K16] - tv->split(-2, 8); - tv->split(-1, 8); - - // -4 -3 -2 -1 - // [... N2o, N8, K2o, K8] - tv->reorder({{-3, -2}, {-2, -3}}); - // [... N2o, K2o, N8, K8] - - if (transposed) { - tv->reorder({{-1, -2}, {-2, -1}}); - } - - tv->merge(-4); - tv->merge(-3); + // A B + // -6 -5 -4 -3 -2 -1 or -5 -4 -3 -2 -1 + //[4m, 2m, 8m, 2k, 4k, 2k'] [1n, 8n, 2k, 4k, 2k'] + + tv->reorder({{-4, -5}, {-5, -2}, {-2, -4}}); + + // A B + // -6 -5 -4 -3 -2 -1 or -5 -4 -3 -2 -1 + //[4m, 8m, 4k, 2k, 2m, 2k'] [8n, 4k, 2k, 1n, 2k'] + + // ldmatrix loads multiple 8x8 matrices from shared memory to registers in a + // swizzled memory format. + // +--------+--------+ + // | | | + // | 8x8 | 8x8 | + // | | | + // +--------+--------+ + // | | | + // | 8x8 | 8x8 | + // | | | + // +--------+--------+ + // If n_major is true, these 8x8 matrices are visited in the order of: + // top left -> top right -> bottom left -> bottom right. + // If n_major is false, these 8x8 matrices are visited in the order of: + // top left -> bottom left -> top right -> bottom right. + // + // In principle, only `n_major = false` should be needed. But unfortunately, + // we are taking advantage of the ldmatrix large load in a pretty hacky way. + // For example, for Turing, only m16n8k8 is supported by hardware. But we are + // also using a fake m16n8k16 and m16n16k16, which uses a single large + // ldmatrix to load data to register, and run multiple mma instructions to + // consume these data. In the future, we should only keep the m16n8k8 macro, + // and schedule m16n8k16 and m16n16k16 more correctly than this current way. + bool n_major = + operand == MmaOperand::B && tv->axis(-2)->extent()->evaluate() > 1; + if (n_major) { + tv->reorder({{-2, -3}, {-3, -2}}); + // -5 -4 -2 -3 -1 + //[8n, 4k, 1n, 2k, 2k'] + } - // [Warp, K8] - tv->axis(-2)->parallelize(ParallelType::TIDx); - } else { - // validation: - NVF_ERROR( - canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 8), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); - - if (transposed) { - // [8, 16] - tv->split(-2, 4); - - // [2i, 4i, 16] - tv->reorder({{-1, -2}, {-2, -1}}); - // [2i, 16, 4i] - - tv->merge(-3); - // [warp, 4i] - } else { - //[8, 16] - tv->split(-1, 4); - tv->split(-2, 2); - - // 0 1 2 3 - //[8, oo2,oi2,i4] - tv->reorder({{-4, -2}, {-2, -4}}); - - // 0 1 2 3 - //[oi2, oo2, 8,i4] - - tv->merge(-4); - tv->merge(-3); - // 0 1 - //[warp, i4] + bool set_allocation = ir_utils::isLdMatrixOp(tv->definition()); + if (!set_allocation) { + for (auto u : tv->uses()) { + if (u->isA()) { + set_allocation = true; + break; } - - tv->axis(-2)->parallelize(ParallelType::TIDx); } - } else { - NVF_ERROR(false, "unreachable"); } - - if (is_immediate_output) { - tv->axis(-1)->parallelize(ParallelType::Vectorize); + if (set_allocation) { + tv->setAllocationDomain(tv->getLeafDomain(), true); } } -} // namespace - -void WarpMmaSwizzler::scheduleVoltaOperandRead( - TensorView* tv, - MmaOptions options) { - switch (options.operand) { - case MmaOptions::Operand::A: - scheduleVoltaA(tv, options); - setWarpMapped(tv, 3); - break; - case MmaOptions::Operand::B: - scheduleVoltaB(tv, options); - setWarpMapped(tv, 4); - break; - default: - NVF_CHECK(false, "WarpMmaSwizzler: please specify operand"); - } -} - -// Fp32 and Fp16 outputs have different layouts on volta, -// but we only support fp32 accumulate at this stage. -void WarpMmaSwizzler::scheduleVoltaM16N16K4Fp32Output( - TensorView* tv, - const MmaOptions& options) { - // Assume last 2 dims [M16, N16] or [M16, N16, R] - bool is_reduction = tv->axis(-1)->isReduction(); +void WarpMmaSwizzler::scheduleMmaWarpOutput(TensorView* tv) { + // This function works for all mma ops, regardless of the architecture. The + // Hopper one is the most general one. For earlier architectures, we will have + // some dimensions with size 1 after split, this is fine. + // Memory format for hopper mma: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d - // Make sure instruction tile size is correct. - if (is_reduction) { - validateMmaRootInnerMNK(tv, options, 16, 16, 4); - } else { - validateMmaRootInnerMN(tv, options, 16, 16); - } - - int m_pos = is_reduction ? -3 : -2; + // Assume last 2 dims, for example [M64, N24] or [M64, N24, R] + NVF_ERROR(tv->nDims() >= 2); + bool is_mma_output = tv->definition()->isA(); + + int m_pos = is_mma_output ? -3 : -2; + int n_pos = is_mma_output ? -2 : -1; + + // m n + // [M64, N24 (,R)] + tv->split(m_pos--, 8); + tv->split(m_pos--, 2); + // m n + // [M4, M2, M8, N24 (,R)] + tv->split(n_pos, 8); + tv->split(n_pos, 2); + + n_pos -= 2; + m_pos -= 2; + // m n + // [M4, M2, M8, N3, N4, N2 (,R)] + + tv->reorder({{m_pos + 1, n_pos + 1}, {n_pos + 1, m_pos + 2}}); + // m n + // [M4, M8, N4, N3, M2, N2 (,R)] + tv->merge(m_pos++); + tv->merge(m_pos++); - // Assumed: // m - // [..., 16,16, (4)] - // [..., M, N, (R)] - tv->split(m_pos, 4); - tv->split(m_pos, 2); - tv->split(m_pos + 1, 8); - tv->split(m_pos + 1, 4); - tv->split(m_pos + 1, 2); - - // m-5 m-4 m-3 m-2 m-1 m m+1 m+2 - // [..., Mo4, Mio2, Mii2, No2, Nio2, Niio2, Niii2, (R)] - tv->reorder( - {{m_pos - 4, m_pos - 1}, - {m_pos - 3, m_pos - 2}, - {m_pos - 2, m_pos - 4}, - {m_pos - 1, m_pos}, - {m_pos, m_pos - 3}}); - - // m-5 m-4 m-3 m-2 m-1 m m+1 m+2 - // [..., Mo4, No2, Niio2, Mii2, Mio2, Nio2, Niii2, (R)] - - tv->merge(m_pos - 5); - tv->merge(m_pos - 4); - tv->merge(m_pos - 3); - - // m-2 m-1 m m+1 m+2 - //[Warp, Mio2, Nio2, Niii2, (R)] - tv->reorder({{m_pos - 1, m_pos}}); - // m-2 m-1 m m+1 m+2 - //[Warp, Nio2, Mio2, Niii2, (R)] - tv->axis(m_pos - 2)->parallelize(ParallelType::TIDx); - - if (is_reduction && tv->definition()->isA()) { - // Set instruction loops for mma reduce output - for (int pos : c10::irange(5)) { - if (!tv->axis(-pos - 1)->isThread()) { - tv->axis(-pos - 1)->parallelize(ParallelType::Mma); - } - tv->axis(-pos - 1)->toMmaSwizzled(); - } - } -} - -void WarpMmaSwizzler::scheduleTuringOperandRead( - TensorView* tv, - MmaOptions options) { - scheduleLdMatrix(tv, options); - setWarpMapped(tv, 2); -} - -void WarpMmaSwizzler::scheduleTuringM16N8K16MmaWarpOutput( - TensorView* tv, - const MmaOptions& options) { - // Assume last 2 dims [M16, N8] or [M16, N8, R] - // Locate instruction m - bool is_reduction = tv->axis(-1)->isReduction(); - - // Make sure instruction tile size is correct. - if (is_reduction) { - validateMmaRootInnerMNK(tv, options, 16, 8, 16); - } else { - validateMmaRootInnerMN(tv, options, 16, 8); + // [WarpGroup128, N3, M2, N2 (,R)] + + if (is_mma_output) { + tv->split(-1, 2); + tv->split(-2, 4); + m_pos -= 2; + // m + // [WarpGroup128, N3, M2, N2, Ro, R4, R2] } - int m_pos = is_reduction ? -3 : -2; - - // m - // [16, 8 (,R)] - tv->split(m_pos, 8); - tv->split(m_pos + 1, 2); - - // m - // [2o, 8o, 4i, 2i (,R)] - tv->merge(m_pos - 1); - - // m - // [2o, Warp, 2i (,R)] NVF_CHECK(tv->definition() != nullptr); - if (is_reduction && tv->definition()->isA()) { - // Set instruction loops for mma reduce - for (int pos : c10::irange(4)) { - tv->axis(-pos - 1)->parallelize(ParallelType::Mma); - } - } - tv->axis(m_pos)->parallelize(ParallelType::TIDx); -} - -void WarpMmaSwizzler::scheduleTuringM16N16K16MmaWarpOutput( - TensorView* tv, - const MmaOptions& options) { - // Assume last 2 dims [M16, N8] or [M16, N8, R] - // Locate instruction m - bool is_reduction = tv->axis(-1)->isReduction(); - - // Make sure instruction tile size is correct. - if (is_reduction) { - validateMmaRootInnerMNK(tv, options, 16, 16, 16); - } else { - validateMmaRootInnerMN(tv, options, 16, 16); - } - - int m_pos = is_reduction ? -3 : -2; - // m - // [16, 16 (,R)] - - tv->split(m_pos + 1, 8); - // m - // [16, n2, 8 (,R)] - tv->reorder({{m_pos, m_pos - 1}, {m_pos - 1, m_pos}}); - - // m - // [n2, 16, 8 (,R)] - tv->split(m_pos, 8); - tv->split(m_pos + 1, 2); - - // m - // [2o, 8o, 4i, 2i (,R)] - tv->merge(m_pos - 1); - - // m - // [2o, Warp, 2i (,R)] - NVF_CHECK(tv->definition() != nullptr); - if (is_reduction && tv->definition()->isA()) { + if (is_mma_output) { // Set instruction loops for mma reduce - for (int pos : c10::irange(5)) { - tv->axis(-pos - 1)->parallelize(ParallelType::Mma); + int pos = -1; + while (pos > m_pos) { + tv->axis(pos--)->parallelize(ParallelType::Mma); } + setWarpMapped(tv, 7); } - - tv->axis(m_pos)->parallelize(ParallelType::TIDx); -} - -namespace { - -bool isMmaInitLoop(const kir::Scope& loop_body) { - for (auto expr : loop_body.exprs()) { - if (auto inner_loop = dynamic_cast(expr)) { - if (!isMmaInitLoop(inner_loop->body())) { - return false; - } - } else if (auto ldst = dynamic_cast(expr)) { - if (!ir_utils::isTvOp(ldst)) { - return false; - } - if (auto ti = dynamic_cast(ldst->output(0))) { - if (!ti->view()->definition() || - !ti->view()->definition()->isA()) { - return false; - } - } - if (auto tv = dynamic_cast(ldst->output(0))) { - if (!tv->definition() || !tv->definition()->isA()) { - return false; - } - } - } else if (auto ite = dynamic_cast(expr)) { - if (!isMmaInitLoop(ite->thenBody())) { - return false; - } - if (!isMmaInitLoop(ite->elseBody())) { - return false; - } - } else { - return false; - } - } - return true; -} - -} // namespace - -bool isMmaInitLoop(const kir::ForLoop* loop) { - return isMmaInitLoop(loop->body()); } void canonicalizeMmaTvOrdering(TensorView* tv) { @@ -1467,7 +1023,7 @@ ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion) { return ProblemIterDomains{m, n, k}; } -MatmulProblemLayoutOpt getMatmulLayout(Fusion* fusion) { +MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion) { ComputeAtMap ca_map(fusion); const auto mma_input_candidates = ir_utils::filterByType(fusion->inputs()).vector(); @@ -1530,16 +1086,16 @@ MatmulProblemLayoutOpt getMatmulLayout(Fusion* fusion) { } if ((mk_found && kn_found) && !(km_found || nk_found)) { - return MmaOptions::MmaLayout::TT; + return MmaLayout::TT; } if ((km_found && kn_found) && !(mk_found || nk_found)) { - return MmaOptions::MmaLayout::NT; + return MmaLayout::NT; } if ((mk_found && nk_found) && !(km_found || kn_found)) { - return MmaOptions::MmaLayout::TN; + return MmaLayout::TN; } if ((km_found && nk_found) && !(mk_found || kn_found)) { - return MmaOptions::MmaLayout::NN; + return MmaLayout::NN; } return {"Failed to decide fusion inputs' data layout."}; @@ -1563,9 +1119,8 @@ RolesMapOpt getTensorsRoles(Fusion* fusion) { return mma_output_domains.getErrorMsg(); } - const auto findRolesByDomains = [](const DependenciesMap& deps_map, - RolesMap& roles_map, - const bool processing_output) { + const auto findInputRolesByDomains = [](const DependenciesMap& deps_map, + RolesMap& roles_map) { for (const auto& entry : deps_map) { const auto& domains = entry.second; const auto begin = domains.begin(); @@ -1575,38 +1130,79 @@ RolesMapOpt getTensorsRoles(Fusion* fusion) { bool has_n = (end != std::find(begin, end, MatmulDomain::N)); bool has_k = (end != std::find(begin, end, MatmulDomain::K)); - if (!processing_output && has_m && has_k && !has_n) { + if (has_m && has_k && !has_n) { roles_map[MatmulRole::INPUT_A].push_back(entry.first); continue; } - if (!processing_output && has_n && has_k && !has_m) { + if (has_n && has_k && !has_m) { roles_map[MatmulRole::INPUT_B].push_back(entry.first); continue; } - if (!processing_output && has_m && has_n && !has_k) { + if (has_m && has_n && !has_k) { roles_map[MatmulRole::INPUT_C].push_back(entry.first); continue; } // Bias vectors are assigned to INPUT_C role - if (!processing_output && has_m && !has_n && !has_k) { + if (has_m && !has_n && !has_k) { roles_map[MatmulRole::INPUT_C].push_back(entry.first); continue; } + } + + for (auto& [role, tvs] : roles_map) { + // NOTE: sort input roles in descending order by uses() size, and + // if equal then by name() to ensure the stable ordering of tensor + // views in collections assigned to the supported roles + std::sort(tvs.begin(), tvs.end(), [](TensorView* a, TensorView* b) { + return (a->uses().size() == b->uses().size()) + ? (a->name() < b->name()) + : (a->uses().size() > b->uses().size()); + }); + } + }; + + const auto findOutputRolesByDomains = [](const DependenciesMap& deps_map, + RolesMap& roles_map) { + std::vector storage; + storage.reserve(deps_map.size()); + + for (const auto& entry : deps_map) { + const auto& domains = entry.second; + const auto begin = domains.begin(); + const auto end = domains.end(); + + bool has_m = (end != std::find(begin, end, MatmulDomain::M)); + bool has_n = (end != std::find(begin, end, MatmulDomain::N)); // NOTE: depending on fusion definition k domain may appear in the output: // - for mma_output == fusion output k domain is present // - for mma_output != fusion output (fusion with epilogue) k domain // is not present - if (processing_output && has_m && has_n) { - roles_map[MatmulRole::OUTPUT_D].push_back(entry.first); - continue; + + // NOTE: the core fusion output tensors are the ones with m and n + // domains + if (has_m && has_n) { + storage.push_back(entry.first); } } - for (auto& [role, tvs] : roles_map) { - // sort tvs by name() - std::sort(tvs.begin(), tvs.end(), [](TensorView* a, TensorView* b) { - return a->name() < b->name(); - }); + + // NOTE: sort output roles in descending order by uses() size, and + // if equal then by name() to ensure the stable ordering of tensor + // views in collections assigned to the supported roles + std::sort(storage.begin(), storage.end(), [](TensorView* a, TensorView* b) { + return (a->uses().size() == b->uses().size()) + ? (a->name() < b->name()) + : (a->uses().size() > b->uses().size()); + }); + + if (!storage.empty()) { + // NOTE: currently, we pick as a reference tensor one with `m` and `n` + // IterDomains and the most uses + auto pos = storage.begin(); + roles_map[MatmulRole::OUTPUT_D].push_back(*pos); + for (++pos; pos != storage.end(); ++pos) { + roles_map[MatmulRole::OUTPUT_AUX].push_back(*pos); + } } }; @@ -1619,18 +1215,16 @@ RolesMapOpt getTensorsRoles(Fusion* fusion) { RolesMap roles_map; // Handle fusion input TensorView objects - bool handling_output = false; resolveTvToMatmulDomainsMapping( deps_map, mma_input_candidates, m, n, k, ca_map); - findRolesByDomains(deps_map, roles_map, handling_output); + findInputRolesByDomains(deps_map, roles_map); deps_map.clear(); // Handle fusion output TensorView objects - handling_output = true; resolveTvToMatmulDomainsMapping( deps_map, mma_output_candidates, m, n, k, ca_map); - findRolesByDomains(deps_map, roles_map, handling_output); + findOutputRolesByDomains(deps_map, roles_map); return roles_map; } diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index aaacf67876a..215a59e3560 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -70,7 +70,7 @@ void canonicalizeMmaTvOrdering(TensorView* tv); //! This class is used to implement the thread swizzle format //! required for the mma macros, cf. PTX ISA 9.7.13.4. //! -//! The mma instructions (Volta through Ampere) require specific +//! The mma instructions (Volta and later arch) require specific //! thread mapping within a warp for both the mma inputs and //! mma outputs. All mma swizzle patterns seen so far turned out //! to be affine, so we could use the normal scheduler interface @@ -89,7 +89,7 @@ void canonicalizeMmaTvOrdering(TensorView* tv); //! as follows: //! //! Step 1. Before scheduling, the mma op needs to be configured with a macro -//! type, either manually or inferred (eg. Volta_16_16_4). +//! type, either manually or inferred (eg. Ampere_16_8_8). //! //! Step 2. Scheduler can tile the outer dimensions based on any heuristics, //! i.e. the CTA tiling, warp tiling, splitK etc. @@ -100,7 +100,7 @@ void canonicalizeMmaTvOrdering(TensorView* tv); //! tensordomain (see [Operand Layout Convention] for exact definition). //! //! For example before calling WarpMmaSwizzler, the domain could look like: -//! [TileM, TileN, TileK, Im(16), In(16), Rk(4)], to use Volta_16_16_4. +//! [TileM, TileN, TileK, Im(16), In(8), Rk(8)], to use Ampere_16_8_8. //! The rightmost 3 iterdomains need to be the innermost component of their //! corresponding root id, similar to vectorization except this requirement //! applies to all 3 rightmost dims. @@ -160,45 +160,29 @@ class WarpMmaSwizzler { //! Applies the output mma swizzling to the given tv, should be used //! on mma output or tv's involved in epilog fusion, i.e. bias. //! The rightmost iterdomains must follow the m,n,k convention before calling. - static void scheduleMmaWarpOutput(TensorView* tv, MmaOptions options); + static void scheduleMmaWarpOutput(TensorView* tv); - //! Applies the input mma swizzling to the given tv, should be used - //! on mma input or tv's involved in any fusion before mma, but after smem - //! read. + //! Applies the input mma swizzling to the given tv as its allocation domain, + //! should be used on mma input or tv's involved in any fusion before mma, but + //! after smem read. //! The rightmost iterdomains must follow the m,n,k convention before calling. - static void scheduleOperandRead( - TensorView* tv, - MmaOptions options = MmaOptions()); - - private: - //! Operand swizzle implementations for Volta mma. - static void scheduleVoltaOperandRead(TensorView* tv, MmaOptions options); - - //! Accumulator swizzle implementations for Volta mma. - static void scheduleVoltaM16N16K4Fp32Output( - TensorView* tv, - const MmaOptions& options); - - //! Operand swizzle implementations for Turing and Ampere mma. - static void scheduleTuringOperandRead(TensorView* tv, MmaOptions options); - - //! Accumulator swizzle implementation for Turing and Ampere mma. - static void scheduleTuringM16N8K16MmaWarpOutput( - TensorView* tv, - const MmaOptions& options); - - //! Accumulator swizzle implementation for emulated 16x16x16 mma tile - //! that enables using ldmatrix.x4. - //! Note: - //! Keeping both this option and the ldmatrix.x2 variant above for - //! now for wider scheduler exploration space. Eventually both of - //! these can be unified with a single affine utility. - static void scheduleTuringM16N16K16MmaWarpOutput( - TensorView* tv, - const MmaOptions& options); - - //! Utility to lock the transformed dimensions from further transforms. - static void setWarpMapped(TensorView* tv, int number_of_dims); + static void scheduleOperandRead(TensorView* tv, MmaOperand operand); + + //! Note [schedule of ldmatrix] + //! If you look at the doc of ldmatrix and mma for Turing and Ampere: + //! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-16816-float + //! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix + //! you will find that, the memory layout of the output of ldmatrix, which + //! matches with the input layout of MMA instruction, mismatch with the index + //! that each thread uses to call ldmatrix. In nvFuser, we schedule the + //! allocation domain of the ldmatrix output and mma inputs to be consistent + //! with the memory layout of the output of ldmatrix, and we schedule the + //! leaf domain of the ldmatrix output to be consistent with the index that + //! each thread uses to call ldmatrix. This function is used to schedule the + //! leaf domain of the ldmatrix output. The allocation domain of the ldmatrix + //! output and mma inputs are scheduled in scheduleOperandRead, which must be + //! called before this function. + static void scheduleLdMatrix(TensorView* tv, MmaOperand operand); }; void checkDimSize( @@ -206,9 +190,6 @@ void checkDimSize( std::vector axis, std::vector expect); -// Returns if the loopnest is initializing for an mma op. -bool isMmaInitLoop(const kir::ForLoop* loop); - //! A constant with minimum number of fusion inputs that could be MMA inputs. //! TODO: update for square matmuls where both inputs are the same tensor constexpr size_t MIN_MATMUL_INPUTS_NUMBER = 2; @@ -254,7 +235,7 @@ class DataWrapperOpt { } }; -using MatmulProblemLayoutOpt = DataWrapperOpt; +using MatmulProblemLayoutOpt = DataWrapperOpt; using ProblemIterDomainsOpt = DataWrapperOpt; using RolesMapOpt = DataWrapperOpt; @@ -270,11 +251,11 @@ using DependenciesMap = std::map; //! - matmul layout which contains information about transposition of matmul //! inputs, it is based on the order of key domains (M,N K) in fusion input //! tensors, -//! - mma layout, some architectures (e.g. Volta) support all combination of +//! - mma layout, some architectures (e.g. Hopper) support all combination of //! transposition of inputs in mma instructions, while other (e.g. Turing, //! Ampere) the only supported transposition is TN which means that mma //! instruction first input is transposed, the second input is non-transposed. -MatmulProblemLayoutOpt getMatmulLayout(Fusion* fusion); +MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion); //! Returns wrapped collection of IterDomains that can be used to get //! problem shape with runtime info. diff --git a/csrc/scheduler/normalization_inner.cpp b/csrc/scheduler/normalization_inner.cpp index e5c0bbbe235..149f9bbb95d 100644 --- a/csrc/scheduler/normalization_inner.cpp +++ b/csrc/scheduler/normalization_inner.cpp @@ -184,12 +184,16 @@ std::shared_ptr innerPersistentHeuristicSharedMemory( const int64_t n_tensor_inputs, const int64_t max_input_dtype_size, const int64_t max_persistent_buffer_size, - const size_t max_vectorize_factor) { + const size_t max_vectorize_factor, + const bool project_to_input, + const PrimDataType index_type) { const auto dev_prop = at::cuda::getCurrentDeviceProperties(); auto rparams = std::make_shared(); rparams->shared_mem_persistent_buffer = true; rparams->persistent_kernel = true; rparams->fastest_dim = true; + rparams->project_persistent_buffers = project_to_input; + rparams->cparams.index_type = index_type; // Inner reduction domain // This heuristic is only used for cases with large total_reduction_numel. // e.g. layer_norm with hidden size larger than 64K for fp16 or 32K for fp32. @@ -251,7 +255,9 @@ std::shared_ptr innerPersistentHeuristic( const int64_t n_tensor_inputs, const int64_t max_input_dtype_size, const int64_t max_persistent_buffer_size, - const size_t vectorize_factor) { + const size_t vectorize_factor, + const bool project_to_input, + const PrimDataType index_type) { if (max_persistent_buffer_size > scheduler_utils::register_file_size) { // use shared memory for persistent buffer return innerPersistentHeuristicSharedMemory( @@ -261,7 +267,9 @@ std::shared_ptr innerPersistentHeuristic( (int64_t)n_tensor_inputs, (int64_t)max_input_dtype_size, max_persistent_buffer_size, - vectorize_factor); + vectorize_factor, + project_to_input, + index_type); } // Set some targets for parallelization @@ -710,6 +718,8 @@ std::shared_ptr innerPersistentHeuristic( rparams->cparams.maxrregcount = (int)nvrtc_register_per_thread; rparams->persistent_kernel = true; rparams->fastest_dim = true; + rparams->project_persistent_buffers = project_to_input; + rparams->cparams.index_type = index_type; // Inner reduction domain rparams->cross_block_inner_reduction = true; @@ -805,9 +815,9 @@ std::shared_ptr getInnerPersistentHeuristics( prop.n_tensor_inputs, prop.max_dtype_size, prop.max_persistent_buffer_size, - prop.vectorize_factor); - rparams->project_persistent_buffers = prop.project_persistent_buffers; - rparams->cparams.index_type = runtime_info.getIndexType(); + prop.vectorize_factor, + prop.project_persistent_buffers, + prop.index_type); return rparams; } diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 712e13824e8..2f38de36c5f 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -267,7 +267,11 @@ bool InnerOuterPersistentKernelScheduler::canScheduleRunTime( if (persistent_buffer_size > available_persistent_buffer_size) { scheduler_debug_utils::canScheduleRejectReason( heuristicType(), - "not enough registers or shared memory for persistence"); + "not enough registers or shared memory for persistence. Needed ", + persistent_buffer_size, + " bytes but only ", + available_persistent_buffer_size, + " bytes are available."); return false; } @@ -365,8 +369,12 @@ std::shared_ptr innerOuterPersistentHeuristic( const int64_t inner_dim_numel, const int64_t max_persistent_buffer_size, const size_t tmp_gmem_dtype_size, - const size_t vectorize_factor) { + const size_t vectorize_factor, + const bool project_to_input, + const PrimDataType index_type) { auto rparams = std::make_shared(); + rparams->project_persistent_buffers = project_to_input; + rparams->cparams.index_type = index_type; // Parameters for inner reduction: // Reduction dim: inner_vect, inner_batch, bdimx and bdimy // Iteration dim: gdimy @@ -609,17 +617,18 @@ std::shared_ptr persistentHeuristic( const size_t tmp_gmem_dtype_size, const int64_t max_persistent_buffer_size, size_t vectorize_factor, - bool project_persistent_buffers) { - std::shared_ptr rparams; + bool project_persistent_buffers, + const PrimDataType index_type) { const int64_t outer_dim_numel = total_iteration_numel; const int64_t inner_dim_numel = inner_most_dimension_numel; - rparams = innerOuterPersistentHeuristic( + auto rparams = innerOuterPersistentHeuristic( outer_dim_numel, inner_dim_numel, max_persistent_buffer_size, tmp_gmem_dtype_size, - vectorize_factor); - rparams->project_persistent_buffers = project_persistent_buffers; + vectorize_factor, + project_persistent_buffers, + index_type); return rparams; } @@ -766,8 +775,8 @@ std::shared_ptr getInnerOuterPersistentHeuristics( tmp_gmem_dtype_size, max_persistent_size, vectorize_factor, - project_persistent_buffers); - heuristic->cparams.index_type = runtime_info.getIndexType(); + project_persistent_buffers, + runtime_info.getIndexType()); return heuristic; } diff --git a/csrc/scheduler/normalization_outer.cpp b/csrc/scheduler/normalization_outer.cpp index 19ce4aec6b1..b5496cc7b7c 100644 --- a/csrc/scheduler/normalization_outer.cpp +++ b/csrc/scheduler/normalization_outer.cpp @@ -294,7 +294,9 @@ std::shared_ptr gridOuterPersistentHeuristic( const int64_t n_tensor_inputs, const int64_t max_input_dtype_size, const int64_t max_persistent_buffer_size, - const size_t vectorize_factor) { + const size_t vectorize_factor, + const bool project_to_input, + const PrimDataType index_type) { auto outer_params = normalization_scheduler_utils::getGridOuterNormalizationParams( total_reduction_numel, @@ -310,6 +312,8 @@ std::shared_ptr gridOuterPersistentHeuristic( auto rparams = std::make_shared(); rparams->persistent_kernel = true; + rparams->project_persistent_buffers = project_to_input; + rparams->cparams.index_type = index_type; rparams->cross_block_inner_reduction = true; rparams->cross_grid_inner_reduction = true; rparams->grid_dim_iter_dom = ParallelType::BIDx; @@ -367,7 +371,9 @@ std::shared_ptr outerPersistentHeuristic( const int64_t n_tensor_inputs, const int64_t max_input_dtype_size, const int64_t max_persistent_buffer_size, - const size_t vectorize_factor) { + const size_t vectorize_factor, + const bool project_to_input, + const PrimDataType index_type) { // Set some targets for parallelization const int64_t n_elems = total_reduction_numel * total_iteration_numel; const auto dev_prop = at::cuda::getCurrentDeviceProperties(); @@ -402,7 +408,9 @@ std::shared_ptr outerPersistentHeuristic( n_tensor_inputs, max_input_dtype_size, max_persistent_buffer_size, - vectorize_factor); + vectorize_factor, + project_to_input, + index_type); } // Compute maximum number of reductions we could do in the same kernel based @@ -565,6 +573,8 @@ std::shared_ptr outerPersistentHeuristic( auto gdimx = ceilDiv(total_iteration_numel, hp.bdimx.get()); rparams->batches_per_block_inner_reduction = hp.batches_per_block.get(); rparams->persistent_kernel = true; + rparams->project_persistent_buffers = project_to_input; + rparams->cparams.index_type = index_type; rparams->fastest_dim = false; rparams->cross_block_inner_reduction = true; @@ -645,9 +655,9 @@ std::shared_ptr getOuterPersistentHeuristics( prop.n_tensor_inputs, prop.max_dtype_size, prop.max_persistent_buffer_size, - prop.vectorize_factor); - rparams->project_persistent_buffers = prop.project_persistent_buffers; - rparams->cparams.index_type = runtime_info.getIndexType(); + prop.vectorize_factor, + prop.project_persistent_buffers, + prop.index_type); return rparams; } diff --git a/csrc/scheduler/normalization_utils.cpp b/csrc/scheduler/normalization_utils.cpp index 05e0fc3a84a..81ea7eae3cc 100644 --- a/csrc/scheduler/normalization_utils.cpp +++ b/csrc/scheduler/normalization_utils.cpp @@ -928,7 +928,8 @@ PersistentKernelProperties getPersistentKernelProperties( .n_tensor_inputs = n_tensor_inputs, .max_dtype_size = max_dtype_size, .vectorize_factor = vectorize_factor, - .project_persistent_buffers = project_persistent_buffers}; + .project_persistent_buffers = project_persistent_buffers, + .index_type = runtime_info.getIndexType()}; } bool checkOpsAndInputs(Fusion* fusion, ScheduleHeuristic schedule_heuristic) { diff --git a/csrc/scheduler/normalization_utils.h b/csrc/scheduler/normalization_utils.h index 059d4c4a9cc..dad168ab306 100644 --- a/csrc/scheduler/normalization_utils.h +++ b/csrc/scheduler/normalization_utils.h @@ -230,6 +230,7 @@ struct PersistentKernelProperties { int64_t max_dtype_size; int64_t vectorize_factor; bool project_persistent_buffers; + PrimDataType index_type; }; PersistentKernelProperties getPersistentKernelProperties( Fusion* fusion, diff --git a/csrc/scheduler/reduction_heuristic.h b/csrc/scheduler/reduction_heuristic.h index 87a63c93aa1..25fc152f312 100644 --- a/csrc/scheduler/reduction_heuristic.h +++ b/csrc/scheduler/reduction_heuristic.h @@ -273,7 +273,8 @@ class ReductionParams : public HeuristicParams { ss << "\ncomputeWith persistent buffers"; } - ss << "\n" << lparams.toString() << "\n"; + ss << "\n" << lparams.toString(); + ss << cparams.toString() << "\n"; ss << "====================================\n"; return ss.str(); } diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index 50bdd5d9427..73eb5916ab3 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -238,9 +238,10 @@ bool isConnectedFusionGraph(Fusion* fusion) { } // Map aliased outputs - for (const auto& [out, in_info] : fusion->ioAlias()) { - Val* in = in_info.first; - component_sets.mapEntries(out, in); + for (Val* out : fusion->outputs()) { + if (Val* in = fusion->getOutputAlias(out).first; in != nullptr) { + component_sets.mapEntries(out, in); + } } // Check connected-ness: diff --git a/csrc/serde/factory.h b/csrc/serde/factory.h index 9f2fb59a23c..e5f9140f5f8 100644 --- a/csrc/serde/factory.h +++ b/csrc/serde/factory.h @@ -9,6 +9,7 @@ #include #include +#include #include namespace nvfuser::serde { @@ -28,20 +29,24 @@ class Factory { Factory(size_t num_parsers) : parsers_(num_parsers, nullptr){}; - void registerParser(int serde_type, SerdeParser parser) { + template + void registerParser(SerdeEnum serde_type, SerdeParser parser) { + auto serde_integer = nvfuser::toUnderlying(serde_type); NVF_ERROR( - serde_type >= 0 && serde_type < (int)parsers_.size(), + serde_integer >= 0 && serde_integer < (int)parsers_.size(), "RegisterParser: Invalid serde type: ", - serde_type); - parsers_.at(serde_type) = parser; + serde_integer); + parsers_.at(serde_integer) = parser; } - BaseTypePtr parse(int serde_type, const SerdeBuffer* buffer) { + template + BaseTypePtr parse(SerdeEnum serde_type, const SerdeBuffer* buffer) { + auto serde_integer = nvfuser::toUnderlying(serde_type); NVF_ERROR( - serde_type >= 0 && serde_type < (int)parsers_.size(), + serde_integer >= 0 && serde_integer < (int)parsers_.size(), "Deserialize: Invalid serde type: ", - serde_type); - return parsers_.at(serde_type)(buffer); + serde_integer); + return parsers_.at(serde_integer)(buffer); } private: diff --git a/csrc/serde/fusion_cache.fbs b/csrc/serde/fusion_cache.fbs index 6b224c69353..c1740b99ef3 100644 --- a/csrc/serde/fusion_cache.fbs +++ b/csrc/serde/fusion_cache.fbs @@ -70,12 +70,12 @@ enum RecordType: int { Ternary_Alpha_VAL_VAL_TV, Ternary_Alpha_TV_VAL_VAL, Ternary_Alpha_VAL_TV_VAL, + NormalDistOp, OutputTv, OutputVal, PadOp, PermuteOp, StrideOrderOp, - RandomOp, ReductionMax, ReductionMin, ReductionProd, @@ -89,6 +89,7 @@ enum RecordType: int { Start, Tensor, TensorSizes, + UniformDistOp, VarianceOp, VarianceMeanOp, Vector, @@ -115,7 +116,6 @@ union RecordData { Scalar, Size, Tensor, - TensorCreation, TensorCreationSymbolic, Vector, } @@ -315,17 +315,9 @@ table Tensor { is_cpu: bool; } -// Data for FullOpRecord -// The shape is defined with constant numbers. -table TensorCreation { - shape: [long]; - dtype: long; -} - -// Data for RandomOpRecord +// Data for RandomDistOpRecord // The shape is symbolic. table TensorCreationSymbolic { - shape: [State]; dtype: long; } @@ -457,6 +449,10 @@ table FusionCache { auto_gen_schedules: [FusionExecutorCache]; // static fusion executor counter global_fusion_count: long; + device_major: long; + device_minor: long; + cuda_major: long; + cuda_minor: long; } root_type FusionCache; diff --git a/csrc/serde/fusion_record_serde.cpp b/csrc/serde/fusion_record_serde.cpp index 1dec1990fd0..300d801e0da 100644 --- a/csrc/serde/fusion_record_serde.cpp +++ b/csrc/serde/fusion_record_serde.cpp @@ -15,7 +15,7 @@ namespace nvfuser::serde { std::vector parseStateArgs( - const flatbuffers::Vector* args) { + const flatbuffers::Vector* args) { std::vector result; for (auto s : *args) { result.emplace_back(s->index(), s->type()); @@ -23,13 +23,13 @@ std::vector parseStateArgs( return result; } -std::optional mapContiguityEnumToOptional(int v) { +std::optional mapContiguityEnumToOptional(Contiguity v) { switch (v) { - case serde::Contiguity_Strided: + case Contiguity::Strided: return std::optional(false); - case serde::Contiguity_Contiguous: + case Contiguity::Contiguous: return std::optional(true); - case serde::Contiguity_None: + case Contiguity::None: return std::nullopt; } NVF_ERROR(false, "Invalid contiguity type."); @@ -39,8 +39,8 @@ std::optional mapContiguityEnumToOptional(int v) { template python_frontend::RecordFunctor* deserializeOpRecord( const std::unordered_map& str_to_func_map, - serde::RecordType record_type, - const serde::RecordFunctor* buffer) { + RecordType record_type, + const RecordFunctor* buffer) { NVF_ERROR( str_to_func_map.find(buffer->name()->str()) != str_to_func_map.end(), "Missing mapping from operation string to nvfuser function in serde deserialization."); @@ -58,8 +58,8 @@ python_frontend::RecordFunctor* deserializeReductionRecord( const std::vector&, bool, nvfuser::DataType)> fusion_op, - serde::RecordType record_type, - const serde::RecordFunctor* buffer) { + RecordType record_type, + const RecordFunctor* buffer) { auto data = buffer->data_as_Reduction(); return new python_frontend::ReductionOpRecord( parseStateArgs(buffer->args()), @@ -73,299 +73,266 @@ python_frontend::RecordFunctor* deserializeReductionRecord( } void RecordFunctorFactory::registerAllParsers() { - auto deserializeStartRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeStartRecord = [](const RecordFunctor* buffer) { return new python_frontend::StartRecord(); }; - registerParser(serde::RecordType_Start, deserializeStartRecord); + registerParser(RecordType::Start, deserializeStartRecord); - auto deserializeEndRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeEndRecord = [](const RecordFunctor* buffer) { return new python_frontend::EndRecord(); }; - registerParser(serde::RecordType_End, deserializeEndRecord); + registerParser(RecordType::End, deserializeEndRecord); // Unary Ops - auto unary_tv_parser = [&](const serde::RecordFunctor* buffer) { + auto unary_tv_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord( - unary_tv, serde::RecordType_Unary_TV, buffer); + unary_tv, RecordType::Unary_TV, buffer); }; - registerParser(serde::RecordType_Unary_TV, unary_tv_parser); + registerParser(RecordType::Unary_TV, unary_tv_parser); - auto unary_val_parser = [&](const serde::RecordFunctor* buffer) { + auto unary_val_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord( - unary_val, serde::RecordType_Unary_VAL, buffer); + unary_val, RecordType::Unary_VAL, buffer); }; - registerParser(serde::RecordType_Unary_VAL, unary_val_parser); + registerParser(RecordType::Unary_VAL, unary_val_parser); // Binary Ops - auto binary_tv_parser = [&](const serde::RecordFunctor* buffer) { + auto binary_tv_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< binary_tv_fn, TensorView*, TensorView*, - TensorView*>(binary_tv, serde::RecordType_Binary_TV, buffer); + TensorView*>(binary_tv, RecordType::Binary_TV, buffer); }; - registerParser(serde::RecordType_Binary_TV, binary_tv_parser); + registerParser(RecordType::Binary_TV, binary_tv_parser); - auto binary_tv_val_parser = [&](const serde::RecordFunctor* buffer) { + auto binary_tv_val_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< binary_tv_val_fn, TensorView*, TensorView*, - Val*>(binary_tv_val, serde::RecordType_Binary_TV_VAL, buffer); + Val*>(binary_tv_val, RecordType::Binary_TV_VAL, buffer); }; - registerParser(serde::RecordType_Binary_TV_VAL, binary_tv_val_parser); + registerParser(RecordType::Binary_TV_VAL, binary_tv_val_parser); - auto binary_val_tv_parser = [&](const serde::RecordFunctor* buffer) { + auto binary_val_tv_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< binary_val_tv_fn, TensorView*, Val*, - TensorView*>(binary_val_tv, serde::RecordType_Binary_VAL_TV, buffer); + TensorView*>(binary_val_tv, RecordType::Binary_VAL_TV, buffer); }; - registerParser(serde::RecordType_Binary_VAL_TV, binary_val_tv_parser); + registerParser(RecordType::Binary_VAL_TV, binary_val_tv_parser); - auto binary_val_parser = [&](const serde::RecordFunctor* buffer) { + auto binary_val_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord( - binary_val, serde::RecordType_Binary_VAL, buffer); + binary_val, RecordType::Binary_VAL, buffer); }; - registerParser(serde::RecordType_Binary_VAL, binary_val_parser); + registerParser(RecordType::Binary_VAL, binary_val_parser); // Ternary Ops - auto ternary_tv_parser = [&](const serde::RecordFunctor* buffer) { + auto ternary_tv_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< ternary_tv_fn, TensorView*, TensorView*, TensorView*, - TensorView*>(ternary_tv, serde::RecordType_Ternary_TV, buffer); + TensorView*>(ternary_tv, RecordType::Ternary_TV, buffer); }; - registerParser(serde::RecordType_Ternary_TV, ternary_tv_parser); + registerParser(RecordType::Ternary_TV, ternary_tv_parser); - auto ternary_tv_tv_val_parser = [&](const serde::RecordFunctor* buffer) { + auto ternary_tv_tv_val_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< ternary_tv_tv_val_fn, TensorView*, TensorView*, TensorView*, - Val*>(ternary_tv_tv_val, serde::RecordType_Ternary_TV_TV_VAL, buffer); + Val*>(ternary_tv_tv_val, RecordType::Ternary_TV_TV_VAL, buffer); }; - registerParser(serde::RecordType_Ternary_TV_TV_VAL, ternary_tv_tv_val_parser); + registerParser(RecordType::Ternary_TV_TV_VAL, ternary_tv_tv_val_parser); - auto ternary_tv_val_tv_parser = [&](const serde::RecordFunctor* buffer) { + auto ternary_tv_val_tv_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< ternary_tv_val_tv_fn, TensorView*, TensorView*, Val*, - TensorView*>( - ternary_tv_val_tv, serde::RecordType_Ternary_TV_VAL_TV, buffer); + TensorView*>(ternary_tv_val_tv, RecordType::Ternary_TV_VAL_TV, buffer); }; - registerParser(serde::RecordType_Ternary_TV_VAL_TV, ternary_tv_val_tv_parser); + registerParser(RecordType::Ternary_TV_VAL_TV, ternary_tv_val_tv_parser); - auto ternary_val_tv_tv_parser = [&](const serde::RecordFunctor* buffer) { + auto ternary_val_tv_tv_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< ternary_val_tv_tv_fn, TensorView*, Val*, TensorView*, - TensorView*>( - ternary_val_tv_tv, serde::RecordType_Ternary_VAL_TV_TV, buffer); + TensorView*>(ternary_val_tv_tv, RecordType::Ternary_VAL_TV_TV, buffer); }; - registerParser(serde::RecordType_Ternary_VAL_TV_TV, ternary_val_tv_tv_parser); + registerParser(RecordType::Ternary_VAL_TV_TV, ternary_val_tv_tv_parser); - auto ternary_val_val_tv_parser = [&](const serde::RecordFunctor* buffer) { + auto ternary_val_val_tv_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< ternary_val_val_tv_fn, TensorView*, Val*, Val*, TensorView*>( - ternary_val_val_tv, serde::RecordType_Ternary_VAL_VAL_TV, buffer); + ternary_val_val_tv, RecordType::Ternary_VAL_VAL_TV, buffer); }; - registerParser( - serde::RecordType_Ternary_VAL_VAL_TV, ternary_val_val_tv_parser); + registerParser(RecordType::Ternary_VAL_VAL_TV, ternary_val_val_tv_parser); - auto ternary_tv_val_val_parser = [&](const serde::RecordFunctor* buffer) { + auto ternary_tv_val_val_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< ternary_tv_val_val_fn, TensorView*, TensorView*, Val*, - Val*>(ternary_tv_val_val, serde::RecordType_Ternary_TV_VAL_VAL, buffer); + Val*>(ternary_tv_val_val, RecordType::Ternary_TV_VAL_VAL, buffer); }; - registerParser( - serde::RecordType_Ternary_TV_VAL_VAL, ternary_tv_val_val_parser); + registerParser(RecordType::Ternary_TV_VAL_VAL, ternary_tv_val_val_parser); - auto ternary_val_tv_val_parser = [&](const serde::RecordFunctor* buffer) { + auto ternary_val_tv_val_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< ternary_val_tv_val_fn, TensorView*, Val*, TensorView*, - Val*>(ternary_val_tv_val, serde::RecordType_Ternary_VAL_TV_VAL, buffer); + Val*>(ternary_val_tv_val, RecordType::Ternary_VAL_TV_VAL, buffer); }; - registerParser( - serde::RecordType_Ternary_VAL_TV_VAL, ternary_val_tv_val_parser); + registerParser(RecordType::Ternary_VAL_TV_VAL, ternary_val_tv_val_parser); - auto ternary_val_parser = [&](const serde::RecordFunctor* buffer) { + auto ternary_val_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord( - ternary_val, serde::RecordType_Ternary_VAL, buffer); + ternary_val, RecordType::Ternary_VAL, buffer); }; - registerParser(serde::RecordType_Ternary_VAL, ternary_val_parser); + registerParser(RecordType::Ternary_VAL, ternary_val_parser); // Ternary-Alpha Ops - auto ternary_alpha_tv_parser = [&](const serde::RecordFunctor* buffer) { + auto ternary_alpha_tv_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< ternary_alpha_tv_fn, TensorView*, TensorView*, TensorView*, TensorView*, - Val*>(ternary_alpha_tv, serde::RecordType_Ternary_Alpha_TV, buffer); + Val*>(ternary_alpha_tv, RecordType::Ternary_Alpha_TV, buffer); }; - registerParser(serde::RecordType_Ternary_Alpha_TV, ternary_alpha_tv_parser); + registerParser(RecordType::Ternary_Alpha_TV, ternary_alpha_tv_parser); - auto ternary_alpha_tv_tv_val_parser = - [&](const serde::RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_tv_tv_val_fn, - TensorView*, - TensorView*, - TensorView*, - Val*, - Val*>( - ternary_alpha_tv_tv_val, - serde::RecordType_Ternary_Alpha_TV_TV_VAL, - buffer); - }; + auto ternary_alpha_tv_tv_val_parser = [&](const RecordFunctor* buffer) { + return deserializeOpRecord< + ternary_alpha_tv_tv_val_fn, + TensorView*, + TensorView*, + TensorView*, + Val*, + Val*>( + ternary_alpha_tv_tv_val, RecordType::Ternary_Alpha_TV_TV_VAL, buffer); + }; registerParser( - serde::RecordType_Ternary_Alpha_TV_TV_VAL, - ternary_alpha_tv_tv_val_parser); - - auto ternary_alpha_tv_val_tv_parser = - [&](const serde::RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_tv_val_tv_fn, - TensorView*, - TensorView*, - Val*, - TensorView*, - Val*>( - ternary_alpha_tv_val_tv, - serde::RecordType_Ternary_Alpha_TV_VAL_TV, - buffer); - }; + RecordType::Ternary_Alpha_TV_TV_VAL, ternary_alpha_tv_tv_val_parser); + + auto ternary_alpha_tv_val_tv_parser = [&](const RecordFunctor* buffer) { + return deserializeOpRecord< + ternary_alpha_tv_val_tv_fn, + TensorView*, + TensorView*, + Val*, + TensorView*, + Val*>( + ternary_alpha_tv_val_tv, RecordType::Ternary_Alpha_TV_VAL_TV, buffer); + }; registerParser( - serde::RecordType_Ternary_Alpha_TV_VAL_TV, - ternary_alpha_tv_val_tv_parser); - - auto ternary_alpha_val_tv_tv_parser = - [&](const serde::RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_val_tv_tv_fn, - TensorView*, - Val*, - TensorView*, - TensorView*, - Val*>( - ternary_alpha_val_tv_tv, - serde::RecordType_Ternary_Alpha_VAL_TV_TV, - buffer); - }; + RecordType::Ternary_Alpha_TV_VAL_TV, ternary_alpha_tv_val_tv_parser); + + auto ternary_alpha_val_tv_tv_parser = [&](const RecordFunctor* buffer) { + return deserializeOpRecord< + ternary_alpha_val_tv_tv_fn, + TensorView*, + Val*, + TensorView*, + TensorView*, + Val*>( + ternary_alpha_val_tv_tv, RecordType::Ternary_Alpha_VAL_TV_TV, buffer); + }; registerParser( - serde::RecordType_Ternary_Alpha_VAL_TV_TV, - ternary_alpha_val_tv_tv_parser); - - auto ternary_alpha_val_val_tv_parser = - [&](const serde::RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_val_val_tv_fn, - TensorView*, - Val*, - Val*, - TensorView*, - Val*>( - ternary_alpha_val_val_tv, - serde::RecordType_Ternary_Alpha_VAL_VAL_TV, - buffer); - }; + RecordType::Ternary_Alpha_VAL_TV_TV, ternary_alpha_val_tv_tv_parser); + + auto ternary_alpha_val_val_tv_parser = [&](const RecordFunctor* buffer) { + return deserializeOpRecord< + ternary_alpha_val_val_tv_fn, + TensorView*, + Val*, + Val*, + TensorView*, + Val*>( + ternary_alpha_val_val_tv, RecordType::Ternary_Alpha_VAL_VAL_TV, buffer); + }; registerParser( - serde::RecordType_Ternary_Alpha_VAL_VAL_TV, - ternary_alpha_val_val_tv_parser); - - auto ternary_alpha_tv_val_val_parser = - [&](const serde::RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_tv_val_val_fn, - TensorView*, - TensorView*, - Val*, - Val*, - Val*>( - ternary_alpha_tv_val_val, - serde::RecordType_Ternary_Alpha_TV_VAL_VAL, - buffer); - }; + RecordType::Ternary_Alpha_VAL_VAL_TV, ternary_alpha_val_val_tv_parser); + + auto ternary_alpha_tv_val_val_parser = [&](const RecordFunctor* buffer) { + return deserializeOpRecord< + ternary_alpha_tv_val_val_fn, + TensorView*, + TensorView*, + Val*, + Val*, + Val*>( + ternary_alpha_tv_val_val, RecordType::Ternary_Alpha_TV_VAL_VAL, buffer); + }; registerParser( - serde::RecordType_Ternary_Alpha_TV_VAL_VAL, - ternary_alpha_tv_val_val_parser); - - auto ternary_alpha_val_tv_val_parser = - [&](const serde::RecordFunctor* buffer) { - return deserializeOpRecord< - ternary_alpha_val_tv_val_fn, - TensorView*, - Val*, - TensorView*, - Val*, - Val*>( - ternary_alpha_val_tv_val, - serde::RecordType_Ternary_Alpha_VAL_TV_VAL, - buffer); - }; + RecordType::Ternary_Alpha_TV_VAL_VAL, ternary_alpha_tv_val_val_parser); + + auto ternary_alpha_val_tv_val_parser = [&](const RecordFunctor* buffer) { + return deserializeOpRecord< + ternary_alpha_val_tv_val_fn, + TensorView*, + Val*, + TensorView*, + Val*, + Val*>( + ternary_alpha_val_tv_val, RecordType::Ternary_Alpha_VAL_TV_VAL, buffer); + }; registerParser( - serde::RecordType_Ternary_Alpha_VAL_TV_VAL, - ternary_alpha_val_tv_val_parser); + RecordType::Ternary_Alpha_VAL_TV_VAL, ternary_alpha_val_tv_val_parser); - auto ternary_alpha_val_parser = [&](const serde::RecordFunctor* buffer) { + auto ternary_alpha_val_parser = [&](const RecordFunctor* buffer) { return deserializeOpRecord< ternary_alpha_val_fn, Val*, Val*, Val*, Val*, - Val*>(ternary_alpha_val, serde::RecordType_Ternary_Alpha_VAL, buffer); + Val*>(ternary_alpha_val, RecordType::Ternary_Alpha_VAL, buffer); }; - registerParser(serde::RecordType_Ternary_Alpha_VAL, ternary_alpha_val_parser); + registerParser(RecordType::Ternary_Alpha_VAL, ternary_alpha_val_parser); // END OpRecord Parsers // START Reduction Parsers - auto reduction_max_parser = [](const serde::RecordFunctor* buffer) { - return deserializeReductionRecord( - max, serde::RecordType_ReductionMax, buffer); + auto reduction_max_parser = [](const RecordFunctor* buffer) { + return deserializeReductionRecord(max, RecordType::ReductionMax, buffer); }; - registerParser(serde::RecordType_ReductionMax, reduction_max_parser); + registerParser(RecordType::ReductionMax, reduction_max_parser); - auto reduction_min_parser = [](const serde::RecordFunctor* buffer) { - return deserializeReductionRecord( - min, serde::RecordType_ReductionMin, buffer); + auto reduction_min_parser = [](const RecordFunctor* buffer) { + return deserializeReductionRecord(min, RecordType::ReductionMin, buffer); }; - registerParser(serde::RecordType_ReductionMin, reduction_min_parser); + registerParser(RecordType::ReductionMin, reduction_min_parser); - auto reduction_prod_parser = [](const serde::RecordFunctor* buffer) { - return deserializeReductionRecord( - prod, serde::RecordType_ReductionProd, buffer); + auto reduction_prod_parser = [](const RecordFunctor* buffer) { + return deserializeReductionRecord(prod, RecordType::ReductionProd, buffer); }; - registerParser(serde::RecordType_ReductionProd, reduction_prod_parser); + registerParser(RecordType::ReductionProd, reduction_prod_parser); - auto reduction_sum_parser = [](const serde::RecordFunctor* buffer) { - return deserializeReductionRecord( - sum, serde::RecordType_ReductionSum, buffer); + auto reduction_sum_parser = [](const RecordFunctor* buffer) { + return deserializeReductionRecord(sum, RecordType::ReductionSum, buffer); }; - registerParser(serde::RecordType_ReductionSum, reduction_sum_parser); + registerParser(RecordType::ReductionSum, reduction_sum_parser); // END Reduction Parsers - auto deserializeBatchNormRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeBatchNormRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_BatchNorm(); return new python_frontend::BatchNormOpRecord( parseStateArgs(buffer->args()), @@ -373,176 +340,179 @@ void RecordFunctorFactory::registerAllParsers() { data->training(), data->channels_last()); }; - registerParser(serde::RecordType_BatchNormOp, deserializeBatchNormRecord); + registerParser(RecordType::BatchNormOp, deserializeBatchNormRecord); - auto deserializeBroadcastRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeBroadcastRecord = [](const RecordFunctor* buffer) { return new python_frontend::BroadcastOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), buffer->name()->str(), parseBoolVector(buffer->data_as_Broadcast()->broadcast_dims())); }; - registerParser(serde::RecordType_BroadcastOp, deserializeBroadcastRecord); + registerParser(RecordType::BroadcastOp, deserializeBroadcastRecord); - auto deserializeCatRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeCatRecord = [](const RecordFunctor* buffer) { return new python_frontend::CatOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), buffer->data_as_Dimension()->dim()); }; - registerParser(serde::RecordType_CatOp, deserializeCatRecord); - - auto deserializeBroadcastInDimRecord = - [](const serde::RecordFunctor* buffer) { - auto data = buffer->data_as_BroadcastInDim(); - return new python_frontend::BroadcastInDimOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - data->output_size(), - parseVector(data->broadcast_dims())); - }; - registerParser( - serde::RecordType_BroadcastInDim, deserializeBroadcastInDimRecord); + registerParser(RecordType::CatOp, deserializeCatRecord); + + auto deserializeBroadcastInDimRecord = [](const RecordFunctor* buffer) { + auto data = buffer->data_as_BroadcastInDim(); + return new python_frontend::BroadcastInDimOpRecord( + parseStateArgs(buffer->args()), + parseStateArgs(buffer->outputs()), + data->output_size(), + parseVector(data->broadcast_dims())); + }; + registerParser(RecordType::BroadcastInDim, deserializeBroadcastInDimRecord); - auto deserializeCastTvRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeCastTvRecord = [](const RecordFunctor* buffer) { std::function fusion_op = static_cast(castOp); return new python_frontend::CastOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), buffer->name()->str(), - serde::RecordType_CastTv, + RecordType::CastTv, fusion_op, mapToNvfuserDtype(buffer->data_as_Dtype()->dtype())); }; - registerParser(serde::RecordType_CastTv, deserializeCastTvRecord); + registerParser(RecordType::CastTv, deserializeCastTvRecord); - auto deserializeCastValRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeCastValRecord = [](const RecordFunctor* buffer) { std::function fusion_op = static_cast(castOp); return new python_frontend::CastOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), buffer->name()->str(), - serde::RecordType_CastVal, + RecordType::CastVal, fusion_op, mapToNvfuserDtype(buffer->data_as_Dtype()->dtype())); }; - registerParser(serde::RecordType_CastVal, deserializeCastValRecord); + registerParser(RecordType::CastVal, deserializeCastValRecord); - auto deserializeScalarRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeScalarRecord = [](const RecordFunctor* buffer) { return new python_frontend::ScalarRecord( parseStateArgs(buffer->outputs()), deserializePolymorphicValue(buffer->data_as_Scalar()), mapToNvfuserDtype(buffer->data_as_Scalar()->dtype())); }; - registerParser(serde::RecordType_Scalar, deserializeScalarRecord); + registerParser(RecordType::Scalar, deserializeScalarRecord); - auto deserializeFullRecord = [](const serde::RecordFunctor* buffer) { - auto data = buffer->data_as_TensorCreation(); + auto deserializeFullRecord = [](const RecordFunctor* buffer) { + auto data = buffer->data_as_TensorCreationSymbolic(); return new python_frontend::FullOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), - parseVector(data->shape()), mapToNvfuserDtype(data->dtype())); }; - registerParser(serde::RecordType_FullOp, deserializeFullRecord); + registerParser(RecordType::FullOp, deserializeFullRecord); - auto deserializeIotaRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeIotaRecord = [](const RecordFunctor* buffer) { return new python_frontend::IotaOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), mapToNvfuserDtype(buffer->data_as_Dtype()->dtype())); }; - registerParser(serde::RecordType_IotaOp, deserializeIotaRecord); + registerParser(RecordType::IotaOp, deserializeIotaRecord); - auto deserializeTorchGatherRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeTorchGatherRecord = [](const RecordFunctor* buffer) { return new python_frontend::TorchGatherOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), buffer->data_as_Dimension()->dim()); }; - registerParser(serde::RecordType_TorchGatherOp, deserializeTorchGatherRecord); + registerParser(RecordType::TorchGatherOp, deserializeTorchGatherRecord); - auto deserializeTakeAlongAxisRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeTakeAlongAxisRecord = [](const RecordFunctor* buffer) { return new python_frontend::TakeAlongAxisOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), buffer->data_as_Dimension()->dim()); }; - registerParser( - serde::RecordType_TakeAlongAxisOp, deserializeTakeAlongAxisRecord); + registerParser(RecordType::TakeAlongAxisOp, deserializeTakeAlongAxisRecord); - auto deserializeIndexSelectRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeIndexSelectRecord = [](const RecordFunctor* buffer) { return new python_frontend::IndexSelectOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), buffer->data_as_Dimension()->dim()); }; - registerParser(serde::RecordType_IndexSelectOp, deserializeIndexSelectRecord); + registerParser(RecordType::IndexSelectOp, deserializeIndexSelectRecord); - auto deserializeOutputTvRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeOutputTvRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_Output(); return new python_frontend::OutputRecord( parseStateArgs(buffer->args()), - serde::RecordType_OutputTv, + RecordType::OutputTv, parseVector(data->stride_order())); }; - registerParser(serde::RecordType_OutputTv, deserializeOutputTvRecord); + registerParser(RecordType::OutputTv, deserializeOutputTvRecord); - auto deserializeOutputValRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeOutputValRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_Output(); return new python_frontend::OutputRecord( parseStateArgs(buffer->args()), - serde::RecordType_OutputVal, + RecordType::OutputVal, parseVector(data->stride_order())); }; - registerParser(serde::RecordType_OutputVal, deserializeOutputValRecord); + registerParser(RecordType::OutputVal, deserializeOutputValRecord); - auto deserializePadRecord = [](const serde::RecordFunctor* buffer) { + auto deserializePadRecord = [](const RecordFunctor* buffer) { return new python_frontend::PadOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), parseVector(buffer->data_as_Pad()->pad_widths())); }; - registerParser(serde::RecordType_PadOp, deserializePadRecord); + registerParser(RecordType::PadOp, deserializePadRecord); - auto deserializePermuteRecord = [](const serde::RecordFunctor* buffer) { - return new python_frontend::DimsOpRecord( + auto deserializePermuteRecord = [](const RecordFunctor* buffer) { + return new python_frontend::DimsOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), parseVector(buffer->data_as_Dims()->dims()), buffer->name()->str()); }; - registerParser(serde::RecordType_PermuteOp, deserializePermuteRecord); + registerParser(RecordType::PermuteOp, deserializePermuteRecord); - auto deserializeStrideOrderRecord = [](const serde::RecordFunctor* buffer) { - return new python_frontend::DimsOpRecord( + auto deserializeStrideOrderRecord = [](const RecordFunctor* buffer) { + return new python_frontend::DimsOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), parseVector(buffer->data_as_Dims()->dims()), buffer->name()->str()); }; - registerParser(serde::RecordType_StrideOrderOp, deserializeStrideOrderRecord); + registerParser(RecordType::StrideOrderOp, deserializeStrideOrderRecord); - auto deserializeRandomRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeNormalDistRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_TensorCreationSymbolic(); - return new python_frontend::RandomOpRecord( + return new python_frontend::RandomDistOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), - parseStateArgs(data->shape()), - buffer->name()->str(), mapToNvfuserDtype(data->dtype())); }; - registerParser(serde::RecordType_RandomOp, deserializeRandomRecord); + registerParser(RecordType::NormalDistOp, deserializeNormalDistRecord); - auto deserializeReshapeRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeUniformDistRecord = [](const RecordFunctor* buffer) { + auto data = buffer->data_as_TensorCreationSymbolic(); + return new python_frontend::RandomDistOpRecord( + parseStateArgs(buffer->args()), + parseStateArgs(buffer->outputs()), + mapToNvfuserDtype(data->dtype())); + }; + registerParser(RecordType::UniformDistOp, deserializeUniformDistRecord); + + auto deserializeReshapeRecord = [](const RecordFunctor* buffer) { return new python_frontend::ReshapeOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); }; - registerParser(serde::RecordType_ReshapeOp, deserializeReshapeRecord); + registerParser(RecordType::ReshapeOp, deserializeReshapeRecord); - auto deserializeSliceRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeSliceRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_Slice(); return new python_frontend::SliceOpRecord( parseStateArgs(buffer->args()), @@ -551,9 +521,9 @@ void RecordFunctorFactory::registerAllParsers() { parseVector(data->end_indices()), parseVector(data->strides())); }; - registerParser(serde::RecordType_SliceOp, deserializeSliceRecord); + registerParser(RecordType::SliceOp, deserializeSliceRecord); - auto deserializeSqueezeRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeSqueezeRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_Squeeze(); return new python_frontend::SqueezeOpRecord( parseStateArgs(buffer->args()), @@ -561,9 +531,9 @@ void RecordFunctorFactory::registerAllParsers() { parseVector(data->original_shape()), parseVector(data->squeeze_dims())); }; - registerParser(serde::RecordType_SqueezeOp, deserializeSqueezeRecord); + registerParser(RecordType::SqueezeOp, deserializeSqueezeRecord); - auto deserializeTensorRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeTensorRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_Tensor(); std::vector> contiguity; @@ -581,39 +551,39 @@ void RecordFunctorFactory::registerAllParsers() { data->is_cpu(), parseVector(data->stride_order())); }; - registerParser(serde::RecordType_Tensor, deserializeTensorRecord); + registerParser(RecordType::Tensor, deserializeTensorRecord); - auto deserializeTensorSizesRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeTensorSizesRecord = [](const RecordFunctor* buffer) { return new python_frontend::TensorSizesRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); }; - registerParser(serde::RecordType_TensorSizes, deserializeTensorSizesRecord); + registerParser(RecordType::TensorSizes, deserializeTensorSizesRecord); - auto deserializeShapeOpRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeShapeOpRecord = [](const RecordFunctor* buffer) { return new python_frontend::ShapeOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); }; - registerParser(serde::RecordType_ShapeOp, deserializeShapeOpRecord); + registerParser(RecordType::ShapeOp, deserializeShapeOpRecord); - auto deserializeSizeOpRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeSizeOpRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_Size(); return new python_frontend::SizeOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), data->dim()); }; - registerParser(serde::RecordType_SizeOp, deserializeSizeOpRecord); + registerParser(RecordType::SizeOp, deserializeSizeOpRecord); - auto deserializeAtOpRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeAtOpRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_At(); return new python_frontend::AtOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), data->index()); }; - registerParser(serde::RecordType_AtOp, deserializeAtOpRecord); + registerParser(RecordType::AtOp, deserializeAtOpRecord); - auto deserializeVarianceRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeVarianceRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_Norm(); return new python_frontend::VarianceOpRecord( parseStateArgs(buffer->args()), @@ -622,9 +592,9 @@ void RecordFunctorFactory::registerAllParsers() { data->correction(), data->keep_dim()); }; - registerParser(serde::RecordType_VarianceOp, deserializeVarianceRecord); + registerParser(RecordType::VarianceOp, deserializeVarianceRecord); - auto deserializeVarianceMeanRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeVarianceMeanRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_Norm(); return new python_frontend::VarianceMeanOpRecord( parseStateArgs(buffer->args()), @@ -633,17 +603,16 @@ void RecordFunctorFactory::registerAllParsers() { data->correction(), data->keep_dim()); }; - registerParser( - serde::RecordType_VarianceMeanOp, deserializeVarianceMeanRecord); + registerParser(RecordType::VarianceMeanOp, deserializeVarianceMeanRecord); - auto deserializeVectorRecord = [](const serde::RecordFunctor* buffer) { + auto deserializeVectorRecord = [](const RecordFunctor* buffer) { auto data = buffer->data_as_Vector(); return new python_frontend::VectorRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), mapToNvfuserDtype(data->dtype())); }; - registerParser(serde::RecordType_Vector, deserializeVectorRecord); + registerParser(RecordType::Vector, deserializeVectorRecord); } void RecordFunctorFactory::setupFunctionMaps() { diff --git a/csrc/serde/fusion_record_serde.h b/csrc/serde/fusion_record_serde.h index 0e35f622606..bb6c7463a9e 100644 --- a/csrc/serde/fusion_record_serde.h +++ b/csrc/serde/fusion_record_serde.h @@ -70,9 +70,10 @@ typedef std::function //! RecordFunctor table. We create an enum type for each RecordFunctor class. //! Each template specialization has a unique RecordType and parser function. class RecordFunctorFactory - : public Factory { + : public Factory { public: - RecordFunctorFactory() : Factory((serde::RecordType_MAX + 1)) { + RecordFunctorFactory() + : Factory((nvfuser::toUnderlying(RecordType::MAX) + 1)) { setupFunctionMaps(); registerAllParsers(); } diff --git a/csrc/serde/polymorphic_value_serde.cpp b/csrc/serde/polymorphic_value_serde.cpp index ee6a620a75e..5300087a55c 100644 --- a/csrc/serde/polymorphic_value_serde.cpp +++ b/csrc/serde/polymorphic_value_serde.cpp @@ -14,14 +14,13 @@ namespace nvfuser::serde { namespace { -nvfuser::PolymorphicValue makeCpuScalarTensor( - const serde::ScalarCpu* scalar_cpu) { +nvfuser::PolymorphicValue makeCpuScalarTensor(const ScalarCpu* scalar_cpu) { NVF_ERROR(scalar_cpu != nullptr); auto scalar = deserializePolymorphicValue(scalar_cpu->scalar_value()); return nvfuser::PolymorphicValue_functions::toTensor(scalar, at::kCPU); } -nvfuser::PolymorphicValue getMetaTensorArg(const serde::TensorArg* tensor) { +nvfuser::PolymorphicValue getMetaTensorArg(const TensorArg* tensor) { NVF_ERROR(tensor != nullptr); if (tensor->strides() != nullptr) { auto meta_tensor = at::detail::empty_strided_meta( @@ -44,7 +43,7 @@ nvfuser::PolymorphicValue getMetaTensorArg(const serde::TensorArg* tensor) { } // namespace -nvfuser::PolymorphicValue deserializePolymorphicValue(const serde::Scalar* c) { +nvfuser::PolymorphicValue deserializePolymorphicValue(const Scalar* c) { if (!c->has_value()) { return {}; } @@ -69,26 +68,26 @@ nvfuser::PolymorphicValue deserializePolymorphicValue(const serde::Scalar* c) { } void PolymorphicValueFactory::registerAllParsers() { - auto deserializeScalar = [](const serde::PolymorphicValue* buffer) { + auto deserializeScalar = [](const PolymorphicValue* buffer) { return deserializePolymorphicValue(buffer->data_as_Scalar()); }; - registerParser(serde::PolymorphicValueData_Scalar, deserializeScalar); + registerParser(PolymorphicValueData::Scalar, deserializeScalar); - auto deserializeScalarCpu = [](const serde::PolymorphicValue* buffer) { + auto deserializeScalarCpu = [](const PolymorphicValue* buffer) { return makeCpuScalarTensor(buffer->data_as_ScalarCpu()); }; - registerParser(serde::PolymorphicValueData_ScalarCpu, deserializeScalarCpu); + registerParser(PolymorphicValueData::ScalarCpu, deserializeScalarCpu); // TODO Encode ptr field which corresponds to the aten tensor's data pointer. // It is used during scheduling for vectorization. A meta aten tensor assumes // that the pointer address is zero. - auto deserializeTensorArg = [](const serde::PolymorphicValue* buffer) { + auto deserializeTensorArg = [](const PolymorphicValue* buffer) { return getMetaTensorArg(buffer->data_as_TensorArg()); }; - registerParser(serde::PolymorphicValueData_TensorArg, deserializeTensorArg); + registerParser(PolymorphicValueData::TensorArg, deserializeTensorArg); } -flatbuffers::Offset serializeScalarCpu( +flatbuffers::Offset serializeScalarCpu( flatbuffers::FlatBufferBuilder& builder, const at::Tensor& tensor) { NVF_ERROR( @@ -118,7 +117,7 @@ flatbuffers::Offset serializeScalarCpu( } } -flatbuffers::Offset serializePolymorphicValue( +flatbuffers::Offset serializePolymorphicValue( flatbuffers::FlatBufferBuilder& builder, std::shared_ptr v) { NVF_ERROR(!v->is(), "PolymorphicValue is a std::monostate."); @@ -139,9 +138,9 @@ flatbuffers::Offset serializePolymorphicValue( if (tensor.is_cpu() && tensor.numel() == 1) { // CPU Scalar auto fb_scalar_data = serializeScalarCpu(builder, tensor); - auto data = serde::CreateScalarCpu(builder, fb_scalar_data); + auto data = CreateScalarCpu(builder, fb_scalar_data); return CreatePolymorphicValue( - builder, PolymorphicValueData_ScalarCpu, data.Union()); + builder, PolymorphicValueData::ScalarCpu, data.Union()); } else { // GPU Tensor // Convert IntArrayRef to std::vector for flatbuffer compatibility @@ -158,23 +157,23 @@ flatbuffers::Offset serializePolymorphicValue( strides_fb.push_back(tensor.stride(dim)); } - auto data = serde::CreateTensorArg( + auto data = CreateTensorArg( builder, (size_t)tensor.data_ptr(), builder.CreateVector(sizes_fb), builder.CreateVector(strides_fb), nvfuser::toUnderlying(tensor.scalar_type())); return CreatePolymorphicValue( - builder, PolymorphicValueData_TensorArg, data.Union()); + builder, PolymorphicValueData::TensorArg, data.Union()); } } else { auto data = serializeScalar(builder, *v, getDataType(*v)); return CreatePolymorphicValue( - builder, PolymorphicValueData_Scalar, data.Union()); + builder, PolymorphicValueData::Scalar, data.Union()); } } -flatbuffers::Offset serializeScalar( +flatbuffers::Offset serializeScalar( flatbuffers::FlatBufferBuilder& builder, const nvfuser::PolymorphicValue& v, nvfuser::DataType t) { @@ -206,7 +205,7 @@ flatbuffers::Offset serializeScalar( builder_.add_imag_value(std::imag(c)); return builder_.Finish(); } - NVF_ERROR(false, "Unable to convert ", v.type().name(), " to serde::Scalar."); + NVF_ERROR(false, "Unable to convert ", v.type().name(), " to Scalar."); } } // namespace nvfuser::serde diff --git a/csrc/serde/polymorphic_value_serde.h b/csrc/serde/polymorphic_value_serde.h index 5ebcb6c3e1e..8c37a768dc1 100644 --- a/csrc/serde/polymorphic_value_serde.h +++ b/csrc/serde/polymorphic_value_serde.h @@ -22,9 +22,10 @@ namespace nvfuser::serde { //! KernelArgumentHolder, which is used to schedule the fusion in //! FusionKernelRuntime and to run a kernel in FusionExecutor. class PolymorphicValueFactory - : public Factory { + : public Factory { public: - PolymorphicValueFactory() : Factory((serde::PolymorphicValueData_MAX + 1)) { + PolymorphicValueFactory() + : Factory((nvfuser::toUnderlying(PolymorphicValueData::MAX) + 1)) { registerAllParsers(); } @@ -32,17 +33,17 @@ class PolymorphicValueFactory void registerAllParsers(); }; -nvfuser::PolymorphicValue deserializePolymorphicValue(const serde::Scalar* c); +nvfuser::PolymorphicValue deserializePolymorphicValue(const Scalar* c); -flatbuffers::Offset serializePolymorphicValue( +flatbuffers::Offset serializePolymorphicValue( flatbuffers::FlatBufferBuilder& builder, std::shared_ptr v); -flatbuffers::Offset serializeScalarCpu( +flatbuffers::Offset serializeScalarCpu( flatbuffers::FlatBufferBuilder& builder, const at::Tensor& tensor); -flatbuffers::Offset serializeScalar( +flatbuffers::Offset serializeScalar( flatbuffers::FlatBufferBuilder& builder, const nvfuser::PolymorphicValue& v, nvfuser::DataType t); diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index ae13fb32343..6c9657b2f80 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -255,6 +255,10 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) compute_with_pos_(src->compute_with_pos_), promote_reuse_(src->promote_reuse_) {} +void TensorView::printTransforms() const { + IrTransformPrinter(std::cout).printTransforms(this); +} + // sets cpu_scalar_ value, which is special handling for CPU based zero-dim // tensors (i.e. CPU Tensors that only have one value). This is only used if // on an input value, otherwise ignored. This is important as special handling @@ -1375,17 +1379,20 @@ bool TensorView::isEmptyTensor() const { }); } -void TensorView::applyMmaSwizzle(MmaOptions options) { - switch (options.operand) { - case MmaOptions::Operand::Accumulator: - mma_utils::WarpMmaSwizzler::scheduleMmaWarpOutput(this, options); +void TensorView::applyMmaSwizzle(MmaOperand operand) { + switch (operand) { + case MmaOperand::Accumulator: + mma_utils::WarpMmaSwizzler::scheduleMmaWarpOutput(this); if (definition()->isA()) { setAllocationDomain(getLeafDomain(), true); } break; - case MmaOptions::Operand::A: - case MmaOptions::Operand::B: - mma_utils::WarpMmaSwizzler::scheduleOperandRead(this, options); + case MmaOperand::A: + case MmaOperand::B: + mma_utils::WarpMmaSwizzler::scheduleOperandRead(this, operand); + if (ir_utils::isLdMatrixOp(definition())) { + mma_utils::WarpMmaSwizzler::scheduleLdMatrix(this, operand); + } break; default: NVF_ERROR(false, "unknown operand flag"); @@ -1483,7 +1490,21 @@ TensorViewBuilder& TensorViewBuilder::strideOrder( NVF_CHECK(ndims_ == 0 || ndims_ == stride_order.size()); ndims_ = stride_order.size(); } - stride_order_ = std::move(stride_order); + + // TODO: this shouldn't be necessary. For details see issue + // https://github.com/NVIDIA/Fuser/issues/1399 + // + // skip stride_order if its alloc_domain is in the same order as with rfactor + // domain. We don't need this and we should be able to just use stride_order_, + // but currently alloc_domain support isn't ideal and could prevent + // vectorization. Adding this workaround to restore performance. + if (std::adjacent_find( + stride_order.begin(), stride_order.end(), [](int64_t l, int64_t r) { + return l <= r; + }) != stride_order.end()) { + // stride_order is not in descending order, we cannot skip it. + stride_order_ = std::move(stride_order); + } return *this; } diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index 5b11c89031d..543361b91f0 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -517,48 +517,17 @@ std::pair TransformReplay::replayPasC( } } - if (!opt.replay_allocation) { - TensorDomain* replayed = IrBuilder::create( - producer->container(), - producer->getRootDomain(), - producer->getRFactorDomain(), - producer->getAllocationDomain(), - new_IDs, - producer->domain()->contiguity()); - return {replayed, producer_pos}; - } + NVF_ERROR( + !opt.replay_allocation, + "replayAllocation is not implemented yet for TransformReplay::replayPasC"); TensorDomain* replayed = IrBuilder::create( producer->container(), producer->getRootDomain(), producer->getRFactorDomain(), - std::vector{}, + producer->getAllocationDomain(), new_IDs, producer->domain()->contiguity()); - - if (consumer->hasAllocation()) { - auto replay_PasC = BestEffortReplay( - new_IDs, - consumer->getLeafDomain(), - root_map.mapConsumerToProducer(consumer->domain(), replayed)); - const auto& c2p_map = replay_PasC.getReplay(); - std::vector new_allocation_domain; - new_allocation_domain.reserve(consumer->getAllocationDomain().size()); - for (auto id : consumer->getAllocationDomain()) { - auto it = c2p_map.find(id); - NVF_CHECK( - it != c2p_map.end(), - "Unable to replayPasC: can not map ", - id->toString(), - " in the allocation domain of consumer tensor ", - consumer->toString(), - " to producer tensor ", - producer->toString()); - new_allocation_domain.emplace_back(it->second); - } - replayed->setAllocationDomain(std::move(new_allocation_domain), true); - } - return {replayed, producer_pos}; } @@ -572,8 +541,9 @@ std::pair TransformReplay::replayCasP( // If this is a reduction operation, we may call transform_replay on the same // tensor view. When this happens, just return thet target view. - if (consumer == producer) + if (consumer == producer) { return {consumer->domain(), consumer->nDims()}; + } if (producer_pos < 0) { producer_pos += (int64_t)producer->nDims() + 1; @@ -801,12 +771,18 @@ std::pair TransformReplay::replayCasP( return {replayed, consumer_pos}; } + NVF_ERROR( + consumer->definition()->isA() && !consumer->hasRFactor(), + "TransformReplay::replayCasP currently replays allocation only for Set. " + "Other ops (e.g. `consumer = broadcast(producer)`) can break. " + "See https://github.com/NVIDIA/Fuser/pull/1291#discussion_r1391999007 for details."); + TensorDomain* replayed = IrBuilder::create( consumer->container(), consumer->getRootDomain(), consumer->getRFactorDomain(), - std::vector{}, - new_IDs, + /*allocation=*/std::vector{}, + /*leaf=*/new_IDs, consumer->domain()->contiguity()); if (producer->hasAllocation()) { @@ -815,21 +791,24 @@ std::pair TransformReplay::replayCasP( producer->getLeafDomain(), root_map.mapProducerToConsumer(producer->domain(), replayed)); const auto& p2c_map = replay_CasP.getReplay(); + + auto producer_rank = producer->getAllocationDomain().size(); std::vector new_allocation_domain; - new_allocation_domain.reserve(producer->getAllocationDomain().size()); - for (auto id : producer->getAllocationDomain()) { - auto it = p2c_map.find(id); - NVF_CHECK( - it != p2c_map.end(), - "Unable to replayCasP: can not map ", - id->toString(), - " in the allocation domain of producer tensor ", - producer->toString(), - " to consumer tensor ", - consumer->toString()); - new_allocation_domain.emplace_back(it->second); + new_allocation_domain.reserve(producer_rank); + std::vector> new_contiguity; + new_contiguity.reserve(producer_rank); + + for (auto i : c10::irange(producer_rank)) { + IterDomain* id = producer->getAllocationDomain()[i]; + // We won't find reduction IterDomains in the map. See + // AllocationDomainTest.CacheBefore. + if (auto it = p2c_map.find(id); it != p2c_map.end()) { + new_allocation_domain.push_back(it->second); + new_contiguity.push_back(producer->getContiguity()[i]); + } } - replayed->setAllocationDomain(std::move(new_allocation_domain), true); + replayed->setAllocationDomain( + std::move(new_allocation_domain), std::move(new_contiguity)); } return {replayed, consumer_pos}; } diff --git a/csrc/transform_view.cpp b/csrc/transform_view.cpp index d3e8db37cef..f138ce0ba8e 100644 --- a/csrc/transform_view.cpp +++ b/csrc/transform_view.cpp @@ -211,14 +211,8 @@ class MergeTransform final : public ViewTransform { "Didn't expect to apply view transformations on an iter domain", " starting at a non-zero position."); - auto merged_extent = mul(outer_id->extent(), inner_id->extent()); - auto new_merged_id = - IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), merged_extent) - .is_rfactor_domain(true) - .build(); - - IrBuilder::create(new_merged_id, outer_id, inner_id); + IterDomain::merge(outer_id, inner_id, /*rfactor_domain*/ true); current_transformed_domain.erase( current_transformed_domain.begin() + index_); @@ -277,23 +271,13 @@ class SplitTransform final : public ViewTransform { "Didn't expect to apply view transformations on an iter domain", " starting at a non-zero position."); - Val* remainder = ceilDiv(id->extent(), factor); - - // outer loop IterDomain - IterDomain* factor_id = - IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), factor) - .parallel_type(id->getParallelType()) - .iter_type(id->getIterType()) - .is_rfactor_domain(true) - .build(); - - // inner loop IterDomain - IterDomain* remainder_id = - IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), remainder) - .is_rfactor_domain(true) - .build(); - - IrBuilder::create(factor_id, remainder_id, id, factor, false); + auto [factor_id, remainder_id] = IterDomain::split( + id, + factor, + /*inner_split=*/false, + /*start_offset=*/nullptr, + /*stop_offset=*/nullptr, + /*rfactor_domain=*/true); current_transformed_domain.erase( current_transformed_domain.begin() + index_); diff --git a/csrc/type.cpp b/csrc/type.cpp index a51b19659e1..189c6d18717 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -338,6 +338,7 @@ bool needFloatSuffix(UnaryOpType t) { case UnaryOpType::IsReal: case UnaryOpType::Print: case UnaryOpType::ToUnsignedSmemAddr: + case UnaryOpType::AdjustPartialLdMatrixAddrInTuring: return false; default: return true; @@ -456,6 +457,8 @@ static const char* unary_op_type2string(UnaryOpType t) { return "std::imag"; case UnaryOpType::ToUnsignedSmemAddr: return "toSmem"; + case UnaryOpType::AdjustPartialLdMatrixAddrInTuring: + return "Turing::adjustPartialLdMatrixAddrInTuring"; default: NVF_ERROR(false, "No string found for unary op type."); } @@ -765,6 +768,8 @@ static const char* id_map_mode_type2string(IdMappingMode t) { return "permissive"; case IdMappingMode::PERMISSIVE_RESIZE: return "permissive_resize"; + case IdMappingMode::INNERMOST: + return "innermost"; default: // Don't try to print t as it would recursively call this function NVF_ERROR(false, "Unexpected IdMappingMode Type."); diff --git a/csrc/type.h b/csrc/type.h index 17efdc38480..e104a61e4f7 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -512,6 +512,10 @@ inline bool hasCompatibleDataType( int max_digits10(DataType dtype); enum class UnaryOpType { + Cast, + BitCast, + RefCast, + Abs, Acos, Acosh, @@ -520,7 +524,6 @@ enum class UnaryOpType { Asinh, Atan, Atanh, - Cast, Ceil, Cos, Cosh, @@ -542,7 +545,6 @@ enum class UnaryOpType { Log10, Log1p, Log2, - BitCast, Neg, Real, Reciprocal, @@ -574,7 +576,8 @@ enum class UnaryOpType { IsReal, // Special unary ops - ToUnsignedSmemAddr + ToUnsignedSmemAddr, + AdjustPartialLdMatrixAddrInTuring }; // TODO: Order of this list is important as it affects type promotion. it's not diff --git a/lib/dynamic_type/CMakeLists.txt b/lib/dynamic_type/CMakeLists.txt index 73250ef7a40..57dae786520 100644 --- a/lib/dynamic_type/CMakeLists.txt +++ b/lib/dynamic_type/CMakeLists.txt @@ -24,7 +24,11 @@ if(BUILD_TEST) test/unary_ops.cpp ) target_include_directories(${target} PUBLIC src) - target_link_libraries(${target} PRIVATE gtest_main gmock_main) + target_include_directories(${target} SYSTEM PRIVATE + ${CMAKE_SOURCE_DIR}/third_party/googletest/googletest/include + ${CMAKE_SOURCE_DIR}/third_party/googletest/googlemock/include + ) + target_link_libraries(${target} PRIVATE GTest::gtest_main GTest::gmock_main) set_property(TARGET ${target} PROPERTY CXX_STANDARD ${std_version}) endfunction() diff --git a/python_tests/pytest_input_generators.py b/python_tests/pytest_input_generators.py index 947c67d6407..5c3038712ab 100644 --- a/python_tests/pytest_input_generators.py +++ b/python_tests/pytest_input_generators.py @@ -21,6 +21,7 @@ complex_dtypes, ) from nvfuser import DataType +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype MINIMUM_SYMBOLIC_SIZE = -1 INT64_MAX = 2**63 - 1 @@ -736,16 +737,11 @@ def full_error_generator( op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs ): # torch.full(size, fill_value, dtype=None) - - make_arg = partial( - make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad - ) - # Error: Trying to create tensor with negative dimension negative_input_shape = [2, -2] yield SampleInput( negative_input_shape, make_number(dtype), dtype - ), RuntimeError, "extent_int >= 0" + ), RuntimeError, "The value -2 at index 1 was neither symbolic(-1), zero_element(0), broadcast(1), or static(>1)." def gather_generator( @@ -1011,6 +1007,18 @@ def permute_error_generator( ), RuntimeError, "argument to have the same length as input" +def random_dist_error_generator( + op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs +): + # Checking that non-supported dtypes fail + yield SampleInput( + make_number(torch.float), + make_number(torch.float), + [2, 2], + dtype=torch_dtype_to_nvfuser_dtype(dtype), + ), RuntimeError, "Random distributions only create floating point types" + + def reduction_generator( op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs ): diff --git a/python_tests/pytest_opinfos.py b/python_tests/pytest_opinfos.py index 796485b2c64..6b5c9cce804 100644 --- a/python_tests/pytest_opinfos.py +++ b/python_tests/pytest_opinfos.py @@ -35,6 +35,7 @@ pad_error_generator, permute_generator, permute_error_generator, + random_dist_error_generator, reduction_error_generator, reshape_generator, reshape_error_generator, @@ -49,8 +50,9 @@ ) from pytest_utils import ( bool_int_dtypes, - int_dtypes, + complex_dtypes, full_precision_float_dtypes, + int_dtypes, int_float_dtypes, float_complex_dtypes, ArgumentType, @@ -1064,6 +1066,38 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): ) tensor_creation_ops.append(iota_opinfo) +# NOTE: normal's python API does not produce value based errors given most parameters are +# symbolic as Scalar or Vector parameters. The dtype parameter is checked to make sure the +# user does not ask for non-floating point random numbers. +uniform_opinfo = OpInfo( + lambda fd: fd.ops.normal, + "normal", + dtypes=(bool_int_dtypes + complex_dtypes), + error_input_generator=random_dist_error_generator, + symbolic_parameter_list=( + ArgumentType.ConstantScalar, + ArgumentType.ConstantScalar, + ArgumentType.Constant, + ), +) +tensor_creation_ops.append(uniform_opinfo) + +# NOTE: uniform's python API does not produce value based errors given most parameters are +# symbolic as Scalar or Vector parameters. The dtype parameter is checked to make sure the +# user does not ask for non-floating point random numbers. +uniform_opinfo = OpInfo( + lambda fd: fd.ops.uniform, + "uniform", + dtypes=(bool_int_dtypes + complex_dtypes), + error_input_generator=random_dist_error_generator, + symbolic_parameter_list=( + ArgumentType.ConstantScalar, + ArgumentType.ConstantScalar, + ArgumentType.Constant, + ), +) +tensor_creation_ops.append(uniform_opinfo) + """ End Tensor Creation """ # Puts all opinfos into the "opinfos" list diff --git a/python_tests/test_python_frontend.py b/python_tests/test_python_frontend.py index 3ed6c2bd8b9..868dec315f5 100644 --- a/python_tests/test_python_frontend.py +++ b/python_tests/test_python_frontend.py @@ -28,9 +28,15 @@ version, compute_contiguity, compute_tensor_descriptor, + serialize as nv_serialize, ) from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype +# Test automatic serialization to common workplace +import atexit + +atexit.register(nv_serialize) + RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM @@ -41,6 +47,13 @@ def is_pre_volta(): return prop.major < 7 +def is_pre_ampere(): + if not RUN_NVFUSER: + return False + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + return prop.major < 8 + + def serde_check(test_fn: Callable): """ A decorator to verify that serialization works with the given exec_nvfuser function. @@ -954,8 +967,7 @@ def fusion_func(fd: FusionDefinition): t0 = fd.from_pytorch(inputs[0]) s_mean = fd.define_scalar(mean) s_std = fd.define_scalar(std) - size = fd.ops.tensor_sizes(t0) - t1 = fd.ops.normal(s_mean, s_std, size, DataType.Double) + t1 = fd.ops.normal(s_mean, s_std, t0.shape(), dtype=DataType.Double) fd.add_output(t1) nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) @@ -992,8 +1004,7 @@ def fusion_func(fd: FusionDefinition): t0 = fd.from_pytorch(inputs[0]) s_lo = fd.define_scalar(lo) s_hi = fd.define_scalar(hi) - size = fd.ops.tensor_sizes(t0) - t1 = fd.ops.uniform(s_lo, s_hi, size, DataType.Double) + t1 = fd.ops.uniform(s_lo, s_hi, t0.shape(), dtype=DataType.Double) fd.add_output(t1) nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) @@ -1376,25 +1387,61 @@ def test_compute_contiguity(self): self.assertEqual(compute_contiguity(sizes, strides), contiguity) def test_compute_tensor_descriptor(self): - sizes = [2, 1, 3, 1, 4, 3] - strides = [12, 4, 4, 4, 1, 0] - contiguity = [True, None, True, None, True, None] - stride_order = [5, 4, 3, 2, 1, 0] - computed_contiguity, computed_stride_order = compute_tensor_descriptor( - sizes, strides - ) - self.assertEqual(computed_contiguity, contiguity) - self.assertEqual(computed_stride_order, stride_order) - - sizes = [2, 3, 1, 5, 4] - strides = [28, 4, 14, 0, 1] - contiguity = [False, None, True, True, None] - stride_order = [4, 2, 3, 0, 1] - computed_contiguity, computed_stride_order = compute_tensor_descriptor( - sizes, strides + configs = ( + ( + # size + [2, 1, 3, 1, 4, 3], + # stride + [12, 4, 4, 4, 1, 0], + # expected contiguity + [True, None, True, None, True, None], + # expected stride_order + [5, 4, 3, 2, 1, 0], + ), + ( + [2, 3, 1, 5, 4], + [28, 4, 14, 0, 1], + [False, None, True, None, True], + [4, 2, 3, 1, 0], + ), + ( + [2, 2, 1, 1, 2, 2, 2], + [8, 4, 3, 9, 2, 0, 1], + [None, True, True, None, True, None, True], + [5, 4, 3, 6, 2, 1, 0], + ), + ( + [2, 2, 1, 2, 4, 2], + [2, 32, 1, 8, 0, 4], + [False, True, True, False, None, None], + [2, 5, 0, 4, 1, 3], + ), + ( + [2, 2, 2, 2], + [8, 4, 2, 1], + [True, True, True, True], + [3, 2, 1, 0], + ), + ( + [2, 1, 3, 1, 4], + [24, 4, 8, 4, 2], + [True, True, None, None, False], + [4, 2, 3, 1, 0], + ), + ( + [2, 2, 2, 2], + [8, 4, 0, 2], + [True, True, None, False], + [3, 2, 1, 0], + ), ) - self.assertEqual(computed_contiguity, contiguity) - self.assertEqual(computed_stride_order, stride_order) + + for sizes, strides, contiguity, stride_order in configs: + computed_contiguity, computed_stride_order = compute_tensor_descriptor( + sizes, strides + ) + self.assertEqual(computed_contiguity, contiguity) + self.assertEqual(computed_stride_order, stride_order) def test_stride_order_with_explicit_broadcast(self): inputs = [ @@ -2472,14 +2519,15 @@ def fusion_func(fd: FusionDefinition, *, deterministic) -> None: t1 = fd.from_pytorch(inputs[0]) a = fd.define_scalar(0.3, DataType.Float) b = fd.define_scalar(1.7, DataType.Float) - shape = [fd.define_scalar(5), fd.define_scalar(9)] randop = getattr(fd.ops, randopname) if deterministic: rng_seed = fd.define_scalar(DataType.Int) rng_offset = fd.define_scalar(DataType.Int) - u = randop(a, b, shape, rng_seed=rng_seed, rng_offset=rng_offset) + u = randop( + a, b, shape=[5, 9], rng_seed=rng_seed, rng_offset=rng_offset + ) else: - u = randop(a, b, shape) + u = randop(a, b, shape=[5, 9]) t2 = t1 * u fd.add_output(t2) @@ -2516,7 +2564,6 @@ def fusion_func(fd: FusionDefinition, *, deterministic) -> None: except AssertionError as e: print(f"Assertion failed for iteration {i} with seed {seed}") print(e) - break # Test expand to zero is replaced with expanded extent and not 1 # see https://github.com/NVIDIA/Fuser/issues/603 @@ -2580,6 +2627,30 @@ def dynamic_reshape(fd: FusionDefinition) -> None: self.assertEqual(y.shape, torch.Size([3, 2, 2])) self.assertEqual(x.flatten(), y.flatten()) + def test_allocation_domain_concretization(self): + inputs = [ + # we need an empty tensor here so we'll trigger `concretizeEmptyExtents` + torch.randn((0,), dtype=torch.float64, device="cuda:0").as_strided( + (1, 0, 1, 1), (0, 1, 1, 1) + ), + ] + + def fusion_func(fd: FusionDefinition) -> None: + T1 = fd.define_tensor( + shape=[1, -1, 1, 1], + contiguity=[True, None, None, None], + dtype=DataType.Double, + is_cpu=False, + stride_order=[0, 3, 2, 1], + ) + S1 = fd.define_scalar(2.0, dtype=DataType.Double) + T2 = fd.ops.mul(T1, S1) + fd.add_output(T2) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + torch_ref = inputs[0] * 2.0 + self.assertEqual(nvf_out[0], torch_ref) + def test_allocation_domain_index_select(self): inputs = [ torch.randn((252,), dtype=torch.float32, device="cuda:0").as_strided( @@ -2664,6 +2735,36 @@ def reshape(fd: FusionDefinition) -> None: self.assertEqual(y.data_ptr(), x.data_ptr()) + # Test that reshape to slice to sum with concrete sizes sets extents properly + # https://github.com/NVIDIA/Fuser/issues/1221 + def test_sum_sliced_reshape_to_broadcast(self): + inputs = [torch.randn((24, 128, 25, 32), dtype=torch.float32, device="cuda:0")] + + def fusion_func(fd: FusionDefinition) -> None: + T18 = fd.define_tensor( + shape=[-1, -1, -1, -1], + contiguity=[True, True, True, True], + dtype=DataType.Float, + is_cpu=False, + ) + S91 = fd.define_scalar(12, dtype=DataType.Int) + S92 = fd.define_scalar(128, dtype=DataType.Int) + S93 = fd.define_scalar(25, dtype=DataType.Int) + S94 = fd.define_scalar(32, dtype=DataType.Int) + S95 = fd.define_scalar(2, dtype=DataType.Int) + V96 = fd.define_vector([S91, S92, S93, S94, S95], dtype=DataType.Int) + T97 = fd.ops.reshape(T18, new_shape=V96) + T98 = fd.ops.slice( + T97, + start_indices=[0, 0, 0, 0, 0], + end_indices=[12, 128, 25, 32, 1], + strides=[1, 1, 1, 1, 1], + ) + T89 = fd.ops.sum(T98, axes=[4], keepdim=False, dtype=DataType.Null) + fd.add_output(T89) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + # This tests no dead code at definition does not cause a problem due to # removal of empty tensors # See https://github.com/NVIDIA/Fuser/pull/1270 @@ -2710,6 +2811,125 @@ def fusion_func(fd: FusionDefinition) -> None: self.assertEqual(nvf_out[0], t24) self.assertEqual(nvf_out[1], t11) + # This tests squeeze of dynamic input is handled properly + def test_issue1273(self): + inputs = [ + torch.randn((4,), dtype=torch.float32, device="cuda:0").as_strided( + (2, 2), (2, 1) + ), + 1e-05, + ] + + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.Float, + is_cpu=False, + ) + S1 = fd.define_scalar(None, dtype=DataType.Double) + T7 = fd.ops.reshape(T0, new_shape=[2, 1, 2]) + T8, T9 = fd.ops.var_mean(T7, axes=[2], correction=0, keepdim=False) + T14 = fd.ops.broadcast_in_dim(T8, shape=[2, 1, 1], broadcast_dims=[0, 1]) + T19 = fd.ops.broadcast_in_dim(T9, shape=[2, 1, 1], broadcast_dims=[0, 1]) + T20 = fd.ops.add(T14, S1) + T21 = fd.ops.rsqrt(T20) + T26 = fd.ops.broadcast_in_dim( + T19, shape=[2, 1, 2], broadcast_dims=[0, 1, 2] + ) + T27 = fd.ops.sub(T7, T26) + T32 = fd.ops.broadcast_in_dim( + T21, shape=[2, 1, 2], broadcast_dims=[0, 1, 2] + ) + T33 = fd.ops.mul(T27, T32) + T37 = fd.ops.reshape(T33, new_shape=[2, 2]) + fd.add_output(T37) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + t7 = inputs[0].reshape((2, 1, 2)) + t8 = t7.var(dim=2, unbiased=False) + t9 = t7.mean(dim=2) + t27 = t7 - t9.unsqueeze(-1).expand((2, 1, 2)) + t32 = torch.rsqrt(inputs[1] + t8.unsqueeze(-1)).expand((2, 1, 2)) + torch_ref = (t27 * t32).reshape((2, 2)) + self.assertEqual(nvf_out[0], torch_ref) + + # See https://github.com/NVIDIA/Fuser/issues/1246 + def test_issue1246(self): + inputs = [ + torch.randn((8388608,), dtype=torch.float32, device="cuda:0").as_strided( + (1, 32, 2048, 128), (8388608, 262144, 128, 1) + ), + torch.randn((0,), dtype=torch.float32, device="cuda:0").as_strided( + (1, 32, 2048, 0), (8388608, 262144, 128, 1) + ), + ] + + for final_mul in [False, True]: + + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[1, -1, -1, -1], + contiguity=[None, True, True, True], + dtype=DataType.Float, + is_cpu=False, + ) + T1 = fd.define_tensor( + shape=[1, -1, -1, -1], + contiguity=[None, True, False, True], + dtype=DataType.Float, + is_cpu=False, + ) + S2 = fd.define_scalar(2.00000, dtype=DataType.Double) + T3 = fd.ops.mul(T0, S2) + T4 = fd.ops.cat([T3, T1], dim=-1) + if final_mul: + # NOTE: original repro does not have this final op + S3 = fd.define_scalar(1.00000, dtype=DataType.Double) + T5 = fd.ops.mul(T4, S3) + fd.add_output(T5) + else: + fd.add_output(T4) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + torch_ref = torch.cat([2.0 * inputs[0], inputs[1]], dim=-1) + self.assertEqual(nvf_out[0], torch_ref) + + # Test that inputs are properly forwarded when an input is used in multiple + # UnaryOps, some having one and others having multiple further uses. + # See https://github.com/NVIDIA/Fuser/issues/1301#issuecomment-1812470502 + @unittest.skipIf(is_pre_ampere(), "Only supported on Ampere and newer devices.") + def test_issue1310(self): + inputs = [torch.randn((16, 128, 768), dtype=torch.bfloat16, device="cuda:0")] + + def fusion_func(fd: FusionDefinition) -> None: + T3 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=[True, True, True], + dtype=DataType.BFloat16, + is_cpu=False, + ) + T14 = fd.ops.cast( + T3, dtype=DataType.Float + ) # NOTE that RHS is same, but the result is assigned to different variables + T15 = fd.ops.cast( + T3, dtype=DataType.Float + ) # NOTE that RHS is same, but the result is assigned to different variables + T16 = fd.ops.sum(T15, axes=[0, 1], keepdim=False, dtype=DataType.Null) + T20 = fd.ops.sum(T14, axes=[0, 1], keepdim=False, dtype=DataType.Null) + T31 = fd.ops.sum(T14, axes=[2], keepdim=False, dtype=DataType.Null) + fd.add_output(T16) + fd.add_output(T20) + fd.add_output(T31) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + t14 = inputs[0].type(torch.float32) + t16 = t14.sum([0, 1]) + t31 = t14.sum([2]) + self.assertEqual(nvf_out[0], t16) + self.assertEqual(nvf_out[1], t16) # T16 == T20 + self.assertEqual(nvf_out[2], t31) + if __name__ == "__main__": run_tests() diff --git a/runtime/memory.cu b/runtime/memory.cu index ee27c33ece1..7b349527a39 100644 --- a/runtime/memory.cu +++ b/runtime/memory.cu @@ -24,8 +24,6 @@ __device__ inline unsigned toSmem(const void* raw_ptr) { namespace Turing { -namespace util { - // LdMatrix has .x1, .x2 and .x4 options, currently we actively use .x2 and // .x4. In .x2 option. the the address register of upper half warp (lane 16-31) // are un-used but on Turing [sm75,sm80) architecture these un-used addresses @@ -42,8 +40,8 @@ namespace util { // hardware. // The alignment requirement is lifted on sm80+, // so this function is a no-op on Ampere or above. -__device__ inline void adjustPartialLdMatrixAddrInTuring( - unsigned& addr_in_byte) { +__device__ inline unsigned adjustPartialLdMatrixAddrInTuring( + unsigned addr_in_byte) { const unsigned thread_id = threadIdx.x; // Upper half warp has 8 bytes offset from aligned in .x2 option // of ldmatrix. Currently no support for .x1 so assume always @@ -57,121 +55,13 @@ __device__ inline void adjustPartialLdMatrixAddrInTuring( // mask out the bits where adjust_mask has 1. addr_in_byte &= (~mask_out); } -} - -} // namespace util - -// Load Matrix (per warp instruction) is to take data from SMEM to Local Memory. -// Automatically handles vectorized loads/stores in the MMA operation. -// Loads 8x8 matrix into a warp. Thread 0-7 provide the ptr that is the start -// of each row. All other threads can simply point to something valid -// (including 0). -// The x2 modifier on the instruction will actually load 2x8 rows to make a -// 16x8, -// then thread 0-15 will specify the start of each row. -// Finally is an x4 modifier producing a 32x8 using addrs from 0-31 in each -// warp. - -__device__ inline void ldMatrix(Array& out, unsigned addr) { -#if (__CUDA_ARCH__ < 800) - util::adjustPartialLdMatrixAddrInTuring(addr); -#endif - asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];" - : "=r"(out[0]), "=r"(out[1]) - : "r"(addr)); -} - -// Same as previous, 8x8 matrix is vectorized loaded, then scattered (to perform -// transpose) so threads will hold 2 values down a column (instead of the -// previous instruction that's across a row). -__device__ inline void ldMatrixT(Array& out, unsigned addr) { -#if (__CUDA_ARCH__ < 800) - util::adjustPartialLdMatrixAddrInTuring(addr); -#endif - asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];" - : "=r"(out[0]), "=r"(out[1]) - : "r"(addr)); -} - -__device__ inline void ldMatrix(Array& out, unsigned addr) { - asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" - : "=r"(out[0]), "=r"(out[1]), "=r"(out[2]), "=r"(out[3]) - : "r"(addr)); -} - -__device__ inline void ldMatrixT(Array& out, unsigned addr) { - asm volatile( - "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" - : "=r"(out[0]), "=r"(out[1]), "=r"(out[2]), "=r"(out[3]) - : "r"(addr)); + return addr_in_byte; } } // namespace Turing #endif // Arch 75 -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - -namespace Ampere { - -// MMA instruction wrappers (sm_80+): - -// Global to SMEM load that is asynchronous, -// not guaranteed to be completed until cpAsyncBarrier() is called. -// if predicate is set to false, then gmem_ptr won't be read and smem_addr will -// be zero-initialized gmem_ptr must be `sizeof(dtype) * len` aligned -template -__device__ inline void cpAsyncCa( - unsigned smem_addr, - void const* gmem_ptr, - bool predicate) { - constexpr int byte_size = sizeof(dtype) * len; - - static_assert( - byte_size == 4 || byte_size == 8 || byte_size == 16, - "cp_async : unsupported byte size"); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.eq.b32 p, %3, 0;\n" - " cp.async.ca.shared.global [%0], [%1], %2, p;\n" - "}\n" ::"r"(smem_addr), - "l"(gmem_ptr), - "n"(byte_size), - "r"((int)predicate)); -} - -// Global to SMEM load that is asynchronous, -// The cache global variant, i.e. skip L1 caching. -// more details see: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators -// not guaranteed to be completed until cpAsyncBarrier() is called. -// if predicate is set to false, then gmem_ptr won't be read and smem_addr will -// be zero-initialized gmem_ptr must be 16B aligned -template -__device__ inline void cpAsyncCg( - unsigned smem_addr, - void const* gmem_ptr, - bool predicate) { - constexpr int byte_size = sizeof(dtype) * len; - - static_assert(byte_size == 16, "cp_async : unsupported byte size"); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.eq.b32 p, %2, 0;\n" - " cp.async.cg.shared.global [%0], [%1], 16, p;\n" - "}\n" ::"r"(smem_addr), - "l"(gmem_ptr), - "r"((int)predicate)); -} - -} // namespace Ampere - -#endif // Arch 80 - #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) namespace Hopper { @@ -284,18 +174,6 @@ struct CpAsyncBulkTensorTileS2GIndex { Array crds; }; -__device__ inline void cpAsyncBulkS2GCommit() { - asm volatile("cp.async.bulk.commit_group;"); -} - -template -__device__ inline void cpAsyncBulkS2GPartialReadBarrier() { - asm volatile("cp.async.bulk.wait_group.read %0;" - : - : "n"(keep_stages) - : "memory"); -} - __device__ inline void cpAsyncBulkTensorTileS2G( const CpAsyncBulkTensorTileS2GIndex<1>& dest, uint32_t smem_addr) { diff --git a/runtime/tensorcore.cu b/runtime/tensorcore.cu deleted file mode 100644 index 1373a9eb985..00000000000 --- a/runtime/tensorcore.cu +++ /dev/null @@ -1,270 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -// Utility macro for this file - -// MMA instruction wrappers: -// The wrappers are subroutines that implement matrix of size -// A(M,K) X B(K,N) = C(M,N) -// The naming of the wrappers follow similar naming conventions -// as the mma instructions. -// All the mma macros follow the namespace and naming like -// Arch::M (M-dim) N (N-dim) K(K-dim) (Layout), eg. -// Volta::M16N16K4TT, -// with the dimensions describing the size of the sub-matrices being -// multiplied by this wrapper. -// see [Operand Layout Convention] in mma_type.h for details on the layout -// notation. -namespace Volta { - -// MMA instruction wrappers (sm_70+): -// The instruction wrappers below are quarter-warp macros, which currently -// nvfuser doesn't explicitly model. -// So they are currently only meant to be -// used as building blocks in warp level mma macros - -// 8x8x4 mma instruction, per quarter warp (8 threads), fp32 accumulate -// per thread register: -// A[4] x B[4] -> C[8] - -__device__ inline void M16N16K4TT( - Array& C, - Array& A, - Array& B) { - asm("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" - : "=f"(C[0]), - "=f"(C[1]), - "=f"(C[2]), - "=f"(C[3]), - "=f"(C[4]), - "=f"(C[5]), - "=f"(C[6]), - "=f"(C[7]) - : "r"(A[0]), - "r"(A[1]), - "r"(B[0]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3]), - "f"(C[4]), - "f"(C[5]), - "f"(C[6]), - "f"(C[7])); -} - -__device__ inline void M16N16K4TN( - Array& C, - Array& A, - Array& B) { - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" - : "=f"(C[0]), - "=f"(C[1]), - "=f"(C[2]), - "=f"(C[3]), - "=f"(C[4]), - "=f"(C[5]), - "=f"(C[6]), - "=f"(C[7]) - : "r"(A[0]), - "r"(A[1]), - "r"(B[0]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3]), - "f"(C[4]), - "f"(C[5]), - "f"(C[6]), - "f"(C[7])); -} - -__device__ inline void M16N16K4NT( - Array& C, - Array& A, - Array& B) { - asm("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" - : "=f"(C[0]), - "=f"(C[1]), - "=f"(C[2]), - "=f"(C[3]), - "=f"(C[4]), - "=f"(C[5]), - "=f"(C[6]), - "=f"(C[7]) - : "r"(A[0]), - "r"(A[1]), - "r"(B[0]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3]), - "f"(C[4]), - "f"(C[5]), - "f"(C[6]), - "f"(C[7])); -} - -__device__ inline void M16N16K4NN( - Array& C, - Array& A, - Array& B) { - asm("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" - : "=f"(C[0]), - "=f"(C[1]), - "=f"(C[2]), - "=f"(C[3]), - "=f"(C[4]), - "=f"(C[5]), - "=f"(C[6]), - "=f"(C[7]) - : "r"(A[0]), - "r"(A[1]), - "r"(B[0]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3]), - "f"(C[4]), - "f"(C[5]), - "f"(C[6]), - "f"(C[7])); -} - -// Same initialization for now, will be different in interleaved -// macros -__device__ inline void initM16N16K4(Array& accumulator) { - accumulator.set(0); -} - -} // namespace Volta - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) - -namespace Turing { - -__device__ inline void initM16N8K16(Array& accumulator) { - accumulator.set(0); -} - -__device__ inline void M16N8K16TN( - Array& C, - Array& A, - Array& B) { - asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[0]), - "r"(A[1]), - "r"(B[0]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3])); - asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[2]), - "r"(A[3]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3])); -} - -__device__ inline void initM16N16K16(Array& accumulator) { - accumulator.set(0); -} - -__device__ inline void M16N16K16TN( - Array& C, - Array& A, - Array& B) { - auto* _C = reinterpret_cast*>(&C); - auto* _B = reinterpret_cast*>(&B); - M16N8K16TN(_C[0], A, _B[0]); - M16N8K16TN(_C[1], A, _B[1]); -} - -} // namespace Turing - -#endif // Arch 75 - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - -namespace Ampere { - -__device__ inline void initM16N8K16(Array& accumulator) { - accumulator.set(0); -} - -__device__ inline void M16N8K16TNF16( - Array& C, - Array& A, - Array& B) { - asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[0]), - "r"(A[1]), - "r"(A[2]), - "r"(A[3]), - "r"(B[0]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3])); -} - -__device__ inline void M16N8K16TNBF16( - Array& C, - Array& A, - Array& B) { - asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[0]), - "r"(A[1]), - "r"(A[2]), - "r"(A[3]), - "r"(B[0]), - "r"(B[1]), - "f"(C[0]), - "f"(C[1]), - "f"(C[2]), - "f"(C[3])); -} - -__device__ inline void initM16N16K16(Array& accumulator) { - accumulator.set(0); -} - -__device__ inline void M16N16K16TNF16( - Array& C, - Array& A, - Array& B) { - auto* _C = reinterpret_cast*>(&C); - auto* _B = reinterpret_cast*>(&B); - M16N8K16TNF16(_C[0], A, _B[0]); - M16N8K16TNF16(_C[1], A, _B[1]); -} - -__device__ inline void M16N16K16TNBF16( - Array& C, - Array& A, - Array& B) { - auto* _C = reinterpret_cast*>(&C); - auto* _B = reinterpret_cast*>(&B); - M16N8K16TNBF16(_C[0], A, _B[0]); - M16N8K16TNBF16(_C[1], A, _B[1]); -} - -} // namespace Ampere - -#endif // Arch 80 diff --git a/test/multidevice.cpp b/test/multidevice.cpp index a7b0ed3ecb0..c6522bf5810 100644 --- a/test/multidevice.cpp +++ b/test/multidevice.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -59,6 +60,9 @@ void MultiDeviceTest::TearDown() { void CommunicationTest::SetUp() { MultiDeviceTest::SetUp(); + if (!communicator->isBackendAvailable(GetParam())) { + GTEST_SKIP() << "Backend not available"; + } all_ranks = std::vector(communicator->size()); std::iota(all_ranks.begin(), all_ranks.end(), 0); } @@ -82,36 +86,77 @@ void CommunicationTest::resetDstBuffers() { namespace { +void unshardTv(TensorView* tv) { + for (IterDomain* id : tv->getLeafDomain()) { + if (id->isDeviceDim()) { + id->parallelize(ParallelType::Serial); + } + } +} + +void doSendRecv( + DeviceIdxType sender, + DeviceIdxType receiver, + at::Tensor send_buf, + at::Tensor recv_buf, + Communicator* communicator) { + CommParams params; + params.root = sender; + if (sender == receiver) { + params.team = {sender}; + } else { + params.team = {sender, receiver}; + } + if (send_buf.numel()) { + params.src_bufs = {send_buf}; + } + if (recv_buf.numel()) { + params.dst_bufs = {recv_buf}; + } + auto work = SendRecv(params).post(*communicator); + if (work) { + work->wait(); + } +} + // Send a possibly sharded tensor represented by a PipelineVal // to one "tester" device void SendToTester( PipelineVal* pVal, at::Tensor tensor, + at::Tensor tester_tensor, DeviceIdxType tester, Communicator* communicator) { std::vector buffer; auto& mesh = pVal->getStage()->descriptor()->mesh; - if (isParallelTypeDeviceDim(pVal->getOriginalVal() - ->as() - ->getRootDomain() - .at(0) - ->getParallelType())) { + if (isSharded(pVal->getOriginalVal()->as())) { for (DeviceIdxType j : c10::irange(mesh.vector().size())) { - buffer = {tensor.index({j, "..."})}; + at::Tensor send_buf, recv_buf; auto sender = mesh.vector().at(j); - if (tester != sender && - (communicator->deviceId() == sender || - communicator->deviceId() == tester)) { - communicator->sendRecv(tester, sender, buffer)->wait(); + if (communicator->deviceId() == sender || + communicator->deviceId() == tester) { + if (communicator->deviceId() == sender) { + send_buf = tensor.index({0, "..."}); + } + if (communicator->deviceId() == tester) { + recv_buf = tester_tensor.index({j, "..."}); + } + doSendRecv(sender, tester, send_buf, recv_buf, communicator); } } } else { - buffer = {tensor}; + at::Tensor send_buf, recv_buf; auto sender = mesh.vector().at(0); if (tester != sender && (communicator->deviceId() == sender || communicator->deviceId() == tester)) { - communicator->sendRecv(tester, sender, buffer)->wait(); + if (communicator->deviceId() == sender) { + send_buf = tensor; + } + if (communicator->deviceId() == tester) { + recv_buf = tester_tensor; + } + doSendRecv(sender, tester, send_buf, recv_buf, communicator); } } } @@ -130,20 +175,29 @@ void testValidateMultidevice( bool validate = true, bool set_mem_type_to_global = true, bool auto_schedule = false) { + std::vector unsharded_inputs; + std::vector unsharded_outputs; + // gathering all the inputs at tester for (auto i : c10::irange(inputs.size())) { + c10::IValue unsharded_input = inputs.at(i).deepcopy(); + unsharded_inputs.push_back(unsharded_input); SendToTester( runtime.pipeline()->inputs().at(i)->as(), inputs.at(i).toTensor(), + unsharded_inputs.at(i).toTensor(), tester, communicator); } // gathering all the outputs at tester for (auto i : c10::irange(outputs.size())) { + at::Tensor unsharded_output = at::clone(outputs.at(i)); + unsharded_outputs.push_back(unsharded_output); SendToTester( runtime.pipeline()->outputs().at(i)->as(), outputs.at(i), + unsharded_outputs.at(i), tester, communicator); } @@ -153,7 +207,12 @@ void testValidateMultidevice( std::stringstream ss; std::string indent = " "; ss << "Obtained final outputs:{\n"; - for (auto& t : outputs) { + for (auto& t : unsharded_outputs) { + ss << indent << t; + } + ss << "\n}\n"; + ss << "Reference (unsharded) input:{\n"; + for (auto& t : unsharded_inputs) { ss << indent << t; } ss << "\n}"; @@ -161,8 +220,9 @@ void testValidateMultidevice( } // sets all the memory type to global to avoid an execution error - if (set_mem_type_to_global) { - for (auto tv : ir_utils::filterByType(fusion_ptr->vals())) { + for (auto tv : ir_utils::filterByType(fusion_ptr->vals())) { + unshardTv(tv); + if (set_mem_type_to_global) { tv->setMemoryType(MemoryType::Global); } } @@ -172,11 +232,11 @@ void testValidateMultidevice( Fusion& fusion = *fusion_ptr.get(); if (auto_schedule) { FusionExecutorCache fec(std::move(fusion_ptr)); - ref_outputs = fec.runFusionWithInputs(inputs); + ref_outputs = fec.runFusionWithInputs(unsharded_inputs); } else { FusionExecutor fe; - fe.compileFusion(&fusion, inputs); - ref_outputs = fe.runFusion(inputs); + fe.compileFusion(&fusion, unsharded_inputs); + ref_outputs = fe.runFusion(unsharded_inputs); } if (print) { @@ -191,7 +251,13 @@ void testValidateMultidevice( } if (validate) { - testValidate(&fusion, outputs, inputs, ref_outputs, __LINE__, __FILE__); + testValidate( + &fusion, + unsharded_outputs, + unsharded_inputs, + ref_outputs, + __LINE__, + __FILE__); } } } @@ -237,6 +303,7 @@ void executeAndValidatePipeline( void PipelineTest::SetUp() { MultiDeviceTest::SetUp(); fusion = std::make_unique(); + communicator->setDefaultBackend(CommunicatorBackend::nccl); } void PipelineTest::validate() { @@ -246,7 +313,12 @@ void PipelineTest::validate() { void PipelineTestTwoStages::SetUp() { PipelineTest::SetUp(); - auto [mesh0, mesh1, is_stage0_sharded, is_stage1_sharded] = GetParam(); + auto [backend, mesh0, mesh1, is_stage0_sharded, is_stage1_sharded] = + GetParam(); + if (!communicator->isBackendAvailable(backend)) { + GTEST_SKIP() << "Backend not available"; + } + communicator->setDefaultBackend(backend); FusionGuard fg(fusion.get()); TensorView* tv0 = makeContigTensor(4); diff --git a/test/multidevice.h b/test/multidevice.h index fea36a5896d..a3150d5fa0d 100644 --- a/test/multidevice.h +++ b/test/multidevice.h @@ -49,7 +49,9 @@ class MultiDeviceTest : public NVFuserTest { bool do_barrier_at_test; }; -class CommunicationTest : public MultiDeviceTest { +class CommunicationTest + : public MultiDeviceTest, + public ::testing::WithParamInterface { protected: void SetUp() override; void validate(at::Tensor obtained, at::Tensor expected); @@ -57,6 +59,8 @@ class CommunicationTest : public MultiDeviceTest { static constexpr DeviceIdxType root = 0; static constexpr int tensor_size = 1024; static constexpr int number_of_repetitions = 8; + static constexpr c10d::ReduceOp::RedOpType red_op = + c10d::ReduceOp::RedOpType::SUM; CommParams params; std::vector all_ranks; }; @@ -73,7 +77,7 @@ class PipelineTest : public MultiDeviceTest { //(first stage's mesh, second stage's mesh, is first stage sharded, is second // stage sharded) using PipelineTestTwoStagesParams = - std::tuple; + std::tuple; class PipelineTestTwoStages : public PipelineTest, public ::testing::WithParamInterface { diff --git a/test/test_alias.cpp b/test/test_alias.cpp index f46ddf454b4..a92850bb092 100644 --- a/test/test_alias.cpp +++ b/test/test_alias.cpp @@ -20,30 +20,16 @@ namespace nvfuser { +using testing::_; +using testing::Each; using testing::ElementsAre; using testing::IsEmpty; +using testing::IsTrue; +using testing::Optional; using testing::Pair; -using testing::UnorderedElementsAre; using AliasAnalysisTest = NVFuserTest; -TEST_F(AliasAnalysisTest, View_ContiguousAndSameAllocationOrder) { - Fusion fusion; - FusionGuard fg(&fusion); - - const std::vector in_shape({2, 3, 4}); - const std::vector out_shape({2, 12}); - - TensorView* in = makeContigConcreteTensor(in_shape); - fusion.addInput(in); - TensorView* out = reshape(in, in_shape, out_shape); - fusion.addOutput(out); - - optimization::AliasAnalysisResult alias_analysis = - optimization::findAliases(&fusion); - EXPECT_EQ(alias_analysis.findRoot(out), in); -} - TEST_F(AliasAnalysisTest, View_SymbolicTensor) { Fusion fusion; FusionGuard fg(&fusion); @@ -79,7 +65,7 @@ TEST_F(AliasAnalysisTest, ChainOfViews) { EXPECT_EQ(alias_analysis.findRoot(out), in); } -TEST_F(AliasAnalysisTest, View_DifferentAllocationOrder) { +TEST_F(AliasAnalysisTest, View_Contiguous) { Fusion fusion; FusionGuard fg(&fusion); @@ -90,33 +76,59 @@ TEST_F(AliasAnalysisTest, View_DifferentAllocationOrder) { fusion.addInput(in); TensorView* out = reshape(in, in_shape, out_shape); fusion.addOutput(out); - out->setAllocationDomain( - {out->axis(1), out->axis(0)}, /*new_contiguity=*/true); optimization::AliasAnalysisResult alias_analysis = optimization::findAliases(&fusion); - EXPECT_EQ(alias_analysis.findRoot(out), out); + EXPECT_EQ(alias_analysis.findRoot(out), in); + optimization::Layout preferred_layout = alias_analysis.preferredLayout(out); + EXPECT_THAT( + preferred_layout.allocation_domain, + ElementsAre(out->axis(0), out->axis(1))); + EXPECT_THAT(preferred_layout.contiguity, Each(Optional(IsTrue()))); } -TEST_F(AliasAnalysisTest, View_NonContiguous) { +TEST_F(AliasAnalysisTest, View_MergeNonContiguous) { Fusion fusion; FusionGuard fg(&fusion); const std::vector in_shape({2, 3, 4}); const std::vector out_shape({2, 12}); - TensorView* in = makeContigConcreteTensor(in_shape); + TensorView* in = TensorViewBuilder() + .shape(in_shape) + .dtype(DataType::Float) + .contiguity({true, false, true}) + .build(); fusion.addInput(in); TensorView* out = reshape(in, in_shape, out_shape); fusion.addOutput(out); - out->setAllocationDomain( - {out->axis(0), out->axis(1)}, /*new_contiguity=*/{true, false}); optimization::AliasAnalysisResult alias_analysis = optimization::findAliases(&fusion); EXPECT_EQ(alias_analysis.findRoot(out), out); } +TEST_F(AliasAnalysisTest, Set) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* in = makeContigConcreteTensor({2, 3, 5}); + fusion.addInput(in); + TensorView* out = set(in); + fusion.addOutput(out); + + in->setAllocationDomain({in->axis(1), in->axis(2), in->axis(0)}, true); + + optimization::AliasAnalysisResult alias_analysis = + optimization::findAliases(&fusion); + EXPECT_EQ(alias_analysis.findRoot(out), in); + + const std::vector& out_rfactor = out->getMaybeRFactorDomain(); + EXPECT_THAT( + alias_analysis.preferredLayout(out).allocation_domain, + ElementsAre(out_rfactor[1], out_rfactor[2], out_rfactor[0])); +} + TEST_F(AliasAnalysisTest, Permute) { Fusion fusion; FusionGuard fg(&fusion); @@ -144,19 +156,20 @@ TEST_F(AliasAnalysisTest, View_SplitExpandedBroadcast) { TensorView* in = makeContigConcreteTensor({4, 5}); fusion.addInput(in); - TensorView* out = broadcast(in, {false, false, true}); - out = expand( - out, + TensorView* broadcast_out = broadcast(in, {false, false, true}); + TensorView* expand_out = expand( + broadcast_out, {IrBuilder::create(4), IrBuilder::create(5), IrBuilder::create(6)}); // tryStaticReshape used to fail to get the expanded extent, which is 6. - out = reshape(out, {IrBuilder::create(40), IrBuilder::create(3)}); + TensorView* out = reshape( + expand_out, {IrBuilder::create(40), IrBuilder::create(3)}); fusion.addOutput(out); optimization::AliasAnalysisResult alias_analysis = optimization::findAliases(&fusion); - EXPECT_EQ(alias_analysis.findRoot(out), out); + EXPECT_EQ(alias_analysis.findRoot(out), expand_out); } TEST_F(AliasAnalysisTest, View_ForwardExpandedBroadcast) { @@ -194,18 +207,64 @@ TEST_F(AliasAnalysisTest, View_MergeExpandedBroadcast) { TensorView* in = makeContigConcreteTensor({4, 5}); fusion.addInput(in); - TensorView* out = broadcast(in, {false, false, true}); - out = expand( - out, + TensorView* broadcast_out = broadcast(in, {false, false, true}); + TensorView* expand_out = expand( + broadcast_out, {IrBuilder::create(4), IrBuilder::create(5), IrBuilder::create(6)}); - out = reshape(out, {4, 5, 6}, {4, -1}); + TensorView* out = reshape(expand_out, {4, 5, 6}, {4, -1}); + fusion.addOutput(out); + + optimization::AliasAnalysisResult alias_analysis = + optimization::findAliases(&fusion); + EXPECT_EQ(alias_analysis.findRoot(out), expand_out); +} + +TEST_F(AliasAnalysisTest, TrivialSlice) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* in = makeContigConcreteTensor({2, 3}); + fusion.addInput(in); + TensorView* out = slice(in, {0, 0}, {2, 3}); + out = reshape(out, {2, 3}, {6}); + fusion.addOutput(out); + + optimization::AliasAnalysisResult alias_analysis = + optimization::findAliases(&fusion); + EXPECT_EQ(alias_analysis.findRoot(out), in); +} + +TEST_F(AliasAnalysisTest, MergeTriviallySlicedDimensions) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* in = makeContigConcreteTensor({2, 3, 5}); + fusion.addInput(in); + TensorView* out = slice(in, {0, 0, 0}, {2, 2, 5}); + out = reshape(out, {2, 2, 5}, {2, 10}); + fusion.addOutput(out); + + optimization::AliasAnalysisResult alias_analysis = + optimization::findAliases(&fusion); + EXPECT_EQ(alias_analysis.findRoot(out), in); +} + +TEST_F(AliasAnalysisTest, MergeSlicedDimensions) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* in = makeContigConcreteTensor({2, 3, 5}); + fusion.addInput(in); + TensorView* slice_out = slice(in, {0, 0, 0}, {2, 2, 5}); + TensorView* out = reshape(slice_out, {2, 2, 5}, {4, 5}); fusion.addOutput(out); optimization::AliasAnalysisResult alias_analysis = optimization::findAliases(&fusion); EXPECT_EQ(alias_analysis.findRoot(out), out); + EXPECT_EQ(alias_analysis.findRoot(slice_out), in); } using AliasTest = NVFuserTest; @@ -233,13 +292,37 @@ TEST_F(AliasTest, View) { EXPECT_EQ(in_tensor.data_ptr(), out_tensor.data_ptr()); // Verify output values. - testValidate( - fec.fusion(), - {out_tensor}, - {in_tensor}, - {in_tensor.view({2, 12})}, - __LINE__, - __FILE__); + testValidate(fec.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); +} + +TEST_F(AliasTest, View_NoAliasForIncompatibleLayout) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const std::vector in_shape({2, 3, 4}); + const std::vector out_shape({2, 12}); + + TensorView* in = makeContigConcreteTensor(in_shape); + fusion->addInput(in); + TensorView* out = reshape(in, in_shape, out_shape); + fusion->addOutput(out); + + // I intentionally set the allocation order to be column major to break the + // alias. + out->setAllocationDomain({out->axis(1), out->axis(0)}, true); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = + at::randn({2, 3, 4}, at::dtype(at::kFloat).device(at::kCUDA, 0)); + std::vector out_tensors = fec.runFusionWithInputs({in_tensor}); + ASSERT_EQ(out_tensors.size(), 1); + at::Tensor out_tensor = out_tensors[0]; + + // Verify `out_tensor` is not an alias of `in_tensor`. + EXPECT_FALSE(out_tensor.is_alias_of(in_tensor)); + + // Verify output values. + testValidate(fec.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); } TEST_F(AliasTest, ViewPermute) { @@ -265,14 +348,244 @@ TEST_F(AliasTest, ViewPermute) { // Verify aliasing. EXPECT_EQ(in_tensor.data_ptr(), out_tensor.data_ptr()); + // Verify output values. + testValidate(fec.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); +} + +TEST_F(AliasTest, DuplicateOutputs) { + // testing a complete fusion + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const std::vector in_shape({2, 3, 4}); + + TensorView* in = makeContigConcreteTensor(in_shape); + fusion->addInput(in); + TensorView* out = add(in, IrBuilder::create(3.141)); + fusion->addOutput(out); + fusion->addOutput(out); // duplicated outputs + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = + at::randn(in_shape, at::dtype(at::kFloat).device(at::kCUDA, 0)); + std::vector out_tensors = fec.runFusionWithInputs({in_tensor}); + ASSERT_EQ(out_tensors.size(), 2); + at::Tensor out_tensor_0 = out_tensors[0]; + at::Tensor out_tensor_1 = out_tensors[1]; + + // Verify aliasing among duplicated outputs + EXPECT_TRUE(out_tensor_0.is_alias_of(out_tensor_1)); + // Verify no segmentation + NVF_CHECK( + !fec.getMostRecentKernelRuntime()->isSegmented(), + "segmentation is not supposed to happen"); + + at::Tensor expected_out_tensor = in_tensor.add(3.141); + // Verify output values. + testValidate( + fec.fusion(), + {expected_out_tensor, expected_out_tensor}, + {in_tensor}, + __LINE__, + __FILE__); +} + +TEST_F(AliasTest, SliceToSizeOne_Issue1353) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({4, 6, 7}); + fusion->addInput(in); + TensorView* out = slice(in, {0, 0, 0}, {4, 6, 1}); + fusion->addOutput(out); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = at::randn({4, 6, 7}).cuda(); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + EXPECT_EQ(in_tensor.data_ptr(), out_tensor.data_ptr()); + EXPECT_THAT(out_tensor.strides(), ElementsAre(42, 7, _)); + + testValidate( + fec.fusion(), + {in_tensor.slice(/*dim=*/2, /*start=*/c10::nullopt, /*end=*/1)}, + {in_tensor}, + __LINE__, + __FILE__); +} + +TEST_F(AliasTest, SliceRightOfBroadcast) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({4, 1, 7}); + fusion->addInput(in); + TensorView* out = slice(in, {0, 0, 0}, {4, 1, 5}); + fusion->addOutput(out); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = at::randn({4, 1, 7}).cuda(); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + EXPECT_EQ(in_tensor.data_ptr(), out_tensor.data_ptr()); + EXPECT_THAT(out_tensor.strides(), ElementsAre(7, _, 1)); + + testValidate( + fec.fusion(), + {in_tensor.slice(/*dim=*/2, /*start=*/c10::nullopt, /*end=*/5)}, + {in_tensor}, + __LINE__, + __FILE__); +} + +TEST_F(AliasTest, SliceViewPermute) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + constexpr int batches = 16; + constexpr int seq_length = 128; + constexpr int features = 1024; + constexpr int heads = 16; + + // The input tensor is a concatenation of [query, key, value], and therefore + // has a feature dimension of size `features * 3`. + TensorView* in = + makeContigConcreteTensor({batches, seq_length, features * 3}); + fusion->addInput(in); + std::vector splits({ + slice(in, {0, 0, 0}, {batches, seq_length, features}), + slice(in, {0, 0, features}, {batches, seq_length, features * 2}), + slice(in, {0, 0, features * 2}, {batches, seq_length, features * 3}), + }); + for (TensorView* split : splits) { + split = reshape( + split, + {batches, seq_length, features}, + {batches, seq_length, heads, features / heads}); + split = permute(split, {0, 2, 1, 3}); + fusion->addOutput(split); + } + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = at::randn({batches, seq_length, features * 3}).cuda(); + std::vector out_tensors = fec.runFusionWithInputs({in_tensor}); + EXPECT_EQ(out_tensors.size(), 3); + + for (const auto& out_tensor : out_tensors) { + EXPECT_TRUE(out_tensor.is_alias_of(in_tensor)); + } + + std::vector expected_out_tensors = + in_tensor.split(/*split_size=*/features, /*dim=*/-1); + for (auto& expected_out_tensor : expected_out_tensors) { + expected_out_tensor = + expected_out_tensor.view({batches, seq_length, heads, -1}) + .permute({0, 2, 1, 3}); + } + + testValidate( + fec.fusion(), + out_tensors, + {in_tensor}, + expected_out_tensors, + __LINE__, + __FILE__); +} + +TEST_F(AliasTest, DuplicateOutputsSegmentedFusion) { + // testing duplicated output in segmented fusion + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const std::vector in_shape({2, 3, 4}); + + TensorView* in = makeContigConcreteTensor(in_shape); + fusion->addInput(in); + TensorView* intermediate_tv = add(in, IrBuilder::create(3.141)); + TensorView* segment_tv = segment_set(intermediate_tv); + TensorView* out = mul(segment_tv, IrBuilder::create(2.0)); + + fusion->addOutput(intermediate_tv); + fusion->addOutput(intermediate_tv); + fusion->addOutput(out); + fusion->addOutput(out); // duplicated outputs + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = + at::randn(in_shape, at::dtype(at::kFloat).device(at::kCUDA, 0)); + std::vector out_tensors = fec.runFusionWithInputs({in_tensor}); + ASSERT_EQ(out_tensors.size(), 4); + at::Tensor out_tensor_0 = out_tensors[0]; + at::Tensor out_tensor_1 = out_tensors[1]; + at::Tensor out_tensor_2 = out_tensors[2]; + at::Tensor out_tensor_3 = out_tensors[3]; + + // Verify aliasing among duplicated outputs + EXPECT_TRUE(out_tensor_0.is_alias_of(out_tensor_1)); + EXPECT_TRUE(out_tensor_2.is_alias_of(out_tensor_3)); + // Verify segmentation + NVF_CHECK( + fec.getMostRecentKernelRuntime()->fusionSegments()->groups().size() == 2, + "segmentation didn't happen as expected"); + + at::Tensor intermediate_tensor = in_tensor.add(3.141); + at::Tensor out_tensor = intermediate_tensor.mul(2.0); // Verify output values. testValidate( fec.fusion(), - {out_tensor}, + {intermediate_tensor, intermediate_tensor, out_tensor, out_tensor}, {in_tensor}, - {in_tensor.view({2, 12}).permute({1, 0})}, __LINE__, __FILE__); } +TEST_F(AliasTest, NotAllOutputsAlias) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({2, 3}); + TensorView* slice_out = slice(in, {0, 0}, {2, 2}); + TensorView* add_out = add(in, fusion->oneVal()); + fusion->addInput(in); + fusion->addOutput(slice_out); + fusion->addOutput(add_out); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = at::randn({2, 3}).cuda(); + std::vector out_tensors = fec.runFusionWithInputs({in_tensor}); + + // As a known limitation, nvFuser still generates code to copy data from `in` + // to `slice_out` despite the fact that `slice_out` is an alias. + testValidate( + fec.fusion(), + out_tensors, + {in_tensor}, + {in_tensor.slice(/*dim=*/1, /*start=*/0, /*end=*/2), in_tensor + 1.f}, + __LINE__, + __FILE__); + + at::Tensor slice_out_tensor = out_tensors[0]; + EXPECT_TRUE(slice_out_tensor.is_alias_of(in_tensor)); +} + +TEST_F(AliasTest, Set_NoAliasForIncompatibleLayout) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({2, 3, 5}); + fusion->addInput(in); + TensorView* out = set(in); + fusion->addOutput(out); + + // I intentionally set the allocation order to be different to block aliasing. + out->setAllocationDomain({out->axis(1), out->axis(2), out->axis(0)}, true); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = at::randn({2, 3, 5}).cuda(); + std::vector out_tensors = fec.runFusionWithInputs({in_tensor}); + ASSERT_EQ(out_tensors.size(), 1); + at::Tensor out_tensor = out_tensors[0]; + + // Verify `out_tensor` is not an alias of `in_tensor`. + EXPECT_FALSE(out_tensor.is_alias_of(in_tensor)); +} + } // namespace nvfuser diff --git a/test/test_allocation_domain.cpp b/test/test_allocation_domain.cpp index 71245855bfd..b758d9b7f9d 100644 --- a/test/test_allocation_domain.cpp +++ b/test/test_allocation_domain.cpp @@ -24,6 +24,8 @@ namespace nvfuser { class AllocationDomainTest : public NVFuserTest {}; +using ::testing::ElementsAre; + // A global->shared->global copy kernel, shared memory allocated transposed to // avoid bank conflict. TEST_F(AllocationDomainTest, TransposedIntermediate) { @@ -59,7 +61,7 @@ TEST_F(AllocationDomainTest, TransposedIntermediate) { FusionExecutor fe; fe.compileFusion(fusion_ptr.get(), {t0}); auto cg_outputs = fe.runFusion({t0}); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // A global->global copy kernel converting NCHW memory format into NHWC, with a @@ -103,7 +105,7 @@ TEST_F(AllocationDomainTest, NCHW4d_To_NHWC4d) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // A global->global copy kernel converting NCHW memory format into NHWC, with a @@ -144,7 +146,7 @@ TEST_F(AllocationDomainTest, NCHW4d_To_NHWC1d) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // A global->global copy kernel converting NCHW memory format into NHWC, with a @@ -186,7 +188,7 @@ TEST_F(AllocationDomainTest, NCHW4d_To_NHWC2d) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // Reshape and transpose a 3d tensor into an NHWC tensor with a 3d allocation @@ -358,7 +360,7 @@ TEST_F(AllocationDomainTest, NHWC4d_To_NHWC4d) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // A global->global copy kernel where both inputs are NHWC memory format. The @@ -419,7 +421,7 @@ TEST_F(AllocationDomainTest, NHWC1d_To_NHWC4d) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // A global->global copy kernel where both inputs are NHWC memory format. The @@ -475,7 +477,7 @@ TEST_F(AllocationDomainTest, NHWC4d_To_NHWC1d) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // A global->global copy kernel where both inputs are NHWC memory format. The @@ -536,7 +538,7 @@ TEST_F(AllocationDomainTest, NHWC1d_To_NHWC1d) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // A global->global copy kernel where both inputs are NHWC memory format. The @@ -604,7 +606,7 @@ TEST_F(AllocationDomainTest, NHWC2d_To_NHWC2d) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // Similar to NHWC4d_To_NHWC4d, but does a cacheBefore @@ -671,7 +673,7 @@ TEST_F(AllocationDomainTest, NHWC4d_To_NHWC4d_cacheBefore) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // Similar to NHWC2d_To_NHWC2d, but does a cacheBefore @@ -748,7 +750,7 @@ TEST_F(AllocationDomainTest, NHWC2d_To_NHWC2d_cacheBefore) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // Similar to NHWC4d_To_NHWC4d, but does a cacheAfter @@ -815,7 +817,7 @@ TEST_F(AllocationDomainTest, NHWC4d_To_NHWC4d_cacheAfter) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // NOT similar to NHWC2d_To_NHWC2d, because cacheAfter requires the @@ -886,7 +888,7 @@ TEST_F(AllocationDomainTest, NHWC2d_To_NHWC2d_cacheAfter) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // Similar to NHWC4d_To_NHWC4d, but does a cacheFork @@ -960,7 +962,7 @@ TEST_F(AllocationDomainTest, NHWC4d_To_NHWC4d_cacheFork) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0, t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // Similar to NHWC2d_To_NHWC2d, but does a cacheFork @@ -1050,7 +1052,7 @@ TEST_F(AllocationDomainTest, NHWC2d_To_NHWC2d_cacheFork) { ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, {t0, t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } TEST_F(AllocationDomainTest, VectorizationIssue902) { @@ -1139,7 +1141,7 @@ TEST_F(NVFuserTest, AllocationDomainContiguityIssue1021) { auto outputs = fec.runFusionWithInputs({t0}); auto t1 = t0.add(5.0); - testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__); + testValidate(fusion, outputs, {t0}, __LINE__, __FILE__); } TEST_F(NVFuserTest, AllocationDomainContiguityForBroadcast) { @@ -1165,7 +1167,7 @@ TEST_F(NVFuserTest, AllocationDomainContiguityForBroadcast) { auto outputs = fec.runFusionWithInputs({t0}); auto t1 = t0.add(5.0); - testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__); + testValidate(fusion, outputs, {t0}, __LINE__, __FILE__); } TEST_F(NVFuserTest, AllocationDomainContiguityForExplicitBroadcast) { @@ -1192,7 +1194,7 @@ TEST_F(NVFuserTest, AllocationDomainContiguityForExplicitBroadcast) { auto outputs = fec.runFusionWithInputs({t0}); auto t1 = t0.add(5.0); - testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__); + testValidate(fusion, outputs, {t0}, __LINE__, __FILE__); } // Test that allocation domain can be used to vectorize overlapping tensors, @@ -1248,4 +1250,62 @@ TEST_F(AllocationDomainTest, VectorizeOverlappingTensor) { testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } +TEST_F(AllocationDomainTest, Issue1290_ContiguityWasMissing) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = TensorViewBuilder() + .ndims(2) + .dtype(DataType::Float) + .contiguity({false, true}) + .shape({-1, -1}) + .build(); + fusion->addInput(in); + TensorView* out1 = permute(in, {1, 0}); + fusion->addOutput(out1); + TensorView* out2 = add(out1, fusion->oneVal()); + fusion->addOutput(out2); + + at::Tensor in_tensor = at::randn({2 * 4}).cuda().as_strided({2, 3}, {4, 1}); + + FusionExecutorCache fec(std::move(fusion)); + fec.runFusionWithInputs({in_tensor}); + + // The initial issue was detected in the pointwise scheduler, so I added these + // checks to make sure it's a valid regression test. The transpose scheduler + // could accept this but decided not to because of a small problem size. + const std::vector& groups = + fec.getMostRecentKernelRuntime()->fusionSegments()->groups(); + ASSERT_EQ(groups.size(), 1); + SegmentedGroup* group = groups[0]; + EXPECT_EQ(group->heuristic(), ScheduleHeuristic::PointWise); +} + +TEST_F(AllocationDomainTest, Issue1290_ReplayCasPFailedDueToDifferentRanks) { + Fusion fusion; + FusionGuard fg(&fusion); + TensorView* in = makeContigConcreteTensor({2, 3}); + TensorView* out = sum(in, {1}); + fusion.addInput(in); + fusion.addOutput(out); + + out->setAllocationDomain({out->axis(0), out->axis(1)}, true); + out->cacheBefore(); + + at::Tensor in_tensor = at::randn({2, 3}).cuda(); + FusionExecutor fe; + fe.compileFusion(&fusion, {in_tensor}); + at::Tensor out_tensor = fe.runFusion({in_tensor})[0]; + EXPECT_THAT(out_tensor.sizes(), ElementsAre(2)); +} + +TEST_F(AllocationDomainTest, TrivialStrideOrderTensorViewBuilder) { + Fusion fusion; + FusionGuard fg(&fusion); + TensorView* tv0 = TensorViewBuilder().ndims(2).strideOrder({0, 1}).build(); + EXPECT_TRUE(tv0->hasAllocation()); + tv0 = TensorViewBuilder().ndims(2).strideOrder({1, 0}).build(); + EXPECT_TRUE(!tv0->hasAllocation()); +} + } // namespace nvfuser diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 8a06e436ae7..f097afc318e 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -212,9 +212,7 @@ TEST_F(NVFuserTest, DynamicTransform3_CUDA) { FusionExecutorCache fec(std::move(fusion_ptr)); auto cg_outputs = fec.runFusionWithInputs(inputs); - auto ref = t1 + t0.reshape(shape_after); - - testValidate(fec.fusion(), cg_outputs, inputs, {ref}, __LINE__, __FILE__); + testValidate(fec.fusion(), cg_outputs, inputs, __LINE__, __FILE__); } // Test multiple patterns of reshape @@ -698,9 +696,8 @@ TEST_F(NVFuserTest, DynamicTransformFusionExecutorCache_CUDA) { auto t1 = at::randn({3, 4}, options); std::vector inputs = {t0, t1}; auto cg_outputs = executor_cache.runFusionWithInputs(inputs); - auto ref = t0 + t1; testValidate( - executor_cache.fusion(), cg_outputs, inputs, {ref}, __LINE__, __FILE__); + executor_cache.fusion(), cg_outputs, inputs, __LINE__, __FILE__); NVF_CHECK( executor_cache.countRuntimes() == 1, "Expect to create a single runtime"); @@ -710,9 +707,8 @@ TEST_F(NVFuserTest, DynamicTransformFusionExecutorCache_CUDA) { auto t1 = at::randn({4, 3}, options); std::vector inputs = {t0, t1}; auto cg_outputs = executor_cache.runFusionWithInputs(inputs); - auto ref = t0.view({4, 3}) + t1; testValidate( - executor_cache.fusion(), cg_outputs, inputs, {ref}, __LINE__, __FILE__); + executor_cache.fusion(), cg_outputs, inputs, __LINE__, __FILE__); auto num_rts = executor_cache.countRuntimes(); auto num_concs = executor_cache.countConcretizations(); NVF_CHECK(num_rts == 2, "Non-trivial reshape should create new runtime"); @@ -725,9 +721,8 @@ TEST_F(NVFuserTest, DynamicTransformFusionExecutorCache_CUDA) { auto t1 = at::randn({4, 3}, options); std::vector inputs = {t0, t1}; auto cg_outputs = executor_cache.runFusionWithInputs(inputs); - auto ref = t0.view({4, 3}) + t1; testValidate( - executor_cache.fusion(), cg_outputs, inputs, {ref}, __LINE__, __FILE__); + executor_cache.fusion(), cg_outputs, inputs, __LINE__, __FILE__); auto num_rts = executor_cache.countRuntimes(); auto num_concs = executor_cache.countConcretizations(); NVF_CHECK( @@ -844,11 +839,8 @@ void reductionDynamicViewAddFusion( auto at_tv1 = (reshape_before_reduction) ? (at_x + at_bias) : at::sum(at_x, kReductionAxis); auto at_x_reshape = at::native::view(at_tv1, output_shape); - auto at_y = (reshape_before_reduction) - ? at::sum(at_x_reshape, kReductionAxis) - : at::add(at_x_reshape, at_bias); - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } } @@ -952,7 +944,7 @@ void reductionDynamicPadAddFusion( auto at_x_pad = at::pad(at_x, pad_widths); auto at_y = at::sum(at_x_pad, kReductionAxis); - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } } #undef CHECK_CACHE @@ -1018,9 +1010,7 @@ TEST_F(NVFuserTest, FusionDynamicSliceToBroadcast_CUDA) { at::Tensor at0 = at::randn({5}, options); std::vector aten_inputs = {at0}; auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); - auto at1 = at::slice(at0, 0, 0, 2); - auto at2 = at::slice(at1, 0, 0, 1); - testValidate(&fusion, outputs, aten_inputs, {at2}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } // Test that empty input to cat is concretized away @@ -1048,8 +1038,7 @@ TEST_F(NVFuserTest, FusionDynamicEmptyCat1_CUDA) { at::Tensor at2 = at::randn({3}, options); std::vector aten_inputs = {at0, at1, at2}; auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); - auto at3 = at::cat({at0, at1, at2}, 0); - testValidate(&fusion, outputs, aten_inputs, {at3}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } // Test that empty input to cat is concretized away @@ -1074,8 +1063,7 @@ TEST_F(NVFuserTest, FusionDynamicEmptyCat2_CUDA) { at::Tensor at1 = at::randn({0}, options); std::vector aten_inputs = {at0, at1}; auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); - auto at2 = at::cat({at0, at1}, 0); - testValidate(&fusion, outputs, aten_inputs, {at2}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); // Check that fusion consists only of tv2 = set(tv0) auto fkr = fusion_executor_cache.getMostRecentKernelRuntime(); @@ -1144,18 +1132,11 @@ TEST_F(NVFuserTest, Issue249_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor at_x = at::randn({2, 3, 4, 5}, options); - auto at_y = (at_x + at_x).reshape({2, 4, -1}); - auto at_z = at_y + at_y; auto outputs = fusion_executor_cache.runFusionWithInputs({at_x}); testValidate( - fusion_executor_cache.fusion(), - outputs, - {at_x}, - {at_z}, - __LINE__, - __FILE__); + fusion_executor_cache.fusion(), outputs, {at_x}, __LINE__, __FILE__); } // This is just like the test above, but uses an input scalar with value -1 @@ -1183,8 +1164,6 @@ TEST_F(NVFuserTest, Issue249InputNegative1_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor at_x = at::randn({2, 3, 4, 5}, options); - auto at_y = (at_x + at_x).reshape({2, 4, -1}); - auto at_z = at_y + at_y; // Dynamic reshape sizes that are not constant at definition must be explicit: // no -1 allowed @@ -1199,7 +1178,6 @@ TEST_F(NVFuserTest, Issue249InputNegative1_CUDA) { fusion_executor_cache.fusion(), outputs, {at_x, 2, 4, 15}, - {at_z}, __LINE__, __FILE__); } @@ -1244,7 +1222,7 @@ TEST_F(NVFuserTest, OptOutMutatorMutatedOutput) { auto outputs = fe.runFusion({t0}); - testValidate(fusion, outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(fusion, outputs, {t0}, __LINE__, __FILE__); } // Another test related to https://github.com/NVIDIA/Fuser/issues/852 @@ -1274,9 +1252,6 @@ TEST_F(NVFuserTest, OptOutMutatorRedefinedConstant) { c->definition(), nullptr); // Replacement value should not be redefined EXPECT_EQ(tv0->definition()->as()->getFillValue(), c); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::full({2}, 1L, options); - inlineMost(); FusionExecutor fe; @@ -1284,7 +1259,48 @@ TEST_F(NVFuserTest, OptOutMutatorRedefinedConstant) { auto outputs = fe.runFusion({3L}); - testValidate(fusion, outputs, {3L}, {t0}, __LINE__, __FILE__); + testValidate(fusion, outputs, {3L}, __LINE__, __FILE__); +} + +// Test that we can squeeze Symbolic IterDomains and that we properly detect +// improper concretizations where we have squeezed a dimension with extent +// other than 1. +// See https://github.com/NVIDIA/Fuser/issues/1273 +TEST_F(NVFuserTest, SymbolicSqueeze) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + auto s0 = IrBuilder::create(DataType::Index); + auto s1 = IrBuilder::create(DataType::Index); + fusion->addInput(tv0); + fusion->addInput(s0); + fusion->addInput(s1); + + auto tv1 = reshape(tv0, {s0, s1}); + auto tv2 = squeeze( + tv1, std::vector({false, true})); // Squeeze second dimension + fusion->addOutput(tv2); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({3, 2}, options); + std::vector valid_inputs = {t0, 6, 1}; + // An invalid input has a second dimension that cannot be squeezed + std::vector invalid_inputs = {t0, 2, 3}; + + auto outputs = fec.runFusionWithInputs(valid_inputs); + + testValidate(fusion, outputs, valid_inputs, __LINE__, __FILE__); + + // An informative error message should be given by + // SqueezeOp::checkConcretization + EXPECT_THAT( + [&]() { fec.runFusionWithInputs(invalid_inputs); }, + ::testing::ThrowsMessage(::testing::HasSubstr( + " must concretize to IterType::Broadcast but found"))); } } // namespace nvfuser diff --git a/test/test_external_src.cpp b/test/test_external_src.cpp index 16310c9ef53..9b7fbf3f016 100644 --- a/test/test_external_src.cpp +++ b/test/test_external_src.cpp @@ -116,7 +116,7 @@ TEST_F(ExternalSrcExample, Matmul_CUDA) { fe.compileRtc(cuda_src_str, "kernel1", true, PrimDataType::Int32); int M = 2048, N = 3456, K = 2048; - MmaOptions::MmaLayout layout = MmaOptions::MmaLayout::TN; + MmaLayout layout = MmaLayout::TN; auto inputs = matmulAtInput(M, N, K, layout); auto at_output = atMatmul(inputs.first, inputs.second, layout).to(at::kFloat); diff --git a/test/test_fusion_profiler.cpp b/test/test_fusion_profiler.cpp index b7847fe996c..ec7d0ae5d1a 100644 --- a/test/test_fusion_profiler.cpp +++ b/test/test_fusion_profiler.cpp @@ -21,7 +21,21 @@ namespace nvfuser { -class FusionProfilerTest : public NVFuserTest {}; +class FusionProfilerTest : public NVFuserTest { + protected: + void SetUp() override { + NVFuserTest::SetUp(); + saved_ = ProfilerOptionsGuard::getCurOptions(); + } + + void TearDown() override { + ProfilerOptionsGuard::getCurOptions() = saved_; + NVFuserTest::TearDown(); + } + + private: + Options saved_; +}; // RUN CMD: bin/nvfuser_tests --gtest_filter="*Profile1Segment*" TEST_F(FusionProfilerTest, Profile1Segment) { diff --git a/test/test_gather.cpp b/test/test_gather.cpp index 6f206899675..8128c8f2a45 100644 --- a/test/test_gather.cpp +++ b/test/test_gather.cpp @@ -286,18 +286,12 @@ TEST_F(IndexingOpTest, TorchGatherSumAdd_CUDA) { at::Tensor input2 = at::randn(input2_dims, options); // lookup at::Tensor input_idx = at::randint(0, input_dims[dim], index_dims, options_i); - at::Tensor output = at::zeros(index_dims, options); - - auto t_gather = at::gather(input, dim, input_idx); - auto t_sum = at::sum(t_gather.to(at::kDouble), {0}, true); - auto tv_out_ref = at::add(input2.to(at::kDouble), t_sum); std::vector aten_inputs = {input, input_idx, input2}; FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, {tv_out_ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } } } @@ -448,11 +442,7 @@ TEST_F(IndexingOpTest, TakeAlongBroadcastIndex_CUDA) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t4 = at::take_along_dim( - t0, t1.unsqueeze(0).unsqueeze(-1).expand(out_dims), 1); - auto ref = t4 + t2; - - testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } } @@ -510,12 +500,7 @@ TEST_F(IndexingOpTest, GatherBroadcastInput_CUDA) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t4 = is_take_along ? at::take_along_dim(t0, t1.unsqueeze(-1), 1) - : at::gather(t0, 1, t1.unsqueeze(-1)); - auto ref = t4 + t2; - - testValidate( - &fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } } } @@ -604,8 +589,6 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorPointwise1_CUDA) { auto outputs = fe.runFusion(aten_inputs); - auto ref = at::take_along_dim(t0 + 1, t1.unsqueeze(-1), 1); - testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } @@ -642,8 +625,6 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorPointwise2_CUDA) { validateSegmentation( fec.getMostRecentKernelRuntime(), {ScheduleHeuristic::PointWise}); - auto ref = at::take_along_dim(t0 + 1, t1.unsqueeze(-1), 1); - testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } @@ -679,9 +660,7 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorReduction1_CUDA) { fec.getMostRecentKernelRuntime(), {ScheduleHeuristic::Reduction, ScheduleHeuristic::PointWise}); - auto ref = at::take_along_dim(t0.to(at::kDouble).sum({1}), t1, 0); - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } // take_along_axis to broadcast, squeeze, then reduction. Segmented @@ -721,10 +700,7 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorReduction2_CUDA) { fec.getMostRecentKernelRuntime(), {ScheduleHeuristic::PointWise, ScheduleHeuristic::Reduction}); - auto t4 = at::take_along_dim(t0.to(at::kDouble) + 1, t1.unsqueeze(-1), 1); - auto ref = t4.squeeze(1).sum({0}); - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } // take_along_axis then reduction. Should not be segmented. @@ -762,9 +738,7 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorReduction3_CUDA) { validateSegmentation( fec.getMostRecentKernelRuntime(), {ScheduleHeuristic::Reduction}); - auto ref = at::take_along_dim(t0.to(at::kDouble) + 1, t1, 1).sum({1}); - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } // Similar to TakeAlongAxisIntermediateTensorReduction2, but no @@ -805,10 +779,7 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorReduction4_CUDA) { validateSegmentation( fec.getMostRecentKernelRuntime(), {ScheduleHeuristic::Reduction}); - auto ref = - at::take_along_dim(t0.to(at::kDouble) + 1, t1.unsqueeze(-1), 1).sum({0}); - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } // Normalization then take_along_axis @@ -1074,9 +1045,7 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorTranspose1_CUDA) { validateSegmentation( fec.getMostRecentKernelRuntime(), {ScheduleHeuristic::Transpose}); - auto ref = at::take_along_dim(t0 + 1, t1.unsqueeze(0), 0).transpose(1, 2); - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } // transpose then take_along_axis. Currently failed to pick the @@ -1119,9 +1088,7 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorTranspose2_CUDA) { validateSegmentation( fec.getMostRecentKernelRuntime(), {ScheduleHeuristic::PointWise}); - auto ref = at::take_along_dim(t0.transpose(1, 2), t1, 0); - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } // transpose the dimension produced by take_along_axis. Currently not @@ -1165,9 +1132,7 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorTranspose3_CUDA) { validateSegmentation( fec.getMostRecentKernelRuntime(), {ScheduleHeuristic::PointWise}); - auto ref = at::take_along_dim(t0 + 1, t1.unsqueeze(0), 2).transpose(1, 2); - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(IndexingOpTest, TakeAlongAxisCrossEntropyLoss_CUDA) { diff --git a/test/test_gpu1.cpp b/test/test_gpu1.cpp index ce3a6c12c6a..ff2fc44d4e5 100644 --- a/test/test_gpu1.cpp +++ b/test/test_gpu1.cpp @@ -90,12 +90,16 @@ TEST_F(NVFuserTest, FusionIrGraphGenerator_CUDA) { fusion.addOutput(tv6); - tv4->axis(2)->parallelize(ParallelType::BIDy); + tv5->reorder({{-1, 0}}); tv6->merge(0); tv6->split(0, 4); + TransformPropagatorWithCheck propagator(tv6); + MaxRootDomainInfoSpanningTree(tv6).traverse(&propagator); + + tv4->axis(2)->parallelize(ParallelType::BIDy); tv6->axis(0)->parallelize(ParallelType::BIDx); - tv5->reorder({{-1, 0}}); - tv2->computeAt(tv6, 1); + + inlineMost(); // Another checkpoint with more node types NVF_CHECK(!IrGraphGenerator::toGraphviz( @@ -149,12 +153,14 @@ TEST_F(NVFuserTest, FusionClear_CUDA) { fusion.addOutput(tv3); tv3->split(0, 4); - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(-1)->parallelize(ParallelType::TIDx); + + inlineMost(); } // 2. Clear the IR @@ -188,9 +194,12 @@ TEST_F(NVFuserTest, FusionClear_CUDA) { // tv3 [i2, i1, i0outer, i0inner{4}] tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); // tv3 [i0outer, i0inner{4}, i1, i2] - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + tv3->axis(1)->parallelize(ParallelType::BIDx); + + inlineMost(); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -227,12 +236,13 @@ TEST_F(NVFuserTest, FusionCopy_CUDA) { tv3->reorder({{0, 2}, {2, 0}}); tv3->split(-1, 4); tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); - - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); + + inlineMost(); } // Test copy before lowering @@ -301,12 +311,13 @@ TEST_F(NVFuserTest, FusionMove_CUDA) { tv3->reorder({{0, 2}, {2, 0}}); tv3->split(-1, 4); tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); - - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); + + inlineMost(); } std::stringstream original_ir; @@ -869,8 +880,10 @@ TEST_F(NVFuserTest, FusionOuterSplit_CUDA) { //[I0*I1*I2o{4}o, I0*I1*I2o{4}i{2}, I2i] tv2->reorder({{0, 1}, {1, 0}}); // I0*I1*I2o{4}i{2}, [I0*I1*I2o{4}o, I2i] + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); - tv0->computeAt(tv2, -1); + inlineMost(); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -909,8 +922,10 @@ TEST_F(NVFuserTest, FusionCodeGen_CUDA) { //[I0o, I0i{4}*I1, I2o, I2i{2}] tv2 = tv2->reorder({{0, 1}, {1, 0}, {3, 2}}); //[I0i{4}*I1, I0o, I2i{2}, I2o] + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); - tv0->computeAt(tv2, -1); + inlineMost(); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -945,13 +960,14 @@ TEST_F(NVFuserTest, FusionCodeGen2_CUDA) { //[I2, I1, I0o, I0i{4}] tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); // I0o, I0i{4}, I1, I2] - - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); + inlineMost(); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({16, 8, 8}, options); @@ -997,17 +1013,16 @@ TEST_F(NVFuserTest, FusionSimplePWise_CUDA) { // Split by n_threads tv3->split(0, 128); tv3->split(0, 4); - - // For all inputs, computeAt the output inline, temporaries should be squeezed - // between them - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); // Parallelize TV3 tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-2)->parallelize(ParallelType::Unroll); tv3->axis(-1)->parallelize(ParallelType::TIDx); + inlineMost(); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({64, 2, 128}, options); @@ -1055,17 +1070,16 @@ TEST_F(NVFuserTest, FusionSimplePWiseDtypeComplex_CUDA) { // Split by n_threads tv3->split(0, 128); tv3->split(0, 4); - - // For all inputs, computeAt the output inline, temporaries should be squeezed - // between them - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); // Parallelize TV3 tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-2)->parallelize(ParallelType::Unroll); tv3->axis(-1)->parallelize(ParallelType::TIDx); + inlineMost(); + auto options = at::TensorOptions().dtype(at::kComplexFloat).device(at::kCUDA, 0); @@ -1106,11 +1120,8 @@ TEST_F(NVFuserTest, FusionExecKernel_CUDA) { tv3->merge(0); tv3->split(0, 128); tv3->split(0, 4); - - // For all inputs, computeAt the output inline, temporaries should be squeezed - // between them - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); // Parallelize TV3 tv3->axis(0)->parallelize(ParallelType::BIDx); @@ -1119,6 +1130,8 @@ TEST_F(NVFuserTest, FusionExecKernel_CUDA) { tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); + inlineMost(); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::ones({1, 128}, options); diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index e4e40a461f6..678b7dc1bc4 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -683,19 +683,8 @@ TEST_F(NVFuserTest, FusionLSTMCell_CUDA) { aten_inputs.insert(aten_inputs.end(), chunked2.begin(), chunked2.end()); aten_inputs.insert(aten_inputs.end(), chunked3.begin(), chunked3.end()); - auto at_ingate = - chunked0[0].add(chunked0[1]).add(chunked0[2]).add(chunked0[3]).sigmoid(); - auto at_forgetgate = - chunked1[0].add(chunked1[1]).add(chunked1[2]).add(chunked1[3]).sigmoid(); - auto at_cellgate = - chunked2[0].add(chunked2[1]).add(chunked2[2]).add(chunked2[3]).tanh(); - auto at_outgate = - chunked3[0].add(chunked3[1]).add(chunked3[2]).add(chunked3[3]).sigmoid(); - auto at_cx = at::randn({batch_size, hidden_features}, options); aten_inputs.push_back(at_cx); - auto at_cy = at_forgetgate.mul(at_cx).add(at_ingate.mul(at_cellgate)); - auto at_hy = at_outgate.mul(at_cy.tanh()); auto lparams = schedulePointwise(&fusion, aten_inputs); @@ -703,8 +692,7 @@ TEST_F(NVFuserTest, FusionLSTMCell_CUDA) { fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); - testValidate( - &fusion, cg_outputs, aten_inputs, {at_cy, at_hy}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionReductionHalf_CUDA) { @@ -1219,11 +1207,6 @@ TEST_F(NVFuserTest, FusionBiasGeluFwd_CUDA) { auto at_input = at::randn(input_shape, options); auto at_bias = at::randn(bias_shape, options); - auto at_x = at_bias.to(c10::ScalarType::Double) + - at_input.to(c10::ScalarType::Double); - auto aten_output_double = - at_x * 0.5 * (1.0 + (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh()); - std::vector aten_inputs = {at_bias, at_input}; auto lparams = schedulePointwise(&fusion, aten_inputs); @@ -1231,13 +1214,7 @@ TEST_F(NVFuserTest, FusionBiasGeluFwd_CUDA) { fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); - testValidate( - &fusion, - cg_outputs, - aten_inputs, - {aten_output_double}, - __LINE__, - __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionBiasGeluBwd_CUDA) { @@ -1300,16 +1277,7 @@ TEST_F(NVFuserTest, FusionBiasGeluBwd_CUDA) { auto at_bias = at::randn(bias_shape, options); auto at_grad = at::randn(input_shape, options); - auto at_x = at_bias.to(c10::ScalarType::Double) + - at_input.to(c10::ScalarType::Double); - auto at_tanh_out = (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh(); - auto at_ff = 0.5 * at_x * - ((1 - at_tanh_out * at_tanh_out) * (k_079 + k_010 * at_x * at_x)) + - 0.5 * (1 + at_tanh_out); - auto at_out = at_ff * at_grad; - std::vector aten_inputs = {at_grad, at_bias, at_input}; - std::vector aten_outputs = {at_out, at_out}; auto lparams = schedulePointwise(&fusion, aten_inputs); @@ -1325,7 +1293,6 @@ TEST_F(NVFuserTest, FusionBiasGeluBwd_CUDA) { &fusion, cg_outputs, aten_inputs, - aten_outputs, __LINE__, __FILE__, "", @@ -3567,10 +3534,6 @@ TEST_F(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { auto outputs = executor_cache.runFusionWithInputs({at_x}); - auto t1 = at_x.add(1.0); - auto t2 = t1.sum({2}); - auto t3 = at::_softmax(t2.to(at::kDouble), -1, false); - auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); ASSERT_TRUE(optimized_fusion->isSegmented()) << "segmentation didn't happen"; ASSERT_EQ(optimized_fusion->fusionSegments()->groups().size(), 2) @@ -3588,8 +3551,7 @@ TEST_F(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { ASSERT_EQ(rparams->unroll_factor_inner_reduction, 2) << "Unexpected vectorization factor"; - testValidate( - executor_cache.fusion(), outputs, {at_x}, {t3}, __LINE__, __FILE__); + testValidate(executor_cache.fusion(), outputs, {at_x}, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionGridPersistence_CUDA) { @@ -5398,12 +5360,7 @@ TEST_F(NVFuserTest, FusionSBAR_CUDA) { executor.compileFusion(&fusion, inputs, lparams); outputs = executor.runFusion(inputs, lparams); - auto at_scale = at::mul(at_x, at_weight); - auto at_scale_bias = at::add(at_scale, at_bias); - auto pwise_add = at::add(at_scale_bias, at_y); - auto output = at::relu(pwise_add); - - testValidate(&fusion, outputs, inputs, {output}, __LINE__, __FILE__); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionSingleElement_CUDA) { @@ -8263,11 +8220,7 @@ TEST_F(NVFuserTest, FusionPointwiseBroadcast_CUDA) { fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); - auto at_x_add_bias = at_x + at_bias; - auto at_x_view = at::native::view(at_x_add_bias, output_shape); - auto aten_y = at::gelu(at_x_view); - - testValidate(&fusion, outputs, aten_inputs, {aten_y}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionPointwiseVectorize_CUDA) { diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 3569ffba4fa..a2677aa0f78 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -5124,161 +5124,6 @@ TEST_F(NVFuserTest, FusionReplayTrivialReductionAndBroadcast2_CUDA) { testValidate(&fusion, outputs, aten_inputs, {t0 + 1}, __LINE__, __FILE__); } -namespace { - -size_t getVecSizeForPointwise(FusionExecutorCache& fec) { - auto most_recent_params = - fec.getMostRecentKernelRuntime()->getMostRecentExecutorLog().params; - auto params = dynamic_cast(most_recent_params.get()); - if (params->vectorize) { - return params->unroll_factor; - } - return 1; -} - -} // namespace - -TEST_F(NVFuserTest, FusionVectorizeStrideContiguity2D_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = - TensorViewBuilder().ndims(2).contiguity({false, true}).build(); - fusion->addInput(tv0); - auto tv1 = set(tv0); - fusion->addOutput(tv1); - - FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); - - std::vector> size_and_vec{{17, 1}, {18, 2}, {32, 4}}; - - for (auto pair : size_and_vec) { - auto size = pair.first; - auto vec = pair.second; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({1000000, size}, options).narrow(1, 0, 16); - auto cg_outputs = fec.runFusionWithInputs({t0}); - - EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); - - testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionVectorizeStrideContiguity3D_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = - TensorViewBuilder().ndims(3).contiguity({false, true, true}).build(); - fusion->addInput(tv0); - auto tv1 = set(tv0); - fusion->addOutput(tv1); - - FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); - - std::vector> size_and_vec{{17, 1}, {10, 2}, {16, 4}}; - - for (auto pair : size_and_vec) { - auto size = pair.first; - auto vec = pair.second; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({1000000, size, 3}, options).narrow(1, 0, 8); - auto cg_outputs = fec.runFusionWithInputs({t0}); - - EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); - - testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionVectorizeStrideContiguity5D_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = TensorViewBuilder() - .ndims(5) - .contiguity({false, true, false, true, true}) - .build(); - fusion->addInput(tv0); - auto tv1 = set(tv0); - fusion->addOutput(tv1); - - FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - std::vector> sizes_and_vec{ - {9, 17, 1}, {9, 10, 2}, {9, 16, 4}}; - - for (auto tup : sizes_and_vec) { - auto size1 = std::get<0>(tup); - auto size2 = std::get<1>(tup); - auto vec = std::get<2>(tup); - at::Tensor t0 = at::randn({4, size1, 12345, size2, 3}, options) - .narrow(1, 0, 8) - .narrow(3, 0, 4); - auto cg_outputs = fec.runFusionWithInputs({t0}); - - EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); - - testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionVectorizeStrideContiguitySelfOverlapping_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = TensorViewBuilder() - .ndims(5) - .contiguity({false, true, false, true, true}) - .build(); - fusion->addInput(tv0); - auto tv1 = set(tv0); - fusion->addOutput(tv1); - - FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - std::vector> sizes_strides_and_vec{ - {4, 4, 4, 4}, - {4, 4, 2, 2}, - {4, 2, 4, 2}, - {2, 4, 4, 2}, - {4, 4, 1, 1}, - {4, 1, 4, 1}, - {1, 4, 4, 1}, - {2, 2, 2, 2}, - {2, 2, 1, 1}, - {2, 1, 2, 1}, - {1, 2, 2, 1}}; - - for (auto tup : sizes_strides_and_vec) { - auto size = std::get<0>(tup); - auto stride1 = std::get<1>(tup); - auto stride2 = std::get<2>(tup); - auto vec = std::get<3>(tup); - std::vector shape = {4, 4, 12345, size, 3}; - std::vector stride = { - stride1, (int64_t)stride2 * 12345, (int64_t)stride2, 3, 1}; - at::Tensor t0 = at::empty_strided(shape, stride, options); - t0.random_(); - auto cg_outputs = fec.runFusionWithInputs({t0}); - EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); - testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); - } -} - TEST_F(NVFuserTest, FusionSimpleAmperePipeline_CUDA) { Fusion fusion; FusionGuard fg(&fusion); diff --git a/test/test_gpu_compute_with.cpp b/test/test_gpu_compute_with.cpp index 3c2f9ecce4e..b06c7b552b6 100644 --- a/test/test_gpu_compute_with.cpp +++ b/test/test_gpu_compute_with.cpp @@ -168,9 +168,7 @@ TEST_F(NVFuserTest, FusionComputeWith1_CUDA) { fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); - auto ref = t0.sum({1}).unsqueeze(-1) + t0; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // StoreAt with 1D softmax @@ -267,9 +265,7 @@ TEST_F(NVFuserTest, FusionComputeWith3_CUDA) { fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); - auto ref = t0.unsqueeze(0); - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // Compute a tensor that has siblings with a consumer. All of the diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index a55e68dadec..79396971c26 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -68,16 +68,12 @@ TEST_F(NVFuserTest, FusionIndexing1_CUDA) { at::Tensor t0 = at::randn({x, y, z}, options); at::Tensor t1 = at::randn({w, x, y, z}, options); - auto t3 = t0.add(1.0); - auto aten_output = t3.add(t1); - std::vector aten_inputs = {t0, t1}; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } // Same as 1 but merge starting from inner most dimension @@ -123,16 +119,12 @@ TEST_F(NVFuserTest, FusionIndexing2_CUDA) { at::Tensor t0 = at::randn({x, y, z}, options); at::Tensor t1 = at::randn({w, x, y, z}, options); - auto t3 = t0.add(1.0); - auto aten_output = t3.add(t1); - std::vector aten_inputs = {t0, t1}; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } // Same compute as 1 and 2 but use a scheduler. @@ -155,9 +147,6 @@ TEST_F(NVFuserTest, FusionIndexing3_CUDA) { at::Tensor t0 = at::randn({x, y, z}, options); at::Tensor t1 = at::randn({w, x, y, z}, options); - auto t2 = t0.add(1.0); - auto aten_output = t2.add(t1); - std::vector aten_inputs = {t0, t1}; auto lparams = schedulePointwise(&fusion, aten_inputs); @@ -166,8 +155,7 @@ TEST_F(NVFuserTest, FusionIndexing3_CUDA) { fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } // Same as 3 but use 3 dimensions and concrete sizes @@ -190,17 +178,13 @@ TEST_F(NVFuserTest, FusionIndexing4_CUDA) { at::Tensor t0 = at::randn({4, 8}, options); at::Tensor t1 = at::randn({4, 4, 8}, options); - auto t2 = t0.add(1.0); - auto aten_output = t2.add(t1); - std::vector aten_inputs = {t0, t1}; FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionIndexing5_CUDA) { @@ -228,17 +212,13 @@ TEST_F(NVFuserTest, FusionIndexing5_CUDA) { at::Tensor t0 = at::randn({7}, options); at::Tensor t1 = at::randn({5, 7, 11}, options); - auto t2 = t0.add(1.0); - auto aten_output = t2.unsqueeze(-1).add(t1); - std::vector aten_inputs = {t0, t1}; FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionIndexing6_CUDA) { @@ -411,13 +391,7 @@ TEST_F(NVFuserTest, FusionIndexing9_CUDA) { fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); - auto at_t1 = at_t0.unsqueeze(-1); - auto at_t2 = at_t1.mul(2.0); - - auto at_t4 = at_t3.add(at_t2); - - testValidate( - &fusion, cg_outputs, aten_inputs, {at_t2, at_t4}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionIndexing10_CUDA) { @@ -521,16 +495,12 @@ TEST_F(NVFuserTest, FusionIndexing11_CUDA) { at::Tensor t0 = at::randn({w, x, y, z}, options); at::Tensor t1 = at::randn({x}, options); - auto t3 = t1.add(1.0).unsqueeze(-1).unsqueeze(-1); - auto aten_output = t3.add(t0); - std::vector aten_inputs = {t0, t1}; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionIndexing12_CUDA) { @@ -610,20 +580,13 @@ TEST_F(NVFuserTest, FusionIndexing13_CUDA) { at::Tensor t1 = at::randn({y, z}, options); at::Tensor t2 = at::randn({x, y, z}, options); - auto t3 = t0.add(1.0); - auto t4 = t3.unsqueeze(-1); - auto t5 = t4.add(t1); - auto t6 = t5.add(t2); - std::vector aten_inputs = {t0, t1, t2}; - std::vector aten_outputs = {t6}; FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionIndexing14_CUDA) { @@ -659,18 +622,13 @@ TEST_F(NVFuserTest, FusionIndexing14_CUDA) { at::Tensor t0 = at::randn({1, y}, options); at::Tensor t1 = at::randn({x, y}, options); - auto t4 = t0 + 2 + 4; - auto t5 = t0 + 2 + t1 + 3; - std::vector aten_inputs = {t0, t1}; - std::vector aten_outputs = {t4, t5}; FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } // This excercises indexing with broadcast root axes. Non-broadcast @@ -705,11 +663,7 @@ TEST_F(NVFuserTest, FusionIndexing15_CUDA) { fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - auto aten_output = - t0.unsqueeze(-1).expand({bx, by}).unsqueeze(-1).expand({bx, by, bz}) + t3; - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionIndexing16_CUDA) { @@ -734,18 +688,14 @@ TEST_F(NVFuserTest, FusionIndexing16_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({5, 4, 3}, options); at::Tensor t1 = at::randn({5, 3}, options); - auto t2 = t1.unsqueeze(1); - auto t3 = t0 + t2; std::vector aten_inputs = {t0, t1}; - std::vector aten_outputs = {t3}; FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(NVFuserTest, FusionIndexing17_CUDA) { @@ -774,24 +724,13 @@ TEST_F(NVFuserTest, FusionIndexing17_CUDA) { at::Tensor t0 = at::randn({5, 4, 3}, options); at::Tensor t1 = at::randn({4}, options); - auto t2 = t0; - auto t3 = t1; - - std::vector reduction_axes{0, 2}; - auto t4 = t2.sum(reduction_axes); - auto t5 = add(t4, t3); - auto t6 = t3.unsqueeze(0).unsqueeze(-1); - auto t7 = t2.add(t6); - std::vector aten_inputs = {t0, t1}; - std::vector aten_outputs = {t5, t7}; FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } // TODO: Finish and enable test diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index ecaa866412f..07b1c29ee2d 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -5,6 +5,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include + #include #include @@ -43,7 +45,6 @@ #include #include #include -#include #include #include @@ -54,337 +55,12 @@ namespace nvfuser { using namespace at::indexing; -// MMA unit test for a single instruction tile. VoltaTT -TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // [M, K] - auto tv0 = makeConcreteTensor({16, 4}, DataType::Half); - // [K, N] - auto tv1 = makeConcreteTensor({4, 16}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M, K, N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - - fusion.addOutput(tv2); - - // TODO: should be able to completely remove it - // in a follow up. - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 16, 4); - gemm_tile.warp_tile = GemmTile(16, 16, 4); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); - - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaLayout::TT); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - // Write A to smem - auto tv0cw = tv0b->cacheAfter(); - // Read A from smem - auto tv0cr = tv0cw->cacheAfter(); - - // Write B to smem - auto tv1cw = tv1b->cacheAfter(); - - // Read B from smem - auto tv1cr = tv1cw->cacheAfter(); - - // Register accumulator - auto tv2c = tv2->cacheBefore(); - - mma_builder.accumulatorTv(tv2c); - - // [M, K, N]->[M, N, K] - tv0cr->reorder({{-2, -1}, {-1, -2}}); - - // Schedule the instruction tile loops, which is the only - // part we have in this unit test. - // Assumes last 3 dims are mnk - // The innermost loops are dictated by the type of mma used, - // the scheduler needs to use mma_utils::WarpMmaSwizzler to - // get the right thread swizzle. Currently this is the only - // method allowed to schedule the 3/2 inner most loops of - // mma input/output. - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - - // [M, K, N]->[M, N, K] - tv1cr->reorder({{-2, -1}, {-1, -2}}); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - - // [M, K, N]->[M, N, K] - tv2c->reorder({{-2, -1}, {-1, -2}}); - - // Schedule the output instruction tile. - // Assumes last 3 dims are mnk - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - // Set memory type. - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 4}, options); - auto t1 = at::randn({4, 16}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, - 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); - - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} - -// MMA unit test for a single instruction tile. VoltaTN -TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // [M, K] - auto tv0 = makeConcreteTensor({16, 4}, DataType::Half); - // [N, K] - auto tv1 = makeConcreteTensor({16, 4}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M, N, K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - // TODO: should be able to completely remove it - // in a follow up. - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 16, 4); - gemm_tile.warp_tile = GemmTile(16, 16, 4); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); - - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(); - auto tv2c = tv2->cacheBefore(); - - mma_builder.accumulatorTv(tv2c); - - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 4}, options); - auto t1 = at::randn({16, 4}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, - 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} - -// MMA unit test for a single instruction tile. VoltaNT -TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // [K, M] - auto tv0 = makeConcreteTensor({4, 16}, DataType::Half); - // [K, N] - auto tv1 = makeConcreteTensor({4, 16}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [K, M, N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 16, 4); - gemm_tile.warp_tile = GemmTile(16, 16, 4); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); - - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaLayout::NT); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(); - auto tv2c = tv2->cacheBefore(); - - mma_builder.accumulatorTv(tv2c); - - // To MNK - tv0cr->reorder({{0, 2}, {1, 0}, {2, 1}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - - // To MNK - tv1cr->reorder({{0, 2}, {1, 0}, {2, 1}}); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - - tv2c->reorder({{0, 2}, {1, 0}, {2, 1}}); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({4, 16}, options); - auto t1 = at::randn({4, 16}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, - 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionVoltaMMANN_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // [K, M] - auto tv0 = makeConcreteTensor({4, 16}, DataType::Half); - // [N, K] - auto tv1 = makeConcreteTensor({16, 4}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [N, K, M] - auto tv0b = broadcast(tv0, {true, false, false}); - auto tv1b = broadcast(tv1, {false, false, true}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - - // Add implicit permute N, K, M -> M, N, K - tv2->reorder({{-1, 0}}); - tv2->commitLeafToRFactor(); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 16, 4); - gemm_tile.warp_tile = GemmTile(16, 16, 4); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); - - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaLayout::NN); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(); - auto tv2c = tv2->cacheBefore(); - - mma_builder.accumulatorTv(tv2c); - - // To MNK - tv0cr->reorder({{-1, 0}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - - // To MNK - tv1cr->reorder({{-1, 0}}); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({4, 16}, options); - auto t1 = at::randn({16, 4}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, - 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.t().to(at::kFloat).matmul(t1.t().to(at::kFloat)); - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} - -// Matmul test for Volta MMA: across supported layouts -TEST_F(NVFuserTest, FusionVoltaMatmul_CUDA) { +// Matmul test for Ampere MMA: across supported layouts +TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) { // Keep multiples of 8 to keep vectorizable. - int M = 264, N = 136, K = 248; + int M = 504, N = 136, K = 248; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -393,33 +69,36 @@ TEST_F(NVFuserTest, FusionVoltaMatmul_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = matmul(tv0, tv1, layout, false); + auto tv2 = matmul(tv0, tv1, layout, true); fusion.addOutput(tv2); MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Volta_16_16_4; + params.mma_macro = MmaMacro::Ampere_16_8_16; params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 4; scheduleMatmul(&fusion, params); auto inputs = matmulAtInput(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, + 8, 0, fe.compileFusion( &fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams)); - // prologSwizzle on Volta is not supported yet - // ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -427,48 +106,49 @@ TEST_F(NVFuserTest, FusionVoltaMatmul_CUDA) { } } -// Matmul test for Volta MMA: across supported layouts -TEST_F(NVFuserTest, FusionVoltaMatmulRegDoubleBuffer_CUDA) { +TEST_F(NVFuserTest, FusionAmpereMatmulBFloat16_CUDA) { // Keep multiples of 8 to keep vectorizable. - int M = 264, N = 136, K = 248; + int M = 504, N = 136, K = 248; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); + auto tv0 = makeContigTensor(2, DataType::BFloat16); + auto tv1 = makeContigTensor(2, DataType::BFloat16); fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = matmul(tv0, tv1, layout, false); + auto tv2 = matmul(tv0, tv1, layout, true); fusion.addOutput(tv2); MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Volta_16_16_4; + params.mma_macro = MmaMacro::Ampere_16_8_16; params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 4; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput(M, N, K, layout, at::kBFloat16); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, + 8, 0, fe.compileFusion( &fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams)); - // prologSwizzle on Volta is not supported yet - // ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -476,1029 +156,310 @@ TEST_F(NVFuserTest, FusionVoltaMatmulRegDoubleBuffer_CUDA) { } } -// MMA unit test on Ampere -TEST_F(NVFuserTest, FusionAmpereMMATN_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); +// Matmul test for Ampere MMA: with pipelined gmem load +TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); - Fusion fusion; - FusionGuard fg(&fusion); + // Gmem pipeline stage + for (auto stage : {3, 4}) { + for (auto layout : kAllSupportedMmaLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); - // [M, K] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [N, K] - auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); - // [M, N, K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); + auto tv2 = matmul(tv0, tv1, layout, true); - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + fusion.addOutput(tv2); - fusion.addOutput(tv2); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + MatmulParams params; + params.mma_macro = MmaMacro::Ampere_16_8_16; + params.tile_sizes = gemm_tile; + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.smem_double_buffer_stage = stage; + scheduleMatmul(&fusion, params); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); + auto inputs = matmulAtInput(M, N, K, layout); - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams)); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } + } +} - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); - - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); - - // [M, N, K] -> [N, M, K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({8, 16}, options); +// Matmul test for Ampere MMA: checking CTA Swizzles +TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int dim = 8192; + int M = dim, N = dim, K = dim; + const auto all_orders = { + MatmulParams::TileRasterizationOrder::RowMajor, + MatmulParams::TileRasterizationOrder::ColumnMajor}; - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); + REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + auto test = [&](MmaLayout layout, + MatmulParams::TileRasterizationOrder order, + int swizzle, + float& runtime) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} + fusion.addInput(tv0); + fusion.addInput(tv1); -// MMA unit test on Ampere -TEST_F(NVFuserTest, FusionAmpereMMATT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + auto tv2 = matmul(tv0, tv1, layout, true); - Fusion fusion; - FusionGuard fg(&fusion); + fusion.addOutput(tv2); - // [M, K] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [K, N] - auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + optimization::OptimizationPass::runPass( + &fusion); - // [M, N, K] - auto tv0b = broadcast(tv0, {false, true, false}); - // [M, K, N] - auto tv1b = broadcast(tv1, {true, false, false}); - // [M, N, K] - auto tv1t = transpose(tv1b, 1, 2); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); - auto tv2 = fusedMultiplySum(tv0b, tv1t, {2}); + MatmulParams params; + params.mma_macro = MmaMacro::Ampere_16_8_16; + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 3; - fusion.addOutput(tv2); + params.cta_order = order; + params.grid_swizzle_factor = swizzle; - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + scheduleMatmul(&fusion, params); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TT); + auto inputs = matmulAtInput(M, N, K, layout); - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); + FusionExecutor fe; + fe.setMeasureKernelTimeFlag(true); + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams)); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + NVF_CHECK(cg_outputs[0].allclose(tref, 0.01, 0.01)); - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1t; - tv1cr->definition()->as()->setOpType( - LoadStoreOpType::LdMatrixTranspose); + int gdimx = fe.lastLaunchParams().gdimx(); + int gdimy = fe.lastLaunchParams().gdimy(); - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); + int expected_gdim_unswizzled = (dim + 128 - 1) / 128; + int expected_gdimx = expected_gdim_unswizzled * swizzle; + int expected_gdimy = (expected_gdim_unswizzled + swizzle - 1) / swizzle; - // [M, N, K] -> [N, M, K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + NVF_CHECK(gdimx == expected_gdimx); + NVF_CHECK(gdimy == expected_gdimy); - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); + runtime = fe.kernelTimeMs(); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({16, 8}, options); + // Check that mma op is not predicated. This is a regression test for + // https://github.com/NVIDIA/Fuser/issues/95 + class PredicateChecker : public kir::IrVisitor { + public: + using kir::IrVisitor::handle; + bool found_mma = false; - FusionExecutor fe; + private: + void handle(kir::Asm* asm_) final { +#if IS_CPP20 + if (!asm_->code().starts_with("mma")) { +#else + if (asm_->code().substr(0, 3) != "mma") { +#endif + return; + } + found_mma = true; + for (auto expr : scope_exprs_) { + NVF_CHECK( + !expr->isA() || + expr->as()->predicate()->isTrivial(), + "MmaOp should't be predicated!", + " Get predicate ", + expr->as()->predicate()->toInlineString()); + } + } + } pred_checker; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); + GpuLower gpulw(&fusion); + pred_checker.handle(gpulw.run()->topLevelExprs()); + ASSERT_TRUE(pred_checker.found_mma); + }; - auto cg_outputs = fe.runFusion({t0, t1}); + // Checking only a single layout to keep runtime short (compilation overhead) + for (auto layout : {MmaLayout::TT}) { + for (auto order : all_orders) { + float runtime1 = 0; + test(layout, order, 1, runtime1); - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + float runtime4 = 0; + test(layout, order, 4, runtime4); - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); + // GRID Swizzle requires further changes to work in main. So for now we + // don't assert the perf benefit here. + // NVF_CHECK(runtime4 < runtime1); + } + } } -// MMA unit test on Ampere -TEST_F(NVFuserTest, FusionAmpereMMANT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - - // [K, M] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [K, N] - auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [K, M, N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); - - // [M, N, K] - auto tv0t = permute(tv0b, {1, 2, 0}); - auto tv1t = permute(tv1b, {1, 2, 0}); - auto tv2 = fusedMultiplySum(tv0t, tv1t, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::NT); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); +TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0t; - tv0cr->definition()->as()->setOpType( - LoadStoreOpType::LdMatrixTranspose); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1t; - tv1cr->definition()->as()->setOpType( - LoadStoreOpType::LdMatrixTranspose); + // Gmem pipeline stage + for (auto stage : {3, 4}) { + for (auto layout : kAllSupportedMmaLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); + fusion.addInput(tv0); + fusion.addInput(tv1); - // [M, N, K] -> [N, M, K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + auto tv2 = matmul(tv0, tv1, layout, true); - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); + fusion.addOutput(tv2); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({16, 8}, options); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); + MatmulParams params; + params.mma_macro = MmaMacro::Ampere_16_8_16; + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.smem_double_buffer_stage = stage; + params.double_buffer_options.double_buffer_smem_read = true; + scheduleMatmul(&fusion, params); - auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + auto inputs = matmulAtInput(M, N, K, layout); - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams)); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } + } } -// MMA unit test on Ampere -TEST_F(NVFuserTest, FusionAmpereMMANN_CUDA) { +// Matmul-Matmul fusion test on Ampere +TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); + int M = 512, N = 256, K1 = 128, K2 = 128; + + // Fusion definition (Both gemms are TN) + // [M,K1] + auto tv0 = makeContigConcreteTensor({M, K1}, DataType::Half); + // [K2,K1] + auto tv1 = makeContigConcreteTensor({K2, K1}, DataType::Half); + // [N,K2] + auto tv2 = makeContigConcreteTensor({N, K2}, DataType::Half); - // [K, M] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [N, K] - auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); fusion.addInput(tv0); fusion.addInput(tv1); + fusion.addInput(tv2); - // [K, M, N] - auto tv0b = broadcast(tv0, {false, false, true}); - // [M, N, K] + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); auto tv1b = broadcast(tv1, {true, false, false}); + auto tv2b = broadcast(tv2, {true, false, false}); - // [M, N, K] - auto tv0t = permute(tv0b, {1, 2, 0}); - auto tv2 = fusedMultiplySum(tv0t, tv1b, {2}); - - fusion.addOutput(tv2); + // [M,K2,R] + auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + auto tv3h = castOp(DataType::Half, tv3); + auto tv3b = broadcast(tv3h, {false, true, false}); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::NN); + auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); + fusion.addOutput(tv4); - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0t; - tv0cr->definition()->as()->setOpType( - LoadStoreOpType::LdMatrixTranspose); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); + // Fusion: + // Gemm(M,K2,K1) x Gemm(M,N,K2) - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); + MatMulTileOptions gemm_tile1, gemm_tile2; - // [M, N, K] -> [N, M, K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + // cta tile: + // To save register, n of cta tile 1 + // matches k of cta tile2 + gemm_tile1.cta_tile = GemmTile(128, 64, 32); + gemm_tile2.cta_tile = GemmTile(128, 32, 64); - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); + // Distribute to 2x2 warps + gemm_tile1.warp_tile = GemmTile(64, 32, 32); + gemm_tile2.warp_tile = GemmTile(64, 16, 64); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({8, 16}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.t().to(at::kFloat).matmul(t1.t().to(at::kFloat)); - - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} - -// Matmul test for Ampere MMA: across supported layouts -TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) { - // Keep multiples of 8 to keep vectorizable. - int M = 504, N = 136, K = 248; - - for (auto layout : kAllSupportedMatmulLayout) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = matmul(tv0, tv1, layout, true); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.double_buffer_smem_read = true; - params.double_buffer_options.smem_double_buffer_stage = 4; - scheduleMatmul(&fusion, params); - - auto inputs = matmulAtInput(M, N, K, layout); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion( - &fusion, - {inputs.first, inputs.second}, - LaunchParams(), - matmul_cparams)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); - NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); - } -} - -TEST_F(NVFuserTest, FusionAmpereMatmulBFloat16_CUDA) { - // Keep multiples of 8 to keep vectorizable. - int M = 504, N = 136, K = 248; - - for (auto layout : kAllSupportedMatmulLayout) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(2, DataType::BFloat16); - auto tv1 = makeContigTensor(2, DataType::BFloat16); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = matmul(tv0, tv1, layout, true); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.double_buffer_smem_read = true; - params.double_buffer_options.smem_double_buffer_stage = 4; - scheduleMatmul(&fusion, params); - - auto inputs = matmulAtInput(M, N, K, layout, at::kBFloat16); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion( - &fusion, - {inputs.first, inputs.second}, - LaunchParams(), - matmul_cparams)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); - NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); - } -} - -// Matmul test for Ampere MMA: with pipelined gmem load -TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { - // Keep multiples of 8 to keep vectorizable. - int M = 504, N = 136, K = 248; - REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); - - // Gmem pipeline stage - for (auto stage : {3, 4}) { - for (auto layout : kAllSupportedMatmulLayout) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = matmul(tv0, tv1, layout, true); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; - params.tile_sizes = gemm_tile; - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.smem_double_buffer_stage = stage; - scheduleMatmul(&fusion, params); - - auto inputs = matmulAtInput(M, N, K, layout); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion( - &fusion, - {inputs.first, inputs.second}, - LaunchParams(), - matmul_cparams)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); - NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); - } - } -} - -// Matmul test for Ampere MMA: checking CTA Swizzles -TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { - // Keep multiples of 8 to keep vectorizable. - int dim = 8192; - int M = dim, N = dim, K = dim; - const auto all_orders = { - MatmulParams::TileRasterizationOrder::RowMajor, - MatmulParams::TileRasterizationOrder::ColumnMajor}; - - REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); - - auto test = [&](MatmulLayout layout, - MatmulParams::TileRasterizationOrder order, - int swizzle, - float& runtime) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = matmul(tv0, tv1, layout, true); - - fusion.addOutput(tv2); - - optimization::OptimizationPass::runPass( - &fusion); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.double_buffer_smem_read = true; - params.double_buffer_options.smem_double_buffer_stage = 3; - - params.cta_order = order; - params.grid_swizzle_factor = swizzle; - - scheduleMatmul(&fusion, params); - - auto inputs = matmulAtInput(M, N, K, layout); - - FusionExecutor fe; - fe.setMeasureKernelTimeFlag(true); - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion( - &fusion, - {inputs.first, inputs.second}, - LaunchParams(), - matmul_cparams)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); - NVF_CHECK(cg_outputs[0].allclose(tref, 0.01, 0.01)); - - int gdimx = fe.lastLaunchParams().gdimx(); - int gdimy = fe.lastLaunchParams().gdimy(); - - int expected_gdim_unswizzled = (dim + 128 - 1) / 128; - int expected_gdimx = expected_gdim_unswizzled * swizzle; - int expected_gdimy = (expected_gdim_unswizzled + swizzle - 1) / swizzle; - - NVF_CHECK(gdimx == expected_gdimx); - NVF_CHECK(gdimy == expected_gdimy); - - runtime = fe.kernelTimeMs(); - - // Check that mma op is not predicated. This is a regression test for - // https://github.com/NVIDIA/Fuser/issues/95 - class PredicateChecker : public kir::IrVisitor { - public: - using kir::IrVisitor::handle; - bool found_mma = false; - - private: - void handle(MmaOp* uop) final { - found_mma = true; - for (auto expr : scope_exprs_) { - NVF_CHECK( - !expr->isA() || - expr->as()->predicate()->isTrivial(), - "MmaOp should't be predicated!", - " Get predicate ", - expr->as()->predicate()->toInlineString()); - } - } - } pred_checker; - - GpuLower gpulw(&fusion); - pred_checker.handle(gpulw.run()->topLevelExprs()); - ASSERT_TRUE(pred_checker.found_mma); - }; - - // Checking only a single layout to keep runtime short (compilation overhead) - for (auto layout : {MatmulLayout::TT}) { - for (auto order : all_orders) { - float runtime1 = 0; - test(layout, order, 1, runtime1); - - float runtime4 = 0; - test(layout, order, 4, runtime4); - - // GRID Swizzle requires further changes to work in main. So for now we - // don't assert the perf benefit here. - // NVF_CHECK(runtime4 < runtime1); - } - } -} - -TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { - // Keep multiples of 8 to keep vectorizable. - int M = 504, N = 136, K = 248; - REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); - - // Gmem pipeline stage - for (auto stage : {3, 4}) { - for (auto layout : kAllSupportedMatmulLayout) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = matmul(tv0, tv1, layout, true); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.smem_double_buffer_stage = stage; - params.double_buffer_options.double_buffer_smem_read = true; - scheduleMatmul(&fusion, params); - - auto inputs = matmulAtInput(M, N, K, layout); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion( - &fusion, - {inputs.first, inputs.second}, - LaunchParams(), - matmul_cparams)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); - NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); - } - } -} - -// Matmul-Matmul fusion test on Ampere -TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - int M = 512, N = 256, K1 = 128, K2 = 128; - - // Fusion definition (Both gemms are TN) - // [M,K1] - auto tv0 = makeContigConcreteTensor({M, K1}, DataType::Half); - // [K2,K1] - auto tv1 = makeContigConcreteTensor({K2, K1}, DataType::Half); - // [N,K2] - auto tv2 = makeContigConcreteTensor({N, K2}, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - auto tv2b = broadcast(tv2, {true, false, false}); - - // [M,K2,R] - auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); - - auto tv3h = castOp(DataType::Half, tv3); - auto tv3b = broadcast(tv3h, {false, true, false}); - - auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); - - fusion.addOutput(tv4); - - // Fusion: - // Gemm(M,K2,K1) x Gemm(M,N,K2) - - MatMulTileOptions gemm_tile1, gemm_tile2; - - // cta tile: - // To save register, n of cta tile 1 - // matches k of cta tile2 - gemm_tile1.cta_tile = GemmTile(128, 64, 32); - gemm_tile2.cta_tile = GemmTile(128, 32, 64); - - // Distribute to 2x2 warps - gemm_tile1.warp_tile = GemmTile(64, 32, 32); - gemm_tile2.warp_tile = GemmTile(64, 16, 64); - - // Using Ampere mma macro - gemm_tile2.instruction_tile = GemmTile(16, 8, 16); - gemm_tile2.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder1 = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile1) - .layout(MmaOptions::MmaLayout::TN); - - auto mma_builder2 = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile2) - .layout(MmaOptions::MmaLayout::TN); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 2 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 2, got ", - mma_ops.size()); - mma_builder1.configureMma(mma_ops[0]); - mma_builder2.configureMma(mma_ops[1]); - - // Global read for gemm 1 - auto tv0r = tv0->cacheAfter(); - auto tv1r = tv1->cacheAfter(); - - // Global read for gemm 2 - auto tv2r = tv2->cacheAfter(); - - // Gemm 1 main loop read - auto tv0cw = tv0r->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tv1cw = tv1r->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); - - // Gemm 1 accumulator reg - auto tv3c = tv3->cacheBefore(); - mma_builder1.accumulatorTv(tv3c); - - // Gemm 2 main loop read - auto tv3cw = tv3h->cacheAfter(); - auto tv3cr = tv3cw->cacheAfter(LoadStoreOpType::LdMatrix); - - auto tv2cw = tv2r->cacheAfter(); - auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix); - - // Gemm 2 accumulator reg - auto tv4c = tv4->cacheBefore(); - mma_builder2.accumulatorTv(tv4c); - - // General idea is inlining gemm1's main loop inside gemm2's - - // Schedule gemm 2: - // ------------------------------------------------------------------ - tv4->split(-2, gemm_tile2.cta_tile.m); - tv4->split(-1, gemm_tile2.cta_tile.n); - - // 0 1 2 3 - // [Mo,M128, No, N128] - tv4->reorder({{1, 2}, {2, 1}}); - - // 0 1 2 3 - // [Mo,No, M128, N128] - tv2->computeAt(tv4, 2); - tv3->computeAt(tv4, 2); - - // Order K - // 0 1 2 3 4 5 - // [Mo,No, M128, N128, Ko, K32] - tv4c->split(-1, gemm_tile2.cta_tile.k); - tv4c->reorder({{2, 3}, {3, 4}, {4, 2}}); - - // 0 1 2 3 4 5 - // [Mo,No, Ko M128, N128, K32] - tv3->computeAt(tv4c, 3); // Implicitly defines cta tile of gemm1 - tv2r->computeAt(tv4c, 3); - - // Make warp tile - mma_utils::scheduleWarpTileWithReduction(tv4c, gemm_tile2); - mma_utils::scheduleWarpTileWithNoReduction(tv4, gemm_tile2); - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki] - tv3cr->computeAt(tv4c, -4); - tv2cr->computeAt(tv4c, -4); - - // Schedule tv2 gmem read and smem write: - // ---------------------------------------------------------------- - // [No,Ko,N,K] - tv2cw->merge(-2); - tv2r->merge(-2); - - // [No,Ko,i,wy,wx,v] - mma_utils::scheduleContiguousVectorLoad(tv2cw, gemm_tile2, 8); - mma_utils::scheduleContiguousVectorLoad(tv2r, gemm_tile2, 8); - tv2cw->setMemoryType(MemoryType::Shared); - - // Schedule tv2 gmem read and smem write: - // ---------------------------------------------------------------- - - // Schedule gemm 2 mma input - // --------------------------------------------------------------------------- - tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); - - // [... Mi, Ni, Ki] want [Ni, Mi, Ki] - tv3b->reorder({{-2, -3}, {-3, -2}}); - tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); - - tv2cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); - tv2b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); - - // Schedule mma output - // --------------------------------------------------------------------------- - tv4c->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); - tv4->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); - - // Schedule gemm 1: - // ------------------------------------------------------------------ - - // CTA tile: - tv0->computeAt(tv3, 2); - tv1->computeAt(tv3, 2); - - // Schedule K dim for gemm 1: - - // Order K - // 0 1 2 3 4 5 - // [Mo,No, M128, N128, Ko, K32] - tv3c->split(-1, gemm_tile1.cta_tile.k); - tv3c->reorder({{2, 3}, {3, 4}, {4, 2}}); - // 0 1 2 3 4 5 - // [Mo,No, Ko M128, N128, K32] - tv0r->computeAt(tv3c, 3); - tv1r->computeAt(tv3c, 3); - - // Make warp tile: - // ------------------------------------------------------------------------- - mma_utils::scheduleWarpTileWithReduction(tv3c, gemm_tile1); - mma_utils::scheduleWarpTileWithNoReduction(tv3cw, gemm_tile1); - - tv0cr->computeAt(tv3c, -4); - tv1cr->computeAt(tv3c, -4); - - tv3->computeAt(tv3cw, -3); - - // Schedule gmem read and smem write: - // --------------------------------------------------------------------------- - // [Mo,Ko,M,K] - tv0cw->merge(-2); - tv0r->merge(-2); - mma_utils::scheduleContiguousVectorLoad(tv0cw, gemm_tile1, 8); - mma_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile1, 8); - tv0cw->setMemoryType(MemoryType::Shared); - // [Mo,Ko,i,wy,wx,v] - - // [No,Ko,N,K] - tv1cw->merge(-2); - tv1r->merge(-2); - // [No,Ko,i,wy,wx,v] - mma_utils::scheduleContiguousVectorLoad(tv1cw, gemm_tile1, 8); - mma_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile1, 8); - tv1cw->setMemoryType(MemoryType::Shared); - - // Schedule mma input - // --------------------------------------------------------------------------- - tv0cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); - // [... Mi, Ni, Ki] want [Ni, Mi, Ki] - tv0b->reorder({{-2, -3}, {-3, -2}}); - tv0b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); - - tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); - tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); - - // Schedule mma output - // --------------------------------------------------------------------------- - tv3c->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - tv3cw->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - tv3h->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - tv3->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - tv3cw->setMemoryType(MemoryType::Shared); - - // Parallelize - // 0 1 2 3 4 5 6 7 - // [Mo No Mwo Nwo Mw Nw (Mi Ni)] - // Gemm 1 - tv3c->axis(4)->parallelize(ParallelType::TIDz); - tv3c->axis(5)->parallelize(ParallelType::TIDy); - - tv3->computeAt(tv3cw, -2); - tv3cw->axis(2)->parallelize(ParallelType::TIDz); - tv3cw->axis(3)->parallelize(ParallelType::TIDy); - - // Gemm 2 - tv4->axis(2)->parallelize(ParallelType::TIDz); - tv4->axis(3)->parallelize(ParallelType::TIDy); - tv4c->axis(4)->parallelize(ParallelType::TIDz); - tv4c->axis(5)->parallelize(ParallelType::TIDy); - - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(1)->parallelize(ParallelType::BIDy); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K1}, options); - auto t1 = at::randn({K2, K1}, options); - auto t2 = at::randn({N, K2}, options); - - auto tref = t0.to(at::kFloat) - .matmul(t1.t().to(at::kFloat)) - .matmul(t2.t().to(at::kFloat)); - - FusionExecutor fe; - - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion(&fusion, {t0, t1, t2}, LaunchParams(), matmul_cparams)); - - auto cg_outputs = fe.runFusion({t0, t1, t2}); - - // relaxed check for now, err accumulation is significant. - NVF_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); -} - -// Simplified Matmul-Softmax-Matmul test on Ampere -// (To be extended in follow ups) -TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - - // Omitting outer dimensions and pointwise ops - - const int seql_q = 32; - const int seql_k = 128; - const int hidden_size = 1024; - const int num_heads = 16; - const int head_dim = hidden_size / num_heads; - - // Gemm 1: - // (80, 80, 64) - const int M1 = seql_q, N1 = seql_k, K1 = head_dim; - // (64, 80) - const int N2 = head_dim, K2 = seql_k; - - // Fusion definition (Both gemms are TN) - // [M,K1] - auto inp = makeContigConcreteTensor({M1, K1}, DataType::Half); - // Query matrix - auto qk = makeContigConcreteTensor({N1, K1}, DataType::Half); - // Second linear matrix - auto acc = makeContigConcreteTensor({N2, K2}, DataType::Half); - - fusion.addInput(inp); - fusion.addInput(qk); - fusion.addInput(acc); - - // [M,N,K] - auto tv0b = broadcast(inp, {false, true, false}); - auto tv1b = broadcast(qk, {true, false, false}); - auto tv2b = broadcast(acc, {true, false, false}); - - // [M,K2,R] - auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); - - // Inline define softmax for now for scheduling - auto x = tv3; - const int kReductionAxis = 1; - const int kNumberOfDims = 2; - std::vector broadcast_mask(kNumberOfDims, false); - broadcast_mask[kReductionAxis] = true; - - auto max_val = max(x, {kReductionAxis}); - auto bcast_max = broadcast(max_val, broadcast_mask); - auto x_max_sub = sub(x, bcast_max); - auto exp_val = exp(x_max_sub); - auto sum_exp = sum(exp_val, {kReductionAxis}); - auto bcast_sum = broadcast(sum_exp, broadcast_mask); - auto recip = reciprocal(bcast_sum); - auto tv3sfm = mul(exp_val, recip); - - auto tv3h = castOp(DataType::Half, tv3sfm); - auto tv3b = broadcast(tv3h, {false, true, false}); - auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); - - fusion.addOutput(tv4); - - // Fusion: - // Gemm(M,K2,K1) x Gemm(M,N,K2) - MatMulTileOptions gemm_tile; - - // TODO: use very small tiles for now since - // alias pass is not re-using smem. Fix later. - gemm_tile.cta_tile = GemmTile(32, 128, 32); - - // Distribute to 2x2 warps - gemm_tile.warp_tile = GemmTile(16, 64, 32); - - // Using Ampere mma macro - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder1 = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); - - auto mma_builder2 = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); + // Using Ampere mma macro + gemm_tile2.instruction_tile = GemmTile(16, 8, 16); + gemm_tile2.instruction_tile = GemmTile(16, 8, 16); auto mma_ops = ir_utils::getOpsOfType(&fusion); NVF_CHECK( 2 == mma_ops.size(), "Invalid number of MmaOp instances in fusion definition, expected 2, got ", mma_ops.size()); - mma_builder1.configureMma(mma_ops[0]); - mma_builder2.configureMma(mma_ops[1]); + mma_ops[0]->setMacro(MmaMacro::Ampere_16_8_16); + mma_ops[1]->setMacro(MmaMacro::Ampere_16_8_16); // Global read for gemm 1 - auto tv0r = inp->cacheAfter(); - auto tv1r = qk->cacheAfter(); + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); // Global read for gemm 2 - auto tv2r = acc->cacheAfter(); + auto tv2r = tv2->cacheAfter(); // Gemm 1 main loop read auto tv0cw = tv0r->cacheAfter(); @@ -1508,27 +469,23 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // Gemm 1 accumulator reg auto tv3c = tv3->cacheBefore(); - mma_builder1.accumulatorTv(tv3c); - - // Softmax conversion: - auto tv3ccr = tv3->cacheAfter(); - - // tv3ccr -> tv3h : softmax // Gemm 2 main loop read - auto tv3cr = tv3h->cacheAfter(LoadStoreOpType::LdMatrix); + auto tv3cw = tv3h->cacheAfter(); + auto tv3cr = tv3cw->cacheAfter(LoadStoreOpType::LdMatrix); auto tv2cw = tv2r->cacheAfter(); auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix); // Gemm 2 accumulator reg auto tv4c = tv4->cacheBefore(); - mma_builder2.accumulatorTv(tv4c); + + // General idea is inlining gemm1's main loop inside gemm2's // Schedule gemm 2: // ------------------------------------------------------------------ - tv4->split(-2, gemm_tile.cta_tile.m); - tv4->split(-1, gemm_tile.cta_tile.n); + tv4->split(-2, gemm_tile2.cta_tile.m); + tv4->split(-1, gemm_tile2.cta_tile.n); // 0 1 2 3 // [Mo,M128, No, N128] @@ -1536,24 +493,24 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // 0 1 2 3 // [Mo,No, M128, N128] - acc->computeAt(tv4, 2); + tv2->computeAt(tv4, 2); tv3->computeAt(tv4, 2); // Order K // 0 1 2 3 4 5 // [Mo,No, M128, N128, Ko, K32] - tv4c->split(-1, gemm_tile.cta_tile.k); + tv4c->split(-1, gemm_tile2.cta_tile.k); tv4c->reorder({{2, 3}, {3, 4}, {4, 2}}); // 0 1 2 3 4 5 // [Mo,No, Ko M128, N128, K32] - tv3->computeAt(tv4c, 2); + tv3->computeAt(tv4c, 3); // Implicitly defines cta tile of gemm1 tv2r->computeAt(tv4c, 3); // Make warp tile - mma_utils::scheduleWarpTileWithReduction(tv4c, gemm_tile); - mma_utils::scheduleWarpTileWithNoReduction(tv4, gemm_tile); - // -8 -7 -6 -5 -4 -3 -2 -1 + mma_utils::scheduleWarpTileWithReduction(tv4c, gemm_tile2); + mma_utils::scheduleWarpTileWithNoReduction(tv4, gemm_tile2); + // -8 -7 -6 -5 -4 -3 -2 -1 // [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki] tv3cr->computeAt(tv4c, -4); tv2cr->computeAt(tv4c, -4); @@ -1565,8 +522,8 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { tv2r->merge(-2); // [No,Ko,i,wy,wx,v] - mma_utils::scheduleContiguousVectorLoad(tv2cw, gemm_tile, 8); - mma_utils::scheduleContiguousVectorLoad(tv2r, gemm_tile, 8); + mma_utils::scheduleContiguousVectorLoad(tv2cw, gemm_tile2, 8); + mma_utils::scheduleContiguousVectorLoad(tv2r, gemm_tile2, 8); tv2cw->setMemoryType(MemoryType::Shared); // Schedule tv2 gmem read and smem write: @@ -1574,42 +531,33 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // Schedule gemm 2 mma input // --------------------------------------------------------------------------- - tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); + tv3cr->applyMmaSwizzle(MmaOperand::A); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] tv3b->reorder({{-2, -3}, {-3, -2}}); - tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); + tv3b->applyMmaSwizzle(MmaOperand::A); - tv2cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); - tv2b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); + tv2cr->applyMmaSwizzle(MmaOperand::B); + tv2b->applyMmaSwizzle(MmaOperand::B); // Schedule mma output // --------------------------------------------------------------------------- - tv4c->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); - tv4->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); + tv4c->applyMmaSwizzle(MmaOperand::Accumulator); + tv4->applyMmaSwizzle(MmaOperand::Accumulator); // Schedule gemm 1: // ------------------------------------------------------------------ // CTA tile: - // [Mo, Mi128, N80] - - tv3->split(-1, gemm_tile.cta_tile.n); - // [Mo, Mi128, No, Ni128] - - tv3->reorder({{1, 2}, {2, 1}}); - - // [Mo, No, Mi128, Ni128] - inp->computeAt(tv3, 2); - qk->computeAt(tv3, 2); + tv0->computeAt(tv3, 2); + tv1->computeAt(tv3, 2); // Schedule K dim for gemm 1: // Order K // 0 1 2 3 4 5 // [Mo,No, M128, N128, Ko, K32] - tv3c->split(-1, gemm_tile.cta_tile.k); + tv3c->split(-1, gemm_tile1.cta_tile.k); tv3c->reorder({{2, 3}, {3, 4}, {4, 2}}); // 0 1 2 3 4 5 // [Mo,No, Ko M128, N128, K32] @@ -1618,21 +566,21 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // Make warp tile: // ------------------------------------------------------------------------- - mma_utils::scheduleWarpTileWithReduction(tv3c, gemm_tile); - mma_utils::scheduleWarpTileWithNoReduction(tv3, gemm_tile); + mma_utils::scheduleWarpTileWithReduction(tv3c, gemm_tile1); + mma_utils::scheduleWarpTileWithNoReduction(tv3cw, gemm_tile1); tv0cr->computeAt(tv3c, -4); tv1cr->computeAt(tv3c, -4); - // tv3->computeAt(tv3cw,-3); + tv3->computeAt(tv3cw, -3); // Schedule gmem read and smem write: // --------------------------------------------------------------------------- // [Mo,Ko,M,K] tv0cw->merge(-2); tv0r->merge(-2); - mma_utils::scheduleContiguousVectorLoad(tv0cw, gemm_tile, 8); - mma_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile, 8); + mma_utils::scheduleContiguousVectorLoad(tv0cw, gemm_tile1, 8); + mma_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile1, 8); tv0cw->setMemoryType(MemoryType::Shared); // [Mo,Ko,i,wy,wx,v] @@ -1640,838 +588,491 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { tv1cw->merge(-2); tv1r->merge(-2); // [No,Ko,i,wy,wx,v] - mma_utils::scheduleContiguousVectorLoad(tv1cw, gemm_tile, 8); - mma_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile, 8); + mma_utils::scheduleContiguousVectorLoad(tv1cw, gemm_tile1, 8); + mma_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile1, 8); tv1cw->setMemoryType(MemoryType::Shared); // Schedule mma input // --------------------------------------------------------------------------- - tv0cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); + tv0cr->applyMmaSwizzle(MmaOperand::A); // [... Mi, Ni, Ki] want [Ni, Mi, Ki] tv0b->reorder({{-2, -3}, {-3, -2}}); - tv0b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); + tv0b->applyMmaSwizzle(MmaOperand::A); - tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); - tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); + tv1cr->applyMmaSwizzle(MmaOperand::B); + tv1b->applyMmaSwizzle(MmaOperand::B); - // // Schedule mma output - // // + // Schedule mma output // --------------------------------------------------------------------------- - tv3c->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - tv3->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); - - // mma_utils::WarpMmaSwizzler::scheduleMmaWarpOutput(tv3ccw, - // mma_builder1.build()); - - // Put tv3 result in smem - tv3->setMemoryType(MemoryType::Shared); - - // schedule a reg persistent softmax: from tv3 - // [Mo, M128, RN] - max_val->split(-1, 128); - // [Mo, M128, RN1, RN128] - max_val->split(-1, 4); - // Map to warp (2x2) - max_val->split(-4, 4); - max_val->split(-4, 2); - - // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] - auto max_rf = max_val->rFactor({-1}); - // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] - - // [Mo, M128, RN] - sum_exp->split(-1, 128); - // [Mo, M128, RN1, RN128] - sum_exp->split(-1, 4); - // Map to warp (2x2) - sum_exp->split(-4, 4); - sum_exp->split(-4, 2); - - // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] - auto sum_exp_rf = sum_exp->rFactor({-1}); - // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] - - exp_val->computeAt(sum_exp_rf, 4); - exp_val->split(-1, 128); - exp_val->split(-1, 4); - bcast_max->computeAt(exp_val, -2); - - // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] - - // Read from smem - tv3ccr->computeAt(max_rf, 4); - // [Mo, Mo32, My2, Mx2, N80] - tv3ccr->split(-1, 128); - tv3ccr->split(-1, 4); - // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] - - // Write to second gemm - tv3h->split(-1, 128); - tv3h->split(-1, 4); - // Map to warp (2x2) - tv3h->split(-4, 4); - tv3h->split(-4, 2); - - bcast_sum->computeAt(tv3h, -2); - - tv3h->setMemoryType(MemoryType::Shared); + tv3c->applyMmaSwizzle(MmaOperand::Accumulator); + tv3cw->applyMmaSwizzle(MmaOperand::Accumulator); + tv3h->applyMmaSwizzle(MmaOperand::Accumulator); + tv3->applyMmaSwizzle(MmaOperand::Accumulator); + tv3cw->setMemoryType(MemoryType::Shared); // Parallelize - tv4->axis(0)->parallelize(ParallelType::BIDx); - // 0 1 2 3 4 5 6 7 - // [Mo No Mwo Nwo Mw Nw (Mi Ni)] - // Gemm 1 - tv3c->axis(4)->parallelize(ParallelType::TIDz); - tv3c->axis(5)->parallelize(ParallelType::TIDy); - tv3->axis(2)->parallelize(ParallelType::TIDz); - tv3->axis(3)->parallelize(ParallelType::TIDy); - - auto parallelize_non_reduced_val = [](TensorView* tv) { - tv->axis(-2)->parallelize(ParallelType::TIDx); - tv->axis(2)->parallelize(ParallelType::TIDz); - tv->axis(3)->parallelize(ParallelType::TIDy); - }; - - auto parallelize_reduced_val = [](TensorView* tv) { - tv->axis(-1)->parallelize(ParallelType::TIDx); - tv->axis(2)->parallelize(ParallelType::TIDz); - tv->axis(3)->parallelize(ParallelType::TIDy); - }; - - parallelize_non_reduced_val(tv3h); - parallelize_non_reduced_val(max_rf); - parallelize_non_reduced_val(bcast_max); - parallelize_non_reduced_val(exp_val); - parallelize_non_reduced_val(sum_exp_rf); - parallelize_non_reduced_val(bcast_sum); - parallelize_non_reduced_val(recip); - - parallelize_reduced_val(max_val); - parallelize_reduced_val(sum_exp); - - // 0 1 2 3 4 5 6 7 - // [Mo No Mwo Nwo Mw Nw (Mi Ni)] - // Gemm 2 - tv4->axis(2)->parallelize(ParallelType::TIDz); - tv4->axis(3)->parallelize(ParallelType::TIDy); - tv4c->axis(4)->parallelize(ParallelType::TIDz); - tv4c->axis(5)->parallelize(ParallelType::TIDy); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M1, K1}, options); - auto t1 = at::randn({N1, K1}, options); - auto t2 = at::randn({N2, K2}, options); - - FusionExecutor fe; - - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion(&fusion, {t0, t1, t2}, LaunchParams(), matmul_cparams)); - - auto cg_outputs = fe.runFusion({t0, t1, t2}); - - auto g1 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - auto sg1 = at::_softmax(g1, -1, false); - auto gsg1 = sg1.matmul(t2.t().to(at::kFloat)); - - NVF_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001)); -} - -// MMA unit test on Turing -TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // [M, K] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [N, K] - auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M, N, K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); - - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); - - // [M, N, K] -> [N, M, K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({8, 16}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, - 5, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} - -// MMA unit test on Turing -TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // [M, K] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [K, N] - auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M, N, K] - auto tv0b = broadcast(tv0, {false, true, false}); - // [M, K, N] - auto tv1b = broadcast(tv1, {true, false, false}); - // [M, N, K] - auto tv1t = transpose(tv1b, 1, 2); - - auto tv2 = fusedMultiplySum(tv0b, tv1t, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TT); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1t; - tv1cr->definition()->as()->setOpType( - LoadStoreOpType::LdMatrixTranspose); + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + // Gemm 1 + tv3c->axis(4)->parallelize(ParallelType::TIDz); + tv3c->axis(5)->parallelize(ParallelType::TIDy); - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); + tv3->computeAt(tv3cw, -2); + tv3cw->axis(2)->parallelize(ParallelType::TIDz); + tv3cw->axis(3)->parallelize(ParallelType::TIDy); - // [M, N, K] -> [N, M, K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + // Gemm 2 + tv4->axis(2)->parallelize(ParallelType::TIDz); + tv4->axis(3)->parallelize(ParallelType::TIDy); + tv4c->axis(4)->parallelize(ParallelType::TIDz); + tv4c->axis(5)->parallelize(ParallelType::TIDy); - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::BIDy); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({16, 8}, options); + auto t0 = at::randn({M, K1}, options); + auto t1 = at::randn({K2, K1}, options); + auto t2 = at::randn({N, K2}, options); + + auto tref = t0.to(at::kFloat) + .matmul(t1.t().to(at::kFloat)) + .matmul(t2.t().to(at::kFloat)); FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, - 5, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion(&fusion, {t0, t1, t2}, LaunchParams(), matmul_cparams)); - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + auto cg_outputs = fe.runFusion({t0, t1, t2}); - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); + // relaxed check for now, err accumulation is significant. + NVF_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); } -// MMA unit test on Turing -TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { +// Simplified Matmul-Softmax-Matmul test on Ampere +// (To be extended in follow ups) +TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + Fusion fusion; FusionGuard fg(&fusion); - // [K, M] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [K, N] - auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [K, M, N] - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, true, false}); - - // [M, N, K] - auto tv0t = permute(tv0b, {1, 2, 0}); - auto tv1t = permute(tv1b, {1, 2, 0}); - auto tv2 = fusedMultiplySum(tv0t, tv1t, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::NT); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0t; - tv0cr->definition()->as()->setOpType( - LoadStoreOpType::LdMatrixTranspose); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1t; - tv1cr->definition()->as()->setOpType( - LoadStoreOpType::LdMatrixTranspose); + // Omitting outer dimensions and pointwise ops - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); + const int seql_q = 32; + const int seql_k = 128; + const int hidden_size = 1024; + const int num_heads = 16; + const int head_dim = hidden_size / num_heads; - // [K,M,N] -> [N,M,K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + // Gemm 1: + // (80, 80, 64) + const int M1 = seql_q, N1 = seql_k, K1 = head_dim; + // (64, 80) + const int N2 = head_dim, K2 = seql_k; - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); + // Fusion definition (Both gemms are TN) + // [M,K1] + auto inp = makeContigConcreteTensor({M1, K1}, DataType::Half); + // Query matrix + auto qk = makeContigConcreteTensor({N1, K1}, DataType::Half); + // Second linear matrix + auto acc = makeContigConcreteTensor({N2, K2}, DataType::Half); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({16, 8}, options); + fusion.addInput(inp); + fusion.addInput(qk); + fusion.addInput(acc); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, - 5, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); + // [M,N,K] + auto tv0b = broadcast(inp, {false, true, false}); + auto tv1b = broadcast(qk, {true, false, false}); + auto tv2b = broadcast(acc, {true, false, false}); - auto cg_outputs = fe.runFusion({t0, t1}); + // [M,K2,R] + auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); - auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + // Inline define softmax for now for scheduling + auto x = tv3; + const int kReductionAxis = 1; + const int kNumberOfDims = 2; + std::vector broadcast_mask(kNumberOfDims, false); + broadcast_mask[kReductionAxis] = true; - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} + auto max_val = max(x, {kReductionAxis}); + auto bcast_max = broadcast(max_val, broadcast_mask); + auto x_max_sub = sub(x, bcast_max); + auto exp_val = exp(x_max_sub); + auto sum_exp = sum(exp_val, {kReductionAxis}); + auto bcast_sum = broadcast(sum_exp, broadcast_mask); + auto recip = reciprocal(bcast_sum); + auto tv3sfm = mul(exp_val, recip); -// MMA unit test on Ampere -TEST_F(NVFuserTest, FusionTuringMMANN_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); + auto tv3h = castOp(DataType::Half, tv3sfm); + auto tv3b = broadcast(tv3h, {false, true, false}); + auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); - // [K, M] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [N, K] - auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + fusion.addOutput(tv4); - // [K, M, N] - auto tv0b = broadcast(tv0, {false, false, true}); - // [M, N, K] - auto tv1b = broadcast(tv1, {true, false, false}); + // Fusion: + // Gemm(M,K2,K1) x Gemm(M,N,K2) + MatMulTileOptions gemm_tile; - // [M, N, K] - auto tv0t = permute(tv0b, {1, 2, 0}); - auto tv2 = fusedMultiplySum(tv0t, tv1b, {2}); + // TODO: use very small tiles for now since + // alias pass is not re-using smem. Fix later. + gemm_tile.cta_tile = GemmTile(32, 128, 32); - fusion.addOutput(tv2); + // Distribute to 2x2 warps + gemm_tile.warp_tile = GemmTile(16, 64, 32); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); + // Using Ampere mma macro gemm_tile.instruction_tile = GemmTile(16, 8, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::NN); - auto mma_ops = ir_utils::getOpsOfType(&fusion); NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", + 2 == mma_ops.size(), + "Invalid number of MmaOp instances in fusion definition, expected 2, got ", mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0t; - tv0cr->definition()->as()->setOpType( - LoadStoreOpType::LdMatrixTranspose); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); - - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); + mma_ops[0]->setMacro(MmaMacro::Ampere_16_8_16); + mma_ops[1]->setMacro(MmaMacro::Ampere_16_8_16); - // [M, N, K] -> [N, M, K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); + // Global read for gemm 1 + auto tv0r = inp->cacheAfter(); + auto tv1r = qk->cacheAfter(); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({8, 16}, options); + // Global read for gemm 2 + auto tv2r = acc->cacheAfter(); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, - 5, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); + // Gemm 1 main loop read + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tref = t0.t().to(at::kFloat).matmul(t1.t().to(at::kFloat)); + // Gemm 1 accumulator reg + auto tv3c = tv3->cacheBefore(); - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); -} + // Softmax conversion: + auto tv3ccr = tv3->cacheAfter(); -// Matmul test for Turing MMA: across supported layouts -TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) { - // Keep multiples of 8 to keep vectorizable. - int M = 504, N = 136, K = 248; + // tv3ccr -> tv3h : softmax - for (auto layout : kAllSupportedMatmulLayout) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); + // Gemm 2 main loop read + auto tv3cr = tv3h->cacheAfter(LoadStoreOpType::LdMatrix); - fusion.addInput(tv0); - fusion.addInput(tv1); + auto tv2cw = tv2r->cacheAfter(); + auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tv2 = matmul(tv0, tv1, layout, true); + // Gemm 2 accumulator reg + auto tv4c = tv4->cacheBefore(); - fusion.addOutput(tv2); + // Schedule gemm 2: + // ------------------------------------------------------------------ + tv4->split(-2, gemm_tile.cta_tile.m); + tv4->split(-1, gemm_tile.cta_tile.n); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + // 0 1 2 3 + // [Mo,M128, No, N128] + tv4->reorder({{1, 2}, {2, 1}}); - MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Turing_16_8_16; - params.tile_sizes = gemm_tile; - scheduleMatmul(&fusion, params); + // 0 1 2 3 + // [Mo,No, M128, N128] + acc->computeAt(tv4, 2); + tv3->computeAt(tv4, 2); - auto inputs = matmulAtInput(M, N, K, layout); + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv4c->split(-1, gemm_tile.cta_tile.k); + tv4c->reorder({{2, 3}, {3, 4}, {4, 2}}); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second})); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); - NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); - } -} + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv3->computeAt(tv4c, 2); + tv2r->computeAt(tv4c, 3); -// Matmul test on ampere, using ampere memory ops -TEST_F(NVFuserTest, FusionAmpereMatmulTNcpAsync_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); + // Make warp tile + mma_utils::scheduleWarpTileWithReduction(tv4c, gemm_tile); + mma_utils::scheduleWarpTileWithNoReduction(tv4, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki] + tv3cr->computeAt(tv4c, -4); + tv2cr->computeAt(tv4c, -4); - int M = 255, N = 511, K = 88; + // Schedule tv2 gmem read and smem write: + // ---------------------------------------------------------------- + // [No,Ko,N,K] + tv2cw->merge(-2); + tv2r->merge(-2); - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + // [No,Ko,i,wy,wx,v] + mma_utils::scheduleContiguousVectorLoad(tv2cw, gemm_tile, 8); + mma_utils::scheduleContiguousVectorLoad(tv2r, gemm_tile, 8); + tv2cw->setMemoryType(MemoryType::Shared); - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); + // Schedule tv2 gmem read and smem write: + // ---------------------------------------------------------------- - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + // Schedule gemm 2 mma input + // --------------------------------------------------------------------------- + tv3cr->applyMmaSwizzle(MmaOperand::A); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv3b->reorder({{-2, -3}, {-3, -2}}); + tv3b->applyMmaSwizzle(MmaOperand::A); - fusion.addOutput(tv2); + tv2cr->applyMmaSwizzle(MmaOperand::B); + tv2b->applyMmaSwizzle(MmaOperand::B); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + // Schedule mma output + // --------------------------------------------------------------------------- + tv4c->applyMmaSwizzle(MmaOperand::Accumulator); + tv4->applyMmaSwizzle(MmaOperand::Accumulator); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); + // Schedule gemm 1: + // ------------------------------------------------------------------ - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); + // CTA tile: + // [Mo, Mi128, N80] - auto tv0cw = tv0->cacheAfter(LoadStoreOpType::CpAsync); - auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tv1cw = tv1->cacheAfter(LoadStoreOpType::CpAsync); - auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); + tv3->split(-1, gemm_tile.cta_tile.n); + // [Mo, Mi128, No, Ni128] - // Make a CTA tile - // ------------------------------------------------------------------ - // [M,N] - tv2->split(-2, gemm_tile.cta_tile.m); - tv2->split(-1, gemm_tile.cta_tile.n); + tv3->reorder({{1, 2}, {2, 1}}); - // 0 1 2 3 - // [Mo,M128, No, N128] - tv2->reorder({{1, 2}, {2, 1}}); + // [Mo, No, Mi128, Ni128] + inp->computeAt(tv3, 2); + qk->computeAt(tv3, 2); - // 0 1 2 3 - // [Mo,No, M128, N128] - tv0->computeAt(tv2, 2); - tv1->computeAt(tv2, 2); + // Schedule K dim for gemm 1: // Order K // 0 1 2 3 4 5 // [Mo,No, M128, N128, Ko, K32] - tv2c->split(-1, gemm_tile.cta_tile.k); - tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); - + tv3c->split(-1, gemm_tile.cta_tile.k); + tv3c->reorder({{2, 3}, {3, 4}, {4, 2}}); // 0 1 2 3 4 5 // [Mo,No, Ko M128, N128, K32] - tv0cw->computeAt(tv2c, 3); - tv1cw->computeAt(tv2c, 3); + tv0r->computeAt(tv3c, 3); + tv1r->computeAt(tv3c, 3); // Make warp tile: // ------------------------------------------------------------------------- - mma_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); - mma_utils::scheduleWarpTileWithNoReduction(tv2, gemm_tile); - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki] - tv0cr->computeAt(tv2c, -4); - tv1cr->computeAt(tv2c, -4); + mma_utils::scheduleWarpTileWithReduction(tv3c, gemm_tile); + mma_utils::scheduleWarpTileWithNoReduction(tv3, gemm_tile); + + tv0cr->computeAt(tv3c, -4); + tv1cr->computeAt(tv3c, -4); + + // tv3->computeAt(tv3cw,-3); // Schedule gmem read and smem write: // --------------------------------------------------------------------------- // [Mo,Ko,M,K] tv0cw->merge(-2); + tv0r->merge(-2); mma_utils::scheduleContiguousVectorLoad(tv0cw, gemm_tile, 8); + mma_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile, 8); tv0cw->setMemoryType(MemoryType::Shared); // [Mo,Ko,i,wy,wx,v] // [No,Ko,N,K] tv1cw->merge(-2); + tv1r->merge(-2); // [No,Ko,i,wy,wx,v] mma_utils::scheduleContiguousVectorLoad(tv1cw, gemm_tile, 8); + mma_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile, 8); tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input // --------------------------------------------------------------------------- - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - // [... Mi, Ni, Ki] + tv0cr->applyMmaSwizzle(MmaOperand::A); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] tv0b->reorder({{-2, -3}, {-3, -2}}); - tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv0b->applyMmaSwizzle(MmaOperand::A); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1cr->applyMmaSwizzle(MmaOperand::B); + tv1b->applyMmaSwizzle(MmaOperand::B); - // Schedule mma output + // // Schedule mma output + // // // --------------------------------------------------------------------------- - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - // Parallelize - // 0 1 2 3 4 5 6 7 8 9 10 - // [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)] - tv2c->axis(4)->parallelize(ParallelType::TIDz); - tv2c->axis(5)->parallelize(ParallelType::TIDy); - - // Parallelize - // 0 1 2 3 4 5 6 7 - // [Mo No Mwo Nwo Mw Nw (Mi Ni)] - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::BIDy); - tv2->axis(2)->parallelize(ParallelType::TIDz); - tv2->axis(3)->parallelize(ParallelType::TIDy); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - - NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -TEST_F(NVFuserTest, FusionAmpereStridedBatchedMatmulTN_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - int64_t M = 511, N = 123, K = 88, B0 = 3, B1 = 5; - - // [B0 ,M, B1, K] - auto tv0 = makeContigTensor(4, DataType::Half); - // [B0, N, B1, K] - auto tv1 = makeContigTensor(4, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [B0, M, N, B1, K] - auto tv0b = broadcast(tv0, {false, false, true, false, false}); - auto tv1b = broadcast(tv1, {false, true, false, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {4}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0r = tv0->cacheAfter(); - auto tv1r = tv1->cacheAfter(); - auto tv0cw = tv0r->cacheAfter(); - auto tv0cr = - tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); - auto tv1cw = tv1r->cacheAfter(); - auto tv1cr = - tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); - auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); + tv3c->applyMmaSwizzle(MmaOperand::Accumulator); + tv3->applyMmaSwizzle(MmaOperand::Accumulator); - // Group the BATCHED DIMS: - // -4 -3 -2 -1 - // [B0, M, N, B1] - tv2->reorder({{-3, -2}, {-2, -1}, {-1, -4}}); + // Put tv3 result in smem + tv3->setMemoryType(MemoryType::Shared); - // -4 -3 -2 -1 - // [B0, B1, M, N] + // schedule a reg persistent softmax: from tv3 + // [Mo, M128, RN] + max_val->split(-1, 128); + // [Mo, M128, RN1, RN128] + max_val->split(-1, 4); + // Map to warp (2x2) + max_val->split(-4, 4); + max_val->split(-4, 2); - // Make a CTA tile - // ------------------------------------------------------------------ - // [B0, B1, M, N] - tv2->split(-2, gemm_tile.cta_tile.m); - tv2->split(-1, gemm_tile.cta_tile.n); + // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] + auto max_rf = max_val->rFactor({-1}); + // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] - // 0 1 2 3 4 5 - // [B0, B1, Mo, M128, No, N128] - tv2->reorder({{-3, -2}, {-2, -3}}); + // [Mo, M128, RN] + sum_exp->split(-1, 128); + // [Mo, M128, RN1, RN128] + sum_exp->split(-1, 4); + // Map to warp (2x2) + sum_exp->split(-4, 4); + sum_exp->split(-4, 2); - // 0 1 2 3 4 5 - // [B0, B1, Mo, No, M128, N128] + // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] + auto sum_exp_rf = sum_exp->rFactor({-1}); + // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] - // Merge the outer dims: - tv2->merge(0); - tv2->merge(0); + exp_val->computeAt(sum_exp_rf, 4); + exp_val->split(-1, 128); + exp_val->split(-1, 4); + bcast_max->computeAt(exp_val, -2); - // 0 1 2 3 - // [Mo, No, M128, N128] - tv0->computeAt(tv2, 2); - tv1->computeAt(tv2, 2); + // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] - // Order K - // 0 1 2 3 4 5 - // [Mo, No, M128, N128, Ko, K32] - tv2c->split(-1, gemm_tile.cta_tile.k); - tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + // Read from smem + tv3ccr->computeAt(max_rf, 4); + // [Mo, Mo32, My2, Mx2, N80] + tv3ccr->split(-1, 128); + tv3ccr->split(-1, 4); + // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] - // 0 1 2 3 4 5 - // [Mo, No, Ko, M128, N128, K32] - tv0r->computeAt(tv2c, 3); - tv1r->computeAt(tv2c, 3); + // Write to second gemm + tv3h->split(-1, 128); + tv3h->split(-1, 4); + // Map to warp (2x2) + tv3h->split(-4, 4); + tv3h->split(-4, 2); - // Make warp tile: - // ------------------------------------------------------------------------- - mma_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); - mma_utils::scheduleWarpTileWithNoReduction(tv2, gemm_tile); - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki] - tv0cr->computeAt(tv2c, -4); - tv1cr->computeAt(tv2c, -4); + bcast_sum->computeAt(tv3h, -2); - // Schedule gmem read and smem write: - // --------------------------------------------------------------------------- - // [Mo, Ko, M, K] - tv0cw->merge(-2); - tv0r->merge(-2); - mma_utils::scheduleContiguousVectorLoad(tv0cw, gemm_tile, 8); - mma_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile, 8); - tv0cw->setMemoryType(MemoryType::Shared); - // [Mo, Ko, i, wy, wx, v] + tv3h->setMemoryType(MemoryType::Shared); - // [No, Ko, N, K] - tv1cw->merge(-2); - tv1r->merge(-2); - // [No, Ko, i, wy, wx, v] - mma_utils::scheduleContiguousVectorLoad(tv1cw, gemm_tile, 8); - mma_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile, 8); - tv1cw->setMemoryType(MemoryType::Shared); - // Schedule mma input - // --------------------------------------------------------------------------- - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + // Parallelize + tv4->axis(0)->parallelize(ParallelType::BIDx); + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + // Gemm 1 + tv3c->axis(4)->parallelize(ParallelType::TIDz); + tv3c->axis(5)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDz); + tv3->axis(3)->parallelize(ParallelType::TIDy); - // [... Mi, Ni, Ki] want [Ni, Mi, Ki] - tv0b->reorder({{-2, -3}, {-3, -2}}); - tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + auto parallelize_non_reduced_val = [](TensorView* tv) { + tv->axis(-2)->parallelize(ParallelType::TIDx); + tv->axis(2)->parallelize(ParallelType::TIDz); + tv->axis(3)->parallelize(ParallelType::TIDy); + }; - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + auto parallelize_reduced_val = [](TensorView* tv) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + tv->axis(2)->parallelize(ParallelType::TIDz); + tv->axis(3)->parallelize(ParallelType::TIDy); + }; - // Schedule mma output - // --------------------------------------------------------------------------- - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + parallelize_non_reduced_val(tv3h); + parallelize_non_reduced_val(max_rf); + parallelize_non_reduced_val(bcast_max); + parallelize_non_reduced_val(exp_val); + parallelize_non_reduced_val(sum_exp_rf); + parallelize_non_reduced_val(bcast_sum); + parallelize_non_reduced_val(recip); - // Parallelize - // 0 1 2 3 4 5 6 7 8 9 10 - // [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)] - tv2c->axis(4)->parallelize(ParallelType::TIDz); - tv2c->axis(5)->parallelize(ParallelType::TIDy); + parallelize_reduced_val(max_val); + parallelize_reduced_val(sum_exp); - // Parallelize // 0 1 2 3 4 5 6 7 // [Mo No Mwo Nwo Mw Nw (Mi Ni)] - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::BIDy); - tv2->axis(2)->parallelize(ParallelType::TIDz); - tv2->axis(3)->parallelize(ParallelType::TIDy); + // Gemm 2 + tv4->axis(2)->parallelize(ParallelType::TIDz); + tv4->axis(3)->parallelize(ParallelType::TIDy); + tv4c->axis(4)->parallelize(ParallelType::TIDz); + tv4c->axis(5)->parallelize(ParallelType::TIDy); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({B0, M, B1, K}, options); - auto t1 = at::randn({B0, N, B1, K}, options); + auto t0 = at::randn({M1, K1}, options); + auto t1 = at::randn({N1, K1}, options); + auto t2 = at::randn({N2, K2}, options); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); + fe.compileFusion(&fusion, {t0, t1, t2}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1, t2}); - // ref implementation: - auto ref_t0 = t0.permute({0, 2, 1, 3}) - .contiguous() - .view({B0 * B1, M, K}); // B0, B1, M, K - auto ref_t1 = t1.permute({0, 2, 3, 1}) - .contiguous() - .view({B0 * B1, K, N}); // B0, B1, K, N - auto ref_permuted = - ref_t0.to(at::kFloat).bmm(ref_t1.to(at::kFloat)); // B0*B1, M,N - auto ref = ref_permuted.view({B0, B1, M, N}) - .permute({0, 2, 3, 1}) - .contiguous(); // B0,M,N,B1 - NVF_CHECK(cg_outputs[0].allclose(ref, 0.0001, 0.0001)); + auto g1 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + auto sg1 = at::_softmax(g1, -1, false); + auto gsg1 = sg1.matmul(t2.t().to(at::kFloat)); + + NVF_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001)); } -// Matmul test on Ampere with a reshape on prolog -TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); +// Matmul test for Turing MMA: across supported layouts +TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + + for (auto layout : kAllSupportedMmaLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout, true); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + MatmulParams params; + params.mma_macro = MmaMacro::Turing_16_8_16; + params.tile_sizes = gemm_tile; + scheduleMatmul(&fusion, params); + + auto inputs = matmulAtInput(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + } +} +// Matmul test on ampere, using ampere memory ops +TEST_F(NVFuserTest, FusionAmpereMatmulTNcpAsync_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - int M = 511, N = 257, K = 88; - int Ko = 11, Ki = 8; - // [M,Ko,Ki] - auto tv0 = makeContigTensor(3, DataType::Half); + int M = 255, N = 511, K = 88; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); // [N,K] auto tv1 = makeContigTensor(2, DataType::Half); fusion.addInput(tv0); fusion.addInput(tv1); - auto tv0_reshape = reshape(tv0, {M, Ko, Ki}, {M, K}); - // [M,N,K] - auto tv0b = broadcast(tv0_reshape, {false, true, false}); + auto tv0b = broadcast(tv0, {false, true, false}); auto tv1b = broadcast(tv1, {true, false, false}); // Leaving both sets of mma inputs for volta outside @@ -2485,27 +1086,18 @@ TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 8, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); - auto mma_ops = ir_utils::getOpsOfType(&fusion); NVF_CHECK( 1 == mma_ops.size(), "Invalid number of MmaOp instances in fusion definition, expected 1, got ", mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); + mma_ops.front()->setMacro(MmaMacro::Ampere_16_8_16); - auto tv0r = tv0->cacheAfter(); - auto tv1r = tv1->cacheAfter(); - auto tv0cw = tv0_reshape->cacheAfter(); - auto tv0cr = - tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); - auto tv1cw = tv1r->cacheAfter(); - auto tv1cr = - tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv0cw = tv0->cacheAfter(LoadStoreOpType::CpAsync); + auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); + auto tv1cw = tv1->cacheAfter(LoadStoreOpType::CpAsync); + auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); // Make a CTA tile // ------------------------------------------------------------------ @@ -2530,8 +1122,8 @@ TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) { // 0 1 2 3 4 5 // [Mo,No, Ko M128, N128, K32] - tv0r->computeAt(tv2c, 3); - tv1r->computeAt(tv2c, 3); + tv0cw->computeAt(tv2c, 3); + tv1cw->computeAt(tv2c, 3); // Make warp tile: // ------------------------------------------------------------------------- @@ -2546,41 +1138,29 @@ TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) { // --------------------------------------------------------------------------- // [Mo,Ko,M,K] tv0cw->merge(-2); - tv0r->merge(-2); - tv0_reshape->merge(-2); mma_utils::scheduleContiguousVectorLoad(tv0cw, gemm_tile, 8); - mma_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile, 8); tv0cw->setMemoryType(MemoryType::Shared); // [Mo,Ko,i,wy,wx,v] // [No,Ko,N,K] tv1cw->merge(-2); - tv1r->merge(-2); // [No,Ko,i,wy,wx,v] mma_utils::scheduleContiguousVectorLoad(tv1cw, gemm_tile, 8); - mma_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile, 8); tv1cw->setMemoryType(MemoryType::Shared); // Schedule mma input // --------------------------------------------------------------------------- - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - - // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0cr->applyMmaSwizzle(MmaOperand::A); + // [... Mi, Ni, Ki] tv0b->reorder({{-2, -3}, {-3, -2}}); - tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv0b->applyMmaSwizzle(MmaOperand::A); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1cr->applyMmaSwizzle(MmaOperand::B); + tv1b->applyMmaSwizzle(MmaOperand::B); // Schedule mma output // --------------------------------------------------------------------------- - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - // Inline the reshape op with the shared mem write minus - // the vectorization axes for now. - tv0_reshape->computeAt(tv0cw, -2); + tv2c->applyMmaSwizzle(MmaOperand::Accumulator); + tv2->applyMmaSwizzle(MmaOperand::Accumulator); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10 @@ -2597,11 +1177,10 @@ TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) { tv2->axis(3)->parallelize(ParallelType::TIDy); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, Ko, Ki}, options); + auto t0 = at::randn({M, K}, options); auto t1 = at::randn({N, K}, options); FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, 0, @@ -2609,154 +1188,143 @@ TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) { auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = - at::native::view(t0, {M, K}).to(at::kFloat).matmul(t1.t().to(at::kFloat)); + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } -// Initial test case for in-CTA split K with VoltaMMA -TEST_F(NVFuserTest, FusionVoltaMatmulTNCrossWarp_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); +TEST_F(NVFuserTest, FusionAmpereStridedBatchedMatmulTN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); - int M = 120, N = 264, K = 120; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); + int64_t M = 511, N = 123, K = 88, B0 = 3, B1 = 5; + // [B0 ,M, B1, K] + auto tv0 = makeContigTensor(4, DataType::Half); + // [B0, N, B1, K] + auto tv1 = makeContigTensor(4, DataType::Half); fusion.addInput(tv0); fusion.addInput(tv1); - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); + // [B0, M, N, B1, K] + auto tv0b = broadcast(tv0, {false, false, true, false, false}); + auto tv1b = broadcast(tv1, {false, true, false, false, false}); // Leaving both sets of mma inputs for volta outside // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + auto tv2 = fusedMultiplySum(tv0b, tv1b, {4}); fusion.addOutput(tv2); MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 16); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); - - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_ops = ir_utils::getOpsOfType(&fusion); NVF_CHECK( 1 == mma_ops.size(), "Invalid number of MmaOp instances in fusion definition, expected 1, got ", mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); + mma_ops.front()->setMacro(MmaMacro::Ampere_16_8_16); auto tv0r = tv0->cacheAfter(); auto tv1r = tv1->cacheAfter(); - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(); + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); auto tv2c = tv2->cacheBefore(); + // Group the BATCHED DIMS: + // -4 -3 -2 -1 + // [B0, M, N, B1] + tv2->reorder({{-3, -2}, {-2, -1}, {-1, -4}}); + + // -4 -3 -2 -1 + // [B0, B1, M, N] + // Make a CTA tile // ------------------------------------------------------------------ - // [M,N] + // [B0, B1, M, N] tv2->split(-2, gemm_tile.cta_tile.m); tv2->split(-1, gemm_tile.cta_tile.n); - // 0 1 2 3 - // [Mo,M128, No, N128] - tv2->reorder({{1, 2}, {2, 1}}); + // 0 1 2 3 4 5 + // [B0, B1, Mo, M128, No, N128] + tv2->reorder({{-3, -2}, {-2, -3}}); - // 0 1 2 3 - // [Mo,No, M128, N128] + // 0 1 2 3 4 5 + // [B0, B1, Mo, No, M128, N128] + + // Merge the outer dims: + tv2->merge(0); + tv2->merge(0); + + // 0 1 2 3 + // [Mo, No, M128, N128] tv0->computeAt(tv2, 2); tv1->computeAt(tv2, 2); // Order K - // 0 1 2 3 4 5 - // [Mo,No, M128, N128, Ko, K32] + // 0 1 2 3 4 5 + // [Mo, No, M128, N128, Ko, K32] tv2c->split(-1, gemm_tile.cta_tile.k); tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); - // 0 1 2 3 4 5 - // [Mo,No, Ko M128, N128, K32] + // 0 1 2 3 4 5 + // [Mo, No, Ko, M128, N128, K32] tv0r->computeAt(tv2c, 3); tv1r->computeAt(tv2c, 3); // Make warp tile: // ------------------------------------------------------------------------- mma_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); - auto tv2c_rf = tv2c->rFactor({-9, -4, -1}); - - // tv2c_rf is the actual output of the mma op after - // Rfactoring. - mma_builder.accumulatorTv(tv2c_rf); - mma_utils::scheduleWarpTileWithNoReduction(tv2, gemm_tile); - - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] - tv0cr->computeAt(tv2c_rf, -4); - tv1cr->computeAt(tv2c_rf, -4); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); // Schedule gmem read and smem write: // --------------------------------------------------------------------------- - // [Mo,No,Ko,M,N,K] - tv0cw->reorder({ - {-3, -2}, - {-2, -3}, - }); - // [Mo,No,Ko,N,M,K] + // [Mo, Ko, M, K] tv0cw->merge(-2); tv0r->merge(-2); mma_utils::scheduleContiguousVectorLoad(tv0cw, gemm_tile, 8); mma_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile, 8); tv0cw->setMemoryType(MemoryType::Shared); - // [Mo,Ko,i,wy,wx,v] + // [Mo, Ko, i, wy, wx, v] - // [Mo,No,Ko,M,N,K] + // [No, Ko, N, K] tv1cw->merge(-2); tv1r->merge(-2); - // [Mo,No,Ko,i,wy,wx,v] + // [No, Ko, i, wy, wx, v] mma_utils::scheduleContiguousVectorLoad(tv1cw, gemm_tile, 8); mma_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile, 8); tv1cw->setMemoryType(MemoryType::Shared); // Schedule mma input // --------------------------------------------------------------------------- - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv0cr->applyMmaSwizzle(MmaOperand::A); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(MmaOperand::A); + + tv1cr->applyMmaSwizzle(MmaOperand::B); + tv1b->applyMmaSwizzle(MmaOperand::B); // Schedule mma output // --------------------------------------------------------------------------- - tv2c_rf->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0b->computeAt(tv0cw, -2); - tv1b->computeAt(tv1cw, -2); - - tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); - tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); - // Parallelize - // 0 1 2 3 4 5 6 7 8 9 10 - // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] - tv2c_rf->axis(0)->parallelize(ParallelType::BIDx); - tv2c_rf->axis(1)->parallelize(ParallelType::BIDy); - tv2c_rf->axis(3)->parallelize(ParallelType::TIDz); - tv2c_rf->axis(4)->parallelize(ParallelType::TIDy); + tv2c->applyMmaSwizzle(MmaOperand::Accumulator); + tv2->applyMmaSwizzle(MmaOperand::Accumulator); - tv2c->axis(2)->parallelize(ParallelType::TIDz); - tv2c->axis(3)->parallelize(ParallelType::TIDy); + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)] + tv2c->axis(4)->parallelize(ParallelType::TIDz); + tv2c->axis(5)->parallelize(ParallelType::TIDy); // Parallelize // 0 1 2 3 4 5 6 7 @@ -2764,36 +1332,56 @@ TEST_F(NVFuserTest, FusionVoltaMatmulTNCrossWarp_CUDA) { tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::BIDy); tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); + auto t0 = at::randn({B0, M, B1, K}, options); + auto t1 = at::randn({B0, N, B1, K}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams); + + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); + auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); - NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + + // ref implementation: + auto ref_t0 = t0.permute({0, 2, 1, 3}) + .contiguous() + .view({B0 * B1, M, K}); // B0, B1, M, K + auto ref_t1 = t1.permute({0, 2, 3, 1}) + .contiguous() + .view({B0 * B1, K, N}); // B0, B1, K, N + auto ref_permuted = + ref_t0.to(at::kFloat).bmm(ref_t1.to(at::kFloat)); // B0*B1, M,N + auto ref = ref_permuted.view({B0, B1, M, N}) + .permute({0, 2, 3, 1}) + .contiguous(); // B0,M,N,B1 + NVF_CHECK(cg_outputs[0].allclose(ref, 0.0001, 0.0001)); } -// Initial test case for cross-CTA split K with VoltaMMA -TEST_F(NVFuserTest, FusionVoltaMatmulTNCrossCTA_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); +// Matmul test on Ampere with a reshape on prolog +TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); - int M = 120, N = 264, K = 120; + int M = 511, N = 257, K = 88; + int Ko = 11, Ki = 8; - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); + // [M,Ko,Ki] + auto tv0 = makeContigTensor(3, DataType::Half); // [N,K] auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); fusion.addInput(tv1); + auto tv0_reshape = reshape(tv0, {M, Ko, Ki}, {M, K}); + // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); + auto tv0b = broadcast(tv0_reshape, {false, true, false}); auto tv1b = broadcast(tv1, {true, false, false}); // Leaving both sets of mma inputs for volta outside @@ -2805,24 +1393,21 @@ TEST_F(NVFuserTest, FusionVoltaMatmulTNCrossCTA_CUDA) { MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 16, 4); - - auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_ops = ir_utils::getOpsOfType(&fusion); NVF_CHECK( 1 == mma_ops.size(), "Invalid number of MmaOp instances in fusion definition, expected 1, got ", mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); + mma_ops.front()->setMacro(MmaMacro::Ampere_16_8_16); auto tv0r = tv0->cacheAfter(); auto tv1r = tv1->cacheAfter(); - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(); + auto tv0cw = tv0_reshape->cacheAfter(); + auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); auto tv2c = tv2->cacheBefore(); // Make a CTA tile @@ -2844,91 +1429,65 @@ TEST_F(NVFuserTest, FusionVoltaMatmulTNCrossCTA_CUDA) { // 0 1 2 3 4 5 // [Mo,No, M128, N128, Ko, K32] tv2c->split(-1, gemm_tile.cta_tile.k); - tv2c->split(-2, 2, true); - // Order K - // 0 1 2 3 4 5 6 - // [Mo,No, M128, N128, Ko, K2CTA, K32] - tv2c->reorder({{2, 4}, {3, 5}, {4, 3}, {5, 2}}); - // 0 1 2 3 4 5 6 - // [Mo,No, K2CTA, Ko M128, N128, K32] - tv0r->computeAt(tv2c, 4); - tv1r->computeAt(tv2c, 4); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); // Make warp tile: // ------------------------------------------------------------------------- mma_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); - // -9 -8 -7 -6 -5 -4 -3 -2 -1 - // [Mo No K2CTA Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki] - auto tv2c_rf = tv2c->rFactor({-9, -8, -1}); - - // tv2c_rf is the actual output of the mma op after - // Rfactoring. - mma_builder.accumulatorTv(tv2c_rf); - mma_utils::scheduleWarpTileWithNoReduction(tv2, gemm_tile); - - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mo No K2CTA Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki] - tv0cr->computeAt(tv2c_rf, -4); - tv1cr->computeAt(tv2c_rf, -4); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); // Schedule gmem read and smem write: // --------------------------------------------------------------------------- - // [Mo,No,Ko,M,N,K] - tv0cw->reorder({ - {-3, -2}, - {-2, -3}, - }); - // [Mo,No,Ko,N,M,K] + // [Mo,Ko,M,K] tv0cw->merge(-2); tv0r->merge(-2); + tv0_reshape->merge(-2); mma_utils::scheduleContiguousVectorLoad(tv0cw, gemm_tile, 8); mma_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile, 8); tv0cw->setMemoryType(MemoryType::Shared); // [Mo,Ko,i,wy,wx,v] - // [Mo,No,Ko,M,N,K] + // [No,Ko,N,K] tv1cw->merge(-2); tv1r->merge(-2); - // [Mo,No,Ko,i,wy,wx,v] + // [No,Ko,i,wy,wx,v] mma_utils::scheduleContiguousVectorLoad(tv1cw, gemm_tile, 8); mma_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile, 8); tv1cw->setMemoryType(MemoryType::Shared); // Schedule mma input // --------------------------------------------------------------------------- - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv0cr->applyMmaSwizzle(MmaOperand::A); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(MmaOperand::A); + + tv1cr->applyMmaSwizzle(MmaOperand::B); + tv1b->applyMmaSwizzle(MmaOperand::B); // Schedule mma output // --------------------------------------------------------------------------- - tv2c_rf->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0b->computeAt(tv0cw, -2); - tv1b->computeAt(tv1cw, -2); - - tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); - tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); + tv2c->applyMmaSwizzle(MmaOperand::Accumulator); + tv2->applyMmaSwizzle(MmaOperand::Accumulator); + + // Inline the reshape op with the shared mem write minus + // the vectorization axes for now. + tv0_reshape->computeAt(tv0cw, -2); + // Parallelize - // 0 1 2 3 4 5 6 7 8 9 10 11 - // [Mo No K2CTA Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki] - tv2c_rf->axis(0)->parallelize(ParallelType::BIDx); - tv2c_rf->axis(1)->parallelize(ParallelType::BIDy); - tv2c_rf->axis(2)->parallelize(ParallelType::BIDz); - tv2c_rf->axis(5)->parallelize(ParallelType::TIDz); - tv2c_rf->axis(6)->parallelize(ParallelType::TIDy); - - // 0 1 2 3 4 5 6 7 8 - // [Mo No K2CTA Mwo Nwo Mw Nw Mi Ni] - tv2c->axis(0)->parallelize(ParallelType::BIDx); - tv2c->axis(1)->parallelize(ParallelType::BIDy); - tv2c->axis(2)->parallelize(ParallelType::BIDz); - tv2c->axis(3)->parallelize(ParallelType::TIDz); - tv2c->axis(4)->parallelize(ParallelType::TIDy); + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)] + tv2c->axis(4)->parallelize(ParallelType::TIDz); + tv2c->axis(5)->parallelize(ParallelType::TIDy); // Parallelize // 0 1 2 3 4 5 6 7 @@ -2939,13 +1498,21 @@ TEST_F(NVFuserTest, FusionVoltaMatmulTNCrossCTA_CUDA) { tv2->axis(3)->parallelize(ParallelType::TIDy); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); + auto t0 = at::randn({M, Ko, Ki}, options); auto t1 = at::randn({N, K}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams); + + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); + auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); + + auto tref = + at::native::view(t0, {M, K}).to(at::kFloat).matmul(t1.t().to(at::kFloat)); + NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } @@ -2975,10 +1542,6 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { auto tv0b = broadcast(tv0, {false, true, false}); auto tv1b = broadcast(tv1, {true, false, false}); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); fusion.addOutput(tv2); @@ -2988,7 +1551,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { 1 == mma_ops.size(), "Invalid number of MmaOp instances in fusion definition, expected 1, got ", mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); + mma_ops.front()->setMacro(MmaMacro::Turing_16_8_16); auto tv0cw = tv0->cacheAfter(LoadStoreOpType::CpAsync); auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); @@ -2996,8 +1559,6 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); auto tv2c = tv2->cacheBefore(); - mma_builder.accumulatorTv(tv2c); - // Make a CTA tile // ------------------------------------------------------------------ // [M,N] @@ -3082,20 +1643,18 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { tv1cw->setMemoryType(MemoryType::Shared); // Schedule mma input - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv0cr->applyMmaSwizzle(MmaOperand::A); // [... Mi, Ni, Ki] tv0b->reorder({{-2, -3}, {-3, -2}}); - tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv0b->applyMmaSwizzle(MmaOperand::A); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1cr->applyMmaSwizzle(MmaOperand::B); + tv1b->applyMmaSwizzle(MmaOperand::B); // Schedule mma output - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2c->applyMmaSwizzle(MmaOperand::Accumulator); + tv2->applyMmaSwizzle(MmaOperand::Accumulator); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10 @@ -3134,7 +1693,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { REQUIRE_DEVICE_SMEM_SIZE(98384, 0); // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -3152,7 +1711,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 64); gemm_tile.instruction_tile = GemmTile(16, 16, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_16_16; + params.mma_macro = MmaMacro::Ampere_16_16_16; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; @@ -3184,7 +1743,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -3203,7 +1762,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 16, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Turing_16_16_16; + params.mma_macro = MmaMacro::Turing_16_16_16; params.tile_sizes = gemm_tile; scheduleMatmul(&fusion, params); @@ -3231,7 +1790,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { REQUIRE_DEVICE_SMEM_SIZE(98384, 0); // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { // Symmetric tile with 16x16x16 macro, // supports mn_size of multiple of 32, // and k size multiple of 16. @@ -3255,7 +1814,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 16, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_16_16; + params.mma_macro = MmaMacro::Ampere_16_16_16; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; @@ -3298,7 +1857,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { REQUIRE_DEVICE_SMEM_SIZE(98384, 0); // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { // ASymmetric tile with 16x16x16 macro, for (int m_size : {256}) { for (int n_size : {32, 64, 96, 128}) { @@ -3321,7 +1880,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 16, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_16_16; + params.mma_macro = MmaMacro::Ampere_16_16_16; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; @@ -3364,7 +1923,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { REQUIRE_DEVICE_SMEM_SIZE(98384, 0); // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { for (int k_size : {32, 48, 64}) { Fusion fusion; FusionGuard fg(&fusion); @@ -3385,7 +1944,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 16, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_16_16; + params.mma_macro = MmaMacro::Ampere_16_16_16; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; @@ -3422,7 +1981,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 2048; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -3441,7 +2000,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 16, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_16_16; + params.mma_macro = MmaMacro::Ampere_16_16_16; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; @@ -3468,401 +2027,12 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) { } } -// MMA and alpha unit test, for Ampere TN -TEST_F(NVFuserTest, FusionAmpereMMATNAlpha_CUDA) { - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - - auto s0 = IrBuilder::create(DataType::Double); - // [M,K] - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [N,K] - auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); - fusion.addInput(s0); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - auto tv3 = mul(s0, tv2); - - fusion.addOutput(tv3); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = - tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = - tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); - auto tv3c = tv3->cacheBefore(); - - mma_builder.accumulatorTv(tv2); - - // [M,N,K] -> [N,M,K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv3c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv3->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - at::manual_seed(0); - const double alpha = 1.5; - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({8, 16}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion( - &fusion, {alpha, t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({alpha, t0, t1}); - - auto t2 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - auto tref = t2.mul(alpha); - - NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - -// MMA and alpha + beta unit test, for Ampere TN -TEST_F(NVFuserTest, FusionAmpereMMATNAlphaBeta_CUDA) { - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - - // alpha - auto s0 = IrBuilder::create(DataType::Double); - // beta - auto s1 = IrBuilder::create(DataType::Double); - - // [M,K] - A - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [N,K] - B - auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); - // [M,N] - C - auto tv2 = makeConcreteTensor({16, 8}, DataType::Half); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - // ops: tv4 := alpha * (A x B) - auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); - auto tv4 = mul(s0, tv3); - - // ops: tv5 := beta * C - auto tv5 = mul(s1, tv2); - // ops: tv6 := alpha * (A x B) + beta * C - auto tv6 = add(tv4, tv5); - - fusion.addInput(s0); - fusion.addInput(s1); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - - fusion.addOutput(tv6); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = - tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = - tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); - auto tv6c = tv6->cacheBefore(); - - mma_builder.accumulatorTv(tv3); - - // [M,N,K] -> [N,M,K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - - // mma output := A x B - tv3->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - // alpha scaling result := alpha * (A x B) - tv4->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - // beta scaling result := beta * C - tv5->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - // final result cache := alpha * (A x B) + beta * C - tv6c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - // final result := alpha * (A x B) + beta * C - tv6->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - at::manual_seed(0); - const double alpha = 1.5; - const double beta = 1.5; - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({8, 16}, options); - auto t2 = at::randn({16, 8}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion( - &fusion, {alpha, beta, t0, t1, t2}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({alpha, beta, t0, t1, t2}); - - auto t3 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - auto t4 = t3.mul(alpha); - - auto t5 = t2.to(at::kFloat).mul(beta); - auto t6 = t4.add(t5); - - NVF_CHECK(cg_outputs[0].allclose(t6, 0.0001, 0.0001)); -} - -// MMA and bias epilogue unit test, for Ampere TN -TEST_F(NVFuserTest, FusionAmpereMMATNBias_CUDA) { - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - - // [M,K] - A - auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); - // [N,K] - B - auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); - // [M] - bias - auto tv2 = makeConcreteTensor({16}, DataType::Half); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - // [M,N] - auto tv2b = broadcast(tv2, {false, true}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - // ops: tv3 := A x B - auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); - - // ops: tv4 := (A x B) + bias - auto tv4 = add(tv3, tv2b); - - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - - fusion.addOutput(tv4); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaLayout::TN); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = - tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = - tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); - auto tv4c = tv4->cacheBefore(); - - mma_builder.accumulatorTv(tv3); - - // [M,N,K] -> [N,M,K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - - // mma output := A x B - tv3->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - // bias result cache := (A x B) + bias - tv4c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - // bias result := (A x B) + bias - tv4->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - at::manual_seed(0); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({16, 16}, options); - auto t1 = at::randn({8, 16}, options); - auto t2 = at::randn({16}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion(&fusion, {t0, t1, t2}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1, t2}); - - auto t3 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - auto t4 = atBiasEpilogue(t3, t2); - - NVF_CHECK(cg_outputs[0].allclose(t4, 0.0001, 0.0001)); -} - -// Strided batch gemm with MMA unit test, for Ampere TN -TEST_F(NVFuserTest, FusionAmpereMMATNSplitKLikeStridedBatch_CUDA) { - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - - const int64_t M = 16, N = 8, K = 16, B = 2; - const auto layout = MmaOptions::MmaLayout::TN; - - // [M, B, K] - auto tv0 = makeConcreteTensor({M, B, K}, DataType::Half); - // [N, B, K] - auto tv1 = makeConcreteTensor({N, B, K}, DataType::Half); - - // Note: following lines are similar to TN handling in - // 'splitkLikeBatchedMatmul(..)' - // [M, B, K] -> [B, M, K] - auto tv0t = transpose(tv0, 0, 1); - // [N, B, K] -> [B, N, K] - auto tv1t = transpose(tv1, 0, 1); - - // [B, M, N, K] - auto tv0b = broadcast(tv0t, {false, false, true, false}); - auto tv1b = broadcast(tv1t, {false, true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {-1}); - - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(16, 8, 16); - gemm_tile.warp_tile = GemmTile(16, 8, 16); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(layout); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_builder.configureMma(mma_ops.front()); - - auto tv0cw = tv0b->cacheAfter(); - auto tv0cr = - tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); - auto tv1cw = tv1b->cacheAfter(); - auto tv1cr = - tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); - auto tv2c = tv2->cacheBefore(); - - mma_builder.accumulatorTv(tv2c); - - // [B, M, N, K] -> [B, N, M, K] - tv0cr->reorder({{-2, -3}, {-3, -2}}); - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - - tv0cw->setMemoryType(MemoryType::Shared); - tv1cw->setMemoryType(MemoryType::Shared); - - at::manual_seed(0); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, B, K}, options); - auto t1 = at::randn({N, B, K}, options); - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); - auto cg_outputs = fe.runFusion({t0, t1}); - auto tref = splitkLikeAtMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); - - NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - // Matmul test for Ampere MMA: across supported layouts TEST_F(NVFuserTest, FusionAmpereSplitKLikeStridedBatchedMatmul_CUDA) { // Keep multiples of 8 to keep vectorizable. int B = 2, M = 504, N = 136, K = 248; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(3, DataType::Half); @@ -3881,7 +2051,7 @@ TEST_F(NVFuserTest, FusionAmpereSplitKLikeStridedBatchedMatmul_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.mma_macro = MmaMacro::Ampere_16_8_16; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; @@ -3910,7 +2080,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { constexpr bool ignore_occupancy_drop = true; // Keep multiples of 8 to keep vectorizable. int M = 4096, N = 4096, K = 4096; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -3934,7 +2104,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.mma_macro = MmaMacro::Ampere_16_8_16; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; @@ -4042,7 +2212,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueCast_CUDA) { constexpr bool ignore_occupancy_drop = true; // Keep multiples of 8 to keep vectorizable. int M = 4096, N = 4096, K = 4096; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -4062,7 +2232,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueCast_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.mma_macro = MmaMacro::Ampere_16_8_16; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; @@ -4129,7 +2299,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueRelu_CUDA) { constexpr bool ignore_occupancy_drop = true; // Keep multiples of 8 to keep vectorizable. int M = 4096, N = 4096, K = 4096; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -4149,7 +2319,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueRelu_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.mma_macro = MmaMacro::Ampere_16_8_16; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; @@ -4222,7 +2392,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSplitK_CUDA) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 8096; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -4241,7 +2411,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSplitK_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.mma_macro = MmaMacro::Ampere_16_8_16; params.tile_sizes = gemm_tile; params.splitk_factor = 2; scheduleMatmul(&fusion, params); @@ -4271,7 +2441,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSplitKBias_CUDA) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 8096; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -4294,7 +2464,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSplitKBias_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.mma_macro = MmaMacro::Ampere_16_8_16; params.tile_sizes = gemm_tile; params.splitk_factor = 2; scheduleMatmul(&fusion, params); @@ -4327,7 +2497,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulBatchSplitK_CUDA) { // Keep multiples of 8 to keep vectorizable. int B = 2, M = 504, N = 136, K = 2048; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(3, DataType::Half); @@ -4346,7 +2516,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulBatchSplitK_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.mma_macro = MmaMacro::Ampere_16_8_16; params.tile_sizes = gemm_tile; params.splitk_factor = 2; scheduleMatmul(&fusion, params); @@ -4380,7 +2550,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulBatchSplitKBias_CUDA) { // Keep multiples of 8 to keep vectorizable. int B = 2, M = 504, N = 136, K = 2048; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(3, DataType::Half); @@ -4403,7 +2573,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulBatchSplitKBias_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); MatmulParams params; - params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16; + params.mma_macro = MmaMacro::Ampere_16_8_16; params.tile_sizes = gemm_tile; params.splitk_factor = 2; scheduleMatmul(&fusion, params); diff --git a/test/test_gpu_utils.cpp b/test/test_gpu_utils.cpp index 0c207226c24..0b71103dc32 100644 --- a/test/test_gpu_utils.cpp +++ b/test/test_gpu_utils.cpp @@ -1077,7 +1077,7 @@ TEST_F(NVFuserTest, FusionSASSDumpError_CUDA) { ::testing::HasSubstr("I am fake"))); auto cg_outputs = fe.runFusion({t0}); - testValidate(fe.kernel(), cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(fe.kernel(), cg_outputs, {t0}, __LINE__, __FILE__); } } // namespace nvfuser diff --git a/test/test_gpu_view.cpp b/test/test_gpu_view.cpp index 0f6573d68d5..f4b334c769e 100644 --- a/test/test_gpu_view.cpp +++ b/test/test_gpu_view.cpp @@ -78,10 +78,7 @@ TEST_F(GpuViewTest, FusionViewDtypeSameSizeOutput) { fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); - auto at_x_add_bias = at_x + at_bias; - auto at_x_view = at_x_add_bias.view(at::ScalarType::Int); - - testValidate(&fusion, outputs, aten_inputs, {at_x_view}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionViewDtypeFailMismatchSize) { @@ -144,12 +141,7 @@ TEST_F(GpuViewTest, FusionViewAsRealOutput) { fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); - auto at_x_add_bias = at_x + at_bias; - auto at_x_view = at::view_as_real(at_x_add_bias); - auto at_y_plus_1 = at_y + 1.0; - auto at_out = at_y_plus_1 + at_x_view; - - testValidate(&fusion, outputs, aten_inputs, {at_out}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionReshapeRfactorExtentReplacement) { @@ -174,10 +166,8 @@ TEST_F(GpuViewTest, FusionReshapeRfactorExtentReplacement) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs({t0, t1}); - auto ref = at::native::view(t0, {4, 3, 8}).sum({-1}) + 1 + t1; - testValidate( - executor_cache.fusion(), cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); + executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionReshapeOutput) { @@ -207,11 +197,7 @@ TEST_F(GpuViewTest, FusionReshapeOutput) { fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); - auto at_x_add_bias = at_x + at_bias; - auto at_x_reshape = at::native::view(at_x_add_bias, output_shape); - - testValidate( - &fusion, outputs, aten_inputs, {at_x_reshape}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionReshapeFailMismatchSize) { @@ -303,14 +289,7 @@ void reductionViewAddFusion( FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); - auto at_tv1 = (reshape_before_reduction) ? (at_x + at_bias) - : at::sum(at_x, kReductionAxis); - auto at_x_reshape = at::native::view(at_tv1, output_shape); - auto at_y = (reshape_before_reduction) - ? at::sum(at_x_reshape, kReductionAxis) - : at::add(at_x_reshape, at_bias); - - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } } @@ -443,15 +422,7 @@ void persistentViewAddFusion( FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); - auto at_tv1 = (reshape_before_persistent) - ? (at_x + at_bias) - : at::_softmax(at_x, kAxis, false /* half_to_float */); - auto at_x_reshape = at::native::view(at_tv1, inferred_output); - auto at_y = (reshape_before_persistent) - ? at::_softmax(at_x_reshape, kAxis, false /* half_to_float */) - : at::add(at_x_reshape, at_bias); - - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } } @@ -505,11 +476,7 @@ void addViewGeluFusion( fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); - auto at_x_add_bias = at_x + at_bias; - auto at_x_reshape = at::native::view(at_x_add_bias, output_shape); - auto at_y = at::gelu(at_x_reshape); - - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } } @@ -579,11 +546,7 @@ void geluViewAddFusion( fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); - auto at_x_gelu = at::gelu(at_x); - auto at_x_reshape = at::native::view(at_x_gelu, inferred_output); - auto at_y = at_x_reshape + at_bias; - - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } } @@ -627,12 +590,7 @@ void geluViewBinaryAddFusion( fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); - auto at_x_gelu = at::gelu(at_x); - auto at_x_reshape = at::native::view(at_x_gelu, output_shape); - auto at_bias_reshape = at::native::view(at_bias, output_shape); - auto at_y = at_x_reshape + at_bias_reshape; - - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } } @@ -669,9 +627,7 @@ TEST_F(GpuViewTest, FusionReshapeConcreteDomain) { fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); - auto ref = (at::native::view(t0, {6}) + 1).unsqueeze(0) + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t1}, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionReshapeConcreteDomain2) { @@ -701,11 +657,7 @@ TEST_F(GpuViewTest, FusionReshapeConcreteDomain2) { FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); - auto at_tv1 = at::_softmax(at_x, kAxis, false /* half_to_float */); - auto at_x_reshape = at::native::view(at_tv1, output_shape); - auto at_y = at::add(at_x_reshape, at_bias); - - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } // Repro of issue #1608 @@ -741,12 +693,7 @@ TEST_F(GpuViewTest, FusionReshapeConcreteDomain3) { FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); - auto at_tv1 = at::add(at_x, at_y); - auto at_tv2 = at::native::view(at_tv1, output_shape); - auto at_tv3 = at::native::view(at_z, output_shape); - auto at_output = at::add(at_tv2, at_tv3); - - testValidate(&fusion, outputs, aten_inputs, {at_output}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionReshapeConcreteDomain4) { @@ -893,11 +840,7 @@ TEST_F(GpuViewTest, FusionFlattenAfterUnsqueezeOutput) { fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); - auto at_x_add_bias = at_x + at_bias; - auto at_x_reshape = at_x_add_bias.unsqueeze(-1).flatten(); - - testValidate( - &fusion, outputs, aten_inputs, {at_x_reshape}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionComputeAtRootDomainMapWithView) { @@ -962,13 +905,11 @@ TEST_F(GpuViewTest, FusionExpandRepro) { LaunchParams l_params; auto outputs = fe.runFusion(aten_inputs, {}, l_params, {}); - auto out = at_x.expand_as(at_y); - - testValidate(&fusion, outputs, aten_inputs, {out}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); // second run to verify cached output allocation outputs = fe.runFusion(aten_inputs, {}, l_params, {}); - testValidate(&fusion, outputs, aten_inputs, {out}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionExpandView1) { @@ -998,10 +939,8 @@ TEST_F(GpuViewTest, FusionExpandView1) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs({t0, t1}); - auto ref = at::reshape(t0.expand({4, 3, 8}), {12, 8}) + t1; - testValidate( - executor_cache.fusion(), cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); + executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionExpandView2) { @@ -1028,10 +967,8 @@ TEST_F(GpuViewTest, FusionExpandView2) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs({t0, t1}); - auto ref = at::reshape(t0.expand({12, 8}), {3, 4, 8}) + t1; - testValidate( - executor_cache.fusion(), cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); + executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionReshapeTransformCache) { @@ -1253,6 +1190,12 @@ TEST_F(GpuViewTest, FusionReshapeVectorize) { fusion.addOutput(tv4); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // This test allocates about 1GB of memory, so in order to avoid an OOM during + // this test, we manually clear the allocator after it's reached a certain + // threshold. + maybeClearAllocator(); + at::Tensor input = at::randn({256, 1024, 1024}, options); auto lparams = schedulePointwise(&fusion, {input}); @@ -1279,9 +1222,7 @@ TEST_F(GpuViewTest, FusionReshapeVectorize) { fe.compileFusion(&fusion, {input}, lparams); auto outputs = fe.runFusion({input}, lparams); - auto tv_ref = input.flatten(1, 2).sin(); - - testValidate(&fusion, outputs, {input}, {tv_ref, tv_ref}, __LINE__, __FILE__); + testValidate(&fusion, outputs, {input}, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionExpandFlatten) { @@ -1308,15 +1249,8 @@ TEST_F(GpuViewTest, FusionExpandFlatten) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs({input}); - auto aten_out = input.expand({256, 1024, 8}).flatten(1, 2).sum(1); - testValidate( - executor_cache.fusion(), - cg_outputs, - {input}, - {aten_out}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, {input}, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionIllegalReductionFlatten) { @@ -1349,13 +1283,11 @@ TEST_F(GpuViewTest, FusionReductionFlatten1) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn({2, 3, 5}, options); - auto ref = t0.sum({1}).flatten(0, 1); FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs({t0}); - testValidate( - executor_cache.fusion(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__); + testValidate(executor_cache.fusion(), cg_outputs, {t0}, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionPwiseViewSchedule) { @@ -1412,16 +1344,12 @@ TEST_F(GpuViewTest, FusionPwiseViewSchedule) { at::Tensor t0 = at::randn({x, y, z}, options); at::Tensor t3 = at::randn({x, y, z}, options); - auto t1 = sin(t0); - auto t2 = at::native::view(t1, {x, y * z}); - auto t4 = at::native::view(t3, {x, y * z}); - auto t5 = t0 + t3; FusionExecutor fe; fe.compileFusion(&fusion, {t0, t3}); auto cg_outputs = fe.runFusion({t0, t3}); - testValidate(&fusion, cg_outputs, {t0, t3}, {t2, t4, t5}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t3}, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionSumViewSchedule) { @@ -1521,16 +1449,12 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule1) { at::Tensor t0 = at::randn({x, y, z}, options); at::Tensor t3 = at::randn({x, y, z}, options); - auto t1 = sin(t0); - auto t2 = at::native::view(t1, {x, y * z}); - auto t4 = at::native::view(t3, {x, y * z}); - auto t5 = t0 + t3; FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs({t0, t3}); NVF_CHECK(!executor_cache.getMostRecentKernelRuntime()->isSegmented()); - testValidate(&fusion, cg_outputs, {t0, t3}, {t2, t4, t5}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t3}, __LINE__, __FILE__); } // Make sure reshapes of reshapes are correct @@ -1555,7 +1479,6 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule2) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({x, y, z}, options); - auto aten_out = sin(t0); // For now pointwise scheduler only accepts a single reshape at a time, so // this will be broken up into multiple kernels. This is due to the reference @@ -1564,7 +1487,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule2) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs({t0}); - testValidate(&fusion, cg_outputs, {t0}, {aten_out}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // Make sure broadcasts not on the reshape path that don't interfere with @@ -1612,12 +1535,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule3) { at::Tensor t0 = at::randn({x, y, z}, options); at::Tensor t3 = at::randn({x, y, z}, options); - auto t1 = sin(t0); - auto t2 = at::native::view(t1, {x, y * z}); - auto t4 = at::native::view(t3, {x, y * z}); - auto t5 = t0 + t3; at::Tensor t6 = at::randn({w, x, y, z}, options); - auto t8 = t6.add(t0); FusionExecutorCache executor_cache(std::move(fusion_ptr)); // Collect the heuristic params @@ -1631,8 +1549,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule3) { executor_cache.getMostRecentExecutorInfo().params->as(); NVF_CHECK(pparams->break_point == 1); - testValidate( - &fusion, cg_outputs, {t0, t3, t6}, {t2, t4, t5, t8}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t3, t6}, __LINE__, __FILE__); } // Make sure broadcasts through reshapes when not conflicting with reshape are @@ -1673,11 +1590,6 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule4) { at::Tensor t0 = at::randn({x, y, z}, options); at::Tensor t3 = at::randn({x, y, z}, options); at::Tensor t4 = at::randn({x, 1, 1}, options); - auto t1 = sin(t0); - auto t2 = at::native::view(t1, {x, y * z}); - auto t5 = t4 + t3; - auto t6 = at::native::view(t5, {x, y * z}); - auto t7 = t0 + t3; FusionExecutorCache executor_cache(std::move(fusion_ptr)); // Collect the heuristic params @@ -1691,8 +1603,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule4) { executor_cache.getMostRecentExecutorInfo().params->as(); NVF_CHECK(pparams->break_point == 1); - testValidate( - &fusion, cg_outputs, {t0, t3, t4}, {t2, t6, t7}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t3, t4}, __LINE__, __FILE__); } // Make sure different reshapes that are consumed by the reference are segmented @@ -1720,12 +1631,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule5) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({w, x, y * z}, options); - auto t1 = sin(t0); - auto t2 = at::native::view(t1, {z, y, x, w}); at::Tensor t3 = at::randn({w, x * y, z}, options); - auto t4 = cos(t3); - auto t5 = at::native::view(t4, {z, y, x, w}); - auto t6 = add(t2, t5); FusionExecutorCache executor_cache(std::move(fusion_ptr)); // Collect the heuristic params @@ -1736,7 +1642,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule5) { NVF_CHECK(executor_cache.getMostRecentExecutorInfo() .params->isA()); - testValidate(&fusion, cg_outputs, {t0, t3}, {t6}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t3}, __LINE__, __FILE__); } // Test reshape/transpose and its impact on vectorization @@ -1763,11 +1669,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule6) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({x, y}, options); - auto t1 = at::native::view(t0, {x, y / 2, 2}); - - auto t2 = t1.transpose(0, 1); at::Tensor t3 = at::randn({y / 2, x, 2}, options); - auto t4 = add(t2, t3); FusionExecutorCache executor_cache(std::move(fusion_ptr)); // Collect the heuristic params @@ -1785,7 +1687,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule6) { .params->as() ->unroll_factor); - testValidate(&fusion, cg_outputs, {t0, t3}, {t4}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t3}, __LINE__, __FILE__); } // View with 3D reduction scheduling @@ -1813,12 +1715,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule7) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({w, v, x, y, z}, options); - auto t1 = sin(t0); - auto t2 = at::native::view(t1, {v * w, x, y * z}); at::Tensor t3 = at::randn({v, w, x, z, y}, options); - auto t4 = cos(t3); - auto t5 = at::native::view(t4, {v * w, x, y * z}); - auto t7 = add(t2, t5).sum(2).sum(0); FusionExecutorCache executor_cache(std::move(fusion_ptr)); // Collect the heuristic params @@ -1829,7 +1726,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule7) { NVF_CHECK(executor_cache.getMostRecentExecutorInfo() .params->isA()); - testValidate(&fusion, cg_outputs, {t0, t3}, {t7}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t3}, __LINE__, __FILE__); } // View with 3D normalization scheduling @@ -1859,16 +1756,8 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule8) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({w, v, x, y, z}, options); - auto t1 = sin(t0); - auto t2 = at::native::view(t1, {v * w, x, y * z}); // This might trigger transpose kernel. at::Tensor t3 = at::randn({v, w, x, z, y}, options); - auto t4 = cos(t3); - auto t5 = at::native::view(t4, {v * w, x, y * z}); - auto t6 = add(t2, t5); - auto t7 = t6.sum(2).sum(0); - auto t8 = t7.unsqueeze(-1).unsqueeze(0); - auto t9 = t6 + t8; FusionExecutorCache executor_cache(std::move(fusion_ptr)); // Collect the heuristic params @@ -1879,7 +1768,7 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule8) { NVF_CHECK(executor_cache.getMostRecentExecutorInfo() .params->isA()); - testValidate(&fusion, cg_outputs, {t0, t3}, {t9}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t3}, __LINE__, __FILE__); } // AlbertForMaskedLM repro https://github.com/csarofeen/pytorch/issues/2066 @@ -1929,32 +1818,10 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule9) { auto t3 = at::randn({2, 512}, options); auto t4 = at::randn({2, 512, 128}, options); - auto t5 = t0.unsqueeze(0).unsqueeze(0); - auto t6 = t1.unsqueeze(-1); - auto t7 = t2.unsqueeze(0).unsqueeze(0); - auto t8 = t3.unsqueeze(-1); - auto t9 = t6; - - auto t11 = t8.abs().add(1.e-12); - auto t12 = t4.sub(t9); - auto t13 = t11.rsqrt(); - auto t14 = t13; - auto t15 = t12.mul(t14); - auto t16 = t15.mul(t5); - auto t17 = t16.add(t7); - auto t18 = t17.to(at::kFloat); - auto t19 = at::native::view(t18, {x * y, z}); - FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs({t0, t1, t2, t3, t4}); - testValidate( - &fusion, - cg_outputs, - {t0, t1, t2, t3, t4}, - {t6, t13, t19}, - __LINE__, - __FILE__); + testValidate(&fusion, cg_outputs, {t0, t1, t2, t3, t4}, __LINE__, __FILE__); } // Simpler version of FusionReshapeMagicSchedule9_CUDA @@ -2006,13 +1873,11 @@ TEST_F(GpuViewTest, FusionReshapeMagicSchedule11) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn({1, x, y, z}, options); - auto t2 = at::native::view(t0, {1, x, y * z}); - auto t3 = at::native::view(t2, {x, y * z}); FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs({t0}); - testValidate(&fusion, cg_outputs, {t0}, {t3}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // Make sure different reshapes that are consumed by the reference are segmented @@ -2057,18 +1922,13 @@ TEST_F(GpuViewTest, FusionReshapeMapping) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({w, x, y * z}, options); - auto t1 = sin(t0); - auto t2 = at::native::view(t1, {z, y, x, w}); at::Tensor t3 = at::randn({w, x * y, z}, options); - auto t4 = cos(t3); - auto t5 = at::native::view(t4, {z, y, x, w}); - auto t6 = add(t2, t5); FusionExecutor fe; fe.compileFusion(&fusion, {t0, t3}); auto cg_outputs = fe.runFusion({t0, t3}); - testValidate(&fusion, cg_outputs, {t0, t3}, {t6}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t3}, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionLowerDivisibleSplits) { @@ -2289,16 +2149,9 @@ TEST_F(GpuViewTest, FusionIssue2076_v2) { at::Tensor t1 = at::randn({48, 128}, options); at::Tensor t2 = at::randn({4, 1, 128}, options); - auto t3 = t1.reshape({4, 12, 128}); - - // [4, 1, 128] - auto t4 = t0.add(t2); - auto t5 = t3.add(t4); - auto t6 = t5.reshape({48, 128}); - FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); - testValidate(&fusion, cg_outputs, {t0, t1, t2}, {t4, t6}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t1, t2}, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionReshapeZeroDimInput) { @@ -2331,9 +2184,7 @@ TEST_F(GpuViewTest, FusionReshapeZeroDimInput) { fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); - auto at_prod = at_x * at_y; - - testValidate(&fusion, outputs, aten_inputs, {at_prod}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionReshapeZeroDimOutput) { @@ -2372,9 +2223,7 @@ TEST_F(GpuViewTest, FusionReshapeZeroDimOutput) { fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); - auto at_prod = (at_x.squeeze() + at_y.squeeze()) * at_z; - - testValidate(&fusion, outputs, aten_inputs, {at_prod}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(GpuViewTest, FusionReshapeZeroDimInputOutput) { @@ -2409,9 +2258,7 @@ TEST_F(GpuViewTest, FusionReshapeZeroDimInputOutput) { fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); - auto at_prod = at_x * at_y; - - testValidate(&fusion, outputs, aten_inputs, {at_prod}, __LINE__, __FILE__); + testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(GpuViewTest, ReshapeOfReshape) { @@ -2462,20 +2309,12 @@ TEST_F(GpuViewTest, ExpandedBroadcast) { at::Tensor in_tensor = at::randn({4, 5}, at::dtype(at::kFloat).device(at::kCUDA, 0)); - at::Tensor expected_out_tensor = - in_tensor.unsqueeze(-1).expand({-1, -1, 6}).reshape({40, 3}); FusionExecutor fe; fe.compileFusion(&fusion, {in_tensor}); at::Tensor actual_out_tensor = fe.runFusion({in_tensor})[0]; - testValidate( - &fusion, - {actual_out_tensor}, - {in_tensor}, - {expected_out_tensor}, - __LINE__, - __FILE__); + testValidate(&fusion, {actual_out_tensor}, {in_tensor}, __LINE__, __FILE__); } } // namespace nvfuser diff --git a/test/test_linked_hash_map.cpp b/test/test_linked_hash_map.cpp index f81fefc090e..8261a358d7c 100644 --- a/test/test_linked_hash_map.cpp +++ b/test/test_linked_hash_map.cpp @@ -12,9 +12,55 @@ #include +namespace { +class CopyableKey { + public: + explicit CopyableKey(std::string data) : data_(std::move(data)) {} + + size_t hash() const { + return std::hash()(data_); + } + + bool operator==(const CopyableKey& other) const { + return data_ == other.data_; + } + + private: + std::string data_; +}; + +class MovableValue { + public: + explicit MovableValue(int data) : data_(data) {} + + MovableValue(const MovableValue&) = delete; + MovableValue& operator=(const MovableValue&) = delete; + + MovableValue(MovableValue&&) = default; + MovableValue& operator=(MovableValue&&) = default; + + int data() const { + return data_; + } + + private: + int data_; +}; +} // namespace + +namespace std { +template <> +struct hash { + size_t operator()(const CopyableKey& key) const { + return key.hash(); + } +}; +} // namespace std + namespace nvfuser { using testing::ElementsAre; +using testing::Eq; using testing::Pair; TEST(LinkedHashMapTest, PushBack) { @@ -72,4 +118,19 @@ TEST(LinkedHashMapTest, EraseThenPushBack) { EXPECT_THAT(map, ElementsAre(Pair("a", 1), Pair("b", 4))); } +namespace { +MATCHER_P(DataIs, data, "") { + return arg.data() == data; +} +} // namespace + +TEST(LinkedHashMapTest, MovableValue) { + LinkedHashMap map; + map.pushBack(CopyableKey("a"), MovableValue(1)); + map.pushBack(CopyableKey("b"), MovableValue(2)); + map.erase(CopyableKey("b")); + + EXPECT_THAT(map, ElementsAre(Pair(CopyableKey("a"), DataIs(1)))); +} + } // namespace nvfuser diff --git a/test/test_loop_rotation.cpp b/test/test_loop_rotation.cpp index 7b3e42d2f74..bdba541142b 100644 --- a/test/test_loop_rotation.cpp +++ b/test/test_loop_rotation.cpp @@ -79,7 +79,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor FusionExecutor fe; fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } } @@ -172,7 +172,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor FusionExecutor fe; fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } } @@ -281,7 +281,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor FusionExecutor fe; fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } } @@ -392,7 +392,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor FusionExecutor fe; fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } } @@ -529,7 +529,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor FusionExecutor fe; fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } } @@ -578,9 +578,22 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor i3 = toSmem(T4) + (12LL * i1); bool b4; b4 = (i1 + nvfuser_zero) < T0.logical_size[0LL]; + bool b5; + b5 = !b4; #pragma unroll - for(nvfuser_index_t i5 = 0; i5 < 3LL; ++i5) { - Ampere::cpAsyncCa((i3 + (4LL * i5)), (ptr2 + (T0.alloc_stride[1LL] * (i5 + nvfuser_zero))), b4); + for(nvfuser_index_t i6 = 0; i6 < 3LL; ++i6) { + asm volatile( + "{\n" + " .reg .pred p0; \n" + " setp.ne.b32 p0, %3, 0;\n" + " cp.async.ca.shared.global [%0], [%1], %2, p0;\n" + "}\n" + : + :"r"((uint32_t)((i3 + (4LL * i6)))), + "l"((ptr2 + (T0.alloc_stride[1LL] * (i6 + nvfuser_zero)))), + "n"(4LL), + "r"((uint32_t)(b5)) + ); } asm volatile("cp.async.commit_group;\n"); } @@ -590,45 +603,58 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1[0LL] = T4[0LL]; #pragma unroll 1 - for(nvfuser_index_t i6 = 0; i6 < T0.logical_size[0LL]; ++i6) { - float* ptr7; - ptr7 = ptr0 + (T0.alloc_stride[0LL] * i6); - nvfuser_index_t i8; - i8 = 4LL + i6; - unsigned i9; - i9 = toSmem(T4) + (12LL * (i8 % 5LL)); - nvfuser_index_t i10; - i10 = 1LL + (3LL * (i6 % 5LL)); + for(nvfuser_index_t i7 = 0; i7 < T0.logical_size[0LL]; ++i7) { + float* ptr8; + ptr8 = ptr0 + (T0.alloc_stride[0LL] * i7); + nvfuser_index_t i9; + i9 = 4LL + i7; + unsigned i10; + i10 = toSmem(T4) + (12LL * (i9 % 5LL)); nvfuser_index_t i11; - i11 = 3LL * i6; - bool b12; - b12 = i8 < T0.logical_size[0LL]; + i11 = 1LL + (3LL * (i7 % 5LL)); + nvfuser_index_t i12; + i12 = 3LL * i7; + bool b13; + b13 = i9 < T0.logical_size[0LL]; + bool b14; + b14 = !b13; #pragma unroll - for(nvfuser_index_t i5 = 0; i5 < 3LL; ++i5) { - Ampere::cpAsyncCa((i9 + (4LL * i5)), (ptr7 + (T0.alloc_stride[1LL] * (i5 + nvfuser_zero))), b12); + for(nvfuser_index_t i6 = 0; i6 < 3LL; ++i6) { + asm volatile( + "{\n" + " .reg .pred p0; \n" + " setp.ne.b32 p0, %3, 0;\n" + " cp.async.ca.shared.global [%0], [%1], %2, p0;\n" + "}\n" + : + :"r"((uint32_t)((i10 + (4LL * i6)))), + "l"((ptr8 + (T0.alloc_stride[1LL] * (i6 + nvfuser_zero)))), + "n"(4LL), + "r"((uint32_t)(b14)) + ); } NVFUSER_UPDATE_MAGIC_ZERO; asm volatile("cp.async.commit_group;\n"); #pragma unroll - for(nvfuser_index_t i13 = 0; i13 < 2LL; ++i13) { - T1[((1LL + i13) % 2LL)] - = T4[(i10 + i13)]; + for(nvfuser_index_t i15 = 0; i15 < 2LL; ++i15) { + T1[((1LL + i15) % 2LL)] + = T4[(i11 + i15)]; float T2[1LL]; T2[0LL] - = T1[(i13 % 2LL)]; - T3[(i11 + (i13 + nvfuser_zero))] + = T1[(i15 % 2LL)]; + T3[(i12 + (i15 + nvfuser_zero))] = T2[0LL]; } NVFUSER_UPDATE_MAGIC_ZERO; float T2[1LL]; T2[0LL] = T1[0LL]; - T3[(2LL + i11)] + T3[(2LL + i12)] = T2[0LL]; NVFUSER_UPDATE_MAGIC_ZERO; asm volatile("cp.async.wait_group %0;\n"::"n"(3LL)); T1[0LL] - = T4[(3LL * ((1LL + i6) % 5LL))]; + = T4[(3LL * ((1LL + i7) % 5LL))]; } } )"; @@ -640,7 +666,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor FusionExecutor fe; fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } } } // namespace nvfuser diff --git a/test/test_matmul_sass.cpp b/test/test_matmul_sass.cpp index bf8c6e530fb..27a52288394 100644 --- a/test/test_matmul_sass.cpp +++ b/test/test_matmul_sass.cpp @@ -38,11 +38,11 @@ class MatmulSASSTest : public NVFuserTest {}; namespace { sass::Container getSASSFor( - MatmulLayout layout, + MmaLayout layout, GemmTile cta_tile, GemmTile warp_tile, GemmTile instruction_tile, - MmaOptions::MacroType macro, + MmaMacro macro, int M, int N, int K, @@ -94,11 +94,11 @@ sass::Container getSASSFor( // A fusion with epilogue made of binary op (scalar multiplication) sass::Container getBinaryOpMulEpilogueSASSFor( - MatmulLayout layout, + MmaLayout layout, GemmTile cta_tile, GemmTile warp_tile, GemmTile instruction_tile, - MmaOptions::MacroType macro, + MmaMacro macro, int M, int N, int K) { @@ -166,7 +166,7 @@ TEST_F(MatmulSASSTest, AmpereSanity_CUDA) { bool found_LDSM = false; bool found_HMMA = false; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { sass::Container sass; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, @@ -176,7 +176,7 @@ TEST_F(MatmulSASSTest, AmpereSanity_CUDA) { GemmTile(128, 128, 32), GemmTile(64, 64, 32), GemmTile(16, 8, 16), - MmaOptions::MacroType::Ampere_16_8_16, + MmaMacro::Ampere_16_8_16, M, N, K)); @@ -215,8 +215,8 @@ TEST_F(MatmulSASSTest, AmpereModifiers_CUDA) { bool found_HMMA = false; bool found_LDGDEPBAR = false; bool found_BAR = false; - bool found_DEPBAR = false; // kAllSupportedMatmulLayout; - for (auto layout : {MatmulLayout::TT}) { + bool found_DEPBAR = false; // kAllSupportedMmaLayout; + for (auto layout : {MmaLayout::TT}) { sass::Container sass; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, @@ -226,7 +226,7 @@ TEST_F(MatmulSASSTest, AmpereModifiers_CUDA) { GemmTile(128, 128, 32), GemmTile(64, 64, 32), GemmTile(16, 8, 16), - MmaOptions::MacroType::Ampere_16_8_16, + MmaMacro::Ampere_16_8_16, M, N, K)); @@ -341,12 +341,12 @@ TEST_F(MatmulSASSTest, AmpereModifiersSharedMemoryEpilogue_CUDA) { } // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; - for (auto layout : {MatmulLayout::TT}) { + for (auto layout : {MmaLayout::TT}) { bool found_LDGSTS = false; bool found_LDSM = false; bool found_HMMA = false; bool found_LDGDEPBAR = false; - bool found_DEPBAR = false; // kAllSupportedMatmulLayout; + bool found_DEPBAR = false; // kAllSupportedMmaLayout; int BAR_COUNT = 0; // we have at least three shared memory barriers in the kernel if // use_shared_epilogue. If promote_prologue_smem_reuse, then 4 @@ -360,7 +360,7 @@ TEST_F(MatmulSASSTest, AmpereModifiersSharedMemoryEpilogue_CUDA) { gemm_tile.cta_tile, gemm_tile.warp_tile, gemm_tile.instruction_tile, - MmaOptions::MacroType::Ampere_16_8_16, + MmaMacro::Ampere_16_8_16, M, N, K, @@ -469,7 +469,7 @@ TEST_F(MatmulSASSTest, AmpereEpilogueBinaryOpMul_CUDA) { bool found_LDGDEPBAR = false; bool found_BAR = false; bool found_DEPBAR = false; - for (auto layout : {MatmulLayout::TT}) { + for (auto layout : {MmaLayout::TT}) { sass::Container sass; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, @@ -479,7 +479,7 @@ TEST_F(MatmulSASSTest, AmpereEpilogueBinaryOpMul_CUDA) { GemmTile(128, 128, 32), GemmTile(64, 64, 32), GemmTile(16, 8, 16), - MmaOptions::MacroType::Ampere_16_8_16, + MmaMacro::Ampere_16_8_16, M, N, K)); @@ -597,7 +597,7 @@ TEST_F(MatmulSASSTest, AmpereRegisterUsageLDSM_CUDA) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { std::unordered_map> base_offsets; sass::Container sass; @@ -609,7 +609,7 @@ TEST_F(MatmulSASSTest, AmpereRegisterUsageLDSM_CUDA) { GemmTile(128, 128, 32), GemmTile(64, 64, 32), GemmTile(16, 8, 16), - MmaOptions::MacroType::Ampere_16_8_16, + MmaMacro::Ampere_16_8_16, M, N, K)); diff --git a/test/test_matmul_scheduler.cpp b/test/test_matmul_scheduler.cpp index c8ec3be821a..f9c8a5a065f 100644 --- a/test/test_matmul_scheduler.cpp +++ b/test/test_matmul_scheduler.cpp @@ -20,14 +20,990 @@ namespace nvfuser { namespace { class MatmulSchedulerTest : public NVFuserTest {}; + +using PrecisionsDesc = std::tuple; + +using AbsoluteError = double; +using RelariveError = double; +using ErrorThresholds = std::pair; +using TestCaseErrorThresholds = std::map; +class PrecisionParametrizedTest + : public NVFuserFixtureParamTest {}; + +[[nodiscard]] auto get_type_letter(const PrimDataType& type) { + switch (type) { + case PrimDataType::Half: + return "H"; + case PrimDataType::Float: + return "S"; + case PrimDataType::BFloat16: + return "T"; + default: + break; + } + NVF_ERROR(false, "Unsupported conversion of PrimDataType"); + return "*"; +} + +static const PrecisionsDesc HSH = std::make_tuple( + PrimDataType::Half, + PrimDataType::Float, + PrimDataType::Half); +static const PrecisionsDesc HSS = std::make_tuple( + PrimDataType::Half, + PrimDataType::Float, + PrimDataType::Float); +static const PrecisionsDesc TST = std::make_tuple( + PrimDataType::BFloat16, + PrimDataType::Float, + PrimDataType::BFloat16); +static const PrecisionsDesc TSS = std::make_tuple( + PrimDataType::BFloat16, + PrimDataType::Float, + PrimDataType::Float); + +// Matmul test that uses segmenter for fusion: +// D = (A x B) + bias +// Target architectures: Turing, Ampere +TEST_P(PrecisionParametrizedTest, EpilogueBias) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + const auto layout = MmaLayout::TT; + + static TestCaseErrorThresholds errs = { + {HSS, std::make_pair(0.0001, 0.0001)}, + {HSH, std::make_pair(0.001, 0.001)}, + {TSS, std::make_pair(0.0001, 0.0001)}, + {TST, std::make_pair(0.01, 0.01)}, + }; + + NVF_CHECK( + errs.count(GetParam()) != 0, + "Undefined error thresholds for requested precisions"); + + const auto [in_prim_type, accu_prim_type, out_prim_type] = GetParam(); + const auto [abs_err_thr, rel_err_thr] = errs[GetParam()]; + + const auto in_type = DataType(in_prim_type); + const auto accu_type = DataType(accu_prim_type); + const auto out_type = DataType(out_prim_type); + const auto at_in_type = data_type_to_aten(in_prim_type); + const auto at_accu_type = data_type_to_aten(accu_prim_type); + const auto at_out_type = data_type_to_aten(out_prim_type); + + // NOTE: bfloat16 is not supported on pre-Ampere archs + if (DataType::BFloat16 == in_type || DataType::BFloat16 == out_type) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // A - tv0, B - tv1, C - tv2 + auto tv0 = makeContigTensor(2, in_type); + auto tv1 = makeContigTensor(2, in_type); + auto tv2 = makeContigTensor(1, out_type); + + // tv3 := A x B + auto tv3 = matmul(tv0, tv1, layout, true); + // tv4 := cast(bias) + auto tv4 = maybeCastOp(accu_type, tv2); + + // tv5 := (A x B) + bias + auto tv5 = biasEpilogue(tv3, tv4); + // tv6 := cast(tv5) + auto tv6 = maybeCastOp(out_type, tv5); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + fusion->addOutput(tv6); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + NVF_CHECK( + ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), + "input layout has not be set for MmaOp"); + NVF_CHECK( + MmaLayout::TN == + ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), + "the MmaOp layout of Ampere MMA must always be TN"); + + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + NVF_CHECK( + fusion_layout.isValid(), + "failed to get decide matmul layout through fusion definition"); + NVF_CHECK( + fusion_layout.getData() == layout, + "mismatch between test layout (", + toString(layout), + ") and layout inferred from fusion definition (", + toString(fusion_layout.getData()), + ")"); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); + + auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); + auto t4 = t2.to(at_accu_type); + + auto t5 = atBiasEpilogue(t3, t4); + auto t6 = t5.to(at_out_type); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); + + NVF_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation did happen"); + + // NOTE: increasted absolute tolerance to silence false negative verification + // caused by different way of calculating reference + NVF_CHECK(outputs[0].allclose(t6, abs_err_thr, rel_err_thr)); +} + +// Matmul test that uses segmenter for fusion: +// D = relu(A x B) +// Target architectures: Turing, Ampere +TEST_P(PrecisionParametrizedTest, EpilogueRelu) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + const auto layout = MmaLayout::TT; + + static TestCaseErrorThresholds errs = { + {HSS, std::make_pair(0.0001, 0.0001)}, + {HSH, std::make_pair(0.001, 0.001)}, + {TSS, std::make_pair(0.0001, 0.0001)}, + {TST, std::make_pair(0.01, 0.01)}, + }; + + NVF_CHECK( + errs.count(GetParam()) != 0, + "Undefined error thresholds for requested precisions"); + + const auto [in_prim_type, accu_prim_type, out_prim_type] = GetParam(); + const auto [abs_err_thr, rel_err_thr] = errs[GetParam()]; + + const auto in_type = DataType(in_prim_type); + const auto out_type = DataType(out_prim_type); + const auto at_in_type = data_type_to_aten(in_prim_type); + const auto at_out_type = data_type_to_aten(out_prim_type); + + // NOTE: bfloat16 is not supported on pre-Ampere archs + if (DataType::BFloat16 == in_type || DataType::BFloat16 == out_type) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // A - tv0, B - tv1 + auto tv0 = makeContigTensor(2, in_type); + auto tv1 = makeContigTensor(2, in_type); + + auto tv2 = matmul(tv0, tv1, layout, true); + auto tv3 = relu(tv2); + auto tv4 = maybeCastOp(out_type, tv3); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv4); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + NVF_CHECK( + ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), + "input layout has not be set for MmaOp"); + NVF_CHECK( + MmaLayout::TN == + ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), + "the MmaOp layout of Ampere MMA must always be TN"); + + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + NVF_CHECK( + fusion_layout.isValid(), + "failed to get decide matmul layout through fusion definition"); + NVF_CHECK( + fusion_layout.getData() == layout, + "mismatch between test layout (", + toString(layout), + ") and layout inferred from fusion definition (", + toString(fusion_layout.getData()), + ")"); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); + auto t3 = at::relu(t2); + auto t4 = t3.to(at_out_type); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + + NVF_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation did happen"); + + NVF_CHECK(outputs[0].allclose(t4, abs_err_thr, rel_err_thr)); +} + +// Matmul test that uses segmenter for fusion: +// D = relu((A x B) + bias) +// Target architectures: Ampere +TEST_P(PrecisionParametrizedTest, EpilogueBiasRelu) { + // NOTE: test skips Turing arch, the relative error was too big + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + const auto layout = MmaLayout::TT; + + static TestCaseErrorThresholds errs = { + {HSS, std::make_pair(0.001, 0.001)}, + {HSH, std::make_pair(0.001, 0.001)}, + {TSS, std::make_pair(0.001, 0.001)}, + {TST, std::make_pair(0.01, 0.001)}, + }; + + NVF_CHECK( + errs.count(GetParam()) != 0, + "Undefined error thresholds for requested precisions"); + + const auto [in_prim_type, accu_prim_type, out_prim_type] = GetParam(); + const auto [abs_err_thr, rel_err_thr] = errs[GetParam()]; + + const auto in_type = DataType(in_prim_type); + const auto accu_type = DataType(accu_prim_type); + const auto out_type = DataType(out_prim_type); + const auto at_in_type = data_type_to_aten(in_prim_type); + const auto at_accu_type = data_type_to_aten(accu_prim_type); + const auto at_out_type = data_type_to_aten(out_prim_type); + + // NOTE: bfloat16 is not supported on pre-Ampere archs + if (DataType::BFloat16 == in_type || DataType::BFloat16 == out_type) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // A - tv0, B - tv1, C - tv2 + auto tv0 = makeContigTensor(2, in_type); + auto tv1 = makeContigTensor(2, in_type); + auto tv2 = makeContigTensor(1, out_type); + + // tv3 := A x B + auto tv3 = matmul(tv0, tv1, layout, true); + + // tv4 := cast(bias) + auto tv4 = maybeCastOp(accu_type, tv2); + + // tv5 := (A x B) + bias + auto tv5 = biasEpilogue(tv3, tv4); + + // tv6 := relu((A x B) + bias) + auto tv6 = relu(tv5); + auto tv7 = maybeCastOp(out_type, tv6); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + fusion->addOutput(tv7); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + NVF_CHECK( + ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), + "input layout has not be set for MmaOp"); + NVF_CHECK( + MmaLayout::TN == + ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), + "the MmaOp layout of Ampere MMA must always be TN"); + + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + NVF_CHECK( + fusion_layout.isValid(), + "failed to get decide matmul layout through fusion definition"); + NVF_CHECK( + fusion_layout.getData() == layout, + "mismatch between test layout (", + toString(layout), + ") and layout inferred from fusion definition (", + toString(fusion_layout.getData()), + ")"); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); + + auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); + auto t4 = t2.to(at_accu_type); + auto t5 = atBiasEpilogue(t3, t4); + auto t6 = at::relu(t5); + auto t7 = t6.to(at_out_type); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); + + NVF_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation did happen"); + + // NOTE: increasted absolute tolerance to silence false negative verification + // caused by different way of calculating reference D tensor results + NVF_CHECK(outputs[0].allclose(t7, abs_err_thr, rel_err_thr)); +} + +// Matmul test that uses segmenter for fusion: +// D = A x B; +// Aux = relu(D) +// Target architectures: Turing, Ampere +TEST_P(PrecisionParametrizedTest, EpilogueReluAux) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + const auto layout = MmaLayout::TT; + + static TestCaseErrorThresholds errs = { + {HSS, std::make_pair(0.001, 0.001)}, + {HSH, std::make_pair(0.001, 0.001)}, + {TSS, std::make_pair(0.001, 0.001)}, + {TST, std::make_pair(0.01, 0.001)}, + }; + + NVF_CHECK( + errs.count(GetParam()) != 0, + "Undefined error thresholds for requested precisions"); + + const auto [in_prim_type, accu_prim_type, out_prim_type] = GetParam(); + const auto [abs_err_thr, rel_err_thr] = errs[GetParam()]; + + const auto in_type = DataType(in_prim_type); + const auto out_type = DataType(out_prim_type); + const auto at_in_type = data_type_to_aten(in_prim_type); + const auto at_out_type = data_type_to_aten(out_prim_type); + + // NOTE: bfloat16 is not supported on pre-Ampere archs + if (DataType::BFloat16 == in_type || DataType::BFloat16 == out_type) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // A - tv0, B - tv1 + auto tv0 = makeContigTensor(2, in_type); + auto tv1 = makeContigTensor(2, in_type); + + auto tv2 = matmul(tv0, tv1, layout, true); + auto tv3 = maybeCastOp(out_type, tv2); + auto tv4 = relu(tv2); + auto tv5 = maybeCastOp(out_type, tv4); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv3); + fusion->addOutput(tv5); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + NVF_CHECK( + ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), + "input layout has not be set for MmaOp"); + NVF_CHECK( + MmaLayout::TN == + ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), + "the MmaOp layout of Ampere MMA must always be TN"); + + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + NVF_CHECK( + fusion_layout.isValid(), + "failed to get decide matmul layout through fusion definition"); + NVF_CHECK( + fusion_layout.getData() == layout, + "mismatch between test layout (", + toString(layout), + ") and layout inferred from fusion definition (", + toString(fusion_layout.getData()), + ")"); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); + auto t3 = t2.to(at_out_type); + auto t4 = at::relu(t2); + auto t5 = t4.to(at_out_type); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + + NVF_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation did happen"); + + // D tensor results + NVF_CHECK(outputs[0].allclose(t3, abs_err_thr, rel_err_thr)); + // Aux tensor results + NVF_CHECK(outputs[1].allclose(t5, abs_err_thr, rel_err_thr)); +} + +// Matmul test that uses segmenter for fusion: +// D = (A x B) + bias +// Aux = relu(D) +// Target architectures: Ampere +TEST_P(PrecisionParametrizedTest, EpilogueBiasReluAux) { + // NOTE: test skips Turing arch, the relative error was too big + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + const auto layout = MmaLayout::TT; + + static TestCaseErrorThresholds errs = { + {HSS, std::make_pair(0.001, 0.001)}, + {HSH, std::make_pair(0.001, 0.001)}, + {TSS, std::make_pair(0.001, 0.001)}, + {TST, std::make_pair(0.01, 0.001)}, + }; + + NVF_CHECK( + errs.count(GetParam()) != 0, + "Undefined error thresholds for requested precisions"); + + const auto [in_prim_type, accu_prim_type, out_prim_type] = GetParam(); + const auto [abs_err_thr, rel_err_thr] = errs[GetParam()]; + + const auto in_type = DataType(in_prim_type); + const auto accu_type = DataType(accu_prim_type); + const auto out_type = DataType(out_prim_type); + const auto at_in_type = data_type_to_aten(in_prim_type); + const auto at_accu_type = data_type_to_aten(accu_prim_type); + const auto at_out_type = data_type_to_aten(out_prim_type); + + // NOTE: bfloat16 is not supported on pre-Ampere archs + if (DataType::BFloat16 == in_type || DataType::BFloat16 == out_type) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // A - tv0, B - tv1, C - tv2 + auto tv0 = makeContigTensor(2, in_type); + auto tv1 = makeContigTensor(2, in_type); + auto tv2 = makeContigTensor(1, out_type); + + // tv3 := A x B + auto tv3 = matmul(tv0, tv1, layout, true); + // tv4 := cast(bias) + auto tv4 = maybeCastOp(accu_type, tv2); + + // tv5 := (A x B) + bias + auto tv5 = biasEpilogue(tv3, tv4); + + // tv6 := cast((A x B) + bias) + auto tv6 = maybeCastOp(out_type, tv5); + + // tv7 := relu((A x B) + bias) + auto tv7 = relu(tv5); + auto tv8 = maybeCastOp(out_type, tv7); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + fusion->addOutput(tv6); + fusion->addOutput(tv8); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + NVF_CHECK( + ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), + "input layout has not be set for MmaOp"); + NVF_CHECK( + MmaLayout::TN == + ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), + "the MmaOp layout of Ampere MMA must always be TN"); + + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + NVF_CHECK( + fusion_layout.isValid(), + "failed to get decide matmul layout through fusion definition"); + NVF_CHECK( + fusion_layout.getData() == layout, + "mismatch between test layout (", + toString(layout), + ") and layout inferred from fusion definition (", + toString(fusion_layout.getData()), + ")"); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); + + auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); + auto t4 = t2.to(at_accu_type); + auto t5 = atBiasEpilogue(t3, t4); + auto t6 = t5.to(at_out_type); + auto t7 = at::relu(t5); + auto t8 = t7.to(at_out_type); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); + + NVF_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation did happen"); + + // NOTE: increasted absolute tolerance to silence false negative verification + // caused by different way of calculating reference D tensor results + NVF_CHECK(outputs[0].allclose(t6, abs_err_thr, rel_err_thr)); + // Aux tensor results + NVF_CHECK(outputs[1].allclose(t8, abs_err_thr, rel_err_thr)); +} + +// Matmul test that uses segmenter for fusion: +// D = gelu(A x B) +// Target architectures: Turing, Ampere +TEST_P(PrecisionParametrizedTest, EpilogueGelu) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + const auto layout = MmaLayout::TT; + + static TestCaseErrorThresholds errs = { + {HSS, std::make_pair(0.001, 0.001)}, + {HSH, std::make_pair(0.001, 0.001)}, + {TSS, std::make_pair(0.001, 0.001)}, + {TST, std::make_pair(0.01, 0.001)}, + }; + + NVF_CHECK( + errs.count(GetParam()) != 0, + "Undefined error thresholds for requested precisions"); + + const auto [in_prim_type, accu_prim_type, out_prim_type] = GetParam(); + const auto [abs_err_thr, rel_err_thr] = errs[GetParam()]; + + const auto in_type = DataType(in_prim_type); + const auto out_type = DataType(out_prim_type); + const auto at_in_type = data_type_to_aten(in_prim_type); + const auto at_out_type = data_type_to_aten(out_prim_type); + + // NOTE: bfloat16 is not supported on pre-Ampere archs + if (DataType::BFloat16 == in_type || DataType::BFloat16 == out_type) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // A - tv0, B - tv1 + auto tv0 = makeContigTensor(2, in_type); + auto tv1 = makeContigTensor(2, in_type); + + auto tv2 = matmul(tv0, tv1, layout, true); + auto tv3 = gelu(tv2); + auto tv4 = maybeCastOp(out_type, tv3); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv4); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + NVF_CHECK( + ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), + "input layout has not be set for MmaOp"); + NVF_CHECK( + MmaLayout::TN == + ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), + "the MmaOp layout of Ampere MMA must always be TN"); + + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + NVF_CHECK( + fusion_layout.isValid(), + "failed to get decide matmul layout through fusion definition"); + NVF_CHECK( + fusion_layout.getData() == layout, + "mismatch between test layout (", + toString(layout), + ") and layout inferred from fusion definition (", + toString(fusion_layout.getData()), + ")"); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); + auto t3 = at::gelu(t2); + auto t4 = t3.to(at_out_type); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + + NVF_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation did happen"); + + NVF_CHECK(outputs[0].allclose(t4, abs_err_thr, rel_err_thr)); +} + +// Matmul test that uses segmenter for fusion: +// D = A x B +// Aux = gelu(D) +// Target architectures: Turing, Ampere +TEST_P(PrecisionParametrizedTest, EpilogueGeluAux) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + const auto layout = MmaLayout::TT; + + static TestCaseErrorThresholds errs = { + {HSS, std::make_pair(0.001, 0.001)}, + {HSH, std::make_pair(0.001, 0.001)}, + {TSS, std::make_pair(0.001, 0.001)}, + {TST, std::make_pair(0.01, 0.001)}, + }; + + NVF_CHECK( + errs.count(GetParam()) != 0, + "Undefined error thresholds for requested precisions"); + + const auto [in_prim_type, accu_prim_type, out_prim_type] = GetParam(); + const auto [abs_err_thr, rel_err_thr] = errs[GetParam()]; + + const auto in_type = DataType(in_prim_type); + const auto out_type = DataType(out_prim_type); + const auto at_in_type = data_type_to_aten(in_prim_type); + const auto at_out_type = data_type_to_aten(out_prim_type); + + // NOTE: bfloat16 is not supported on pre-Ampere archs + if (DataType::BFloat16 == in_type || DataType::BFloat16 == out_type) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // A - tv0, B - tv1 + auto tv0 = makeContigTensor(2, in_type); + auto tv1 = makeContigTensor(2, in_type); + + auto tv2 = matmul(tv0, tv1, layout, true); + auto tv3 = maybeCastOp(out_type, tv2); + auto tv4 = gelu(tv2); + auto tv5 = maybeCastOp(out_type, tv4); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv3); + fusion->addOutput(tv5); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + NVF_CHECK( + ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), + "input layout has not be set for MmaOp"); + NVF_CHECK( + MmaLayout::TN == + ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), + "the MmaOp layout of Ampere MMA must always be TN"); + + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + NVF_CHECK( + fusion_layout.isValid(), + "failed to get decide matmul layout through fusion definition"); + NVF_CHECK( + fusion_layout.getData() == layout, + "mismatch between test layout (", + toString(layout), + ") and layout inferred from fusion definition (", + toString(fusion_layout.getData()), + ")"); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); + auto t3 = t2.to(at_out_type); + auto t4 = at::gelu(t2); + auto t5 = t4.to(at_out_type); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + + NVF_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation did happen"); + + // D tensor results + NVF_CHECK(outputs[0].allclose(t3, abs_err_thr, rel_err_thr)); + // Aux tensor results + NVF_CHECK(outputs[1].allclose(t5, abs_err_thr, rel_err_thr)); +} + +// Matmul test that uses segmenter for fusion for Ampere: +// D = gelu((A x B) + bias) +// Target architectures: Turing, Ampere +TEST_P(PrecisionParametrizedTest, EpilogueBiasGelu) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + const auto layout = MmaLayout::TT; + + static TestCaseErrorThresholds errs = { + {HSS, std::make_pair(0.001, 0.001)}, + {HSH, std::make_pair(0.01, 0.001)}, + {TSS, std::make_pair(0.001, 0.001)}, + {TST, std::make_pair(0.01, 0.01)}, + }; + + NVF_CHECK( + errs.count(GetParam()) != 0, + "Undefined error thresholds for requested precisions"); + + const auto [in_prim_type, accu_prim_type, out_prim_type] = GetParam(); + const auto [abs_err_thr, rel_err_thr] = errs[GetParam()]; + + const auto in_type = DataType(in_prim_type); + const auto accu_type = DataType(accu_prim_type); + const auto out_type = DataType(out_prim_type); + const auto at_in_type = data_type_to_aten(in_prim_type); + const auto at_accu_type = data_type_to_aten(accu_prim_type); + const auto at_out_type = data_type_to_aten(out_prim_type); + + // NOTE: bfloat16 is not supported on pre-Ampere archs + if (DataType::BFloat16 == in_type || DataType::BFloat16 == out_type) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // A - tv0, B - tv1, C - tv2 + auto tv0 = makeContigTensor(2, in_type); + auto tv1 = makeContigTensor(2, in_type); + auto tv2 = makeContigTensor(1, out_type); + + // tv3 := A x B + auto tv3 = matmul(tv0, tv1, layout, true); + // tv4 := cast(bias) + auto tv4 = maybeCastOp(accu_type, tv2); + + // tv5 := (A x B) + bias + auto tv5 = biasEpilogue(tv3, tv4); + + // tv6 := gelu((A x B) + bias) + auto tv6 = gelu(tv5); + auto tv7 = maybeCastOp(out_type, tv6); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + fusion->addOutput(tv7); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + NVF_CHECK( + ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), + "input layout has not be set for MmaOp"); + NVF_CHECK( + MmaLayout::TN == + ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), + "the MmaOp layout of Ampere MMA must always be TN"); + + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + NVF_CHECK( + fusion_layout.isValid(), + "failed to get decide matmul layout through fusion definition"); + NVF_CHECK( + fusion_layout.getData() == layout, + "mismatch between test layout (", + toString(layout), + ") and layout inferred from fusion definition (", + toString(fusion_layout.getData()), + ")"); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); + + auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); + auto t4 = t2.to(at_accu_type); + auto t5 = atBiasEpilogue(t3, t4); + auto t6 = at::gelu(t5); + auto t7 = t6.to(at_out_type); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); + + NVF_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation did happen"); + + // NOTE: increasted absolute tolerance to silence false negative verification + // caused by different way of calculating reference + NVF_CHECK(outputs[0].allclose(t7, abs_err_thr, rel_err_thr)); +} + +// Matmul test that uses segmenter for fusion: +// D = (A x B) + bias +// Aux = gelu(D) +// Target architectures: Ampere +TEST_P(PrecisionParametrizedTest, EpilogueBiasGeluAux) { + // NOTE: test skips Turing arch, the relative error was too big + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + const auto layout = MmaLayout::TT; + + static TestCaseErrorThresholds errs = { + {HSS, std::make_pair(0.001, 0.001)}, + {HSH, std::make_pair(0.01, 0.001)}, + {TSS, std::make_pair(0.001, 0.001)}, + {TST, std::make_pair(0.01, 0.001)}, + }; + + NVF_CHECK( + errs.count(GetParam()) != 0, + "Undefined error thresholds for requested precisions"); + + const auto [in_prim_type, accu_prim_type, out_prim_type] = GetParam(); + const auto [abs_err_thr, rel_err_thr] = errs[GetParam()]; + + const auto in_type = DataType(in_prim_type); + const auto accu_type = DataType(accu_prim_type); + const auto out_type = DataType(out_prim_type); + const auto at_in_type = data_type_to_aten(in_prim_type); + const auto at_accu_type = data_type_to_aten(accu_prim_type); + const auto at_out_type = data_type_to_aten(out_prim_type); + + // NOTE: bfloat16 is not supported on pre-Ampere archs + if (DataType::BFloat16 == in_type || DataType::BFloat16 == out_type) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // A - tv0, B - tv1, C - tv2 + auto tv0 = makeContigTensor(2, in_type); + auto tv1 = makeContigTensor(2, in_type); + auto tv2 = makeContigTensor(1, out_type); + + // tv3 := A x B + auto tv3 = matmul(tv0, tv1, layout, true); + // tv4 := cast(bias) + auto tv4 = maybeCastOp(accu_type, tv2); + + // tv5 := (A x B) + bias + auto tv5 = biasEpilogue(tv3, tv4); + // tv6 := cast((A x B) + bias) + auto tv6 = maybeCastOp(out_type, tv5); + + // tv7 := gelu((A x B) + bias) + auto tv7 = gelu(tv5); + auto tv8 = maybeCastOp(out_type, tv7); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + fusion->addOutput(tv6); + fusion->addOutput(tv8); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + NVF_CHECK( + ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), + "input layout has not be set for MmaOp"); + NVF_CHECK( + MmaLayout::TN == + ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), + "the MmaOp layout of Ampere MMA must always be TN"); + + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + NVF_CHECK( + fusion_layout.isValid(), + "failed to get decide matmul layout through fusion definition"); + NVF_CHECK( + fusion_layout.getData() == layout, + "mismatch between test layout (", + toString(layout), + ") and layout inferred from fusion definition (", + toString(fusion_layout.getData()), + ")"); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); + + auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); + auto t4 = t2.to(at_accu_type); + auto t5 = atBiasEpilogue(t3, t4); + auto t6 = t5.to(at_out_type); + auto t7 = at::gelu(t5); + auto t8 = t7.to(at_out_type); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); + + NVF_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation did happen"); + + // NOTE: increasted absolute tolerance to silence false negative verification + // caused by different way of calculating reference D tensor results + NVF_CHECK(outputs[0].allclose(t6, abs_err_thr, rel_err_thr)); + // Aux tensor results + NVF_CHECK(outputs[1].allclose(t8, abs_err_thr, rel_err_thr)); +} + } // namespace -// Matmul test that relies on segmenter for 'C = A x B' fusion, +INSTANTIATE_TEST_SUITE_P( + MatmulSchedulerTest, + PrecisionParametrizedTest, + ::testing::Values(HSS, HSH, TSS, TST), + [](const testing::TestParamInfo& info) { + std::ostringstream os; + os << get_type_letter(std::get<0>(info.param)); + os << get_type_letter(std::get<1>(info.param)); + os << get_type_letter(std::get<2>(info.param)); + return os.str(); + }); + +// Matmul test that uses segmenter for 'C = A x B' fusion, // for Ampere with strict ref check, hence single layout check -TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT_CUDA) { +TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 8, 9); const int M = 128, N = 256, K = 512; - const auto layout = MatmulLayout::TT; + const auto layout = MmaLayout::TT; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -46,11 +1022,11 @@ TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT_CUDA) { ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must be always TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -85,11 +1061,11 @@ TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT_CUDA) { } // Matmul test that reslies on segmenter for 'C = A x B' fusion, for Ampere -TEST_F(MatmulSchedulerTest, BasicMatmulRelaxedCheck_CUDA) { +TEST_F(MatmulSchedulerTest, BasicMatmulRelaxedCheck) { // skip until we have Hopper support NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); const int M = 504, N = 136, K = 2048; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -111,14 +1087,14 @@ TEST_F(MatmulSchedulerTest, BasicMatmulRelaxedCheck_CUDA) { .has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()) .front() ->layout() .value(), "the MmaOp layout of Ampere MMA must be always TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -155,11 +1131,11 @@ TEST_F(MatmulSchedulerTest, BasicMatmulRelaxedCheck_CUDA) { // Matmul test that reslies on segmenter for 'C = A x B' fusion, for Ampere // MMA first input is passed as second fusion parameter. // MMA second input is passed as first fusion parameter. -TEST_F(MatmulSchedulerTest, BasicMatmulInputShuffledTT_CUDA) { +TEST_F(MatmulSchedulerTest, BasicMatmulInputShuffledTT) { // skip until we have Hopper support NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); const int M = 504, N = 136, K = 2048; - const auto layout = MmaOptions::MmaLayout::TT; + const auto layout = MmaLayout::TT; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -178,11 +1154,11 @@ TEST_F(MatmulSchedulerTest, BasicMatmulInputShuffledTT_CUDA) { ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must be always TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -215,15 +1191,15 @@ TEST_F(MatmulSchedulerTest, BasicMatmulInputShuffledTT_CUDA) { NVF_CHECK(outputs[0].allclose(tref, 0.001, 0.001)); } -// Matmul test that relies on segmenter for 'C = float2half(A x B)' fusion, for +// Matmul test that uses segmenter for 'C = float2half(A x B)' fusion, for // Ampere -TEST_F(MatmulSchedulerTest, EpilogueOutputCast_CUDA) { +TEST_F(MatmulSchedulerTest, EpilogueOutputCast) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); - const auto layout = MatmulLayout::TT; + const auto layout = MmaLayout::TT; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - // A - tv 0, B - tv1 + // A - tv0, B - tv1 auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); @@ -241,11 +1217,11 @@ TEST_F(MatmulSchedulerTest, EpilogueOutputCast_CUDA) { ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -276,15 +1252,15 @@ TEST_F(MatmulSchedulerTest, EpilogueOutputCast_CUDA) { NVF_CHECK(outputs[0].allclose(tref, 0.001, 0.001)); } -// Matmul test that relies on segmenter for 'C = alpha * (A x B)' fusion, for +// Matmul test that uses segmenter for 'C = alpha * (A x B)' fusion, for // Ampere -TEST_F(MatmulSchedulerTest, EpilogueAlpha_CUDA) { +TEST_F(MatmulSchedulerTest, EpilogueAlpha) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); - const auto layout = MatmulLayout::TT; + const auto layout = MmaLayout::TT; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - // alpha - s0, A - tv 0, B - tv1 + // alpha - s0, A - tv0, B - tv1 auto s0 = IrBuilder::create(DataType::Double); auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); @@ -304,11 +1280,11 @@ TEST_F(MatmulSchedulerTest, EpilogueAlpha_CUDA) { ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -340,15 +1316,15 @@ TEST_F(MatmulSchedulerTest, EpilogueAlpha_CUDA) { NVF_CHECK(outputs[0].allclose(tref, 0.001, 0.001)); } -// Matmul test that relies on segmenter for 'C = float2half(alpha * (A x B))' +// Matmul test that uses segmenter for 'C = float2half(alpha * (A x B))' // fusion, for Ampere -TEST_F(MatmulSchedulerTest, EpilogueAlphaOutputCast_CUDA) { +TEST_F(MatmulSchedulerTest, EpilogueAlphaOutputCast) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); - const auto layout = MatmulLayout::TT; + const auto layout = MmaLayout::TT; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - // alpha - s0, A - tv 0, B - tv1 + // alpha - s0, A - tv0, B - tv1 auto s0 = IrBuilder::create(DataType::Double); auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); @@ -369,11 +1345,11 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaOutputCast_CUDA) { ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -406,140 +1382,18 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaOutputCast_CUDA) { NVF_CHECK(outputs[0].allclose(tref, 0.001, 0.001)); } -// Matmul test that relies on segmenter for 'C = relu(A x B)' fusion, for -// Ampere -TEST_F(MatmulSchedulerTest, EpilogueRelu_CUDA) { - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); - const auto layout = MatmulLayout::TT; - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - // A - tv 0, B - tv1 - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); - - auto tv2 = matmul(tv0, tv1, layout, true); - auto tv3 = relu(tv2); - - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addOutput(tv3); - - NVF_CHECK( - 1 == ir_utils::getOpsOfType(fusion.get()).size(), - "matmul fusion must have at least one MmaOp"); - NVF_CHECK( - ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), - "input layout has not be set for MmaOp"); - NVF_CHECK( - MatmulLayout::TN == - ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), - "the MmaOp layout of Ampere MMA must always be TN"); - - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); - NVF_CHECK( - fusion_layout.isValid(), - "failed to get decide matmul layout through fusion definition"); - NVF_CHECK( - fusion_layout.getData() == layout, - "mismatch between test layout (", - toString(layout), - ") and layout inferred from fusion definition (", - toString(fusion_layout.getData()), - ")"); - - FusionExecutorCache executor_cache(std::move(fusion)); - - const int M = 504, N = 136, K = 1024; - - at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); - auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); - auto tref = at::relu(t2).to(at::kFloat); - - auto outputs = executor_cache.runFusionWithInputs({t0, t1}); - - NVF_CHECK( - !executor_cache.getMostRecentKernelRuntime()->isSegmented(), - "segmentation did happen"); - - NVF_CHECK(outputs[0].allclose(tref, 0.001, 0.001)); -} - -// Matmul test that relies on segmenter for 'C = gelu(A x B)' fusion, for -// Ampere -TEST_F(MatmulSchedulerTest, EpilogueGelu_CUDA) { - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); - const auto layout = MatmulLayout::TT; - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - // A - tv 0, B - tv1 - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); - - auto tv2 = matmul(tv0, tv1, layout, true); - auto tv3 = gelu(tv2); - - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addOutput(tv3); - - NVF_CHECK( - 1 == ir_utils::getOpsOfType(fusion.get()).size(), - "matmul fusion must have at least one MmaOp"); - NVF_CHECK( - ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), - "input layout has not be set for MmaOp"); - NVF_CHECK( - MatmulLayout::TN == - ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), - "the MmaOp layout of Ampere MMA must always be TN"); - - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); - NVF_CHECK( - fusion_layout.isValid(), - "failed to get decide matmul layout through fusion definition"); - NVF_CHECK( - fusion_layout.getData() == layout, - "mismatch between test layout (", - toString(layout), - ") and layout inferred from fusion definition (", - toString(fusion_layout.getData()), - ")"); - - FusionExecutorCache executor_cache(std::move(fusion)); - - const int M = 504, N = 136, K = 1024; - - at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); - auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); - auto tref = at::gelu(t2).to(at::kFloat); - - auto outputs = executor_cache.runFusionWithInputs({t0, t1}); - - NVF_CHECK( - !executor_cache.getMostRecentKernelRuntime()->isSegmented(), - "segmentation did happen"); - - NVF_CHECK(outputs[0].allclose(tref, 0.001, 0.001)); -} - -// Matmul test that relies on segmenter for fusion for Ampere: +// Matmul test that uses segmenter for fusion for Ampere: // D = (A x B) + beta * C -TEST_F(MatmulSchedulerTest, EpilogueBeta_CUDA) { +TEST_F(MatmulSchedulerTest, EpilogueBeta) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); - const auto layout = MatmulLayout::TT; + const auto layout = MmaLayout::TT; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); // beta - s0 auto s0 = IrBuilder::create(DataType::Double); - // A - tv 0, B - tv1, C - tv2 + // A - tv0, B - tv1, C - tv2 auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); auto tv2 = makeContigTensor(2, DataType::Half); @@ -565,11 +1419,11 @@ TEST_F(MatmulSchedulerTest, EpilogueBeta_CUDA) { ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -607,11 +1461,11 @@ TEST_F(MatmulSchedulerTest, EpilogueBeta_CUDA) { NVF_CHECK(outputs[0].allclose(t5, 0.01, 0.04)); } -// Matmul test that relies on segmenter for fusion for Ampere: +// Matmul test that uses segmenter for fusion for Ampere: // D = alpha * (A x B) + beta * C -TEST_F(MatmulSchedulerTest, EpilogueAlphaBeta_CUDA) { +TEST_F(MatmulSchedulerTest, EpilogueAlphaBeta) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); - const auto layout = MatmulLayout::TT; + const auto layout = MmaLayout::TT; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -619,7 +1473,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBeta_CUDA) { auto s0 = IrBuilder::create(DataType::Double); auto s1 = IrBuilder::create(DataType::Double); - // A - tv 0, B - tv1, C - tv2 + // A - tv0, B - tv1, C - tv2 auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); auto tv2 = makeContigTensor(2, DataType::Half); @@ -647,11 +1501,11 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBeta_CUDA) { ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -691,11 +1545,11 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBeta_CUDA) { NVF_CHECK(outputs[0].allclose(t6, 0.001, 0.004)); } -// Matmul test that relies on segmenter for fusion for Ampere: +// Matmul test that uses segmenter for fusion for Ampere: // D = gelu(alpha * (A x B) + beta * C) -TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaGeluOutputCast_CUDA) { +TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaGeluOutputCast) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); - const auto layout = MatmulLayout::TT; + const auto layout = MmaLayout::TT; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -703,7 +1557,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaGeluOutputCast_CUDA) { auto s0 = IrBuilder::create(DataType::Double); auto s1 = IrBuilder::create(DataType::Double); - // A - tv 0, B - tv1, C - tv2 + // A - tv0, B - tv1, C - tv2 auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); auto tv2 = makeContigTensor(2, DataType::Half); @@ -735,11 +1589,11 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaGeluOutputCast_CUDA) { ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -782,84 +1636,12 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaGeluOutputCast_CUDA) { NVF_CHECK(outputs[0].allclose(t8, 0.01, 0.06)); } -// Matmul test that relies on segmenter for fusion for Ampere: -// D = (A x B) + bias -TEST_F(MatmulSchedulerTest, EpilogueBias_CUDA) { - // NOTE: test skips Turing arch, the relative error was too big - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); - const auto layout = MatmulLayout::TT; - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - // A - tv0, B - tv1, C - tv2 - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); - auto tv2 = makeContigTensor(1, DataType::Float); - - // tv3 := A x B - auto tv3 = matmul(tv0, tv1, layout, true); - - // tv4 := (A x B) + bias - auto tv4 = biasEpilogue(tv3, tv2); - - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addInput(tv2); - fusion->addOutput(tv4); - - NVF_CHECK( - 1 == ir_utils::getOpsOfType(fusion.get()).size(), - "matmul fusion must have at least one MmaOp"); - NVF_CHECK( - ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), - "input layout has not be set for MmaOp"); - NVF_CHECK( - MatmulLayout::TN == - ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), - "the MmaOp layout of Ampere MMA must always be TN"); - - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); - NVF_CHECK( - fusion_layout.isValid(), - "failed to get decide matmul layout through fusion definition"); - NVF_CHECK( - fusion_layout.getData() == layout, - "mismatch between test layout (", - toString(layout), - ") and layout inferred from fusion definition (", - toString(fusion_layout.getData()), - ")"); - - FusionExecutorCache executor_cache(std::move(fusion)); - - const int M = 504, N = 136, K = 1024; - - at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); - auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at::kFloat, M, N, K); - - auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); - - auto t4 = atBiasEpilogue(t3, t2).to(at::kFloat); - - auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); - - NVF_CHECK( - !executor_cache.getMostRecentKernelRuntime()->isSegmented(), - "segmentation did happen"); - - // NOTE: increasted absolute tolerance to silence false negative verification - // caused by different way of calculating reference - NVF_CHECK(outputs[0].allclose(t4, 0.001, 0.001)); -} - -// Matmul test that relies on segmenter for fusion for Ampere: +// Matmul test that uses segmenter for fusion for Ampere: // D = alpha * ((A x B) + bias) + beta * C -TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaBias_CUDA) { +TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaBias) { // NOTE: test skips Turing arch, the relative error was too big NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); - const auto layout = MatmulLayout::TT; + const auto layout = MmaLayout::TT; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -899,11 +1681,11 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaBias_CUDA) { ir_utils::getOpsOfType(fusion.get()).front()->layout().has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -951,10 +1733,10 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaBias_CUDA) { // Strided batch gemm test taht uses matmul scheduler, for Ampere: // D = (A x B) -TEST_F(MatmulSchedulerTest, StridedBatch_CUDA) { +TEST_F(MatmulSchedulerTest, StridedBatch) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); const int M = 504, N = 136, K = 248, B = 2; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -979,14 +1761,14 @@ TEST_F(MatmulSchedulerTest, StridedBatch_CUDA) { .has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()) .front() ->layout() .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1021,11 +1803,11 @@ TEST_F(MatmulSchedulerTest, StridedBatch_CUDA) { // Strided batch gemm test with alpha and beta that uses matmul scheduler, // for Ampere architecture: // D = alpha * (A x B) + beta * C -TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaBeta_CUDA) { +TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaBeta) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); const int M = 504, N = 136, K = 248, B = 2; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1063,14 +1845,14 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaBeta_CUDA) { .has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()) .front() ->layout() .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1114,11 +1896,11 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaBeta_CUDA) { // scheduler, // there is only single C tensor for whole batch; test for Ampere architecture: // D = alpha * (A x B) + beta * C -TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaSingleBeta_CUDA) { +TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaSingleBeta) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); const int M = 504, N = 136, K = 248, B = 2; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1159,14 +1941,14 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaSingleBeta_CUDA) { .has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()) .front() ->layout() .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1211,11 +1993,11 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaSingleBeta_CUDA) { // Strided batch gemm test with bias that uses matmul scheduler, for Ampere: // D = (A x B) + bias -TEST_F(MatmulSchedulerTest, StridedBatchEpilogueBias_CUDA) { +TEST_F(MatmulSchedulerTest, StridedBatchEpilogueBias) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); const int M = 504, N = 136, K = 248, B = 2; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1244,14 +2026,14 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueBias_CUDA) { .has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()) .front() ->layout() .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1289,11 +2071,11 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueBias_CUDA) { // Strided batch gemm test with single bias vector that uses matmul // scheduler, for Ampere: // D = (A x B) + bias -TEST_F(MatmulSchedulerTest, StridedBatchEpilogueSingleBias_CUDA) { +TEST_F(MatmulSchedulerTest, StridedBatchEpilogueSingleBias) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); const int M = 504, N = 136, K = 248, B = 2; - for (auto layout : kAllSupportedMatmulLayout) { + for (auto layout : kAllSupportedMmaLayout) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1322,14 +2104,14 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueSingleBias_CUDA) { .has_value(), "input layout has not be set for MmaOp"); NVF_CHECK( - MatmulLayout::TN == + MmaLayout::TN == ir_utils::getOpsOfType(fusion.get()) .front() ->layout() .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMatmulLayout(fusion.get()); + const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); diff --git a/test/test_memory.cpp b/test/test_memory.cpp index fcc00e406f9..1c765a34808 100644 --- a/test/test_memory.cpp +++ b/test/test_memory.cpp @@ -26,9 +26,9 @@ namespace nvfuser { -class MemoryTest - : public NVFuserTest, - public testing::WithParamInterface> { +using MemoryTestParams = std::tuple; + +class MemoryTest : public NVFuserFixtureParamTest { protected: void expectMatchCount( const std::string& text, @@ -107,7 +107,7 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(CacheOp::AllLevels, "ca"), std::make_tuple(CacheOp::Global, "cg"), std::make_tuple(CacheOp::Streaming, "cs")), - [](const testing::TestParamInfo>& info) { + [](const testing::TestParamInfo& info) { std::ostringstream os; os << std::get<0>(info.param); return os.str(); @@ -472,4 +472,115 @@ TEST_F(TMATest, DisableIndexHoisting) { testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); } +using LdMatrixTestParam = std::tuple; + +class LdMatrixTest : public NVFuserFixtureParamTest { + protected: + void SetUp() override { + // requires Turing or newer + if (cudaArchGuardShouldSkip(7, 5)) { + GTEST_SKIP() << "skipping tests on pre-Turing GPUs"; + } + NVFuserTest::SetUp(); + } +}; + +TEST_P(LdMatrixTest, Regular) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto macro = std::get<0>(GetParam()); + auto operand = std::get<1>(GetParam()); + + bool is_a = operand == MmaOperand::A; + + int size1 = (is_a ? getM(macro) : getN(macro)); + + auto tv0 = makeConcreteTensor({size1, getK(macro)}, DataType::Half); + fusion.addInput(tv0); + auto tv1 = set(tv0); + tv1->setMemoryType(MemoryType::Shared); + auto tv2 = set(tv1); + tv2->definition()->as()->setOpType(LoadStoreOpType::LdMatrix); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv2->applyMmaSwizzle(operand); + tv3->applyMmaSwizzle(operand); + + tv3->merge(0); + if (is_a) { + tv3->merge(0); + } + tv3->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({size1, getK(macro)}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}, LaunchParams(), matmul_cparams); + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); +} + +TEST_P(LdMatrixTest, Transpose) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto macro = std::get<0>(GetParam()); + auto operand = std::get<1>(GetParam()); + + bool is_a = operand == MmaOperand::A; + + int size2 = (is_a ? getM(macro) : getN(macro)); + + auto tv0 = makeConcreteTensor({getK(macro), size2}, DataType::Half); + fusion.addInput(tv0); + auto tv1 = set(tv0); + tv1->setMemoryType(MemoryType::Shared); + auto tv2 = transpose(tv1, 0, 1); + tv2->definition()->as()->setOpType( + LoadStoreOpType::LdMatrixTranspose); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv2->applyMmaSwizzle(operand); + tv3->applyMmaSwizzle(operand); + + tv3->merge(0); + if (is_a) { + tv3->merge(0); + } + tv3->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({getK(macro), size2}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}, LaunchParams(), matmul_cparams); + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); +} + +INSTANTIATE_TEST_SUITE_P( + CopyUsingLdMatrix, + LdMatrixTest, + testing::Values( + std::make_tuple(MmaMacro::Turing_16_8_8, MmaOperand::A), + std::make_tuple(MmaMacro::Turing_16_8_16, MmaOperand::A), + std::make_tuple(MmaMacro::Turing_16_8_8, MmaOperand::B), + std::make_tuple(MmaMacro::Turing_16_8_16, MmaOperand::B), + std::make_tuple(MmaMacro::Turing_16_16_16, MmaOperand::B), + std::make_tuple(MmaMacro::Hopper_64_8_16, MmaOperand::A)), + [](const testing::TestParamInfo& info) { + std::ostringstream os; + auto macro = std::get<0>(info.param); + bool is_a = std::get<1>(info.param) == MmaOperand::A; + os << (is_a ? "A" : "B") << "_" << (is_a ? getM(macro) : getN(macro)) + << "x" << getK(macro); + return os.str(); + }); + } // namespace nvfuser diff --git a/test/test_mma.cpp b/test/test_mma.cpp new file mode 100644 index 00000000000..ea646a559b0 --- /dev/null +++ b/test/test_mma.cpp @@ -0,0 +1,173 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace nvfuser { + +using MmaTestParams = std::tuple; + +class MmaTest : public NVFuserFixtureParamTest { + protected: + MmaLayout layout; + MmaMacro macro; + PrimDataType dtype; + + void SetUp() override { + macro = std::get<0>(GetParam()); + dtype = std::get<1>(GetParam()); + layout = std::get<2>(GetParam()); + + if (isTuring(macro) && cudaArchGuardShouldSkip(7, 5)) { + GTEST_SKIP() << "skipping tests on pre-Turing GPUs"; + } + + if (isAmpere(macro) && cudaArchGuardShouldSkip(8, 0)) { + GTEST_SKIP() << "skipping tests on pre-Ampere GPUs"; + } + + NVFuserTest::SetUp(); + } +}; + +TEST_P(MmaTest, SingleTile) { + Fusion fusion; + FusionGuard fg(&fusion); + + bool transpose_a = (layout == MmaLayout::NT || layout == MmaLayout::NN); + bool transpose_b = (layout == MmaLayout::TT || layout == MmaLayout::NT); + + std::vector A_shape{getM(macro), getK(macro)}, + B_shape{getN(macro), getK(macro)}; + + if (transpose_a) { + std::swap(A_shape[0], A_shape[1]); + } + + if (transpose_b) { + std::swap(B_shape[0], B_shape[1]); + } + + auto tv0 = makeConcreteTensor(A_shape, dtype); + auto tv1 = makeConcreteTensor(B_shape, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M, K] + if (transpose_a) { + tv0 = transpose(tv0, 0, 1); + } + + // [N, K] + if (transpose_b) { + tv1 = transpose(tv1, 0, 1); + } + + // [M, N, K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + auto mma_ops = ir_utils::getOpsOfType(&fusion); + NVF_CHECK( + 1 == mma_ops.size(), + "Invalid number of MmaOp instances in fusion definition, expected 1, got ", + mma_ops.size()); + mma_ops.front()->setMacro(macro); + + auto tv2c = tv2->cacheBefore(); + + // [M, N, K] -> [N, M, K] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(MmaOperand::A); + tv1b->applyMmaSwizzle(MmaOperand::B); + + tv0b->merge(1); + tv0b->merge(1); + tv0b->axis(1)->parallelize(ParallelType::TIDx); + tv1b->merge(1); + tv1b->axis(1)->parallelize(ParallelType::TIDx); + + tv2c->applyMmaSwizzle(MmaOperand::Accumulator); + tv2->applyMmaSwizzle(MmaOperand::Accumulator); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto t0 = at::randn(A_shape, options); + auto t1 = at::randn(B_shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams); + + auto cg_outputs = fe.runFusion({t0, t1}); + + at::Tensor t0t = t0, t1t = t1; + + if (transpose_a) { + t0t = t0.t(); + } + + if (!transpose_b) { + t1t = t1.t(); + } + + auto tref = t0t.to(at::kFloat).matmul(t1t.to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +auto all_mma_layouts = + testing::Values(MmaLayout::TT, MmaLayout::TN, MmaLayout::NT, MmaLayout::NN); + +auto all_dtypes = testing::Values(DataType::Half, DataType::BFloat16); + +std::string testName(const testing::TestParamInfo& info) { + std::ostringstream os; + auto macro = std::get<0>(info.param); + auto dtype = std::get<1>(info.param); + auto layout = std::get<2>(info.param); + os << getM(macro) << "_" << getN(macro) << "_" << getK(macro) << "_" + << toString(layout) << dtype; + return os.str(); +} + +INSTANTIATE_TEST_SUITE_P( + Turing, + MmaTest, + testing::Combine( + testing::Values( + MmaMacro::Turing_16_8_8, + MmaMacro::Turing_16_8_16, + MmaMacro::Turing_16_16_16), + testing::Values(DataType::Half), + all_mma_layouts), + testName); + +INSTANTIATE_TEST_SUITE_P( + Ampere, + MmaTest, + testing::Combine( + testing::Values(MmaMacro::Ampere_16_8_16, MmaMacro::Ampere_16_16_16), + all_dtypes, + all_mma_layouts), + testName); + +} // namespace nvfuser diff --git a/test/test_multidevice_communications.cpp b/test/test_multidevice_communications.cpp index a1ce37c7dc2..4bdc63dcc3f 100644 --- a/test/test_multidevice_communications.cpp +++ b/test/test_multidevice_communications.cpp @@ -16,7 +16,7 @@ namespace nvfuser { -TEST_F(CommunicationTest, Communication_Gather) { +TEST_P(CommunicationTest, Communication_Gather) { params.root = root; params.team = all_ranks; params.src_bufs = {at::empty(tensor_size, tensor_options)}; @@ -33,7 +33,7 @@ TEST_F(CommunicationTest, Communication_Gather) { at::arange(tensor_size, tensor_options) + (communicator->deviceId() + 1) * j); - auto work = communication.post(*communicator); + auto work = communication.post(*communicator, GetParam()); work->wait(); if (communicator->deviceId() == root) { @@ -46,7 +46,7 @@ TEST_F(CommunicationTest, Communication_Gather) { } } -TEST_F(CommunicationTest, Communication_Allgather) { +TEST_P(CommunicationTest, Communication_Allgather) { params.team = all_ranks; params.src_bufs = { at::empty(tensor_size, tensor_options) * communicator->deviceId()}; @@ -61,7 +61,7 @@ TEST_F(CommunicationTest, Communication_Allgather) { at::arange(tensor_size, tensor_options) + (communicator->deviceId() + 1) * j); - auto work = communication.post(*communicator); + auto work = communication.post(*communicator, GetParam()); work->wait(); for (int i : c10::irange(communicator->size())) { @@ -72,7 +72,7 @@ TEST_F(CommunicationTest, Communication_Allgather) { } } -TEST_F(CommunicationTest, Communication_Scatter) { +TEST_P(CommunicationTest, Communication_Scatter) { params.root = root; params.team = all_ranks; if (communicator->deviceId() == root) { @@ -91,7 +91,7 @@ TEST_F(CommunicationTest, Communication_Scatter) { at::arange(tensor_size, tensor_options) + (i + 1) * j); } - auto work = communication.post(*communicator); + auto work = communication.post(*communicator, GetParam()); work->wait(); auto obtained = params.dst_bufs.at(0); @@ -101,7 +101,7 @@ TEST_F(CommunicationTest, Communication_Scatter) { } } -TEST_F(CommunicationTest, Communication_Broadcast) { +TEST_P(CommunicationTest, Communication_Broadcast) { params.root = root; params.team = all_ranks; if (communicator->deviceId() == root) { @@ -117,7 +117,7 @@ TEST_F(CommunicationTest, Communication_Broadcast) { params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); } - auto work = communication.post(*communicator); + auto work = communication.post(*communicator, GetParam()); if (communicator->size() > 1) { work->wait(); } @@ -128,7 +128,7 @@ TEST_F(CommunicationTest, Communication_Broadcast) { } } -TEST_F(CommunicationTest, Communication_SendRecv) { +TEST_P(CommunicationTest, Communication_SendRecv) { DeviceIdxType sender = 0; DeviceIdxType receiver = 1; if (communicator->deviceId() > 1) { // only devices 0 and 1 participate @@ -150,7 +150,7 @@ TEST_F(CommunicationTest, Communication_SendRecv) { params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); } - auto work = communication.post(*communicator); + auto work = communication.post(*communicator, GetParam()); work->wait(); if (communicator->deviceId() == receiver) { @@ -161,7 +161,7 @@ TEST_F(CommunicationTest, Communication_SendRecv) { } } -TEST_F(CommunicationTest, Communication_SendRecvToSelf) { +TEST_P(CommunicationTest, Communication_SendRecvToSelf) { DeviceIdxType sender = 0; if (communicator->deviceId() > 0) { // only device 0 participates return; @@ -177,7 +177,7 @@ TEST_F(CommunicationTest, Communication_SendRecvToSelf) { resetDstBuffers(); params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); - communication.post(*communicator); + communication.post(*communicator, GetParam()); auto obtained = params.dst_bufs.at(0); auto ref = at::arange(tensor_size, tensor_options) + j; @@ -185,6 +185,95 @@ TEST_F(CommunicationTest, Communication_SendRecvToSelf) { } } +TEST_P(CommunicationTest, Communication_Reduce) { + params.redOp = red_op; + params.root = root; + params.team = all_ranks; + params.src_bufs = {at::empty(tensor_size, tensor_options)}; + if (communicator->deviceId() == root) { + params.dst_bufs = {at::empty(tensor_size, tensor_options)}; + } + auto communication = Reduce(params); + + for (int j : c10::irange(number_of_repetitions)) { + resetDstBuffers(); + params.src_bufs.at(0).copy_( + at::arange(tensor_size, tensor_options) + + (communicator->deviceId() + 1) * j); + + auto work = communication.post(*communicator, GetParam()); + work->wait(); + + if (communicator->deviceId() == root) { + auto obtained = params.dst_bufs.at(0); + int S = communicator->size(); + auto ref = + at::arange(tensor_size, tensor_options) * S + S * (S + 1) / 2 * j; + validate(obtained, ref); + } + } +} + +TEST_P(CommunicationTest, Communication_Allreduce) { + params.redOp = red_op; + params.team = all_ranks; + params.src_bufs = {at::empty(tensor_size, tensor_options)}; + params.dst_bufs = {at::empty(tensor_size, tensor_options)}; + auto communication = Allreduce(params); + + for (int j : c10::irange(number_of_repetitions)) { + resetDstBuffers(); + params.src_bufs.at(0).copy_( + at::arange(tensor_size, tensor_options) + + (communicator->deviceId() + 1) * j); + + auto work = communication.post(*communicator, GetParam()); + work->wait(); + + auto obtained = params.dst_bufs.at(0); + int S = communicator->size(); + auto ref = + at::arange(tensor_size, tensor_options) * S + S * (S + 1) / 2 * j; + validate(obtained, ref); + } +} + +TEST_P(CommunicationTest, Communication_ReduceScatter) { + params.redOp = red_op; + params.root = root; + params.team = all_ranks; + for (int64_t i = 0; i < communicator->size(); i++) { + params.src_bufs.push_back(at::empty(tensor_size, tensor_options)); + } + params.dst_bufs = {at::empty(tensor_size, tensor_options)}; + auto communication = ReduceScatter(params); + + for (int j : c10::irange(number_of_repetitions)) { + resetDstBuffers(); + for (int i : c10::irange(communicator->size())) { + params.src_bufs.at(i).copy_( + at::arange(tensor_size, tensor_options) + + (communicator->deviceId() + 1) * (i + j)); + } + + auto work = communication.post(*communicator, GetParam()); + work->wait(); + + auto obtained = params.dst_bufs.at(0); + int S = communicator->size(); + auto ref = at::arange(tensor_size, tensor_options) * S + + S * (S + 1) / 2 * (communicator->deviceId() + j); + validate(obtained, ref); + } +} + +INSTANTIATE_TEST_SUITE_P( + CommunicatorBackend, + CommunicationTest, + ::testing::Values(CommunicatorBackend::nccl, CommunicatorBackend::ucc) + +); + } // namespace nvfuser #endif diff --git a/test/test_multidevice_pipeline.cpp b/test/test_multidevice_pipeline.cpp index 4c20cc14030..9089a599ad0 100644 --- a/test/test_multidevice_pipeline.cpp +++ b/test/test_multidevice_pipeline.cpp @@ -136,6 +136,9 @@ TEST_F(PipelineTest, Pipeline) { validate(); } +auto all_backends = + ::testing::Values(CommunicatorBackend::nccl, CommunicatorBackend::ucc); + DeviceMesh mesh0({0}); DeviceMesh mesh1({1}); DeviceMesh mesh2({0, 1, 2, 3}); @@ -149,6 +152,7 @@ INSTANTIATE_TEST_SUITE_P( Gather, PipelineTestTwoStages, ::testing::Combine( + all_backends, all_meshes, all_meshes, ::testing::Values(true), @@ -158,6 +162,7 @@ INSTANTIATE_TEST_SUITE_P( Scatter, PipelineTestTwoStages, ::testing::Combine( + all_backends, all_meshes, all_meshes, ::testing::Values(false), @@ -167,6 +172,7 @@ INSTANTIATE_TEST_SUITE_P( Bcast, PipelineTestTwoStages, ::testing::Combine( + all_backends, all_meshes, all_meshes, ::testing::Values(false), @@ -176,6 +182,7 @@ INSTANTIATE_TEST_SUITE_P( Bcast_sharded, PipelineTestTwoStages, ::testing::Combine( + all_backends, ::testing::Values(mesh3), ::testing::Values(mesh4), ::testing::Values(true), diff --git a/test/test_no_op.cpp b/test/test_no_op.cpp index 5c8826aa119..ca99f0b7389 100644 --- a/test/test_no_op.cpp +++ b/test/test_no_op.cpp @@ -44,8 +44,7 @@ TEST_F(NoOpTest, FusionNullScheduler) { std::cerr << cg_outputs[0].sizes() << std::endl; std::cerr << t1.sizes() << std::endl; - testValidate( - executor_cache.fusion(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); + testValidate(executor_cache.fusion(), cg_outputs, {t0}, __LINE__, __FILE__); auto groups = executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(); @@ -76,10 +75,7 @@ TEST_F(NoOpTest, FusionNullScheduler2) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t1 = t0.sum({1, 2}); - - testValidate( - executor_cache.fusion(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); + testValidate(executor_cache.fusion(), cg_outputs, {t0}, __LINE__, __FILE__); auto groups = executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(); @@ -112,12 +108,7 @@ TEST_F(NoOpTest, FusionNullScheduler3) { auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); testValidate( - executor_cache.fusion(), - cg_outputs, - {t0, t1}, - {t0 + t1}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); auto groups = executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(); @@ -147,10 +138,7 @@ TEST_F(NoOpTest, FusionReducingZeroElements) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t1 = t0.sum({0, 1, 2}); - - testValidate( - executor_cache.fusion(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); + testValidate(executor_cache.fusion(), cg_outputs, {t0}, __LINE__, __FILE__); } TEST_F(NoOpTest, FusionEmpty) { @@ -174,12 +162,7 @@ TEST_F(NoOpTest, FusionEmpty) { auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); testValidate( - executor_cache.fusion(), - cg_outputs, - {t0, t1}, - {t0, t1}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); auto groups = executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(); diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index f56740391f2..e81e676f911 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -428,13 +428,7 @@ TEST_F(NVFuserTest, FusionRemoveEmptyReduction_CUDA) { runtime.compileFusionParallel(args); auto outputs = runtime.runWithInputs(args); - testValidate( - preseg_fusion, - outputs, - aten_inputs, - {at::sum(at0, {0})}, - __LINE__, - __FILE__); + testValidate(preseg_fusion, outputs, aten_inputs, __LINE__, __FILE__); } // In this test, a reduction over a non-empty axis occurs first, followed by a @@ -467,13 +461,7 @@ TEST_F(NVFuserTest, FusionRemoveEmptyReductionWithNonReduction_CUDA) { runtime.compileFusionParallel(args); auto outputs = runtime.runWithInputs(args); - testValidate( - preseg_fusion, - outputs, - aten_inputs, - {at::sum(at::sum(at0, 1), 0)}, - __LINE__, - __FILE__); + testValidate(preseg_fusion, outputs, aten_inputs, __LINE__, __FILE__); } // Test that we replace empty Welford with full @@ -565,13 +553,7 @@ TEST_F(NVFuserTest, FusionRemoveEmptyCat_CUDA) { runtime.compileFusionParallel(args); auto outputs = runtime.runWithInputs(args); - testValidate( - preseg_fusion, - outputs, - aten_inputs, - {at::cat({at1, at2}, 0), at1}, - __LINE__, - __FILE__); + testValidate(preseg_fusion, outputs, aten_inputs, __LINE__, __FILE__); } // Test that we replace empty tensors in pad properly @@ -609,13 +591,7 @@ TEST_F(NVFuserTest, FusionRemoveEmptyPad_CUDA) { runtime.compileFusionParallel(args); auto outputs = runtime.runWithInputs(args); - testValidate( - preseg_fusion, - outputs, - aten_inputs, - {at::pad(at0, {1, 1}, "constant", 3.14)}, - __LINE__, - __FILE__); + testValidate(preseg_fusion, outputs, aten_inputs, __LINE__, __FILE__); } // Test that we replace empty tensors in matmuls properly @@ -660,13 +636,7 @@ TEST_F(NVFuserTest, FusionRemoveEmptyMatmul_CUDA) { runtime.compileFusionParallel(args); auto outputs = runtime.runWithInputs(args); - testValidate( - preseg_fusion, - outputs, - aten_inputs, - {at::zeros({16, 8}, options)}, - __LINE__, - __FILE__); + testValidate(preseg_fusion, outputs, aten_inputs, __LINE__, __FILE__); } } // namespace nvfuser::optimization diff --git a/test/test_pointwise.cpp b/test/test_pointwise.cpp new file mode 100644 index 00000000000..7b1b10400dc --- /dev/null +++ b/test/test_pointwise.cpp @@ -0,0 +1,179 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace nvfuser { + +using PointwiseTest = NVFuserTest; + +namespace { + +size_t getVecSizeForPointwise(FusionExecutorCache& fec) { + auto most_recent_params = + fec.getMostRecentKernelRuntime()->getMostRecentExecutorLog().params; + const auto* params = dynamic_cast(most_recent_params.get()); + NVF_ERROR( + params != nullptr, + "`fec`'s contained fusion didn't trigger the pointwise scheduler."); + if (params->vectorize) { + return params->unroll_factor; + } + return 1; +} + +} // namespace + +TEST_F(PointwiseTest, VectorizeStrideContiguity2D) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = + TensorViewBuilder().ndims(2).contiguity({false, true}).build(); + fusion->addInput(tv0); + auto tv1 = add(tv0, tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + std::vector> size_and_vec{{17, 1}, {18, 2}, {32, 4}}; + + for (auto pair : size_and_vec) { + auto size = pair.first; + auto vec = pair.second; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1000000, size}, options).narrow(1, 0, 16); + auto cg_outputs = fec.runFusionWithInputs({t0}); + + EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); + + testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__); + } +} + +TEST_F(PointwiseTest, VectorizeStrideContiguity3D) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = + TensorViewBuilder().ndims(3).contiguity({false, true, true}).build(); + fusion->addInput(tv0); + auto tv1 = add(tv0, tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + std::vector> size_and_vec{{17, 1}, {10, 2}, {16, 4}}; + + for (auto pair : size_and_vec) { + auto size = pair.first; + auto vec = pair.second; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1000000, size, 3}, options).narrow(1, 0, 8); + auto cg_outputs = fec.runFusionWithInputs({t0}); + + EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); + + testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__); + } +} + +TEST_F(PointwiseTest, VectorizeStrideContiguity5D) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(5) + .contiguity({false, true, false, true, true}) + .build(); + fusion->addInput(tv0); + auto tv1 = add(tv0, tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + std::vector> sizes_and_vec{ + {9, 17, 1}, {9, 10, 2}, {9, 16, 4}}; + + for (auto tup : sizes_and_vec) { + auto size1 = std::get<0>(tup); + auto size2 = std::get<1>(tup); + auto vec = std::get<2>(tup); + at::Tensor t0 = at::randn({4, size1, 12345, size2, 3}, options) + .narrow(1, 0, 8) + .narrow(3, 0, 4); + auto cg_outputs = fec.runFusionWithInputs({t0}); + + EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); + + testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__); + } +} + +TEST_F(PointwiseTest, VectorizeStrideContiguitySelfOverlapping) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(5) + .contiguity({false, true, false, true, true}) + .build(); + fusion->addInput(tv0); + auto tv1 = add(tv0, tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + std::vector> sizes_strides_and_vec{ + {4, 4, 4, 4}, + {4, 4, 2, 2}, + {4, 2, 4, 2}, + {2, 4, 4, 2}, + {4, 4, 1, 1}, + {4, 1, 4, 1}, + {1, 4, 4, 1}, + {2, 2, 2, 2}, + {2, 2, 1, 1}, + {2, 1, 2, 1}, + {1, 2, 2, 1}}; + + for (auto tup : sizes_strides_and_vec) { + auto size = std::get<0>(tup); + auto stride1 = std::get<1>(tup); + auto stride2 = std::get<2>(tup); + auto vec = std::get<3>(tup); + std::vector shape = {4, 4, 12345, size, 3}; + std::vector stride = { + stride1, (int64_t)stride2 * 12345, (int64_t)stride2, 3, 1}; + at::Tensor t0 = at::empty_strided(shape, stride, options); + t0.random_(); + auto cg_outputs = fec.runFusionWithInputs({t0}); + EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); + testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__); + } +} + +} // namespace nvfuser diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 36213e760ec..ebbd68eb10e 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -454,17 +454,8 @@ TEST_F(ResizeTest, FusionResizePadScheduler4) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t0_double = t0.to(at::kDouble); - auto t2 = at::pad(t0_double, {1, 1}).sum({0}); - auto t4 = at::pad(t0_double, {1, 1}).sum({1}); - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {t2, t4}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } // Pad a broadcast @@ -1324,15 +1315,8 @@ TEST_F(ResizeTest, FusionResizePadReduceScheduler1) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto ref = at::pad(t0, pad_extents).sum({1}); - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {ref}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(ResizeTest, FusionResizeSliceReduceScheduler1) { @@ -1371,18 +1355,8 @@ TEST_F(ResizeTest, FusionResizeSliceReduceScheduler1) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t1 = t0.index( - {at::indexing::Slice(slice_inputs[0], slice_inputs[1]), - at::indexing::Slice(slice_inputs[2], slice_inputs[3])}); - auto ref = t1.sum({1}); - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {ref}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } // Multiple slice+reduction. Different slices. @@ -1425,22 +1399,8 @@ TEST_F(ResizeTest, FusionResizeSliceReduceScheduler2) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t1 = t0.index( - {at::indexing::Slice(0, at::indexing::None), - at::indexing::Slice(slice_inputs[0], slice_inputs[1])}); - auto t2 = t1.sum({1}); - auto t3 = t0.index( - {at::indexing::Slice(0, at::indexing::None), - at::indexing::Slice(slice_inputs[2], slice_inputs[3])}); - auto t4 = t3.sum({1}); - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {t2, t4}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } // Multiple slice+reduction. Same slices. Should be segmented at the moment. @@ -1479,22 +1439,8 @@ TEST_F(ResizeTest, FusionSliceReduceScheduler3) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t1 = t0.index( - {at::indexing::Slice(0, at::indexing::None), - at::indexing::Slice(slice_inputs[0], slice_inputs[1])}); - auto t2 = t1.to(at::kDouble).sum({1}); - auto t3 = t0.index( - {at::indexing::Slice(0, at::indexing::None), - at::indexing::Slice(slice_inputs[0], slice_inputs[1])}); - auto t4 = t3.to(at::kDouble).sum({1}); - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {t2, t4}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(ResizeTest, FusionResizeCatReduceScheduler1) { @@ -1523,15 +1469,8 @@ TEST_F(ResizeTest, FusionResizeCatReduceScheduler1) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto ref = at::cat({t0, t1}, 1).sum({1}); - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {ref}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(ResizeTest, FusionResizeCatSoftmaxScheduler1) { @@ -1560,16 +1499,8 @@ TEST_F(ResizeTest, FusionResizeCatSoftmaxScheduler1) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t2 = at::cat({t0, t1}, 1); - auto ref = at::_softmax(t2.to(at::kDouble), -1, false); - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {ref}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } TEST_F(ResizeTest, FusionResizeReductionSliceScheduler1) { @@ -1597,16 +1528,8 @@ TEST_F(ResizeTest, FusionResizeReductionSliceScheduler1) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t1 = t0.to(at::kDouble).sum({1}); - auto t2 = t1.index({at::indexing::Slice(1, shape0[0] - 2)}); - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {t2}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } // Softmax followed by slicing of a non-normalized dimension @@ -1636,18 +1559,8 @@ TEST_F(ResizeTest, FusionResizeSoftmaxSliceScheduler1) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t1 = at::_softmax(t0.to(at::kDouble), -1, false); - auto t2 = t1.index( - {at::indexing::Slice(1, shape0[0] - 2), - at::indexing::Slice(0, at::indexing::None)}); - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {t2}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } // Softmax followed by slicing of a normalized dimension @@ -1677,18 +1590,8 @@ TEST_F(ResizeTest, FusionResizeSoftmaxSliceScheduler2) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto t1 = at::_softmax(t0.to(at::kDouble), -1, false); - auto t2 = t1.index( - {at::indexing::Slice(0, at::indexing::None), - at::indexing::Slice(1, shape0[1] - 2)}); - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {t2}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } // Same as Pad1 but pad by specified value @@ -2087,19 +1990,8 @@ TEST_F(ResizeTest, ResizePermuteAndSlice) { "Unexpected heuristic: ", heuristic); - auto ref_t2 = (t0 + 1).index( - {at::indexing::Slice(1, shape.at(0) - 1), - at::indexing::Slice(2, shape.at(1) - 2)}); - auto ref_t3 = ref_t2.transpose(0, 1); - auto ref_t4 = ref_t2 + 1; - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {ref_t3, ref_t4}, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } // When scheduling this test, the pointwise scheduler attempt to replay a Split @@ -2360,7 +2252,7 @@ TEST_F(ResizeTest, SliceVectorization) { // testValidate does not check that dtypes match EXPECT_EQ(cg_outputs[0].dtype(), ref.dtype()); - testValidate(&fusion, cg_outputs, inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, inputs, __LINE__, __FILE__); } // Concretize a symbolic pad that results in a broadcast (static pads) @@ -2429,16 +2321,7 @@ TEST_F(NVFuserTest, ResizePadToBroadcastStatic_CUDA) { EXPECT_EQ(conc_t2->axis(i)->getIterType(), expected_itertypes.at(i)); } - auto t2_padded = at::pad(t0, pad_widths); - auto ref_t2 = t1 * t2_padded; - - testValidate( - concretized_fusion, - cg_outputs, - aten_inputs, - {ref_t2}, - __LINE__, - __FILE__); + testValidate(concretized_fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } // Concretize a symbolic pad that results in a broadcast (dynamic pads) @@ -2505,16 +2388,7 @@ TEST_F(NVFuserTest, ResizePadToBroadcastDynamic_CUDA) { EXPECT_EQ(conc_t2->axis(3)->getIterType(), IterType::Broadcast); EXPECT_EQ(conc_t2->axis(4)->getIterType(), IterType::Iteration); - auto t2_padded = at::pad(t0, pad_widths); - auto ref_t2 = t1 * t2_padded; - - testValidate( - concretized_fusion, - cg_outputs, - aten_inputs, - {ref_t2}, - __LINE__, - __FILE__); + testValidate(concretized_fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } // See https://github.com/NVIDIA/Fuser/issues/596 @@ -2545,14 +2419,10 @@ TEST_F(NVFuserTest, ResizePadToBroadcastIssue596_CUDA) { runtime.compileFusionParallel(args); auto cg_outputs = runtime.runWithInputs(args); - auto t2_padded = at::pad(t0, {0, -1}); - auto ref_t2 = t1 * t2_padded; - testValidate( runtime.fusionSegments()->completeFusion(), cg_outputs, aten_inputs, - {ref_t2}, __LINE__, __FILE__); } @@ -3144,9 +3014,6 @@ TEST_F(ResizeTest, PadExpandedEmpty) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - std::cout << t0 << std::endl; - std::cout << t0.strides() << std::endl; - testValidate( executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } @@ -3174,9 +3041,7 @@ TEST_F(ResizeTest, PadOfBroadcast) { fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - auto ref = at::pad(t0, {1, 1}); - - testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } // Test that we can cat along broadcast dims that have been expanded @@ -3205,9 +3070,7 @@ TEST_F(ResizeTest, PadOfExpandedBroadcast) { fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - auto ref = at::pad(at::expand_copy(t0, shape0e), {1, 1}); - - testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } } // namespace nvfuser diff --git a/test/test_scalar_hoisting.cpp b/test/test_scalar_hoisting.cpp index 96fc10a8b02..372e4f1677a 100644 --- a/test/test_scalar_hoisting.cpp +++ b/test/test_scalar_hoisting.cpp @@ -217,7 +217,7 @@ TEST_F(ScalarHoistTest, IndexHoist1) { fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); - testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } // Hoist indices for vectorized tensors @@ -261,9 +261,7 @@ TEST_F(ScalarHoistTest, IndexHoist2) { fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, {t0, t1}, __LINE__, __FILE__); } TEST_F(ScalarHoistTest, IndexHoist3) { @@ -291,7 +289,6 @@ TEST_F(ScalarHoistTest, IndexHoist3) { const auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::arange(10000, options).view({100, 100}); - at::Tensor t1 = t0.sin() + 10000; FusionExecutor fe; fe.compileFusion(fusion.get(), {t0}); @@ -334,7 +331,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor assertCUDAKernel(fusion.get(), expected_kernel); - testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); + testValidate(fusion.get(), cg_outputs, {t0}, __LINE__, __FILE__); } TEST_F(ScalarHoistTest, ARange) { @@ -361,11 +358,6 @@ TEST_F(ScalarHoistTest, ARange) { fe.compileFusion(fusion.get(), {start, end, step}); auto cg_outputs = fe.runFusion({start, end, step}); - const auto options = - at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); - at::Tensor t0 = at::arange(start, end, step, options); - at::Tensor t1 = at::full_like(t0, end - start, options); - const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(int64_t i0, int64_t i1, int64_t i2, Tensor T0, Tensor T1) { int64_t i3; @@ -394,12 +386,7 @@ __global__ void CUDAGeneratedKernel(int64_t i0, int64_t i1, int64_t i2, Tensor expect; - expect.reserve(dtypes.size()); - for (auto dtype : dtypes) { - if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { - continue; - } - const auto options = - at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); - expect.emplace_back(at::full({size}, 11, options)); - expect.emplace_back(at::full({size, size}, 12, options)); - expect.emplace_back(at::full({size, size}, 13, options)); - } auto cg_outputs = executor_cache.runFusionWithInputs({size, 11, 12, 13}); testValidate( executor_cache.fusion(), cg_outputs, {size, 11, 12, 13}, - expect, __LINE__, __FILE__); } @@ -119,27 +106,10 @@ TEST_F(TensorFactoryTest, StandaloneZeros) { FusionExecutorCache executor_cache(std::move(fusion)); for (auto size : sizes) { - std::vector expect; - expect.reserve(dtypes.size()); - for (auto dtype : dtypes) { - if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { - continue; - } - const auto options = - at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); - expect.emplace_back(at::zeros({size}, options)); - expect.emplace_back(at::zeros({size, size}, options)); - expect.emplace_back(at::zeros({size, size}, options)); - } auto cg_outputs = executor_cache.runFusionWithInputs({size}); testValidate( - executor_cache.fusion(), - cg_outputs, - {size}, - expect, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, {size}, __LINE__, __FILE__); } } @@ -176,27 +146,10 @@ TEST_F(TensorFactoryTest, StandaloneOnes) { FusionExecutorCache executor_cache(std::move(fusion)); for (auto size : sizes) { - std::vector expect; - expect.reserve(dtypes.size()); - for (auto dtype : dtypes) { - if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { - continue; - } - const auto options = - at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); - expect.emplace_back(at::ones({size}, options)); - expect.emplace_back(at::ones({size, size}, options)); - expect.emplace_back(at::ones({size, size}, options)); - } auto cg_outputs = executor_cache.runFusionWithInputs({size}); testValidate( - executor_cache.fusion(), - cg_outputs, - {size}, - expect, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, {size}, __LINE__, __FILE__); } } @@ -228,8 +181,6 @@ TEST_F(TensorFactoryTest, StandaloneIota) { FusionExecutorCache executor_cache(std::move(fusion)); - const auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); - switch (dtype) { case at::kInt: case at::kLong: { @@ -238,9 +189,6 @@ TEST_F(TensorFactoryTest, StandaloneIota) { for (auto step : steps) { int64_t start_ = (int64_t)start; int64_t step_ = (int64_t)step; - int64_t end_ = start_ + step_ * length; - auto a = at::arange(start_, end_, step_, options); - auto cg_outputs = executor_cache.runFusionWithInputs({length, start_, step_}); @@ -248,7 +196,6 @@ TEST_F(TensorFactoryTest, StandaloneIota) { executor_cache.fusion(), cg_outputs, {length, start_, step_}, - {a}, __LINE__, __FILE__); } @@ -263,14 +210,6 @@ TEST_F(TensorFactoryTest, StandaloneIota) { for (auto step : steps) { double start_ = (double)start; double step_ = (double)step; - - // Due to rounding error, it can be hard to guarantee the size of - // the output of arange to be exactly length, so we generate a - // larger tensor and truncate it to length. - double end_ = start_ + step_ * (length + 1); - auto a = - at::arange(start_, end_, step_, options).narrow(0, 0, length); - auto cg_outputs = executor_cache.runFusionWithInputs({length, start_, step_}); @@ -278,7 +217,6 @@ TEST_F(TensorFactoryTest, StandaloneIota) { executor_cache.fusion(), cg_outputs, {length, start_, step_}, - {a}, __LINE__, __FILE__); } @@ -327,8 +265,6 @@ TEST_F(TensorFactoryTest, StandaloneARange) { FusionExecutorCache executor_cache(std::move(fusion)); - const auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); - for (auto start : starts_ends) { for (auto end : starts_ends) { for (auto step : steps) { @@ -336,15 +272,6 @@ TEST_F(TensorFactoryTest, StandaloneARange) { continue; } - at::Tensor a = - at::arange((int64_t)start, (int64_t)end, (int64_t)step, options); - at::Tensor b = - at::arange((double)start, (double)end, (double)step, options); - at::Tensor c = - at::arange((int64_t)start, (double)end, (double)step, options); - at::Tensor d = - at::arange((double)start, (double)end, (int64_t)step, options); - auto cg_outputs = executor_cache.runFusionWithInputs( {(int64_t)start, (int64_t)end, @@ -362,7 +289,6 @@ TEST_F(TensorFactoryTest, StandaloneARange) { (double)start, (double)end, (double)step}, - {a, b, c, d}, __LINE__, __FILE__); } @@ -404,26 +330,10 @@ TEST_F(TensorFactoryTest, StandaloneEye) { FusionExecutorCache executor_cache(std::move(fusion)); for (auto size : sizes) { - std::vector expect; - expect.reserve(dtypes.size()); - for (auto dtype : dtypes) { - if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { - continue; - } - const auto options = - at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); - expect.emplace_back(at::eye(size, options)); - expect.emplace_back(at::eye(size, 15, options)); - } auto cg_outputs = executor_cache.runFusionWithInputs({size, 15}); testValidate( - executor_cache.fusion(), - cg_outputs, - {size, 15}, - expect, - __LINE__, - __FILE__); + executor_cache.fusion(), cg_outputs, {size, 15}, __LINE__, __FILE__); } } diff --git a/test/utils.cpp b/test/utils.cpp index a6c51f4af34..010396a2e3e 100644 --- a/test/utils.cpp +++ b/test/utils.cpp @@ -257,49 +257,6 @@ Container parse(const std::string& nvdisasm_output) { } // namespace sass -TensorView* matmulVolta(TensorView* a, TensorView* b, MatmulLayout layout) { - NVF_CHECK( - a->nDims() == 2 && b->nDims() == 2, "only pure matmuls for these tests"); - // Here, we canonicalize the mma output as M, N, K, but the position of K does - // not really matter. So the implicit transpose is only required for NN. - TensorView *tv2 = nullptr, *tv0b = nullptr, *tv1b = nullptr; - switch (layout) { - case MatmulLayout::TT: - tv0b = broadcast(a, {false, false, true}); - tv1b = broadcast(b, {true, false, false}); - tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - // M, K, N -> M, N, K - tv2->reorder({{1, -1}}); - tv2->commitLeafToRFactor(); - break; - case MatmulLayout::TN: - tv0b = broadcast(a, {false, true, false}); - tv1b = broadcast(b, {true, false, false}); - tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - // M, N, K - break; - case MatmulLayout::NT: - tv0b = broadcast(a, {false, false, true}); - tv1b = broadcast(b, {false, true, false}); - tv2 = fusedMultiplySum(tv0b, tv1b, {0}); - // K, M, N -> M, N, K - tv2->reorder({{0, -1}}); - tv2->commitLeafToRFactor(); - break; - case MatmulLayout::NN: - tv0b = broadcast(a, {true, false, false}); - tv1b = broadcast(b, {false, false, true}); - tv2 = fusedMultiplySum(tv0b, tv1b, {1}); - // N, K, M -> M, N, K - tv2->reorder({{-1, 0}}); - tv2->commitLeafToRFactor(); - break; - default: - NVF_CHECK(false, "unsupported data layout."); - } - return tv2; -} - // matmulAtInput provides batched inputs in a splitk-like ordering. It provides // contiguous tensors with these shapes // TT: [M, B, K] [B, K, N] @@ -311,7 +268,7 @@ TensorView* matmulVolta(TensorView* a, TensorView* b, MatmulLayout layout) { TensorView* matmulTuringOrLater( TensorView* a, TensorView* b, - MatmulLayout layout) { + MmaLayout layout) { NVF_CHECK(a->nDims() == b->nDims()); NVF_CHECK(a->nDims() == 2 || a->nDims() == 3); TensorView *tv2 = nullptr, *tv0t = nullptr, *tv1t = nullptr, *tv0b = nullptr, @@ -319,19 +276,19 @@ TensorView* matmulTuringOrLater( if (a->nDims() == 3) { // bmm switch (layout) { // Canonicalize all inputs to [B, M, K] and [B, N, K] - case MatmulLayout::TT: + case MmaLayout::TT: tv0t = transpose(a, 0, 1); tv1t = transpose(b, 1, 2); break; - case MatmulLayout::TN: + case MmaLayout::TN: tv0t = transpose(a, 0, 1); tv1t = transpose(b, 0, 1); break; - case MatmulLayout::NT: + case MmaLayout::NT: tv0t = transpose(a, 1, 2); tv1t = transpose(b, 1, 2); break; - case MatmulLayout::NN: + case MmaLayout::NN: tv0t = transpose(a, 1, 2); tv1t = transpose(b, 0, 1); break; @@ -341,19 +298,19 @@ TensorView* matmulTuringOrLater( } else { switch (layout) { // Canonicalize all inputs to [M, K] and [N, K] - case MatmulLayout::TT: + case MmaLayout::TT: tv0t = a; tv1t = transpose(b, 0, 1); break; - case MatmulLayout::TN: + case MmaLayout::TN: tv0t = a; tv1t = b; break; - case MatmulLayout::NT: + case MmaLayout::NT: tv0t = transpose(a, 0, 1); tv1t = transpose(b, 0, 1); break; - case MatmulLayout::NN: + case MmaLayout::NN: tv0t = transpose(a, 0, 1); tv1t = b; break; @@ -374,20 +331,17 @@ TensorView* matmulTuringOrLater( TensorView* matmul( TensorView* a, TensorView* b, - MatmulLayout layout, + MmaLayout layout, bool turing_or_later // TODO: This is a temporary solution. Remove this! ) { - if (turing_or_later) { - return matmulTuringOrLater(a, b, layout); - } else { - return matmulVolta(a, b, layout); - } + NVF_ERROR(turing_or_later, "Only Turing or later is supported for now."); + return matmulTuringOrLater(a, b, layout); } TensorView* splitkLikeBatchedMatmul( TensorView* a, TensorView* b, - MatmulLayout layout) { + MmaLayout layout) { NVF_CHECK( a->nDims() == 3 && b->nDims() == 3, "only splitk-like batched matmuls for these tests"); @@ -395,25 +349,25 @@ TensorView* splitkLikeBatchedMatmul( *tv1b = nullptr; switch (layout) { // Canonicalize all inputs to [B, M, K] and [B, N, K] - case MatmulLayout::TT: + case MmaLayout::TT: // [M, B, K] -> [B, M, K] tv0t = transpose(a, 0, 1); // [B, K, N] -> [B, N, K] tv1t = transpose(b, 1, 2); break; - case MatmulLayout::TN: + case MmaLayout::TN: // [M, B, K] -> [B, M, K] tv0t = transpose(a, 0, 1); // [N, B, K] -> [B, N, K] tv1t = transpose(b, 0, 1); break; - case MatmulLayout::NT: + case MmaLayout::NT: // [B, K, M] -> [B, M, K] tv0t = transpose(a, 1, 2); // [B, K, N] -> [B, N, K] tv1t = transpose(b, 1, 2); break; - case MatmulLayout::NN: + case MmaLayout::NN: // [B, K, M] -> [B, M, K] tv0t = transpose(a, 1, 2); // [N, B, K] -> [B, N, K] @@ -435,7 +389,7 @@ TensorView* splitkLikeBatchedMatmul( // NT: [B, K, M] [B, K, N] // NN: [B, K, M] [N, B, K] // ATen matmul assumes [B, M, K] [B, K, N] so here we transpose into that order -at::Tensor atMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout) { +at::Tensor atMatmul(at::Tensor a, at::Tensor b, MmaLayout layout) { NVF_CHECK( a.dim() == b.dim(), "Either both or none of A and B should be batch"); NVF_CHECK( @@ -443,26 +397,26 @@ at::Tensor atMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout) { "Must have either zero or one batch dimensions"); if (a.dim() == 3) { // bmm switch (layout) { - case MatmulLayout::TT: + case MmaLayout::TT: return a.transpose(0, 1).matmul(b); - case MatmulLayout::TN: + case MmaLayout::TN: return a.transpose(0, 1).matmul(b.transpose(0, 1).transpose(1, 2)); - case MatmulLayout::NT: + case MmaLayout::NT: return a.transpose(1, 2).matmul(b); - case MatmulLayout::NN: + case MmaLayout::NN: return a.transpose(1, 2).matmul(b.transpose(0, 1).transpose(1, 2)); default: NVF_CHECK(false, "unsupported data layout."); } } else { switch (layout) { - case MatmulLayout::TT: + case MmaLayout::TT: return a.matmul(b); - case MatmulLayout::TN: + case MmaLayout::TN: return a.matmul(b.t()); - case MatmulLayout::NT: + case MmaLayout::NT: return a.t().matmul(b); - case MatmulLayout::NN: + case MmaLayout::NN: return a.t().matmul(b.t()); default: NVF_CHECK(false, "unsupported data layout."); @@ -471,18 +425,18 @@ at::Tensor atMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout) { return at::Tensor(); } -at::Tensor splitkLikeAtMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout) { +at::Tensor splitkLikeAtMatmul(at::Tensor a, at::Tensor b, MmaLayout layout) { switch (layout) { - case MatmulLayout::TT: + case MmaLayout::TT: // [M, B, K] @ [B, K, N] -> [B, M, N] return a.transpose(0, 1).matmul(b); - case MatmulLayout::TN: + case MmaLayout::TN: // [M, B, K] @ [N, B, K] -> [B, M, N] return a.transpose(0, 1).matmul(b.permute({1, 2, 0})); - case MatmulLayout::NT: + case MmaLayout::NT: // [B, K, M] @ [B, K, N] -> [B, M, N] return a.transpose(1, 2).matmul(b); - case MatmulLayout::NN: + case MmaLayout::NN: // [B, K, M] @ [N, B, K] -> [B, M, N] return a.transpose(1, 2).matmul(b.permute({1, 2, 0})); default: @@ -495,21 +449,21 @@ std::pair matmulAtInput( int M, int N, int K, - MatmulLayout layout, + MmaLayout layout, c10::ScalarType dtype) { auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); switch (layout) { - case MatmulLayout::TT: + case MmaLayout::TT: return std::make_pair( at::randn({M, K}, options), at::randn({K, N}, options)); - case MatmulLayout::TN: + case MmaLayout::TN: return std::make_pair( at::randn({M, K}, options), at::randn({N, K}, options)); - case MatmulLayout::NT: + case MmaLayout::NT: return std::make_pair( at::randn({K, M}, options), at::randn({K, N}, options)); - case MatmulLayout::NN: + case MmaLayout::NN: return std::make_pair( at::randn({K, M}, options), at::randn({N, K}, options)); default: @@ -519,7 +473,7 @@ std::pair matmulAtInput( } at::Tensor matmulAtInput( - const MatmulLayout layout, + const MmaLayout layout, const TensorMatmulPos tensor, const c10::ScalarType dtype, const int M, @@ -544,7 +498,7 @@ at::Tensor matmulAtInput( } switch (layout) { - case MatmulLayout::TT: + case MmaLayout::TT: switch (tensor) { case TensorMatmulPos::A: return is_batch ? at::randn({M, B, K}, options) @@ -556,7 +510,7 @@ at::Tensor matmulAtInput( break; } break; - case MatmulLayout::TN: + case MmaLayout::TN: switch (tensor) { case TensorMatmulPos::A: return is_batch ? at::randn({M, B, K}, options) @@ -568,7 +522,7 @@ at::Tensor matmulAtInput( break; } break; - case MatmulLayout::NT: + case MmaLayout::NT: switch (tensor) { case TensorMatmulPos::A: return is_batch ? at::randn({B, K, M}, options) @@ -580,7 +534,7 @@ at::Tensor matmulAtInput( break; } break; - case MatmulLayout::NN: + case MmaLayout::NN: switch (tensor) { case TensorMatmulPos::A: return is_batch ? at::randn({B, K, M}, options) diff --git a/test/utils.h b/test/utils.h index 4f8b9505df5..9fc038b39d9 100644 --- a/test/utils.h +++ b/test/utils.h @@ -483,6 +483,12 @@ class NVFuserTest : public ::testing::Test { bool capturing_ = false; }; +// Fixture with param class must be uniquely identified, i.e., can't be in an +// anonymous namespace +template +class NVFuserFixtureParamTest : public NVFuserTest, + public ::testing::WithParamInterface {}; + // assert that the given fusion lowers to the given CUDA kernel void assertCUDAKernel(Fusion* fusion, const std::string& expected_kernel); @@ -576,20 +582,17 @@ inline bool cudaArchGuardShouldSkip( COMPILE_FUSION; \ } -// util to track support matmul operand layout. -using MatmulLayout = MmaOptions::MmaLayout; - -static constexpr std::array kAllSupportedMatmulLayout = { - MatmulLayout::TT, - MatmulLayout::NT, - MatmulLayout::TN, - MatmulLayout::NN}; +static constexpr std::array kAllSupportedMmaLayout = { + MmaLayout::TT, + MmaLayout::NT, + MmaLayout::TN, + MmaLayout::NN}; // Generic interface to get matmul op with the given layout. TensorView* matmul( TensorView* a, TensorView* b, - MatmulLayout layout, + MmaLayout layout, bool turing_or_later // TODO: This is a temporary solution. Remove this! ); @@ -600,20 +603,20 @@ TensorView* matmul( TensorView* splitkLikeBatchedMatmul( TensorView* a, TensorView* b, - MatmulLayout layout); + MmaLayout layout); // Utility to generate matmul input tensors based on given layout -at::Tensor atMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout); +at::Tensor atMatmul(at::Tensor a, at::Tensor b, MmaLayout layout); // Utility to generate matmul input tensors based on given layout -at::Tensor splitkLikeAtMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout); +at::Tensor splitkLikeAtMatmul(at::Tensor a, at::Tensor b, MmaLayout layout); // Utility to generate inputs based on given layout std::pair matmulAtInput( int M, int N, int K, - MatmulLayout layout, + MmaLayout layout, c10::ScalarType dtype = at::kHalf); // Labels to describe tensor position in matmul: @@ -626,7 +629,7 @@ enum class TensorMatmulPos { A, B, C, D, Bias }; // Utility to generate buffers based on given problem, layout and tensor // position in matmul with support for matmul and strided batch matmul at::Tensor matmulAtInput( - const MatmulLayout layout, + const MmaLayout layout, const TensorMatmulPos tensor, const c10::ScalarType dtype, const int M, diff --git a/test/validator.h b/test/validator.h index 8b638725c59..3d6c1a20bdd 100644 --- a/test/validator.h +++ b/test/validator.h @@ -65,30 +65,25 @@ void testValidate( } } - const auto& io_alias = fusion->ioAlias(); - auto should_remove = [&io_alias](Val* out_val) -> bool { - if (auto alias_it = io_alias.find(out_val); alias_it != io_alias.end()) { - return alias_it->second.second.hide_output; - } - return false; - }; - - for (size_t i = 0, j = 0; i < fusion->outputs().size(); i++) { - NVF_ERROR(fusion->outputs()[i]->isA()); - if (should_remove(fusion->outputs()[i])) { + size_t j = 0; + for (Val* fusion_output : fusion->outputs()) { + const AliasInfo* alias_info = fusion->getOutputAlias(fusion_output).second; + if (alias_info != nullptr && alias_info->hide_output) { // This is an aliased output that's hidden from integration. // Let's not check this. continue; } + NVF_ERROR(fusion_output->isA()); + TensorView* fusion_output_tv = fusion_output->as(); + auto fusion_output_tensor = fusion_outputs[j]; - auto fusion_output_tv = fusion->outputs()[i]->as(); auto aten_output_tensor = aten_outputs[j]; NVF_ERROR( reduction_sizes.count(fusion_output_tv), - "Missed reduction size count on fusion output at index: ", - i); + "Missed reduction size count on fusion output: ", + fusion_output_tv->toString()); int64_t reduction_size = reduction_sizes.at(fusion_output_tv); diff --git a/tools/codediff/diff_report.py b/tools/codediff/diff_report.py index 2d1c2b79362..2bd6258394c 100644 --- a/tools/codediff/diff_report.py +++ b/tools/codediff/diff_report.py @@ -545,7 +545,7 @@ def find_preamble(self): # we set nvfuser_index_t in the preamble. We ignore that change for the purposes of this diff if line[:8] == "typedef " and line[-17:] == " nvfuser_index_t;": line = "typedef int nvfuser_index_t; // NOTE: index type hard-coded as int for display only" - if re.search(r"void kernel\d+\b", line) is not None: + if re.search(r"void (nvfuser|kernel)_?\d+\b", line) is not None: # we arrived at the kernel definition break if first: @@ -579,7 +579,8 @@ def get_kernel( kern.index_type = m.groups()[0] if not strip_preamble or i >= self.preamble_size_lines: # replace kernel934 with kernel1 to facilitate diffing - kern.code += re.sub(r"\bkernel\d+\b", "kernelN", line) + # also match kernel_43 to handle new-style naming with static fusion count + kern.code += re.sub(r"\bnvfuser_\d+\b", "nvfuser_N", line) kern.code = kern.code.rstrip() if strip_preamble and kern.code[-1] == "}": # trailing curly brace is close of namespace. This will clean it up so that we have just the kernel @@ -642,13 +643,28 @@ def sanitize_ptx_lines(lines: list[str]) -> list[str]: for l in lines: # Replace mangled kernel names like # _ZN76_GLOBAL__N__00000000_37___tmp_kernel_pointwise_f0_c1_r0_g0_cu_8995cef2_3255329nvfuser_pointwise_f0_c1_r0_g0ENS_6TensorIfLi2ELi2EEES1_S1_ + # or + # _ZN76_GLOBAL__N__00000000_37___tmp_kernel_4_cu_8995cef2_3255329nvfuser_4ENS_6TensorIfLi2ELi2EEES1_S1_ # with - # _ZN76_GLOBAL__N__00000000_37___tmp_kernel_pointwise_cu_8995cef2_3255329nvfuser_pointwiseENS_6TensorIfLi2ELi2EEES1_S1_ - l = re.sub( - r"_tmp_kernel_[a-z0-9_]+nvfuser_[a-z]+_f\d+_c\d*_r\d+_g\d+", "kernel", l - ) + # _ZN11kernelscope6kernelENS_6TensorIfLi2ELi2EEES1_S1_ - # Remove comments. This is important for + # demangle first two parts after _ZN and replace with "kernelscope" and "kernel" + m = re.match(r"^(?P^.*\b_Z?ZN)(?P\d+)_", l) + if m is not None: + d = m.groupdict() + scopenamelen = int(d["scopenamelen"]) + # demangle second part in remainder after scope name + remainder = l[(len(d["prefix"]) + len(d["scopenamelen"]) + scopenamelen) :] + mrem = re.match(r"^(?P\d+)", remainder) + if mrem is not None: + drem = mrem.groupdict() + varnamelen = int(drem["varnamelen"]) + remainder = ( + "6kernel" + remainder[len(drem["varnamelen"]) + varnamelen :] + ) + l = d["prefix"] + "11kernelscope" + remainder + + # Remove comments. This fixes mismatches in PTX "callseq" comments, which appear to be non-repeatable. l = re.sub(r"//.*$", "", l) sanitary_lines.append(l) return sanitary_lines @@ -857,7 +873,10 @@ def generate_html(self, omit_preamble: bool, max_diffs: bool) -> str: ) parser.add_argument("--html", action="store_true", help="Write HTML file?") parser.add_argument( - "--hide-diffs", action="store_true", help="Print diffs to STDOUT?" + "--hide-diffs", + "--no-print-diff", + action="store_true", + help="Print diffs to STDOUT?", ) parser.add_argument( "--kernel-inclusion-criterion", diff --git a/tools/codediff/run_command.sh b/tools/codediff/run_command.sh index a5af6d0cc60..db341747261 100755 --- a/tools/codediff/run_command.sh +++ b/tools/codediff/run_command.sh @@ -165,6 +165,8 @@ ensure_in_list() { # ensure some NVFUSER_DUMP options are enabled appended_dump=$(ensure_in_list "$NVFUSER_DUMP" cuda_to_file ptxas_verbose ptx) export NVFUSER_DUMP=$appended_dump +appended_enable=$(ensure_in_list "$NVFUSER_ENABLE" static_fusion_count) +export NVFUSER_ENABLE=$appended_enable # Allow command to fail, but record exit code set +e diff --git a/tools/codediff/templates/command_env_info.html b/tools/codediff/templates/command_env_info.html index 87648514ab7..83d0aff8a88 100644 --- a/tools/codediff/templates/command_env_info.html +++ b/tools/codediff/templates/command_env_info.html @@ -2,10 +2,18 @@ Command: {{ run1.command|e }} {%- else -%} {{ run1.name|e }} command: {{ run1.command|e }} -
- {{ run2.name|e }} command: {{ run2.command|e }} - {%- endif -%} - {%- if run1.gpu_names is not none %} +
+ {{ run2.name|e }} command: {{ run2.command|e }} +{%- endif -%} +{%- if run1.exit_code != 0 -%} +
+ {{ run1.name|e }} command failed with exit code {{ run1.exit_code }}. +{%- endif -%} +{%- if run2.exit_code != 0 -%} +
+ {{ run2.name|e }} command failed with exit code {{ run2.exit_code }}. +{%- endif -%} +{%- if run1.gpu_names is not none %}
{%- if run1.gpu_names | length > 1 %} {% if run1.gpu_names != run2.gpu_names %}{{ run1.name|e }}{% endif %} diff --git a/version.txt b/version.txt index d917d3e26ad..b1e80bb2480 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.1.2 +0.1.3 From 6ba29ede6c2037c482aaf21e9f21c2958594dc53 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 29 Nov 2023 17:22:38 -0800 Subject: [PATCH 083/178] Merge lower2device --- csrc/device_lower/lower2device.cpp | 53 +++++++++++++----------------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index dac936eb583..84588dc1585 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -238,7 +238,6 @@ void assignRNGOffset(Fusion* fusion) { void dumpExprsIfEnabled( const std::vector& exprs, std::string pass_name, - bool force_expr_disable = true, bool force_enable = false) { auto enabled_by_env = [&pass_name]() { if (!isDebugDumpEnabled(DebugDumpOption::LowerVerbose)) { @@ -249,12 +248,8 @@ void dumpExprsIfEnabled( args.empty() || std::find(args.begin(), args.end(), pass_name) != args.end()); }; - bool name_only = isDebugDumpEnabled(DebugDumpOption::LowerNameOnly); - if (name_only || force_enable || enabled_by_env()) { - std::cout << "After " << pass_name << ":" << std::endl; - if (name_only || force_expr_disable) { - return; - } + if (force_enable || enabled_by_env()) { + debug() << "After " << pass_name << ":" << std::endl; for (auto exp : exprs) { debug() << exp->toString() << std::endl; } @@ -366,19 +361,17 @@ void GpuLower::analysis(Fusion* fusion) { // prepare for lowering validateIr(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateIr", true); + dumpExprsIfEnabled(fusion_->exprs(), "validateIr"); // Checks if any TIDx dim is marked as padded to a warp. Also checks if we can // determine the padding is explicitly a single warp. collectPaddedParallelDims(); - dumpExprsIfEnabled(fusion_->exprs(), "collectPaddedParallelDims", true); + dumpExprsIfEnabled(fusion_->exprs(), "collectPaddedParallelDims"); // Replaces integers that are tensor sizes by named scalars as "T0.size[0]" replaceSymbolicSizes(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "replaceSymbolicSizes"); - IdModel test(fusion_); - // Build what's refered to as the compute at map. This map contains the // mappings of all iteration domains across the fusion. There are three types // of mappings Permissive, Exact, and Loop, see compute_at_map.h/cpp for more @@ -397,7 +390,7 @@ void GpuLower::analysis(Fusion* fusion) { } resolveComputeWith(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "resolveComputeWith", true); + dumpExprsIfEnabled(fusion_->exprs(), "resolveComputeWith"); if (isDebugDumpEnabled(DebugDumpOption::ComputeAtMap)) { debug() << compute_at_map_->toString() << std::endl; @@ -407,35 +400,34 @@ void GpuLower::analysis(Fusion* fusion) { // Uses compute_at_map, find all splits that are enforced to be divisible divisible_splits_ = getAllDivisibleSplits(fusion_, compute_at_map_.get()); - dumpExprsIfEnabled(fusion_->exprs(), "getAllDivisibleSplits", true); + dumpExprsIfEnabled(fusion_->exprs(), "getAllDivisibleSplits"); // Used in parallel dimension map concretized_broadcast_domains_ = std::make_shared(fusion_); - dumpExprsIfEnabled( - fusion_->exprs(), "build ConcretizedBroadcastDomains", true); + dumpExprsIfEnabled(fusion_->exprs(), "build ConcretizedBroadcastDomains"); parallelDimensionMap().build(fusion_); if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) { debug() << "Parallel dimension map:" << std::endl; debug() << parallel_dimension_map_.toString() << std::endl; } - dumpExprsIfEnabled(fusion_->exprs(), "build parallelDimensionMap", true); + dumpExprsIfEnabled(fusion_->exprs(), "build parallelDimensionMap"); // Validate mma data format and compatibility if any on the fusion. validateMma(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateMma", true); + dumpExprsIfEnabled(fusion_->exprs(), "validateMma"); // Validate swizzle usage on the fusion schedule. validateSwizzle(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateSwizzle", true); + dumpExprsIfEnabled(fusion_->exprs(), "validateSwizzle"); validateResize(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "validateResize"); // Compute thread predicates. Depends on parallel_dimension_map_ thread_pred_map_.build(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build thread_pred_map_", true); + dumpExprsIfEnabled(fusion_->exprs(), "build thread_pred_map_"); // Fuse cetain patterns of reductions, such as a grid reduction // followed by a grid broadcast. Only depends on parallelization and @@ -446,27 +438,26 @@ void GpuLower::analysis(Fusion* fusion) { // Scan the whole fusion and build mappings about halo extensions of // all IterDomains halo_info_ = std::make_shared(fusion_, compute_at_map_); - dumpExprsIfEnabled(fusion_->exprs(), "build HaloInfo", true); + dumpExprsIfEnabled(fusion_->exprs(), "build HaloInfo"); // Want to run this after parallel map and halo info map are // created. vectorized_accesses_ and vectorized_set_info_ are filled. validateAndCollectVectorizeInfo(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateAndCollectVectorizeInfo", true); + dumpExprsIfEnabled(fusion_->exprs(), "validateAndCollectVectorizeInfo"); // Depends on ComputeAtMap and HaloInfo. validateAndConvertIterDomainGrouping(fusion_); - dumpExprsIfEnabled( - fusion_->exprs(), "validateAndConvertIterDomainGrouping", true); + dumpExprsIfEnabled(fusion_->exprs(), "validateAndConvertIterDomainGrouping"); // Assumes all grouped reductions are convered to // GroupedReductionOp, which is done by // validateAndConvertIterDomainGrouping validateGroupedReductions(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateGroupedReductions", true); + dumpExprsIfEnabled(fusion_->exprs(), "validateGroupedReductions"); // all of the lookup TVs are fusion inputs validateLookupTV(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validateLookupTV", true); + dumpExprsIfEnabled(fusion_->exprs(), "validateLookupTV"); // Depends on thread_pred_map_, validates parallelization collects which // tensor views need WAR or RAW syncs @@ -474,24 +465,24 @@ void GpuLower::analysis(Fusion* fusion) { if (isDebugDumpEnabled(DebugDumpOption::SyncMap)) { debug() << sync_map_->toString() << std::endl; } - dumpExprsIfEnabled(fusion_->exprs(), "SyncMap", true); + dumpExprsIfEnabled(fusion_->exprs(), "SyncMap"); partialSplitMap().build(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build partialSplitMap", true); + dumpExprsIfEnabled(fusion_->exprs(), "build partialSplitMap"); validatePartialSplit(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "validatePartialSplit", true); + dumpExprsIfEnabled(fusion_->exprs(), "validatePartialSplit"); nonDivisibleSplitInfo().build(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build nonDivisibleSplitInfo", true); + dumpExprsIfEnabled(fusion_->exprs(), "build nonDivisibleSplitInfo"); // Detects all exprssions that don't need predicates. Depends on // nonDivisibleSplitInfo. pred_elimination_ = std::make_unique(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination", true); + dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination"); doubleBufferInfo().build(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build doubleBufferInfo", true); + dumpExprsIfEnabled(fusion_->exprs(), "build doubleBufferInfo"); compute_at_map_->allocateIndexVariables(); dumpExprsIfEnabled(fusion_->exprs(), "allocateIndexVariables"); From 9eacdf6e43dbc7b45fa1db77cd012388d3a3d8e0 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 29 Nov 2023 23:00:08 -0800 Subject: [PATCH 084/178] Merge from main --- csrc/options.cpp | 1 - csrc/options.h | 1 - 2 files changed, 2 deletions(-) diff --git a/csrc/options.cpp b/csrc/options.cpp index 1fd5708ebe3..15287fed070 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -126,7 +126,6 @@ std::unordered_map> Options< {"kernel_ir", DebugDumpOption::KernelIr}, {"launch_param", DebugDumpOption::LaunchParam}, {"loop_rotation", DebugDumpOption::LoopRotation}, - {"lower_name_only", DebugDumpOption::LowerNameOnly}, {"lower_verbose", DebugDumpOption::LowerVerbose}, {"occupancy", DebugDumpOption::Occupancy}, {"parallel_dimensions", DebugDumpOption::ParallelDimensions}, diff --git a/csrc/options.h b/csrc/options.h index a9913817f90..b506bd7a746 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -60,7 +60,6 @@ enum class DebugDumpOption { Ptx, //! Dump compiled PTX BankConflictInfo, //! Dump bank confliction info SyncMap, //! RAW dependency info - LowerNameOnly, //! Print all passes' names as they're run in GpuLower::lower LowerVerbose, //! Print all passes' transform in GpuLower::lower ExprSimplification, //! Print all passes' transform in simplifyExpr ExprSort, //! Print merging decisions on expression sorting From f8a958550ff05b0decf8cb76c9d2eccbaf97f474 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 1 Dec 2023 00:05:02 -0800 Subject: [PATCH 085/178] Fix validation Validation needs to be done immediately after an exact graph is built before replaying domains as new domains would be added the exact graph as well, which would not exist in the CA map. --- csrc/device_lower/lower2device.cpp | 6 +----- csrc/id_model/id_model.cpp | 21 ++++++++++++++++----- csrc/id_model/id_model.h | 10 +++++++--- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 84588dc1585..89b79c60d46 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -34,7 +34,6 @@ #include #include #include -#include #include #include #include @@ -383,10 +382,7 @@ void GpuLower::analysis(Fusion* fusion) { // so it is expected that generated code may use diffrent variable // names if (isOptionEnabled(EnableOption::IdModel)) { - IdModel id_model(fusion_); - // Only the exact graph is genereated at this moment - IdModelValidator::checkExactGraphEquivalence( - id_model.idGraph(IdMappingMode::EXACT)); + IdModel id_model(fusion_, false, true); } resolveComputeWith(fusion_); diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index c99f7325ec2..6988a6a45ad 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -54,7 +55,7 @@ IdModel::IdModel( IdModel::IdModel(const std::vector& exprs, bool allow_self_mapping) : IdModel(exprs, {}, allow_self_mapping) {} -IdModel::IdModel(Fusion* fusion, bool allow_self_mapping) { +IdModel::IdModel(Fusion* fusion, bool allow_self_mapping, bool validate) { std::vector inputs_and_outputs; { auto inp_tvs = ir_utils::filterByType(fusion->inputs()); @@ -67,7 +68,7 @@ IdModel::IdModel(Fusion* fusion, bool allow_self_mapping) { inputs_and_outputs.begin(), out_tvs.begin(), out_tvs.end()); } - build(fusion->exprs(), inputs_and_outputs); + build(fusion->exprs(), inputs_and_outputs, validate); if (!allow_self_mapping) { assertNoSelfMapping(); @@ -641,7 +642,7 @@ ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) { return id_graph; } -void IdModel::buildExactMap(const std::vector& exprs) { +void IdModel::buildExactGraph(const std::vector& exprs) { for (auto expr : exprs) { TensorView* c_tv = ir_utils::getTvOutput(expr); @@ -938,7 +939,8 @@ StatefulLoweringInfo buildInfo( void IdModel::build( const std::vector& exprs, - const std::vector& additional_tvs) { + const std::vector& additional_tvs, + bool validate) { // Initialize the required sets as if a permissive relationship is never // found, then querying an empty permissive map will fail later. // Initialize disjoint sets @@ -977,7 +979,16 @@ void IdModel::build( // expressions. idGraph(IdMappingMode::EXACT) = initializeIdGraph(); - buildExactMap(tv_exprs); + buildExactGraph(tv_exprs); + + if (validate) { + IdModelValidator::checkExactGraphEquivalence(idGraph(IdMappingMode::EXACT)); + } + + if (getenv("EXACT_ONLY")) { + return; + } + buildAlmostExactMap(); buildPermissiveMap(tv_exprs); diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 1315c4fe125..9fc213c0254 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -94,7 +94,10 @@ class IdModel : public PolymorphicBase { // Same as the above constructor with fusion->exprs() excpet fusion may have // some dangling inputs/outputs that are expected to have IterDomain entries // even though there's no possible connections from them. - IdModel(Fusion* fusion, bool allow_self_mapping = false); + IdModel( + Fusion* fusion, + bool allow_self_mapping = false, + bool validate = false); // Returns iter domain graph of provided mode. const ValGraph& idGraph(IdMappingMode mode) const; @@ -176,7 +179,8 @@ class IdModel : public PolymorphicBase { // the Fusion that don't have expressions associated with them. void build( const std::vector& exprs, - const std::vector& additional_tvs); + const std::vector& additional_tvs, + bool validate = false); // ======= START Iteration domain build process in order called ======= @@ -191,7 +195,7 @@ class IdModel : public PolymorphicBase { // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs // and first output of expr - void buildExactMap(const std::vector& exprs); + void buildExactGraph(const std::vector& exprs); // Fills disjoint_ids_[IdMappingMode::ALMOSTEXACT]. Initialize AlmostExact as // Exact entries, then map anything that's either merged with a size-1 or From 9733ab61ec054f7f173e8891a9370e5c0bdcae89 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 1 Dec 2023 00:17:20 -0800 Subject: [PATCH 086/178] Merge from main --- csrc/id_model/id_model.cpp | 151 ++++++++++++++++++---------------- csrc/id_model/id_model.h | 12 ++- csrc/id_model/to_string.cpp | 35 +++----- csrc/id_model/to_string.h | 3 +- csrc/id_model/visitor.cpp | 13 +-- csrc/val_graph.cpp | 158 ++++++++++++++++-------------------- csrc/val_graph.h | 111 ++++++++++++++----------- 7 files changed, 240 insertions(+), 243 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 6988a6a45ad..87ae40c8eb0 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -22,9 +22,37 @@ #include #include +#include namespace nvfuser { +namespace { + +// Map through loop swizzles, as input/output IterDomains are exact, only the +// order they're traversed differs. +void mapThroughLoopSwizzles(ValGraph& graph) { + std::vector all_swizzles; + + for (const auto& expr_set : + std::as_const(graph).disjointExprSets().disjointSets()) { + auto swizzles_in_expr_set = ir_utils::filterByType( + expr_set->vector().begin(), expr_set->vector().end()); + all_swizzles.insert( + all_swizzles.end(), + swizzles_in_expr_set.begin(), + swizzles_in_expr_set.end()); + } + + for (auto swizzle : all_swizzles) { + if (swizzle->swizzleMode() == SwizzleMode::Loop) { + graph.mapVals(swizzle->inX(), swizzle->outX()); + graph.mapVals(swizzle->inY(), swizzle->outY()); + } + } +} + +} // namespace + void IdModel::assertNoSelfMapping() { if (hasSelfMapping()) { NVF_ERROR( @@ -52,9 +80,6 @@ IdModel::IdModel( } } -IdModel::IdModel(const std::vector& exprs, bool allow_self_mapping) - : IdModel(exprs, {}, allow_self_mapping) {} - IdModel::IdModel(Fusion* fusion, bool allow_self_mapping, bool validate) { std::vector inputs_and_outputs; { @@ -65,7 +90,7 @@ IdModel::IdModel(Fusion* fusion, bool allow_self_mapping, bool validate) { { auto out_tvs = ir_utils::filterByType(fusion->outputs()); inputs_and_outputs.insert( - inputs_and_outputs.begin(), out_tvs.begin(), out_tvs.end()); + inputs_and_outputs.end(), out_tvs.begin(), out_tvs.end()); } build(fusion->exprs(), inputs_and_outputs, validate); @@ -77,7 +102,11 @@ IdModel::IdModel(Fusion* fusion, bool allow_self_mapping, bool validate) { const ValGraph& IdModel::idGraph(IdMappingMode mode) const { auto graph_it = id_graphs_.find(mode); - NVF_ERROR(graph_it != id_graphs_.end()); + NVF_ERROR( + graph_it != id_graphs_.end(), + "Failed to find an IdGraph with the ", + mode, + " mode"); return graph_it->second; } @@ -214,14 +243,14 @@ findFirstSelfMapping( void IdModel::buildIterDomainDefinitionsAndUses( const std::vector& all_tvs) { - for (auto tv : all_tvs) { + for (const auto tv : all_tvs) { VectorOfUniqueEntries root_domain_ids{ tv->getRootDomain().begin(), tv->getRootDomain().end()}; - auto all_ids = ir_utils::allIDsOf(tv); + std::vector all_ids = ir_utils::allIDsOf(tv); - // Check is this domain is a consumer of a view-like operation - bool view_like_domain = tv->domain()->hasViewLikeRFactor(); + // Check if this domain is a consumer of a view-like operation + const bool view_like_domain = tv->domain()->hasViewLikeRFactor(); for (auto id : all_ids) { // Check if this id is a view like rfactor id @@ -237,64 +266,47 @@ void IdModel::buildIterDomainDefinitionsAndUses( } if (id_definitions_.find(id) == id_definitions_.end()) { - id_definitions_[id] = {}; + id_definitions_.emplace(id, VectorOfUniqueEntries{}); } if (id_uses_.find(id) == id_uses_.end()) { - id_uses_[id] = {}; + id_uses_.emplace(id, VectorOfUniqueEntries{}); } - auto def = id->definition(); + Expr* def = id->definition(); if (def == nullptr || root_domain_ids.has(id)) { continue; } - if (id_definitions_.find(id) == id_definitions_.end()) { - id_definitions_[id] = {}; - } - id_definitions_.at(id).pushBack(def); + id_definitions_[id].pushBack(def); auto inp_ids = ir_utils::filterByType(def->inputs()); for (auto inp_id : inp_ids) { - if (id_uses_.find(inp_id) == id_uses_.end()) { - id_uses_[inp_id] = {}; - } - id_uses_.at(inp_id).pushBack(def); + id_uses_[inp_id].pushBack(def); } } } } std::string IdModel::toString() const { - // Figure out which graphs are already initialized to make sure we add the new - // expression to them. - std::vector initialized_modes; + std::stringstream ss; + ss << "IterDomainGraphs { \n"; + // Only print initialized graphs for (auto mode : kIdMappingModes) { auto graph_it = id_graphs_.find(mode); if (graph_it == id_graphs_.end()) { continue; } - auto& graph = graph_it->second; - if (graph.disjointValSets().disjointSetMap().empty()) { - continue; - } - - initialized_modes.push_back(mode); - } - - std::stringstream ss; - ss << "IterDomainGraphs { \n"; - for (auto mode : initialized_modes) { - std::stringstream ss; + // graph may be empty, but then just print it as an empty graph, + // which might be useful for debugging ss << " IdGraph " << mode << "{ \n"; ss << " Disjoint Ids:\n" << idGroupsString(idGraph(mode), 2) << "\n Disjoint Expression groups:\n" << exprGroupsString(idGraph(mode), 2) << std::endl; ss << " } IdGraph\n" << std::endl; - return ss.str(); } ss << " } IterDomainGraphs\n" << std::endl; return ss.str(); @@ -342,7 +354,7 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { id_uses_[new_inputs.back()]; for (auto mode : initialized_modes) { idGraph(mode).initializeVal(new_inputs.back(), {}, {}); - idGraph(mode).mapIds(new_inputs.back(), tmp_input); + idGraph(mode).mapVals(new_inputs.back(), tmp_input); } } } @@ -408,9 +420,8 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { // Gather all use expressions from inputs VectorOfUniqueEntries representative_uses; for (IterDomain* inp : new_inputs) { - auto uses_pair = graph.getUses(graph.toGroup(inp)); - if (uses_pair.second) { - for (const ExprGroup& use_group : uses_pair.first) { + if (const ExprGroups* uses = graph.getUses(graph.toGroup(inp)); uses) { + for (const ExprGroup& use_group : *uses) { representative_uses.pushBack(use_group->front()); } } @@ -557,9 +568,8 @@ Expr* IdModel::addExprWithReplacement( // Forward VectorOfUniqueEntries representative_uses; for (auto in : ir_utils::filterByType(replay->inputs())) { - auto uses_pair = graph.getUses(graph.toGroup(in)); - if (uses_pair.second) { - for (const ExprGroup& use_group : uses_pair.first) { + if (const ExprGroups* uses = graph.getUses(graph.toGroup(in)); uses) { + for (const ExprGroup& use_group : *uses) { if (use_group == replay_group) { continue; } @@ -575,9 +585,9 @@ Expr* IdModel::addExprWithReplacement( // Backwards VectorOfUniqueEntries representative_defs; for (auto out : ir_utils::filterByType(replay->outputs())) { - auto defs_pair = graph.getDefinitions(graph.toGroup(out)); - if (defs_pair.second) { - for (const ExprGroup& def_group : defs_pair.first) { + if (auto definition = graph.getDefinitions(graph.toGroup(out)); + definition) { + for (const ExprGroup& def_group : *definition) { if (def_group == replay_group) { continue; } @@ -620,7 +630,7 @@ IterDomain* IdModel::cloneIterDomain(IterDomain* id) { for (auto mode : initialized_modes) { idGraph(mode).initializeVal(id_copy, {}, {}); - idGraph(mode).mapIds(id, id_copy); + idGraph(mode).mapVals(id, id_copy); } return id_copy; @@ -666,7 +676,7 @@ void IdModel::buildExactGraph(const std::vector& exprs) { for (auto domain_i : c10::irange(c_tv->getRootDomain().size())) { auto c_id = c_tv->getRootDomain()[domain_i]; auto o_id = other_tv_output->getRootDomain()[domain_i]; - idGraph(IdMappingMode::EXACT).mapIds(o_id, c_id); + idGraph(IdMappingMode::EXACT).mapVals(o_id, c_id); } } @@ -682,11 +692,12 @@ void IdModel::buildExactGraph(const std::vector& exprs) { for (auto c_id : getSortedKeys(exact_c2p_root_map, Statement::lessThan)) { auto p_id = exact_c2p_root_map.at(c_id); - idGraph(IdMappingMode::EXACT).mapIds(c_id, p_id); + idGraph(IdMappingMode::EXACT).mapVals(c_id, p_id); } } - idGraph(IdMappingMode::EXACT).mapThroughLoopSwizzles(); + // TODO: Revisit if we really should map domains in the exact map + mapThroughLoopSwizzles(idGraph(IdMappingMode::EXACT)); } } @@ -713,32 +724,32 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { ForwardingInfo permissive_forwarding(p_tv, c_tv); for (auto entry : permissive_forwarding.producer_forwarding_map) { - idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); + idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); } // TODO: Should this just get rolled up in the forwarding map now? for (const auto& entry : permissive_forwarding.producer_compliment_map) { for (auto entry_2 : entry.second) { - idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); + idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry_2); } } for (auto entry : permissive_forwarding.consumer_forwarding_map) { - idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); + idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); } // TODO: Should this just get rolled up in the forwarding map now? // TODO: Why should IDs be mapped to their compliments? Is this right? for (const auto& entry : permissive_forwarding.consumer_compliment_map) { for (auto entry_2 : entry.second) { - idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry_2); + idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry_2); } } auto permissive_c2p_root_map = PairwiseRootDomainMap(p_tv, c_tv); for (auto entry : permissive_c2p_root_map.mapConsumerToProducer()) { - idGraph(IdMappingMode::PERMISSIVE).mapIds(entry.first, entry.second); + idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); } } } @@ -957,14 +968,9 @@ void IdModel::build( }); auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs); - if (!additional_tvs.empty()) { - std::unordered_set all_added_tvs( - all_tvs.begin(), all_tvs.end()); - for (auto additional_tv : additional_tvs) { - if (all_added_tvs.find(additional_tv) == all_added_tvs.end()) { - all_tvs.pushBack(additional_tv); - } - } + + for (auto additional_tv : additional_tvs) { + all_tvs.pushBack(additional_tv); } if (all_tvs.empty()) { @@ -1100,7 +1106,7 @@ ValGraph IdModel::buildIntersection( // id0 and id1 map in group0. If they also map in the group1, // add the mapping to the inersection. if (graph1.disjointValSets().strictAreMapped(id0, id1)) { - intersection.mapIds(id0, id1); + intersection.mapVals(id0, id1); } } } @@ -1120,7 +1126,7 @@ void IdModel::initializeLoopMap(StatefulLoweringInfo& info) { if (entry_it != info.p2c_ca_permissive_maps.end()) { const VectorOfUniqueEntries& c_ids = entry_it->second; for (Val* c_id : c_ids) { - idGraph(IdMappingMode::LOOP).mapIds(p_id, c_id); + idGraph(IdMappingMode::LOOP).mapVals(p_id, c_id); } } } @@ -1337,8 +1343,9 @@ std::unordered_map IdModel::buildInlinePromotions( ExprGroups non_promoted_input_uses; for (const ValGroup& iel_group : promoted_input_groups.computeIntersect(input_groups)) { - non_promoted_input_uses.pushBack( - intersection_exact_loop_graph.getUniqueUses(iel_group)); + const ExprGroups* uses = intersection_exact_loop_graph.getUses(iel_group); + NVF_ERROR(uses); + non_promoted_input_uses.pushBack(*uses); } Expr* replay = nullptr; @@ -1390,7 +1397,7 @@ std::unordered_map IdModel::buildInlinePromotions( iel_promotion_map[out_groups[i]] = replay_out_ids[i]; // Explicitly map loop map since expr propagation doesn't happen if (replayed) { - idGraph(IdMappingMode::LOOP).mapIds(replay_out_ids[i], ref_out_ids[i]); + idGraph(IdMappingMode::LOOP).mapVals(replay_out_ids[i], ref_out_ids[i]); } } } @@ -1441,7 +1448,9 @@ std::unordered_map computeCoveredGroups( for (const ValGroup& id_group : exact_graph.disjointValSets().disjointSets()) { // Initialize inputs - if (exact_graph.getUniqueDefinitions(id_group).empty()) { + const ExprGroups* id_group_defs = exact_graph.getDefinitions(id_group); + NVF_ERROR(id_group_defs); + if (id_group_defs->empty()) { covered_ids[id_group] = {id_group}; } @@ -1739,7 +1748,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( intersection_exact_loop_graph.toGroup(inp_id); promoted_input_groups.push_back(inp_exact_group); promoted_input_uses.pushBack( - intersection_exact_loop_graph.getUniqueUses(inp_exact_group)); + *intersection_exact_loop_graph.getUses(inp_exact_group)); } // Check every use to see if it matches @@ -1815,7 +1824,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // If we built new iter domains because we generated a new expression, // link the outputs in the loop graph. idGraph(IdMappingMode::LOOP) - .mapIds(replay_out_ids[i], ref_out_ids[i]); + .mapVals(replay_out_ids[i], ref_out_ids[i]); } } } diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 9fc213c0254..66db8b0ab85 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -27,7 +27,7 @@ namespace { struct StatefulLoweringInfo; } // namespace -// A collection of IterDomainGraphs that are built from a fusion or series of +// A collection of ValGraphs that are built from a fusion or series of // expressions. These graphs are related, but have some distinct features based // on the IdMappingMode. // @@ -63,6 +63,9 @@ struct StatefulLoweringInfo; // producer's broadcast is inlined (in total or partially). Then the producer's // iter domain will be "promoted" to the size of the consumers iter domain. // +// IdMappingMode::EXACT +// Don't map any broadcast axes to non-broadcast axes +// Do not forward through any broadcast IDs // IdMappingMode::LOOP // Forward broadcast axes in replay // Denotes groups of IterDomains that are considered promoted to a common iter @@ -72,9 +75,6 @@ struct StatefulLoweringInfo; // Map all iteration domains // Always contain root mappings (otherwise they could have been forwarded in // broadcast) -// IdMappingMode::EXACT -// Don't map any broadcast axes to non-broadcast axes -// Do not forward through any broadcast IDs // IdMappingMode::AlmostExact // Forward through broadcast axes, but not through to a non-broadcast axis // i.e. id{b1*i0}, id{i0} are mapped @@ -86,11 +86,9 @@ class IdModel : public PolymorphicBase { public: IdModel( const std::vector& exprs, - const std::vector& additional_tvs, + const std::vector& additional_tvs = {}, bool allow_self_mapping = false); - IdModel(const std::vector& exprs, bool allow_self_mapping = false); - // Same as the above constructor with fusion->exprs() excpet fusion may have // some dangling inputs/outputs that are expected to have IterDomain entries // even though there's no possible connections from them. diff --git a/csrc/id_model/to_string.cpp b/csrc/id_model/to_string.cpp index 87057559f6c..95037ecfb04 100644 --- a/csrc/id_model/to_string.cpp +++ b/csrc/id_model/to_string.cpp @@ -53,16 +53,9 @@ std::string toString(const std::vector& id_group, int indent_size) { std::string toString( const std::vector& id_group, int indent_size) { - std::vector names; - names.reserve(id_group.size()); - for (auto id : id_group) { - names.push_back(id->name()); - } - std::sort(names.begin(), names.end()); - - std::stringstream ss; - ss << indent(indent_size) << "{" << names << "}"; - return ss.str(); + std::vector val_group; + std::copy(id_group.begin(), id_group.end(), std::back_inserter(val_group)); + return toString(val_group, indent_size); } std::string toString(const ValGroup& id_group, int indent_size, bool with_ptr) { @@ -313,32 +306,30 @@ std::string definitionsString( const ValGraph& id_graph, int indent_size, bool with_ptr) { - ExprGroups defs; + ExprGroups all_defs; for (const ValGroup& id_group : id_graph.disjointValSets().disjointSets()) { - auto definition_pair = id_graph.getDefinitions(id_group); - if (definition_pair.second) { - for (const ExprGroup& expr_group : definition_pair.first) { - defs.pushBack(expr_group); + if (auto definition = id_graph.getDefinitions(id_group); definition) { + for (const ExprGroup& expr_group : *definition) { + all_defs.pushBack(expr_group); } } } - return toString(id_graph, defs, indent_size, with_ptr); + return toString(id_graph, all_defs, indent_size, with_ptr); } std::string usesString( const ValGraph& id_graph, int indent_size, bool with_ptr) { - ExprGroups uses; + ExprGroups all_uses; for (const ValGroup& id_group : id_graph.disjointValSets().disjointSets()) { - auto definition_pair = id_graph.getUses(id_group); - if (definition_pair.second) { - for (const ExprGroup& expr_group : definition_pair.first) { - uses.pushBack(expr_group); + if (const ExprGroups* uses = id_graph.getUses(id_group); uses) { + for (const ExprGroup& expr_group : *uses) { + all_uses.pushBack(expr_group); } } } - return toString(id_graph, uses, indent_size, with_ptr); + return toString(id_graph, all_uses, indent_size, with_ptr); } } // namespace nvfuser diff --git a/csrc/id_model/to_string.h b/csrc/id_model/to_string.h index d64d92a8fed..d1ad38d1a7b 100644 --- a/csrc/id_model/to_string.h +++ b/csrc/id_model/to_string.h @@ -15,11 +15,12 @@ namespace nvfuser { -std::string toString(const std::vector& id_group, int indent_size = 0); +std::string toString(const std::vector& val_group, int indent_size = 0); std::string toString( const std::vector& id_group, int indent_size = 0); + std::string toString( const ValGroup& id_group, int indent_size = 0, diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp index 9feb9e0fcff..0c397548539 100644 --- a/csrc/id_model/visitor.cpp +++ b/csrc/id_model/visitor.cpp @@ -37,7 +37,7 @@ void IdGraphVisitor::traverse() { graph().disjointExprSets().disjointSets().end()); } else { for (const ValGroup& id_group : all_ids) { - for (const ExprGroup& def : graph().getUniqueDefinitions(id_group)) { + for (const ExprGroup& def : *(graph().getDefinitions(id_group))) { if (all_exprs.has(def)) { continue; } @@ -91,9 +91,10 @@ void IdGraphVisitor::traverse() { }; auto is_id_ready = [&](const ValGroup& id_group) { - auto unique_defs = graph().getUniqueDefinitions(id_group); + auto unique_defs = graph().getDefinitions(id_group); + NVF_ERROR(unique_defs); return std::all_of( - unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { + unique_defs->begin(), unique_defs->end(), [&](ExprGroup expr_group) { return expr_group->empty() || visited_exprs.has(expr_group) || graph().isTrivialExprGroup(expr_group); }); @@ -144,9 +145,9 @@ void IdGraphVisitor::traverse() { visited_ids.pushBack(current_id_group); if (!terminating_outputs.has(current_id_group)) { - auto uses_pair = graph().getUses(current_id_group); - if (uses_pair.second) { - to_visit_exprs.pushBack(uses_pair.first); + if (const ExprGroups* uses = graph().getUses(current_id_group); + uses) { + to_visit_exprs.pushBack(*uses); } } } else { diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 1d7bf1c5431..52c8e6b3b65 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -120,8 +120,8 @@ std::vector ValGraph::inputGroups(const ExprGroup& expr) const { ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { DequeOfExprGroup to_visit; for (const ValGroup& of_id_group : of) { - if (const auto& [group_uses, found] = getUses(of_id_group); found) { - to_visit.insert(to_visit.end(), group_uses.begin(), group_uses.end()); + if (const ExprGroups* uses = getUses(of_id_group); uses) { + to_visit.insert(to_visit.end(), uses->begin(), uses->end()); } } @@ -131,8 +131,8 @@ ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { to_visit.pop_front(); visited.emplace(current_expr); for (const ValGroup& output_id : outputGroups(current_expr)) { - if (const auto& [group_uses, found] = getUses(output_id); found) { - for (const ExprGroup& group_use : group_uses) { + if (const ExprGroups* uses = getUses(output_id); uses) { + for (const ExprGroup& group_use : *uses) { if (visited.count(group_use)) { continue; } @@ -147,9 +147,10 @@ ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { ExprGroups ValGraph::allDefinitionsOf(const ValGroups& of) const { DequeOfExprGroup to_visit; - for (const ValGroup& of_id_group : of) { - if (const auto& [group_defs, found] = getDefinitions(of_id_group); found) { - to_visit.insert(to_visit.end(), group_defs.begin(), group_defs.end()); + for (const ValGroup& of_val_group : of) { + if (const ExprGroups* group_defs = getDefinitions(of_val_group); + group_defs != nullptr) { + to_visit.insert(to_visit.end(), group_defs->begin(), group_defs->end()); } } @@ -159,8 +160,9 @@ ExprGroups ValGraph::allDefinitionsOf(const ValGroups& of) const { to_visit.pop_front(); visited.emplace(current_expr); for (const ValGroup& input_id : inputGroups(current_expr)) { - if (const auto& [group_defs, found] = getDefinitions(input_id); found) { - for (const ExprGroup& group_def : group_defs) { + if (const ExprGroups* group_defs = getDefinitions(input_id); + group_defs != nullptr) { + for (const ExprGroup& group_def : *group_defs) { if (visited.count(group_def)) { continue; } @@ -263,8 +265,9 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) // domain coming back from any of its uses. ExprGroups min_groups; - std::pair uses_pair = getUses(id_group); - if (!uses_pair.second) { + const ExprGroups* uses = getUses(id_group); + + if (!uses) { // No expressions required for this iter domain, it must be a // terminating output. required_ind_exprs_ids[id_group] = min_groups; @@ -273,8 +276,7 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) // Only worry about expressions between inputs and outputs we're // looking at. - for (const ExprGroup& use_group : - uses_pair.first.computeIntersect(all_exprs)) { + for (const ExprGroup& use_group : uses->computeIntersect(all_exprs)) { auto use_required_ind_exprs_it = required_ind_exprs_exprs.find(use_group); if (use_required_ind_exprs_it == required_ind_exprs_exprs.end()) { // If there isn't an entry for the use expression it wasn't @@ -341,9 +343,9 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) if (processValGroup(currently_visiting_ids)) { something_was_processed = true; - auto definitions_pair = getDefinitions(currently_visiting_ids); - if (definitions_pair.second) { - for (const ExprGroup& def : definitions_pair.first) { + if (const auto definitions = getDefinitions(currently_visiting_ids); + definitions) { + for (const ExprGroup& def : *definitions) { if (!all_exprs.has(def)) { continue; } @@ -370,8 +372,8 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) for (const auto& entry : required_ind_exprs_ids) { const ValGroup& id = entry.first; const ExprGroups& traverse_exprs = entry.second; - if (auto all_uses = getUses(id); all_uses.second) { - uses_path[id] = traverse_exprs.computeIntersect(all_uses.first); + if (auto all_uses = getUses(id); all_uses) { + uses_path[id] = traverse_exprs.computeIntersect(*all_uses); } else { uses_path[id] = {}; continue; @@ -411,11 +413,9 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) auto outputs = outputGroups(currently_visiting); for (const ValGroup& out_id : outputs) { visited.pushBack(out_id); - auto use_pair = getUses(out_id); - if (!use_pair.second) { - continue; + if (const auto uses = getUses(out_id); uses) { + still_to_visit.pushBack(uses->computeIntersect(all_exprs)); } - still_to_visit.pushBack(use_pair.first.computeIntersect(all_exprs)); } } else { still_to_visit.pushBack(currently_visiting); @@ -481,33 +481,6 @@ std::unordered_map> ValGraph::buildMapBetween( return buildMapBetween(from.vector(), to.vector()); } -std::pair ValGraph::getDefinitions( - const ValGroup& id_group) const { - if (!id_group) { - return {{}, false}; - } - - if (auto definitions_it = unique_definitions_.find(id_group); - definitions_it != unique_definitions_.end()) { - return std::make_pair(definitions_it->second, true); - } else { - return {{}, false}; - } -} - -std::pair ValGraph::getUses(const ValGroup& id_group) const { - if (!id_group) { - return {{}, false}; - } - - if (auto uses_it = unique_uses_.find(id_group); - uses_it != unique_uses_.end()) { - return std::make_pair(uses_it->second, true); - } else { - return {{}, false}; - } -} - bool ValGraph::hasUses(const ValGroup& id_group) const { NVF_ERROR(id_group); return unique_uses_.find(id_group) != unique_uses_.end(); @@ -698,55 +671,60 @@ bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { return true; } -const ExprGroups& ValGraph::getUniqueDefinitions(const ValGroup& group) const { - auto unique_defs_it = unique_definitions_.find(group); - NVF_ERROR( - unique_defs_it != unique_definitions_.end(), - "Definition not found for ValGroup: ", - group->toString()); - return unique_defs_it->second; +const ExprGroups* ValGraph::getDefinitions(const ValGroup& val_group) const { + NVF_ERROR(val_group, "Nullptr not allowed"); + if (auto it = unique_definitions_.find(val_group); + it != unique_definitions_.end()) { + return &(it->second); + } else { + return nullptr; + } } -const ExprGroups& ValGraph::getUniqueUses(const ValGroup& group) const { - auto unique_uses_it = unique_uses_.find(group); - NVF_ERROR( - unique_uses_it != unique_uses_.end(), - "Uses not found for ValGroup: ", - group->toString()); - return unique_uses_it->second; +const ExprGroups* ValGraph::getUses(const ValGroup& val_group) const { + NVF_ERROR(val_group, "Nullptr not allowed"); + if (auto it = unique_uses_.find(val_group); it != unique_uses_.end()) { + return &(it->second); + } else { + return nullptr; + } } -void ValGraph::mapIds(Val* id0, Val* id1) { - if (id0 == id1) { +void ValGraph::mapVals(Val* val0, Val* val1) { + if (val0 == val1) { return; } - if (disjointValSets().strictAreMapped(id0, id1)) { + if (disjointValSets().strictAreMapped(val0, val1)) { return; } // Definitions and uses are based on the groups of id0 and id1, don't merge // them into a single group until we grab all definitions and uses for later // processing. - ValGroup orig_id_group0 = toGroup(id0); - ValGroup orig_id_group1 = toGroup(id1); - const ExprGroups& orig_defs0 = getUniqueDefinitions(orig_id_group0); - const ExprGroups& orig_defs1 = getUniqueDefinitions(orig_id_group1); - const ExprGroups& orig_uses0 = getUniqueUses(orig_id_group0); - const ExprGroups& orig_uses1 = getUniqueUses(orig_id_group1); + ValGroup orig_val_group0 = toGroup(val0); + ValGroup orig_val_group1 = toGroup(val1); + const ExprGroups* orig_defs0 = getDefinitions(orig_val_group0); + NVF_ERROR(orig_defs0); + const ExprGroups* orig_defs1 = getDefinitions(orig_val_group1); + NVF_ERROR(orig_defs1); + const ExprGroups* orig_uses0 = getUses(orig_val_group0); + NVF_ERROR(orig_uses0); + const ExprGroups* orig_uses1 = getUses(orig_val_group1); + NVF_ERROR(orig_uses1); // Map the iter domains together before we traverse across definitions and // uses. Traversing definitions and uses could use the new property of id0 and // id1 being mapped. - disjointValSets().mapEntries(id0, id1); - auto new_id_group = toGroup(id0); + disjointValSets().mapEntries(val0, val1); + auto new_val_group = toGroup(val0); - unique_definitions_[new_id_group] = orig_defs0.computeUnion(orig_defs1); - unique_uses_[new_id_group] = orig_uses0.computeUnion(orig_uses1); + unique_definitions_[new_val_group] = orig_defs0->computeUnion(*orig_defs1); + unique_uses_[new_val_group] = orig_uses0->computeUnion(*orig_uses1); // Propagate on uses - if (!orig_uses0.empty() && !orig_uses1.empty()) { - for (const ExprGroup& use_group_1 : orig_uses1) { - for (const ExprGroup& use_group_0 : orig_uses0) { + if (!orig_uses0->empty() && !orig_uses1->empty()) { + for (const ExprGroup& use_group_1 : *orig_uses1) { + for (const ExprGroup& use_group_0 : *orig_uses0) { if (use_group_0 == use_group_1) { continue; } @@ -758,9 +736,9 @@ void ValGraph::mapIds(Val* id0, Val* id1) { } // Propagate on definitions - if (!orig_defs0.empty() && !orig_defs1.empty()) { - for (const ExprGroup& def_group_1 : orig_defs1) { - for (const ExprGroup& def_group_0 : orig_defs0) { + if (!orig_defs0->empty() && !orig_defs1->empty()) { + for (const ExprGroup& def_group_1 : *orig_defs1) { + for (const ExprGroup& def_group_0 : *orig_defs0) { if (def_group_0 == def_group_1) { continue; } @@ -771,10 +749,10 @@ void ValGraph::mapIds(Val* id0, Val* id1) { } } - unique_definitions_.erase(orig_id_group0); - unique_definitions_.erase(orig_id_group1); - unique_uses_.erase(orig_id_group0); - unique_uses_.erase(orig_id_group1); + unique_definitions_.erase(orig_val_group0); + unique_definitions_.erase(orig_val_group1); + unique_uses_.erase(orig_val_group0); + unique_uses_.erase(orig_val_group1); } void ValGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { @@ -865,7 +843,7 @@ bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { "\nand\n", second->toString()); for (auto out_i : c10::irange(first_ids.size())) { - mapIds(first_ids[out_i], second_ids[out_i]); + mapVals(first_ids[out_i], second_ids[out_i]); } return true; @@ -885,8 +863,8 @@ void ValGraph::mapThroughLoopSwizzles() { for (auto swizzle : all_swizzles) { if (swizzle->swizzleMode() == SwizzleMode::Loop) { - mapIds(swizzle->inX(), swizzle->outX()); - mapIds(swizzle->inY(), swizzle->outY()); + mapVals(swizzle->inX(), swizzle->outX()); + mapVals(swizzle->inY(), swizzle->outY()); } } } @@ -911,7 +889,7 @@ void ValGraph::mapThroughTrivialExprs() { // Map through trivial expressions for (auto mapped_id_group : mapped_ids) { for (auto id : mapped_id_group) { - mapIds(mapped_id_group.front(), id); + mapVals(mapped_id_group.front(), id); } } } diff --git a/csrc/val_graph.h b/csrc/val_graph.h index 4740ed54f07..22c5edc1e7a 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -16,6 +16,39 @@ namespace nvfuser { +// ValGraph is a DAG of Vals and Exprs connected by their input and +// output dependencies. Each graph node is a collection of +// either Vals or Exprs that are grouped together through mapVals and +// mapExprs, respectively. +// +// The primary use case of ValGraph is for representing groupings and +// dependencies of iteration domains. For example, given a fusion as +// shown below: +// +// T1 = set(T0); +// T2 = set(T1); +// +// T0: root [I0, I1], leaf [I0, I1] +// T1: root [I2, I3], leaf [I2*I3/4, 4] +// T2: root [I4, I5], leaf [I4*I5/4, 4] +// +// The Exact ValGraph consists of ValGroups of: +// +// - {I0, I2, I4} +// - {I1, I3, I5} +// - {I2*I3, I4*I5} +// - {I2*I3/4, I4*I5/4} +// - {4, 4} +// +// and ExprGroups of: +// +// - {merge of I2 and I3, merge of I4 and I5} +// - {split of I2*I3, split of I4*I5} +// +// ValGraph can be used with any Val types, however, it's currenty +// only tested with IterDomain. Some of the routines might need to be +// extended for other Val types. + using ValGroup = std::shared_ptr>; using ValGroups = VectorOfUniqueEntries; using ExprGroup = std::shared_ptr>; @@ -34,7 +67,7 @@ class ValGraph { ValGraph(bool propagate_through_exprs) : propagate_through_exprs_(propagate_through_exprs) {} - // Returns the disjoint IterDomain set. + // Returns the disjoint val set. const DisjointSets& disjointValSets() const { return disjoint_vals_; } @@ -55,14 +88,14 @@ class ValGraph { // Return if there's a group entry in the graph for this expr bool hasGroup(Expr* expr) const; - // Return if there's a group entry in the graph for this id - bool hasGroup(Val* id) const; + // Return if there's a group entry in the graph for this val + bool hasGroup(Val* val) const; // Convert expr to its exprGroup, assert that it exists. const ExprGroup& toGroup(Expr* expr) const; - // Convert iter domain to its ValGroup, assert that it exists. - const ValGroup& toGroup(Val* id) const; + // Convert Val to its ValGroup, assert that it exists. + const ValGroup& toGroup(Val* val) const; // Convert unique vector of expressions to unique vector of its groups ExprGroups toGroups(const VectorOfUniqueEntries& exprs) const; @@ -79,20 +112,33 @@ class ValGraph { return val_groups; } - // Return output/input iter domain groups of provided expr - // Note that the same IdGroup can show up multiple times, so the + // Return output/input Val groups of provided expr + // Note that the same ValGroup can show up multiple times, so the // output type cannot be VectorOfUniqueEntries std::vector outputGroups(const ExprGroup& expr) const; std::vector inputGroups(const ExprGroup& expr) const; - // Recursively traverses uses of the ValGroups in 'of' and returns all - // ExprGroups that have a use in their definition of provided of ValGroups. + // Recursively traverses uses of the IdGroups in 'of' and returns all + // ExprGroups that have a use in their definition of provided of IdGroups. ExprGroups allUsesOf(const ValGroups& of) const; - // Recursively traverses definitions of the ValGroups in 'of' and returns all - // ExprGroups used in this history of defining the 'of' ValGroups. + // Recursively traverses definitions of the IdGroups in 'of' and returns all + // ExprGroups used in this history of defining the 'of' IdGroups. ExprGroups allDefinitionsOf(const ValGroups& of) const; + //! Returns the pointer to expressions associated with the + //! definitions of the provided ValGroup. Nullptr is returned + //! otherwise. + //! The returned pointer is to a vector of vector of expressions. The + //! inner vector is proven to be equivalent. The + //! outer vector are expression groups that are not equivalent, but + //! produce one of the ValGroups within the same disjoint Val set. + const ExprGroups* getDefinitions(const ValGroup& val_group) const; + + //! Same as getDefinitions but for uses instead of + //! definitions + const ExprGroups* getUses(const ValGroup& val_group) const; + // Return sorted expressions to go from the provided IterDomains in from to // the provided IterDomains in to with provided mode. Minimal expressions to // get from 'from' to 'to' returned. @@ -111,28 +157,6 @@ class ValGraph { const VectorOfUniqueEntries& from, const VectorOfUniqueEntries& to) const; - //! Returns - //! (1) The expressions associated with the definitions of the provided - //! IterDomain group in the provided mapping mode (if it exists). - //! (2) If there is a definitions entry of the provided IterDomain group in - //! the provided mapping mode. - //! First entry in the returned pair is a vector of vector of expressions. The - //! inner vector is proven to be equivalent based on the provided mode. The - //! outer vector are expression groups that are not equivalent based on the - //! provided mode, but produce one of the IterDomains within the same disjoint - //! Iter Domain set based on the provided mode. - //! - //! TODO-NM: ExprGroups is a real container. Consider returning a reference - std::pair getDefinitions(const ValGroup& id_group) const; - - //! Same as iterDomainGroupDefinitions but for uses instead of - //! definitions - //! - //! TODO-NM: ExprGroups is a real container. Consider returning a - //! reference - //! TODO-NM: Rename to getMaybeUses. See getUses - std::pair getUses(const ValGroup& id_group) const; - bool hasUses(const ValGroup& id_group) const; std::string toString() const; @@ -161,14 +185,6 @@ class ValGraph { // , std::vector second_input_or_output_override ) const; - // Returns entry in unique_definitions_ for provided group in provided mode, - // otherwise errors if no entry is found. - const ExprGroups& getUniqueDefinitions(const ValGroup& group) const; - - // Returns entry in unique_uses_ for provided group in provided mode, - // otherwise errors if no entry is found. - const ExprGroups& getUniqueUses(const ValGroup& group) const; - public: void addUniqueUses(const ValGroup& id_group, const ExprGroup& uses) { unique_uses_.at(id_group).pushBack(uses); @@ -178,13 +194,16 @@ class ValGraph { unique_definitions_.at(id_group).pushBack(defs); } - // Set id0 and id1 to mapped in disjointIdsSet[mode], attempt to propagate - // new mapping through id0/id1 definitions/uses. - void mapIds(Val* id0, Val* id1); + // Set val0 and val1 to mapped in this graph, attempt to propagate + // new mapping through val0/val1 definitions/uses. + void mapVals(Val* val0, Val* val1); // Checks if expr0 and expr1 should map together, maps them together, and if - // expression propagation is on, propagates mapping through them. This should - // be the only call in IdGraph to mapThroughExpr + // expression propagation is on, propagates mapping through + // them. The forward parameter determines the direction of the + // propagation. The expressions are mapped if the inputs are mapped + // when the forward parameter is true. This should + // be the only call in ValGraph to mapThroughExpr. void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); // Map through loop swizzles, as input/output IterDomains are exact, only the From c07402e40de9d8a3cd2906eed362ae7535f73e75 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 1 Dec 2023 01:32:28 -0800 Subject: [PATCH 087/178] WIP --- csrc/id_model/id_model.cpp | 2 +- csrc/val_graph.cpp | 116 +++++++++++++++++-------------------- csrc/val_graph.h | 71 ++++++++++------------- 3 files changed, 86 insertions(+), 103 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 87ae40c8eb0..5834aa896a1 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -753,7 +753,7 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { } } } - idGraph(IdMappingMode::PERMISSIVE).mapThroughLoopSwizzles(); + mapThroughLoopSwizzles(idGraph(IdMappingMode::PERMISSIVE)); } void IdModel::buildAlmostExactMap() { diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 52c8e6b3b65..6bc8fd8a558 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -175,6 +175,16 @@ ExprGroups ValGraph::allDefinitionsOf(const ValGroups& of) const { return visited; } +bool ValGraph::hasDefinitions(const ValGroup& val_group) const { + NVF_ERROR(val_group); + return unique_definitions_.find(val_group) != unique_definitions_.end(); +} + +bool ValGraph::hasUses(const ValGroup& val_group) const { + NVF_ERROR(val_group); + return unique_uses_.find(val_group) != unique_uses_.end(); +} + ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) const { ExprGroups all_uses_of_from = allUsesOf(from); @@ -481,11 +491,6 @@ std::unordered_map> ValGraph::buildMapBetween( return buildMapBetween(from.vector(), to.vector()); } -bool ValGraph::hasUses(const ValGroup& id_group) const { - NVF_ERROR(id_group); - return unique_uses_.find(id_group) != unique_uses_.end(); -} - std::string ValGraph::toString() const { std::stringstream ss; ss << "IdGraph { \n"; @@ -568,7 +573,7 @@ void ValGraph::initializeVal( Val* val, const VectorOfUniqueEntries& definitions, const VectorOfUniqueEntries& uses) { - const ValGroup& id_disjoint_set = + const ValGroup& val_disjoint_set = disjointValSets().initializeSet(val).first->second; ExprGroups def_groups; @@ -579,7 +584,7 @@ void ValGraph::initializeVal( } // TODO-NM: def_groups can be empty. Should it be still mapped? // TODO-NM: Can this be overwritten? - NVF_ERROR(unique_definitions_.emplace(id_disjoint_set, def_groups).second); + NVF_ERROR(unique_definitions_.emplace(val_disjoint_set, def_groups).second); ExprGroups use_groups; for (auto use : uses) { @@ -589,38 +594,45 @@ void ValGraph::initializeVal( } // TODO-NM: use_groups can be empty. Should it be still mapped? // TODO-NM: Can this be overwritten? - NVF_ERROR(unique_uses_.emplace(id_disjoint_set, use_groups).second); + NVF_ERROR(unique_uses_.emplace(val_disjoint_set, use_groups).second); +} + +void ValGraph::initializeVal(Val* val) { + VectorOfUniqueEntries defs; + if (val->definition()) { + defs.pushBack(val->definition()); + } + VectorOfUniqueEntries uses; + for (Expr* use : val->uses()) { + uses.pushBack(use); + } + initializeVal(val, defs, uses); } bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { - if (!transformAtributesMatch(first, second)) { + NVF_ERROR(first); + NVF_ERROR(second); + + if (!first->sameOp(second)) { return false; } - auto first_ids = - ir_utils::filterByType(forward ? first->inputs() : first->outputs()) - .vector(); - - auto second_ids = ir_utils::filterByType( - forward ? second->inputs() : second->outputs()) - .vector(); + std::vector first_vals = forward ? first->inputs() : first->outputs(); + std::vector second_vals = + forward ? second->inputs() : second->outputs(); NVF_ERROR( - first_ids.size() == second_ids.size(), + first_vals.size() == second_vals.size(), "Expected number of ", (forward ? "inputs" : "outputs"), " to match for\n", first->toString(), second->toString()); - // TODO-MN: Is this equivalent as - // inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)) ? - { - for (const auto i : c10::irange(first_ids.size())) { - if (!disjointValSets().permissiveAreMapped( - first_ids.at(i), second_ids.at(i))) { - return false; - } + for (const auto i : c10::irange(first_vals.size())) { + if (!disjointValSets().permissiveAreMapped( + first_vals.at(i), second_vals.at(i))) { + return false; } } @@ -655,19 +667,6 @@ bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { // exactly the same given the exact map. We might want to pipe that // information through to here. - // TODO-NM: Should this be transformAtributesMatch? - if (first->isA()) { - if (!first->as()->leftExpand()->sameAs( - second->as()->leftExpand())) { - return false; - } - - if (!first->as()->rightExpand()->sameAs( - second->as()->rightExpand())) { - return false; - } - } - return true; } @@ -756,18 +755,31 @@ void ValGraph::mapVals(Val* val0, Val* val1) { } void ValGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) { + // By default, expressions are mapped only when everything is + // matched, i.e., inputs, outputs and attributes are all mapped or + // equal. When the propagation is allowed, as long as the inputs are + // mapped and the attributes are equal, we propagate the mappings to + // the outputs and the expressions. + // In either case, it should be always true that when two + // expressions are mapped, their inputs and outputs are also mapped, + // respectively, and vice versa. + if (!exprsMap(expr0, expr1, forward)) { return; } - // Expr inputs are mapped. If propagate_exprs_ is true, map the - // exprs and outputs + // Expr inputs are mapped. If propagate_through_exprs_ is true, map the + // exprs and outputs. If not, map the exprs only when both inputs + // and outputs are mapped. Since exprsMap makes sure inputs or + // outputs are mapped, only outputs or inputs need to be checked if (propagate_through_exprs_) { mapExprs(expr0, expr1); mapThroughExpr(expr0, expr1, forward); } else if ( - inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)) && - outputGroups(toGroup(expr0)) == outputGroups(toGroup(expr1))) { + (forward && + outputGroups(toGroup(expr0)) == outputGroups(toGroup(expr1))) || + (!forward && + inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)))) { mapExprs(expr0, expr1); } } @@ -849,26 +861,6 @@ bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return true; } -void ValGraph::mapThroughLoopSwizzles() { - std::vector all_swizzles; - - for (const auto& expr_set : disjointExprSets().disjointSets()) { - auto swizzles_in_expr_set = ir_utils::filterByType( - expr_set->vector().begin(), expr_set->vector().end()); - all_swizzles.insert( - all_swizzles.end(), - swizzles_in_expr_set.begin(), - swizzles_in_expr_set.end()); - } - - for (auto swizzle : all_swizzles) { - if (swizzle->swizzleMode() == SwizzleMode::Loop) { - mapVals(swizzle->inX(), swizzle->outX()); - mapVals(swizzle->inY(), swizzle->outY()); - } - } -} - void ValGraph::mapThroughTrivialExprs() { // Grab all expressions std::vector exprs; diff --git a/csrc/val_graph.h b/csrc/val_graph.h index 22c5edc1e7a..b29653c2ad1 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -127,8 +127,8 @@ class ValGraph { ExprGroups allDefinitionsOf(const ValGroups& of) const; //! Returns the pointer to expressions associated with the - //! definitions of the provided ValGroup. Nullptr is returned - //! otherwise. + //! definitions of the provided ValGroup. Nullptr is returned otherwise. + //! //! The returned pointer is to a vector of vector of expressions. The //! inner vector is proven to be equivalent. The //! outer vector are expression groups that are not equivalent, but @@ -139,6 +139,10 @@ class ValGraph { //! definitions const ExprGroups* getUses(const ValGroup& val_group) const; + bool hasDefinitions(const ValGroup& val_group) const; + + bool hasUses(const ValGroup& val_group) const; + // Return sorted expressions to go from the provided IterDomains in from to // the provided IterDomains in to with provided mode. Minimal expressions to // get from 'from' to 'to' returned. @@ -157,8 +161,6 @@ class ValGraph { const VectorOfUniqueEntries& from, const VectorOfUniqueEntries& to) const; - bool hasUses(const ValGroup& id_group) const; - std::string toString() const; // Checks if the expression is a trivial operation where an input is simply an @@ -168,22 +170,22 @@ class ValGraph { // Returns if all atributes of the ID transforms first and second are the same static bool transformAtributesMatch(Expr* first, Expr* second); - // Initializes entries for the provided IterDomain in the IterDomainGraphs + // Initializes entries for the provided Val with its definitions and + // uses. void initializeVal( Val* val, const VectorOfUniqueEntries& definitions, const VectorOfUniqueEntries& uses); - // Returns if first and second are expressions through which the provided - // id_map have matching inputs (if forward), or outputs (if not forward). - // Returning true means the expressions are "the same", in terms they modify - // matching original extents, by the same amount. - bool exprsMap( - Expr* first, - Expr* second, - bool forward - // , std::vector second_input_or_output_override - ) const; + // Same as the above exept val->definition() and val->uses() are + // used + void initializeVal(Val* val); + + // Returns true if first and second are expressions through which + // this ValGraph has matching inputs (if forward), or outputs (if not + // forward). Returning true means the expressions are "the same", in terms + // they modify matching original inputs by the same amount. + bool exprsMap(Expr* first, Expr* second, bool forward) const; public: void addUniqueUses(const ValGroup& id_group, const ExprGroup& uses) { @@ -206,10 +208,6 @@ class ValGraph { // be the only call in ValGraph to mapThroughExpr. void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); - // Map through loop swizzles, as input/output IterDomains are exact, only the - // order they're traversed differs. - void mapThroughLoopSwizzles(); - // Maps iter domain pairs returned by calling that return mappings from // IdGraph::isTrivialExpr on every expression in the graph. void mapThroughTrivialExprs(); @@ -232,49 +230,42 @@ class ValGraph { } private: - // Map expr0 and expr1 with eachother, update unique_definitions_ unique_uses_ + // Map expr0 and expr1 with each other, update unique_definitions_ + // unique_uses_ // TODO: Make this variant hidden? void mapExprs(Expr* expr0, Expr* expr1); - // Checks if expr's are considered "the same" where sameness inputs and - // outputs in the same position across expressions map with provided - // MappingMode. If the expressions are determined the same then + // Checks if expr's are considered "the same" where sameness is + // defined as inputs and outputs in the same position across + // expressions are mapped. If the expressions are determined the + // same then + // // if forward // will map outputs // else // will map inputs - // in the provided mode. - // Returns if expressions were mapped through. // + // Returns true if expressions were mapped through. bool mapThroughExpr(Expr* first, Expr* second, bool forward); private: // If propagate_through_exprs_ = false, then mapThroughExpr will not be called - // as a consequence of calling mapIds. As well as mapThroughExpr will not be + // as a consequence of calling mapVals. As well as mapThroughExpr will not be // called (again) as a result of calling mapThroughExpr. // - // Note: For the second sentence of above... mapThroughExpr can call mapIds - // which could in return call mapThoughExpr again, but propagate_exprs_ as - // mentioned above prevents that from happening. + // Note: For the second sentence of above... mapThroughExpr can call mapVals + // which could in return call mapThoughExpr again, but + // propagate_through_exprs_ as mentioned above prevents that from happening. bool propagate_through_exprs_ = true; - // Keeps a disjoint set entry for all IterDomain for all mapping mode types. - // - // Using an array here might be nice, but it seems hard to use an enum as an - // array key - // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum + // Keeps a disjoint set entry for all Vals. DisjointSets disjoint_vals_; - // Keeps a disjoint set entry for all Expressions for all mapping mode types. + // Keeps a disjoint set entry for all Exprs. DisjointSets disjoint_exprs_; // Definitions of ValGroup. There can be multiple definitions due to // replays. - // TODO-NM: ValGroup by a new definition ExprGroup would not be used - // by existing uses. Does it make sense to represent uses and defs - // this way? In other words, there is a traversal path from a - // definition ExprGroup to an ValGroup and its use ExprGroup, but - // that does't guarantee the path actually exist std::unordered_map unique_definitions_; std::unordered_map unique_uses_; From dbff25c5532e064a0c0f29d0264a01d6bcacee63 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 1 Dec 2023 11:44:03 -0800 Subject: [PATCH 088/178] remove non-const disjointExprs and disjointVals --- csrc/id_model/id_model.cpp | 8 ++++---- csrc/val_graph.cpp | 18 +++++++++++++----- csrc/val_graph.h | 13 +++++-------- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 5834aa896a1..51bcf5ac192 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -400,7 +400,7 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { // Initialize output iter domains in the graphs for (auto mode : initialized_modes) { - idGraph(mode).disjointExprSets().initializeSet(replay); + idGraph(mode).registerExpr(replay); auto replay_group = idGraph(mode).toGroup(replay); // Initialize output ids in map @@ -533,10 +533,10 @@ Expr* IdModel::addExprWithReplacement( for (auto mode : initialized_modes) { auto& graph = idGraph(mode); - graph.disjointExprSets().initializeSet(replay); + graph.registerExpr(replay); auto replay_group = graph.toGroup(replay); - // Initialize any non-existant input ids, update existing ones + // Initialize any non-existent input ids, update existing ones for (auto inp_id : ir_utils::filterByType(replay->inputs())) { if (!graph.disjointValSets().mappingExists(inp_id)) { // inp_id is not initialized in the map, initialize it @@ -548,7 +548,7 @@ Expr* IdModel::addExprWithReplacement( } } - // Initialize any non-existant output ids, update existing ones + // Initialize any non-existent output ids, update existing ones for (auto out_id : ir_utils::filterByType(replay->outputs())) { if (!graph.disjointValSets().mappingExists(out_id)) { // out_id is not initialized in the map, initialize it diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 6bc8fd8a558..d1626a66124 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -574,12 +574,12 @@ void ValGraph::initializeVal( const VectorOfUniqueEntries& definitions, const VectorOfUniqueEntries& uses) { const ValGroup& val_disjoint_set = - disjointValSets().initializeSet(val).first->second; + disjoint_vals_.initializeSet(val).first->second; ExprGroups def_groups; for (auto def : definitions) { const ExprGroup& expr_set = - disjointExprSets().initializeSet(def).first->second; + disjoint_exprs_.initializeSet(def).first->second; def_groups.pushBack(expr_set); } // TODO-NM: def_groups can be empty. Should it be still mapped? @@ -589,7 +589,7 @@ void ValGraph::initializeVal( ExprGroups use_groups; for (auto use : uses) { const ExprGroup& expr_set = - disjointExprSets().initializeSet(use).first->second; + disjoint_exprs_.initializeSet(use).first->second; use_groups.pushBack(expr_set); } // TODO-NM: use_groups can be empty. Should it be still mapped? @@ -609,6 +609,14 @@ void ValGraph::initializeVal(Val* val) { initializeVal(val, defs, uses); } +void ValGraph::registerExpr(Expr* expr) { + NVF_ERROR( + !disjoint_exprs_.mappingExists(expr), + "Already in the disjoint sets: ", + expr->toString()); + disjoint_exprs_.initializeSet(expr); +} + bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { NVF_ERROR(first); NVF_ERROR(second); @@ -714,7 +722,7 @@ void ValGraph::mapVals(Val* val0, Val* val1) { // Map the iter domains together before we traverse across definitions and // uses. Traversing definitions and uses could use the new property of id0 and // id1 being mapped. - disjointValSets().mapEntries(val0, val1); + disjoint_vals_.mapEntries(val0, val1); auto new_val_group = toGroup(val0); unique_definitions_[new_val_group] = orig_defs0->computeUnion(*orig_defs1); @@ -796,7 +804,7 @@ void ValGraph::mapExprs(Expr* expr0, Expr* expr1) { ExprGroup expr0_orig_group = toGroup(expr0); ExprGroup expr1_orig_group = toGroup(expr1); - disjointExprSets().mapEntries(expr0, expr1); + disjoint_exprs_.mapEntries(expr0, expr1); auto expr_new_group = toGroup(expr0); diff --git a/csrc/val_graph.h b/csrc/val_graph.h index b29653c2ad1..a24064cafc9 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -72,19 +72,11 @@ class ValGraph { return disjoint_vals_; } - DisjointSets& disjointValSets() { - return disjoint_vals_; - } - // Returns the disjoint Expr set. const DisjointSets& disjointExprSets() const { return disjoint_exprs_; } - DisjointSets& disjointExprSets() { - return disjoint_exprs_; - } - // Return if there's a group entry in the graph for this expr bool hasGroup(Expr* expr) const; @@ -181,6 +173,11 @@ class ValGraph { // used void initializeVal(Val* val); + // Add expr to the disjoint sets as a sole group. Used for + // registering replayed domains and exprs. Error if the expr is + // already registered. + void registerExpr(Expr* expr); + // Returns true if first and second are expressions through which // this ValGraph has matching inputs (if forward), or outputs (if not // forward). Returning true means the expressions are "the same", in terms From f1b5f63d1f46f93ac02cb12ea8cdad01d5c608d1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 1 Dec 2023 12:47:59 -0800 Subject: [PATCH 089/178] Clean up self mapping --- CMakeLists.txt | 1 + csrc/id_model/id_model.cpp | 21 +++++++++++++-------- csrc/id_model/id_model.h | 4 ++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d6dfdc73c4..692d1ea8ab4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -431,6 +431,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/test/test_no_op.cpp ${NVFUSER_ROOT}/test/test_linked_hash_map.cpp ${NVFUSER_ROOT}/test/test_pointwise.cpp + ${NVFUSER_ROOT}/test/test_id_model.cpp ) # We don't link CUPTI for MSVC diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 51bcf5ac192..dac66d76aaf 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -134,14 +134,14 @@ Expr* IdModel::idDef(IterDomain* id) const { namespace { -// Returns the first pair of id's in ids detected to match eachother on the -// permissive map of the ID graph. TODO: what this is really looking for is if +// Returns the first pair of id's in ids detected to match each other on the +// exact ID graph. TODO: what this is really looking for is if // there's any overlapping between the iter domains in the provided set. // // i.e. if we have: -// tv0 = arange(6).view({3, 2}) +// tv0 = arange(6).reshape({3, 2}) // tv1 = tv0[3, 2].t() -// tv2 = tv0[3, 2].view({2, 3}) +// tv2 = tv0[3, 2].reshape({2, 3}) // tv3 = tv1 + tv2 // // Then we can see this overlap in the tv3 expression as: @@ -165,7 +165,12 @@ namespace { // will assume tv2 can be trivially inlined/parallelized. Instead we'd need to // take into consideration the effective communication going on here, so that // we pull multiple values of tv0 to compute tv3. -c10::optional> detectMappablePair( +// +// Note, however, that the above example is not detectable at this +// moment as the self mapping is partial through reshape. The analysis +// below would need to be extended to consider producer and consumers +// of domains as well rather than just root, rfactor and leaf domains. +std::optional> detectMappablePair( const std::vector& ids, const IdModel& id_graph, IdMappingMode mode) { @@ -181,14 +186,14 @@ c10::optional> detectMappablePair( } } - return {}; + return std::nullopt; } // It is assumed that for any tensor represented by a list of domains, // those domains should never be mapped with each other. It may be // possible to lift this assumption, but it's unclear if it could // matter in practice. -c10::optional> +std::optional> findFirstSelfMapping( const std::vector& all_tvs, const IdModel& id_graph) { @@ -236,7 +241,7 @@ findFirstSelfMapping( "Leaf"); } } - return c10::nullopt; + return std::nullopt; } } // namespace diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 66db8b0ab85..4c1da0d572a 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -289,8 +289,8 @@ class IdModel : public PolymorphicBase { std::unordered_map> id_definitions_; // Debug information to hold if a self mapping in a TensorView is found. - c10::optional> - self_mapping_info_ = c10::nullopt; + std::optional> + self_mapping_info_ = std::nullopt; // Promotion domain for each loop group std::unordered_map loop_promotion_map_; From ad5debf77518c8c602305ef8f0c8e095540f4637 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 5 Dec 2023 13:15:19 -0800 Subject: [PATCH 090/178] IdModel: merge main (#1453) --- CMakeLists.txt | 1 + csrc/compute_at_map.cpp | 3 - csrc/contiguity.cpp | 12 +- .../device_lower/analysis/divisible_split.cpp | 2 +- .../analysis/predicate_elimination.cpp | 2 +- .../analysis/sync_information.cpp | 4 +- .../analysis/thread_predicate.cpp | 5 +- csrc/device_lower/pass/alias_memory.cpp | 3 +- csrc/device_lower/pass/allocation.cpp | 2 +- csrc/device_lower/pass/expr_sort.cpp | 10 +- csrc/device_lower/pass/inline_ptx.cpp | 30 ++- csrc/device_lower/pass/warp_reduce.cpp | 3 +- csrc/device_lower/validation.cpp | 3 +- csrc/dynamic_transform.cpp | 5 +- csrc/executor.cpp | 69 +++--- csrc/executor.h | 20 ++ csrc/executor_utils.cpp | 12 +- csrc/executor_utils.h | 1 + csrc/fusion.cpp | 11 +- csrc/fusion_segmenter.cpp | 4 +- csrc/index_compute.cpp | 36 +--- csrc/ir/builder.cpp | 23 ++ csrc/ir/builder.h | 6 + csrc/ir/cloner.cpp | 8 +- csrc/ir/cloner.h | 2 - csrc/ir/iostream.cpp | 1 + csrc/ir/nodes.cpp | 8 +- csrc/ir/utils.cpp | 6 +- csrc/iter_visitor.cpp | 71 +++--- csrc/iter_visitor.h | 17 +- csrc/kernel_ir.cpp | 7 +- csrc/kernel_ir.h | 7 +- csrc/mma_type.cpp | 16 ++ csrc/mma_type.h | 12 ++ csrc/multidevice/executor.cpp | 2 +- csrc/multidevice/pipeline.cpp | 2 +- csrc/non_divisible_split.cpp | 2 +- csrc/ops/arith.cpp | 8 +- csrc/optimization/alias_analysis.cpp | 13 +- csrc/optimization/alias_analysis.h | 2 + csrc/optimization/mark_alias.cpp | 5 + csrc/partial_split_map.cpp | 2 +- csrc/root_domain_map.cpp | 6 +- csrc/scheduler/reduction_utils.cpp | 10 +- csrc/scheduler/transpose.cpp | 6 +- csrc/scheduler/utils.cpp | 38 +++- csrc/scheduler/utils.h | 7 + csrc/scheduler/vectorize_helper.cpp | 8 +- csrc/tensor_metadata.cpp | 8 +- csrc/tensor_view.cpp | 2 +- csrc/tma.cpp | 33 +-- csrc/tma.h | 5 +- csrc/transform_iter.cpp | 23 +- csrc/type.cpp | 6 +- nvfuser/__init__.py | 5 +- python_tests/test_python_frontend.py | 56 +++++ test/test_alias.cpp | 59 +++-- test/test_allocation_domain.cpp | 4 + test/test_gpu2.cpp | 5 +- test/test_gpu3.cpp | 203 +++++++++--------- test/test_gpu_utils.cpp | 2 +- test/test_gpu_view.cpp | 7 +- test/test_id_model.cpp | 40 ++++ test/test_iter_visitor.cpp | 121 +++++++++++ test/test_mma.cpp | 35 ++- test/test_optimization_pass.cpp | 6 +- test/test_resize.cpp | 42 ++++ test/test_swizzle.cpp | 1 - 68 files changed, 767 insertions(+), 429 deletions(-) create mode 100644 test/test_id_model.cpp create mode 100644 test/test_iter_visitor.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 692d1ea8ab4..3b77b817b93 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -431,6 +431,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/test/test_no_op.cpp ${NVFUSER_ROOT}/test/test_linked_hash_map.cpp ${NVFUSER_ROOT}/test/test_pointwise.cpp + ${NVFUSER_ROOT}/test/test_iter_visitor.cpp ${NVFUSER_ROOT}/test/test_id_model.cpp ) diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index 6955309c072..b6c6088a5d1 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -605,7 +605,6 @@ void IterDomainGraph::build(Fusion* fusion) { // Grab all the rfactor ids. for (auto consumer_tv : all_consumer_tvs) { auto exprs = StmtSort::getExprsTo( - fusion, {consumer_tv->getMaybeRFactorDomain().begin(), consumer_tv->getMaybeRFactorDomain().end()}); for (auto expr : exprs) { @@ -1437,8 +1436,6 @@ std::string ComputeAtMap::toString() const { << idGraphNodesToString(*this, IdMappingMode::PERMISSIVE); ss << "Permissive-Resize map:\n" << idGraphNodesToString(*this, IdMappingMode::PERMISSIVE_RESIZE); - ss << "Innermost map:\n" - << idGraphNodesToString(*this, IdMappingMode::INNERMOST); ss << "Consumer maps:\n"; for (auto key : getSortedKeys(id_graph_.consumers(), Statement::lessThan)) { auto consumers = id_graph_.consumers().at(key); diff --git a/csrc/contiguity.cpp b/csrc/contiguity.cpp index a1df7807845..c41cf793644 100644 --- a/csrc/contiguity.cpp +++ b/csrc/contiguity.cpp @@ -39,9 +39,7 @@ OrderedIdInformation::OrderedIdInformation( // consistently_ordered_ids_, id_to_alloc_ids_, and // exclusively_consumes_allocs_ for all the IDs auto exprs = StmtSort::getExprsBetween( - ids[0]->fusion(), - {alloc_domain.begin(), alloc_domain.end()}, - {ids.begin(), ids.end()}); + {alloc_domain.begin(), alloc_domain.end()}, {ids.begin(), ids.end()}); for (auto expr : exprs) { OptInDispatch::dispatch(expr); @@ -386,9 +384,7 @@ NonDivisibleSplitDependencies::NonDivisibleSplitDependencies( return; } auto transforms = StmtSort::getExprsBetween( - ids[0]->fusion(), - {alloc_domain.begin(), alloc_domain.end()}, - {ids.begin(), ids.end()}); + {alloc_domain.begin(), alloc_domain.end()}, {ids.begin(), ids.end()}); for (auto transform : transforms) { auto inp_ids = ir_utils::filterByType(transform->inputs()); for (auto inp_id : inp_ids) { @@ -545,9 +541,7 @@ void ContigIDs::build(const std::vector& ids) { if (!contig_ids_.empty()) { auto exprs = StmtSort::getExprsBetween( - ids.at(0)->fusion(), - {alloc_domain_.begin(), alloc_domain_.end()}, - {ids.begin(), ids.end()}); + {alloc_domain_.begin(), alloc_domain_.end()}, {ids.begin(), ids.end()}); for (auto expr : exprs) { if (auto resize = dynamic_cast(expr)) { resize_deps_.insert(resize->out()); diff --git a/csrc/device_lower/analysis/divisible_split.cpp b/csrc/device_lower/analysis/divisible_split.cpp index 75d344a4dbd..5b62bf5de2d 100644 --- a/csrc/device_lower/analysis/divisible_split.cpp +++ b/csrc/device_lower/analysis/divisible_split.cpp @@ -38,7 +38,7 @@ std::unordered_set getAllDivisibleSplits( // Take the view transformations and add all the splits. Those splits are // the only divisible splits. auto view_exprs = - StmtSort::getExprsTo(fusion, {rfactor_dom.begin(), rfactor_dom.end()}); + StmtSort::getExprsTo({rfactor_dom.begin(), rfactor_dom.end()}); auto split_exprs = ir_utils::filterByType(view_exprs); all_divisible_splits.insert(split_exprs.begin(), split_exprs.end()); } diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 676c5f7e695..e3078b326d4 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -863,7 +863,7 @@ class PredicateChcker : public IterVisitor { } // namespace PredicateElimination::PredicateElimination(Fusion* fusion) { - traverseTo(fusion, fusion->outputs()); + traverseTo(fusion->outputs()); } bool PredicateElimination::needsPredicate(Expr* expr) const { diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index f7c922ad5e8..36a5c891314 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -106,7 +106,6 @@ struct ProducerConsumerIndexingInfoCache { const auto& consumer_leaf_ids_shared_with_producer = getConsumerLeafIDsSharedWithProducer(); consumer_root_ids_shared_with_producer_ = InputsOf::outputs( - producer_tv_->fusion(), {consumer_leaf_ids_shared_with_producer.begin(), consumer_leaf_ids_shared_with_producer.end()}); } @@ -261,10 +260,9 @@ bool useSameIndex( // consumer_id. The goal of the analysis below is to find out if all // of the root IDs are indexed in the same way between the producer // and consumer tensors. - auto consumer_root_ids = InputsOf::output(consumer_id->fusion(), consumer_id); + auto consumer_root_ids = InputsOf::output(consumer_id); auto producer_root_vals = StmtSort::getStmtsBetween( - producer_id->fusion(), {producer_tv->getMaybeRFactorDomain().begin(), producer_tv->getMaybeRFactorDomain().end()}, {producer_id}); diff --git a/csrc/device_lower/analysis/thread_predicate.cpp b/csrc/device_lower/analysis/thread_predicate.cpp index 0258d7e2da1..fdb0b0806aa 100644 --- a/csrc/device_lower/analysis/thread_predicate.cpp +++ b/csrc/device_lower/analysis/thread_predicate.cpp @@ -239,8 +239,9 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) { // Run through inputs and update bitsets for (const auto* inp : expr->inputs()) { - if (!ir_utils::isTV(inp)) + if (!ir_utils::isTV(inp)) { continue; + } auto tv_inp = inp->as(); @@ -365,7 +366,7 @@ class RedundantUseAnalysis : BackwardVisitor { public: RedundantUseAnalysis(Fusion* fusion, const ThreadPredicateMap& pred_map) : fusion_(fusion), pred_map_(pred_map) { - traverseTo(fusion, fusion->terminatingMathVals()); + traverseTo(fusion->terminatingMathVals()); } //! Returns a bit map signifying the parallel dimensions diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 080967fb1aa..8e59da72dd6 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -122,8 +122,7 @@ bool isSerialBroadcastResolution( // traverse across view boundaries as we do in indexing. This // should not result in false aliasing but may miss safe aliasing // opportunities. - auto serial_loop_roots = - InputsOf::outputs(FusionGuard::getCurFusion(), serial_loop_concrete_ids); + auto serial_loop_roots = InputsOf::outputs(serial_loop_concrete_ids); // Collect exact concrete id's in producer's root domain std::unordered_set producer_exact_concrete_root_ids; diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 69e7cb69ce8..8c3052e2751 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -218,7 +218,7 @@ class AllocationInserter : public kir::ExprMutator { [](IterDomain* dom) { return dom->as(); }); // Get all exprs involved in generating the allocation IDs - auto exprs = StmtSort::getExprsTo(tv->fusion(), start_vals); + auto exprs = StmtSort::getExprsTo(start_vals); // Get the halo extent if found auto getExtent = [this](IterDomain* id) { diff --git a/csrc/device_lower/pass/expr_sort.cpp b/csrc/device_lower/pass/expr_sort.cpp index 386d8ead16c..52fa723fdff 100644 --- a/csrc/device_lower/pass/expr_sort.cpp +++ b/csrc/device_lower/pass/expr_sort.cpp @@ -406,14 +406,16 @@ std::string ExprGroup::toString() const { os << " ca_ids {"; for (size_t i = 0; i < payload()->ca_domains.size(); i++) { os << payload()->ca_domains[i]; - if (i + 1 != payload()->ca_domains.size()) + if (i + 1 != payload()->ca_domains.size()) { os << ", "; + } } os << "} pa_ids {"; for (size_t i = 0; i < payload()->pa_domains.size(); i++) { os << payload()->pa_domains[i]; - if (i + 1 != payload()->pa_domains.size()) + if (i + 1 != payload()->pa_domains.size()) { os << ", "; + } } os << "}"; os << "\nExprs {\n"; @@ -1507,9 +1509,7 @@ void ExprSegmentationSorter::sort() { // Not putting the exprs between allKnownVals() and fusion inputs here // because they are computed using the expr evaluator. auto all_exprs = StmtSort::getExprsBetween( - fusion_, - GpuLower::current()->allKnownVals(), - fusion_->getTerminatingOutputs()); + GpuLower::current()->allKnownVals(), fusion_->getTerminatingOutputs()); // Figure out all the values used as inputs to the expressions we're sorting // (to find terminating expressions). There could be branches of expressions diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 134d172e6c4..0c97b9aa7de 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -27,7 +27,7 @@ class LowerToInlinePtx : public kir::ExprMutator { "cp.async.commit_group", std::vector{}, std::vector{}, - kir::Asm::Options{true})); + kir::Asm::Options{/*volatile=*/true})); } void handle(kir::CpAsyncWait* wait) override { @@ -38,13 +38,13 @@ class LowerToInlinePtx : public kir::ExprMutator { "cp.async.wait_group", std::vector{}, std::vector{IrBuilder::create(stages)}, - kir::Asm::Options{true}); + kir::Asm::Options{/*volatile=*/true}); } else { replace = IrBuilder::create( "cp.async.wait_all", std::vector{}, std::vector{}, - kir::Asm::Options{true}); + kir::Asm::Options{/*volatile=*/true}); } registerReplace(wait, replace); @@ -57,7 +57,7 @@ class LowerToInlinePtx : public kir::ExprMutator { "cp.async.bulk.commit_group", std::vector{}, std::vector{}, - kir::Asm::Options{true})); + kir::Asm::Options{/*volatile=*/true})); } void handle(kir::CpAsyncBulkS2GWait* wait) override { @@ -68,7 +68,7 @@ class LowerToInlinePtx : public kir::ExprMutator { "cp.async.bulk.wait_group.read", std::vector{}, std::vector{IrBuilder::create(stages)}, - kir::Asm::Options{true, true})); + kir::Asm::Options{/*volatile=*/true, /*memory=*/true})); } void handle(LoadStoreOp* ldst) override { @@ -87,7 +87,7 @@ class LowerToInlinePtx : public kir::ExprMutator { ss.str(), std::vector{ldst->out()}, std::vector{ldst->in()}, - kir::Asm::Options{true})); + kir::Asm::Options{/*volatile=*/true})); return; } else if (ir_utils::isCpAsyncOp(ldst)) { auto out_tv = ldst->out()->as()->view(); @@ -113,11 +113,11 @@ class LowerToInlinePtx : public kir::ExprMutator { ldst->in(), IrBuilder::create(vec_size), ldst->predicate()}, - kir::Asm::Options{true})); + kir::Asm::Options{/*volatile=*/true})); } } - void handle(MmaOp* mma) override { + void handleTuringOrAmpereMma(MmaOp* mma) { // Constants definitions based on MMA PTX instruction documentation: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#multiply-and-accumulate-instruction-mma const int m = 16; @@ -186,6 +186,20 @@ class LowerToInlinePtx : public kir::ExprMutator { } registerRemove(mma); } + + void handleHopperMma(MmaOp* mma) { + NVF_ERROR(false, "Hopper MMA not supported yet"); + } + + void handle(MmaOp* mma) override { + if (mma->isTuring() || mma->isAmpere()) { + handleTuringOrAmpereMma(mma); + } else if (mma->isHopper()) { + handleHopperMma(mma); + } else { + NVF_ERROR(false, "Unsupported MMA architecture"); + } + } }; std::vector lowerToInlinePtx(const std::vector& exprs) { diff --git a/csrc/device_lower/pass/warp_reduce.cpp b/csrc/device_lower/pass/warp_reduce.cpp index c41e72a690e..8a27af271f7 100644 --- a/csrc/device_lower/pass/warp_reduce.cpp +++ b/csrc/device_lower/pass/warp_reduce.cpp @@ -104,8 +104,7 @@ class EliminateDeadBroadcastAndAllocate { // Also find any TVs used in index expressions. // These expressions will likely not be in the Expr tree we are // provided, so we need to traverse to find them. - auto all_index_roots = - InputsOf::outputs(FusionGuard::getCurFusion(), {ti->index()}); + auto all_index_roots = InputsOf::outputs({ti->index()}); auto index_root_tis = ir_utils::filterByType(all_index_roots); for (auto rootti : index_root_tis) { diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 9e2bc68e413..9e5df2bc26a 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -836,7 +836,7 @@ void validatePartialSplit(Fusion* fusion) { for (auto tv : ir_utils::allTvs(fusion)) { auto exprs = StmtSort::getExprsTo( - tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); + {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // When the start and stop offsets are not zero, make sure the // range defined by the split includes the required range to @@ -1276,7 +1276,6 @@ void validateResize(Fusion* fusion) { for (auto tv : ir_utils::filterByType(fusion_vals)) { // Make sure resize is only used as part of rfactor transformations auto rf_to_leaf_exprs = StmtSort::getExprsBetween( - fusion, {tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()}, {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index fb0c091632e..87550c6561c 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -90,7 +90,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { !fusion->isA(), "Invalid container. Kernel container not allowed.\n"); - traverseTo(fusion, fusion->getTerminatingOutputs(), false, false); + traverseTo(fusion->getTerminatingOutputs(), false, false); finalizeDynamicVals(); @@ -147,7 +147,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { //! Process vector of leaf dynamic values by finding inputs and recording the //! result into info_ void finalizeDynamicVals() { - const auto inputs = InputsOf::outputs(info_.fusion(), leaf_dynamic_vals_); + const auto inputs = InputsOf::outputs(leaf_dynamic_vals_); info_.root_dynamic_vals_.insert(inputs.begin(), inputs.end()); // initial_info_ provides a set of Vals that are used for concretization. @@ -621,7 +621,6 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Note that it is assumed that theres's no further expression // beyond the rfactor domain as asserted above auto all_id_exprs = StmtSort::getExprsBetween( - tv->fusion(), {tv->getRootDomain().begin(), tv->getRootDomain().end()}, {tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()}); diff --git a/csrc/executor.cpp b/csrc/executor.cpp index 6232dc5b03a..ed85e0da9b1 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -275,7 +275,7 @@ void FusionExecutor::compileFusion( } output_extents.emplace_back(extent); } - auto dependencies = InputsOf::outputs(fusion, output_extents); + auto dependencies = InputsOf::outputs(output_extents); if (std::any_of(dependencies.begin(), dependencies.end(), [](Val* val) { return val->isFusionInput(); })) { @@ -607,7 +607,6 @@ std::pair, std::vector> inferShapeOfOutput( class ForwardTraverseFromAllocToRFactor { at::Tensor tensor_; - TensorView* tv_; ExpressionEvaluator& ee_; std::list& frontier_; @@ -725,18 +724,15 @@ class ForwardTraverseFromAllocToRFactor { public: ForwardTraverseFromAllocToRFactor( at::Tensor tensor, - TensorView* tv, ExpressionEvaluator& ee, std::list& frontier) - : tensor_(std::move(tensor)), tv_(tv), ee_(ee), frontier_(frontier) {} + : tensor_(std::move(tensor)), ee_(ee), frontier_(frontier) {} at::Tensor run( const std::vector& rfactor, const std::vector& alloc) { auto forward_exprs = StmtSort::getExprsBetween( - tv_->fusion(), - {alloc.begin(), alloc.end()}, - {rfactor.begin(), rfactor.end()}); + {alloc.begin(), alloc.end()}, {rfactor.begin(), rfactor.end()}); for (auto expr : forward_exprs) { handle(expr); } @@ -748,7 +744,6 @@ class ForwardTraverseFromAllocToRFactor { // transformations. class BackwardTraverseFromAllocToRFactor { at::Tensor tensor_; - TensorView* tv_; ExpressionEvaluator& ee_; std::list& frontier_; @@ -853,18 +848,15 @@ class BackwardTraverseFromAllocToRFactor { public: BackwardTraverseFromAllocToRFactor( at::Tensor tensor, - TensorView* tv, ExpressionEvaluator& ee, std::list& frontier) - : tensor_(std::move(tensor)), tv_(tv), ee_(ee), frontier_(frontier) {} + : tensor_(std::move(tensor)), ee_(ee), frontier_(frontier) {} at::Tensor run( const std::vector& rfactor, const std::vector& alloc) { auto backward_exprs = StmtSort::getExprsBetween( - tv_->fusion(), - {rfactor.begin(), rfactor.end()}, - {alloc.begin(), alloc.end()}); + {rfactor.begin(), rfactor.end()}, {alloc.begin(), alloc.end()}); std::reverse(backward_exprs.begin(), backward_exprs.end()); for (auto expr : backward_exprs) { handle(expr); @@ -894,9 +886,9 @@ at::Tensor transformOutputFromAllocationToRFactor( // forward and a backward traverse. std::list frontier(alloc.begin(), alloc.end()); NVF_ERROR(tensor.dim() == (int64_t)frontier.size()); - tensor = ForwardTraverseFromAllocToRFactor(tensor, tv, ee, frontier) + tensor = ForwardTraverseFromAllocToRFactor(tensor, ee, frontier) .run(rfactor, alloc); - tensor = BackwardTraverseFromAllocToRFactor(tensor, tv, ee, frontier) + tensor = BackwardTraverseFromAllocToRFactor(tensor, ee, frontier) .run(rfactor, alloc); NVF_ERROR(frontier.size() == rfactor.size()); // Now that all affine transformations are handled, and frontiers should @@ -987,26 +979,38 @@ std::vector allocateOutputs( std::vector outputs; outputs.reserve(output_info.size()); + + std::unordered_map outputs_map; + for (const auto output_idx : c10::irange(output_info.size())) { Val* out = kernel->outputs()[output_idx]; - auto [aliased_in, alias_info] = kernel->getOutputAlias(out); - at::Tensor aliased_in_tensor; - if (aliased_in != nullptr) { - const PolymorphicValue& aliased_in_val = - *inputs[IndexOfFusionInput(aliased_in, kernel)]; - NVF_ERROR( - aliased_in_val.is(), - "Alias io only supports tensor. Found ", - PolymorphicValue_functions::toString(aliased_in_val)); - aliased_in_tensor = aliased_in_val.as(); + // TODO: remove the else block and outputs_map when output aliasing is + // handled properly + auto iter = outputs_map.find(out); + if (iter == outputs_map.end()) { + auto [aliased_in, alias_info] = kernel->getOutputAlias(out); + at::Tensor aliased_in_tensor; + if (aliased_in != nullptr) { + const PolymorphicValue& aliased_in_val = + *inputs[IndexOfFusionInput(aliased_in, kernel)]; + NVF_ERROR( + aliased_in_val.is(), + "Alias io only supports tensor. Found ", + PolymorphicValue_functions::toString(aliased_in_val)); + aliased_in_tensor = aliased_in_val.as(); + } + auto output = allocateOutput( + output_info[output_idx], + aliased_in, + alias_info, + aliased_in_tensor, + device, + ee); + outputs_map[out] = output; + outputs.push_back(output); + } else { + outputs.push_back(iter->second); } - outputs.push_back(allocateOutput( - output_info[output_idx], - aliased_in, - alias_info, - aliased_in_tensor, - device, - ee)); } return outputs; } @@ -1781,6 +1785,7 @@ std::vector FusionExecutor::runFusion( const int hw_max_warps = prop->maxThreadsPerMultiProcessor / prop->warpSize; const float occupancy = (float)warps_per_sm / (float)hw_max_warps * 100.f; + setKernelOccupancy(occupancy); std::ostringstream oss; oss << std::fixed << std::setprecision(2) << occupancy << "%"; debug() << "blocks_per_sm= " << blocks_per_sm diff --git a/csrc/executor.h b/csrc/executor.h index 580b1b988a1..a9fd6c9ed51 100644 --- a/csrc/executor.h +++ b/csrc/executor.h @@ -205,6 +205,22 @@ class FusionExecutor : public NonCopyable { return measure_kernel_time_ ? kernel_time_ms_ : 0; } + //! get occupancy of the last kernel execution + float getKernelOccupancy() const { + NVF_ERROR( + kernel_occupancy_ > 0, + "Occupancy unknown, should run with dump occupancy or perf_debug_verbose"); + return kernel_occupancy_; + } + + void setKernelOccupancy(float occupancy) { + kernel_occupancy_ = occupancy; + } + + //! get register spills (load + store) of the compiled kernel + int getKernelRegisterSpills() const { + return compiled_kernel_->register_spills; + } //! Returns the input bytes accessed for a kernel //! \note It is important to sample the args struct prior to adding the // 1 output to the args struct @@ -549,6 +565,10 @@ class FusionExecutor : public NonCopyable { // is true float kernel_time_ms_ = 0; + // Heuristic tuning support: the last kernel occupancy, if + // DebugDumpOption::Occupancy is true + float kernel_occupancy_ = -1.0f; + // Profiling support: last kernel bytes processed in each input std::optional> bytes_processed_per_input_ = std::nullopt; diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index 549b647d11d..144e9a8350b 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -979,9 +979,13 @@ void fillCompileOptions( // Meanwhile, for forward compatibility (future device with // `unsupported_arch==True`), since SASS are not necessarily compatible, // we fallback to PTX instead. - const std::string compute = std::string("--gpu-architecture=") + + std::string compute = std::string("--gpu-architecture=") + (compile_to_sass ? "sm_" : "compute_") + std::to_string(major) + std::to_string(minor); + if (major == 9) { + // Hopper MMAs require 90a instead of 90 + compute += "a"; + } nvrtc_compile_driver.setOption(compute); nvrtc_compile_driver.setOption("-default-device"); @@ -1058,7 +1062,7 @@ void fillCompileOptions( } // Dump ptxas output if register spill is detected -void warnRegisterSpill(const std::string& compile_log) { +int warnRegisterSpill(const std::string& compile_log) { auto getRegisterSpillInfo = [](const std::string& log, const char* subStr) { auto it_end = std::search(log.begin(), log.end(), subStr, subStr + strlen(subStr)) - @@ -1093,6 +1097,7 @@ void warnRegisterSpill(const std::string& compile_log) { load_count > allowed_spill) { debug() << "WARNING: Register spill detected\n" << compile_log << std::endl; } + return store_count + load_count; } void createNvrtcProgram( @@ -1266,7 +1271,8 @@ std::unique_ptr getCompiledKernel( if (isOptionEnabled(EnableOption::WarnRegisterSpill) || compile_params.enable_ptxas_verbose) { - warnRegisterSpill(compiled_kernel->compile_log); + compiled_kernel->register_spills = + warnRegisterSpill(compiled_kernel->compile_log); } NVFUSER_CUDA_SAFE_CALL(cuModuleGetFunction( diff --git a/csrc/executor_utils.h b/csrc/executor_utils.h index c217bf28d7d..3ab0852fb1d 100644 --- a/csrc/executor_utils.h +++ b/csrc/executor_utils.h @@ -58,6 +58,7 @@ struct CompiledKernel : public NonCopyable { std::string kernel_name; std::string compile_args; long block_size = -1; + int register_spills = -1; }; // Returns executable function and the ptxas log from compilation diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index ee623065499..bea5bf0aa75 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -242,6 +242,11 @@ void Fusion::addInput(Val* input) { "Immediate scalar value cannot be added as an input. It is not necessary to pass it as an input."); } + NVF_CHECK( + !input->isFusionInput(), + "Val: ", + input->toString(), + " is already registered as input, duplicated inputs is not allowed"); inputs_.push_back(input); input->setIsFusionInput(true); @@ -375,7 +380,7 @@ bool Fusion::isNoOp() { } std::vector Fusion::inputsOf(Val* val) { - return InputsOf::output(this, val); + return InputsOf::output(val); } void Fusion::validateInputs() { @@ -528,7 +533,7 @@ void Fusion::printMath(bool from_outputs_only) { leaf_vals.push_back(val); } } - exprs_for_print = StmtSort::getExprsTo(this, leaf_vals); + exprs_for_print = StmtSort::getExprsTo(leaf_vals); } debug() << "\n%kernel_math {\n"; @@ -649,7 +654,7 @@ std::vector Fusion::usedMathVals() { // there can be vals that are created inside a fusion without using // anything from inputs. See, for example, tv0 in the // FusionOuterSplit test. - const auto inputs = InputsOf::outputs(this, outputs()); + const auto inputs = InputsOf::outputs(outputs()); auto used_math_vals = DependencyCheck::getAllValsBetween( {inputs.begin(), inputs.end()}, outputs()); // When an expre has multiple outputs and only some of them are diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 06d57c28de6..8bacea97f75 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -3703,7 +3703,7 @@ void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) { group->input_vals = IterVisitor::getInputsTo(group->inputs()); // Grab all expressions needed to produce to_visit - auto input_exprs = StmtSort::getExprsTo(completeFusion(), to_visit); + auto input_exprs = StmtSort::getExprsTo(to_visit); // Insert those expressions at the beginning of the group group->exprs_.insert( @@ -3963,7 +3963,7 @@ class ForceHalfAnnotation : public IterVisitor { val->getDataType().value() == DataType::BFloat16); }); - annotation.traverseTo(fusion, fp16_outputs); + annotation.traverseTo(fp16_outputs); return annotation.force_fp16_tv_set_; } diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 314ee914e5e..de1c4423686 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -333,22 +333,6 @@ Val* getProducerIndexWithPartialSplit( SimplifyingIrBuilder::create(diff->evaluate(), DataType::Index)); } -Val* getTensorBaseAddress(TensorView* tv) { - auto metadata = IrBuilder::metadataExpr(tv); - switch (auto memtype = tv->getMemoryType()) { - case MemoryType::Global: - return IrBuilder::getAttrExpr(metadata, "data"); - case MemoryType::Shared: { - auto output = IrBuilder::create(DataType::SMemAddress); - IrBuilder::create( - UnaryOpType::ToUnsignedSmemAddr, output, metadata); - return output; - } - default: - NVF_CHECK(false, "Unsupported memory type ", memtype); - } -} - } // namespace bool IndexCompute::hasUnswitchedDependentDomains(IterDomain* id) const { @@ -928,7 +912,7 @@ void IndexCompute::updateIndexMapFromPermissiveMap(const Expr* id_expr) { void IndexCompute::run() { const std::vector domain_vals(td_->leaf().begin(), td_->leaf().end()); - traverseTo(td_->fusion(), domain_vals, false); + traverseTo(domain_vals, false); } IterDomain* IndexCompute::maybeGetExactMapConcreteID(IterDomain* id) const { @@ -1035,7 +1019,7 @@ class UpdateLeafIndices : public IterVisitor { extent_map_(std::move(extent_map)) { const std::vector domain_vals(td_->leaf().begin(), td_->leaf().end()); - traverseTo(td_->fusion(), domain_vals, false); + traverseTo(domain_vals, false); } const std::unordered_map& indexMap() const { @@ -2353,7 +2337,7 @@ Val* Index::getProducerStridedIndices( FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices"); if (producer->domain()->noReductions().empty()) { if (generate_pointer) { - return getTensorBaseAddress(producer); + return IrBuilder::baseAddressExpr(producer); } else { return GpuLower::current()->kernel()->zeroVal(); } @@ -2364,7 +2348,7 @@ Val* Index::getProducerStridedIndices( producer, consumer, loops, rotated_loops, override_index)); if (generate_pointer) { return SimplifyingIrBuilder::addExpr( - getTensorBaseAddress(producer), index); + IrBuilder::baseAddressExpr(producer), index); } else { return index; } @@ -2376,7 +2360,8 @@ Val* Index::getProducerStridedIndices( index, IrBuilder::create( dataTypeSize(*producer->getDataType()), *index->getDataType())); - return IrBuilder::addExpr(getTensorBaseAddress(producer), index_bytes); + return IrBuilder::addExpr( + IrBuilder::baseAddressExpr(producer), index_bytes); } else { return index; } @@ -2422,7 +2407,7 @@ Val* Index::getConsumerStridedIndices( FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerStridedIndices"); if (consumer->domain()->noReductions().empty()) { if (generate_pointer) { - return getTensorBaseAddress(consumer); + return IrBuilder::baseAddressExpr(consumer); } else { return GpuLower::current()->kernel()->zeroVal(); } @@ -2433,7 +2418,7 @@ Val* Index::getConsumerStridedIndices( consumer, loops, rotated_loops, override_index)); if (generate_pointer) { return SimplifyingIrBuilder::addExpr( - getTensorBaseAddress(consumer), index); + IrBuilder::baseAddressExpr(consumer), index); } else { return index; } @@ -2445,7 +2430,8 @@ Val* Index::getConsumerStridedIndices( index, IrBuilder::create( dataTypeSize(*consumer->getDataType()), *index->getDataType())); - return IrBuilder::addExpr(getTensorBaseAddress(consumer), index_bytes); + return IrBuilder::addExpr( + IrBuilder::baseAddressExpr(consumer), index_bytes); } else { return index; } @@ -3285,7 +3271,7 @@ Val* Index::cpAsyncBulkIndex( box_dim, element_strides, TensorMapInterleave::NoInterleave, - TensorMapSwizzle::NoSwizzle, + MmaInputSmemSwizzle::None, TensorMapL2Promotion::NoL2Promotion, TensorMapFloatOOBFill::NoOOBFill); diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index d8e9b777b37..3833ed054d4 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -172,6 +172,14 @@ Val* IrBuilder::bitwiseOrExpr(Val* lhs, Val* rhs) { return newArithmeticExpr(BinaryOpType::BitwiseOr, lhs, rhs); } +Val* IrBuilder::lShiftExpr(Val* lhs, Val* rhs) { + return newArithmeticExpr(BinaryOpType::Lshift, lhs, rhs); +} + +Val* IrBuilder::rShiftExpr(Val* lhs, Val* rhs) { + return newArithmeticExpr(BinaryOpType::Rshift, lhs, rhs); +} + Val* IrBuilder::eqExpr(Val* lhs, Val* rhs) { return newLogicExpr(BinaryOpType::Eq, lhs, rhs); } @@ -265,6 +273,21 @@ Val* IrBuilder::metadataExpr(TensorView* tv) { return tv->fusion()->metadataOf(tv); } +Val* IrBuilder::baseAddressExpr(TensorView* tv) { + auto metadata = metadataExpr(tv); + switch (auto memtype = tv->getMemoryType()) { + case MemoryType::Global: + return getAttrExpr(metadata, "data"); + case MemoryType::Shared: { + auto output = create(DataType::SMemAddress); + create(UnaryOpType::ToUnsignedSmemAddr, output, metadata); + return output; + } + default: + NVF_CHECK(false, "Unsupported memory type ", memtype); + } +} + Val* SimplifyingIrBuilder::negExpr(Val* val) { if (val->isZeroInt()) { return val->container()->zeroVal(val->dtype()); diff --git a/csrc/ir/builder.h b/csrc/ir/builder.h index 21df7433e65..5630d312451 100644 --- a/csrc/ir/builder.h +++ b/csrc/ir/builder.h @@ -72,6 +72,8 @@ class IrBuilder { static Val* logicalOrExpr(Val* lhs, Val* rhs); static Val* bitwiseAndExpr(Val* lhs, Val* rhs); static Val* bitwiseOrExpr(Val* lhs, Val* rhs); + static Val* lShiftExpr(Val* lhs, Val* rhs); + static Val* rShiftExpr(Val* lhs, Val* rhs); static Val* eqExpr(Val* lhs, Val* rhs); static Val* neExpr(Val* lhs, Val* rhs); static Val* gtExpr(Val* lhs, Val* rhs); @@ -100,6 +102,10 @@ class IrBuilder { // Get tensor metadata static Val* metadataExpr(TensorView* tv); + // Get tensor base address, for gmem tensor, it is something like + // `T1.data`. For smem tensor, it is something like `toSmem(T1)`. + static Val* baseAddressExpr(TensorView* tv); + // Construct an array of values, or nested arrays of values. template static Val* arrayExpr(std::vector members) { diff --git a/csrc/ir/cloner.cpp b/csrc/ir/cloner.cpp index 1fdac07954e..87898b3e8e2 100644 --- a/csrc/ir/cloner.cpp +++ b/csrc/ir/cloner.cpp @@ -59,8 +59,7 @@ TensorView* RecomputeTv::recompute( "Cannot recompute buffers that are inputs of the fusion."); // Grab all the expressions used to generate the TensorView - auto exprs = - StmtSort::getExprsBetween(tv->fusion(), from, {tv}, false, false); + auto exprs = StmtSort::getExprsBetween(from, {tv}, false, false); // Run the replicator RecomputeTv replicator(tv->fusion()); @@ -91,7 +90,7 @@ TensorView* RecomputeTv::recompute( return cloned_val->as(); } -RecomputeTv::RecomputeTv(Fusion* fusion) : IrCloner(fusion), fusion_(fusion) { +RecomputeTv::RecomputeTv(Fusion* fusion) : IrCloner(fusion) { // Add inputs to the clones map to prevent cloning them. for (const auto inp : fusion->inputs()) { clones_map_[inp] = inp; @@ -115,8 +114,7 @@ Statement* RecomputeTv::handle(const Statement* s) { Statement* RecomputeTv::handle(const TensorDomain* td) { // Make sure to recompute the history of the iteration domains, explicitly go // through the expressions and send them to IrCloner. - auto exprs = - StmtSort::getExprsTo(fusion_, {td->leaf().begin(), td->leaf().end()}); + auto exprs = StmtSort::getExprsTo({td->leaf().begin(), td->leaf().end()}); for (auto expr : exprs) { IrCloner::handle(expr); diff --git a/csrc/ir/cloner.h b/csrc/ir/cloner.h index 7997a239e94..9a3e4ec95cd 100644 --- a/csrc/ir/cloner.h +++ b/csrc/ir/cloner.h @@ -128,8 +128,6 @@ class RecomputeTv : private IrCloner { RecomputeTv(Fusion* fusion); Statement* handle(const Statement* s) override; Statement* handle(const TensorDomain*); - - Fusion* fusion_; }; //! Clone an IR node, forwarding the arguments to the IrCloner constructor. diff --git a/csrc/ir/iostream.cpp b/csrc/ir/iostream.cpp index 824e6cb16f1..b670d196c4f 100644 --- a/csrc/ir/iostream.cpp +++ b/csrc/ir/iostream.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 1e781b6b899..e0961bb8da6 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -625,6 +625,12 @@ std::vector BinaryOp::evaluate( case BinaryOpType::Gcd: return {gcd(lhs, rhs)}; break; + case BinaryOpType::Lshift: + return {lhs << rhs}; + break; + case BinaryOpType::Rshift: + return {lhs >> rhs}; + break; default: NVF_CHECK( false, @@ -3020,7 +3026,7 @@ std::pair IterDomain::swizzle( !in_x->isReduction() && !in_y->isReduction(), "swizzled reduction not yet supported"); - for (auto input : InputsOf::outputs(in_x->fusion(), {in_x, in_y})) { + for (auto input : InputsOf::outputs({in_x, in_y})) { NVF_CHECK( !input->as()->isBroadcast(), "swizzling broadcast axes not yet supported"); diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index ccbe0679b03..50b11ff0896 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -460,7 +460,7 @@ class ValReplacementMutator : private OptOutMutator { // typically not used by anything else. If we don't grab that count, then it // would be a tensorview that doesn't get updated extents. Therefore, first // grab all leaves towards outputs and grab stmts from there. - auto stmts = StmtSort::getStmtsTo(fusion, allLeafOuts(fusion), true, true); + auto stmts = StmtSort::getStmtsTo(allLeafOuts(fusion), true, true); // Some fusions, such as standalone rand_like, can have disconnected DAG, so // we need some mechanism to make sure our replacement set is as complete as @@ -478,7 +478,7 @@ class ValReplacementMutator : private OptOutMutator { more.emplace_back(v); } } - auto more_stmts = StmtSort::getStmtsTo(fusion, more, true, true); + auto more_stmts = StmtSort::getStmtsTo(more, true, true); more_stmts.insert(more_stmts.end(), stmts.begin(), stmts.end()); for (auto stmt : more_stmts) { @@ -797,7 +797,6 @@ bool hasResizedRfactor(const TensorView* tv) { return false; } auto root_to_rf_exprs = StmtSort::getExprsBetween( - tv->fusion(), {tv->getRootDomain().begin(), tv->getRootDomain().end()}, {tv->getRFactorDomain().begin(), tv->getRFactorDomain().end()}); return std::any_of( @@ -840,7 +839,6 @@ class ValidateDomainEquivalence : private IterVisitor { toDelimitedString(derived_domain)); traverseBetween( - initial_domain.at(0)->fusion(), {initial_domain.begin(), initial_domain.end()}, {derived_domain.begin(), derived_domain.end()}); diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 0c3be5a2c7e..68f52787933 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -126,7 +126,7 @@ void IterVisitor::dispatch(Val* v) { // Implementation details: // We start with an entry in stmt_stack that is the outputs we want to -// process. We cannot process these outputs untill all Stmts in their history +// process. We cannot process these outputs until all Stmts in their history // have been processed, as those Stmts contain all dependencies to produce // these values. What we will do is traverse towards inputs until we hit a // leaf node. Once we hit a leaf node that node will be visited, then we will @@ -138,13 +138,16 @@ void IterVisitor::dispatch(Val* v) { // function to remove visited nodes from being re-added to the stack // (remove_visited). void IterVisitor::traverseBetween( - Fusion* fusion, const std::unordered_set& from, const std::vector& to, bool traverse_all_paths, bool traverse_into_members, bool traverse_attributes, bool traverse_siblings) { + if (to.empty()) { + return; + } + Fusion* fusion = to.front()->fusion(); FusionGuard fg(fusion); std::unordered_set visited; @@ -287,14 +290,12 @@ void IterVisitor::traverseBetween( } void IterVisitor::traverseTo( - Fusion* fusion, const std::vector& to, bool traverse_all_paths, bool traverse_into_members, bool traverse_attributes, bool traverse_siblings) { traverseBetween( - fusion, {}, to, traverse_all_paths, @@ -308,7 +309,7 @@ void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) { auto term_val_outs = fusion->getTerminatingOutputs(); if (!term_val_outs.empty()) { - traverseTo(fusion, term_val_outs, traverse_all_paths); + traverseTo(term_val_outs, traverse_all_paths); } } @@ -364,7 +365,7 @@ class Inputs : public IterVisitor { return {}; } Inputs inps(all_inputs); - inps.traverseTo(of[0]->fusion(), of); + inps.traverseTo(of); return inps.inputs_; } }; @@ -393,7 +394,7 @@ class AllVals : public IterVisitor { Fusion* fusion, const std::vector& from) { AllVals av; - av.traverseTo(fusion, from, false); + av.traverseTo(from, false); return av.vals; } }; @@ -451,21 +452,20 @@ void BackwardVisitor::dispatch(Val* val) { } void BackwardVisitor::traverseTo( - Fusion* fusion, const std::vector& from, bool traverseAllPaths) { + if (from.empty()) { + return; + } + Fusion* fusion = from.front()->fusion(); FusionGuard fg(fusion); // Reset members stmt_stack_.clear(); traversal_exprs_.clear(); - if (from.empty()) { - return; - } - auto vals = AllVals::get(fusion, from); - auto exprs = StmtSort::getExprsTo(fusion, from); + auto exprs = StmtSort::getExprsTo(from); { size_t pos = 0; @@ -603,7 +603,7 @@ struct Dependencies : public IterVisitor { std::unordered_set _dependencies, const std::vector& of) : dependencies_(std::move(_dependencies)) { - traverseTo(of[0]->fusion(), of, false); + traverseTo(of, false); }; public: @@ -650,7 +650,7 @@ struct FindOutputs : public IterVisitor { // tracing all paths like this. FindOutputs(const std::unordered_set& _of) : of_(_of) { auto fusion = (*of_.begin())->fusion(); - traverseTo(fusion, fusion->outputs(), true); + traverseTo(fusion->outputs(), true); }; static std::unordered_set getAllOutputsOf( @@ -719,7 +719,7 @@ class DependentVals : public IterVisitor { DependentVals(const std::unordered_set& _of) : of_(_of) { createBoundary(); auto fusion = (*of_.begin())->fusion(); - traverseTo(fusion, fusion->outputs(), false); + traverseTo(fusion->outputs(), false); }; public: @@ -755,7 +755,7 @@ class DependencyChains : public IterVisitor { DependencyChains(Val* _dependency, Val* _of, bool all_chains_ = false) : dependencies_({_dependency}) { - traverseTo(_of->fusion(), {_of}, all_chains_); + traverseTo({_of}, all_chains_); } DependencyChains(Val* _dependency, bool all_chains_ = false) @@ -882,7 +882,6 @@ std::vector StmtSort::getExprs( bool traverse_siblings) { auto terminating_outputs = fusion->getTerminatingOutputs(); return StmtSort::getExprsTo( - fusion, terminating_outputs, traverse_members, traverse_attributes, @@ -890,32 +889,25 @@ std::vector StmtSort::getExprs( } std::vector StmtSort::getExprsTo( - Fusion* fusion, const std::vector& to, bool traverse_members, bool traverse_attributes, bool traverse_siblings) { auto stmts = StmtSort::getStmtsTo( - fusion, to, traverse_members, traverse_attributes, traverse_siblings); + to, traverse_members, traverse_attributes, traverse_siblings); auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); std::vector exprs(filter.begin(), filter.end()); return exprs; } std::vector StmtSort::getExprsBetween( - Fusion* fusion, const std::vector& from, const std::vector& to, bool traverse_members, bool traverse_attributes, bool traverse_siblings) { auto stmts = StmtSort::getStmtsBetween( - fusion, - from, - to, - traverse_members, - traverse_attributes, - traverse_siblings); + from, to, traverse_members, traverse_attributes, traverse_siblings); auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); std::vector exprs(filter.begin(), filter.end()); return exprs; @@ -928,7 +920,6 @@ std::vector StmtSort::getStmts( bool traverse_siblings) { auto terminating_outputs = fusion->getTerminatingOutputs(); return StmtSort::getStmtsTo( - fusion, terminating_outputs, traverse_members, traverse_attributes, @@ -936,24 +927,17 @@ std::vector StmtSort::getStmts( } std::vector StmtSort::getStmtsTo( - Fusion* fusion, const std::vector& to, bool traverse_members, bool traverse_attributes, bool traverse_siblings) { StmtSort es; es.traverseTo( - fusion, - to, - false, - traverse_members, - traverse_attributes, - traverse_siblings); + to, false, traverse_members, traverse_attributes, traverse_siblings); return es.stmts; } std::vector StmtSort::getStmtsBetween( - Fusion* fusion, const std::vector& from, const std::vector& to, bool traverse_members, @@ -961,7 +945,6 @@ std::vector StmtSort::getStmtsBetween( bool traverse_siblings) { StmtSort es; es.traverseBetween( - fusion, {from.begin(), from.end()}, to, false, @@ -979,15 +962,13 @@ void InputsOf::dispatch(Val* v) { } } -std::vector InputsOf::output(Fusion* fusion, Val* output_) { - return outputs(fusion, {output_}); +std::vector InputsOf::output(Val* output_) { + return outputs({output_}); } -std::vector InputsOf::outputs( - Fusion* fusion, - const std::vector& outputs_) { +std::vector InputsOf::outputs(const std::vector& outputs_) { InputsOf io; - io.traverseTo(fusion, outputs_, false); + io.traverseTo(outputs_, false); return io.ordered_inputs; } @@ -995,14 +976,14 @@ std::vector InputsOf::outputs( bool DeadCodeRemover::run() { // First we build a set of all live Statements so that we can detect dead // branches. - for (auto stmt : StmtSort::getStmtsTo(fusion_, fusion_->outputs())) { + for (auto stmt : StmtSort::getStmtsTo(fusion_->outputs())) { markLive(stmt); } // Note that StmtSort::getStmtsTo() is also run in traverseTo. In the future, // we could potentially refactor this so that derived classes from // BackwardVisitor can make use of that traversal instead of repeating it. - traverseTo(fusion_, fusion_->outputs(), false); + traverseTo(fusion_->outputs(), false); // We do not remove Statements from the Fusion while traversing, to avoid // dereferencing invalid pointers. Instead, we wait until this point to do the diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index 2a0e5b92188..f1f6610b854 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -99,7 +99,6 @@ class IterVisitor : public OptOutDispatch { //! active multi-output expressions, even if those Expr outputs are not used //! in paths to Fusion outputs. void traverseTo( - Fusion* fusion, const std::vector& to, bool traverse_all_paths = false, bool traverse_into_members = false, @@ -126,7 +125,6 @@ class IterVisitor : public OptOutDispatch { //! active multi-output expressions, even if those Expr outputs are not used //! in paths to Fusion outputs. void traverseBetween( - Fusion* fusion, const std::unordered_set& from, const std::vector& to, bool traverse_all_paths = false, @@ -238,10 +236,7 @@ class BackwardVisitor : public OptOutDispatch { // traverseAllPaths = false only call handle on each Statement* once // traverseAllPaths = true traverses all paths from nodes in from to inputs. // Handle on a Statement* for every path from "from" nodes, to inputs. - void traverseTo( - Fusion* fusion, - const std::vector& from, - bool traverseAllPaths = false); + void traverseTo(const std::vector& from, bool traverseAllPaths = false); bool must_cover_all_expr_outputs_ = true; }; @@ -313,7 +308,6 @@ class StmtSort : public IterVisitor { // Returns ordered Statements required to produce 'to', including 'to'. static std::vector getStmtsTo( - Fusion* fusion, const std::vector& to, bool traverse_members = false, bool traverse_attributes = false, @@ -337,7 +331,6 @@ class StmtSort : public IterVisitor { // If traverse_members it will also extract all member nodes in the sorted // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc static std::vector getStmtsBetween( - Fusion* fusion, const std::vector& from, const std::vector& to, bool traverse_members = false, @@ -353,7 +346,6 @@ class StmtSort : public IterVisitor { // Same as getStmts version but filters to only return the Expr*s static std::vector getExprsTo( - Fusion* fusion, const std::vector& to, bool traverse_members = false, bool traverse_attributes = false, @@ -361,7 +353,6 @@ class StmtSort : public IterVisitor { // Same as getStmts version but filters to only return the Expr*s static std::vector getExprsBetween( - Fusion* fusion, const std::vector& from, const std::vector& to, bool traverse_members = false, @@ -379,10 +370,8 @@ class InputsOf : public IterVisitor { void dispatch(Val* v) final; public: - static std::vector output(Fusion* fusion, Val* output_); - static std::vector outputs( - Fusion* fusion, - const std::vector& outputs_); + static std::vector output(Val* output_); + static std::vector outputs(const std::vector& outputs_); }; //! This is a generic traversal class that is used to modify a Fusion graph by diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index 4081f3c181f..097e372c261 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -1579,7 +1579,7 @@ EncodeTensorMapTiled::EncodeTensorMapTiled( Val* box_dim, Val* element_strides, tma::TensorMapInterleave interleave, - tma::TensorMapSwizzle swizzle, + MmaInputSmemSwizzle swizzle, tma::TensorMapL2Promotion l2_promotion, tma::TensorMapFloatOOBFill oob_fill) : Expr(passkey) { @@ -1633,7 +1633,7 @@ std::string EncodeTensorMapTiled::toString(int indent_size) const { << ", element_strides=" << elementStrides()->toString() << ", interleave=" << interleave() - << ", swizzle=" << swizzle() + << ", swizzle=" << nvfuser::toString(swizzle()) << ", l2_promotion=" << l2Promotion() << ", oob_fill=" << oobFill() << ")\n"; return ss.str(); @@ -1647,7 +1647,8 @@ std::string EncodeTensorMapTiled::toInlineString(int indent_size) const { << ", global_strides=" << globalStrides()->toInlineString() << ", box_dim=" << boxDim()->toInlineString() << ", element_strides=" << elementStrides()->toInlineString() - << ", interleave=" << interleave() << ", swizzle=" << swizzle() + << ", interleave=" << interleave() + << ", swizzle=" << nvfuser::toString(swizzle()) << ", l2_promotion=" << l2Promotion() << ", oob_fill=" << oobFill() << ")"; return ss.str(); } diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index f3728b73ace..211272466d9 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -1392,7 +1393,7 @@ class EncodeTensorMapTiled : public Expr { Val* box_dim, Val* element_strides, tma::TensorMapInterleave interleave, - tma::TensorMapSwizzle swizzle, + MmaInputSmemSwizzle swizzle, tma::TensorMapL2Promotion l2_promotion, tma::TensorMapFloatOOBFill oob_fill); @@ -1437,8 +1438,8 @@ class EncodeTensorMapTiled : public Expr { return attribute(2); } - const tma::TensorMapSwizzle& swizzle() const { - return attribute(3); + const MmaInputSmemSwizzle& swizzle() const { + return attribute(3); } const tma::TensorMapL2Promotion& l2Promotion() const { diff --git a/csrc/mma_type.cpp b/csrc/mma_type.cpp index ebc9c1bb62f..41b74abca7e 100644 --- a/csrc/mma_type.cpp +++ b/csrc/mma_type.cpp @@ -75,6 +75,22 @@ std::string toString(MmaMacro macro) { return ss.str(); } +std::string toString(MmaInputSmemSwizzle swizzle) { + switch (swizzle) { + case MmaInputSmemSwizzle::None: + return "NoSwizzle"; + case MmaInputSmemSwizzle::B32: + return "32B"; + case MmaInputSmemSwizzle::B64: + return "64B"; + case MmaInputSmemSwizzle::B128: + return "128B"; + default: + NVF_CHECK(false, "Unknown tensor map swizzle type!"); + break; + } +} + size_t hash(MmaMacro macro) { return std::hash{}(static_cast(macro)); } diff --git a/csrc/mma_type.h b/csrc/mma_type.h index bee5c54b7b2..4dbdb00eafa 100644 --- a/csrc/mma_type.h +++ b/csrc/mma_type.h @@ -243,11 +243,23 @@ int getInputBRegisterSize(MmaMacro macro); // Unpack MMA op shape GemmTile getMmaOpShape(MmaMacro macro); +// Warning: The values of the enum class must match the matrix descriptor as +// specified in: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor +// Do not edit the values of the enum class unless you know what you are doing. +enum class MmaInputSmemSwizzle { + None = 0, + B128 = 1, + B64 = 2, + B32 = 3, +}; + // MMA stringify utils std::string toString(MmaLayout input_layout); std::string toString(const GemmTile& tile); std::string toString(const MatMulTileOptions& opts); std::string toString(MmaMacro macro); +std::string toString(MmaInputSmemSwizzle swizzle); // MMA hash utils size_t hash(MmaMacro macro); diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 3d8ae73a430..3a06355c77c 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -116,7 +116,7 @@ std::vector PipelineExecutor::runWithInput( } // Run through the stages to launch kernel - traverseTo(runtime_.pipeline_, runtime_.pipeline_->outputs()); + traverseTo(runtime_.pipeline_->outputs()); // Collect global outputs from context std::vector outputs; diff --git a/csrc/multidevice/pipeline.cpp b/csrc/multidevice/pipeline.cpp index ffe208b7250..e5a62cbeba2 100644 --- a/csrc/multidevice/pipeline.cpp +++ b/csrc/multidevice/pipeline.cpp @@ -282,7 +282,7 @@ class PipelinePrinter : public IterVisitor { string_ << "}\n"; string_ << "Pipeline's Traversal inputs --> outputs {\n"; - traverseTo(pipeline_, pipeline_->outputs()); + traverseTo(pipeline_->outputs()); string_ << "}\n"; string_ << "Pipeline's outputs:{\n"; diff --git a/csrc/non_divisible_split.cpp b/csrc/non_divisible_split.cpp index 406cb5525f4..ad741004573 100644 --- a/csrc/non_divisible_split.cpp +++ b/csrc/non_divisible_split.cpp @@ -26,7 +26,7 @@ void NonDivisibleSplitInfo::build(Fusion* fusion) { tv->getLeafDomain().begin(), tv->getLeafDomain().end()); current_tv_ = tv; clearReachability(); - traverseTo(fusion, domain_vals); + traverseTo(domain_vals); current_tv_ = nullptr; } diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 438ad5dff0b..d7ef8d13a01 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -1478,13 +1478,19 @@ TensorView* expand(TensorView* inp, const std::vector& expanded_sizes) { // already done when constructing out_id_builder. out_id_builder.extent(inp_id->extent()); } else if ( - inp_id->isBroadcast() && + (inp_id->isBroadcast() || + // special patch for Symbolic IterDomain with a static size-1 extent + // See Issue: https://github.com/NVIDIA/Fuser/pull/1393 + (inp_id->isSymbolic() && inp_id->extent()->isConstInt() && + inp_id->extent()->evaluate() == 1)) && (!expanded_size_int.hasValue() || expanded_size_int != 1)) { // When input id is a broadcast, expand the extent to the given // size, which can be concrete or symbolic. expanded = true; auto expanded_extent = maybeCastOp(DataType::Index, expanded_sizes[i]); out_id_builder.expanded_extent(expanded_extent); + // need to mark iter type as Broadcast for Symbolic input domains + out_id_builder.iter_type(IterType::Broadcast); maybe_expanded_sizes[i] = expanded_extent; } else if (!inp_id->extent()->isConstInt()) { // Input id is non-broadcast and its extent is symbolic. Promote diff --git a/csrc/optimization/alias_analysis.cpp b/csrc/optimization/alias_analysis.cpp index 8916a0e3aae..e632c84342b 100644 --- a/csrc/optimization/alias_analysis.cpp +++ b/csrc/optimization/alias_analysis.cpp @@ -210,7 +210,7 @@ void AliasFinder::handle(const LoadStoreOp* permute) { // For example, // // in: rfactor=[i0,i1,i2], allocation=[i2,i0,i1] - // out = permute(in, {2, 0, 1}) + // out = permute(in, {1, 0, 2}) // out: root=[i3,i4,i5], rfactor=[i4,i3,i5] // // `out`'s preferred allocation domain is [i5,i3,i4]. This allocation domain @@ -359,6 +359,17 @@ Layout AliasAnalysisResult::preferredLayout(const Val* v) const { return {tv->getMaybeAllocationDomain(), tv->getContiguity()}; } +std::string AliasAnalysisResult::toString(const int indent_size) const { + std::stringstream ss; + for (const auto& [alias, source_and_layout] : alias_to_source_) { + const auto& [source, layout] = source_and_layout; + indent(ss, indent_size) + << alias->toString() << " is an alias of " << source->toString() + << " if its layout is " << layout.toString() << std::endl; + } + return ss.str(); +} + AliasAnalysisResult findAliases(Fusion* fusion) { AliasAnalysisResult analysis; AliasFinder finder(analysis); diff --git a/csrc/optimization/alias_analysis.h b/csrc/optimization/alias_analysis.h index cdafa743c6c..8a302e84a29 100644 --- a/csrc/optimization/alias_analysis.h +++ b/csrc/optimization/alias_analysis.h @@ -40,6 +40,8 @@ class AliasAnalysisResult { // preferred layout. void add(const TensorView* alias, const TensorView* source, Layout&& layout); + std::string toString(int indent_size) const; + private: // Maps aliases (e.g. the output of a View) to their direct sources (e.g. the // input of the same View). Also stores the preferred output layout for the diff --git a/csrc/optimization/mark_alias.cpp b/csrc/optimization/mark_alias.cpp index 3fe5cbf3003..2e14e2c2eb1 100644 --- a/csrc/optimization/mark_alias.cpp +++ b/csrc/optimization/mark_alias.cpp @@ -15,6 +15,11 @@ namespace nvfuser::optimization { void MarkAliasPass::runPass(Fusion* fusion) { const AliasAnalysisResult alias_analysis = findAliases(fusion); + if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { + debug() << "Alias analysis result:" << std::endl; + debug() << alias_analysis.toString(/*indent_size=*/1) << std::endl; + } + for (TensorView* out : ir_utils::filterByType(fusion->outputs())) { // Lazy move: we could check compatibility and only give up when diff --git a/csrc/partial_split_map.cpp b/csrc/partial_split_map.cpp index 9540c6b08a2..3eb23222e3a 100644 --- a/csrc/partial_split_map.cpp +++ b/csrc/partial_split_map.cpp @@ -16,7 +16,7 @@ void PartialSplitMap::build(Fusion* fusion) { for (auto tv : ir_utils::filterByType(used_vals)) { auto exprs = StmtSort::getExprsTo( - fusion, {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); + {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // Only needs to check root domains as partial split is only // allowed with root domains diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 18d1269f7e0..05eff3f7f78 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -322,7 +322,7 @@ class FindInputDomains : BackwardVisitor { } DomainKeySet find() { - traverseTo(tv_->fusion(), {tv_}); + traverseTo({tv_}); return input_keys_; } @@ -782,7 +782,7 @@ ComputeAtRootDomainMapBuilder::ComputeAtRootDomainMapBuilder( map_through_reduction_(map_through_reduction) { Fusion* fusion = FusionGuard::getCurFusion(); NVF_ERROR(fusion != nullptr); - traverseTo(fusion, fusion->outputs(), false); + traverseTo(fusion->outputs(), false); if (!pending_map_.empty()) { std::stringstream ss; ss << "pending map:\n"; @@ -1241,7 +1241,7 @@ class ExactRootDomainMapBuilder : private IterVisitor { Fusion* fusion, DisjointSets& eq_sets) : eq_sets_(eq_sets) { - traverseTo(fusion, fusion->outputs()); + traverseTo(fusion->outputs()); } private: diff --git a/csrc/scheduler/reduction_utils.cpp b/csrc/scheduler/reduction_utils.cpp index eb545b1a7dc..3a01d89790d 100644 --- a/csrc/scheduler/reduction_utils.cpp +++ b/csrc/scheduler/reduction_utils.cpp @@ -740,15 +740,9 @@ class PersistentBufferProjector { persistent_buffers.begin(), persistent_buffers.end()); for (auto buffer_i : c10::irange(persistent_buffers.size())) { auto buffer = persistent_buffers[buffer_i]; - // skip reduction buffers - if (buffer->hasReduction()) { - continue; - } const auto& producers = ir_utils::producerTvsOf(buffer); - if (!producers.empty() && - std::all_of(producers.begin(), producers.end(), [&](auto producer) { - return persistent_buffer_set.count(producer) > 0; - })) { + if (scheduler_utils::canProjectToPersistentProducer( + buffer, producers, persistent_buffer_set)) { projectToInputOrImmediatePersistentProducer( (int)buffer_i, std::vector(producers.begin(), producers.end())); diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 07eb09a1723..fe3a5b6f457 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -142,7 +142,7 @@ struct TransposeViewPropagator : public MaxInfoSpanningTree::Propagator { // propagation travelling across view op. Note this is a conservative check, // since view does NOT necessarily always introduce incoherent transform // that would break the propagation. - auto chain_exprs = StmtSort::getExprsBetween(from->fusion(), {from}, {to}); + auto chain_exprs = StmtSort::getExprsBetween({from}, {to}); if (!ir_utils::filterByType(chain_exprs).empty()) { should_reject = true; }; @@ -239,9 +239,7 @@ class DomainMap : public pointwise_utils::DomainMap { " in tensor ", tv); auto replay_exprs = StmtSort::getExprsBetween( - tv->fusion(), - {mapped_id}, - {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); + {mapped_id}, {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); // Project the root id to leaf id. Similar to projectIdToRFactor. for (auto expr : replay_exprs) { if (expr->isA()) { diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 10dd0d51618..70775a55939 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -781,6 +781,23 @@ getScopePersistenceFactors( } // namespace +// Returns true if a persistent tv can be projected to its persistent producers. +bool canProjectToPersistentProducer( + TensorView* buffer, + const std::vector& producers, + const std::unordered_set& persistent_buffer_set) { + if (buffer->hasReduction() || producers.empty()) { + return false; + } + if (std::all_of(producers.begin(), producers.end(), [&](auto producer) { + return persistent_buffer_set.count(producer) > 0; + })) { + return true; + } else { + return false; + } +} + PersistentBufferSizeReturn persistentBufferSize( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, @@ -847,14 +864,19 @@ PersistentBufferSizeReturn persistentBufferSize( // Buffers involved in normal persistence std::vector persistent_mask(all_buffers.size(), false); - + std::unordered_set persistent_buffer_set( + persistent_buffers.begin(), persistent_buffers.end()); for (auto buffer_i : c10::irange(persistent_buffers.size())) { - persistent_mask[buffer_i] = true; + auto buffer = persistent_buffers[buffer_i]; + const auto& producers = ir_utils::producerTvsOf(buffer); + if (!canProjectToPersistentProducer( + buffer, producers, persistent_buffer_set)) { + persistent_mask[buffer_i] = true; + } } // Buffers involved in projected to inputs std::vector projected_mask(all_buffers.size(), true); - for (auto buffer_i : c10::irange(persistent_buffers.size())) { auto buffer = persistent_buffers[buffer_i]; // Not a projectable buffer, or an input of a projectable buffer @@ -1088,7 +1110,7 @@ IterDomain* projectIdToRoot( return reference_id; } - auto replay_exprs = StmtSort::getExprsTo(tv->fusion(), {reference_id}); + auto replay_exprs = StmtSort::getExprsTo({reference_id}); if (replay_exprs.empty()) { return reference_id; } @@ -1154,9 +1176,7 @@ IterDomain* projectIdToRFactor( } auto replay_exprs = StmtSort::getExprsTo( - tv->fusion(), - {tv->getRFactorDomain().begin(), tv->getRFactorDomain().end()}, - false); + {tv->getRFactorDomain().begin(), tv->getRFactorDomain().end()}, false); if (replay_exprs.empty()) { return reference_id; } @@ -1831,7 +1851,6 @@ DisjointSets disjointRFactorSets(Fusion* fusion) { // rfactor domains they should be considered "contaminated". for (auto tv : ir_utils::allTvs(fusion)) { for (auto expr : StmtSort::getExprsTo( - fusion, {tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()})) { if (expr->isA()) { @@ -1882,7 +1901,7 @@ bool breakIsDisjoint(std::vector group_ids, int pos) { std::unordered_map domainReorderAsRfactorMap(TensorView* tv) { FusionGuard fg(tv->fusion()); auto transform_exprs = StmtSort::getExprsTo( - tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); + {tv->getLeafDomain().begin(), tv->getLeafDomain().end()}); // simply update this vector of id's as progressing through the transformation // expressions. We'll always insert the result of split in the location of the // input, and insert the merge result in the position of the inner dimension. @@ -1965,7 +1984,6 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { // rfactor domains they should be considered "contaminated". for (auto tv : ir_utils::allTvs(fusion)) { for (auto expr : StmtSort::getExprsBetween( - fusion, {tv->getRootDomain().begin(), tv->getRootDomain().end()}, {tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()})) { diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index a85c1029d1c..879157c6d65 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -190,6 +190,13 @@ struct PersistentBufferInfo { // can simply be read multiple times from GMEM in the same kernel. PersistentBufferInfo persistentBuffers(Fusion* fusion); +// A persistent tv can be projected to its producers when all the producers are +// persistent tvs and there is no reduction op. +bool canProjectToPersistentProducer( + TensorView* buffer, + const std::vector& producers, + const std::unordered_set& persistent_buffer_set); + struct ReductionTvProperties { // How many elements in tensor view are there to reduce. int64_t total_reduction_numel = 1; diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index be18e29ffb2..6d4ceb54ec5 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -377,9 +377,7 @@ std::vector ContiguousInnerDimensionsMapper::projectId( // empty backward exprs, vice versa. auto backward_exprs = StmtSort::getExprsBetween( - frontier.front()->fusion(), - {to.begin(), to.end()}, - {frontier.begin(), frontier.end()}); + {to.begin(), to.end()}, {frontier.begin(), frontier.end()}); // Mapping from rfactor to root, reverse expressions std::reverse(backward_exprs.begin(), backward_exprs.end()); @@ -407,9 +405,7 @@ std::vector ContiguousInnerDimensionsMapper::projectId( } auto forward_exprs = StmtSort::getExprsBetween( - frontier.front()->fusion(), - {frontier.begin(), frontier.end()}, - {to.begin(), to.end()}); + {frontier.begin(), frontier.end()}, {to.begin(), to.end()}); // Map forward through transforms since we're going from root to rfactor for (auto* expr : forward_exprs) { diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 86406e01034..34d2a86930d 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -106,9 +106,7 @@ class ForwardTraverseFromRFactorToAlloc { const std::vector& rfactor, const std::vector& alloc) { auto forward_exprs = StmtSort::getExprsBetween( - tv->fusion(), - {rfactor.begin(), rfactor.end()}, - {alloc.begin(), alloc.end()}); + {rfactor.begin(), rfactor.end()}, {alloc.begin(), alloc.end()}); for (auto expr : forward_exprs) { handle(expr); } @@ -201,9 +199,7 @@ class BackwardTraverseFromRFactorToAlloc { const std::vector& rfactor, const std::vector& alloc) { auto backward_exprs = StmtSort::getExprsBetween( - tv->fusion(), - {alloc.begin(), alloc.end()}, - {rfactor.begin(), rfactor.end()}); + {alloc.begin(), alloc.end()}, {rfactor.begin(), rfactor.end()}); std::reverse(backward_exprs.begin(), backward_exprs.end()); for (auto expr : backward_exprs) { handle(expr); diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 6c9657b2f80..ae57acd474c 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -820,7 +820,7 @@ TensorView* TensorView::swizzle( // Disable unsupported use cases at the current step. // Currently do not support reducing or broadcasting // swizzled dimensions. - auto all_inputs = InputsOf::outputs(fusion(), {axis(x), axis(y)}); + auto all_inputs = InputsOf::outputs({axis(x), axis(y)}); for (auto id : ir_utils::filterByType(all_inputs)) { NVF_ERROR( !id->isBroadcast() && !id->isReduction(), diff --git a/csrc/tma.cpp b/csrc/tma.cpp index b95bad474ef..906209ca493 100644 --- a/csrc/tma.cpp +++ b/csrc/tma.cpp @@ -58,15 +58,15 @@ inline CUtensorMapInterleave getCUtensorMapInterleave( } } -inline CUtensorMapSwizzle getCUtensorMapSwizzle(TensorMapSwizzle swizzle) { +inline CUtensorMapSwizzle getCUtensorMapSwizzle(MmaInputSmemSwizzle swizzle) { switch (swizzle) { - case TensorMapSwizzle::NoSwizzle: + case MmaInputSmemSwizzle::None: return CU_TENSOR_MAP_SWIZZLE_NONE; - case TensorMapSwizzle::B32: + case MmaInputSmemSwizzle::B32: return CU_TENSOR_MAP_SWIZZLE_32B; - case TensorMapSwizzle::B64: + case MmaInputSmemSwizzle::B64: return CU_TENSOR_MAP_SWIZZLE_64B; - case TensorMapSwizzle::B128: + case MmaInputSmemSwizzle::B128: return CU_TENSOR_MAP_SWIZZLE_128B; default: NVF_ERROR(false, "Unknown tensor map swizzle type!"); @@ -130,27 +130,6 @@ std::ostream& operator<<(std::ostream& os, TensorMapInterleave interleave) { return os; } -std::ostream& operator<<(std::ostream& os, TensorMapSwizzle swizzle) { - switch (swizzle) { - case TensorMapSwizzle::NoSwizzle: - os << "NoSwizzle"; - break; - case TensorMapSwizzle::B32: - os << "32B"; - break; - case TensorMapSwizzle::B64: - os << "64B"; - break; - case TensorMapSwizzle::B128: - os << "128B"; - break; - default: - NVF_CHECK(false, "Unknown tensor map swizzle type!"); - break; - } - return os; -} - std::ostream& operator<<(std::ostream& os, TensorMapL2Promotion l2_promotion) { switch (l2_promotion) { case TensorMapL2Promotion::NoL2Promotion: @@ -195,7 +174,7 @@ Val* encodeTensorMapTiled( Val* box_dim, Val* element_strides, TensorMapInterleave interleave, - TensorMapSwizzle swizzle, + MmaInputSmemSwizzle swizzle, TensorMapL2Promotion l2_promotion, TensorMapFloatOOBFill oob_fill) { auto output = IrBuilder::create( diff --git a/csrc/tma.h b/csrc/tma.h index c36f5a78e5c..adbf73f15d2 100644 --- a/csrc/tma.h +++ b/csrc/tma.h @@ -10,6 +10,7 @@ #include +#include #include // Note: [TMA support in nvFuser] @@ -84,12 +85,10 @@ namespace nvfuser { namespace tma { enum class TensorMapInterleave { NoInterleave, B16, B32 }; -enum class TensorMapSwizzle { NoSwizzle, B32, B64, B128 }; enum class TensorMapL2Promotion { NoL2Promotion, B64, B128, B256 }; enum class TensorMapFloatOOBFill { NoOOBFill, NaN_Request_Zero_FMA }; std::ostream& operator<<(std::ostream& os, TensorMapInterleave interleave); -std::ostream& operator<<(std::ostream& os, TensorMapSwizzle swizzle); std::ostream& operator<<(std::ostream& os, TensorMapL2Promotion l2_promotion); std::ostream& operator<<(std::ostream& os, TensorMapFloatOOBFill oob_fill); @@ -117,7 +116,7 @@ Val* encodeTensorMapTiled( Val* box_dim, Val* element_strides, TensorMapInterleave interleave, - TensorMapSwizzle swizzle, + MmaInputSmemSwizzle swizzle, TensorMapL2Promotion l2_promotion, TensorMapFloatOOBFill oob_fill); diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 16d08a9dcdf..164ce205e40 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -257,7 +257,7 @@ void ReplayTransformations::runReplay() { // Switch outDomain to a vector to start the traversal std::vector traversal_vals( target_domain_.begin(), target_domain_.end()); - traverseTo(traversal_vals[0]->fusion(), traversal_vals); + traverseTo(traversal_vals); if (error_on_failure_) { NVF_ERROR( @@ -321,9 +321,8 @@ BestEffortReplay::BestEffortReplay( } // Grab expr history of iter domains in target_domain - std::vector target_exprs = StmtSort::getExprsTo( - FusionGuard::getCurFusion(), - std::vector(target_domain.begin(), target_domain.end())); + std::vector target_exprs = + StmtSort::getExprsTo({target_domain.begin(), target_domain.end()}); // If we check how an IterDomain was generated, it should only use an // IterDomain in an expression once. We pull a map from the input @@ -332,9 +331,8 @@ BestEffortReplay::BestEffortReplay( // replay_domain map. // Map replay domain's IterDomains to the Exprs they're used in - std::vector replay_exprs = StmtSort::getExprsTo( - FusionGuard::getCurFusion(), - std::vector(replay_domain.begin(), replay_domain.end())); + std::vector replay_exprs = + StmtSort::getExprsTo({replay_domain.begin(), replay_domain.end()}); // Track which id's in replay have to be replayed to guarantee rfactor // transformations. The iteration domains in the rfactor axes don't have @@ -720,11 +718,8 @@ ForwardingInfo::ForwardingInfo( // We have root axes in active_tv that don't exist in the inactive tensor, // now forward those to include all id's in active_tv comprised of only axes // not in the inactive tensor. - auto active_tv_history = StmtSort::getExprsTo( - FusionGuard::getCurFusion(), - std::vector( - active_tv->domain()->leaf().begin(), - active_tv->domain()->leaf().end())); + auto active_tv_history = StmtSort::getExprsTo(std::vector( + active_tv->domain()->leaf().begin(), active_tv->domain()->leaf().end())); auto isInForwardIdSet = [&forwarded_ids](IterDomain* input_id) { return forwarded_ids.count(input_id) > 0; @@ -870,8 +865,8 @@ void BestEffortReplay::addComplimentLeafIDs( } // Grab all exprs used to make the forwarded compliments - auto compliment_exprs = StmtSort::getExprsTo( - FusionGuard::getCurFusion(), {compliments.begin(), compliments.end()}); + auto compliment_exprs = + StmtSort::getExprsTo({compliments.begin(), compliments.end()}); // Figure out if there are any leaves in compliment_exprs that aren't // the forwarded id diff --git a/csrc/type.cpp b/csrc/type.cpp index 189c6d18717..65b09ccdc3c 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -843,8 +843,10 @@ constexpr unsigned int supported_switch_pair(PrimDataType t1, PrimDataType t2) { return ((unsigned int)t1 << _WORD_SHIFT) + (unsigned int)t2; } -static const char* supported_casts2string( - const std::pair& t) { +static const char* supported_casts2string(std::pair t) { + if (t.first == DataType::SMemAddress) { + t.first = DataType::UInt32; + } switch (supported_switch_pair( std::get(t.first.type), std::get(t.second.type))) { diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index 9c1a3473581..6afe6b6cb60 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -134,7 +134,7 @@ def execute( ) msg += ( f"Here's a script to reproduce the error:\n" - "```\n" + "```python\n" "import torch\n" "from nvfuser import FusionDefinition, DataType\n" f"{self}" @@ -158,8 +158,9 @@ def execute( f".as_strided({tuple(i.size())}, {tuple(i.stride())}),\n" ) else: + upper_bound = 2 if i.dtype == torch.bool else 10 msg += ( - f" torch.randint(0, 10, ({sz},), dtype={i.dtype}, device='{i.device}')" + f" torch.randint(0, {upper_bound}, ({sz},), dtype={i.dtype}, device='{i.device}')" f".as_strided({tuple(i.size())}, {tuple(i.stride())}),\n" ) else: diff --git a/python_tests/test_python_frontend.py b/python_tests/test_python_frontend.py index 868dec315f5..435402b6890 100644 --- a/python_tests/test_python_frontend.py +++ b/python_tests/test_python_frontend.py @@ -2930,6 +2930,62 @@ def fusion_func(fd: FusionDefinition) -> None: self.assertEqual(nvf_out[1], t16) # T16 == T20 self.assertEqual(nvf_out[2], t31) + def test_issue1393(self): + inputs = [ + torch.randn((5,), dtype=torch.float16, device="cuda:0").as_strided( + (3, 4, 5), (0, 0, 1) + ), + torch.randn((3,), dtype=torch.float16, device="cuda:0").as_strided( + (3, 4), (1, 0) + ), + torch.randn((4,), dtype=torch.float16, device="cuda:0").as_strided( + (3, 4), (0, 1) + ), + ] + + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=[None, None, True], + dtype=DataType.Half, + is_cpu=False, + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, None], + dtype=DataType.Half, + is_cpu=False, + ) + T2 = fd.define_tensor( + shape=[-1, -1], + contiguity=[None, True], + dtype=DataType.Half, + is_cpu=False, + ) + T3 = fd.ops.cast(T1, dtype=DataType.Float) + T4 = fd.ops.cast(T2, dtype=DataType.Float) + T5 = fd.ops.mul(T3, T4) + T6 = fd.ops.cast(T5, dtype=DataType.Half) + S7 = fd.define_scalar(3, dtype=DataType.Int) + S8 = fd.define_scalar(4, dtype=DataType.Int) + S9 = fd.define_scalar(1, dtype=DataType.Int) + V10 = fd.define_vector([S7, S8, S9], dtype=DataType.Int) + T11 = fd.ops.reshape(T6, new_shape=V10) + S12 = fd.define_scalar(3, dtype=DataType.Int) + S13 = fd.define_scalar(4, dtype=DataType.Int) + S14 = fd.define_scalar(5, dtype=DataType.Int) + V15 = fd.define_vector([S12, S13, S14], dtype=DataType.Int) + T16 = fd.ops.broadcast_in_dim(T11, shape=V15, broadcast_dims=[0, 1, 2]) + T17 = fd.ops.cast(T16, dtype=DataType.Float) + T18 = fd.ops.cast(T0, dtype=DataType.Float) + T19 = fd.ops.mul(T17, T18) + T20 = fd.ops.cast(T19, dtype=DataType.Half) + fd.add_output(T20) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + torch_ref = inputs[0] * (inputs[1] * inputs[2]).unsqueeze(-1) + self.assertEqual(nvf_out[0], torch_ref) + if __name__ == "__main__": run_tests() diff --git a/test/test_alias.cpp b/test/test_alias.cpp index a92850bb092..f99a91ab806 100644 --- a/test/test_alias.cpp +++ b/test/test_alias.cpp @@ -382,12 +382,7 @@ TEST_F(AliasTest, DuplicateOutputs) { at::Tensor expected_out_tensor = in_tensor.add(3.141); // Verify output values. - testValidate( - fec.fusion(), - {expected_out_tensor, expected_out_tensor}, - {in_tensor}, - __LINE__, - __FILE__); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); } TEST_F(AliasTest, SliceToSizeOne_Issue1353) { @@ -529,12 +524,7 @@ TEST_F(AliasTest, DuplicateOutputsSegmentedFusion) { at::Tensor intermediate_tensor = in_tensor.add(3.141); at::Tensor out_tensor = intermediate_tensor.mul(2.0); // Verify output values. - testValidate( - fec.fusion(), - {intermediate_tensor, intermediate_tensor, out_tensor, out_tensor}, - {in_tensor}, - __LINE__, - __FILE__); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); } TEST_F(AliasTest, NotAllOutputsAlias) { @@ -588,4 +578,49 @@ TEST_F(AliasTest, Set_NoAliasForIncompatibleLayout) { EXPECT_FALSE(out_tensor.is_alias_of(in_tensor)); } +// Verifying that duplicated outputs are properly alised +TEST_F(AliasTest, DuplicateOutputsComplex) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({2, 3, 5}); + fusion->addInput(in); + TensorView* out = add(in, IrBuilder::create(5.0)); + fusion->addOutput(out); + // duplicated output + fusion->addOutput(out); + TensorView* out1 = add(in, IrBuilder::create(1.0)); + fusion->addOutput(out1); + // duplicated output + fusion->addOutput(out); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = at::randn({2, 3, 5}).cuda(); + + std::vector out_tensors = fec.runFusionWithInputs({in_tensor}); + ASSERT_EQ(out_tensors.size(), 4); + + // Verify aliases among outputs. + EXPECT_TRUE(out_tensors[0].is_alias_of(out_tensors[1])); + EXPECT_FALSE(out_tensors[0].is_alias_of(out_tensors[2])); + EXPECT_TRUE(out_tensors[0].is_alias_of(out_tensors[3])); + + // Verify output values. + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); +} + +// test verifying that duplicated input is not allowed in nvfuser +TEST_F(AliasTest, DuplicateInputs) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* in = makeContigConcreteTensor({2, 3, 5}); + fusion->addInput(in); + + // duplicated input is not allowed + EXPECT_THAT( + [&]() { fusion->addInput(in); }, + testing::ThrowsMessage( + testing::HasSubstr("duplicated inputs is not allowed"))); +} + } // namespace nvfuser diff --git a/test/test_allocation_domain.cpp b/test/test_allocation_domain.cpp index b758d9b7f9d..5493ca7f067 100644 --- a/test/test_allocation_domain.cpp +++ b/test/test_allocation_domain.cpp @@ -1299,12 +1299,16 @@ TEST_F(AllocationDomainTest, Issue1290_ReplayCasPFailedDueToDifferentRanks) { EXPECT_THAT(out_tensor.sizes(), ElementsAre(2)); } +// This test is meant to verify that trivial stride order is dropped by +// TensorViewBuilder. See issue: https://github.com/NVIDIA/Fuser/issues/1399 TEST_F(AllocationDomainTest, TrivialStrideOrderTensorViewBuilder) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = TensorViewBuilder().ndims(2).strideOrder({0, 1}).build(); EXPECT_TRUE(tv0->hasAllocation()); + // trivial stride order would be dropped by TensorViewbuilder tv0 = TensorViewBuilder().ndims(2).strideOrder({1, 0}).build(); + // confirming that stride order is dropped and allocation domain is empty EXPECT_TRUE(!tv0->hasAllocation()); } diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 678b7dc1bc4..b1988e88198 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -9152,10 +9152,11 @@ TEST_F(NVFuserTest, FusionPersistentBufferCalculation4_CUDA) { auto persistent_buffer_size = persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); + // T1 and T2 are persistent buffers, but T2 can be projected to T1. + // So, the actual buffer size is just the size to save T1. NVF_ERROR( persistent_buffer_size.persistent_buffer_size == - static_cast( - aten_t0.size(1) * dataTypeSize(DataType::Float) * 2)); + static_cast(aten_t0.size(1) * dataTypeSize(DataType::Float))); NVF_ERROR( persistent_buffer_size.projected_persistent_buffer_size == diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index a2677aa0f78..cc1ad6603d6 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -5991,7 +5991,7 @@ TEST_F(NVFuserTest, FusionPropagateVectorizePredicate_CUDA) { // Make sure the index of the inner loop isn't used in the predicate NVF_ERROR(!for_loops_.empty()); auto loop_index = for_loops_.back()->index(); - auto cond_inputs = InputsOf::output(cond->fusion(), cond); + auto cond_inputs = InputsOf::output(cond); auto index_it = std::find(cond_inputs.begin(), cond_inputs.end(), loop_index); auto vec_factor_it = @@ -7241,35 +7241,6 @@ TEST_F(NVFuserTest, FusionBFloat16Scalars_CUDA) { } #endif -// Quick test of traversing attributes with IterVisitor -TEST_F(NVFuserTest, IterVisitorTraverseAttributes_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = slice( - tv0, - {{IrBuilder::create(1L), - sub(tv0->axis(0)->extent(), IrBuilder::create(1L))}}); - fusion.addOutput(tv1); - - auto tv1_resize = tv1->axis(0)->definition()->as(); - - auto stmts = StmtSort::getStmts(&fusion, true, true); - - // Make sure the expansion parameters of tv1_resize are visited - NVF_CHECK( - std::find(stmts.begin(), stmts.end(), tv1_resize->leftExpand()) != - stmts.end(), - "Resize left expand parameter not found"); - NVF_CHECK( - std::find(stmts.begin(), stmts.end(), tv1_resize->rightExpand()) != - stmts.end(), - "Resize right expand parameter not found"); -} - TEST_F(NVFuserTest, FusionManagedData_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8088,8 +8059,8 @@ TEST_F(NVFuserTest, FusionLayerNormFusedOpsRedundantCast_CUDA) { const float kEps = 1e-5; const int batch_size = 2048 * 8; const int hidden_size = 20480; + DataType dtype = DataType::Half; { - DataType dtype = DataType::Half; auto tv0 = makeContigTensor(1, dtype); auto tv1 = makeContigTensor(2, dtype); auto tv2 = makeContigTensor(1, dtype); @@ -8171,16 +8142,20 @@ TEST_F(NVFuserTest, FusionLayerNormFusedOpsRedundantCast_CUDA) { outputs.emplace_back(t33); } - auto persistent_buffer_info1 = scheduler_utils::persistentBuffers(fusion); + auto persistent_buffer_info = scheduler_utils::persistentBuffers(fusion); NVF_CHECK( - persistent_buffer_info1.persistent_buffers.size() == 2, + persistent_buffer_info.persistent_buffers.size() == 2, "Before project to other buffers, should have two persistent buffers!"); - reduction_scheduler_utils::projectPersistentBuffers(fusion, false); - auto persistent_buffer_info2 = scheduler_utils::persistentBuffers(fusion); + // The buffer size should only count 1 buffer because the other one is + // projected to its producer. + SchedulerRuntimeInfo runtime_info(fusion, inputs); + auto persistent_buffer_size = + persistentBufferSize(fusion, runtime_info, persistent_buffer_info); NVF_CHECK( - persistent_buffer_info2.persistent_buffers.size() == 1, - "After project to other buffers, should have one persistent buffer!"); + persistent_buffer_size.persistent_buffer_size == + hidden_size * dataTypeSize(dtype), + "Persistent buffer size is not correct!"); FusionExecutorCache fec(std::move(fusion_ptr)); auto cg_outputs = fec.runFusionWithInputs(inputs); @@ -8578,50 +8553,6 @@ TEST_F(NVFuserTest, FusionDanglingUnaryOp_CUDA) { __FILE__); } -// Test that traversing siblings with IterVisitor visits "orphans", i.e. unused -// outputs of multi-output Exprs. -TEST_F(NVFuserTest, IterVisitorTraverseSiblings_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto wf = Welford(tv0, {0}); - // wf.var_sum is used, but wf.avg and wf.n are orphaned - auto tv1 = neg(wf.var_sum); - fusion.addOutput(tv1); - - auto stmts = StmtSort::getStmts( - &fusion, - /*traverse_all_paths*/ false, - /*traverse_attributes*/ false, - /*traverse_siblings*/ true); - - // Make sure the expansion parameters of tv1_resize are visited - NVF_CHECK( - std::find(stmts.begin(), stmts.end(), wf.avg) != stmts.end(), - "Welford avg not traversed"); - NVF_CHECK( - std::find(stmts.begin(), stmts.end(), wf.n) != stmts.end(), - "Welford n not traversed"); - - // Test getting statements "to" a tensor with siblings - stmts = StmtSort::getStmtsTo( - &fusion, - {wf.n}, - /*traverse_all_paths*/ false, - /*traverse_attributes*/ false, - /*traverse_siblings*/ true); - // Make sure the expansion parameters of tv1_resize are visited - NVF_CHECK( - std::find(stmts.begin(), stmts.end(), wf.avg) != stmts.end(), - "Welford avg not traversed in getStmtsTo({n})"); - NVF_CHECK( - std::find(stmts.begin(), stmts.end(), wf.var_sum) != stmts.end(), - "Welford var_sum not traversed in getStmtsTo({n})"); -} - TEST_F(NVFuserTest, FusionLayerNormSharedMemoryBuffer_CUDA) { auto test = [](const int64_t hidden_size, DataType dtype) { std::unique_ptr fusion_ptr = std::make_unique(); @@ -8694,31 +8625,6 @@ TEST_F(NVFuserTest, FusionLayerNormSharedMemoryBuffer_CUDA) { } } -TEST_F(NVFuserTest, IterVisitorGetInputsTo) { - // Test that IterVisitor::getInputsTo() will stop further traverse when - // reaching the target tensors - Fusion fusion; - FusionGuard fg(&fusion); - - auto a = makeSymbolicTensor(1); - auto b = makeSymbolicTensor(1); - auto c = makeSymbolicTensor(1); - - fusion.addInput(a); - fusion.addInput(b); - fusion.addInput(c); - - auto d = add(b, c); - auto e = add(a, d); - - fusion.addOutput(e); - - auto inputs = IterVisitor::getInputsTo({e}, {a, d}); - std::unordered_set inputs_set(inputs.begin(), inputs.end()); - - EXPECT_EQ(inputs_set, std::unordered_set({a, d})); -} - // converted from https://github.com/NVIDIA/Fuser/issues/443 TEST_F(NVFuserTest, FusionInstanceNormNHWC_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); @@ -9374,6 +9280,91 @@ TEST_F(NVFuserTest, LoweringHook) { EXPECT_TRUE(executed); } +TEST_F(NVFuserTest, ProjectPersistentBufferMultiScopes) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + const int batch_size = 2048; + const int hidden_size = 10240; + DataType input_dtype = DataType::Float; + auto tv0 = makeContigTensor(2, input_dtype); + auto tv1 = makeContigTensor(2, input_dtype); + auto tv2 = makeContigTensor(2, input_dtype); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + auto tv3 = add(tv0, tv0); + auto tv4 = sum(tv3, {1}); + auto tv5 = broadcast(tv4, {false, true}); + auto tv6 = add(tv3, tv5); + + auto tv7 = add(tv3, tv3); + auto tv8 = sum(tv7, {1}); + auto tv9 = broadcast(tv8, {false, true}); + auto tv10 = add(tv7, tv9); + + auto tv11 = add(tv0, tv1); + auto tv12 = mul(tv11, tv11); + auto tv13 = sum(tv12, {1}); + auto tv14 = broadcast(tv13, {false, true}); + auto tv15 = add(tv12, tv14); + + auto tv16 = add(tv12, tv2); + auto tv17 = mul(tv16, tv16); + auto tv18 = sum(tv17, {1}); + auto tv19 = broadcast(tv18, {false, true}); + auto tv20 = add(tv17, tv19); + + fusion->addOutput(tv6); + fusion->addOutput(tv10); + fusion->addOutput(tv15); + fusion->addOutput(tv20); + + auto options = at::TensorOptions() + .dtype(data_type_to_aten(input_dtype)) + .device(at::kCUDA, 0); + auto t0 = at::randn({batch_size, hidden_size}, options); + auto t1 = at::randn({batch_size, hidden_size}, options); + auto t2 = at::randn({batch_size, hidden_size}, options); + std::vector inputs{t0, t1, t2}; + + // The persistent buffers in this fusion are: tv3, tv7, tv12, and tv17. Note + // that tv7 can be projected back to its producer, tv3. When calculating the + // total size of persistent buffers ([persistent_buffer_size]), it's important + // to consider the active scopes of these buffers. Simply subtracting the + // buffer size of tv7 from the max buffer size may lead to an underestimation. + // This is because there are two distinct scopes in this computation: (1) + // During the calculation of tv10, the active persistent buffers are tv3 and + // tv7. (2) For the calculation of tv20, the active persistent buffers are + // tv12 and tv17. The max buffer size is based on tv12 and tv17. There is no + // projectable buffer needs to be deducted in this scope. + auto persistent_info = scheduler_utils::persistentBuffers(fusion); + SchedulerRuntimeInfo runtime_info(fusion, inputs); + auto persistent_buffer_size = + persistentBufferSize(fusion, runtime_info, persistent_info); + auto calculated_size = persistent_buffer_size.persistent_buffer_size; + auto expected_size = + static_cast(hidden_size * 2 * dataTypeSize(input_dtype)); + NVF_CHECK( + calculated_size == expected_size, + "Buffer size calculation failure. Expected size: ", + expected_size, + ". Actual: ", + calculated_size); + auto persistent_params = getInnerPersistentHeuristics(fusion, inputs); + NVF_CHECK(persistent_params, "Reduction schedule was not generated!"); + NVF_CHECK( + !persistent_params->project_persistent_buffers, + "Shouldn't project persistent buffers to inputs!"); + + scheduleInnerPersistentKernel(fusion, *persistent_params); + FusionExecutor fe; + fe.compileFusion(fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); +} // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser diff --git a/test/test_gpu_utils.cpp b/test/test_gpu_utils.cpp index 0b71103dc32..0a89af5ce75 100644 --- a/test/test_gpu_utils.cpp +++ b/test/test_gpu_utils.cpp @@ -985,7 +985,7 @@ TEST_F(VectorizeHelperTest, SpanningTree_CUDA) { } auto fusion_inps = fusion.inputs(); for (auto inp : fusion_inps) { - fusion.removeOutput(inp); + fusion.removeInput(inp); } } diff --git a/test/test_gpu_view.cpp b/test/test_gpu_view.cpp index f4b334c769e..188868b316b 100644 --- a/test/test_gpu_view.cpp +++ b/test/test_gpu_view.cpp @@ -1191,12 +1191,7 @@ TEST_F(GpuViewTest, FusionReshapeVectorize) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - // This test allocates about 1GB of memory, so in order to avoid an OOM during - // this test, we manually clear the allocator after it's reached a certain - // threshold. - maybeClearAllocator(); - - at::Tensor input = at::randn({256, 1024, 1024}, options); + at::Tensor input = at::randn({256, 256, 256}, options); auto lparams = schedulePointwise(&fusion, {input}); diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp new file mode 100644 index 00000000000..fc823316564 --- /dev/null +++ b/test/test_id_model.cpp @@ -0,0 +1,40 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace nvfuser { + +class IdModelTest : public NVFuserTest {}; + +TEST_F(IdModelTest, DetectSelfMapping) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 2}); + fusion.addInput(tv0); + auto tv1 = transpose(tv0, 0, 1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + EXPECT_THAT( + [&]() { IdModel id_model(&fusion); }, + ::testing::ThrowsMessage( + ::testing::HasSubstr("!hasSelfMapping"))); +} + +} // namespace nvfuser diff --git a/test/test_iter_visitor.cpp b/test/test_iter_visitor.cpp new file mode 100644 index 00000000000..a0e5eacd2c4 --- /dev/null +++ b/test/test_iter_visitor.cpp @@ -0,0 +1,121 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include + +#include +#include +#include +#include + +namespace nvfuser { + +using IterVisitorTest = NVFuserTest; +using testing::Contains; +using testing::IsSupersetOf; +using testing::UnorderedElementsAre; + +// Quick test of traversing attributes with IterVisitor +TEST_F(IterVisitorTest, IterVisitorTraverseAttributes) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = slice( + tv0, + {{IrBuilder::create(1L), + sub(tv0->axis(0)->extent(), IrBuilder::create(1L))}}); + fusion.addOutput(tv1); + + auto tv1_resize = tv1->axis(0)->definition()->as(); + + auto stmts = StmtSort::getStmts(&fusion, true, true); + + // Make sure the expansion parameters of tv1_resize are visited + EXPECT_THAT(stmts, Contains(tv1_resize->leftExpand())) << "Resize left expand parameter not found"; + EXPECT_THAT(stmts, Contains(tv1_resize->rightExpand())) << "Resize right expand parameter not found"; +} + +// Test that traversing siblings with IterVisitor visits "orphans", i.e. unused +// outputs of multi-output Exprs. +TEST_F(IterVisitorTest, IterVisitorTraverseSiblings) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto wf = Welford(tv0, {0}); + // wf.var_sum is used, but wf.avg and wf.n are orphaned + auto tv1 = neg(wf.var_sum); + fusion.addOutput(tv1); + + auto stmts = StmtSort::getStmts( + &fusion, + /*traverse_all_paths*/ false, + /*traverse_attributes*/ false, + /*traverse_siblings*/ true); + + EXPECT_THAT(stmts, Contains(wf.avg)) << "Welford avg not traversed"; + EXPECT_THAT(stmts, Contains(wf.n)) << "Welford n not traversed"; + + // Test getting statements "to" a tensor with siblings + stmts = StmtSort::getStmtsTo( + {wf.n}, + /*traverse_all_paths=*/false, + /*traverse_attributes=*/false, + /*traverse_siblings=*/true); + EXPECT_THAT(stmts, Contains(wf.avg)) << "Welford avg not traversed in getStmtsTo({n})"; + EXPECT_THAT(stmts, Contains(wf.var_sum)) << "Welford var_sum not traversed in getStmtsTo({n})"; +} + +TEST_F(IterVisitorTest, IterVisitorGetInputsTo) { + // Test that IterVisitor::getInputsTo() will stop further traverse when + // reaching the target tensors + Fusion fusion; + FusionGuard fg(&fusion); + + auto a = makeSymbolicTensor(1); + auto b = makeSymbolicTensor(1); + auto c = makeSymbolicTensor(1); + + fusion.addInput(a); + fusion.addInput(b); + fusion.addInput(c); + + auto d = add(b, c); + auto e = add(a, d); + + fusion.addOutput(e); + + std::vector inputs = IterVisitor::getInputsTo({e}, {a, d}); + EXPECT_THAT(inputs, UnorderedElementsAre(a, d)); +} + +TEST_F(IterVisitorTest, NonTerminatingOutput) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* a = makeSymbolicTensor(1); + TensorView* b = set(a); + TensorView* c = set(b); + TensorView* d = set(c); + TensorView* e = set(d); + + fusion.addInput(a); + fusion.addOutput(c); + fusion.addOutput(e); + + // Even though `c` is a non-terminating output, `d` and `e` should still be + // considered in between. This is because `StmtSort::getExprsBetween` + // traverses from `to` along use-def chains until it hits `from`. + EXPECT_THAT(StmtSort::getExprsBetween({a}, {c, e}), IsSupersetOf({d->definition(), e->definition()})); +} + +} // namespace nvfuser diff --git a/test/test_mma.cpp b/test/test_mma.cpp index ea646a559b0..4b6e3e12425 100644 --- a/test/test_mma.cpp +++ b/test/test_mma.cpp @@ -109,29 +109,22 @@ TEST_P(MmaTest, SingleTile) { tv2c->applyMmaSwizzle(MmaOperand::Accumulator); tv2->applyMmaSwizzle(MmaOperand::Accumulator); - auto options = - at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); - auto t0 = at::randn(A_shape, options); - auto t1 = at::randn(B_shape, options); + auto inputs = matmulAtInput( + getM(macro), getN(macro), getK(macro), layout, data_type_to_aten(dtype)); FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams); - - auto cg_outputs = fe.runFusion({t0, t1}); - - at::Tensor t0t = t0, t1t = t1; - - if (transpose_a) { - t0t = t0.t(); - } - - if (!transpose_b) { - t1t = t1.t(); - } - - auto tref = t0t.to(at::kFloat).matmul(t1t.to(at::kFloat)); - - testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + testValidate( + &fusion, + cg_outputs, + {inputs.first, inputs.second}, + {tref}, + __LINE__, + __FILE__); } auto all_mma_layouts = diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index e81e676f911..b4c961e1f76 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -75,9 +75,7 @@ TEST_F(NVFuserTest, FusionCyclicGraph_CUDA) { ir_utils::checkCycle(fusion.get()).size() == 6, "cycle of size 6 should be detected in fusion"); EXPECT_THAT( - [&]() { - StmtSort::getStmtsBetween(fusion.get(), {}, fusion->outputs()); - }, + [&]() { StmtSort::getStmtsBetween({}, fusion->outputs()); }, ::testing::ThrowsMessage( ::testing::HasSubstr("cycle detected"))); } @@ -115,7 +113,7 @@ TEST_F(NVFuserTest, FusionCyclicGraph_CUDA) { to.push_back(tv1); // cycle should be detected, since dead branch is in our check path EXPECT_THAT( - [&]() { StmtSort::getStmtsBetween(fusion.get(), {}, to); }, + [&]() { StmtSort::getStmtsBetween({}, to); }, ::testing::ThrowsMessage( ::testing::HasSubstr("cycle detected"))); diff --git a/test/test_resize.cpp b/test/test_resize.cpp index ebbd68eb10e..5f0b782c8d6 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -3073,4 +3073,46 @@ TEST_F(ResizeTest, PadOfExpandedBroadcast) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } +TEST_F(NVFuserTest, dynamicReshapeIssue1393) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = TensorViewBuilder() + .ndims(2) + .shape({-1, -1}) + .contiguity({true, std::nullopt}) + .expanded({false, true}) + .build(); + auto tv1 = TensorViewBuilder() + .ndims(2) + .shape({-1, -1}) + .contiguity({std::nullopt, true}) + .expanded({true, false}) + .build(); + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = add(tv0, tv1); + auto s0 = IrBuilder::create(3); + auto s1 = IrBuilder::create(4); + auto s2 = IrBuilder::create(1); + auto s3 = IrBuilder::create(5); + auto tv3 = reshape(tv2, {s0, s1, s2}); + auto tv4 = expand(tv3, {s0, s1, s3}); + fusion->addOutput(tv4); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({3}, options).as_strided({3, 4}, {1, 0}); + at::Tensor t1 = at::randn({4}, options).as_strided({3, 4}, {0, 1}); + auto ref = t0.add(t1).as_strided({3, 4, 5}, {4, 1, 0}); + + std::vector aten_inputs({t0, t1}); + auto outputs = fec.runFusionWithInputs(aten_inputs); + + testValidate(fusion, outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/test/test_swizzle.cpp b/test/test_swizzle.cpp index a6d70bd26bb..5a1c1281f9f 100644 --- a/test/test_swizzle.cpp +++ b/test/test_swizzle.cpp @@ -644,7 +644,6 @@ TEST_F(SwizzleTest, TransformPropagatorSkipSwizzleOnTarget) { MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); auto exprs = StmtSort::getExprsBetween( - tv1->fusion(), {tv1->getRootDomain().begin(), tv1->getRootDomain().end()}, {tv1->getLeafDomain().begin(), tv1->getLeafDomain().end()}); EXPECT_TRUE(std::any_of(exprs.begin(), exprs.end(), [](Expr* expr) { From 14ab237a0f13ee51cb609d711b19adcabd7ecbb1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 6 Dec 2023 12:37:15 -0800 Subject: [PATCH 091/178] IdModel: cleanup almost exact (#1457) Refactoring the almost exact graph construction. - The main logic is moved from ValGraph to IdModel - Extended the validator to support the almost exact graph --- csrc/device_lower/lower2device.cpp | 2 +- csrc/id_model/id_model.cpp | 88 +++++++++++- csrc/id_model/validation_utils.cpp | 207 +++++++++++++++++++---------- csrc/id_model/validation_utils.h | 15 ++- csrc/val_graph.cpp | 54 -------- csrc/val_graph.h | 8 -- 6 files changed, 236 insertions(+), 138 deletions(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 89b79c60d46..87d5a32c772 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -381,7 +381,7 @@ void GpuLower::analysis(Fusion* fusion) { // functionality should be affected. New IterDomains may be created, // so it is expected that generated code may use diffrent variable // names - if (isOptionEnabled(EnableOption::IdModel)) { + if (true || isOptionEnabled(EnableOption::IdModel)) { IdModel id_model(fusion_, false, true); } diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index dac66d76aaf..ef651694ef9 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -761,10 +762,75 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { mapThroughLoopSwizzles(idGraph(IdMappingMode::PERMISSIVE)); } +namespace { + +// Checks if the expression is a trivial operation where an input is simply an +// output of the transformation. Returns the mapped iter domains if found. +std::vector> isTrivialExpr(Expr* expr) { + std::vector> mapped_ids; + if (auto merge = dynamic_cast(expr)) { + if (merge->inner()->extent()->isOneInt()) { + mapped_ids.push_back({merge->outer(), merge->out()}); + } + if (merge->outer()->extent()->isOneInt()) { + mapped_ids.push_back({merge->inner(), merge->out()}); + } + } else if (auto split = dynamic_cast(expr)) { + if (split->factor()->isOneInt() && split->startOffset()->isZeroInt() && + split->stopOffset()->isZeroInt()) { + if (split->innerSplit()) { + mapped_ids.push_back({split->in(), split->outer()}); + } else { + mapped_ids.push_back({split->in(), split->inner()}); + } + } + } else if (auto swizzle = dynamic_cast(expr)) { + if (swizzle->swizzleType() == Swizzle2DType::NoSwizzle || + swizzle->swizzleMode() == SwizzleMode::NoSwizzle) { + mapped_ids.push_back({swizzle->inX(), swizzle->outX()}); + mapped_ids.push_back({swizzle->inY(), swizzle->outY()}); + } + } + return mapped_ids; +} + +} // namespace + void IdModel::buildAlmostExactMap() { // Build almost exact map by forwarding through broadcast axes idGraph(IdMappingMode::ALMOSTEXACT) = idGraph(IdMappingMode::EXACT); - idGraph(IdMappingMode::ALMOSTEXACT).mapThroughTrivialExprs(); + + auto& almost_exact_graph = idGraph(IdMappingMode::ALMOSTEXACT); + + // Maps iter domain pairs returned by calling that return mappings from + // isTrivialExpr on every expression in the graph. + + // Don't traverse the graph and at the same time add more mappings + // as the traversal would be invalidated + std::vector> ids_to_map; + + for (const auto& expr_group : + almost_exact_graph.disjointExprSets().disjointSets()) { + for (auto expr : *expr_group) { + // If not trivial continue + auto mapped_ids = isTrivialExpr(expr); + if (mapped_ids.empty()) { + continue; + } + + // Map through trivial expressions + for (auto mapped_id_group : mapped_ids) { + for (auto id : mapped_id_group) { + // almost_exact_graph.mapVals(mapped_id_group.front(), id); + ids_to_map.emplace_back(mapped_id_group.front(), id); + } + } + } + } + + for (const auto& [id1, id2] : ids_to_map) { + almost_exact_graph.mapVals(id1, id2); + } } // TODO: Reenable after reenabling parallel propagation. @@ -982,6 +1048,16 @@ void IdModel::build( return; } + std::unique_ptr validator; + + // A ComputeAtMap will be built inside the constructor of + // IdModelValidator, which may fail for some fusions that are not + // supported currently (but work with IdModel). Make sure the + // validator is only created when it is indeed requested + if (validate) { + validator = std::make_unique(all_tvs.front()->fusion()); + } + FusionGuard fg(all_tvs.front()->fusion()); // Add uses and definitions to all iter domains. buildIterDomainDefinitionsAndUses(all_tvs.vector()); @@ -991,16 +1067,16 @@ void IdModel::build( idGraph(IdMappingMode::EXACT) = initializeIdGraph(); buildExactGraph(tv_exprs); - if (validate) { - IdModelValidator::checkExactGraphEquivalence(idGraph(IdMappingMode::EXACT)); + validator->checkExactGraphEquivalence(idGraph(IdMappingMode::EXACT)); } - if (getenv("EXACT_ONLY")) { - return; + buildAlmostExactMap(); + if (validate) { + validator->checkAlmostExactGraphEquivalence( + idGraph(IdMappingMode::ALMOSTEXACT)); } - buildAlmostExactMap(); buildPermissiveMap(tv_exprs); // Permissive graph needs the trivial exprs from the almost exact graph to diff --git a/csrc/id_model/validation_utils.cpp b/csrc/id_model/validation_utils.cpp index f996180591c..9fd53aeac78 100644 --- a/csrc/id_model/validation_utils.cpp +++ b/csrc/id_model/validation_utils.cpp @@ -8,106 +8,132 @@ #include #include #include +#include +#include #include namespace nvfuser { -void IdModelValidator::checkExactGraphEquivalence(const ValGraph& exact_graph) { - // Empty graph - if (exact_graph.disjointValSets().disjointSets().empty()) { - return; - } - - auto all_exprs = exact_graph.disjointExprSets().getAllElements(); - if (std::find_if(all_exprs.begin(), all_exprs.end(), [](Expr* expr) { - return expr->isA(); - }) != all_exprs.end()) { - // Ignoring a fusion with swizzle - return; +IdModelValidator::IdModelValidator(Fusion* fusion) : ca_map_(fusion) { + for (auto tv : ir_utils::allTvs(fusion)) { + for (auto id : ir_utils::allIDsOf(tv)) { + if (id->definition() && id->definition()->isA()) { + has_swizzle_ = true; + break; + } + } } +} - Fusion* fusion = exact_graph.disjointValSets() - .disjointSets() - .at(0) - ->vector() - .at(0) - ->fusion(); - ComputeAtMap ca_map(fusion); - - DisjointSets& ca_map_exact_sets = ca_map.id_graph_.exact_nodes_; - - // Propgate mappings through expressions in ComputeAtMap. Since we - // want to traverse and update ca_map_exact_sets, once updated, the - // traversal of the ID groups cannot continue and needs to be - // restarted. The algorithm seems terriblly inefficient, but - // shuldn't matter as this is just for transitory validations - bool updated = true; - while (updated) { - updated = false; - for (const auto& set : ca_map_exact_sets.disjointSets()) { - auto uses = ca_map.uniqueExactUses(set->vector().front()); - auto use_count = uses.size(); - // Note that it should be fine to continue updating the map with - // the loop below as it should only modify output domain groups - for (size_t i = 0; i < use_count; ++i) { - auto use_i = uses.at(i); - for (size_t j = i + 1; j < use_count; ++j) { - auto use_j = uses.at(j); - if (!IterDomainGraph::exprsMap( - use_i, use_j, true, ca_map_exact_sets)) { - continue; +void IdModelValidator::fullyPropagateMappings( + DisjointSets& id_sets) { + // This algorithm seems terriblly inefficient but shuldn't matter as + // this is just for transitory validations + while (true) { + // Grab all pairs of domains to map + std::vector> ids_to_map; + for (const auto& set : id_sets.disjointSets()) { + // Propagate both forward and backward + for (bool is_forward : {true, false}) { + // Grab all use exprs of this ID set + std::vector all_exprs; + for (auto id : *set) { + // In the case of forward propagation, the uses() exprs may + // not be actually used for IterDomain. Make sure to pick + // only those whose outputs are in the map + if (is_forward) { + for (auto use : id->uses()) { + if (std::all_of( + use->outputs().begin(), + use->outputs().end(), + [&](Val* output) { + return output->isA() && + id_sets.mappingExists(output->as()); + })) { + all_exprs.push_back(use); + } + } + } else { + all_exprs.push_back(id->definition()); } - auto num_outputs = use_i->outputs().size(); - NVF_ERROR(use_j->outputs().size() == num_outputs); - for (size_t output_i = 0; output_i < num_outputs; ++output_i) { - auto out_i = use_i->output(output_i)->as(); - auto out_j = use_j->output(output_i)->as(); - if (!ca_map_exact_sets.strictAreMapped(out_i, out_j)) { - ca_map_exact_sets.mapEntries(out_i, out_j); - updated = true; + } + + // Look at all combinatorial pairs of the uses of + // definitions. If they are mapped, i.e., their input or + // output domains are mapped and the expr + // properties are equivalent, map the outputs or inputs as well + auto count = all_exprs.size(); + for (size_t i = 0; i < count; ++i) { + auto expr_i = all_exprs.at(i); + for (size_t j = i + 1; j < count; ++j) { + auto expr_j = all_exprs.at(j); + if (!IterDomainGraph::exprsMap( + expr_i, expr_j, is_forward, id_sets)) { + continue; + } + const auto& prop_target_i = + is_forward ? expr_i->outputs() : expr_i->inputs(); + const auto& prop_target_j = + is_forward ? expr_j->outputs() : expr_j->inputs(); + auto num_target = prop_target_i.size(); + NVF_ERROR(num_target == prop_target_j.size()); + for (size_t target_i = 0; target_i < num_target; ++target_i) { + auto id_i = prop_target_i.at(target_i)->as(); + auto id_j = prop_target_j.at(target_i)->as(); + if (!id_sets.strictAreMapped(id_i, id_j)) { + // Don't actually map them yet as it would invalidate + // the loop over id_sets + ids_to_map.emplace_back(id_i, id_j); + } } } } } - // If updated, the previous sets returned by - // ca_map_exact_sets.disjointSets() may contain stale sets - if (updated) { - ca_map.build(fusion); - break; - } + } + + // No additional domains to map. Nothing to do further + if (ids_to_map.empty()) { + return; + } + + for (const auto& [id1, id2] : ids_to_map) { + id_sets.mapEntries(id1, id2); } } +} - const DisjointSets& id_model_exact_sets = exact_graph.disjointValSets(); +namespace { - if (id_model_exact_sets.size() != ca_map_exact_sets.size()) { +void compareDisjointSets( + const DisjointSets& ca_map_sets, + const DisjointSets& id_model_sets) { + if (id_model_sets.size() != ca_map_sets.size()) { std::stringstream ss; - ss << "Mismatched number of groups: " << id_model_exact_sets.size() << ", " - << ca_map_exact_sets.size() << "\n"; + ss << "Mismatched number of groups: " << id_model_sets.size() << ", " + << ca_map_sets.size() << "\n"; - ss << "IdModel exact sets:\n"; - for (const auto& id_set : id_model_exact_sets.disjointSets()) { + ss << "IdModel sets:\n"; + for (const auto& id_set : id_model_sets.disjointSets()) { ss << "\t" << nvfuser::toString(id_set->vector()) << "\n"; } - ss << "ComputeAtMap exact sets:\n"; - for (const auto& id_set : ca_map_exact_sets.disjointSets()) { + ss << "ComputeAtMap sets:\n"; + for (const auto& id_set : ca_map_sets.disjointSets()) { ss << "\t" << nvfuser::toString(id_set->vector()) << "\n"; } NVF_ERROR(false, ss.str()); } - for (const auto& id_model_id_set : id_model_exact_sets.disjointSets()) { + for (const auto& id_model_id_set : id_model_sets.disjointSets()) { NVF_ERROR(!id_model_id_set->empty()); NVF_ERROR( - ca_map_exact_sets.mappingExists( - id_model_id_set->front()->as()), + ca_map_sets.mappingExists(id_model_id_set->front()->as()), "Not found in ComputeAtMap: ", id_model_id_set->front()->toString()); - const auto& ca_map_id_set = ca_map_exact_sets.getDisjointSetOf( + const auto& ca_map_id_set = ca_map_sets.getDisjointSetOf( id_model_id_set->front()->as()); std::unordered_set ca_map_id_set_cast; @@ -125,4 +151,49 @@ void IdModelValidator::checkExactGraphEquivalence(const ValGraph& exact_graph) { } } +} // namespace + +void IdModelValidator::checkExactGraphEquivalence(const ValGraph& exact_graph) { + if (has_swizzle_) { + // Ignoring a fusion with swizzle + return; + } + + // Empty graph + if (exact_graph.disjointValSets().disjointSets().empty()) { + return; + } + + DisjointSets ca_map_sets = ca_map_.id_graph_.exact_nodes_; + + // IdModel propagates mappings forward and backward more + // consistently, which is not the case with ComputeAt. To compare + // the two mappings, augment the ComputeAt mappings with the same + // propagation. This might potentially hide some subtle differences + // between the two mappings, but I think this is still a reasonable + // way to validate IdModel + fullyPropagateMappings(ca_map_sets); + + compareDisjointSets(ca_map_sets, exact_graph.disjointValSets()); +} + +void IdModelValidator::checkAlmostExactGraphEquivalence( + const ValGraph& almost_exact_graph) { + if (has_swizzle_) { + // Ignoring a fusion with swizzle + return; + } + + // Empty graph + if (almost_exact_graph.disjointValSets().disjointSets().empty()) { + return; + } + + DisjointSets ca_map_sets = ca_map_.id_graph_.almost_exact_nodes_; + + fullyPropagateMappings(ca_map_sets); + + compareDisjointSets(ca_map_sets, almost_exact_graph.disjointValSets()); +} + } // namespace nvfuser diff --git a/csrc/id_model/validation_utils.h b/csrc/id_model/validation_utils.h index 647fe1ba925..530ab3f1b01 100644 --- a/csrc/id_model/validation_utils.h +++ b/csrc/id_model/validation_utils.h @@ -17,6 +17,8 @@ namespace nvfuser { // have private access class IdModelValidator { public: + IdModelValidator(Fusion* fusion); + // Validate a given exact graph of IdModel by comparing it with // ComputeAtMap. Their maps should // be almost the same but there are some differences. @@ -34,7 +36,18 @@ class IdModelValidator { // swizzle is used we give up validating the exact graph. The second // difference is whether mappings are propagated, which can be // accounted for by updating the ComputeAtMap as is done in IdModel. - static void checkExactGraphEquivalence(const ValGraph& exact_graph); + void checkExactGraphEquivalence(const ValGraph& exact_graph); + + void checkAlmostExactGraphEquivalence(const ValGraph& almost_exact_graph); + + private: + // Propagate mappings in a ComputeAtMap as is done in IdModel + static void fullyPropagateMappings(DisjointSets& id_sets); + + private: + ComputeAtMap ca_map_; + // Validation is not enabled if swizzle is found. See the comment above + bool has_swizzle_ = false; }; } // namespace nvfuser diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index d1626a66124..1a33dbef61d 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -501,34 +501,6 @@ std::string ValGraph::toString() const { return ss.str(); } -std::vector> ValGraph::isTrivialExpr(Expr* expr) { - std::vector> mapped_ids; - if (auto merge = dynamic_cast(expr)) { - if (merge->inner()->extent()->isOneInt()) { - mapped_ids.push_back({merge->outer(), merge->out()}); - } - if (merge->outer()->extent()->isOneInt()) { - mapped_ids.push_back({merge->inner(), merge->out()}); - } - } else if (auto split = dynamic_cast(expr)) { - if (split->factor()->isOneInt() && split->startOffset()->isZeroInt() && - split->stopOffset()->isZeroInt()) { - if (split->innerSplit()) { - mapped_ids.push_back({split->in(), split->outer()}); - } else { - mapped_ids.push_back({split->in(), split->inner()}); - } - } - } else if (auto swizzle = dynamic_cast(expr)) { - if (swizzle->swizzleType() == Swizzle2DType::NoSwizzle || - swizzle->swizzleMode() == SwizzleMode::NoSwizzle) { - mapped_ids.push_back({swizzle->inX(), swizzle->outX()}); - mapped_ids.push_back({swizzle->inY(), swizzle->outY()}); - } - } - return mapped_ids; -} - bool ValGraph::transformAtributesMatch(Expr* first, Expr* second) { if (first == nullptr || second == nullptr) { return false; @@ -869,32 +841,6 @@ bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return true; } -void ValGraph::mapThroughTrivialExprs() { - // Grab all expressions - std::vector exprs; - - for (const auto& expr_group : disjointExprSets().disjointSets()) { - for (auto expr : *expr_group) { - exprs.push_back(expr); - } - } - - for (auto expr : exprs) { - // If not trivial continue - auto mapped_ids = ValGraph::isTrivialExpr(expr); - if (mapped_ids.empty()) { - continue; - } - - // Map through trivial expressions - for (auto mapped_id_group : mapped_ids) { - for (auto id : mapped_id_group) { - mapVals(mapped_id_group.front(), id); - } - } - } -} - void ValGraph::removeTrivialExprs() { ExprGroups trivial_expr_groups; // This seems like it shouls just be a copy if. diff --git a/csrc/val_graph.h b/csrc/val_graph.h index a24064cafc9..e3199ae0c50 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -155,10 +155,6 @@ class ValGraph { std::string toString() const; - // Checks if the expression is a trivial operation where an input is simply an - // output of the transformation. Returns the mapped iter domains if found. - static std::vector> isTrivialExpr(Expr* expr); - // Returns if all atributes of the ID transforms first and second are the same static bool transformAtributesMatch(Expr* first, Expr* second); @@ -205,10 +201,6 @@ class ValGraph { // be the only call in ValGraph to mapThroughExpr. void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); - // Maps iter domain pairs returned by calling that return mappings from - // IdGraph::isTrivialExpr on every expression in the graph. - void mapThroughTrivialExprs(); - // Removes expressions from unique_definitions_ and unique_uses_ that return // mappings from IdGraph::isTrivialExpr void removeTrivialExprs(); From f6e68486e6c4e89902c83004b1166a675d1e2a95 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 7 Dec 2023 11:44:39 -0800 Subject: [PATCH 092/178] cleanup transform_iter --- csrc/transform_iter.cpp | 25 +++++++++++++++++-------- csrc/transform_iter.h | 3 +++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 164ce205e40..5ce00978687 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -670,6 +670,12 @@ int BestEffortReplay::findFirstMismatchedID( ForwardingInfo::ForwardingInfo( const TensorView* producer, const TensorView* consumer) { + // No forwarding unless this is broadcast or squeeze + if (!dynamic_cast(consumer->definition()) && + !dynamic_cast(consumer->definition())) { + return; + } + // Active indicates the TV that has axes the other TV does not. For // broadcast this is the consumer squeeze the producer. // @@ -747,24 +753,27 @@ ForwardingInfo::ForwardingInfo( // For the sake of BestEffortReplay we can forward the input mapping // to both the active and inactive tensor to the output of the // expression - std::vector forwarded_ids_vec; - std::vector compliment_ids; + IterDomain* forwarded_id = nullptr; + IterDomain* compliment_id = nullptr; for (auto input_id : input_ids) { if (!isInForwardIdSet(input_id)) { - forwarded_ids_vec.emplace_back(input_id); + NVF_ERROR(forwarded_id == nullptr); + forwarded_id = input_id; active_forwarding_map->emplace( std::make_pair(input_id, merge_expr->out())); } else { - compliment_ids.push_back(input_id); + NVF_ERROR(compliment_id == nullptr); + compliment_id = input_id; } } + NVF_ERROR(forwarded_id != nullptr); + NVF_ERROR(compliment_id != nullptr); + // Set up compliment map - for (auto forwarded_id : forwarded_ids_vec) { - active_compliment_map->emplace( - std::make_pair(forwarded_id, compliment_ids)); - } + active_compliment_map->emplace( + forwarded_id, std::vector{compliment_id}); } } } diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index 086c47c19ac..4e0f08bfe5a 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -171,6 +171,9 @@ class ReplayTransformations : public IterVisitor { // nodes we may have after the forwarding process is finished. Leaf nodes are // only important for replayCasP, so look there to see how this is done. Forward // map is used for replayCasP and replayPasC. +// +// The producer forwarding map is filled when producer broadcast +// domains are squeezed. class ForwardingInfo { public: // Map IterDomain* axes that can safely be forwarded to their output. From a8070a9cf063d66a046f64864e9314cdba84bb3c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 7 Dec 2023 19:16:54 -0800 Subject: [PATCH 093/178] cleanup --- csrc/compute_at_map.cpp | 3 +- csrc/val_graph.cpp | 96 ++++++++++++++++++++--------------------- 2 files changed, 49 insertions(+), 50 deletions(-) diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index b6c6088a5d1..f6a7755fdc8 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -1499,9 +1499,8 @@ const DisjointSets& ComputeAtMap::getIdSets( return id_graph_.permissiveResizeNodes(); case IdMappingMode::INNERMOST: return id_graph_.innermostNodes(); - default: - NVF_ERROR(false, "Error with mapping mode provided: ", mode); } + NVF_ERROR(false, "Error with mapping mode provided."); } bool ComputeAtMap::idExistsInMap(IterDomain* id) const { diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 1a33dbef61d..54e4790d240 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -22,35 +22,32 @@ ValGraph::ValGraph(const ValGraph& other) disjoint_exprs_(other.disjoint_exprs_), unique_definitions_(), unique_uses_() { - for (const auto& [orig_id_group, orig_expr_groups] : + for (const auto& [orig_val_group, orig_expr_groups] : other.unique_definitions_) { - auto new_id_group = toGroup(orig_id_group->front()); + auto new_val_group = toGroup(orig_val_group->front()); ExprGroups new_expr_groups; for (const ExprGroup& orig_expr_group : orig_expr_groups) { new_expr_groups.pushBack(toGroup(orig_expr_group->front())); } - unique_definitions_[new_id_group] = new_expr_groups; + unique_definitions_[new_val_group] = new_expr_groups; } - for (const auto& [orig_id_group, orig_expr_groups] : other.unique_uses_) { - auto new_id_group = toGroup(orig_id_group->front()); + for (const auto& [orig_val_group, orig_expr_groups] : other.unique_uses_) { + auto new_val_group = toGroup(orig_val_group->front()); ExprGroups new_expr_groups; for (const ExprGroup& orig_expr_group : orig_expr_groups) { new_expr_groups.pushBack(toGroup(orig_expr_group->front())); } - unique_uses_[new_id_group] = new_expr_groups; + NVF_ERROR( + unique_uses_.emplace(new_val_group, std::move(new_expr_groups)).second); } } ValGraph& ValGraph::operator=(const ValGraph& other) { - disjoint_vals_.clear(); - disjoint_exprs_.clear(); - unique_definitions_.clear(); - unique_uses_.clear(); ValGraph copy(other); std::swap(*this, copy); return *this; @@ -61,9 +58,9 @@ bool ValGraph::hasGroup(Expr* expr) const { return disjoint_exprs_.mappingExists(expr); } -// Return if there's a group entry in the graph for this id -bool ValGraph::hasGroup(Val* id) const { - return disjoint_vals_.mappingExists(id); +// Return if there's a group entry in the graph for this val +bool ValGraph::hasGroup(Val* val) const { + return disjoint_vals_.mappingExists(val); } const ExprGroup& ValGraph::toGroup(Expr* expr) const { @@ -75,12 +72,12 @@ const ExprGroup& ValGraph::toGroup(Expr* expr) const { return disjoint_set_it->second; } -const ValGroup& ValGraph::toGroup(Val* id) const { - auto disjoint_set_it = disjoint_vals_.disjointSetMap().find(id); +const ValGroup& ValGraph::toGroup(Val* val) const { + auto disjoint_set_it = disjoint_vals_.disjointSetMap().find(val); NVF_ERROR( disjoint_set_it != disjoint_vals_.disjointSetMap().end(), "\nId group could not be found in graph associated with: ", - id->toString(), + val->toString(), "\n"); return disjoint_set_it->second; } @@ -93,35 +90,36 @@ ExprGroups ValGraph::toGroups(const VectorOfUniqueEntries& exprs) const { return expr_groups; } -ValGroups ValGraph::toGroups(const VectorOfUniqueEntries& ids) const { - ValGroups id_groups; - for (auto id : ids) { - id_groups.pushBack(toGroup(id)); +ValGroups ValGraph::toGroups(const VectorOfUniqueEntries& vals) const { + ValGroups val_groups; + for (auto val : vals) { + val_groups.pushBack(toGroup(val)); } - return id_groups; + return val_groups; } std::vector ValGraph::outputGroups(const ExprGroup& expr) const { std::vector output_groups; - for (auto id_output : ir_utils::filterByType(expr->front()->outputs())) { - output_groups.push_back(toGroup(id_output)); + for (auto output : expr->front()->outputs()) { + output_groups.push_back(toGroup(output)); } return output_groups; } std::vector ValGraph::inputGroups(const ExprGroup& expr) const { std::vector input_groups; - for (auto id_input : ir_utils::filterByType(expr->front()->inputs())) { - input_groups.push_back(toGroup(id_input)); + for (auto input : expr->front()->inputs()) { + input_groups.push_back(toGroup(input)); } return input_groups; } ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { DequeOfExprGroup to_visit; - for (const ValGroup& of_id_group : of) { - if (const ExprGroups* uses = getUses(of_id_group); uses) { - to_visit.insert(to_visit.end(), uses->begin(), uses->end()); + for (const ValGroup& of_val_group : of) { + if (const ExprGroups* group_uses = getUses(of_val_group); + group_uses != nullptr) { + to_visit.insert(to_visit.end(), group_uses->begin(), group_uses->end()); } } @@ -130,9 +128,10 @@ ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { ExprGroup current_expr = to_visit.front(); to_visit.pop_front(); visited.emplace(current_expr); - for (const ValGroup& output_id : outputGroups(current_expr)) { - if (const ExprGroups* uses = getUses(output_id); uses) { - for (const ExprGroup& group_use : *uses) { + for (const ValGroup& output_group : outputGroups(current_expr)) { + if (const ExprGroups* group_uses = getUses(output_group); + group_uses != nullptr) { + for (const ExprGroup& group_use : *group_uses) { if (visited.count(group_use)) { continue; } @@ -555,8 +554,10 @@ void ValGraph::initializeVal( def_groups.pushBack(expr_set); } // TODO-NM: def_groups can be empty. Should it be still mapped? - // TODO-NM: Can this be overwritten? - NVF_ERROR(unique_definitions_.emplace(val_disjoint_set, def_groups).second); + NVF_ERROR( + unique_definitions_.emplace(val_disjoint_set, def_groups).second, + "Multiple defining groups for ", + nvfuser::toString(val_disjoint_set)); ExprGroups use_groups; for (auto use : uses) { @@ -565,8 +566,10 @@ void ValGraph::initializeVal( use_groups.pushBack(expr_set); } // TODO-NM: use_groups can be empty. Should it be still mapped? - // TODO-NM: Can this be overwritten? - NVF_ERROR(unique_uses_.emplace(val_disjoint_set, use_groups).second); + NVF_ERROR( + unique_uses_.emplace(val_disjoint_set, use_groups).second, + "Multiple use groups for ", + nvfuser::toString(val_disjoint_set)); } void ValGraph::initializeVal(Val* val) { @@ -773,18 +776,18 @@ void ValGraph::mapExprs(Expr* expr0, Expr* expr1) { return; } - ExprGroup expr0_orig_group = toGroup(expr0); - ExprGroup expr1_orig_group = toGroup(expr1); + const ExprGroup& expr0_orig_group = toGroup(expr0); + const ExprGroup& expr1_orig_group = toGroup(expr1); disjoint_exprs_.mapEntries(expr0, expr1); - auto expr_new_group = toGroup(expr0); + const ExprGroup& expr_new_group = toGroup(expr0); // Update unique uses of producers ValGroups producers; for (auto expr : std::vector{expr0, expr1}) { - for (auto input_id : ir_utils::filterByType(expr->inputs())) { - producers.pushBack(toGroup(input_id)); + for (auto input : expr->inputs()) { + producers.pushBack(toGroup(input)); } } @@ -797,8 +800,8 @@ void ValGraph::mapExprs(Expr* expr0, Expr* expr1) { // Update unique definitinos of consumers ValGroups consumers; for (auto expr : std::vector{expr0, expr1}) { - for (auto output_id : ir_utils::filterByType(expr->outputs())) { - consumers.pushBack(toGroup(output_id)); + for (auto output : expr->outputs()) { + consumers.pushBack(toGroup(output)); } } @@ -822,12 +825,9 @@ bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { propagate_through_exprs_, "Asked to propagate expression mappings on a graph that has propagate_exprs_ disabled."); - auto first_ids = - ir_utils::filterByType(forward ? first->outputs() : first->inputs()) - .vector(); - auto second_ids = ir_utils::filterByType( - forward ? second->outputs() : second->inputs()) - .vector(); + const auto& first_ids = forward ? first->outputs() : first->inputs(); + const auto& second_ids = forward ? second->outputs() : second->inputs(); + NVF_ERROR( first_ids.size() == second_ids.size(), "This should be unreachable, if transformation expressions match, their number of inputs and outputs should as well.\n However found:\n", From 6c4a5f302793a09dd76f15c0eab05cca0ab711c2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 7 Dec 2023 19:20:22 -0800 Subject: [PATCH 094/178] fix --- csrc/compute_at_map.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index f6a7755fdc8..b6c6088a5d1 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -1499,8 +1499,9 @@ const DisjointSets& ComputeAtMap::getIdSets( return id_graph_.permissiveResizeNodes(); case IdMappingMode::INNERMOST: return id_graph_.innermostNodes(); + default: + NVF_ERROR(false, "Error with mapping mode provided: ", mode); } - NVF_ERROR(false, "Error with mapping mode provided."); } bool ComputeAtMap::idExistsInMap(IterDomain* id) const { From ad83b72629f86940d1a84a7ac017ec4402e387fa Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 7 Dec 2023 19:35:05 -0800 Subject: [PATCH 095/178] cleanup --- csrc/id_model/id_model.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 2fd8d5dea3e..a5ca554e190 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -794,11 +794,6 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { auto tv_inputs = ir_utils::filterByType(expr->inputs()); for (auto p_tv : tv_inputs) { - auto p_ids_vec = ir_utils::allIDsOf(p_tv); - auto c_ids_vec = ir_utils::allIDsOf(c_tv); - std::unordered_set p_ids(p_ids_vec.begin(), p_ids_vec.end()); - std::unordered_set c_ids(c_ids_vec.begin(), c_ids_vec.end()); - ForwardingInfo permissive_forwarding(p_tv, c_tv); for (auto entry : permissive_forwarding.producer_forwarding_map) { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); From 554bd3e834df4d0496a1690f344a2223f895b3e1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 7 Dec 2023 19:35:12 -0800 Subject: [PATCH 096/178] It should not be necessary to handle swizzles in the permissive map since the baseline exact map should map them --- csrc/id_model/id_model.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index a5ca554e190..90891d8e9da 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -825,7 +825,6 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { } } } - mapThroughLoopSwizzles(idGraph(IdMappingMode::PERMISSIVE)); } // TODO: Reenable after reenabling parallel propagation. From 4691a3a8b08c9b45db99eb1e2f2410af6902fcf2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 7 Dec 2023 20:28:44 -0800 Subject: [PATCH 097/178] Bug fix --- csrc/val_graph.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 54e4790d240..12fafe5949a 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -776,8 +776,10 @@ void ValGraph::mapExprs(Expr* expr0, Expr* expr1) { return; } - const ExprGroup& expr0_orig_group = toGroup(expr0); - const ExprGroup& expr1_orig_group = toGroup(expr1); + // Note that non-reference copies are required here as they may be + // removed by mapEntries + const ExprGroup expr0_orig_group = toGroup(expr0); + const ExprGroup expr1_orig_group = toGroup(expr1); disjoint_exprs_.mapEntries(expr0, expr1); From f5b39e013f07185b9bb08e7a919abf4ff8cf4038 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 13 Dec 2023 12:37:17 -0800 Subject: [PATCH 098/178] Fix mapping propagation through backward merge (#1511) While working on comparing the Permissive graph with the CA Permissive map, I found one thing that I think should be considered inconsistency. When backward propagating mappings through a Merge expr, we require at least one of the input domain pairs should be already mapped or have equal extents, which makes sense, but the current code only checks the extents of one ID out of each ID group, which is fine with the Exact graph but in the case of the Permissive graph, each ID group may consist of domains that have different extents, so I think we would also need to check all domains in the group, and if there's one matching pair, we should allow propagation. This PR also adds the validation of the Permissive graph. It does not pass yet as the CA map does not map the compliment IDs. --- csrc/id_model/id_model.cpp | 12 +++- csrc/id_model/validation_utils.cpp | 19 ++++++ csrc/id_model/validation_utils.h | 2 + csrc/val_graph.cpp | 96 +++++++++++++++++++++++------- 4 files changed, 106 insertions(+), 23 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 90891d8e9da..698fdf00bdd 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -802,7 +802,9 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { // TODO: Should this just get rolled up in the forwarding map now? for (const auto& entry : permissive_forwarding.producer_compliment_map) { for (auto entry_2 : entry.second) { - idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry_2); + if (getenv("COMP")) { + idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry_2); + } } } @@ -814,7 +816,9 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { // TODO: Why should IDs be mapped to their compliments? Is this right? for (const auto& entry : permissive_forwarding.consumer_compliment_map) { for (auto entry_2 : entry.second) { - idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry_2); + if (getenv("COMP")) { + idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry_2); + } } } @@ -1072,6 +1076,10 @@ void IdModel::build( } buildPermissiveMap(tv_exprs); + if (validate) { + validator->checkPermissiveGraphEquivalence( + idGraph(IdMappingMode::PERMISSIVE)); + } // Permissive graph needs the trivial exprs from the almost exact graph to // build correctly. Once built though we can remove the trivial expressions diff --git a/csrc/id_model/validation_utils.cpp b/csrc/id_model/validation_utils.cpp index 0a2111f83c2..82a924c4d56 100644 --- a/csrc/id_model/validation_utils.cpp +++ b/csrc/id_model/validation_utils.cpp @@ -196,4 +196,23 @@ void IdModelValidator::checkAlmostExactGraphEquivalence( compareDisjointSets(ca_map_sets, almost_exact_graph.disjointValSets()); } +void IdModelValidator::checkPermissiveGraphEquivalence( + const ValGraph& permissive_graph) { + if (has_swizzle_) { + // Ignoring a fusion with swizzle + return; + } + + // Empty graph + if (permissive_graph.disjointValSets().disjointSets().empty()) { + return; + } + + DisjointSets ca_map_sets = ca_map_.id_graph_.permissive_nodes_; + + fullyPropagateMappings(ca_map_sets); + + compareDisjointSets(ca_map_sets, permissive_graph.disjointValSets()); +} + } // namespace nvfuser diff --git a/csrc/id_model/validation_utils.h b/csrc/id_model/validation_utils.h index 530ab3f1b01..a1252e25d7e 100644 --- a/csrc/id_model/validation_utils.h +++ b/csrc/id_model/validation_utils.h @@ -40,6 +40,8 @@ class IdModelValidator { void checkAlmostExactGraphEquivalence(const ValGraph& almost_exact_graph); + void checkPermissiveGraphEquivalence(const ValGraph& permissive_graph); + private: // Propagate mappings in a ComputeAtMap as is done in IdModel static void fullyPropagateMappings(DisjointSets& id_sets); diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 12fafe5949a..a297346ec7e 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -10,6 +10,8 @@ #include #include +#include + namespace nvfuser { namespace { @@ -592,6 +594,78 @@ void ValGraph::registerExpr(Expr* expr) { disjoint_exprs_.initializeSet(expr); } +namespace { + +// Can't back prop through merge without making sure one input actually +// matches. This can be done on a map or extent basis. +bool mapMergeBackward(Merge* merge0, Merge* merge1, const ValGraph& graph) { + auto extent_match = [](IterDomain* id0, IterDomain* id1) -> bool { + return id0->extent()->sameAs(id1->extent()) || + (id0->extent()->isConstInt() && id1->extent()->isConstInt() && + id0->extent()->evaluate() == id1->extent()->evaluate()); + }; + + // If one pair of the domains are mapped in the given graph, the + // backward merge is considered mapped + if (graph.disjointValSets().permissiveAreMapped( + merge0->outer(), merge1->outer()) || + graph.disjointValSets().permissiveAreMapped( + merge0->inner(), merge1->inner())) { + return true; + } + + // Considered mapped if the extents are equal + if (extent_match(merge0->outer(), merge1->outer()) || + extent_match(merge0->inner(), merge1->inner())) { + return true; + } + + // The mapped ID group may have different extents depending on the + // mapping conditions. For example, the Permissive graph may have a + // symbolic extent as well as an extent of 1 for broadcast + // domains. Those other mapped domains need to be checked as well. + + // First, the outer groups + ValGroup outer0_group = graph.hasGroup(merge0->outer()) + ? graph.toGroup(merge0->outer()) + : std::make_shared>( + VectorOfUniqueEntries{merge0->outer()}); + ValGroup outer1_group = graph.hasGroup(merge1->outer()) + ? graph.toGroup(merge1->outer()) + : std::make_shared>( + VectorOfUniqueEntries{merge1->outer()}); + + for (Val* outer0 : *outer0_group) { + for (Val* outer1 : *outer1_group) { + if (extent_match(outer0->as(), outer1->as())) { + return true; + } + } + } + + // Check the inner groups as well if not already matched + ValGroup inner0_group = graph.hasGroup(merge0->inner()) + ? graph.toGroup(merge0->inner()) + : std::make_shared>( + VectorOfUniqueEntries{merge0->inner()}); + ValGroup inner1_group = graph.hasGroup(merge1->inner()) + ? graph.toGroup(merge1->inner()) + : std::make_shared>( + VectorOfUniqueEntries{merge1->inner()}); + + for (Val* inner0 : *inner0_group) { + for (Val* inner1 : *inner1_group) { + if (extent_match(inner0->as(), inner1->as())) { + return true; + } + } + } + + return false; +} + +} // namespace + bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { NVF_ERROR(first); NVF_ERROR(second); @@ -621,27 +695,7 @@ bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { // Special handling for backprop of merge if (first->isA() && !forward) { - // Can't back prop through merge without making sure one input actually - // matches. This can be done on a map or extent basis. - auto merge0 = first->as(); - auto merge1 = second->as(); - - auto extent_0o = merge0->outer()->extent(); - auto extent_0i = merge0->inner()->extent(); - auto extent_1o = merge1->outer()->extent(); - auto extent_1i = merge1->inner()->extent(); - - auto extent_o_match = extent_0o->sameAs(extent_1o) || - (extent_0o->isConstInt() && extent_1o->isConstInt() && - extent_0o->evaluate() == extent_1o->evaluate()) || - disjointValSets().permissiveAreMapped(merge0->outer(), merge1->outer()); - - auto extent_i_match = extent_0i->sameAs(extent_1i) || - (extent_0i->isConstInt() && extent_1i->isConstInt() && - extent_0i->evaluate() == extent_1i->evaluate()) || - disjointValSets().permissiveAreMapped(merge0->inner(), merge1->inner()); - - if (!(extent_o_match || extent_i_match)) { + if (!mapMergeBackward(first->as(), second->as(), *this)) { return false; } } From bc8fc054f1c4d4227a3f572786a57ac4ab7d5145 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 13 Dec 2023 15:41:05 -0800 Subject: [PATCH 099/178] Temporarily remove IdMappingMode::INDEX as it doesn't exist yet --- csrc/type.cpp | 10 ++++------ csrc/type.h | 4 +--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/csrc/type.cpp b/csrc/type.cpp index c056143e789..eb76b0a53b5 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -763,16 +763,14 @@ static const char* id_map_mode_type2string(IdMappingMode t) { return "exact"; case IdMappingMode::ALMOSTEXACT: return "almost_exact"; - case IdMappingMode::INDEX: - return "index"; - case IdMappingMode::LOOP: - return "loop"; case IdMappingMode::PERMISSIVE: return "permissive"; - case IdMappingMode::PERMISSIVE_RESIZE: - return "permissive_resize"; + case IdMappingMode::LOOP: + return "loop"; case IdMappingMode::INNERMOST: return "innermost"; + case IdMappingMode::PERMISSIVE_RESIZE: + return "permissive_resize"; default: // Don't try to print t as it would recursively call this function NVF_ERROR(false, "Unexpected IdMappingMode Type."); diff --git a/csrc/type.h b/csrc/type.h index 970a625b5ad..d4585a86a71 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -711,16 +711,14 @@ enum class IdMappingMode { EXACT, ALMOSTEXACT, LOOP, - INDEX, PERMISSIVE, PERMISSIVE_RESIZE, INNERMOST }; -static constexpr std::array kIdMappingModes = { +static constexpr std::array kIdMappingModes = { IdMappingMode::EXACT, IdMappingMode::ALMOSTEXACT, - IdMappingMode::INDEX, IdMappingMode::LOOP, IdMappingMode::PERMISSIVE, IdMappingMode::PERMISSIVE_RESIZE, From 9b6d761ff464f7b01e8e94c32cd4f9195f0eb61b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 13 Dec 2023 15:52:37 -0800 Subject: [PATCH 100/178] Revert unnecessary change --- csrc/compute_at_map.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index ff645ef720b..f6f6c518ca9 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -1508,9 +1508,8 @@ const DisjointSets& ComputeAtMap::getIdSets( return id_graph_.permissiveResizeNodes(); case IdMappingMode::INNERMOST: return id_graph_.innermostNodes(); - default: - NVF_ERROR(false, "Error with mapping mode provided: ", mode); } + NVF_ERROR(false, "Error with mapping mode provided."); } bool ComputeAtMap::idExistsInMap(IterDomain* id) const { From 4d2dc50e9507d0904d3190c256f68cfba2cc9a5e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 13 Dec 2023 15:54:49 -0800 Subject: [PATCH 101/178] cleanup --- csrc/id_model/id_model.cpp | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 698fdf00bdd..4240b358ab5 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -55,19 +55,17 @@ void mapThroughLoopSwizzles(ValGraph& graph) { } // namespace void IdModel::assertNoSelfMapping() { - if (hasSelfMapping()) { - NVF_ERROR( - !hasSelfMapping(), - "Unsupported domain mapping detected in ", - std::get<0>(*self_mapping_info_)->toString(), - ". ", - std::get<3>(*self_mapping_info_), - " domains, ", - std::get<1>(*self_mapping_info_)->toString(), - " and ", - std::get<2>(*self_mapping_info_)->toString(), - ", are mapped with each other."); - } + NVF_ERROR( + !hasSelfMapping(), + "Unsupported domain mapping detected in ", + std::get<0>(*self_mapping_info_)->toString(), + ". ", + std::get<3>(*self_mapping_info_), + " domains, ", + std::get<1>(*self_mapping_info_)->toString(), + " and ", + std::get<2>(*self_mapping_info_)->toString(), + ", are mapped with each other."); } IdModel::IdModel( From 4ad9eebe47e93cd0e35bd95680acb423e356347d Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 13 Dec 2023 15:55:37 -0800 Subject: [PATCH 102/178] Make it explicit that broadcasts are mapped --- csrc/id_model/id_model.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 4240b358ab5..30aa4be9ce2 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -820,7 +820,8 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { } } - auto permissive_c2p_root_map = PairwiseRootDomainMap(p_tv, c_tv); + auto permissive_c2p_root_map = + PairwiseRootDomainMap(p_tv, c_tv).mapBroadcast(true); for (auto entry : permissive_c2p_root_map.mapConsumerToProducer()) { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); From 151b0ef8bfec63ed8dfdab1911292493f403500f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 21 Dec 2023 15:56:43 -0800 Subject: [PATCH 103/178] [IdModel] Propagation fix for permissive graphs (#1538) Same fix as in #1522 --- csrc/id_model/validation_utils.cpp | 101 ++++++++++++++++++++++++++++- csrc/val_graph.cpp | 85 ++++-------------------- csrc/val_graph.h | 80 +++++++++++++++++++++++ 3 files changed, 189 insertions(+), 77 deletions(-) diff --git a/csrc/id_model/validation_utils.cpp b/csrc/id_model/validation_utils.cpp index 82a924c4d56..7608b2f91f1 100644 --- a/csrc/id_model/validation_utils.cpp +++ b/csrc/id_model/validation_utils.cpp @@ -15,6 +15,100 @@ namespace nvfuser { +namespace { + +// Same as IterDomain::exprsMap but uses +// ValGraph::mapMergeBackward. Copying the funciton here isn't ideal, +// but it doesn't make sense to change the ComputeAtMap code just for +// this validation. +bool exprsMap( + Expr* first, + Expr* second, + bool forward, + const DisjointSets& id_map) { + if (first == nullptr || second == nullptr) { + return false; + } + + if (typeid(*first) != typeid(*second)) { + return false; + } + + NVF_ERROR( + first->isA() || first->isA() || first->isA(), + "Merge, split and resize are the only expressions supported through rfactor operations in compute at map, but found:\n", + first->toString()); + + auto first_ids = ir_utils::filterByType( + forward ? first->inputs() : first->outputs()) + .vector(); + + auto second_ids = ir_utils::filterByType( + forward ? second->inputs() : second->outputs()) + .vector(); + + NVF_ERROR( + first_ids.size() == second_ids.size(), + "Expected number of ", + (forward ? "inputs" : "outputs"), + " to match for\n", + first->toString(), + second->toString()); + + { + std::vector> zipped_ids; + + std::transform( + first_ids.begin(), + first_ids.end(), + second_ids.begin(), + std::back_inserter(zipped_ids), + [](IterDomain* first, IterDomain* second) { + return std::make_pair(first, second); + }); + + if (std::any_of( + zipped_ids.begin(), + zipped_ids.end(), + [&](std::pair id_pair) { + return !id_map.strictAreMapped(id_pair.first, id_pair.second); + })) { + return false; + } + } + + if (first->isA() && !forward) { + if (!ValGraph::shouldMapMergeBackward( + first->as(), second->as(), id_map)) { + return false; + } + } + + if (first->isA()) { + auto first_split = first->as(); + auto second_split = second->as(); + if (!first_split->factor()->sameAs(second_split->factor()) || + first_split->innerSplit() != second_split->innerSplit() || + !first_split->startOffset()->sameAs(second_split->startOffset()) || + !first_split->stopOffset()->sameAs(second_split->stopOffset())) { + return false; + } + } + + if (first->isA()) { + auto first_resize = first->as(); + auto second_resize = second->as(); + if (!first_resize->leftExpand()->sameAs(second_resize->leftExpand()) || + !first_resize->rightExpand()->sameAs(second_resize->rightExpand())) { + return false; + } + } + + return true; +} + +} // namespace + IdModelValidator::IdModelValidator(Fusion* fusion) : ca_map_(fusion) { for (auto tv : ir_utils::allTvs(fusion)) { for (auto id : ir_utils::allIDsOf(tv)) { @@ -55,7 +149,9 @@ void IdModelValidator::fullyPropagateMappings( } } } else { - all_exprs.push_back(id->definition()); + if (id->definition()) { + all_exprs.push_back(id->definition()); + } } } @@ -68,8 +164,7 @@ void IdModelValidator::fullyPropagateMappings( auto expr_i = all_exprs.at(i); for (size_t j = i + 1; j < count; ++j) { auto expr_j = all_exprs.at(j); - if (!IterDomainGraph::exprsMap( - expr_i, expr_j, is_forward, id_sets)) { + if (!exprsMap(expr_i, expr_j, is_forward, id_sets)) { continue; } const auto& prop_target_i = diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index a297346ec7e..c06571ba98c 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -509,7 +509,7 @@ bool ValGraph::transformAtributesMatch(Expr* first, Expr* second) { NVF_ERROR( first->isA() || first->isA() || first->isA() || - first->isA(), + first->isA() || first->isA(), "Unsupported rfactor expressions in compute at map:\n", first->toString()); @@ -537,6 +537,14 @@ bool ValGraph::transformAtributesMatch(Expr* first, Expr* second) { } } + if (first->isA()) { + auto swizzle_1 = first->as(); + auto swizzle_2 = first->as(); + if (swizzle_1->swizzleType() != swizzle_2->swizzleType()) { + return false; + } + } + // TODO: Resize properties return true; @@ -594,78 +602,6 @@ void ValGraph::registerExpr(Expr* expr) { disjoint_exprs_.initializeSet(expr); } -namespace { - -// Can't back prop through merge without making sure one input actually -// matches. This can be done on a map or extent basis. -bool mapMergeBackward(Merge* merge0, Merge* merge1, const ValGraph& graph) { - auto extent_match = [](IterDomain* id0, IterDomain* id1) -> bool { - return id0->extent()->sameAs(id1->extent()) || - (id0->extent()->isConstInt() && id1->extent()->isConstInt() && - id0->extent()->evaluate() == id1->extent()->evaluate()); - }; - - // If one pair of the domains are mapped in the given graph, the - // backward merge is considered mapped - if (graph.disjointValSets().permissiveAreMapped( - merge0->outer(), merge1->outer()) || - graph.disjointValSets().permissiveAreMapped( - merge0->inner(), merge1->inner())) { - return true; - } - - // Considered mapped if the extents are equal - if (extent_match(merge0->outer(), merge1->outer()) || - extent_match(merge0->inner(), merge1->inner())) { - return true; - } - - // The mapped ID group may have different extents depending on the - // mapping conditions. For example, the Permissive graph may have a - // symbolic extent as well as an extent of 1 for broadcast - // domains. Those other mapped domains need to be checked as well. - - // First, the outer groups - ValGroup outer0_group = graph.hasGroup(merge0->outer()) - ? graph.toGroup(merge0->outer()) - : std::make_shared>( - VectorOfUniqueEntries{merge0->outer()}); - ValGroup outer1_group = graph.hasGroup(merge1->outer()) - ? graph.toGroup(merge1->outer()) - : std::make_shared>( - VectorOfUniqueEntries{merge1->outer()}); - - for (Val* outer0 : *outer0_group) { - for (Val* outer1 : *outer1_group) { - if (extent_match(outer0->as(), outer1->as())) { - return true; - } - } - } - - // Check the inner groups as well if not already matched - ValGroup inner0_group = graph.hasGroup(merge0->inner()) - ? graph.toGroup(merge0->inner()) - : std::make_shared>( - VectorOfUniqueEntries{merge0->inner()}); - ValGroup inner1_group = graph.hasGroup(merge1->inner()) - ? graph.toGroup(merge1->inner()) - : std::make_shared>( - VectorOfUniqueEntries{merge1->inner()}); - - for (Val* inner0 : *inner0_group) { - for (Val* inner1 : *inner1_group) { - if (extent_match(inner0->as(), inner1->as())) { - return true; - } - } - } - - return false; -} - -} // namespace - bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { NVF_ERROR(first); NVF_ERROR(second); @@ -695,7 +631,8 @@ bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { // Special handling for backprop of merge if (first->isA() && !forward) { - if (!mapMergeBackward(first->as(), second->as(), *this)) { + if (!shouldMapMergeBackward( + first->as(), second->as(), this->disjointValSets())) { return false; } } diff --git a/csrc/val_graph.h b/csrc/val_graph.h index e3199ae0c50..f84f01e61db 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -218,6 +218,86 @@ class ValGraph { propagate_through_exprs_ = b; } + // Can't back prop through merge without making sure one input actually + // matches. This can be done on a map or extent basis. + // TODO: Move this to val_graph.cpp once validation_utils.cpp is + // retired. + template + static bool shouldMapMergeBackward( + Merge* merge0, + Merge* merge1, + const DisjointSets& id_sets) { + auto extent_match = [](IterDomain* id0, IterDomain* id1) -> bool { + return id0->extent()->sameAs(id1->extent()) || + (id0->extent()->isConstInt() && id1->extent()->isConstInt() && + id0->extent()->evaluate() == id1->extent()->evaluate()); + }; + + // If one pair of the domains are mapped in the given graph, the + // backward merge is considered mapped + if (id_sets.permissiveAreMapped(merge0->outer(), merge1->outer()) || + id_sets.permissiveAreMapped(merge0->inner(), merge1->inner())) { + return true; + } + + // Considered mapped if the extents are equal + if (extent_match(merge0->outer(), merge1->outer()) || + extent_match(merge0->inner(), merge1->inner())) { + return true; + } + + // The mapped ID group may have different extents depending on the + // mapping conditions. For example, the Permissive graph may have a + // symbolic extent as well as an extent of 1 for broadcast + // domains. Those other mapped domains need to be checked as well. + + // First, the outer groups + auto outer0_group = id_sets.mappingExists(merge0->outer()) + ? id_sets.disjointSetMap().at(merge0->outer()) + : std::make_shared>( + VectorOfUniqueEntries{merge0->outer()}); + auto outer1_group = id_sets.mappingExists(merge1->outer()) + ? id_sets.disjointSetMap().at(merge1->outer()) + : std::make_shared>( + VectorOfUniqueEntries{merge1->outer()}); + + for (T* outer0 : *outer0_group) { + for (T* outer1 : *outer1_group) { + if (extent_match( + outer0->template as(), + outer1->template as())) { + // std::cerr << "outer are equal: " << outer0->name() << ", " << + // outer1->name() << std::endl; + return true; + } + } + } + + // Check the inner groups as well if not already matched + auto inner0_group = id_sets.mappingExists(merge0->inner()) + ? id_sets.disjointSetMap().at(merge0->inner()) + : std::make_shared>( + VectorOfUniqueEntries{merge0->inner()}); + auto inner1_group = id_sets.mappingExists(merge1->inner()) + ? id_sets.disjointSetMap().at(merge1->inner()) + : std::make_shared>( + VectorOfUniqueEntries{merge1->inner()}); + + for (T* inner0 : *inner0_group) { + for (T* inner1 : *inner1_group) { + if (extent_match( + inner0->template as(), + inner1->template as())) { + // std::cerr << "inner are equal: " << inner0->name() << ", " << + // inner1->name() << std::endl; + return true; + } + } + } + + return false; + } + private: // Map expr0 and expr1 with each other, update unique_definitions_ // unique_uses_ From e45b1bb128f3315869504dc05adba31f0f3e5ea7 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 27 Dec 2023 18:38:40 -0800 Subject: [PATCH 104/178] [IdModel] merge main 20231227 (#1564) Merging 101530ec09a7eab9937f277715be0e3ecc1478b6 --- CMakeLists.txt | 2 +- csrc/alias_analysis.h | 6 +++--- csrc/id_model/id_model.cpp | 7 +------ csrc/id_model/id_model.h | 12 ++++++------ csrc/id_model/validation_utils.cpp | 13 +++++++++++-- ...ptimize_layout.cpp => mark_aliases_prepare.cpp} | 8 ++++---- .../{optimize_layout.h => mark_aliases_prepare.h} | 9 +++++---- csrc/optimization/pre_segmenter.cpp | 4 ++-- csrc/val_graph.h | 4 ---- python_tests/test_python_frontend.py | 6 +++--- test/test_gpu_transpose.cpp | 14 +++++++------- 11 files changed, 43 insertions(+), 42 deletions(-) rename csrc/optimization/{optimize_layout.cpp => mark_aliases_prepare.cpp} (93%) rename csrc/optimization/{optimize_layout.h => mark_aliases_prepare.h} (53%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6431692cd94..d1e25283578 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -198,7 +198,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp ${NVFUSER_SRCS_DIR}/optimization/add_axioms.cpp ${NVFUSER_SRCS_DIR}/optimization/consecutive_cast.cpp - ${NVFUSER_SRCS_DIR}/optimization/optimize_layout.cpp + ${NVFUSER_SRCS_DIR}/optimization/mark_aliases_prepare.cpp ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp ${NVFUSER_SRCS_DIR}/val_graph.cpp diff --git a/csrc/alias_analysis.h b/csrc/alias_analysis.h index fda37e0f464..4eff46e0d2e 100644 --- a/csrc/alias_analysis.h +++ b/csrc/alias_analysis.h @@ -104,9 +104,9 @@ class AliasAnalysisResult { // implement given the current infrastructure. // // Therefore, I chose to run alias analysis both before segmentation and in -// schedulers. The former, used by OptimizeLayoutPass, updates layouts to enable -// aliases; the latter, used by NoOpScheduler, calls Fusion::aliasOutputToInput -// to mark aliases. +// schedulers. The former, used by MarkAliasesPreparePass, updates layouts to +// enable aliases; the latter, used by NoOpScheduler, calls +// Fusion::aliasOutputToInput to mark aliases. AliasAnalysisResult findAliases( Fusion* fusion, bool can_override_empty_allocation_domain = true); diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 30aa4be9ce2..c046112d0d7 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -783,8 +783,6 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { idGraph(IdMappingMode::PERMISSIVE) = idGraph(IdMappingMode::EXACT); for (auto expr : exprs) { - // Multiple outputs are already mapped, we can ignore all but the first - // consumer given they have to be replayed in the same exact way // Multiple outputs are already mapped, we can ignore all but the first // consumer given they have to be replayed in the same exact way TensorView* c_tv = ir_utils::getTvOutput(expr); @@ -797,7 +795,6 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); } - // TODO: Should this just get rolled up in the forwarding map now? for (const auto& entry : permissive_forwarding.producer_compliment_map) { for (auto entry_2 : entry.second) { if (getenv("COMP")) { @@ -810,8 +807,6 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); } - // TODO: Should this just get rolled up in the forwarding map now? - // TODO: Why should IDs be mapped to their compliments? Is this right? for (const auto& entry : permissive_forwarding.consumer_compliment_map) { for (auto entry_2 : entry.second) { if (getenv("COMP")) { @@ -1127,7 +1122,7 @@ void IdModel::build( */ } - // Debug, make sure there's no self mapping in TensorView's during lowering + // Make sure there's no self mapping in TensorView's during lowering // that would invalidate lowering assumptions. self_mapping_info_ = findFirstSelfMapping(all_tvs.vector(), *this); } diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 7d6c1ee0e68..f9f39642c25 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -66,10 +66,6 @@ struct StatefulLoweringInfo; // IdMappingMode::EXACT // Don't map any broadcast axes to non-broadcast axes // Do not forward through any broadcast IDs -// IdMappingMode::LOOP -// Forward broadcast axes in replay -// Denotes groups of IterDomains that are considered promoted to a common iter -// domain size // IdMappingMode::PERMISSIVE // Forward broadcast axes in replay // Map all iteration domains @@ -81,6 +77,10 @@ struct StatefulLoweringInfo; // id{i1*i0}, id{i0} are not mapped (this part is the difference from // PERMISSIVE) // Forward through split one axes, i.e. id{ceilDiv(i0, 1)}, id{i0} are mapped +// IdMappingMode::LOOP +// Forward broadcast axes in replay +// Denotes groups of IterDomains that are considered promoted to a common iter +// domain size // class IdModel : public PolymorphicBase { public: @@ -203,8 +203,8 @@ class IdModel : public PolymorphicBase { // split by a size-1 dimension. void buildAlmostExactMap(); - // Fills disjoint_ids_[IdMappingMode::PERMISSIVE]. Initialize PermissiveMap as - // AlmostExact entries, then map through broadcasts + // Fills disjoint_ids_[IdMappingMode::PERMISSIVE]. Initialize it as + // Exact entries, then map through broadcasts void buildPermissiveMap(const std::vector& exprs); // Make sure only leaf nodes of tensor views are parallelized diff --git a/csrc/id_model/validation_utils.cpp b/csrc/id_model/validation_utils.cpp index 7608b2f91f1..f9e80056535 100644 --- a/csrc/id_model/validation_utils.cpp +++ b/csrc/id_model/validation_utils.cpp @@ -35,8 +35,9 @@ bool exprsMap( } NVF_ERROR( - first->isA() || first->isA() || first->isA(), - "Merge, split and resize are the only expressions supported through rfactor operations in compute at map, but found:\n", + first->isA() || first->isA() || first->isA() || + first->isA(), + "Merge, split, resize and swizzle are the only expressions supported here, but found:\n", first->toString()); auto first_ids = ir_utils::filterByType( @@ -104,6 +105,14 @@ bool exprsMap( } } + if (first->isA()) { + auto swizzle_1 = first->as(); + auto swizzle_2 = first->as(); + if (swizzle_1->swizzleType() != swizzle_2->swizzleType()) { + return false; + } + } + return true; } diff --git a/csrc/optimization/optimize_layout.cpp b/csrc/optimization/mark_aliases_prepare.cpp similarity index 93% rename from csrc/optimization/optimize_layout.cpp rename to csrc/optimization/mark_aliases_prepare.cpp index d0dd256c8a8..9386a8830bc 100644 --- a/csrc/optimization/optimize_layout.cpp +++ b/csrc/optimization/mark_aliases_prepare.cpp @@ -8,16 +8,16 @@ #include #include #include -#include +#include #include namespace nvfuser::optimization { -void OptimizeLayoutPass::runPass(Fusion* fusion) { +void MarkAliasesPreparePass::runPass(Fusion* fusion) { const AliasAnalysisResult analysis = findAliases(fusion, /*can_override_empty_allocation_domain=*/true); if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { - debug() << "Fusion before OptimizeLayoutPass:" << std::endl; + debug() << "Fusion before MarkAliasesPreparePass:" << std::endl; fusion->printMath(); debug() << "Alias analysis result:" << std::endl; debug() << analysis.toString(/*indent_size=*/1) << std::endl; @@ -105,7 +105,7 @@ void OptimizeLayoutPass::runPass(Fusion* fusion) { } if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { - debug() << "Fusion after OptimizeLayoutPass:" << std::endl; + debug() << "Fusion after MarkAliasesPreparePass:" << std::endl; fusion->printMath(); fusion->printTransforms(); } diff --git a/csrc/optimization/optimize_layout.h b/csrc/optimization/mark_aliases_prepare.h similarity index 53% rename from csrc/optimization/optimize_layout.h rename to csrc/optimization/mark_aliases_prepare.h index fb5245ade2e..c4c4a82cc6e 100644 --- a/csrc/optimization/optimize_layout.h +++ b/csrc/optimization/mark_aliases_prepare.h @@ -9,10 +9,11 @@ namespace nvfuser::optimization { -// Updates layouts to enable aliases. -// TODO(wujingyue): Rename. It also inserts segment_set to help segmentation. -class OptimizeLayoutPass : public OptimizationPass { - friend class OptimizationPass; +// Prepares the input fusion for marking aliases. It currently updates layouts +// to enable aliases, and inserts `SegmenterSet`s so segmentation will separate +// out alias-only regions. +class MarkAliasesPreparePass : public OptimizationPass { + friend class OptimizationPass; protected: static void runPass(Fusion* fusion); diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index f124b3a1a6f..9d4c5c5347c 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include namespace nvfuser::optimization { @@ -20,7 +20,7 @@ void PreSegmenter::runPass(Fusion* fusion) { // removes consecutive cast operations OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); } } // namespace nvfuser::optimization diff --git a/csrc/val_graph.h b/csrc/val_graph.h index f84f01e61db..38d587abc46 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -266,8 +266,6 @@ class ValGraph { if (extent_match( outer0->template as(), outer1->template as())) { - // std::cerr << "outer are equal: " << outer0->name() << ", " << - // outer1->name() << std::endl; return true; } } @@ -288,8 +286,6 @@ class ValGraph { if (extent_match( inner0->template as(), inner1->template as())) { - // std::cerr << "inner are equal: " << inner0->name() << ", " << - // inner1->name() << std::endl; return true; } } diff --git a/python_tests/test_python_frontend.py b/python_tests/test_python_frontend.py index 298bf0cb3d4..9383a4c4746 100644 --- a/python_tests/test_python_frontend.py +++ b/python_tests/test_python_frontend.py @@ -2995,9 +2995,9 @@ def fusion_func(fd: FusionDefinition) -> None: torch_ref = inputs[0] * (inputs[1] * inputs[2]).unsqueeze(-1) self.assertEqual(nvf_out[0], torch_ref) - # This tests no dead code at definition does not cause a problem due to - # removal of empty tensors - # See https://github.com/NVIDIA/Fuser/pull/1270 + # Test that expand+pad does not cause indexing error, and that no scalars + # are lost during segmentation. + # See https://github.com/NVIDIA/Fuser/issues/1277 def test_issue1277(self): inputs = [ 0.5, diff --git a/test/test_gpu_transpose.cpp b/test/test_gpu_transpose.cpp index dbe98582be6..c54e2cbba74 100644 --- a/test/test_gpu_transpose.cpp +++ b/test/test_gpu_transpose.cpp @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include #include #include @@ -44,15 +44,15 @@ class TransposeTest : public NVFuserTest { protected: void SetUp() override { NVFuserTest::SetUp(); - previously_enabled_ = optimization::OptimizeLayoutPass::getEnabled(); - // For convenience, disable OptimizeLayoutPass. Many tests in this file run - // a fusion that consists of `transpose` only. OptimizeLayoutPass would turn - // those fusions into a no-op, skipping the transpose scheduler. - optimization::OptimizeLayoutPass::setEnabled(false); + previously_enabled_ = optimization::MarkAliasesPreparePass::getEnabled(); + // For convenience, disable MarkAliasesPreparePass. Many tests in this file + // run a fusion that consists of `transpose` only. MarkAliasesPreparePass + // would turn those fusions into a no-op, skipping the transpose scheduler. + optimization::MarkAliasesPreparePass::setEnabled(false); } void TearDown() override { - optimization::OptimizeLayoutPass::setEnabled(previously_enabled_); + optimization::MarkAliasesPreparePass::setEnabled(previously_enabled_); NVFuserTest::TearDown(); } From f9c9d37e96d9d562c7713be94132f80b098865fa Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 29 Dec 2023 12:45:14 -0800 Subject: [PATCH 105/178] cleanup --- csrc/id_model/id_model.cpp | 16 ---------------- csrc/id_model/id_model.h | 19 +++---------------- 2 files changed, 3 insertions(+), 32 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index c046112d0d7..4c000502aaa 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -115,22 +115,6 @@ ValGraph& IdModel::idGraph(IdMappingMode mode) { return graph_it->second; } -Expr* IdModel::idUse(IterDomain* id) const { - auto use_it = id_uses_.find(id); - if (use_it == id_uses_.end()) { - return nullptr; - } - return use_it->second.front(); -} - -Expr* IdModel::idDef(IterDomain* id) const { - auto def_it = id_definitions_.find(id); - if (def_it == id_definitions_.end()) { - return nullptr; - } - return def_it->second.front(); -} - namespace { // Returns the first pair of id's in ids detected to match each other on the diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index f9f39642c25..d64768dfa2f 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -104,19 +104,6 @@ class IdModel : public PolymorphicBase { const ValGraph& idGraph(IdMappingMode mode) const; ValGraph& idGraph(IdMappingMode mode); - // IterDomains from the original fusion are only allowed to be used once in - // the IterDomain graph, id->uses() are not directly used as there's no bounds - // check that would prevent a use from being defined that's not part of the - // actual fusion definition. - // - // Note, any iter domains used during something like loop or concrete id - // resolution could actually have multiple Expr* uses, and uses on disjoint id - // sets should be used, not this. - // - // TODO: Refactor or remove? - Expr* idUse(IterDomain* id) const; - Expr* idDef(IterDomain* id) const; - // TODO: Seems a bit unfortunate that this isn't IterDomain local information. const std::unordered_set& viewRfactorIds() const { return view_rfactor_ids_; @@ -190,8 +177,8 @@ class IdModel : public PolymorphicBase { void buildIterDomainDefinitionsAndUses( const std::vector& all_tvs); - // Iterates over all IterDomains in id_definitions_ and calls initializeID on - // a new IdGraph and returns it. + // Iterates over all IterDomains in id_definitions_ and calls initializeVal on + // a new ValGraph and returns it. ValGraph initializeIdGraph(bool propagate_through_exprs = true); // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs @@ -272,7 +259,7 @@ class IdModel : public PolymorphicBase { // Errors if self mapping occurs void assertNoSelfMapping(); - // Keeps a disjoint set entry for all IterDomain for all mapping mode types. + // Keeps ValGraphs containing all IterDomains for all mapping mode types. // // Using an array here might be nice, but it seems hard to use an enum as an // array key From 6fbebc9be2b7a0178fb89c42ec3ac7fcdd7715c9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 29 Dec 2023 21:56:39 -0800 Subject: [PATCH 106/178] cleanup --- csrc/id_model/id_model.cpp | 423 ++++++++++++++++++------------------- csrc/id_model/id_model.h | 72 ++++--- csrc/val_graph.cpp | 46 ++-- csrc/val_graph.h | 4 +- 4 files changed, 275 insertions(+), 270 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 4c000502aaa..46599390557 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -809,82 +809,16 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { } } -// TODO: Reenable after reenabling parallel propagation. -// propagateLoopPTypes -void IdModel::validatePTypes(const std::vector& all_tvs) const { - // VectorOfUniqueEntries leaf_ids; - // for (auto tv : all_tvs) { - // leaf_ids.pushBack(tv->domain()->leaf()); - // } - - // for (const auto& disjoint_set : - // idGraph(IdMappingMode::EXACT).disjointValSets().disjointSets()) { - // for (auto id : disjoint_set->vector()) { - // auto id_ptype = id->getParallelType(); - - // NVF_ERROR( - // leaf_ids.has(id) || id_ptype == ParallelType::Serial, - // "Invalid parallelization of non leaf iter domain: ", - // id->toString()); - // } - // } -} - -void IdModel::propagateLoopPTypes() const { - for (const auto& loop_disjoint_set : - idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { - ParallelType common_ptype = ParallelType::Serial; - for (auto id : loop_disjoint_set->vector()) { - auto id_ptype = id->as()->getParallelType(); - - NVF_ERROR( - id_ptype == common_ptype || id_ptype == ParallelType::Serial || - common_ptype == ParallelType::Serial, - "Issue validating parallel type disjoint ptype is, ", - common_ptype, - " but found in the set the id: ", - id->toString()); - - common_ptype = - common_ptype == ParallelType::Serial ? id_ptype : common_ptype; - } - - for (auto id : loop_disjoint_set->vector()) { - id->as()->parallelize(common_ptype); - } - } -} - namespace { -struct StatefulLoweringInfo { - // Tracks all p2c mappings in permissive maps even those not inlined between - // producer and consumer - std::unordered_map> - p2c_permissive_maps; - - // All consumer ids in a deterministic order (ignores fusion->inputs()) - VectorOfUniqueEntries ordered_c_ids; - - // p2c mappings through the fusion within (including dependencies of) inlined - // leaf domains. - std::unordered_map> - p2c_ca_permissive_maps; - - // All producer ids within (including dependencies of) inlined leaf domains, - // used for deterministic order - VectorOfUniqueEntries ordered_p_ca_ids; - - std::unordered_map> - p2c_root_broadcast_resolution_map; -}; // Returns the root producer iteration domains that are resolved by provided // consumer std::unordered_map resolvedRootBroadcasts( TensorView* producer, TensorView* consumer) { - auto p2c_map = - PairwiseRootDomainMap(producer, consumer).mapProducerToConsumer(); + auto p2c_map = PairwiseRootDomainMap(producer, consumer) + .mapBroadcast(true) + .mapProducerToConsumer(); std::unordered_map resolved_bcast_map; for (const auto& [p_id, c_id] : p2c_map) { @@ -908,84 +842,52 @@ std::unordered_map resolvedRootBroadcasts( return resolved_bcast_map; } -StatefulLoweringInfo buildInfo( +// Grab inlining relationships +StatefulInliningInfo buildStatefulInliningInfo( const std::vector& exprs, const ValGraph& exact_graph, const ValGraph& permissive_graph) { - StatefulLoweringInfo info; - // Grab inlining relationships + StatefulInliningInfo info; for (auto expr : exprs) { - for (auto producer : ir_utils::filterByType(expr->inputs())) { - auto producer_root = producer->getMaybeRFactorDomain(); - auto producer_domain = producer->domain()->leaf(); + for (auto producer_tv : + ir_utils::filterByType(expr->inputs())) { + const auto& producer_root = producer_tv->getMaybeRFactorDomain(); + const auto& producer_domain = producer_tv->domain()->leaf(); // Grab all iteration domains in producer that its compute at iter domains // depend on. - VectorOfUniqueEntries all_producer_ca_deps; - { - auto ca_dep_vals = DependencyCheck::getAllValsBetween( - {producer_root.begin(), producer_root.end()}, - {producer_domain.begin(), - producer_domain.begin() + producer->getComputeAtPosition()}); - auto ca_deps_filter = ir_utils::filterByType(ca_dep_vals); - - all_producer_ca_deps.insert( - ca_deps_filter.begin(), ca_deps_filter.end()); - } + auto ca_dep_vals = DependencyCheck::getAllValsBetween( + {producer_root.begin(), producer_root.end()}, + {producer_domain.begin(), + producer_domain.begin() + producer_tv->getComputeAtPosition()}); + auto ca_deps_filter = ir_utils::filterByType(ca_dep_vals); + VectorOfUniqueEntries all_producer_ca_deps( + ca_deps_filter.begin(), ca_deps_filter.end()); info.ordered_p_ca_ids.pushBack(all_producer_ca_deps); - for (auto consumer : - ir_utils::filterByType(expr->outputs())) { - auto resolved_bcast_map = resolvedRootBroadcasts(producer, consumer); - - for (const auto& [p_id, c_id] : resolved_bcast_map) { - info.p2c_root_broadcast_resolution_map[p_id].pushBack(c_id); - for (auto other_exact_bcast : *(exact_graph.toGroup(p_id))) { - if (p_id == other_exact_bcast) { - continue; - } - if (all_producer_ca_deps.has(other_exact_bcast->as())) { - // TODO-NM: Why is this here? Can be removed? - NVF_ERROR( - false, - "Can this happen? Adding other exact: ", - other_exact_bcast->name(), - " in addition to ", - p_id->name(), - " of ", - producer->toString()); - info.p2c_root_broadcast_resolution_map[other_exact_bcast - ->as()] - .pushBack(c_id); - } - } - } - auto all_producer_ids = ir_utils::allIDsOf(producer); - auto all_consumer_ids = ir_utils::allIDsOf(consumer); - info.ordered_c_ids.pushBack(all_consumer_ids); + // Gather info on and producer-consumer + // mappings of CA domains and broadcast resolution + for (auto consumer_tv : + ir_utils::filterByType(expr->outputs())) { + auto all_producer_ids = ir_utils::allIDsOf(producer_tv); + auto all_consumer_ids = ir_utils::allIDsOf(consumer_tv); auto p2c_permissive_map = permissive_graph.buildMapBetween( all_producer_ids, all_consumer_ids); - for (const auto& entry : p2c_permissive_map) { - if (entry.second.empty()) { - continue; - } - if (all_producer_ca_deps.has(entry.first->as())) { - info.p2c_ca_permissive_maps[entry.first->as()].pushBack( - entry.second); + for (const auto& [p_id, c_ids] : p2c_permissive_map) { + if (!c_ids.empty() && + all_producer_ca_deps.has(p_id->as())) { + info.p2c_ca_permissive_maps[p_id->as()].pushBack(c_ids); } - info.p2c_permissive_maps[entry.first->as()].pushBack( - entry.second); } - for (const auto& entry : p2c_permissive_map) { - if (entry.second.empty()) { - continue; - } - info.p2c_permissive_maps[entry.first->as()].pushBack( - entry.second); + std::unordered_map resolved_bcast_map = + resolvedRootBroadcasts(producer_tv, consumer_tv); + + for (const auto& [p_root_id, c_root_id] : resolved_bcast_map) { + info.p2c_root_broadcast_resolution_map[p_root_id].pushBack(c_root_id); } } } @@ -995,6 +897,87 @@ StatefulLoweringInfo buildInfo( } // namespace +void IdModel::buildLoopMap(const std::vector& exprs) { + const StatefulInliningInfo info = buildStatefulInliningInfo( + exprs, idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::PERMISSIVE)); + + initializeLoopMap(info); + + // Initial propagation of parallel types for inlined iter domains. Each time + // new expressions are replayed this needs to be run. The disjoint sets in + // the loop graph can only be joined after this point. + // propagateLoopPTypes(); + + auto iel_promotion_map = buildInlinePromotions(info); + // propagateLoopPTypes(); + + // Find loops that need to be promoted because of broadcast resolution, + // figure out what that resolution should look like, compute IDs for it if + // necessary. + iel_promotion_map = buildLoopPromotionMap(exprs, info, iel_promotion_map); + // Loop map potentialy changed changed, as we could have replayed + // expressions. Re-propagate parallel types. + // propagateLoopPTypes(); + + // This pass still doesn't work, disable for now in case it's disruptive to + // tests. + /* + // Find loops that need to be promoted because of broadcast resolution, + // figure out what that resolution should look like, compute IDs for it if + // necessary. + auto leaf_id_promo_map = + buildIndexGraph(tv_exprs, all_tvs, info, iel_promotion_map); + // Make sure we update ptypes onto the index leaf iter domains + propagateLoopPTypes(); + */ +} + +// TODO: Reenable after reenabling parallel propagation. +// propagateLoopPTypes +void IdModel::validatePTypes(const std::vector& all_tvs) const { + // VectorOfUniqueEntries leaf_ids; + // for (auto tv : all_tvs) { + // leaf_ids.pushBack(tv->domain()->leaf()); + // } + + // for (const auto& disjoint_set : + // idGraph(IdMappingMode::EXACT).disjointValSets().disjointSets()) { + // for (auto id : disjoint_set->vector()) { + // auto id_ptype = id->getParallelType(); + + // NVF_ERROR( + // leaf_ids.has(id) || id_ptype == ParallelType::Serial, + // "Invalid parallelization of non leaf iter domain: ", + // id->toString()); + // } + // } +} + +void IdModel::propagateLoopPTypes() const { + for (const auto& loop_disjoint_set : + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { + ParallelType common_ptype = ParallelType::Serial; + for (auto id : loop_disjoint_set->vector()) { + auto id_ptype = id->as()->getParallelType(); + + NVF_ERROR( + id_ptype == common_ptype || id_ptype == ParallelType::Serial || + common_ptype == ParallelType::Serial, + "Issue validating parallel type disjoint ptype is, ", + common_ptype, + " but found in the set the id: ", + id->toString()); + + common_ptype = + common_ptype == ParallelType::Serial ? id_ptype : common_ptype; + } + + for (auto id : loop_disjoint_set->vector()) { + id->as()->parallelize(common_ptype); + } + } +} + void IdModel::build( const std::vector& exprs, const std::vector& additional_tvs, @@ -1064,47 +1047,7 @@ void IdModel::build( // from the almost exact graph. idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); - // Only build loop map during lowering - // TODO: make this configurable - if (true || FusionGuard::getCurFusion()->isA()) { - validatePTypes(all_tvs.vector()); - - StatefulLoweringInfo info = buildInfo( - tv_exprs, - idGraph(IdMappingMode::EXACT), - idGraph(IdMappingMode::PERMISSIVE)); - - initializeLoopMap(info); - - // Initial propagation of parallel types for inlined iter domains. Each time - // new expressions are replayed this needs to be run. The disjoint sets in - // the loop graph can only be joined after this point. - // propagateLoopPTypes(); - - auto iel_promotion_map = buildInlinePromotions(info); - // propagateLoopPTypes(); - - // Find loops that need to be promoted because of broadcast resolution, - // figure out what that resolution should look like, compute IDs for it if - // necessary. - iel_promotion_map = - buildLoopPromotionMap(tv_exprs, info, iel_promotion_map); - // Loop map potentialy changed changed, as we could have replayed - // expressions. Re-propagate parallel types. - // propagateLoopPTypes(); - - // This pass still doesn't work, disable for now in case it's disruptive to - // tests. - /* - // Find loops that need to be promoted because of broadcast resolution, - // figure out what that resolution should look like, compute IDs for it if - // necessary. - auto leaf_id_promo_map = - buildIndexGraph(tv_exprs, all_tvs, info, iel_promotion_map); - // Make sure we update ptypes onto the index leaf iter domains - propagateLoopPTypes(); - */ - } + buildLoopMap(tv_exprs); // Make sure there's no self mapping in TensorView's during lowering // that would invalidate lowering assumptions. @@ -1112,7 +1055,7 @@ void IdModel::build( } VectorOfUniqueEntries IdModel::computeTerminalLoopIds( - const StatefulLoweringInfo info) { + const StatefulInliningInfo info) { VectorOfUniqueEntries terminal_loop_ids; for (const ValGroup& group : idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { @@ -1157,15 +1100,15 @@ ValGraph IdModel::buildIntersection( const ValGraph& graph0, const ValGraph& graph1, bool propagate_exprs) { - auto intersection = initializeIdGraph(propagate_exprs); - for (const auto& group0 : graph0.disjointValSets().disjointSets()) { + ValGraph intersection = initializeIdGraph(propagate_exprs); + for (const ValGroup& group0 : graph0.disjointValSets().disjointSets()) { auto set_size = group0->size(); for (auto id0_i : c10::irange(set_size)) { - auto id0 = group0->vector()[id0_i]; + Val* id0 = group0->vector()[id0_i]; for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { - auto id1 = group0->vector()[id1_i]; + Val* id1 = group0->vector()[id1_i]; // id0 and id1 map in group0. If they also map in the group1, - // add the mapping to the inersection. + // add the mapping to the intersection. if (graph1.disjointValSets().strictAreMapped(id0, id1)) { intersection.mapVals(id0, id1); } @@ -1175,7 +1118,7 @@ ValGraph IdModel::buildIntersection( return intersection; } -void IdModel::initializeLoopMap(StatefulLoweringInfo& info) { +void IdModel::initializeLoopMap(const StatefulInliningInfo& info) { // See Indexing20 example for why we shouldn't propagate when generating loop // groups idGraph(IdMappingMode::LOOP) = initializeIdGraph(false); @@ -1193,8 +1136,9 @@ void IdModel::initializeLoopMap(StatefulLoweringInfo& info) { } } -std::unordered_map IdModel::buildInlinePromotions( - StatefulLoweringInfo& info) { +std::unordered_map IdModel::buildInlineRootPromotions( + const ValGraph& iel_graph, + const StatefulInliningInfo& info) { // Make an intersection of the exact and loop map. This will group together // entries in each loop group that are exact with each other. This provides a // better graph to do promotion and replays. @@ -1219,9 +1163,6 @@ std::unordered_map IdModel::buildInlinePromotions( // smaller groups and this algorithm scales with the number of groups * // (number of entries in groups ^ 2) - ValGraph intersection_exact_loop_graph = buildIntersection( - idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); - // Promotion logic is going to be on the intersection of the exact and loop // graph. We will generate a map on the entries of this graph so it's // important to not modify this graph moving forward, as that would invalidate @@ -1255,11 +1196,12 @@ std::unordered_map IdModel::buildInlinePromotions( // Note again this process is only done for root domains. Once we // find promotion relationships for root domains, we propagate the // mappings to derived domains - for (const ValGroup& iel_group : - intersection_exact_loop_graph.disjointValSets().disjointSets()) { + for (const ValGroup& iel_group : iel_graph.disjointValSets().disjointSets()) { NVF_ERROR(!iel_group->empty()); - if (!iel_group->front()->as()->isBroadcast()) { + IterDomain* iel_group_id = iel_group->front()->as(); + + if (!iel_group_id->isBroadcast()) { continue; } @@ -1277,54 +1219,64 @@ std::unordered_map IdModel::buildInlinePromotions( } } + if (resolved_exact_groups.empty()) { + // No resolution + continue; + } + // Collect all the exact groups in the loop set containing this iel_group - auto loop_group = idGraph(IdMappingMode::LOOP).toGroup(iel_group->front()); - auto loop_covered_exact_groups = + const ValGroup& loop_group = + idGraph(IdMappingMode::LOOP).toGroup(iel_group_id); + ValGroups loop_covered_exact_groups = idGraph(IdMappingMode::EXACT).toGroups(*loop_group); // The intersection of the exact groups that the broadcast domains can be // broadcasted to, and those that exist within the same loop groop are is - // the promotion needed for this iel_group. + // the promotion needed for this iel_group. The promotion should + // be none or unique. ValGroups loop_exact_resolved_intersection = resolved_exact_groups.computeIntersect(loop_covered_exact_groups); if (loop_exact_resolved_intersection.empty()) { - // No resolution + // No promotion continue; } if (loop_exact_resolved_intersection.size() > 1) { + // Ambiguous promotion. This should not happen. std::stringstream err_msg; - err_msg << "Invalid multiple broadcast resolution within shared loops detected, group:\n " << iel_group->toString() << "\nIs being broadcasted to:"; - for (const ValGroup& entry : loop_exact_resolved_intersection) { err_msg << "\n " << entry->toString(); } NVF_ERROR(false, err_msg.str()); } - // loop_exact_resolved_intersection.size() must be 1 at this point - ValGroup exact_resolution_group = loop_exact_resolved_intersection.front(); + const ValGroup& exact_resolution_group = + loop_exact_resolved_intersection.front(); + // Within the loop group, find the IDs that the broadcast IDs are + // resolved to VectorOfUniqueEntries resolved_ids = exact_resolution_group->computeIntersect(*loop_group); - auto promoted_iel_groups = - intersection_exact_loop_graph.toGroups(resolved_ids); - if (promoted_iel_groups.empty()) { - continue; - } + NVF_ERROR(!resolved_ids.empty()); + + // All the IDs in resolved_ids are mapped with both of the exact + // and loop graphs, so any of them can be used as an IEL promotion + // ID. Just to make it extra clear, look for corresponding + // groups in the IEL graph and make sure there's only one such group. + ValGroups promoted_iel_groups = iel_graph.toGroups(resolved_ids); + + NVF_ERROR(!promoted_iel_groups.empty()); if (promoted_iel_groups.size() > 1) { std::stringstream err_msg; - err_msg << "Invalid multiple broadcast resolution within shared loops detected, group:\n " << iel_group->toString() << "\nIs being broadcasted to:"; - for (const ValGroup& entry : promoted_iel_groups) { err_msg << "\n " << entry->toString(); } @@ -1335,6 +1287,54 @@ std::unordered_map IdModel::buildInlinePromotions( promoted_iel_groups.front()->front()->as(); } + return iel_promotion_map; +} + +std::unordered_map IdModel::buildInlinePromotions( + const StatefulInliningInfo& info) { + // Make an intersection of the exact and loop map. This will group together + // entries in each loop group that are exact with each other. This provides a + // better graph to do promotion and replays. + + // It's tempting to use the intersection of the almost exact and loop, but we + // need to model broadcast promotion, and if we have two tensors like: + // + // T1[i0, b1] = T0[i0] + // T2[i0, b2] = T0[i0] + // Then resolution of: + // T4 = T1[i0, b1] + T3[i0, i1] + // T6 = T2[i0, b2] + T5[i0, i2] + // + // Then merge(0, 1) with all tensors except for T0 + // + // The almost exact map will map i0, i0*b1, and i0*b2 together, but b1 and b2 + // are being resolved to i1 and i2 respectively. So we want to have separate + // entries so we can have an easy to process promotion map. + // + // Loop is a permissive like map, it could have many entries, use the exact + // map as the one we iterate on to reduce complexity as it hopefully has + // smaller groups and this algorithm scales with the number of groups * + // (number of entries in groups ^ 2) + + ValGraph iel_graph = buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); + + // Promotion logic is going to be on the intersection of the exact and loop + // graph. We will generate a map on the entries of this graph so it's + // important to not modify this graph moving forward, as that would invalidate + // the map. + // + // iel stands for Intersection of the Exact and Loop graphs. + std::unordered_map iel_promotion_map = + buildInlineRootPromotions(iel_graph, info); + + // This should probably work just on terminating inputs, as we shouldn't be + // able to modify a broadcast domain between root and rfactor which would be + // required to resolve a non input broadcast domain. But for now leaving it as + // traversal on all broadcast groups. + // + // TODO-NM: The ordering appears to be non-deterministic + // Propagate promotion mappings from root domains to derived domains // by traversing IEL exprs. For each expr, if an input is promoted, // the output needs to be promoted too. If there's already a domain @@ -1345,12 +1345,11 @@ std::unordered_map IdModel::buildInlinePromotions( // In order to make // this traversal work, the traversal order must be toplogically // sorted. - IdGraphStmtSort iel_stmt_sort(intersection_exact_loop_graph); + IdGraphStmtSort iel_stmt_sort(iel_graph); for (const ExprGroup& iel_expr : iel_stmt_sort.exprs()) { NVF_ERROR(!iel_expr->empty()); - std::vector input_groups = - intersection_exact_loop_graph.inputGroups(iel_expr); + std::vector input_groups = iel_graph.inputGroups(iel_expr); // Check if any inputs need promotion indicating this expr group needs to // be replayed with promoted inputs @@ -1374,9 +1373,8 @@ std::unordered_map IdModel::buildInlinePromotions( ValGroups promoted_input_groups; for (auto inp_id : promoted_inputs) { - if (intersection_exact_loop_graph.hasGroup(inp_id)) { - promoted_input_groups.pushBack( - intersection_exact_loop_graph.toGroup(inp_id)); + if (iel_graph.hasGroup(inp_id)) { + promoted_input_groups.pushBack(iel_graph.toGroup(inp_id)); } } @@ -1404,7 +1402,7 @@ std::unordered_map IdModel::buildInlinePromotions( ExprGroups non_promoted_input_uses; for (const ValGroup& iel_group : promoted_input_groups.computeIntersect(input_groups)) { - const ExprGroups* uses = intersection_exact_loop_graph.getUses(iel_group); + const ExprGroups* uses = iel_graph.getUses(iel_group); NVF_ERROR(uses); non_promoted_input_uses.pushBack(*uses); } @@ -1427,7 +1425,7 @@ std::unordered_map IdModel::buildInlinePromotions( bool inps_match = true; for (auto inp_i : c10::irange(use_inps.size())) { inps_match = inps_match && - intersection_exact_loop_graph.disjointValSets().strictAreMapped( + iel_graph.disjointValSets().strictAreMapped( use_inps[inp_i], promoted_inputs[inp_i]); } if (inps_match) { @@ -1442,8 +1440,7 @@ std::unordered_map IdModel::buildInlinePromotions( replay = addReplayAs(promoted_inputs, iel_expr->front()); } - std::vector out_groups = - intersection_exact_loop_graph.outputGroups(iel_expr); + std::vector out_groups = iel_graph.outputGroups(iel_expr); // Mark outputs as having a promoted iter domain auto replay_out_ids = @@ -1556,7 +1553,7 @@ std::unordered_map computeCoveredGroups( std::unordered_map IdModel::buildLoopPromotionMap( const std::vector& exprs, - StatefulLoweringInfo& info, + const StatefulInliningInfo& info, const std::unordered_map& stale_promotion_map) { // Non-ca domains may also need to be promoted if parent domains are // promoted. @@ -2006,7 +2003,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( std::unordered_map IdModel::buildIndexGraph( const std::vector& exprs, const std::vector& all_tvs, - StatefulLoweringInfo& info, + StatefulInliningInfo& info, std::unordered_map stale_promotion_map) { NVF_ERROR(false, "Not implemented yet."); } diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index d64768dfa2f..2c66d0b3284 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -21,11 +21,20 @@ namespace nvfuser { class ValGraph; -namespace { -// Convenience to store some intermediate data across a few lowering build -// passes. -struct StatefulLoweringInfo; -} // namespace +struct StatefulInliningInfo { + // All producer ids within (including dependencies of) inlined leaf domains, + // used for deterministic order + VectorOfUniqueEntries ordered_p_ca_ids; + + // Broadcast resolution map for root domains + std::unordered_map> + p2c_root_broadcast_resolution_map; + + // p2c mappings through the fusion within (including dependencies of) inlined + // leaf domains. + std::unordered_map> + p2c_ca_permissive_maps; +}; // A collection of ValGraphs that are built from a fusion or series of // expressions. These graphs are related, but have some distinct features based @@ -160,8 +169,7 @@ class IdModel : public PolymorphicBase { return loop_promotion_map_; } - // TODO: Should this not be private? - protected: + private: // Sometimes fusion inputs or outputs are disconnected from expressions, in // those cases we still may want to send in some additional tensor views from // the Fusion that don't have expressions associated with them. @@ -194,6 +202,31 @@ class IdModel : public PolymorphicBase { // Exact entries, then map through broadcasts void buildPermissiveMap(const std::vector& exprs); + // Fills disjoint_ids_[IdMappingMode::LOOP]. Map only inlined + // domains that are mapped in the permissive graph + void buildLoopMap(const std::vector& exprs); + + // Start loop map by grouping inlined iter domains + void initializeLoopMap(const StatefulInliningInfo& info); + + std::unordered_map buildInlineRootPromotions( + const ValGraph& iel_graph, + const StatefulInliningInfo& info); + + // Returns map of ValGroups in the loop map to a representative IterDomain + // that contains all resolved transformations that the terminal IterDomains + // should be promoted to. The returned promotions are valid only for inlined + // iter domains. + std::unordered_map buildInlinePromotions( + const StatefulInliningInfo& info); + + // Returns a similar thing to buildInlinePromotions but also includes iter + // domains that are not inlined. + std::unordered_map buildLoopPromotionMap( + const std::vector& exprs, + const StatefulInliningInfo& info, + const std::unordered_map& stale_promotion_map); + // Make sure only leaf nodes of tensor views are parallelized void validatePTypes(const std::vector& all_tvs) const; @@ -210,7 +243,7 @@ class IdModel : public PolymorphicBase { // is also in the same loop group // 2) Don't have a direct IterDomain consumer within the group VectorOfUniqueEntries computeTerminalLoopIds( - const StatefulLoweringInfo info); + const StatefulInliningInfo info); // Returns an IdGraph with all Id's mapped that are mapped both in graph0 and // graph1. @@ -221,30 +254,13 @@ class IdModel : public PolymorphicBase { // !! END Helper functions to build loop promotion and index map!! - // Start loop map by grouping inlined iter domains - void initializeLoopMap(StatefulLoweringInfo& info); - - // Returns map of ValGroups in the loop map to a representative IterDomain - // that contains all resolved transformations that the terminal IterDomains - // should be promoted to. The returned promotions are valid only for inlined - // iter domains. - std::unordered_map buildInlinePromotions( - StatefulLoweringInfo& info); - - // Returns a similar thing to buildInlinePromotions but also includes iter - // domains that are not inlined. - std::unordered_map buildLoopPromotionMap( - const std::vector& exprs, - StatefulLoweringInfo& info, - const std::unordered_map& stale_promotion_map); - // Builds idGraph(IdMappingMode::INDEX) and returns the iter domain promotion // map to go from leaf domains of each (consumer only?) tensor to their // corresponding leaf domain in the index graph. std::unordered_map buildIndexGraph( const std::vector& exprs, const std::vector& all_tvs, - StatefulLoweringInfo& info, + StatefulInliningInfo& info, std::unordered_map stale_promotion_map); // Returns the terminal rfactor or input iter domains each group in the almost @@ -259,6 +275,7 @@ class IdModel : public PolymorphicBase { // Errors if self mapping occurs void assertNoSelfMapping(); + private: // Keeps ValGraphs containing all IterDomains for all mapping mode types. // // Using an array here might be nice, but it seems hard to use an enum as an @@ -282,6 +299,9 @@ class IdModel : public PolymorphicBase { std::optional> self_mapping_info_ = std::nullopt; + // Loop promotion map for inlined root broadcast domains + std::unordered_map iel_root_promotion_map_; + // Promotion domain for each loop group std::unordered_map loop_promotion_map_; diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index c06571ba98c..666568067a7 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -442,48 +442,36 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) std::unordered_map> ValGraph::buildMapBetween( const std::vector& from, const std::vector& to) const { - std::unordered_map from_ids2set; + // Map from the sets associated with the Vals in to, to those Vals + std::unordered_map> set2to_vals; - for (auto from_id : from) { - if (!hasGroup(from_id)) { + for (auto to_val : to) { + if (!hasGroup(to_val)) { continue; } - from_ids2set[from_id] = toGroup(from_id); + const auto& to_set = toGroup(to_val); + set2to_vals[to_set].pushBack(to_val); } - // Map from the sets associated with the IterDomains in to, to those iter - // domains - std::unordered_map> set2to_ids; + std::unordered_map> from_vals2to_vals; + for (auto from_val : from) { + // Initialize in case no to val is mapped + from_vals2to_vals[from_val] = VectorOfUniqueEntries(); - for (auto to_id : to) { - if (!hasGroup(to_id)) { + if (!hasGroup(from_val)) { continue; } - auto to_set = toGroup(to_id); - auto set2to_ids_it = set2to_ids.find(to_set); - if (set2to_ids_it == set2to_ids.end()) { - set2to_ids[to_set] = {to_id}; - } else { - set2to_ids[to_set].pushBack(to_id); - } - } + const ValGroup& from_set = toGroup(from_val); - std::unordered_map> from_ids2to_ids; - for (auto from_id : from) { - from_ids2to_ids[from_id] = VectorOfUniqueEntries(); - - auto from_it = from_ids2set.find(from_id); - NVF_ERROR(from_it != from_ids2set.end()); - - auto from_set = from_it->second; - auto to_entry_it = set2to_ids.find(from_set); - if (to_entry_it == set2to_ids.end()) { + auto to_entry_it = set2to_vals.find(from_set); + if (to_entry_it == set2to_vals.end()) { continue; } - from_ids2to_ids[from_id] = to_entry_it->second; + + from_vals2to_vals[from_val] = to_entry_it->second; } - return from_ids2to_ids; + return from_vals2to_vals; } std::unordered_map> ValGraph::buildMapBetween( diff --git a/csrc/val_graph.h b/csrc/val_graph.h index 38d587abc46..0f08574f2b1 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -141,8 +141,8 @@ class ValGraph { ExprGroups getExprsBetween(const ValGroups& from, const ValGroups& to) const; // Supports one to many mappings, uses the disjoint sets of the provided mode - // to produce mappings between from and to. If multiple IterDomains in to map - // to a single iter domain in from, the order of the IterDomains in value of + // to produce mappings between from and to. If multiple Vals in to map + // to a single Val in from, the order of the Vals in value of // the map is preserved to be the order provided in to. std::unordered_map> buildMapBetween( const std::vector& from, From b364f289f4bbf10bda908b4fadffd459a336cef1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sun, 31 Dec 2023 12:56:53 -0800 Subject: [PATCH 107/178] cleanup --- csrc/id_model/id_model.cpp | 5 +-- csrc/id_model/id_model.h | 73 +++++++++++++++++++------------------- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 46599390557..3a1cf4771f5 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -1119,8 +1119,9 @@ ValGraph IdModel::buildIntersection( } void IdModel::initializeLoopMap(const StatefulInliningInfo& info) { - // See Indexing20 example for why we shouldn't propagate when generating loop - // groups + // In the case of the Loop graph, we do not propagate mappings but + // explicitly set which domains to map based on the permissive graph + // and the CA positions. idGraph(IdMappingMode::LOOP) = initializeIdGraph(false); // Make sure this is called in a deterministic order. Build all inlined diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 2c66d0b3284..ab1d452c040 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -127,44 +127,8 @@ class IdModel : public PolymorphicBase { return self_mapping_info_.has_value(); } - // Update the LOOP ID disjoint sets with resolved computeWith - void updateComputeWith(TensorView* compute_with_tv); - std::string toString() const; - // Replay Expr but with the inputs provided. IterDomainGraphss will be updated - // for all maps that have entries, adding the output iter domains of the - // replayed expression and adding potential mappings through the expression. - Expr* addReplayAs(std::vector new_inputs, Expr* expr); - - // Similar to addReplayAs, but clones the expr exactly instead of replaying it - // forward. It's up to the calling code to make sure the replacements are - // valid for the provided expr. It's generally recommended that the - // IterDomains exactly match those in the expr. - // - // "forward" dictates the same argument for mapThroughExpr. If forward the - // function will apply mapThroughExpr forward if inputs map in each - // initialized map. Else does the same but backwards through the expression - // from outputs. - Expr* addExprWithReplacement( - const std::unordered_map& old_2_new_ids, - Expr* old_expr); - - // Make a new expr matching that provided but using the outputs provided. - // IterDomainGraphss will be updated for all maps that have entries. Adding - // the input iter domains of the replayed expression and adding potential - // mappings through the expressions. Input domains will match exactly in all - // properties as those in expr. This is unlike addReplayAs which will produce - // new outputs using transformations directly. - Expr* addBackwardsReplayAs( - const std::vector& new_outputs, - Expr* expr); - - // Make an exact copy of provided IterDomain (without rfactor set), and map - // the copy to the original in all registered IdModel. IterDomain copy will - // not have any registered uses or definitions. - IterDomain* cloneIterDomain(IterDomain* id); - const std::unordered_map loopPromotionMap() const { return loop_promotion_map_; } @@ -275,6 +239,43 @@ class IdModel : public PolymorphicBase { // Errors if self mapping occurs void assertNoSelfMapping(); + // TODO: + // Update the LOOP ID disjoint sets with resolved computeWith + void updateComputeWith(TensorView* compute_with_tv); + + // Replay Expr but with the inputs provided. IterDomainGraphss will be updated + // for all maps that have entries, adding the output iter domains of the + // replayed expression and adding potential mappings through the expression. + Expr* addReplayAs(std::vector new_inputs, Expr* expr); + + // Similar to addReplayAs, but clones the expr exactly instead of replaying it + // forward. It's up to the calling code to make sure the replacements are + // valid for the provided expr. It's generally recommended that the + // IterDomains exactly match those in the expr. + // + // "forward" dictates the same argument for mapThroughExpr. If forward the + // function will apply mapThroughExpr forward if inputs map in each + // initialized map. Else does the same but backwards through the expression + // from outputs. + Expr* addExprWithReplacement( + const std::unordered_map& old_2_new_ids, + Expr* old_expr); + + // Make a new expr matching that provided but using the outputs provided. + // IterDomainGraphss will be updated for all maps that have entries. Adding + // the input iter domains of the replayed expression and adding potential + // mappings through the expressions. Input domains will match exactly in all + // properties as those in expr. This is unlike addReplayAs which will produce + // new outputs using transformations directly. + Expr* addBackwardsReplayAs( + const std::vector& new_outputs, + Expr* expr); + + // Make an exact copy of provided IterDomain (without rfactor set), and map + // the copy to the original in all registered IdModel. IterDomain copy will + // not have any registered uses or definitions. + IterDomain* cloneIterDomain(IterDomain* id); + private: // Keeps ValGraphs containing all IterDomains for all mapping mode types. // From 3b9457478b163612e147b763fc8e5ce3563c86e2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 6 Jan 2024 10:36:02 -0800 Subject: [PATCH 108/178] IdModel: refactoring loop promotion (#1589) --- csrc/id_model/id_model.cpp | 914 +++++++++++++---------------- csrc/id_model/id_model.h | 49 +- csrc/id_model/transform_replay.cpp | 10 + csrc/id_model/transform_replay.h | 2 + test/test_gpu_indexing.cpp | 3 +- 5 files changed, 468 insertions(+), 510 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 3a1cf4771f5..ff5abf05422 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -898,11 +898,25 @@ StatefulInliningInfo buildStatefulInliningInfo( } // namespace void IdModel::buildLoopMap(const std::vector& exprs) { + if (exprs.empty()) { + return; + } + const StatefulInliningInfo info = buildStatefulInliningInfo( exprs, idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::PERMISSIVE)); + std::stringstream ss; + exprs.at(0)->fusion()->print(ss); + VERBOSE() << ss.str(); + initializeLoopMap(info); + VERBOSE() << "Initial loop graph:\n"; + for (const auto& group : + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { + VERBOSE() << nvfuser::toString(group) << std::endl; + } + // Initial propagation of parallel types for inlined iter domains. Each time // new expressions are replayed this needs to be run. The disjoint sets in // the loop graph can only be joined after this point. @@ -1055,7 +1069,7 @@ void IdModel::build( } VectorOfUniqueEntries IdModel::computeTerminalLoopIds( - const StatefulInliningInfo info) { + const StatefulInliningInfo& info) { VectorOfUniqueEntries terminal_loop_ids; for (const ValGroup& group : idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { @@ -1286,85 +1300,136 @@ std::unordered_map IdModel::buildInlineRootPromotions( iel_promotion_map[iel_group] = promoted_iel_groups.front()->front()->as(); + + VERBOSE() << "Root promotion: " << nvfuser::toString(iel_group) << " -> " + << promoted_iel_groups.front()->front()->as()->name() + << std::endl; } return iel_promotion_map; } -std::unordered_map IdModel::buildInlinePromotions( - const StatefulInliningInfo& info) { - // Make an intersection of the exact and loop map. This will group together - // entries in each loop group that are exact with each other. This provides a - // better graph to do promotion and replays. +namespace { - // It's tempting to use the intersection of the almost exact and loop, but we - // need to model broadcast promotion, and if we have two tensors like: - // - // T1[i0, b1] = T0[i0] - // T2[i0, b2] = T0[i0] - // Then resolution of: - // T4 = T1[i0, b1] + T3[i0, i1] - // T6 = T2[i0, b2] + T5[i0, i2] - // - // Then merge(0, 1) with all tensors except for T0 - // - // The almost exact map will map i0, i0*b1, and i0*b2 together, but b1 and b2 - // are being resolved to i1 and i2 respectively. So we want to have separate - // entries so we can have an easy to process promotion map. - // - // Loop is a permissive like map, it could have many entries, use the exact - // map as the one we iterate on to reduce complexity as it hopefully has - // smaller groups and this algorithm scales with the number of groups * - // (number of entries in groups ^ 2) +// When replaying the transformations we can't blindly apply loop promotion +// to all iter domains within a loop group as it would replay the +// transformations within that loop group on the promoted id of that loop +// group. +// +// i.e. if we have the inlined domains from: +// T2[i0*i1] pa(1) = T0[i0*b1]ca(1) + T1[i0*i1]ca(1) +// The inlined loop group would be: +// +// i0, i1, b1, i0*i1, b0*i1 +// Then if we replayed the iel transformations they would be: +// merge(i0, i1) +// merge(i0, b1) +// +// So if we replayed them with loop promotion, then i0, i1, b1 would be +// promoted to i0*i1, and the merges would be replayed. +// +// Therefore only promote i0*b1 to i0*i1, or i0*i1 to i0*i1 (i.e. don't +// promote an input to any transformation within the loop group). +// +// So if we have an iel_expr make sure it's inputs and outputs are not in +// the same loop group. +bool hasUniqueOutputLoopGroups( + const ExprGroup& iel_expr, + const ValGraph& iel_graph, + const ValGraph& loop_graph) { + const std::vector iel_inp_groups = iel_graph.inputGroups(iel_expr); - ValGraph iel_graph = buildIntersection( - idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); + const std::vector iel_out_groups = iel_graph.outputGroups(iel_expr); - // Promotion logic is going to be on the intersection of the exact and loop - // graph. We will generate a map on the entries of this graph so it's - // important to not modify this graph moving forward, as that would invalidate - // the map. - // - // iel stands for Intersection of the Exact and Loop graphs. - std::unordered_map iel_promotion_map = - buildInlineRootPromotions(iel_graph, info); + ValGroups inp_loop_groups; + for (const ValGroup& iel_inp_group : iel_inp_groups) { + inp_loop_groups.pushBack(loop_graph.toGroup(iel_inp_group->front())); + } + ValGroups out_loop_groups; + for (const ValGroup& iel_out_group : iel_out_groups) { + out_loop_groups.pushBack(loop_graph.toGroup(iel_out_group->front())); + } - // This should probably work just on terminating inputs, as we shouldn't be - // able to modify a broadcast domain between root and rfactor which would be - // required to resolve a non input broadcast domain. But for now leaving it as - // traversal on all broadcast groups. - // - // TODO-NM: The ordering appears to be non-deterministic + // Check if output groups that are not included in the input group set + return !inp_loop_groups.computeSubtract(out_loop_groups).empty(); +} - // Propagate promotion mappings from root domains to derived domains - // by traversing IEL exprs. For each expr, if an input is promoted, - // the output needs to be promoted too. If there's already a domain - // that the output domain should be promoted to, create a mapping to it from - // the promoted output domain. If not, a new domain is created by - // replaying the expr with the promoted inputs. +} // namespace - // In order to make - // this traversal work, the traversal order must be toplogically - // sorted. +// Propagate promotion mappings from root domains to derived domains +// by traversing IEL exprs. For each expr, if an input is promoted, +// the output needs to be promoted too. If there's already a domain +// that the output domain should be promoted to, create a mapping to it from +// the promoted output domain. If not, a new domain is created by +// replaying the expr with the promoted inputs. +// +// This is used twice when building the promotion map. The first time +// it is used there's no loop graph promotion yet, so only the IEL +// promotions are propagated. In that case, loop_graph_promotion_map +// should be just empty. +void IdModel::propagatePromotions( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map, + const ValGraph& loop_graph, + const std::unordered_map& loop_graph_promotion_map, + bool require_loop_mapped_promotion) { + // In order to make this traversal work, the traversal order must be + // topologically sorted. IdGraphStmtSort iel_stmt_sort(iel_graph); + // TODO-NM: The ordering might be non-deterministic + for (const ExprGroup& iel_expr : iel_stmt_sort.exprs()) { NVF_ERROR(!iel_expr->empty()); - std::vector input_groups = iel_graph.inputGroups(iel_expr); + const std::vector iel_inp_groups = + iel_graph.inputGroups(iel_expr); + + // Propagate loop graph promotion only when the inputs and outputs are + // not in the same loop group. + const bool loop_promote_inputs = !loop_graph_promotion_map.empty() && + hasUniqueOutputLoopGroups(iel_expr, iel_graph, loop_graph); // Check if any inputs need promotion indicating this expr group needs to // be replayed with promoted inputs - std::vector promoted_inputs; bool an_input_was_promoted = false; + std::vector maybe_promoted_inputs; + maybe_promoted_inputs.reserve(iel_inp_groups.size()); - for (const ValGroup& inp : input_groups) { - auto inp_promo_it = iel_promotion_map.find(inp); - if (inp_promo_it == iel_promotion_map.end()) { - promoted_inputs.push_back(inp->front()->as()); - } else { - promoted_inputs.push_back(inp_promo_it->second); + for (const ValGroup& iel_inp_group : iel_inp_groups) { + // Assumed all inputs are IterDomains + NVF_ERROR(iel_inp_group->front()->isA()); + + // Promote loops based on the loop promotion map. If the loop promotion + // map should be used and has an entry we should use that promotion. This + // happen when an iel expression is across a loop group boundary. + // Signifying and capturing instances when we traverse across an inlined + // loop group to a non-inlined loop group boundary (think of the iel graph + // projected onto the loop graph). + if (loop_promote_inputs) { + const ValGroup& loop_copy_group = + loop_graph.toGroup(iel_inp_group->front()); + auto inp_loop_promo_it = loop_graph_promotion_map.find(loop_copy_group); + if (inp_loop_promo_it != loop_graph_promotion_map.end()) { + maybe_promoted_inputs.push_back(inp_loop_promo_it->second); + an_input_was_promoted = true; + continue; + } + } + + // Even when loop promotions are given, We still could require + // an input promotion. We could be traversing across non-inlined + // groups. Meaning we have inputs that were promoted in an + // inlined loop group traversing through the non-inlined + // portions of the iel graph. + if (auto inp_promo_it = iel_promotion_map.find(iel_inp_group); + inp_promo_it != iel_promotion_map.end()) { + maybe_promoted_inputs.push_back(inp_promo_it->second); an_input_was_promoted = true; + continue; } + + // No promotion found. Just use the non-promoted domain + maybe_promoted_inputs.push_back(iel_inp_group->front()->as()); } if (!an_input_was_promoted) { @@ -1372,94 +1437,191 @@ std::unordered_map IdModel::buildInlinePromotions( continue; } - ValGroups promoted_input_groups; - for (auto inp_id : promoted_inputs) { - if (iel_graph.hasGroup(inp_id)) { - promoted_input_groups.pushBack(iel_graph.toGroup(inp_id)); - } - } + VERBOSE() << "IEL expr: " << iel_expr->front()->toString(); // Before replaying, check if there's already an expression like this, if so // use that for promotion. We would need the iel entries for non-promoted // inputs to match exactly to reuse the expression. - // - // Unfortunately this doesn't actually seem to save any replays because - // we're not adding the replayed expression to the iel graph since we're - // traversing the iel graph. - // - // TODO: Can we reduce the number of new expressions generated - // here? - // - // TODO-NM: This won't work for any single-input expr, e.g., - // split, as there's no other non-promoted input. Can't we just - // look at the use expr of the promoted IDGroup? - // - // TODO-NM: Why can't we just also use the promoted IDs and their - // uses? E.g., test Indexing5, t3 has a merge of iS11 and bS7, - // both of them are promoted to iS17 and iS45, respectively. Since - // there's no promoted input, there would be no reuse, but it - // seems perfectly fine to reuse the merge of iS17 and iS45. - - ExprGroups non_promoted_input_uses; - for (const ValGroup& iel_group : - promoted_input_groups.computeIntersect(input_groups)) { - const ExprGroups* uses = iel_graph.getUses(iel_group); - NVF_ERROR(uses); - non_promoted_input_uses.pushBack(*uses); - } - - Expr* replay = nullptr; - - // Look for exprs that have inputs that are mapped in the IEL - // graph with the (promoted) inputs of iel_expr. If found, no need - // to create a new expr to produce promoted outputs - for (const ExprGroup& iel_use_group : non_promoted_input_uses) { - // No need to check itself - if (iel_expr == iel_use_group) { - continue; + auto findMatchingExpr = + [this, &require_loop_mapped_promotion]( + const ExprGroup& iel_expr, + const ValGraph& iel_graph, + const std::vector& maybe_promoted_inputs) -> Expr* { + ExprGroups maybe_promoted_input_uses; + + for (auto inp_id : maybe_promoted_inputs) { + // inp_id may have been just replayed, in which case it should + // not exist in the IEL graph. It should be just ignored as it + // should not have any use yet. + if (!iel_graph.hasGroup(inp_id)) { + continue; + } + const auto& inp_exact_group = iel_graph.toGroup(inp_id); + const ExprGroups* uses = iel_graph.getUses(inp_exact_group); + NVF_ERROR(uses); + maybe_promoted_input_uses.pushBack(*uses); } - if (ValGraph::transformAtributesMatch( - iel_expr->front(), iel_use_group->front())) { - auto use_inps = - ir_utils::filterByType(iel_use_group->front()->inputs()) - .vector(); + + // Look for exprs that have inputs that are mapped in the IEL + // graph with the (promoted) inputs of iel_expr. If found, no need + // to create a new expr to produce promoted outputs + for (const ExprGroup& maybe_promoted_input_use_group : + maybe_promoted_input_uses) { + VERBOSE() << "Checking other use: " + << nvfuser::toString(maybe_promoted_input_use_group) + << std::endl; + NVF_ERROR(!maybe_promoted_input_use_group->empty()); + // No need to check itself + if (iel_expr == maybe_promoted_input_use_group) { + continue; + } + Expr* maybe_promoted_input_use = + maybe_promoted_input_use_group->front(); + // TODO-NM: Use isSameOp instead + if (!ValGraph::transformAtributesMatch( + iel_expr->front(), maybe_promoted_input_use)) { + continue; + } + // Check if all inputs are mapped + NVF_ERROR( + maybe_promoted_inputs.size() == + maybe_promoted_input_use->inputs().size()); bool inps_match = true; - for (auto inp_i : c10::irange(use_inps.size())) { + for (const auto inp_i : c10::irange(maybe_promoted_inputs.size())) { + // Here, new promoted ids are not added to iel_graph, so + // once promoted, this should not return true anymore. Also, + // strictAreMapped doesn't work as promoted domains are not + // in the graph inps_match = inps_match && - iel_graph.disjointValSets().strictAreMapped( - use_inps[inp_i], promoted_inputs[inp_i]); + iel_graph.disjointValSets().permissiveAreMapped( + maybe_promoted_inputs[inp_i], + maybe_promoted_input_use->inputs().at(inp_i)); } - if (inps_match) { - replay = iel_use_group->front(); - break; + if (!inps_match) { + continue; } - } - } - bool replayed = replay == nullptr; - if (replay == nullptr) { - replay = addReplayAs(promoted_inputs, iel_expr->front()); - } + // For the final loop promotion map, we want to find + // promotions within the same loop groups. Note that that's + // guaranteed when replayed. + if (require_loop_mapped_promotion) { + if (!idGraph(IdMappingMode::LOOP) + .disjointExprSets() + .permissiveAreMapped( + iel_expr->front(), + maybe_promoted_input_use_group->front())) { + continue; + } + // This is just an extra sanity check. Make sure all exprs in + // the use group are mapped + NVF_ERROR( + std::all_of( + maybe_promoted_input_use_group->vector().begin(), + maybe_promoted_input_use_group->vector().end(), + [&](Expr* iel_use) { + return idGraph(IdMappingMode::LOOP) + .disjointExprSets() + .permissiveAreMapped(iel_expr->front(), iel_use); + }), + "Not all mapped: ", + nvfuser::toString(iel_expr), + "\n", + nvfuser::toString(maybe_promoted_input_use_group)); + } + return maybe_promoted_input_use; + } - std::vector out_groups = iel_graph.outputGroups(iel_expr); + return nullptr; + }; - // Mark outputs as having a promoted iter domain - auto replay_out_ids = - ir_utils::filterByType(replay->outputs()).vector(); - auto ref_out_ids = - ir_utils::filterByType(iel_expr->front()->outputs()) - .vector(); + bool replayed = false; + Expr* promoted_expr = + findMatchingExpr(iel_expr, iel_graph, maybe_promoted_inputs); - NVF_ERROR(replay_out_ids.size() == out_groups.size()); + if (!promoted_expr) { + promoted_expr = addReplayAs(maybe_promoted_inputs, iel_expr->front()); + replayed = true; + VERBOSE() << "Replayed: " << promoted_expr->toString(); + } else { + VERBOSE() << "Reusing: " << promoted_expr->toString(); + } - for (auto i : c10::irange(replay_out_ids.size())) { - iel_promotion_map[out_groups[i]] = replay_out_ids[i]; + // Mark outputs as having a promoted iter domain + std::vector out_groups = iel_graph.outputGroups(iel_expr); + NVF_ERROR(promoted_expr->outputs().size() == out_groups.size()); + NVF_ERROR( + ir_utils::filterByType(promoted_expr->outputs()).size() == + out_groups.size(), + "Unexpected non IterDomain outputs found: ", + promoted_expr->toString()); + + for (const auto i : c10::irange(out_groups.size())) { + // Promote if necessary, if the output is already in the same exact map + // it doesn't need a promotion. + if (idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped( + promoted_expr->output(i), out_groups[i]->front())) { + continue; + } + iel_promotion_map[out_groups[i]] = + promoted_expr->output(i)->as(); // Explicitly map loop map since expr propagation doesn't happen if (replayed) { - idGraph(IdMappingMode::LOOP).mapVals(replay_out_ids[i], ref_out_ids[i]); + idGraph(IdMappingMode::LOOP) + .mapVals(iel_expr->front()->output(i), promoted_expr->output(i)); } } } +} + +void IdModel::propagatePromotions( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map) { + propagatePromotions( + iel_graph, iel_promotion_map, idGraph(IdMappingMode::LOOP), {}, false); +} + +std::unordered_map IdModel::buildInlinePromotions( + const StatefulInliningInfo& info) { + // Make an intersection of the exact and loop map. This will group together + // entries in each loop group that are exact with each other. This provides a + // better graph to do promotion and replays. + + // It's tempting to use the intersection of the almost exact and loop, but we + // need to model broadcast promotion, and if we have two tensors like: + // + // T1[i0, b1] = T0[i0] + // T2[i0, b2] = T0[i0] + // Then resolution of: + // T4 = T1[i0, b1] + T3[i0, i1] + // T6 = T2[i0, b2] + T5[i0, i2] + // + // Then merge(0, 1) with all tensors except for T0 + // + // The almost exact map will map i0, i0*b1, and i0*b2 together, but b1 and b2 + // are being resolved to i1 and i2 respectively. So we want to have separate + // entries so we can have an easy to process promotion map. + // + // Loop is a permissive like map, it could have many entries, use the exact + // map as the one we iterate on to reduce complexity as it hopefully has + // smaller groups and this algorithm scales with the number of groups * + // (number of entries in groups ^ 2) + + // Promotion logic is going to be on the intersection of the exact and loop + // graph. We will generate a map on the entries of this graph so it's + // important to not modify this graph moving forward, as that would invalidate + // the map. + // + // iel stands for Intersection of the Exact and Loop graphs. + ValGraph iel_graph = buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); + + // First, identify promotions of root broadcast domains only + std::unordered_map iel_promotion_map = + buildInlineRootPromotions(iel_graph, info); + + propagatePromotions(iel_graph, iel_promotion_map); std::stringstream ss; ss << "Inline promotion map\n"; @@ -1498,16 +1660,15 @@ std::unordered_map updateMap( // traversing on definitions. Ignoring broadcast ValGroups and resetting inputs // at RFactor ValGroups. std::unordered_map computeCoveredGroups( - const ValGraph& exact_graph, + const ValGraph& graph, const std::unordered_set& view_rfactor_ids) { // Map from an exact iter domain group, to all the exact iter domain groups it // covers std::unordered_map covered_ids; - for (const ValGroup& id_group : - exact_graph.disjointValSets().disjointSets()) { + for (const ValGroup& id_group : graph.disjointValSets().disjointSets()) { // Initialize inputs - const ExprGroups* id_group_defs = exact_graph.getDefinitions(id_group); + const ExprGroups* id_group_defs = graph.getDefinitions(id_group); NVF_ERROR(id_group_defs); if (id_group_defs->empty()) { covered_ids[id_group] = {id_group}; @@ -1530,17 +1691,17 @@ std::unordered_map computeCoveredGroups( } } - IdGraphStmtSort exact_stmt_sort(exact_graph); + IdGraphStmtSort exact_stmt_sort(graph); for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) { - std::vector input_groups = exact_graph.inputGroups(exact_expr); + std::vector input_groups = graph.inputGroups(exact_expr); ValGroups covered; for (const ValGroup& inp_group : input_groups) { covered.pushBack(covered_ids.at(inp_group)); } - for (const ValGroup& output_group : exact_graph.outputGroups(exact_expr)) { + for (const ValGroup& output_group : graph.outputGroups(exact_expr)) { // Don't overwrite initialized cases due to rfactor markings. if (covered_ids.find(output_group) == covered_ids.end()) { covered_ids[output_group] = covered; @@ -1552,30 +1713,93 @@ std::unordered_map computeCoveredGroups( } }; // namespace +IterDomain* IdModel::findPromotionOfLoopGroup( + const ValGroup& loop_group, + const ValGraph& iel_graph, + const std::unordered_map& iel_promotion_map, + const std::unordered_map& loop_graph_promotion_map, + const std::unordered_map& exact_covered_ids, + const VectorOfUniqueEntries& terminal_loop_ids) { + const ValGraph& exact_graph = idGraph(IdMappingMode::EXACT); + + std::unordered_map promotion_map; + + // Grab all the (potentially promoted) terminal iter domains in this group. + // Save the exact group and the iter domain in this vector. + std::vector> exact_promoted_terminal_ids; + for (auto loop_id : *loop_group) { + // If not a terminal id in the group skip + if (!terminal_loop_ids.has(loop_id->as())) { + continue; + } + + // Grab the iel entry + const ValGroup& iel_group = iel_graph.toGroup(loop_id); + + auto iel_promo_it = iel_promotion_map.find(iel_group); + if (iel_promo_it == iel_promotion_map.end()) { + // If this terminal ID doesn't have a promotion associated with it, save + // the terminal ID. + exact_promoted_terminal_ids.emplace_back( + exact_graph.toGroup(loop_id), loop_id->as()); + } else { + // If this terminal ID has a promotion, grab the promoted ID. + exact_promoted_terminal_ids.emplace_back( + exact_graph.toGroup(iel_promo_it->second), iel_promo_it->second); + } + + if (auto loop_graph_promotion_map_it = + loop_graph_promotion_map.find(loop_group); + loop_graph_promotion_map_it != loop_graph_promotion_map.end()) { + VERBOSE() << "Found in loop promotion: " << nvfuser::toString(loop_group) + << std::endl; + exact_promoted_terminal_ids.emplace_back( + exact_graph.toGroup(loop_graph_promotion_map_it->second), + loop_graph_promotion_map_it->second); + } + } + + // All the exact groups of the iter domains in the loop group + ValGroups exact_groups = exact_graph.toGroups(*loop_group); + + // All exact groups covered by all iter domains in this loop group + ValGroups loop_group_covered_ids; + for (const ValGroup& exact_group : exact_groups) { + auto covered_it = exact_covered_ids.find(exact_group); + NVF_ERROR(covered_it != exact_covered_ids.end()); + loop_group_covered_ids.pushBack(covered_it->second); + } + + // Check if any of the candidate Iter Domains we collected cover all the + // exact groups of loop_group_covered_ids. If so, that's the correct + // promoted iter domain of this group. + for (const auto& entry : exact_promoted_terminal_ids) { + const ValGroup& terminal_id_group = entry.first; + IterDomain* terminal_id = entry.second; + auto covered_it = exact_covered_ids.find(terminal_id_group); + NVF_ERROR(covered_it != exact_covered_ids.end()); + if (loop_group_covered_ids.computeSubtract(covered_it->second).empty()) { + return terminal_id; + } + } + + return nullptr; +} + std::unordered_map IdModel::buildLoopPromotionMap( const std::vector& exprs, - const StatefulInliningInfo& info, + const StatefulInliningInfo& inlining_info, const std::unordered_map& stale_promotion_map) { // Non-ca domains may also need to be promoted if parent domains are // promoted. // Need to use the intersection of exact and loop map again, it needs to be // recomputed. - auto intersection_exact_loop_graph = buildIntersection( + auto iel_graph = buildIntersection( idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); // Update the promotion map - auto iel_promotion_map = - updateMap(stale_promotion_map, intersection_exact_loop_graph); - - // Map from an exact iter domain group, to all the exact iter domain groups it - // covers; needs to be recomputed. - std::unordered_map exact_covered_ids = - computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); - - // Grab terminal iter domain in the loop groups. - VectorOfUniqueEntries terminal_loop_ids = - computeTerminalLoopIds(info); + auto iel_promotion_map = updateMap(stale_promotion_map, iel_graph); // Loop promotion map is to prepare for IterDomain replays to resolve // non-inlined loop groups. Since these replays will modify the loop map as @@ -1583,100 +1807,26 @@ std::unordered_map IdModel::buildLoopPromotionMap( // the original one. auto loop_graph_copy = idGraph(IdMappingMode::LOOP); - // Build a map from loop iter domain group to a promoted iter domain (doesn't - // have to be in the loop group) that covers all the exact groups - // representative of the resolved transformations within the loop group. Only - // the inlined loop groups will be covered here. std::unordered_map loop_graph_copy_promotion_map; - // TODO: I'm uncertain if we can simply use the iel_promotion_map. Once this - // system is in use we should test not recomputing the "concrete ids". + std::unordered_map exact_covered_ids = + computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); + + // Grab terminal iter domain in the loop groups. + const VectorOfUniqueEntries terminal_loop_ids = + computeTerminalLoopIds(inlining_info); for (const ValGroup& loop_group : loop_graph_copy.disjointValSets().disjointSets()) { - if (loop_group->size() == 1) { - loop_graph_copy_promotion_map[loop_group] = - loop_group->front()->as(); - continue; - } - - // Grab all the (potentially promoted) terminal iter domains in this group. - // Save the exact group and the iter domain in this vector. - std::vector> exact_promoted_terminal_ids; - for (auto loop_id : *loop_group) { - // If not a terminal id in the group skip - if (!terminal_loop_ids.has(loop_id->as())) { - continue; - } - - // Grab the iel entry - const ValGroup& iel_group = - intersection_exact_loop_graph.toGroup(loop_id); - - auto iel_promo_it = iel_promotion_map.find(iel_group); - if (iel_promo_it == iel_promotion_map.end()) { - // If this terminal ID doesn't have a promotion associated with it, save - // the terminal ID. - exact_promoted_terminal_ids.emplace_back( - idGraph(IdMappingMode::EXACT).toGroup(loop_id), - loop_id->as()); - } else { - // If this terminal ID has a promotion, grab the promoted ID. - exact_promoted_terminal_ids.emplace_back( - idGraph(IdMappingMode::EXACT).toGroup(iel_promo_it->second), - iel_promo_it->second); - } - } - - // All the exact groups of the iter domains in the loop group - ValGroups exact_groups = - idGraph(IdMappingMode::EXACT).toGroups(*loop_group); - - // All exact groups covered by all iter domains in this loop group - ValGroups loop_group_covered_ids; - for (const ValGroup& exact_group : exact_groups) { - auto covered_it = exact_covered_ids.find(exact_group); - NVF_ERROR(covered_it != exact_covered_ids.end()); - loop_group_covered_ids.pushBack(covered_it->second); - } - - IterDomain* loop_promotion_id = nullptr; - - // Check if any of the candidate Iter Domains we collected cover all the - // exact groups of loop_group_covered_ids. If so, that's the correct - // promoted iter domain of this group. - for (const auto& entry : exact_promoted_terminal_ids) { - const ValGroup& terminal_id_group = entry.first; - IterDomain* terminal_id = entry.second; - auto covered_it = exact_covered_ids.find(terminal_id_group); - NVF_ERROR(covered_it != exact_covered_ids.end()); - if (loop_group_covered_ids.computeSubtract(covered_it->second).empty()) { - loop_promotion_id = terminal_id; - break; - } - } - - if (loop_promotion_id == nullptr) { - std::stringstream err_msg; - err_msg - << "\n ERROR Loop promotion map build. Could not find promotion for loop group:\n "; - err_msg << nvfuser::toString(loop_group, 0, true); - err_msg << "\nnone of the terminal iter domains of this group:\n "; - for (const auto& entry : exact_promoted_terminal_ids) { - const ValGroup& terminal_id_group = entry.first; - const ValGroups& covered_id_groups = - exact_covered_ids.at(terminal_id_group); - err_msg << " " << nvfuser::toString(terminal_id_group, 0, true) - << " -(covers)-> " << nvfuser::toString(covered_id_groups) - << std::endl; - } - err_msg << "iter domains in this group cover all id groups:\n"; - for (const ValGroup& covered_group : loop_group_covered_ids) { - err_msg << " " << nvfuser::toString(covered_group, 0, true); - } - // NVF_ERROR(false, err_msg.str()); - } else { - loop_graph_copy_promotion_map[loop_group] = loop_promotion_id; + IterDomain* promotion_id = findPromotionOfLoopGroup( + loop_group, + iel_graph, + iel_promotion_map, + {}, + exact_covered_ids, + terminal_loop_ids); + if (promotion_id) { + loop_graph_copy_promotion_map[loop_group] = promotion_id; } } @@ -1686,6 +1836,12 @@ std::unordered_map IdModel::buildLoopPromotionMap( // Indexing19. Its parent ID loop group is promoted, but the loop // group of iS50 is not found yet. + VERBOSE() << "Promotion projected to loop groups:\n"; + for (const auto& [loop_group, id] : loop_graph_copy_promotion_map) { + VERBOSE() << nvfuser::toString(loop_group) << " -> " << id->name() + << std::endl; + } + // Reset the promotion map for the second pass. // TODO: Unclear if we could simply update the iel_promotion_map from // buildInlinePromotions, instead of manually building it. @@ -1693,278 +1849,32 @@ std::unordered_map IdModel::buildLoopPromotionMap( // Need to run a replay for the loop groups that are dependent on inlined loop // groups, but themselves are not inlined loop groups. - - for (const ExprGroup& iel_expr : - IdGraphStmtSort(intersection_exact_loop_graph).exprs()) { - NVF_ERROR(!iel_expr->empty()); - - std::vector iel_inp_groups = - intersection_exact_loop_graph.inputGroups(iel_expr); - - std::vector iel_out_groups = - intersection_exact_loop_graph.outputGroups(iel_expr); - - // When replaying the transformations we can't blindly apply loop promotion - // to all iter domains within a loop group as it would replay the - // transformations within that loop group on the promoted id of that loop - // group. - // - // i.e. if we have the inlined domains from: - // T2[i0*i1] pa(1) = T0[i0*b1]ca(1) + T1[i0*i1]ca(1) - // The inlined loop group would be: - // - // i0, i1, b1, i0*i1, b0*i1 - // Then if we replayed the iel transformations they would be: - // merge(i0, i1) - // merge(i0, b1) - // - // So if we replayed them with loop promotion, then i0, i1, b1 would be - // promoted to i0*i1, and the merges would be replayed. - // - // Therefore only promote i0*b1 to i0*i1, or i0*i1 to i0*i1 (i.e. don't - // promote an input to any transformation within the loop group). - // - // So if we have an iel_expr make sure it's inputs and outputs are not in - // the same loop group. - - ValGroups inp_loop_groups; - for (const ValGroup& iel_inp_group : iel_inp_groups) { - inp_loop_groups.pushBack(loop_graph_copy.toGroup(iel_inp_group->front())); - } - - ValGroups out_loop_groups; - for (const ValGroup& iel_out_group : iel_out_groups) { - out_loop_groups.pushBack(loop_graph_copy.toGroup(iel_out_group->front())); - } - - // The inputs should be promoted based on the loop promotion map. - bool loop_promote_inputs = - !inp_loop_groups.computeSubtract(out_loop_groups).empty(); - - std::vector promoted_inputs; - - bool an_input_was_promoted = false; - - // Promote inputs for replay - for (const ValGroup& iel_inp_group : iel_inp_groups) { - // Promote loops based on the loop promotion map. If the loop promotion - // map should be used and has an entry we should use that promotion. This - // happen when an iel expression is across a loop group boundary. - // Signifying and capturing instances when we traverse across an inlined - // loop group to a non-inlined loop group boundary (think of the iel graph - // projected onto the loop graph). - const ValGroup& loop_copy_group = - loop_graph_copy.toGroup(iel_inp_group->front()); - auto inp_loop_promo_it = - loop_graph_copy_promotion_map.find(loop_copy_group); - if (loop_promote_inputs && - inp_loop_promo_it != loop_graph_copy_promotion_map.end()) { - promoted_inputs.push_back(inp_loop_promo_it->second); - an_input_was_promoted = true; - } else { - // We still could require an input promotion. We could be traversing - // across non-inlined groups. Meaning we have inputs that were promoted - // in an inlined loop group traversing through the non-inlined portions - // of the iel graph. - auto inp_promo_it = iel_promotion_map.find(iel_inp_group); - if (inp_promo_it == iel_promotion_map.end()) { - promoted_inputs.push_back(iel_inp_group->front()->as()); - } else { - promoted_inputs.push_back(inp_promo_it->second); - an_input_was_promoted = true; - } - } - } - - if (!an_input_was_promoted) { - continue; - } - - Expr* replay = nullptr; - - // Before replaying, check if there's already an expression like this, if so - // use that for promotion. We're still only looking for representative iter - // domains, so if there's already an expression that would produce something - // representative (matching in the IEL graph) of what the new inputs would - // generate, just promote to that expressions outputs, don't bother - // generating a new one. - // - // Check all uses of the IEL map the inputs are in, and look for one that - // would match. Grab all uses of the promoted inputs' groups in the IEL - // map. Note that promotion should be to loop-mapped domains, so - // the IEL graph is used rather than the exact graph - std::vector promoted_input_groups; - - ExprGroups promoted_input_uses; - for (auto inp_id : promoted_inputs) { - // inp_id may have been just replayed, in which case it should - // not exist in the IEL graph. It should be just ignored as it - // should not have any use yet. - if (!intersection_exact_loop_graph.hasGroup(inp_id)) { - continue; - } - const auto& inp_exact_group = - intersection_exact_loop_graph.toGroup(inp_id); - promoted_input_groups.push_back(inp_exact_group); - promoted_input_uses.pushBack( - *intersection_exact_loop_graph.getUses(inp_exact_group)); - } - - // Check every use to see if it matches - for (const ExprGroup& iel_use_group : promoted_input_uses) { - NVF_ERROR(!iel_use_group->empty()); - // Check if all the attributes (including type) of the transform match - if (!ValGraph::transformAtributesMatch( - iel_expr->front(), iel_use_group->front())) { - continue; - } - // Check if inputs all match - if (promoted_input_groups != - intersection_exact_loop_graph.inputGroups(iel_use_group)) { - continue; - } - // Input mapping doesn't always mean expr and output - // mappings. Make sure the exprs are mapped, which automatically - // means the outputs are mapped in the case of the LOOP map - if (!idGraph(IdMappingMode::LOOP) - .disjointExprSets() - .permissiveAreMapped( - iel_expr->front(), iel_use_group->front())) { - continue; - } - // This is just an extra sanity check. Make sure all exprs in - // the use group are mapped - NVF_ERROR( - std::all_of( - iel_use_group->vector().begin(), - iel_use_group->vector().end(), - [&](Expr* iel_use) { - return idGraph(IdMappingMode::LOOP) - .disjointExprSets() - .permissiveAreMapped(iel_expr->front(), iel_use); - }), - "Not all mapped: ", - nvfuser::toString(iel_expr), - "\n", - nvfuser::toString(iel_use_group)); - - replay = iel_use_group->front(); - break; - } - - bool replayed = replay == nullptr; - if (replay == nullptr) { - replay = addReplayAs(promoted_inputs, iel_expr->front()); - } - - auto output_groups = intersection_exact_loop_graph.outputGroups(iel_expr); - - // Match or replay, mark promotion for output groups. - auto replay_out_ids = - ir_utils::filterByType(replay->outputs()).vector(); - auto ref_out_ids = - ir_utils::filterByType(iel_expr->front()->outputs()) - .vector(); - - NVF_ERROR(replay_out_ids.size() == output_groups.size()); - - for (auto i : c10::irange(replay_out_ids.size())) { - if (!idGraph(IdMappingMode::EXACT) - .disjointValSets() - .strictAreMapped(replay_out_ids[i], output_groups[i]->front())) { - // Promote if necessary, if the output is already in the same exact map - // it doesn't need a promotion. - iel_promotion_map[output_groups[i]] = replay_out_ids[i]; - // Explicitly map loop map since expr propagation doesn't happen on the - // loop map and the replayed outputs are brand new so we can map them - // without joining disjoint loop groups (other than the new loop groups - // the outputs of the replay are in) - if (replayed) { - // If we built new iter domains because we generated a new expression, - // link the outputs in the loop graph. - idGraph(IdMappingMode::LOOP) - .mapVals(replay_out_ids[i], ref_out_ids[i]); - } - } - } - } - - // Update the coverage map + propagatePromotions( + iel_graph, + iel_promotion_map, + loop_graph_copy, + loop_graph_copy_promotion_map, + true); + + // Update the coverage map as new IDs were added to the exact graph exact_covered_ids = computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); // Set up the loop promotion map of loop groups to promotion - // IDs. Note that the IEL promotion map is still incomplete in the - // sense that: - // - // - Not all loop graphs have promotions set at this point. - // - Multiple domains that are loop-mapped may have different - // promotions, one of which should cover the rest. - // - // Fill the gap, here we traverse the loop graph and for each loop - // group we examine each IEL group. If an IEL group has a promotion, - // we consider it as a candidate of the promotion of this loop - // group. If not, we include a domain of the IEL group as a - // candidate too. We also look at the inline promotion map since - // that may also contain the promotion the loop should be associated - // with. Once all candidates are obtained, we pick one that covers - // all the exact domains (cf. concrete domains in ComputeAtMap) + // IDs. for (const ValGroup& loop_group : loop_graph_copy.disjointValSets().disjointSets()) { - ValGroups iel_groups = intersection_exact_loop_graph.toGroups(*loop_group); - // All exact groups covered by all iter domains in this loop group - ValGroups loop_group_covered_ids; - for (const ValGroup& iel_group : iel_groups) { - auto exact_group = - idGraph(IdMappingMode::EXACT).toGroup(iel_group->front()); - auto covered_it = exact_covered_ids.find(exact_group); - NVF_ERROR( - covered_it != exact_covered_ids.end(), - "Exact covered id not found for ", - nvfuser::toString(exact_group)); - loop_group_covered_ids.pushBack(covered_it->second); - } - - VectorOfUniqueEntries representative_id_candidates; - - for (const ValGroup& iel_group : iel_groups) { - if (auto iel_promotion_map_it = iel_promotion_map.find(iel_group); - iel_promotion_map_it != iel_promotion_map.end()) { - IterDomain* iel_promotion_id = iel_promotion_map_it->second; - representative_id_candidates.pushBack(iel_promotion_id); - } else { - representative_id_candidates.pushBack( - iel_group->front()->as()); - } - } - - if (auto loop_graph_copy_promotion_map_it = - loop_graph_copy_promotion_map.find( - loop_graph_copy.toGroup(loop_group->vector().at(0))); - loop_graph_copy_promotion_map_it != - loop_graph_copy_promotion_map.end()) { - VERBOSE() << "Found in loop promotion: " << nvfuser::toString(loop_group) - << std::endl; - representative_id_candidates.pushBack( - loop_graph_copy_promotion_map_it->second); - } - - VERBOSE() << "Loop promotion candidates: " << std::endl; - - // All candidates gathered - for (IterDomain* candidate_id : representative_id_candidates) { - auto covered_it = exact_covered_ids.find( - idGraph(IdMappingMode::EXACT).toGroup(candidate_id)); - NVF_ERROR(covered_it != exact_covered_ids.end()); - if (loop_group_covered_ids.computeSubtract(covered_it->second).empty()) { - // Found - VERBOSE() << "Representative found: " << candidate_id->toString() - << std::endl; - const ValGroup& current_loop_group = - idGraph(IdMappingMode::LOOP).toGroup(loop_group->front()); - loop_promotion_map_.emplace(current_loop_group, candidate_id); - break; - } + IterDomain* promotion_id = findPromotionOfLoopGroup( + loop_group, + iel_graph, + iel_promotion_map, + loop_graph_copy_promotion_map, + exact_covered_ids, + terminal_loop_ids); + if (promotion_id) { + const ValGroup& current_loop_group = + idGraph(IdMappingMode::LOOP).toGroup(loop_group->front()); + loop_promotion_map_.emplace(current_loop_group, promotion_id); } } diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index ab1d452c040..ed839a7667f 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -26,7 +26,8 @@ struct StatefulInliningInfo { // used for deterministic order VectorOfUniqueEntries ordered_p_ca_ids; - // Broadcast resolution map for root domains + // Broadcast resolution map for root domains, including non-inlined + // root domains std::unordered_map> p2c_root_broadcast_resolution_map; @@ -173,17 +174,32 @@ class IdModel : public PolymorphicBase { // Start loop map by grouping inlined iter domains void initializeLoopMap(const StatefulInliningInfo& info); - std::unordered_map buildInlineRootPromotions( - const ValGraph& iel_graph, - const StatefulInliningInfo& info); - - // Returns map of ValGroups in the loop map to a representative IterDomain + // Returns map of ValGroups in the IEL graph to a representative IterDomain // that contains all resolved transformations that the terminal IterDomains // should be promoted to. The returned promotions are valid only for inlined // iter domains. std::unordered_map buildInlinePromotions( const StatefulInliningInfo& info); + // Helper function for buildInlinePromotions. Only build mappings of + // root broadcast domains + std::unordered_map buildInlineRootPromotions( + const ValGraph& iel_graph, + const StatefulInliningInfo& info); + + // Helper function for buildInlinePromotions. Propagate root + // promotions to intermediate and leaf domains + void propagatePromotions( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map); + + void propagatePromotions( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map, + const ValGraph& loop_graph, + const std::unordered_map& loop_graph_promotion_map, + bool require_loop_mapped_promotion); + // Returns a similar thing to buildInlinePromotions but also includes iter // domains that are not inlined. std::unordered_map buildLoopPromotionMap( @@ -191,6 +207,25 @@ class IdModel : public PolymorphicBase { const StatefulInliningInfo& info, const std::unordered_map& stale_promotion_map); + // Find a promoted iter domain of a given loop group that covers all + // the exact groups representative of the resolved transformations + // within the loop group. It doesn't have to be in the loop + // group. Specifically, we examine each IEL group of the loop graph, + // and if an IEL group has a promotion, we consider it as a + // candidate of the promotion of this loop group. If not, we include a + // domain of the IEL group as a candidate too. We also look at the + // inline promotion map since that may also contain the promotion the + // loop should be associated with. Once all candidates are obtained, + // we pick one that covers all the exact domains (cf. concrete domains + // in ComputeAtMap) + IterDomain* findPromotionOfLoopGroup( + const ValGroup& loop_group, + const ValGraph& iel_graph, + const std::unordered_map& iel_promotion_map, + const std::unordered_map& loop_graph_promotion_map, + const std::unordered_map& exact_covered_ids, + const VectorOfUniqueEntries& terminal_loop_ids); + // Make sure only leaf nodes of tensor views are parallelized void validatePTypes(const std::vector& all_tvs) const; @@ -207,7 +242,7 @@ class IdModel : public PolymorphicBase { // is also in the same loop group // 2) Don't have a direct IterDomain consumer within the group VectorOfUniqueEntries computeTerminalLoopIds( - const StatefulInliningInfo info); + const StatefulInliningInfo& info); // Returns an IdGraph with all Id's mapped that are mapped both in graph0 and // graph1. diff --git a/csrc/id_model/transform_replay.cpp b/csrc/id_model/transform_replay.cpp index 2561bb0a7e7..3dbea53c588 100644 --- a/csrc/id_model/transform_replay.cpp +++ b/csrc/id_model/transform_replay.cpp @@ -65,6 +65,16 @@ void ReplayTransform::handle(const Swizzle2D* swizzle_2d) { .first->definition(); } +void ReplayTransform::handle(const Swizzle* swizzle) { + NVF_ERROR( + input_ids_.size() == 2, + "Expected two inputs to match swizzle: ", + swizzle->toString()); + replayed_expr_ = + IterDomain::swizzle(swizzle->swizzleType(), input_ids_[0], input_ids_[1]) + .first->definition(); +} + void ReplayTransform::handle(const Resize* resize) { NVF_ERROR( input_ids_.size() == 1, diff --git a/csrc/id_model/transform_replay.h b/csrc/id_model/transform_replay.h index ff549db65e5..276c79848a5 100644 --- a/csrc/id_model/transform_replay.h +++ b/csrc/id_model/transform_replay.h @@ -42,6 +42,8 @@ class ReplayTransform : OptInConstDispatch { // if replaying swizzle is enabled. void handle(const Swizzle2D* swizzle_2d) final; + void handle(const Swizzle* swizzle) final; + // We're going to replay this resize operation on the corresponding IDs // if replaying resize is enabled. void handle(const Resize* resize) final; diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index b9ea7213c19..25b03c52e88 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -848,7 +848,8 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { auto promotion_map_it = promotion_map.find(merge_loop_group); ASSERT_TRUE(promotion_map_it != promotion_map.end()) - << "Loop promotion not found for merge loop group"; + << "Loop promotion not found for merge loop group: " + << nvfuser::toString(merge_loop_group); auto merge_out_promotion_id = promotion_map_it->second; ASSERT_EQ( id_model.idGraph(IdMappingMode::EXACT).toGroup(merge_out_promotion_id), From 1904eff59e1562b07b69c3bc1a767f136e646d6b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 6 Jan 2024 23:28:10 -0800 Subject: [PATCH 109/178] further refactoring of loop promotion --- csrc/id_model/id_model.cpp | 389 ++++++++++++++++--------------------- csrc/id_model/id_model.h | 39 ++-- 2 files changed, 192 insertions(+), 236 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index ff5abf05422..0929875e2e6 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -895,21 +895,45 @@ StatefulInliningInfo buildStatefulInliningInfo( return info; } +// Update a map of ValGroups to ID from an old Valgraph to a new +// ValGraph. The new graph must be a superset of the old graph. +std::unordered_map updateMap( + const std::unordered_map& stale_map, + ValGraph& new_graph) { + std::unordered_map new_map; + + for (const auto& [stale_group, mapped_id] : stale_map) { + const ValGroups& new_groups = new_graph.toGroups(*stale_group); + NVF_ERROR( + new_groups.size() == 1, + "\nUpdate map assumes that new graph is equivalent to old graph plus extra mappings.\n", + "i.e. all mappings in new_graph should exist in the graph stale_map was produced on.\n", + "old:", + nvfuser::toString(stale_group), + "new: ", + nvfuser::toString(new_groups)); + NVF_ERROR( + new_map.emplace(new_groups.front(), mapped_id).second, + "Expected only a single mapping but multiple entries detected for ", + nvfuser::toString(new_groups.front())); + } + return new_map; +} + } // namespace void IdModel::buildLoopMap(const std::vector& exprs) { - if (exprs.empty()) { - return; + if (!exprs.empty()) { + std::stringstream ss; + exprs.at(0)->fusion()->print(ss); + VERBOSE() << ss.str(); } - const StatefulInliningInfo info = buildStatefulInliningInfo( + // Gather broadcast resolution and inlining information + const StatefulInliningInfo inlining_info = buildStatefulInliningInfo( exprs, idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::PERMISSIVE)); - std::stringstream ss; - exprs.at(0)->fusion()->print(ss); - VERBOSE() << ss.str(); - - initializeLoopMap(info); + initializeLoopMap(inlining_info); VERBOSE() << "Initial loop graph:\n"; for (const auto& group : @@ -917,33 +941,116 @@ void IdModel::buildLoopMap(const std::vector& exprs) { VERBOSE() << nvfuser::toString(group) << std::endl; } - // Initial propagation of parallel types for inlined iter domains. Each time - // new expressions are replayed this needs to be run. The disjoint sets in - // the loop graph can only be joined after this point. - // propagateLoopPTypes(); - - auto iel_promotion_map = buildInlinePromotions(info); - // propagateLoopPTypes(); - - // Find loops that need to be promoted because of broadcast resolution, - // figure out what that resolution should look like, compute IDs for it if - // necessary. - iel_promotion_map = buildLoopPromotionMap(exprs, info, iel_promotion_map); - // Loop map potentialy changed changed, as we could have replayed - // expressions. Re-propagate parallel types. - // propagateLoopPTypes(); - - // This pass still doesn't work, disable for now in case it's disruptive to - // tests. - /* - // Find loops that need to be promoted because of broadcast resolution, - // figure out what that resolution should look like, compute IDs for it if - // necessary. - auto leaf_id_promo_map = - buildIndexGraph(tv_exprs, all_tvs, info, iel_promotion_map); - // Make sure we update ptypes onto the index leaf iter domains - propagateLoopPTypes(); - */ + loop_promotion_map_ = buildLoopPromotionMap(inlining_info); +} + +std::unordered_map IdModel::buildLoopPromotionMap( + const StatefulInliningInfo& inlining_info) { + // Make an intersection of the exact and loop map. This will group together + // entries in each loop group that are exact with each other. This provides a + // better graph to do promotion and replays. + // + // It's tempting to use the intersection of the almost exact and loop, but we + // need to model broadcast promotion, and if we have two tensors like: + // + // T1[i0, b1] = T0[i0] + // T2[i0, b2] = T0[i0] + // Then resolution of: + // T4 = T1[i0, b1] + T3[i0, i1] + // T6 = T2[i0, b2] + T5[i0, i2] + // + // Then merge(0, 1) with all tensors except for T0 + // + // The almost exact map will map i0, i0*b1, and i0*b2 together, but b1 and b2 + // are being resolved to i1 and i2 respectively. So we want to have separate + // entries so we can have an easy to process promotion map. + // + // Loop is a permissive like map, it could have many entries, use the exact + // map as the one we iterate on to reduce complexity as it hopefully has + // smaller groups and this algorithm scales with the number of groups * + // (number of entries in groups ^ 2) + // + // iel stands for Intersection of the Exact and Loop graphs. + ValGraph iel_graph = buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); + + // Loop promotion map is to prepare for IterDomain replays to resolve + // non-inlined loop groups. Since these replays will modify the loop map as + // we're iterating over the loop map, operate on a copy of the loop map, not + // the original one. + auto loop_graph_copy = idGraph(IdMappingMode::LOOP); + + // Step 1: Build a map of the IEL groups of root broadcast domains + // to resolving domains. + std::unordered_map iel_promotion_map = + buildInlineRootPromotionMap(iel_graph, inlining_info); + + // Step 2: Propagate the root promotions to intermediate and leaf groups. + // At this point, the promotion may not be final as the analysis is + // localized to IEL groups. The map is used in the next step to + // build mappings of the loop groups. + propagatePromotionsInIELGraph(iel_graph, iel_promotion_map); + + std::stringstream ss; + ss << "Inline promotion map\n"; + for (const auto& [iel_group, promoted_id] : iel_promotion_map) { + ss << "\t" << nvfuser::toString(iel_group) << " -> " << promoted_id->name() + << std::endl; + } + VERBOSE() << ss.str(); + + // Step 3: Determine the promotion of each loop graph based on the + // IEL promotion map. For each loop group, examine all the IEL + // promotions and find the most representative one that captures all + // the dependent input domains of the loop group + std::unordered_map loop_graph_copy_promotion_map = + projectIELPromotionToLoopGraph( + iel_graph, iel_promotion_map, loop_graph_copy, inlining_info); + + // At this point, most of loop groups should have correct promoted + // IDs. However, non-inlined loop groups may miss promotion that + // should be propagated from parent ID groups, e.g., iS50 of T2 in + // Indexing19. Its parent ID loop group is promoted, but the loop + // group of iS50 is not found yet. + + // Update the IEL graph as new domains may have been added + iel_graph = buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); + + loop_graph_copy = idGraph(IdMappingMode::LOOP); + loop_graph_copy_promotion_map = + updateMap(loop_graph_copy_promotion_map, loop_graph_copy); + + // Step 4: In order to fully propagate the loop graph promotions, first + // propagate them to the IEL groups, which are then used to + // propagate back to the loop groups in Step 5 + std::unordered_map final_iel_promotion_map; + propagatePromotionsInIELGraph( + iel_graph, + final_iel_promotion_map, + loop_graph_copy, + loop_graph_copy_promotion_map, + true); + + // Step 5: Find the final promotion of each loop group based on the + // final IEL promotion map + auto final_loop_promotion_map = projectIELPromotionToLoopGraph( + iel_graph, final_iel_promotion_map, loop_graph_copy, inlining_info); + + // The loop map is built for loop_graph_copy. Update the map to the + // latest loop graph + final_loop_promotion_map = + updateMap(final_loop_promotion_map, idGraph(IdMappingMode::LOOP)); + + sanityCheckLoopPromotionMap(final_loop_promotion_map); + + VERBOSE() << "Final loop promotion map:" << std::endl; + for (const auto& [loop_group, id] : final_loop_promotion_map) { + VERBOSE() << nvfuser::toString(loop_group) << " -> " << id->name() + << std::endl; + } + + return final_loop_promotion_map; } // TODO: Reenable after reenabling parallel propagation. @@ -1151,39 +1258,9 @@ void IdModel::initializeLoopMap(const StatefulInliningInfo& info) { } } -std::unordered_map IdModel::buildInlineRootPromotions( +std::unordered_map IdModel::buildInlineRootPromotionMap( const ValGraph& iel_graph, const StatefulInliningInfo& info) { - // Make an intersection of the exact and loop map. This will group together - // entries in each loop group that are exact with each other. This provides a - // better graph to do promotion and replays. - - // It's tempting to use the intersection of the almost exact and loop, but we - // need to model broadcast promotion, and if we have two tensors like: - // - // T1[i0, b1] = T0[i0] - // T2[i0, b2] = T0[i0] - // Then resolution of: - // T4 = T1[i0, b1] + T3[i0, i1] - // T6 = T2[i0, b2] + T5[i0, i2] - // - // Then merge(0, 1) with all tensors except for T0 - // - // The almost exact map will map i0, i0*b1, and i0*b2 together, but b1 and b2 - // are being resolved to i1 and i2 respectively. So we want to have separate - // entries so we can have an easy to process promotion map. - // - // Loop is a permissive like map, it could have many entries, use the exact - // map as the one we iterate on to reduce complexity as it hopefully has - // smaller groups and this algorithm scales with the number of groups * - // (number of entries in groups ^ 2) - - // Promotion logic is going to be on the intersection of the exact and loop - // graph. We will generate a map on the entries of this graph so it's - // important to not modify this graph moving forward, as that would invalidate - // the map. - // - // iel stands for Intersection of the Exact and Loop graphs. std::unordered_map iel_promotion_map; // This should probably work just on terminating inputs, as we shouldn't be @@ -1191,7 +1268,6 @@ std::unordered_map IdModel::buildInlineRootPromotions( // required to resolve a non input broadcast domain. But for now leaving it as // traversal on all broadcast groups. // - // TODO-NM: The ordering appears to be non-deterministic // We first visit all broadcast root domains. If a broadcast is // resovled, see if it's promoted. Note that a domain be resolved to @@ -1239,6 +1315,11 @@ std::unordered_map IdModel::buildInlineRootPromotions( continue; } + // resolved_exact_groups is a list of IDs that resolves the + // broadcast. We only care those that are also in the same loop + // group, and there must be just one or none. Otherwise, the + // resolution is ambiguous. + // Collect all the exact groups in the loop set containing this iel_group const ValGroup& loop_group = idGraph(IdMappingMode::LOOP).toGroup(iel_group_id); @@ -1367,7 +1448,9 @@ bool hasUniqueOutputLoopGroups( // it is used there's no loop graph promotion yet, so only the IEL // promotions are propagated. In that case, loop_graph_promotion_map // should be just empty. -void IdModel::propagatePromotions( +// +// The loop_graph pamameter may not be up-to-date. +void IdModel::propagatePromotionsInIELGraph( const ValGraph& iel_graph, std::unordered_map& iel_promotion_map, const ValGraph& loop_graph, @@ -1575,87 +1658,15 @@ void IdModel::propagatePromotions( } } -void IdModel::propagatePromotions( +void IdModel::propagatePromotionsInIELGraph( const ValGraph& iel_graph, std::unordered_map& iel_promotion_map) { - propagatePromotions( + propagatePromotionsInIELGraph( iel_graph, iel_promotion_map, idGraph(IdMappingMode::LOOP), {}, false); } -std::unordered_map IdModel::buildInlinePromotions( - const StatefulInliningInfo& info) { - // Make an intersection of the exact and loop map. This will group together - // entries in each loop group that are exact with each other. This provides a - // better graph to do promotion and replays. - - // It's tempting to use the intersection of the almost exact and loop, but we - // need to model broadcast promotion, and if we have two tensors like: - // - // T1[i0, b1] = T0[i0] - // T2[i0, b2] = T0[i0] - // Then resolution of: - // T4 = T1[i0, b1] + T3[i0, i1] - // T6 = T2[i0, b2] + T5[i0, i2] - // - // Then merge(0, 1) with all tensors except for T0 - // - // The almost exact map will map i0, i0*b1, and i0*b2 together, but b1 and b2 - // are being resolved to i1 and i2 respectively. So we want to have separate - // entries so we can have an easy to process promotion map. - // - // Loop is a permissive like map, it could have many entries, use the exact - // map as the one we iterate on to reduce complexity as it hopefully has - // smaller groups and this algorithm scales with the number of groups * - // (number of entries in groups ^ 2) - - // Promotion logic is going to be on the intersection of the exact and loop - // graph. We will generate a map on the entries of this graph so it's - // important to not modify this graph moving forward, as that would invalidate - // the map. - // - // iel stands for Intersection of the Exact and Loop graphs. - ValGraph iel_graph = buildIntersection( - idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); - - // First, identify promotions of root broadcast domains only - std::unordered_map iel_promotion_map = - buildInlineRootPromotions(iel_graph, info); - - propagatePromotions(iel_graph, iel_promotion_map); - - std::stringstream ss; - ss << "Inline promotion map\n"; - for (const auto& [iel_group, promoted_id] : iel_promotion_map) { - ss << "\t" << nvfuser::toString(iel_group) << " -> " << promoted_id->name() - << std::endl; - } - VERBOSE() << ss.str(); - - return iel_promotion_map; -} - namespace { -std::unordered_map updateMap( - const std::unordered_map& stale_map, - ValGraph& new_graph) { - std::unordered_map new_map; - - for (const auto& [stale_key, mapped_id] : stale_map) { - const ValGroups& new_groups = new_graph.toGroups(*stale_key); - NVF_ERROR( - new_groups.size() == 1, - "\nUpdate map assumes that new graph is equivalent to old graph plus extra mappings.\n", - "i.e. all mappings in new_graph should exist in the graph stale_map was produced on.\n", - "old:", - nvfuser::toString(stale_key), - "new: ", - nvfuser::toString(new_groups)); - new_map[new_groups.front()] = mapped_id; - } - return new_map; -} - // Returns for each ValGroup in provided IdGraph what the input ValGroups are // traversing on definitions. Ignoring broadcast ValGroups and resetting inputs // at RFactor ValGroups. @@ -1786,28 +1797,13 @@ IterDomain* IdModel::findPromotionOfLoopGroup( return nullptr; } -std::unordered_map IdModel::buildLoopPromotionMap( - const std::vector& exprs, - const StatefulInliningInfo& inlining_info, - const std::unordered_map& stale_promotion_map) { - // Non-ca domains may also need to be promoted if parent domains are - // promoted. - - // Need to use the intersection of exact and loop map again, it needs to be - // recomputed. - auto iel_graph = buildIntersection( - idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); - - // Update the promotion map - auto iel_promotion_map = updateMap(stale_promotion_map, iel_graph); - - // Loop promotion map is to prepare for IterDomain replays to resolve - // non-inlined loop groups. Since these replays will modify the loop map as - // we're iterating over the loop map, operate on a copy of the loop map, not - // the original one. - auto loop_graph_copy = idGraph(IdMappingMode::LOOP); - - std::unordered_map loop_graph_copy_promotion_map; +std::unordered_map IdModel:: + projectIELPromotionToLoopGraph( + const ValGraph& iel_graph, + const std::unordered_map& iel_promotion_map, + const ValGraph& loop_graph, + const StatefulInliningInfo& inlining_info) { + std::unordered_map loop_promotion_map; std::unordered_map exact_covered_ids = computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); @@ -1817,7 +1813,8 @@ std::unordered_map IdModel::buildLoopPromotionMap( computeTerminalLoopIds(inlining_info); for (const ValGroup& loop_group : - loop_graph_copy.disjointValSets().disjointSets()) { + loop_graph.disjointValSets().disjointSets()) { + // Error happens here. Likely iel_graph is stale IterDomain* promotion_id = findPromotionOfLoopGroup( loop_group, iel_graph, @@ -1826,59 +1823,21 @@ std::unordered_map IdModel::buildLoopPromotionMap( exact_covered_ids, terminal_loop_ids); if (promotion_id) { - loop_graph_copy_promotion_map[loop_group] = promotion_id; + loop_promotion_map[loop_group] = promotion_id; } } - // At this point, most of loop groups should have correct promoted - // IDs. However, non-inlined loop groups may miss promotion that - // should be propagated from parent ID groups, e.g., iS50 of T2 in - // Indexing19. Its parent ID loop group is promoted, but the loop - // group of iS50 is not found yet. - VERBOSE() << "Promotion projected to loop groups:\n"; - for (const auto& [loop_group, id] : loop_graph_copy_promotion_map) { + for (const auto& [loop_group, id] : loop_promotion_map) { VERBOSE() << nvfuser::toString(loop_group) << " -> " << id->name() << std::endl; } - // Reset the promotion map for the second pass. - // TODO: Unclear if we could simply update the iel_promotion_map from - // buildInlinePromotions, instead of manually building it. - iel_promotion_map.clear(); - - // Need to run a replay for the loop groups that are dependent on inlined loop - // groups, but themselves are not inlined loop groups. - propagatePromotions( - iel_graph, - iel_promotion_map, - loop_graph_copy, - loop_graph_copy_promotion_map, - true); - - // Update the coverage map as new IDs were added to the exact graph - exact_covered_ids = - computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); - - // Set up the loop promotion map of loop groups to promotion - // IDs. - for (const ValGroup& loop_group : - loop_graph_copy.disjointValSets().disjointSets()) { - IterDomain* promotion_id = findPromotionOfLoopGroup( - loop_group, - iel_graph, - iel_promotion_map, - loop_graph_copy_promotion_map, - exact_covered_ids, - terminal_loop_ids); - if (promotion_id) { - const ValGroup& current_loop_group = - idGraph(IdMappingMode::LOOP).toGroup(loop_group->front()); - loop_promotion_map_.emplace(current_loop_group, promotion_id); - } - } + return loop_promotion_map; +} - // Sanity check of the loop promotion map +void IdModel::sanityCheckLoopPromotionMap( + const std::unordered_map& loop_promotion_map) { for (const ValGroup& loop_group : idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { // Non-leaf loop groups are not guaranteed to have valid @@ -1887,9 +1846,9 @@ std::unordered_map IdModel::buildLoopPromotionMap( if (idGraph(IdMappingMode::LOOP).hasUses(loop_group)) { continue; } - auto promotion_it = loop_promotion_map_.find(loop_group); + auto promotion_it = loop_promotion_map.find(loop_group); NVF_ERROR( - promotion_it != loop_promotion_map_.end(), + promotion_it != loop_promotion_map.end(), "Loop promotion not found for ", nvfuser::toString(loop_group)); IterDomain* promotion = promotion_it->second; @@ -1901,14 +1860,6 @@ std::unordered_map IdModel::buildLoopPromotionMap( ". Promotion domain: ", promotion->name()); } - - VERBOSE() << "Loop promotion map:" << std::endl; - for (const auto& [iel_group, id] : iel_promotion_map) { - VERBOSE() << nvfuser::toString(iel_group) << " -> " << id->name() - << std::endl; - } - - return iel_promotion_map; } std::unordered_map IdModel::buildIndexGraph( diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index ed839a7667f..f3b265da5a1 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -174,38 +174,43 @@ class IdModel : public PolymorphicBase { // Start loop map by grouping inlined iter domains void initializeLoopMap(const StatefulInliningInfo& info); - // Returns map of ValGroups in the IEL graph to a representative IterDomain - // that contains all resolved transformations that the terminal IterDomains - // should be promoted to. The returned promotions are valid only for inlined - // iter domains. - std::unordered_map buildInlinePromotions( + // Build a map of loop groups to IterDomains that represent actual + // loops. The map is built based on the broadcast resolution with + // root domains between inlined producer and consumer tensors. + std::unordered_map buildLoopPromotionMap( const StatefulInliningInfo& info); - // Helper function for buildInlinePromotions. Only build mappings of - // root broadcast domains - std::unordered_map buildInlineRootPromotions( + // Helper function for buildLoopPromotionMap. Returns a map of + // root broadcast ValGroups in the IEL graph to a representative IterDomain. + std::unordered_map buildInlineRootPromotionMap( const ValGraph& iel_graph, const StatefulInliningInfo& info); - // Helper function for buildInlinePromotions. Propagate root - // promotions to intermediate and leaf domains - void propagatePromotions( + // Helper function for building loop promotion map. Propagate + // promotions of root IEL groups to leaf IEL groups + void propagatePromotionsInIELGraph( const ValGraph& iel_graph, std::unordered_map& iel_promotion_map); - void propagatePromotions( + // Same as the other version but also propagates promotoins of loop + // groups as well + void propagatePromotionsInIELGraph( const ValGraph& iel_graph, std::unordered_map& iel_promotion_map, const ValGraph& loop_graph, - const std::unordered_map& loop_graph_promotion_map, + const std::unordered_map& loop_promotion_map, bool require_loop_mapped_promotion); // Returns a similar thing to buildInlinePromotions but also includes iter // domains that are not inlined. - std::unordered_map buildLoopPromotionMap( - const std::vector& exprs, - const StatefulInliningInfo& info, - const std::unordered_map& stale_promotion_map); + std::unordered_map projectIELPromotionToLoopGraph( + const ValGraph& iel_graph, + const std::unordered_map& iel_promotion_map, + const ValGraph& loop_graph, + const StatefulInliningInfo& inlining_info); + + void sanityCheckLoopPromotionMap( + const std::unordered_map& loop_promotion_map); // Find a promoted iter domain of a given loop group that covers all // the exact groups representative of the resolved transformations From 133613a4b529b942e5ee32b0707ab2b9f7695605 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 10 Jan 2024 17:39:36 -0800 Subject: [PATCH 110/178] IdModel: enable compliment mapping (#1611) See the design doc for more details --- csrc/id_model/id_model.cpp | 17 ++++++++++------- csrc/id_model/id_model.h | 4 ++++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 0929875e2e6..a9514c4128e 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -779,9 +779,10 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); } - for (const auto& entry : permissive_forwarding.producer_compliment_map) { - for (auto entry_2 : entry.second) { - if (getenv("COMP")) { + if (permissive_graph_map_compliment_ids_) { + for (const auto& entry : + permissive_forwarding.producer_compliment_map) { + for (auto entry_2 : entry.second) { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry_2); } } @@ -791,9 +792,10 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); } - for (const auto& entry : permissive_forwarding.consumer_compliment_map) { - for (auto entry_2 : entry.second) { - if (getenv("COMP")) { + if (permissive_graph_map_compliment_ids_) { + for (const auto& entry : + permissive_forwarding.consumer_compliment_map) { + for (auto entry_2 : entry.second) { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry_2); } } @@ -1158,7 +1160,8 @@ void IdModel::build( } buildPermissiveMap(tv_exprs); - if (validate) { + // Validation is not implemented when compliment mapping is enabled + if (validate && !permissive_graph_map_compliment_ids_) { validator->checkPermissiveGraphEquivalence( idGraph(IdMappingMode::PERMISSIVE)); } diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index f3b265da5a1..5b545bcfd2d 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -347,6 +347,10 @@ class IdModel : public PolymorphicBase { std::unordered_map loop_promotion_map_; std::unordered_set view_rfactor_ids_; + + // By default, the permissive graph should map compliment domains as + // well. See the design doc for more details + bool permissive_graph_map_compliment_ids_ = true; }; } // namespace nvfuser From c66617d3b2accc6020d9a5606f2225b5f4fe9d67 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 11 Jan 2024 16:14:22 -0800 Subject: [PATCH 111/178] test cleanup --- csrc/id_model/id_model.h | 2 +- test/test_gpu_indexing.cpp | 124 +++++++++++++++++++++++++++++++++++-- 2 files changed, 119 insertions(+), 7 deletions(-) diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 5b545bcfd2d..42807501700 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -130,7 +130,7 @@ class IdModel : public PolymorphicBase { std::string toString() const; - const std::unordered_map loopPromotionMap() const { + const std::unordered_map& loopPromotionMap() const { return loop_promotion_map_; } diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index 25b03c52e88..473f7dff06e 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -956,8 +956,6 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { // fusion.printKernel(); } -// TODO: Finish and enable test -// // Progressive loop promotion. producer gets promoted in consumer, consumer is // promoted in a different way to its consumer. TEST_F(NVFuserTest, FusionIndexing20_CUDA) { @@ -1004,7 +1002,111 @@ TEST_F(NVFuserTest, FusionIndexing20_CUDA) { // [2, 4, (3*5//2)*7//4] tv5->inlineAt(2); - fusion.printKernel(); + IdModel id_model(&fusion); + const auto& promotion_map = id_model.loopPromotionMap(); + + // For tv1, tv2, tv4, their first leaf domains should all be + // loop-mapped and promoted to a domain that is exaclty mapped with + // the first leaf domain of tv7. The second leaf domains should also + // be promoted to the domain at the same position in tv7 but since + // they are not inlined, they should not be loop-mapped + for (auto tv : {tv1, tv2, tv4}) { + // Validating the first leaf ID + { + auto leaf_id0 = tv->axis(0); + auto ref_id0 = tv7->axis(0); + ASSERT_TRUE(id_model.idGraph(IdMappingMode::LOOP) + .disjointValSets() + .strictAreMapped(leaf_id0, ref_id0)); + auto promotion_map_it = promotion_map.find( + id_model.idGraph(IdMappingMode::LOOP).toGroup(leaf_id0)); + ASSERT_NE(promotion_map_it, promotion_map.end()); + auto promoted_id = promotion_map_it->second; + ASSERT_TRUE(id_model.idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped(promoted_id, ref_id0)) + << "Expected exact mapping: " << promoted_id->toString() << " with " + << ref_id0->toString() << " of " << tv7->toString(); + } + + // Validating the second leaf ID + { + auto leaf_id1 = tv->axis(1); + // Should be promoted to a domain that is exactly mapped with iS31 + auto ref_id1 = tv7->axis(1) + ->definition() + ->as() + ->in() + ->definition() + ->as() + ->outer(); + auto promotion_map_it = promotion_map.find( + id_model.idGraph(IdMappingMode::LOOP).toGroup(leaf_id1)); + ASSERT_NE(promotion_map_it, promotion_map.end()); + auto promoted_id = promotion_map_it->second; + ASSERT_TRUE(id_model.idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped(promoted_id, ref_id1)) + << "Expected exact mapping: " << promoted_id->toString() << " with " + << ref_id1->toString() << " of " << tv7->toString(); + // While promoted ID should be exact-mapped with the reference ID, they + // should not be loop-mapped + ASSERT_FALSE(id_model.idGraph(IdMappingMode::LOOP) + .disjointValSets() + .strictAreMapped(promoted_id, ref_id1)) + << "Expected no loop mapping: " << promoted_id->toString() << " with " + << ref_id1->toString() << " of " << tv7->toString(); + + // In the case of tv1 and tv2, the promoted id is a newly replayed + // domain, whereas for the tv4, there should be no replay as + // there's no broadcast. So, the size of the loop group should be + // 2 for the former and 1 for the latter. + const auto& leaf_id1_loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(leaf_id1); + ASSERT_EQ(leaf_id1_loop_group->size(), tv == tv4 ? 1 : 2) + << "Unexpected loop group: " + << nvfuser::toString(leaf_id1_loop_group); + } + } + + // Validate tv5. The last leaf domain should be promoted to a domain + // that is exactly mapped with the last domain of tv7 + { + auto last_leaf = tv5->axis(-1); + auto promotion_map_it = promotion_map.find( + id_model.idGraph(IdMappingMode::LOOP).toGroup(last_leaf)); + ASSERT_NE(promotion_map_it, promotion_map.end()); + auto promoted_id = promotion_map_it->second; + ASSERT_TRUE(id_model.idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped(promoted_id, tv7->axis(-1))) + << "Expected exact mapping: " << promoted_id->toString() << " with " + << tv7->axis(-1)->toString() << " of " << tv7->toString(); + + // While promoted ID should be exact-mapped with the last ID, they + // should not be loop-mapped + ASSERT_FALSE(id_model.idGraph(IdMappingMode::LOOP) + .disjointValSets() + .strictAreMapped(promoted_id, tv7->axis(-1))) + << "Expected no loop maping: " << promoted_id->toString() << " with " + << tv7->axis(-1)->toString() << " of " << tv7->toString(); + } + + // Validation not enabled yet as incorrect code is generated. Need + // to use the loop promotion info to generate correct loop-nests +#if 0 + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({5}, options); + auto t3 = at::randn({3, 5}, options); + auto t6 = at::randn({3, 5, 7}, options); + std::vector aten_inputs = {t0, t3, t6}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); +#endif } // Repro for issue #1873 @@ -1088,9 +1190,11 @@ TEST_F(NVFuserTest, FusionMultiPromotion_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -// TODO: Finish and enable test. // Broadcast and concretize same domain in two different ways and try to merge -// their loops remains unsupported. +// their loops. The inlining pattern is invalid but the current +// inlining check is not capable of flagging the inlining poistion as +// invalid. The loop promotion analysis should not find any promotion +// of the loop group where all the leaf domains are merged into. TEST_F(NVFuserTest, FusionMultiPromotion2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1125,11 +1229,19 @@ TEST_F(NVFuserTest, FusionMultiPromotion2_CUDA) { tv->merge(0); } + // Since x and y are not proven to be the same, this inling position + // should not be allowed. for (auto tv : std::vector{tv3, tv4, tv6}) { tv->inlineAt(1); } - ASSERT_ANY_THROW(fusion.printKernel()); + // For now, just make sure there's no loop promotion for the merged + // loop group. + IdModel id_model(&fusion); + const auto& leaf_loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(tv7->axis(0)); + auto promotion_map_it = id_model.loopPromotionMap().find(leaf_loop_group); + ASSERT_EQ(promotion_map_it, id_model.loopPromotionMap().end()); } // TODO: All the above tests are merges followed by splits, we should make some From 15bfe6415264e089c8c033e6558151d89b12c16f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Jan 2024 16:37:59 -0800 Subject: [PATCH 112/178] [IdModle] Val graph cleanup (#1637) Simplify `getDefinitions()` and `getUses()` of `ValGraph` as they always have mappings for any valid ValGroups. --- csrc/id_model/id_model.cpp | 40 +++++------ csrc/id_model/to_string.cpp | 12 ++-- csrc/id_model/visitor.cpp | 13 ++-- csrc/val_graph.cpp | 131 +++++++++++++++--------------------- csrc/val_graph.h | 16 ++--- 5 files changed, 90 insertions(+), 122 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 7ec9365a3fc..bbe8ebde187 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -408,10 +408,9 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { // Gather all use expressions from inputs VectorOfUniqueEntries representative_uses; for (IterDomain* inp : new_inputs) { - if (const ExprGroups* uses = graph.getUses(graph.toGroup(inp)); uses) { - for (const ExprGroup& use_group : *uses) { - representative_uses.pushBack(use_group->front()); - } + for (const ExprGroup& use_group : graph.getUses(graph.toGroup(inp))) { + NVF_ERROR(!use_group->empty()); + representative_uses.pushBack(use_group->front()); } } @@ -556,13 +555,11 @@ Expr* IdModel::addExprWithReplacement( // Forward VectorOfUniqueEntries representative_uses; for (auto in : ir_utils::filterByType(replay->inputs())) { - if (const ExprGroups* uses = graph.getUses(graph.toGroup(in)); uses) { - for (const ExprGroup& use_group : *uses) { - if (use_group == replay_group) { - continue; - } - representative_uses.pushBack(use_group->front()); + for (const ExprGroup& use_group : graph.getUses(graph.toGroup(in))) { + if (use_group == replay_group) { + continue; } + representative_uses.pushBack(use_group->front()); } } @@ -573,14 +570,12 @@ Expr* IdModel::addExprWithReplacement( // Backwards VectorOfUniqueEntries representative_defs; for (auto out : ir_utils::filterByType(replay->outputs())) { - if (auto definition = graph.getDefinitions(graph.toGroup(out)); - definition) { - for (const ExprGroup& def_group : *definition) { - if (def_group == replay_group) { - continue; - } - representative_defs.pushBack(def_group->front()); + for (const ExprGroup& def_group : + graph.getDefinitions(graph.toGroup(out))) { + if (def_group == replay_group) { + continue; } + representative_defs.pushBack(def_group->front()); } } @@ -1124,6 +1119,8 @@ void IdModel::build( const std::vector& exprs, const std::vector& additional_tvs, bool validate) { + VERBOSE() << "*** Building all graphs ***"; + // Initialize the required sets as if a permissive relationship is never // found, then querying an empty permissive map will fail later. // Initialize disjoint sets @@ -1543,9 +1540,7 @@ void IdModel::propagatePromotionsInIELGraph( continue; } const auto& inp_exact_group = iel_graph.toGroup(inp_id); - const ExprGroups* uses = iel_graph.getUses(inp_exact_group); - NVF_ERROR(uses); - maybe_promoted_input_uses.pushBack(*uses); + maybe_promoted_input_uses.pushBack(iel_graph.getUses(inp_exact_group)); } // Look for exprs that have inputs that are mapped in the IEL @@ -1682,9 +1677,8 @@ std::unordered_map computeCoveredGroups( for (const ValGroup& id_group : graph.disjointValSets().disjointSets()) { // Initialize inputs - const ExprGroups* id_group_defs = graph.getDefinitions(id_group); - NVF_ERROR(id_group_defs); - if (id_group_defs->empty()) { + const ExprGroups& id_group_defs = graph.getDefinitions(id_group); + if (id_group_defs.empty()) { covered_ids[id_group] = {id_group}; } diff --git a/csrc/id_model/to_string.cpp b/csrc/id_model/to_string.cpp index 95037ecfb04..95d0c32ade8 100644 --- a/csrc/id_model/to_string.cpp +++ b/csrc/id_model/to_string.cpp @@ -308,10 +308,8 @@ std::string definitionsString( bool with_ptr) { ExprGroups all_defs; for (const ValGroup& id_group : id_graph.disjointValSets().disjointSets()) { - if (auto definition = id_graph.getDefinitions(id_group); definition) { - for (const ExprGroup& expr_group : *definition) { - all_defs.pushBack(expr_group); - } + for (const ExprGroup& expr_group : id_graph.getDefinitions(id_group)) { + all_defs.pushBack(expr_group); } } return toString(id_graph, all_defs, indent_size, with_ptr); @@ -323,10 +321,8 @@ std::string usesString( bool with_ptr) { ExprGroups all_uses; for (const ValGroup& id_group : id_graph.disjointValSets().disjointSets()) { - if (const ExprGroups* uses = id_graph.getUses(id_group); uses) { - for (const ExprGroup& expr_group : *uses) { - all_uses.pushBack(expr_group); - } + for (const ExprGroup& expr_group : id_graph.getUses(id_group)) { + all_uses.pushBack(expr_group); } } return toString(id_graph, all_uses, indent_size, with_ptr); diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp index 0c397548539..f81ce4177d9 100644 --- a/csrc/id_model/visitor.cpp +++ b/csrc/id_model/visitor.cpp @@ -37,7 +37,7 @@ void IdGraphVisitor::traverse() { graph().disjointExprSets().disjointSets().end()); } else { for (const ValGroup& id_group : all_ids) { - for (const ExprGroup& def : *(graph().getDefinitions(id_group))) { + for (const ExprGroup& def : graph().getDefinitions(id_group)) { if (all_exprs.has(def)) { continue; } @@ -91,10 +91,9 @@ void IdGraphVisitor::traverse() { }; auto is_id_ready = [&](const ValGroup& id_group) { - auto unique_defs = graph().getDefinitions(id_group); - NVF_ERROR(unique_defs); + const ExprGroups& unique_defs = graph().getDefinitions(id_group); return std::all_of( - unique_defs->begin(), unique_defs->end(), [&](ExprGroup expr_group) { + unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { return expr_group->empty() || visited_exprs.has(expr_group) || graph().isTrivialExprGroup(expr_group); }); @@ -145,10 +144,8 @@ void IdGraphVisitor::traverse() { visited_ids.pushBack(current_id_group); if (!terminating_outputs.has(current_id_group)) { - if (const ExprGroups* uses = graph().getUses(current_id_group); - uses) { - to_visit_exprs.pushBack(*uses); - } + const ExprGroups& uses = graph().getUses(current_id_group); + to_visit_exprs.pushBack(uses); } } else { still_to_visit_ids.pushBack(current_id_group); diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 666568067a7..d66108468f0 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -119,10 +119,8 @@ std::vector ValGraph::inputGroups(const ExprGroup& expr) const { ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { DequeOfExprGroup to_visit; for (const ValGroup& of_val_group : of) { - if (const ExprGroups* group_uses = getUses(of_val_group); - group_uses != nullptr) { - to_visit.insert(to_visit.end(), group_uses->begin(), group_uses->end()); - } + const ExprGroups& group_uses = getUses(of_val_group); + to_visit.insert(to_visit.end(), group_uses.begin(), group_uses.end()); } UnorderedSetOfExprGroup visited; @@ -131,14 +129,11 @@ ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { to_visit.pop_front(); visited.emplace(current_expr); for (const ValGroup& output_group : outputGroups(current_expr)) { - if (const ExprGroups* group_uses = getUses(output_group); - group_uses != nullptr) { - for (const ExprGroup& group_use : *group_uses) { - if (visited.count(group_use)) { - continue; - } - to_visit.push_back(group_use); + for (const ExprGroup& group_use : getUses(output_group)) { + if (visited.count(group_use)) { + continue; } + to_visit.push_back(group_use); } } } @@ -149,10 +144,8 @@ ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { ExprGroups ValGraph::allDefinitionsOf(const ValGroups& of) const { DequeOfExprGroup to_visit; for (const ValGroup& of_val_group : of) { - if (const ExprGroups* group_defs = getDefinitions(of_val_group); - group_defs != nullptr) { - to_visit.insert(to_visit.end(), group_defs->begin(), group_defs->end()); - } + const ExprGroups& group_defs = getDefinitions(of_val_group); + to_visit.insert(to_visit.end(), group_defs.begin(), group_defs.end()); } UnorderedSetOfExprGroup visited; @@ -161,14 +154,11 @@ ExprGroups ValGraph::allDefinitionsOf(const ValGroups& of) const { to_visit.pop_front(); visited.emplace(current_expr); for (const ValGroup& input_id : inputGroups(current_expr)) { - if (const ExprGroups* group_defs = getDefinitions(input_id); - group_defs != nullptr) { - for (const ExprGroup& group_def : *group_defs) { - if (visited.count(group_def)) { - continue; - } - to_visit.push_back(group_def); + for (const ExprGroup& group_def : getDefinitions(input_id)) { + if (visited.count(group_def)) { + continue; } + to_visit.push_back(group_def); } } } @@ -276,9 +266,9 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) // domain coming back from any of its uses. ExprGroups min_groups; - const ExprGroups* uses = getUses(id_group); + const ExprGroups& uses = getUses(id_group); - if (!uses) { + if (uses.empty()) { // No expressions required for this iter domain, it must be a // terminating output. required_ind_exprs_ids[id_group] = min_groups; @@ -287,7 +277,7 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) // Only worry about expressions between inputs and outputs we're // looking at. - for (const ExprGroup& use_group : uses->computeIntersect(all_exprs)) { + for (const ExprGroup& use_group : uses.computeIntersect(all_exprs)) { auto use_required_ind_exprs_it = required_ind_exprs_exprs.find(use_group); if (use_required_ind_exprs_it == required_ind_exprs_exprs.end()) { // If there isn't an entry for the use expression it wasn't @@ -354,16 +344,13 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) if (processValGroup(currently_visiting_ids)) { something_was_processed = true; - if (const auto definitions = getDefinitions(currently_visiting_ids); - definitions) { - for (const ExprGroup& def : *definitions) { - if (!all_exprs.has(def)) { - continue; - } - if (required_ind_exprs_exprs.find(def) == - required_ind_exprs_exprs.end()) { - to_visit_exprs.pushBack(def); - } + for (const ExprGroup& def : getDefinitions(currently_visiting_ids)) { + if (!all_exprs.has(def)) { + continue; + } + if (required_ind_exprs_exprs.find(def) == + required_ind_exprs_exprs.end()) { + to_visit_exprs.pushBack(def); } } } else { @@ -383,12 +370,8 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) for (const auto& entry : required_ind_exprs_ids) { const ValGroup& id = entry.first; const ExprGroups& traverse_exprs = entry.second; - if (auto all_uses = getUses(id); all_uses) { - uses_path[id] = traverse_exprs.computeIntersect(*all_uses); - } else { - uses_path[id] = {}; - continue; - } + const ExprGroups& all_uses = getUses(id); + uses_path[id] = traverse_exprs.computeIntersect(all_uses); } // Topologically sort the uses_path. @@ -424,9 +407,8 @@ ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) auto outputs = outputGroups(currently_visiting); for (const ValGroup& out_id : outputs) { visited.pushBack(out_id); - if (const auto uses = getUses(out_id); uses) { - still_to_visit.pushBack(uses->computeIntersect(all_exprs)); - } + const ExprGroups& uses = getUses(out_id); + still_to_visit.pushBack(uses.computeIntersect(all_exprs)); } } else { still_to_visit.pushBack(currently_visiting); @@ -632,23 +614,24 @@ bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { return true; } -const ExprGroups* ValGraph::getDefinitions(const ValGroup& val_group) const { +const ExprGroups& ValGraph::getDefinitions(const ValGroup& val_group) const { NVF_ERROR(val_group, "Nullptr not allowed"); - if (auto it = unique_definitions_.find(val_group); - it != unique_definitions_.end()) { - return &(it->second); - } else { - return nullptr; - } + auto it = unique_definitions_.find(val_group); + NVF_ERROR( + it != unique_definitions_.end(), + "Definition group not found for ", + nvfuser::toString(val_group)); + return it->second; } -const ExprGroups* ValGraph::getUses(const ValGroup& val_group) const { +const ExprGroups& ValGraph::getUses(const ValGroup& val_group) const { NVF_ERROR(val_group, "Nullptr not allowed"); - if (auto it = unique_uses_.find(val_group); it != unique_uses_.end()) { - return &(it->second); - } else { - return nullptr; - } + auto it = unique_uses_.find(val_group); + NVF_ERROR( + it != unique_uses_.end(), + "Use group not found for ", + nvfuser::toString(val_group)); + return it->second; } void ValGraph::mapVals(Val* val0, Val* val1) { @@ -659,19 +642,17 @@ void ValGraph::mapVals(Val* val0, Val* val1) { if (disjointValSets().strictAreMapped(val0, val1)) { return; } + // Definitions and uses are based on the groups of id0 and id1, don't merge // them into a single group until we grab all definitions and uses for later // processing. - ValGroup orig_val_group0 = toGroup(val0); - ValGroup orig_val_group1 = toGroup(val1); - const ExprGroups* orig_defs0 = getDefinitions(orig_val_group0); - NVF_ERROR(orig_defs0); - const ExprGroups* orig_defs1 = getDefinitions(orig_val_group1); - NVF_ERROR(orig_defs1); - const ExprGroups* orig_uses0 = getUses(orig_val_group0); - NVF_ERROR(orig_uses0); - const ExprGroups* orig_uses1 = getUses(orig_val_group1); - NVF_ERROR(orig_uses1); + const ValGroup orig_val_group0 = toGroup(val0); + const ValGroup orig_val_group1 = toGroup(val1); + + const ExprGroups& orig_defs0 = getDefinitions(orig_val_group0); + const ExprGroups& orig_defs1 = getDefinitions(orig_val_group1); + const ExprGroups& orig_uses0 = getUses(orig_val_group0); + const ExprGroups& orig_uses1 = getUses(orig_val_group1); // Map the iter domains together before we traverse across definitions and // uses. Traversing definitions and uses could use the new property of id0 and @@ -679,13 +660,13 @@ void ValGraph::mapVals(Val* val0, Val* val1) { disjoint_vals_.mapEntries(val0, val1); auto new_val_group = toGroup(val0); - unique_definitions_[new_val_group] = orig_defs0->computeUnion(*orig_defs1); - unique_uses_[new_val_group] = orig_uses0->computeUnion(*orig_uses1); + unique_definitions_[new_val_group] = orig_defs0.computeUnion(orig_defs1); + unique_uses_[new_val_group] = orig_uses0.computeUnion(orig_uses1); // Propagate on uses - if (!orig_uses0->empty() && !orig_uses1->empty()) { - for (const ExprGroup& use_group_1 : *orig_uses1) { - for (const ExprGroup& use_group_0 : *orig_uses0) { + if (!orig_uses0.empty() && !orig_uses1.empty()) { + for (const ExprGroup& use_group_1 : orig_uses1) { + for (const ExprGroup& use_group_0 : orig_uses0) { if (use_group_0 == use_group_1) { continue; } @@ -697,9 +678,9 @@ void ValGraph::mapVals(Val* val0, Val* val1) { } // Propagate on definitions - if (!orig_defs0->empty() && !orig_defs1->empty()) { - for (const ExprGroup& def_group_1 : *orig_defs1) { - for (const ExprGroup& def_group_0 : *orig_defs0) { + if (!orig_defs0.empty() && !orig_defs1.empty()) { + for (const ExprGroup& def_group_1 : orig_defs1) { + for (const ExprGroup& def_group_0 : orig_defs0) { if (def_group_0 == def_group_1) { continue; } diff --git a/csrc/val_graph.h b/csrc/val_graph.h index 1a434d8fe47..57f186c681c 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -118,18 +118,18 @@ class ValGraph { // ExprGroups used in this history of defining the 'of' IdGroups. ExprGroups allDefinitionsOf(const ValGroups& of) const; - //! Returns the pointer to expressions associated with the - //! definitions of the provided ValGroup. Nullptr is returned otherwise. + //! Returns the expressions associated with the + //! definitions of the provided ValGroup. //! - //! The returned pointer is to a vector of vector of expressions. The - //! inner vector is proven to be equivalent. The - //! outer vector are expression groups that are not equivalent, but - //! produce one of the ValGroups within the same disjoint Val set. - const ExprGroups* getDefinitions(const ValGroup& val_group) const; + //! Each ExprGroup of the returned ExprGroup vector is proven to be + //! equivalent. The ExprGroup vector holds expression groups that are not + //! equivalent, but produce one of the ValGroups within the same disjoint Val + //! set. + const ExprGroups& getDefinitions(const ValGroup& val_group) const; //! Same as getDefinitions but for uses instead of //! definitions - const ExprGroups* getUses(const ValGroup& val_group) const; + const ExprGroups& getUses(const ValGroup& val_group) const; bool hasDefinitions(const ValGroup& val_group) const; From 64b409ef096d583fc298bd6afa4eaa23908435c9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Jan 2024 17:22:02 -0800 Subject: [PATCH 113/178] fix memory usage --- csrc/val_graph.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index d66108468f0..7716f2b9d22 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -649,10 +649,13 @@ void ValGraph::mapVals(Val* val0, Val* val1) { const ValGroup orig_val_group0 = toGroup(val0); const ValGroup orig_val_group1 = toGroup(val1); - const ExprGroups& orig_defs0 = getDefinitions(orig_val_group0); - const ExprGroups& orig_defs1 = getDefinitions(orig_val_group1); - const ExprGroups& orig_uses0 = getUses(orig_val_group0); - const ExprGroups& orig_uses1 = getUses(orig_val_group1); + // Note that getDefinitions and getUses return references, which + // will be invalidated once unique_definitions_ and unique_uses_ are + // updated + const ExprGroups orig_defs0 = getDefinitions(orig_val_group0); + const ExprGroups orig_defs1 = getDefinitions(orig_val_group1); + const ExprGroups orig_uses0 = getUses(orig_val_group0); + const ExprGroups orig_uses1 = getUses(orig_val_group1); // Map the iter domains together before we traverse across definitions and // uses. Traversing definitions and uses could use the new property of id0 and @@ -666,7 +669,11 @@ void ValGraph::mapVals(Val* val0, Val* val1) { // Propagate on uses if (!orig_uses0.empty() && !orig_uses1.empty()) { for (const ExprGroup& use_group_1 : orig_uses1) { + NVF_ERROR(use_group_1.get() != nullptr); + NVF_ERROR(!use_group_1->empty()); for (const ExprGroup& use_group_0 : orig_uses0) { + NVF_ERROR(use_group_0.get() != nullptr); + NVF_ERROR(!use_group_0->empty()); if (use_group_0 == use_group_1) { continue; } @@ -680,7 +687,11 @@ void ValGraph::mapVals(Val* val0, Val* val1) { // Propagate on definitions if (!orig_defs0.empty() && !orig_defs1.empty()) { for (const ExprGroup& def_group_1 : orig_defs1) { + NVF_ERROR(def_group_1.get() != nullptr); + NVF_ERROR(!def_group_1->empty()); for (const ExprGroup& def_group_0 : orig_defs0) { + NVF_ERROR(def_group_0.get() != nullptr); + NVF_ERROR(!def_group_0->empty()); if (def_group_0 == def_group_1) { continue; } From a12514edc61cae9cedd63554047e6d397e9bc997 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Jan 2024 11:26:15 -0800 Subject: [PATCH 114/178] IdModel: Fix inconsistent graph issue (#1627) Fixes #1636 Previously when exprs are merged, their input and output groups are updated to point to the new merged group. However, this is not enough as there may be val groups that are used or defined by the pre-merged expr groups but they may not be inputs or outputs of the exprs. See #1636 for a concrete example. Here, we check all mapping entries of `unique_definitions_` and `unique_uses_` so all uses and definitions are correctly updated. --- csrc/disjoint_set.h | 4 +- csrc/id_model/id_model.cpp | 8 +++ csrc/val_graph.cpp | 120 ++++++++++++++++++++++++++++++------- csrc/val_graph.h | 4 ++ 4 files changed, 113 insertions(+), 23 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index eff54c62276..982e4923ee1 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -409,7 +409,9 @@ class DisjointSets { entry_it != disjointSetMap().end(), "Strict mapping failed on element: ", abstractToString(entry0), - " either an error occurred, or non strict mapping should have been used."); + " either an error occurred, or non strict mapping should have been used.", + " ", + entry0->name()); return entry_it->second->has(entry1); } diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index bbe8ebde187..2ef8343b044 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -682,6 +682,8 @@ void IdModel::buildExactGraph(const std::vector& exprs) { // TODO: Revisit if we really should map domains in the exact map mapThroughLoopSwizzles(idGraph(IdMappingMode::EXACT)); } + + idGraph(IdMappingMode::EXACT).validateConsistency(); } namespace { @@ -753,6 +755,8 @@ void IdModel::buildAlmostExactMap() { for (const auto& [id1, id2] : ids_to_map) { almost_exact_graph.mapVals(id1, id2); } + + almost_exact_graph.validateConsistency(); } void IdModel::buildPermissiveMap(const std::vector& exprs) { @@ -804,6 +808,8 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { } } } + + idGraph(IdMappingMode::PERMISSIVE).validateConsistency(); } namespace { @@ -958,6 +964,8 @@ void IdModel::buildLoopMap(const std::vector& exprs) { } loop_promotion_map_ = buildLoopPromotionMap(inlining_info); + + idGraph(IdMappingMode::LOOP).validateConsistency(); } std::unordered_map IdModel::buildLoopPromotionMap( diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 7716f2b9d22..979e4ac8182 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -756,33 +756,23 @@ void ValGraph::mapExprs(Expr* expr0, Expr* expr1) { const ExprGroup& expr_new_group = toGroup(expr0); - // Update unique uses of producers - ValGroups producers; - for (auto expr : std::vector{expr0, expr1}) { - for (auto input : expr->inputs()) { - producers.pushBack(toGroup(input)); + // Update unique uses + for (auto& [producer_group, use_groups] : unique_uses_) { + if (use_groups.has(expr0_orig_group) || use_groups.has(expr1_orig_group)) { + use_groups.erase(expr0_orig_group); + use_groups.erase(expr1_orig_group); + use_groups.pushBack(expr_new_group); } } - for (const ValGroup& producer_group : producers) { - unique_uses_.at(producer_group).erase(expr0_orig_group); - unique_uses_.at(producer_group).erase(expr1_orig_group); - unique_uses_.at(producer_group).pushBack(expr_new_group); - } - - // Update unique definitinos of consumers - ValGroups consumers; - for (auto expr : std::vector{expr0, expr1}) { - for (auto output : expr->outputs()) { - consumers.pushBack(toGroup(output)); + // Update unique definitions + for (auto& [consumer_group, def_groups] : unique_definitions_) { + if (def_groups.has(expr0_orig_group) || def_groups.has(expr1_orig_group)) { + def_groups.erase(expr0_orig_group); + def_groups.erase(expr1_orig_group); + def_groups.pushBack(expr_new_group); } } - - for (const ValGroup& consumer_group : consumers) { - unique_definitions_.at(consumer_group).erase(expr0_orig_group); - unique_definitions_.at(consumer_group).erase(expr1_orig_group); - unique_definitions_.at(consumer_group).pushBack(expr_new_group); - } } bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { @@ -864,4 +854,90 @@ bool ValGraph::isTrivialExprGroup(const ExprGroup& expr_group) const { .empty(); } +void ValGraph::validateConsistency() const { + // Check the consistency of the mapping information. Specifically: + // 1. All ValGroup and ExprGroup sets are not empty. This may not be + // strictly necessary but it's often implicitly assumed as we tend + // to use `front()`. + // 2. All Val groups in disjoint_vals_ are mapped in unique_definitions_ and + // unique_uses_ + // 3. All Expr groups in disjoint_exprs_ are mapped to in + // unique_definitions_ and unique_uses_ + // 4. Any val and expr groups in unique_definitions_ and + // unique_uses_ are found in disjoint_vals_ and disjoint_exprs_ + + // Check 1 + for (const ValGroup& valg : disjointValSets().disjointSets()) { + NVF_ERROR(valg.get() != nullptr); + NVF_ERROR(!valg->empty(), "Empty Val group is not allowed"); + } + + for (const ExprGroup& exprg : disjointExprSets().disjointSets()) { + NVF_ERROR(exprg.get() != nullptr); + NVF_ERROR(!exprg->empty(), "Empty Expr group is not allowed"); + } + + // Check 2 + for (const ValGroup& valg : disjointValSets().disjointSets()) { + NVF_ERROR( + unique_definitions_.find(valg) != unique_definitions_.end(), + "Definition exprs not found for ", + nvfuser::toString(valg)); + NVF_ERROR( + unique_uses_.find(valg) != unique_uses_.end(), + "Use exprs not found for ", + nvfuser::toString(valg)); + } + + // Check 3 + for (const ExprGroup& exprg : disjointExprSets().disjointSets()) { + for (const auto& use_def_map : {unique_definitions_, unique_uses_}) { + bool found = false; + for (const auto& [val_group, expr_groups] : use_def_map) { + if (expr_groups.has(exprg)) { + found = true; + continue; + } + } + NVF_ERROR( + found, + "ExprGroup not found in ", + (&use_def_map == &unique_definitions_) ? "unique_definitions_" + : "unique_uses_"); + } + } + + // Check 4 + for (const auto& use_def_map : {unique_definitions_, unique_uses_}) { + for (const auto& [val_group, expr_groups] : use_def_map) { + NVF_ERROR(val_group.get() != nullptr); + auto val_set_it = std::find( + disjointValSets().disjointSets().begin(), + disjointValSets().disjointSets().end(), + val_group); + NVF_ERROR( + val_set_it != disjointValSets().disjointSets().end(), + "Inconsistent ValGroup, ", + nvfuser::toString(val_group), + ", at addreess ", + val_group.get(), + ", not found in the disjoint Val sets."); + for (const ExprGroup& expr_group : expr_groups) { + NVF_ERROR(expr_group.get() != nullptr); + auto expr_set_it = std::find( + disjointExprSets().disjointSets().begin(), + disjointExprSets().disjointSets().end(), + expr_group); + NVF_ERROR( + expr_set_it != disjointExprSets().disjointSets().end(), + "Inconsistent ExprGroup, ", + nvfuser::toString(expr_group), + ", at addreess ", + expr_group.get(), + ", not found in the disjoint Expr sets."); + } + } + } +} + } // namespace nvfuser diff --git a/csrc/val_graph.h b/csrc/val_graph.h index 57f186c681c..d796a2e0dbb 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -203,6 +203,10 @@ class ValGraph { // they modify matching original inputs by the same amount. bool exprsMap(Expr* first, Expr* second, bool forward) const; + // Check basic consistencies of val and expr groups and their + // mappings. + void validateConsistency() const; + public: void addUniqueUses(const ValGroup& id_group, const ExprGroup& uses) { unique_uses_.at(id_group).pushBack(uses); From 197f227e75124b282df56b1210f4cb1af415b61f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Jan 2024 13:50:29 -0800 Subject: [PATCH 115/178] Cherry pick #1627 --- csrc/id_model/id_model.cpp | 8 +++ csrc/val_graph.cpp | 120 ++++++++++++++++++++++++++++++------- csrc/val_graph.h | 4 ++ 3 files changed, 110 insertions(+), 22 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 4ddd20deccc..ac97bbeb3a7 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -360,6 +360,8 @@ void IdModel::buildExactGraph(const std::vector& exprs) { // TODO: Revisit if we really should map domains in the exact map mapThroughLoopSwizzles(idGraph(IdMappingMode::EXACT)); } + + idGraph(IdMappingMode::EXACT).validateConsistency(); } namespace { @@ -431,6 +433,8 @@ void IdModel::buildAlmostExactMap() { for (const auto& [id1, id2] : ids_to_map) { almost_exact_graph.mapVals(id1, id2); } + + almost_exact_graph.validateConsistency(); } void IdModel::buildPermissiveMap(const std::vector& exprs) { @@ -464,6 +468,8 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { } } } + + idGraph(IdMappingMode::PERMISSIVE).validateConsistency(); } namespace { @@ -540,6 +546,8 @@ void IdModel::buildLoopMap(const std::vector& exprs) { exprs, idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::PERMISSIVE)); initializeLoopMap(info); + + idGraph(IdMappingMode::LOOP).validateConsistency(); } void IdModel::build( diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index dbee82eb6a8..3e9324a0c19 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -442,33 +442,23 @@ void ValGraph::mapExprs(Expr* expr0, Expr* expr1) { const ExprGroup& expr_new_group = toGroup(expr0); - // Update unique uses of producers - ValGroups producers; - for (auto expr : std::vector{expr0, expr1}) { - for (auto input : expr->inputs()) { - producers.pushBack(toGroup(input)); + // Update unique uses + for (auto& [producer_group, use_groups] : unique_uses_) { + if (use_groups.has(expr0_orig_group) || use_groups.has(expr1_orig_group)) { + use_groups.erase(expr0_orig_group); + use_groups.erase(expr1_orig_group); + use_groups.pushBack(expr_new_group); } } - for (const ValGroup& producer_group : producers) { - unique_uses_.at(producer_group).erase(expr0_orig_group); - unique_uses_.at(producer_group).erase(expr1_orig_group); - unique_uses_.at(producer_group).pushBack(expr_new_group); - } - - // Update unique definitinos of consumers - ValGroups consumers; - for (auto expr : std::vector{expr0, expr1}) { - for (auto output : expr->outputs()) { - consumers.pushBack(toGroup(output)); + // Update unique definitions + for (auto& [consumer_group, def_groups] : unique_definitions_) { + if (def_groups.has(expr0_orig_group) || def_groups.has(expr1_orig_group)) { + def_groups.erase(expr0_orig_group); + def_groups.erase(expr1_orig_group); + def_groups.pushBack(expr_new_group); } } - - for (const ValGroup& consumer_group : consumers) { - unique_definitions_.at(consumer_group).erase(expr0_orig_group); - unique_definitions_.at(consumer_group).erase(expr1_orig_group); - unique_definitions_.at(consumer_group).pushBack(expr_new_group); - } } bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { @@ -500,4 +490,90 @@ bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return true; } +void ValGraph::validateConsistency() const { + // Check the consistency of the mapping information. Specifically: + // 1. All ValGroup and ExprGroup sets are not empty. This may not be + // strictly necessary but it's often implicitly assumed as we tend + // to use `front()`. + // 2. All Val groups in disjoint_vals_ are mapped in unique_definitions_ and + // unique_uses_ + // 3. All Expr groups in disjoint_exprs_ are mapped to in + // unique_definitions_ and unique_uses_ + // 4. Any val and expr groups in unique_definitions_ and + // unique_uses_ are found in disjoint_vals_ and disjoint_exprs_ + + // Check 1 + for (const ValGroup& valg : disjointValSets().disjointSets()) { + NVF_ERROR(valg.get() != nullptr); + NVF_ERROR(!valg->empty(), "Empty Val group is not allowed"); + } + + for (const ExprGroup& exprg : disjointExprSets().disjointSets()) { + NVF_ERROR(exprg.get() != nullptr); + NVF_ERROR(!exprg->empty(), "Empty Expr group is not allowed"); + } + + // Check 2 + for (const ValGroup& valg : disjointValSets().disjointSets()) { + NVF_ERROR( + unique_definitions_.find(valg) != unique_definitions_.end(), + "Definition exprs not found for ", + nvfuser::toString(valg)); + NVF_ERROR( + unique_uses_.find(valg) != unique_uses_.end(), + "Use exprs not found for ", + nvfuser::toString(valg)); + } + + // Check 3 + for (const ExprGroup& exprg : disjointExprSets().disjointSets()) { + for (const auto& use_def_map : {unique_definitions_, unique_uses_}) { + bool found = false; + for (const auto& [val_group, expr_groups] : use_def_map) { + if (expr_groups.has(exprg)) { + found = true; + continue; + } + } + NVF_ERROR( + found, + "ExprGroup not found in ", + (&use_def_map == &unique_definitions_) ? "unique_definitions_" + : "unique_uses_"); + } + } + + // Check 4 + for (const auto& use_def_map : {unique_definitions_, unique_uses_}) { + for (const auto& [val_group, expr_groups] : use_def_map) { + NVF_ERROR(val_group.get() != nullptr); + auto val_set_it = std::find( + disjointValSets().disjointSets().begin(), + disjointValSets().disjointSets().end(), + val_group); + NVF_ERROR( + val_set_it != disjointValSets().disjointSets().end(), + "Inconsistent ValGroup, ", + nvfuser::toString(val_group), + ", at addreess ", + val_group.get(), + ", not found in the disjoint Val sets."); + for (const ExprGroup& expr_group : expr_groups) { + NVF_ERROR(expr_group.get() != nullptr); + auto expr_set_it = std::find( + disjointExprSets().disjointSets().begin(), + disjointExprSets().disjointSets().end(), + expr_group); + NVF_ERROR( + expr_set_it != disjointExprSets().disjointSets().end(), + "Inconsistent ExprGroup, ", + nvfuser::toString(expr_group), + ", at addreess ", + expr_group.get(), + ", not found in the disjoint Expr sets."); + } + } + } +} + } // namespace nvfuser diff --git a/csrc/val_graph.h b/csrc/val_graph.h index fbb25621986..f7a20e46bdb 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -175,6 +175,10 @@ class ValGraph { // they modify matching original inputs by the same amount. bool exprsMap(Expr* first, Expr* second, bool forward) const; + // Check basic consistencies of val and expr groups and their + // mappings. + void validateConsistency() const; + void addUniqueUses(const ValGroup& id_group, const ExprGroup& uses) { unique_uses_.at(id_group).pushBack(uses); } From 5d5458eca86e149462ebec2b45d20d7a0dce81b1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Jan 2024 14:06:19 -0800 Subject: [PATCH 116/178] enable idmodel --- csrc/device_lower/lower2device.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index fa34738d69c..4bc6869d988 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -382,7 +382,7 @@ void GpuLower::analysis(Fusion* fusion) { // functionality should be affected. New IterDomains may be created, // so it is expected that generated code may use diffrent variable // names - if (isOptionEnabled(EnableOption::IdModel)) { + if (true|| isOptionEnabled(EnableOption::IdModel)) { IdModel id_model(fusion_, false, true); } From de20be4b00ebc43b02bd988069315843195b4c72 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Jan 2024 17:26:48 -0800 Subject: [PATCH 117/178] disable idmodel --- csrc/device_lower/lower2device.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 4bc6869d988..fa34738d69c 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -382,7 +382,7 @@ void GpuLower::analysis(Fusion* fusion) { // functionality should be affected. New IterDomains may be created, // so it is expected that generated code may use diffrent variable // names - if (true|| isOptionEnabled(EnableOption::IdModel)) { + if (isOptionEnabled(EnableOption::IdModel)) { IdModel id_model(fusion_, false, true); } From 9286489ddc1b1f46bb21f73cb1232732f4aa6c45 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Jan 2024 21:19:46 -0800 Subject: [PATCH 118/178] [IdModel] Refactoring for testing (#1624) --- csrc/device_lower/lower2device.cpp | 2 +- csrc/id_model/id_model.cpp | 281 +++++++++++++++---------- csrc/id_model/id_model.h | 81 ++++--- test/test_gpu_indexing.cpp | 7 +- test/test_id_model.cpp | 325 ++++++++++++++++++++++++++++- 5 files changed, 552 insertions(+), 144 deletions(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 5d91762eaa6..3fc0cea3e23 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -383,7 +383,7 @@ void GpuLower::analysis(Fusion* fusion) { // so it is expected that generated code may use diffrent variable // names if (true || isOptionEnabled(EnableOption::IdModel)) { - IdModel id_model(fusion_, false, true); + IdModel id_model(fusion_); } resolveComputeWith(fusion_); diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 2ef8343b044..c76a7329bb0 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -71,31 +71,65 @@ void IdModel::assertNoSelfMapping() { IdModel::IdModel( const std::vector& exprs, const std::vector& additional_tvs, - bool allow_self_mapping) { - build(exprs, additional_tvs); + bool build_graphs, + bool allow_self_mapping) + : allow_self_mapping_(allow_self_mapping) { + std::copy_if( + exprs.begin(), + exprs.end(), + std::back_inserter(tv_exprs_), + [](Expr* expr) { + NVF_ERROR(expr != nullptr); + return ir_utils::isTvOp(expr); + }); - if (!allow_self_mapping) { - assertNoSelfMapping(); + auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs_); + all_tvs.pushBack(additional_tvs.begin(), additional_tvs.end()); + + tvs_ = all_tvs.vector(); + + // Add uses and definitions to all iter domains. + buildIterDomainDefinitionsAndUses(); + + if (build_graphs) { + buildAllGraphs(); } } -IdModel::IdModel(Fusion* fusion, bool allow_self_mapping, bool validate) { - std::vector inputs_and_outputs; +IdModel::IdModel( + Fusion* fusion, + bool build_graphs, + bool allow_self_mapping, + bool validate) + : allow_self_mapping_(allow_self_mapping), validate_(validate) { + auto all_exprs = fusion->exprs(); + std::copy_if( + all_exprs.begin(), + all_exprs.end(), + std::back_inserter(tv_exprs_), + [](Expr* expr) { + NVF_ERROR(expr != nullptr); + return ir_utils::isTvOp(expr); + }); + + auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs_); + { auto inp_tvs = ir_utils::filterByType(fusion->inputs()); - inputs_and_outputs.insert( - inputs_and_outputs.begin(), inp_tvs.begin(), inp_tvs.end()); + all_tvs.pushBack(inp_tvs.begin(), inp_tvs.end()); } { auto out_tvs = ir_utils::filterByType(fusion->outputs()); - inputs_and_outputs.insert( - inputs_and_outputs.end(), out_tvs.begin(), out_tvs.end()); + all_tvs.pushBack(out_tvs.begin(), out_tvs.end()); } - build(fusion->exprs(), inputs_and_outputs, validate); + tvs_ = all_tvs.vector(); - if (!allow_self_mapping) { - assertNoSelfMapping(); + // Add uses and definitions to all iter domains. + buildIterDomainDefinitionsAndUses(); + + if (build_graphs) { + buildAllGraphs(); } } @@ -111,7 +145,11 @@ const ValGraph& IdModel::idGraph(IdMappingMode mode) const { ValGraph& IdModel::idGraph(IdMappingMode mode) { auto graph_it = id_graphs_.find(mode); - NVF_ERROR(graph_it != id_graphs_.end()); + NVF_ERROR( + graph_it != id_graphs_.end(), + "Failed to find an IdGraph with the ", + mode, + " mode"); return graph_it->second; } @@ -179,7 +217,7 @@ std::optional> detectMappablePair( std::optional> findFirstSelfMapping( const std::vector& all_tvs, - const IdModel& id_graph) { + const IdModel& id_model) { for (auto tv : all_tvs) { // For each tensor, make sure root, rfactor and leaf domains // should not include domains that are mapped with another domain @@ -188,7 +226,7 @@ findFirstSelfMapping( // Root domains auto self_mappped_root_pair = - detectMappablePair(tv->getRootDomain(), id_graph, IdMappingMode::EXACT); + detectMappablePair(tv->getRootDomain(), id_model, IdMappingMode::EXACT); if (self_mappped_root_pair.has_value()) { return std::make_tuple( tv, @@ -200,7 +238,7 @@ findFirstSelfMapping( // Rfactor domains if (tv->hasRFactor()) { auto self_mappped_rf_pair = detectMappablePair( - tv->getRFactorDomain(), id_graph, IdMappingMode::EXACT); + tv->getRFactorDomain(), id_model, IdMappingMode::EXACT); if (self_mappped_rf_pair.has_value()) { return std::make_tuple( tv, @@ -215,7 +253,7 @@ findFirstSelfMapping( // map. However, it should also be impossible for index map to generate a // case like this. auto self_mappped_leaf_pair = detectMappablePair( - tv->domain()->leaf(), id_graph, IdMappingMode::EXACT); + tv->domain()->leaf(), id_model, IdMappingMode::EXACT); if (self_mappped_leaf_pair.has_value()) { return std::make_tuple( tv, @@ -229,9 +267,8 @@ findFirstSelfMapping( } // namespace -void IdModel::buildIterDomainDefinitionsAndUses( - const std::vector& all_tvs) { - for (const auto tv : all_tvs) { +void IdModel::buildIterDomainDefinitionsAndUses() { + for (const auto tv : tvs_) { VectorOfUniqueEntries root_domain_ids{ tv->getRootDomain().begin(), tv->getRootDomain().end()}; @@ -635,8 +672,13 @@ ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) { return id_graph; } -void IdModel::buildExactGraph(const std::vector& exprs) { - for (auto expr : exprs) { +void IdModel::buildExactGraph() { + // Initialize the maps with all the IterDomains used in the provded + // expressions. + NVF_ERROR( + id_graphs_.emplace(IdMappingMode::EXACT, initializeIdGraph()).second); + + for (auto expr : tv_exprs_) { TensorView* c_tv = ir_utils::getTvOutput(expr); auto all_tv_outputs = ir_utils::filterByType(expr->outputs()); @@ -720,9 +762,15 @@ std::vector> getTriviallyMappedIds(Expr* expr) { } // namespace -void IdModel::buildAlmostExactMap() { +void IdModel::buildAlmostExactGraph() { + // Make sure the exact graph is already built + maybeBuildGraph(IdMappingMode::EXACT); + // Build almost exact map by forwarding through broadcast axes - idGraph(IdMappingMode::ALMOSTEXACT) = idGraph(IdMappingMode::EXACT); + NVF_ERROR( + id_graphs_ + .emplace(IdMappingMode::ALMOSTEXACT, idGraph(IdMappingMode::EXACT)) + .second); auto& almost_exact_graph = idGraph(IdMappingMode::ALMOSTEXACT); @@ -759,13 +807,19 @@ void IdModel::buildAlmostExactMap() { almost_exact_graph.validateConsistency(); } -void IdModel::buildPermissiveMap(const std::vector& exprs) { +void IdModel::buildPermissiveGraph() { + // Make sure the exact graph is already built + maybeBuildGraph(IdMappingMode::EXACT); + // Use the exact map as the starting map rather than the // almost-exact map. Almost exact is useful for index hoisting but // not necessary for permissive and loop maps - idGraph(IdMappingMode::PERMISSIVE) = idGraph(IdMappingMode::EXACT); + NVF_ERROR( + id_graphs_ + .emplace(IdMappingMode::PERMISSIVE, idGraph(IdMappingMode::EXACT)) + .second); - for (auto expr : exprs) { + for (auto expr : tv_exprs_) { // Multiple outputs are already mapped, we can ignore all but the first // consumer given they have to be replayed in the same exact way TensorView* c_tv = ir_utils::getTvOutput(expr); @@ -845,6 +899,33 @@ std::unordered_map resolvedRootBroadcasts( return resolved_bcast_map; } +// Update a map of ValGroups to ID from an old Valgraph to a new +// ValGraph. The new graph must be a superset of the old graph. +std::unordered_map updateMap( + const std::unordered_map& stale_map, + ValGraph& new_graph) { + std::unordered_map new_map; + + for (const auto& [stale_group, mapped_id] : stale_map) { + const ValGroups& new_groups = new_graph.toGroups(*stale_group); + NVF_ERROR( + new_groups.size() == 1, + "\nUpdate map assumes that new graph is equivalent to old graph plus extra mappings.\n", + "i.e. all mappings in new_graph should exist in the graph stale_map was produced on.\n", + "old:", + nvfuser::toString(stale_group), + "new: ", + nvfuser::toString(new_groups)); + NVF_ERROR( + new_map.emplace(new_groups.front(), mapped_id).second, + "Expected only a single mapping but multiple entries detected for ", + nvfuser::toString(new_groups.front())); + } + return new_map; +} + +} // namespace + // Grab inlining relationships StatefulInliningInfo buildStatefulInliningInfo( const std::vector& exprs, @@ -898,38 +979,12 @@ StatefulInliningInfo buildStatefulInliningInfo( return info; } -// Update a map of ValGroups to ID from an old Valgraph to a new -// ValGraph. The new graph must be a superset of the old graph. -std::unordered_map updateMap( - const std::unordered_map& stale_map, - ValGraph& new_graph) { - std::unordered_map new_map; - - for (const auto& [stale_group, mapped_id] : stale_map) { - const ValGroups& new_groups = new_graph.toGroups(*stale_group); - NVF_ERROR( - new_groups.size() == 1, - "\nUpdate map assumes that new graph is equivalent to old graph plus extra mappings.\n", - "i.e. all mappings in new_graph should exist in the graph stale_map was produced on.\n", - "old:", - nvfuser::toString(stale_group), - "new: ", - nvfuser::toString(new_groups)); - NVF_ERROR( - new_map.emplace(new_groups.front(), mapped_id).second, - "Expected only a single mapping but multiple entries detected for ", - nvfuser::toString(new_groups.front())); - } - return new_map; -} - -} // namespace - -void IdModel::initializeLoopMap(const StatefulInliningInfo& info) { +void IdModel::initializeLoopGraph(const StatefulInliningInfo& info) { // In the case of the Loop graph, we do not propagate mappings but // explicitly set which domains to map based on the permissive graph // and the CA positions. - idGraph(IdMappingMode::LOOP) = initializeIdGraph(false); + NVF_ERROR( + id_graphs_.emplace(IdMappingMode::LOOP, initializeIdGraph(false)).second); // Make sure this is called in a deterministic order. Build all inlined // relationships in loop graph. @@ -944,18 +999,24 @@ void IdModel::initializeLoopMap(const StatefulInliningInfo& info) { } } -void IdModel::buildLoopMap(const std::vector& exprs) { - if (!exprs.empty()) { +void IdModel::buildLoopGraph() { + // Make sure the depedent graphs are already built + maybeBuildGraph(IdMappingMode::EXACT); + maybeBuildGraph(IdMappingMode::PERMISSIVE); + + if (!tv_exprs_.empty()) { std::stringstream ss; - exprs.at(0)->fusion()->print(ss); + tv_exprs_.at(0)->fusion()->print(ss); VERBOSE() << ss.str(); } // Gather broadcast resolution and inlining information const StatefulInliningInfo inlining_info = buildStatefulInliningInfo( - exprs, idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::PERMISSIVE)); + tv_exprs_, + idGraph(IdMappingMode::EXACT), + idGraph(IdMappingMode::PERMISSIVE)); - initializeLoopMap(inlining_info); + initializeLoopGraph(inlining_info); VERBOSE() << "Initial loop graph:\n"; for (const auto& group : @@ -1007,7 +1068,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // Step 1: Build a map of the IEL groups of root broadcast domains // to resolving domains. std::unordered_map iel_promotion_map = - buildInlineRootPromotionMap(iel_graph, inlining_info); + buildInlineRootResolutionmap(iel_graph, inlining_info); // Step 2: Propagate the root promotions to intermediate and leaf groups. // At this point, the promotion may not be final as the analysis is @@ -1123,69 +1184,48 @@ void IdModel::propagateLoopPTypes() const { } } -void IdModel::build( - const std::vector& exprs, - const std::vector& additional_tvs, - bool validate) { +void IdModel::buildAllGraphs() { VERBOSE() << "*** Building all graphs ***"; - // Initialize the required sets as if a permissive relationship is never - // found, then querying an empty permissive map will fail later. - // Initialize disjoint sets - for (auto mode : kIdMappingModes) { - id_graphs_[mode] = ValGraph(); - } - - std::vector tv_exprs; - - std::copy_if( - exprs.begin(), exprs.end(), std::back_inserter(tv_exprs), [](Expr* expr) { - NVF_ERROR(expr != nullptr); - return ir_utils::isTvOp(expr); - }); - - auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs); - - for (auto additional_tv : additional_tvs) { - all_tvs.pushBack(additional_tv); - } - - if (all_tvs.empty()) { + if (tvs_.empty()) { return; } std::unique_ptr validator; + Fusion* fusion = tvs_.front()->fusion(); + // A ComputeAtMap will be built inside the constructor of // IdModelValidator, which may fail for some fusions that are not // supported currently (but work with IdModel). Make sure the // validator is only created when it is indeed requested - if (validate) { - validator = std::make_unique(all_tvs.front()->fusion()); + if (validate_) { + validator = std::make_unique(fusion); } - FusionGuard fg(all_tvs.front()->fusion()); - // Add uses and definitions to all iter domains. - buildIterDomainDefinitionsAndUses(all_tvs.vector()); - - // Initialize the maps with all the IterDomains used in the provded - // expressions. - idGraph(IdMappingMode::EXACT) = initializeIdGraph(); + FusionGuard fg(fusion); - buildExactGraph(tv_exprs); - if (validate) { + buildExactGraph(); + if (validate_) { validator->checkExactGraphEquivalence(idGraph(IdMappingMode::EXACT)); } - buildAlmostExactMap(); - if (validate) { + // Make sure there's no self mapping in TensorView's during lowering + // that would invalidate lowering assumptions. + self_mapping_info_ = findFirstSelfMapping(tvs_, *this); + if (!allow_self_mapping_) { + assertNoSelfMapping(); + } + + buildAlmostExactGraph(); + if (validate_) { validator->checkAlmostExactGraphEquivalence( idGraph(IdMappingMode::ALMOSTEXACT)); } - buildPermissiveMap(tv_exprs); + buildPermissiveGraph(); // Validation is not implemented when compliment mapping is enabled - if (validate && !permissive_graph_map_compliment_ids_) { + if (validate_ && !permissive_graph_map_compliment_ids_) { validator->checkPermissiveGraphEquivalence( idGraph(IdMappingMode::PERMISSIVE)); } @@ -1195,11 +1235,34 @@ void IdModel::build( // from the almost exact graph. idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); - buildLoopMap(tv_exprs); + buildLoopGraph(); +} - // Make sure there's no self mapping in TensorView's during lowering - // that would invalidate lowering assumptions. - self_mapping_info_ = findFirstSelfMapping(all_tvs.vector(), *this); +void IdModel::buildGraph(IdMappingMode mode) { + switch (mode) { + case IdMappingMode::EXACT: + buildExactGraph(); + break; + case IdMappingMode::ALMOSTEXACT: + buildAlmostExactGraph(); + break; + case IdMappingMode::PERMISSIVE: + buildPermissiveGraph(); + break; + case IdMappingMode::LOOP: + buildLoopGraph(); + break; + default: + NVF_ERROR(false, "Unsupported mode: ", mode); + } +} + +void IdModel::maybeBuildGraph(IdMappingMode mode) { + if (id_graphs_.find(mode) != id_graphs_.end()) { + return; + } else { + buildGraph(mode); + } } VectorOfUniqueEntries IdModel::computeTerminalLoopIds( @@ -1266,7 +1329,7 @@ ValGraph IdModel::buildIntersection( return intersection; } -std::unordered_map IdModel::buildInlineRootPromotionMap( +std::unordered_map IdModel::buildInlineRootResolutionmap( const ValGraph& iel_graph, const StatefulInliningInfo& info) { std::unordered_map iel_promotion_map; diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 42bbf8dc61b..4bc243cb7ff 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -37,6 +37,11 @@ struct StatefulInliningInfo { p2c_root_broadcast_resolution_map; }; +StatefulInliningInfo buildStatefulInliningInfo( + const std::vector& exprs, + const ValGraph& exact_graph, + const ValGraph& permissive_graph); + // A collection of ValGraphs that are built from a fusion or series of // expressions. These graphs are related, but have some distinct features based // on the IdMappingMode. @@ -93,9 +98,16 @@ struct StatefulInliningInfo { // considered promoted to a common iter domain class IdModel : public PolymorphicBase { public: + // Sometimes fusion inputs or outputs are disconnected from expressions, in + // those cases we still may want to send in some additional tensor views from + // the Fusion that don't have expressions associated with them. + // + // All graphs are built by default. It can be disabled with + // build_graphs=false. IdModel( const std::vector& exprs, const std::vector& additional_tvs = {}, + bool build_graphs = true, bool allow_self_mapping = false); // Same as the above constructor with fusion->exprs() excpet fusion may have @@ -106,8 +118,9 @@ class IdModel : public PolymorphicBase { // transition from the current ComputeAtMap. IdModel( Fusion* fusion, + bool build_graphs = true, bool allow_self_mapping = false, - bool validate = false); + bool validate = true); // Returns iter domain graph of provided mode. const ValGraph& idGraph(IdMappingMode mode) const; @@ -129,49 +142,47 @@ class IdModel : public PolymorphicBase { std::string toString() const; - const std::unordered_map& loopPromotionMap() const { - return loop_promotion_map_; - } - - private: - // Sometimes fusion inputs or outputs are disconnected from expressions, in - // those cases we still may want to send in some additional tensor views from - // the Fusion that don't have expressions associated with them. - void build( - const std::vector& exprs, - const std::vector& additional_tvs, - bool validate = false); - - // ======= START Iteration domain build process in order called ======= - - // Fills id_uses_ and id_definitions_ for all IterDomains active in the - // fusion. - void buildIterDomainDefinitionsAndUses( - const std::vector& all_tvs); - - // Iterates over all IterDomains in id_definitions_ and calls initializeVal on - // a new ValGraph and returns it. - ValGraph initializeIdGraph(bool propagate_through_exprs = true); + // Build all graphs. This is by default called from the constructor + void buildAllGraphs(); // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs // and first output of expr - void buildExactGraph(const std::vector& exprs); + void buildExactGraph(); // Fills disjoint_ids_[IdMappingMode::ALMOSTEXACT]. Initialize AlmostExact as // Exact entries, then map anything that's either merged with a size-1 or // split by a size-1 dimension. - void buildAlmostExactMap(); + void buildAlmostExactGraph(); // Fills disjoint_ids_[IdMappingMode::PERMISSIVE]. Initialize it as // Exact entries, then map through broadcasts - void buildPermissiveMap(const std::vector& exprs); + void buildPermissiveGraph(); // Fills disjoint_ids_[IdMappingMode::LOOP]. Map only inlined // domains that are mapped in the permissive graph - void buildLoopMap(const std::vector& exprs); + void buildLoopGraph(); + + // Build a graph + void buildGraph(IdMappingMode mode); + + // Build a graph if not already built + void maybeBuildGraph(IdMappingMode mode); + + // Iterates over all IterDomains in id_definitions_ and calls initializeVal on + // a new ValGraph and returns it. + ValGraph initializeIdGraph(bool propagate_through_exprs = true); + + const std::unordered_map& loopPromotionMap() const { + return loop_promotion_map_; + } + + protected: + // Fills id_uses_ and id_definitions_ for all IterDomains active in the + // fusion. + void buildIterDomainDefinitionsAndUses(); // Start loop map by grouping inlined iter domains - void initializeLoopMap(const StatefulInliningInfo& info); + void initializeLoopGraph(const StatefulInliningInfo& info); // Build a map of loop groups to IterDomains that represent actual // loops. The map is built based on the broadcast resolution with @@ -180,8 +191,9 @@ class IdModel : public PolymorphicBase { const StatefulInliningInfo& info); // Helper function for buildLoopPromotionMap. Returns a map of - // root broadcast ValGroups in the IEL graph to a representative IterDomain. - std::unordered_map buildInlineRootPromotionMap( + // root broadcast ValGroups in the IEL graph to a representative + // IterDomain picked from its IEL group. + std::unordered_map buildInlineRootResolutionmap( const ValGraph& iel_graph, const StatefulInliningInfo& info); @@ -315,7 +327,12 @@ class IdModel : public PolymorphicBase { // not have any registered uses or definitions. IterDomain* cloneIterDomain(IterDomain* id); - private: + protected: + std::vector tv_exprs_; + std::vector tvs_; + bool allow_self_mapping_ = false; + bool validate_ = false; + // Keeps ValGraphs containing all IterDomains for all mapping mode types. // // Using an array here might be nice, but it seems hard to use an enum as an diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index 473f7dff06e..9863864ed01 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -818,7 +818,12 @@ TEST_F(NVFuserTest, FusionIndexing19_CUDA) { tensor->inlineAt(1); } - IdModel id_model(&fusion); + // Validation needs to be disabled as ComputeAtMap would fail with this fusion + IdModel id_model( + &fusion, + /* build_graphs */ true, + /* allow_self_mapping */ false, + /* validate */ false); // All of the IDs that are generated with merge operations from the // root domains should be mapped to the single group. diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index fc823316564..9252f11a433 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include namespace nvfuser { @@ -32,9 +33,331 @@ TEST_F(IdModelTest, DetectSelfMapping) { fusion.addOutput(tv2); EXPECT_THAT( - [&]() { IdModel id_model(&fusion); }, + [&]() { + IdModel id_model(&fusion); + id_model.buildAllGraphs(); + }, ::testing::ThrowsMessage( ::testing::HasSubstr("!hasSelfMapping"))); } +namespace { + +// Helper class to test IdModel +class IdModelTester : public IdModel { + public: + // Do not automatically build the graphs + IdModelTester(Fusion* fusion) : IdModel(fusion, /* build_graphs */ false) {} + + std::pair> + getInlineRootResolutionMap() { + // Make sure the depedent graphs are already built + maybeBuildGraph(IdMappingMode::EXACT); + maybeBuildGraph(IdMappingMode::PERMISSIVE); + + // Gather broadcast resolution and inlining information + const StatefulInliningInfo inlining_info = buildStatefulInliningInfo( + tv_exprs_, + idGraph(IdMappingMode::EXACT), + idGraph(IdMappingMode::PERMISSIVE)); + + initializeLoopGraph(inlining_info); + + ValGraph iel_graph = buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); + + std::unordered_map root_promotion_map = + buildInlineRootResolutionmap(iel_graph, inlining_info); + + return {std::move(iel_graph), std::move(root_promotion_map)}; + } +}; + +// Test if root_broadcast_id is resolved to ref_id. If ref_id is +// nullptr, test if root_broadcast_id has no resolution. +void validateResolution( + IterDomain* root_broadcast_id, + IterDomain* ref_id, + const ValGraph& iel_graph, + const std::unordered_map& root_resolution_map) { + ASSERT_TRUE(root_broadcast_id->isBroadcast()); + const auto& iel_group = iel_graph.toGroup(root_broadcast_id); + auto root_promotion_map_it = root_resolution_map.find(iel_group); + if (ref_id != nullptr) { + ASSERT_TRUE(root_promotion_map_it != root_resolution_map.end()) + << "Root resolution not found for: " << nvfuser::toString(iel_group); + ASSERT_FALSE(ref_id->isBroadcast()); + auto resolution_id = root_promotion_map_it->second; + ASSERT_TRUE( + iel_graph.disjointValSets().strictAreMapped(resolution_id, ref_id)) + << "Unexpected root resolution. " + << "Expected: " << ref_id->toString() + << ". Actual: " << resolution_id->toString(); + } else { + ASSERT_TRUE(root_promotion_map_it == root_resolution_map.end()) + << "Root resolution should not exist for: " + << nvfuser::toString(iel_group) + << ", but found: " << root_promotion_map_it->second->toString(); + } +} + +// Create a fusion where we're missing a valid concrete id so the compute at map +// processing will fail. We need to be able to create the concrete ID not just +// look for one. It is not yet possible to lower this fusion as the +// current indexing cannot generate correct indices. Also used in +// FusionIndeixing19 +std::unique_ptr createFusionWithMultipleResolutionPaths() { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({7}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = broadcast(tv1, {false, true}); + + auto tv3 = makeConcreteTensor({7, 11}); + fusion.addInput(tv3); + + auto tv4 = add(tv3, tv2); + auto tv5 = broadcast(tv4, {false, false, true}); + // tv4[7, 11, 1] + + auto tv6 = broadcast(tv1, {false, true}); + + auto tv7 = makeConcreteTensor({7, 13}); + fusion.addInput(tv7); + auto tv8 = add(tv7, tv6); + auto tv9 = broadcast(tv8, {false, true, false}); + // tv9[7, 1, 13] + + auto tv10 = add(tv5, tv9); + fusion.addOutput(tv10); + + // tv10[7, 11, 13] + tv10->merge(0)->merge(0); + // tv10[7*11*13] + tv10->split(0, 5)->split(0, 3); + // tv10[7*11*13//5//3, 3, 5] + + TransformPropagatorWithCheck propagator(tv10); + MaxRootDomainInfoSpanningTree(tv10).traverse(&propagator); + + std::vector tensors_to_inline{tv1, tv2, tv4, tv6, tv8}; + for (auto tensor : tensors_to_inline) { + tensor->inlineAt(1); + } + + return fusion_ptr; +} + +TensorView* findTensorByName( + const std::vector& tvs, + StmtNameType name) { + if (auto it = std::find_if( + tvs.begin(), + tvs.end(), + [&](TensorView* tv) { return tv->name() == name; }); + it != tvs.end()) { + return *it; + } else { + return nullptr; + } +} + +} // namespace + +// Testing root resolution with a simple broadcast pattern +TEST_F(IdModelTest, LoopGraphRootResolution1) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto t0 = makeSymbolicTensor(1); + fusion->addInput(t0); + auto t1 = makeSymbolicTensor(2); + fusion->addInput(t1); + auto t2 = broadcast(t0, {true, false}); + auto t3 = add(t2, t1); + fusion->addOutput(t3); + + { + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // Nothing inlined. Should be no resolution + ASSERT_TRUE(root_resolution_map.empty()); + } + + t2->inlineAt(2); + ASSERT_EQ(t2->getComputeAtPosition(), 2); + + { + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // t2 is now fully inlined. Its root broadcast domain should be + // resoled with the corresponding domain of t3 + validateResolution( + t2->getRootDomain().at(0), + t3->getRootDomain().at(0), + iel_graph, + root_resolution_map); + } +} + +// Test with a fusion with progressive broadcasting +TEST_F(IdModelTest, LoopGraphRootResolution2) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto t0 = makeSymbolicTensor(1); + fusion->addInput(t0); + auto t1 = makeSymbolicTensor(3); + fusion->addInput(t1); + + auto t2 = broadcast(t0, {true, false}); + auto t3 = broadcast(t2, {true, false, false}); + auto t4 = add(t3, t1); + fusion->addOutput(t4); + + inlineMost(); + + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // Validate t2 and t3 as they have root broadcast domains + validateResolution( + t2->getRootDomain().at(0), + t4->getRootDomain().at(1), + iel_graph, + root_resolution_map); + + validateResolution( + t3->getRootDomain().at(0), + t4->getRootDomain().at(0), + iel_graph, + root_resolution_map); +} + +// Multiple inlined and non-inlined broadcast domains +TEST_F(IdModelTest, LoopGraphRootResolution3) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + auto tv1 = makeSymbolicTensor(4); + fusion->addInput(tv1); + + auto tv2 = broadcast(tv0, {false, true, false, true}); + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + // tv3: [i0, i1, i2, i3] -> [i0*i1, i2*i3] + tv3->merge(0); + tv3->merge(1); + + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv2->inlineAt(1); + + // tv2: [i0*b1, i2*b3] ca(1) + // tv3: [i0*i1, i2*i3] + + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // The b1 broadcast domain tv2 should be resolved as it's inlined, + // but b3 should not. + validateResolution( + tv2->getRootDomain().at(1), + tv3->getRootDomain().at(1), + iel_graph, + root_resolution_map); + + validateResolution( + tv2->getRootDomain().at(3), nullptr, iel_graph, root_resolution_map); +} + +TEST_F(IdModelTest, LoopGraphRootResolution4) { + auto fusion = createFusionWithMultipleResolutionPaths(); + auto all_tvs = ir_utils::allTvs(fusion.get()); + + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // Verify all tensors with broadcast have correct resolution of root + // broadcast domains + for (auto tv : ir_utils::allTvs(fusion.get())) { + // Skip tensors with no broadcast + if (std::none_of( + tv->getRootDomain().begin(), + tv->getRootDomain().end(), + [](auto id) { return id->isBroadcast(); })) { + continue; + } + + switch (tv->name()) { + case 2: + // T2_l[ iS49{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS50{3}, + // iS48{5} ] ca_pos( 1 ) produce_pos( 1 ) + // root domain : (iS2{7}, bS3{1}) + // Resolution: Resolved by the immediate consumer (T4) + validateResolution( + tv->getRootDomain().at(1), + findTensorByName(all_tvs, 4)->getRootDomain().at(1), + iel_graph, + root_resolution_map); + break; + case 5: + // T5_l[ iS39{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 1 ), 5) ), 3) )}, + // iS40{3}, iS38{5} ] produce_pos( 1 ) + // root domain : (iS8{7}, iS9{11}, bS10{1}) + // Resolution: T5 is not inlined to the immediate consumer, + // T10. Resolution is done with the other path from T1, such + // as T8 or T9. + validateResolution( + tv->getRootDomain().at(2), + findTensorByName(all_tvs, 9)->getRootDomain().at(2), + iel_graph, + root_resolution_map); + break; + case 6: + // T6_l[ iS64{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS65{3}, + // iS63{5} ] ca_pos( 1 ) produce_pos( 1 ) + // root domain : (iS11{7}, bS12{1}) + // Resolution: Resolved by the immediate consumer (T8) + validateResolution( + tv->getRootDomain().at(1), + findTensorByName(all_tvs, 8)->getRootDomain().at(1), + iel_graph, + root_resolution_map); + break; + case 9: + // T9_l[ iS33{( ceilDiv(( ceilDiv(( ( 7 * 1 ) * 13 ), 5) ), 3) )}, + // iS34{3}, iS32{5} ] produce_pos( 1 ) + // root domain : (iS17{7}, bS18{1}, iS19{13}) + // Resolution: T9 is not inlined to the immediate consumer, + // T10. Resolution is done with the other path from T1, such + // as T4 or T5 + validateResolution( + tv->getRootDomain().at(1), + findTensorByName(all_tvs, 5)->getRootDomain().at(1), + iel_graph, + root_resolution_map); + break; + default: + FAIL() << "Unexpected tensor: " << tv->toString(); + } + } +} + } // namespace nvfuser From a902c6eb2a2d894c98e6d90bf921a9d10d588bd6 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 20 Jan 2024 13:58:02 -0800 Subject: [PATCH 119/178] Refactoring to allow more flexible testing Previously IdModel constructor builds all graphs at once. This refactoring is done in preparation for loop promotion analysis to allow testing of each step of the analysis --- csrc/device_lower/lower2device.cpp | 2 +- csrc/id_model/id_model.cpp | 221 ++++++++++++++++++----------- csrc/id_model/id_model.h | 74 ++++++---- 3 files changed, 188 insertions(+), 109 deletions(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index fa34738d69c..f787f181d89 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -383,7 +383,7 @@ void GpuLower::analysis(Fusion* fusion) { // so it is expected that generated code may use diffrent variable // names if (isOptionEnabled(EnableOption::IdModel)) { - IdModel id_model(fusion_, false, true); + IdModel id_model(fusion_); } resolveComputeWith(fusion_); diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index ac97bbeb3a7..b709acdc80e 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -68,31 +68,64 @@ void IdModel::assertNoSelfMapping() { IdModel::IdModel( const std::vector& exprs, const std::vector& additional_tvs, + bool build_graphs, bool allow_self_mapping) { - build(exprs, additional_tvs); + std::copy_if( + exprs.begin(), + exprs.end(), + std::back_inserter(tv_exprs_), + [](Expr* expr) { + NVF_ERROR(expr != nullptr); + return ir_utils::isTvOp(expr); + }); - if (!allow_self_mapping) { - assertNoSelfMapping(); + auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs_); + all_tvs.pushBack(additional_tvs.begin(), additional_tvs.end()); + + tvs_ = all_tvs.vector(); + + // Add uses and definitions to all iter domains. + buildIterDomainDefinitionsAndUses(); + + if (build_graphs) { + buildAllGraphs(); } } -IdModel::IdModel(Fusion* fusion, bool allow_self_mapping, bool validate) { - std::vector inputs_and_outputs; +IdModel::IdModel( + Fusion* fusion, + bool build_graphs, + bool allow_self_mapping, + bool validate) + : allow_self_mapping_(allow_self_mapping), validate_(validate) { + auto all_exprs = fusion->exprs(); + std::copy_if( + all_exprs.begin(), + all_exprs.end(), + std::back_inserter(tv_exprs_), + [](Expr* expr) { + NVF_ERROR(expr != nullptr); + return ir_utils::isTvOp(expr); + }); + + auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs_); + { auto inp_tvs = ir_utils::filterByType(fusion->inputs()); - inputs_and_outputs.insert( - inputs_and_outputs.begin(), inp_tvs.begin(), inp_tvs.end()); + all_tvs.pushBack(inp_tvs.begin(), inp_tvs.end()); } { auto out_tvs = ir_utils::filterByType(fusion->outputs()); - inputs_and_outputs.insert( - inputs_and_outputs.end(), out_tvs.begin(), out_tvs.end()); + all_tvs.pushBack(out_tvs.begin(), out_tvs.end()); } - build(fusion->exprs(), inputs_and_outputs, validate); + tvs_ = all_tvs.vector(); - if (!allow_self_mapping) { - assertNoSelfMapping(); + // Add uses and definitions to all iter domains. + buildIterDomainDefinitionsAndUses(); + + if (build_graphs) { + buildAllGraphs(); } } @@ -108,7 +141,11 @@ const ValGraph& IdModel::idGraph(IdMappingMode mode) const { ValGraph& IdModel::idGraph(IdMappingMode mode) { auto graph_it = id_graphs_.find(mode); - NVF_ERROR(graph_it != id_graphs_.end()); + NVF_ERROR( + graph_it != id_graphs_.end(), + "Failed to find an IdGraph with the ", + mode, + " mode"); return graph_it->second; } @@ -176,7 +213,7 @@ std::optional> detectMappablePair( std::optional> findFirstSelfMapping( const std::vector& all_tvs, - const IdModel& id_graph) { + const IdModel& id_model) { for (auto tv : all_tvs) { // For each tensor, make sure root, rfactor and leaf domains // should not include domains that are mapped with another domain @@ -185,7 +222,7 @@ findFirstSelfMapping( // Root domains auto self_mappped_root_pair = - detectMappablePair(tv->getRootDomain(), id_graph, IdMappingMode::EXACT); + detectMappablePair(tv->getRootDomain(), id_model, IdMappingMode::EXACT); if (self_mappped_root_pair.has_value()) { return std::make_tuple( tv, @@ -197,7 +234,7 @@ findFirstSelfMapping( // Rfactor domains if (tv->hasRFactor()) { auto self_mappped_rf_pair = detectMappablePair( - tv->getRFactorDomain(), id_graph, IdMappingMode::EXACT); + tv->getRFactorDomain(), id_model, IdMappingMode::EXACT); if (self_mappped_rf_pair.has_value()) { return std::make_tuple( tv, @@ -212,7 +249,7 @@ findFirstSelfMapping( // map. However, it should also be impossible for index map to generate a // case like this. auto self_mappped_leaf_pair = detectMappablePair( - tv->domain()->leaf(), id_graph, IdMappingMode::EXACT); + tv->domain()->leaf(), id_model, IdMappingMode::EXACT); if (self_mappped_leaf_pair.has_value()) { return std::make_tuple( tv, @@ -226,9 +263,8 @@ findFirstSelfMapping( } // namespace -void IdModel::buildIterDomainDefinitionsAndUses( - const std::vector& all_tvs) { - for (const auto tv : all_tvs) { +void IdModel::buildIterDomainDefinitionsAndUses() { + for (const auto tv : tvs_) { VectorOfUniqueEntries root_domain_ids{ tv->getRootDomain().begin(), tv->getRootDomain().end()}; @@ -313,8 +349,13 @@ ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) { return id_graph; } -void IdModel::buildExactGraph(const std::vector& exprs) { - for (auto expr : exprs) { +void IdModel::buildExactGraph() { + // Initialize the maps with all the IterDomains used in the provded + // expressions. + NVF_ERROR( + id_graphs_.emplace(IdMappingMode::EXACT, initializeIdGraph()).second); + + for (auto expr : tv_exprs_) { TensorView* c_tv = ir_utils::getTvOutput(expr); auto all_tv_outputs = ir_utils::filterByType(expr->outputs()); @@ -398,9 +439,15 @@ std::vector> getTriviallyMappedIds(Expr* expr) { } // namespace -void IdModel::buildAlmostExactMap() { +void IdModel::buildAlmostExactGraph() { + // Make sure the exact graph is already built + maybeBuildGraph(IdMappingMode::EXACT); + // Build almost exact map by forwarding through broadcast axes - idGraph(IdMappingMode::ALMOSTEXACT) = idGraph(IdMappingMode::EXACT); + NVF_ERROR( + id_graphs_ + .emplace(IdMappingMode::ALMOSTEXACT, idGraph(IdMappingMode::EXACT)) + .second); auto& almost_exact_graph = idGraph(IdMappingMode::ALMOSTEXACT); @@ -437,13 +484,19 @@ void IdModel::buildAlmostExactMap() { almost_exact_graph.validateConsistency(); } -void IdModel::buildPermissiveMap(const std::vector& exprs) { +void IdModel::buildPermissiveGraph() { + // Make sure the exact graph is already built + maybeBuildGraph(IdMappingMode::EXACT); + // Use the exact map as the starting map rather than the // almost-exact map. Almost exact is useful for index hoisting but // not necessary for permissive and loop maps - idGraph(IdMappingMode::PERMISSIVE) = idGraph(IdMappingMode::EXACT); + NVF_ERROR( + id_graphs_ + .emplace(IdMappingMode::PERMISSIVE, idGraph(IdMappingMode::EXACT)) + .second); - for (auto expr : exprs) { + for (auto expr : tv_exprs_) { // Multiple outputs are already mapped, we can ignore all but the first // consumer given they have to be replayed in the same exact way TensorView* c_tv = ir_utils::getTvOutput(expr); @@ -472,8 +525,6 @@ void IdModel::buildPermissiveMap(const std::vector& exprs) { idGraph(IdMappingMode::PERMISSIVE).validateConsistency(); } -namespace { - // Grab inlining relationships StatefulInliningInfo buildStatefulInliningInfo( const std::vector& exprs, @@ -520,13 +571,12 @@ StatefulInliningInfo buildStatefulInliningInfo( return info; } -} // namespace - -void IdModel::initializeLoopMap(const StatefulInliningInfo& info) { +void IdModel::initializeLoopGraph(const StatefulInliningInfo& info) { // In the case of the Loop graph, we do not propagate mappings but // explicitly set which domains to map based on the permissive graph // and the CA positions. - idGraph(IdMappingMode::LOOP) = initializeIdGraph(false); + NVF_ERROR( + id_graphs_.emplace(IdMappingMode::LOOP, initializeIdGraph(false)).second); // Make sure this is called in a deterministic order. Build all inlined // relationships in loop graph. @@ -541,82 +591,93 @@ void IdModel::initializeLoopMap(const StatefulInliningInfo& info) { } } -void IdModel::buildLoopMap(const std::vector& exprs) { +void IdModel::buildLoopGraph() { + // Make sure the depedent graphs are already built + maybeBuildGraph(IdMappingMode::EXACT); + maybeBuildGraph(IdMappingMode::PERMISSIVE); + const StatefulInliningInfo info = buildStatefulInliningInfo( - exprs, idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::PERMISSIVE)); + tv_exprs_, + idGraph(IdMappingMode::EXACT), + idGraph(IdMappingMode::PERMISSIVE)); - initializeLoopMap(info); + initializeLoopGraph(info); idGraph(IdMappingMode::LOOP).validateConsistency(); } -void IdModel::build( - const std::vector& exprs, - const std::vector& additional_tvs, - bool validate) { - // Initialize the required sets as if a permissive relationship is never - // found, then querying an empty permissive map will fail later. - // Initialize disjoint sets - for (auto mode : kIdMappingModes) { - id_graphs_[mode] = ValGraph(); - } - - std::vector tv_exprs; - - std::copy_if( - exprs.begin(), exprs.end(), std::back_inserter(tv_exprs), [](Expr* expr) { - NVF_ERROR(expr != nullptr); - return ir_utils::isTvOp(expr); - }); - - auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs); - - for (auto additional_tv : additional_tvs) { - all_tvs.pushBack(additional_tv); - } - - if (all_tvs.empty()) { +void IdModel::buildAllGraphs() { + if (tvs_.empty()) { return; } std::unique_ptr validator; + Fusion* fusion = tvs_.front()->fusion(); + // A ComputeAtMap will be built inside the constructor of // IdModelValidator, which may fail for some fusions that are not // supported currently (but work with IdModel). Make sure the // validator is only created when it is indeed requested - if (validate) { - validator = std::make_unique(all_tvs.front()->fusion()); + if (validate_) { + validator = std::make_unique(fusion); } - FusionGuard fg(all_tvs.front()->fusion()); - // Add uses and definitions to all iter domains. - buildIterDomainDefinitionsAndUses(all_tvs.vector()); + FusionGuard fg(fusion); - // Initialize the maps with all the IterDomains used in the provded - // expressions. - idGraph(IdMappingMode::EXACT) = initializeIdGraph(); - - buildExactGraph(tv_exprs); - if (validate) { + buildExactGraph(); + if (validate_) { validator->checkExactGraphEquivalence(idGraph(IdMappingMode::EXACT)); } - buildAlmostExactMap(); - if (validate) { + // Make sure there's no self mapping in TensorView's during lowering + // that would invalidate lowering assumptions. + self_mapping_info_ = findFirstSelfMapping(tvs_, *this); + if (!allow_self_mapping_) { + assertNoSelfMapping(); + } + + buildAlmostExactGraph(); + if (validate_) { validator->checkAlmostExactGraphEquivalence( idGraph(IdMappingMode::ALMOSTEXACT)); } - buildPermissiveMap(tv_exprs); - if (validate) { + buildPermissiveGraph(); + // Validation is not implemented when compliment mapping is enabled + if (validate_) { validator->checkPermissiveGraphEquivalence( idGraph(IdMappingMode::PERMISSIVE)); } - // Make sure there's no self mapping in TensorView's during lowering - // that would invalidate lowering assumptions. - self_mapping_info_ = findFirstSelfMapping(all_tvs.vector(), *this); + buildLoopGraph(); +} + +void IdModel::buildGraph(IdMappingMode mode) { + switch (mode) { + case IdMappingMode::EXACT: + buildExactGraph(); + break; + case IdMappingMode::ALMOSTEXACT: + buildAlmostExactGraph(); + break; + case IdMappingMode::PERMISSIVE: + buildPermissiveGraph(); + break; + case IdMappingMode::LOOP: + buildLoopGraph(); + break; + default: + NVF_ERROR(false, "Unsupported mode: ", mode); + } +} + +void IdModel::maybeBuildGraph(IdMappingMode mode) { + if (id_graphs_.find(mode) != id_graphs_.end()) { + return; + } else { + buildGraph(mode); + } } } // namespace nvfuser diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 2e66c0a93cd..d05b8e4dad4 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -32,6 +32,11 @@ struct StatefulInliningInfo { p2c_ca_permissive_maps; }; +StatefulInliningInfo buildStatefulInliningInfo( + const std::vector& exprs, + const ValGraph& exact_graph, + const ValGraph& permissive_graph); + // A collection of ValGraphs that are built from a fusion or series of // expressions. These graphs are related, but have some distinct features based // on the IdMappingMode. @@ -74,9 +79,16 @@ struct StatefulInliningInfo { // class IdModel : public PolymorphicBase { public: + // Sometimes fusion inputs or outputs are disconnected from expressions, in + // those cases we still may want to send in some additional tensor views from + // the Fusion that don't have expressions associated with them. + // + // All graphs are built by default. It can be disabled with + // build_graphs=false. IdModel( const std::vector& exprs, const std::vector& additional_tvs = {}, + bool build_graphs = true, bool allow_self_mapping = false); // Same as the above constructor with fusion->exprs() excpet fusion may have @@ -87,8 +99,9 @@ class IdModel : public PolymorphicBase { // transition from the current ComputeAtMap. IdModel( Fusion* fusion, + bool build_graphs = true, bool allow_self_mapping = false, - bool validate = false); + bool validate = true); // Returns iter domain graph of provided mode. const ValGraph& idGraph(IdMappingMode mode) const; @@ -110,50 +123,55 @@ class IdModel : public PolymorphicBase { std::string toString() const; - // TODO: Should this not be private? - protected: - // Sometimes fusion inputs or outputs are disconnected from expressions, in - // those cases we still may want to send in some additional tensor views from - // the Fusion that don't have expressions associated with them. - void build( - const std::vector& exprs, - const std::vector& additional_tvs, - bool validate = false); - - // ======= START Iteration domain build process in order called ======= - - // Fills id_uses_ and id_definitions_ for all IterDomains active in the - // fusion. - void buildIterDomainDefinitionsAndUses( - const std::vector& all_tvs); - - // Iterates over all IterDomains in id_definitions_ and calls initializeVal on - // a new ValGraph and returns it. - ValGraph initializeIdGraph(bool propagate_through_exprs = true); + // Build all graphs. This is by default called from the constructor + void buildAllGraphs(); // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs // and first output of expr - void buildExactGraph(const std::vector& exprs); + void buildExactGraph(); // Fills disjoint_ids_[IdMappingMode::ALMOSTEXACT]. Initialize AlmostExact as // Exact entries, then map anything that's either merged with a size-1 or // split by a size-1 dimension. - void buildAlmostExactMap(); + void buildAlmostExactGraph(); // Fills disjoint_ids_[IdMappingMode::PERMISSIVE]. Initialize it as - // Exact entries, then map through broadcasts - void buildPermissiveMap(const std::vector& exprs); + // Exact entries, then map through broadcasts. Build the Exact graph + // as well if not yet done. + void buildPermissiveGraph(); // Fills disjoint_ids_[IdMappingMode::LOOP]. Map only inlined - // domains that are mapped in the permissive graph - void buildLoopMap(const std::vector& exprs); + // domains that are mapped in the permissive graph. Build the Exact + // and Permissive graphs as well if not yet done. + void buildLoopGraph(); + + // Build a graph. Dependent graphs are also built if not yet done. + void buildGraph(IdMappingMode mode); + + // Build a graph if not already built + void maybeBuildGraph(IdMappingMode mode); + + // Iterates over all IterDomains in id_definitions_ and calls initializeVal on + // a new ValGraph and returns it. + ValGraph initializeIdGraph(bool propagate_through_exprs = true); + + protected: + // Fills id_uses_ and id_definitions_ for all IterDomains active in the + // fusion. + void buildIterDomainDefinitionsAndUses(); /// Start loop map by grouping inlined iter domains - void initializeLoopMap(const StatefulInliningInfo& info); + void initializeLoopGraph(const StatefulInliningInfo& info); // Errors if self mapping occurs void assertNoSelfMapping(); + protected: + std::vector tv_exprs_; + std::vector tvs_; + bool allow_self_mapping_ = false; + bool validate_ = false; + // Keeps ValGraphs containing all IterDomains for all mapping mode types. // // Using an array here might be nice, but it seems hard to use an enum as an From 7ef52afe8805b331d2fda65044c688104984a66b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 20 Jan 2024 14:16:18 -0800 Subject: [PATCH 120/178] enable idmodel --- csrc/device_lower/lower2device.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index f787f181d89..3fc0cea3e23 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -382,7 +382,7 @@ void GpuLower::analysis(Fusion* fusion) { // functionality should be affected. New IterDomains may be created, // so it is expected that generated code may use diffrent variable // names - if (isOptionEnabled(EnableOption::IdModel)) { + if (true || isOptionEnabled(EnableOption::IdModel)) { IdModel id_model(fusion_); } From eced85a991b0e462b9aea840bcefc721ea7cadc2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 20 Jan 2024 14:29:13 -0800 Subject: [PATCH 121/178] comment --- csrc/id_model/id_model.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index d05b8e4dad4..cc8386b5d20 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -103,7 +103,8 @@ class IdModel : public PolymorphicBase { bool allow_self_mapping = false, bool validate = true); - // Returns iter domain graph of provided mode. + // Returns iter domain graph of provided mode. The graph must have + // been already built. const ValGraph& idGraph(IdMappingMode mode) const; ValGraph& idGraph(IdMappingMode mode); @@ -123,7 +124,8 @@ class IdModel : public PolymorphicBase { std::string toString() const; - // Build all graphs. This is by default called from the constructor + // Build all graphs, i.e., Exact, AlmostExact, Permissive and + // LOOP. This is by default called from the constructor void buildAllGraphs(); // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs @@ -167,9 +169,17 @@ class IdModel : public PolymorphicBase { void assertNoSelfMapping(); protected: + // All tensor expressions that this model analyzes std::vector tv_exprs_; + + // All tensors that this model analyzes std::vector tvs_; + + // Tensors should not have domains that are mapped with another + // domains of the same tensor. This flag disables the check bool allow_self_mapping_ = false; + + // If true, validate graphs by comparing them with ComputeAtMap bool validate_ = false; // Keeps ValGraphs containing all IterDomains for all mapping mode types. From cbaaf0ea0168c220313932fe694c99b0583d541e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 20 Jan 2024 21:45:22 -0800 Subject: [PATCH 122/178] WIP: Build a map of broadcast resolutions for root domains This is the first step of the loop promotion analysis. It may not be very clear what this analysis is intended to do at this point, but in short, the root resolution information is propagated to intermediate and leaf ID groups in the IEL graph, which is then projected back to the loop graph. I thought before diving into the full promotion analysis it would make reviewing easier to split this part off as the first sub-PR with unit tests. --- csrc/disjoint_set.h | 5 + csrc/id_model/id_model.cpp | 257 ++++++++++++++++++++++++++++- csrc/id_model/id_model.h | 36 ++++ csrc/val_graph.h | 26 +++ test/test_id_model.cpp | 327 ++++++++++++++++++++++++++++++++++++- 5 files changed, 648 insertions(+), 3 deletions(-) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 25f4c183af0..4e4f329b287 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -42,6 +42,11 @@ std::string abstractToString(T ref) { template > class VectorOfUniqueEntries { public: + // Naming not following our conventions but using the same name as + // std::vector makes it more convenient when we want to use this + // class as if it's like std::vector + using value_type = T; + VectorOfUniqueEntries() = default; VectorOfUniqueEntries(const std::initializer_list& initializer) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index b709acdc80e..970df7b89bf 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -509,10 +509,28 @@ void IdModel::buildPermissiveGraph() { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); } + if (permissive_graph_map_compliment_ids_) { + for (const auto& entry : + permissive_forwarding.producer_compliment_map) { + for (auto entry_2 : entry.second) { + idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry_2); + } + } + } + for (auto entry : permissive_forwarding.consumer_forwarding_map) { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); } + if (permissive_graph_map_compliment_ids_) { + for (const auto& entry : + permissive_forwarding.consumer_compliment_map) { + for (auto entry_2 : entry.second) { + idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry_2); + } + } + } + auto permissive_c2p_root_map = PairwiseRootDomainMap(p_tv, c_tv).mapBroadcast(true); @@ -525,6 +543,41 @@ void IdModel::buildPermissiveGraph() { idGraph(IdMappingMode::PERMISSIVE).validateConsistency(); } +namespace { + +// Returns the root producer iteration domains that are resolved by provided +// consumer +std::unordered_map resolvedRootBroadcasts( + TensorView* producer, + TensorView* consumer) { + auto p2c_map = PairwiseRootDomainMap(producer, consumer) + .mapBroadcast(true) + .mapProducerToConsumer(); + + std::unordered_map resolved_bcast_map; + for (const auto& [p_id, c_id] : p2c_map) { + // Look for a broadcast producer and non-broadcast consumer + + // Ignore non-broadcast producer and broadcast consumer dims + if (!p_id->isBroadcast() || c_id->isBroadcast()) { + continue; + } + + if (c_id->isReduction()) { + // This should only happen with expanded broadcast + // domains. Otherwise, squeeze should be used + NVF_ERROR( + p_id->hasExpandedExtent(), "Unexpected domain: ", c_id->toString()); + continue; + } + + resolved_bcast_map[p_id] = c_id; + } + return resolved_bcast_map; +} + +} // namespace + // Grab inlining relationships StatefulInliningInfo buildStatefulInliningInfo( const std::vector& exprs, @@ -565,6 +618,13 @@ StatefulInliningInfo buildStatefulInliningInfo( info.p2c_ca_permissive_maps[p_id->as()].pushBack(c_ids); } } + + std::unordered_map resolved_bcast_map = + resolvedRootBroadcasts(producer_tv, consumer_tv); + + for (const auto& [p_root_id, c_root_id] : resolved_bcast_map) { + info.p2c_root_broadcast_resolution_map[p_root_id].pushBack(c_root_id); + } } } } @@ -596,16 +656,187 @@ void IdModel::buildLoopGraph() { maybeBuildGraph(IdMappingMode::EXACT); maybeBuildGraph(IdMappingMode::PERMISSIVE); - const StatefulInliningInfo info = buildStatefulInliningInfo( + const StatefulInliningInfo inlining_info = buildStatefulInliningInfo( tv_exprs_, idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::PERMISSIVE)); - initializeLoopGraph(info); + initializeLoopGraph(inlining_info); + + loop_promotion_map_ = buildLoopPromotionMap(inlining_info); idGraph(IdMappingMode::LOOP).validateConsistency(); } +std::unordered_map IdModel::buildLoopPromotionMap( + const StatefulInliningInfo& inlining_info) { + // Make an intersection of the exact and loop map. This will group together + // entries in each loop group that are exact with each other. This provides a + // better graph to do promotion and replays. + // + // It's tempting to use the intersection of the almost exact and loop, but we + // need to model broadcast promotion, and if we have two tensors like: + // + // T1[i0, b1] = T0[i0] + // T2[i0, b2] = T0[i0] + // Then resolution of: + // T4 = T1[i0, b1] + T3[i0, i1] + // T6 = T2[i0, b2] + T5[i0, i2] + // + // Then merge(0, 1) with all tensors except for T0 + // + // The almost exact map will map i0, i0*b1, and i0*b2 together, but b1 and b2 + // are being resolved to i1 and i2 respectively. So we want to have separate + // entries so we can have an easy to process promotion map. + // + // Loop is a permissive like map, it could have many entries, use the exact + // map as the one we iterate on to reduce complexity as it hopefully has + // smaller groups and this algorithm scales with the number of groups * + // (number of entries in groups ^ 2) + // + // iel stands for Intersection of the Exact and Loop graphs. + ValGraph iel_graph = buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); + + // Step 1: Build a map of the IEL groups of root broadcast domains + // to resolving domains. + std::unordered_map iel_promotion_map = + buildInlineRootResolutionmap(iel_graph, inlining_info); + + // This is not a right map to return but just a placeholder since + // the loop promotion map is not yet completely merged. It will be + // replaced by a proper map. + return iel_promotion_map; +} + +std::unordered_map IdModel::buildInlineRootResolutionmap( + const ValGraph& iel_graph, + const StatefulInliningInfo& info) { + std::unordered_map iel_promotion_map; + + // This should probably work just on terminating inputs, as we shouldn't be + // able to modify a broadcast domain between root and rfactor which would be + // required to resolve a non input broadcast domain. But for now leaving it as + // traversal on all broadcast groups. + // + + // We first visit all broadcast root domains. If a broadcast is + // resovled, see if it's promoted. Note that a domain be resolved to + // a domain that may not be loop mapped, yet it can still be + // promoted. In other words, there can be a domain that is exactly + // mapped with the resolving domain *and* is mapped with the + // broadcast domain by the loop map. The algorihm here is: + // + // 1. For a broadcast domain, find the domain that the broadcast is + // resolved to. + // 2. If the resolving domain is also loop-mapped with the + // broadcast, that is the promotion domain, but the resolving + // domain may not be loop mapped as mentioned above. Instead, + // find all loop-mapped domains with the broadcast domain and + // pick one that is exactly mapped with the resolving domain + // + // Note again this process is only done for root domains. Once we + // find promotion relationships for root domains, we propagate the + // mappings to derived domains + for (const ValGroup& iel_group : iel_graph.disjointValSets().disjointSets()) { + NVF_ERROR(!iel_group->empty()); + + IterDomain* iel_group_id = iel_group->front()->as(); + + if (!iel_group_id->isBroadcast()) { + continue; + } + + // Collect all the exact groups of the resolutions of the broadcast id's + ValGroups resolved_exact_groups; + for (Val* bcast_id : *iel_group) { + if (auto p2c_root_broadcast_resolution_map_it = + info.p2c_root_broadcast_resolution_map.find( + bcast_id->as()); + p2c_root_broadcast_resolution_map_it != + info.p2c_root_broadcast_resolution_map.end()) { + resolved_exact_groups.pushBack( + idGraph(IdMappingMode::EXACT) + .toGroups(p2c_root_broadcast_resolution_map_it->second)); + } + } + + if (resolved_exact_groups.empty()) { + // No resolution + continue; + } + + // resolved_exact_groups is a list of IDs that resolves the + // broadcast. We only care those that are also in the same loop + // group, and there must be just one or none. Otherwise, the + // resolution is ambiguous. + + // Collect all the exact groups in the loop set containing this iel_group + const ValGroup& loop_group = + idGraph(IdMappingMode::LOOP).toGroup(iel_group_id); + ValGroups loop_covered_exact_groups = + idGraph(IdMappingMode::EXACT).toGroups(*loop_group); + + // The intersection of the exact groups that the broadcast domains can be + // broadcasted to, and those that exist within the same loop groop are is + // the promotion needed for this iel_group. The promotion should + // be none or unique. + ValGroups loop_exact_resolved_intersection = + resolved_exact_groups.computeIntersect(loop_covered_exact_groups); + + if (loop_exact_resolved_intersection.empty()) { + // No promotion + continue; + } + + if (loop_exact_resolved_intersection.size() > 1) { + // Ambiguous promotion. This should not happen. + std::stringstream err_msg; + err_msg + << "Invalid multiple broadcast resolution within shared loops detected, group:\n " + << iel_group->toString() << "\nIs being broadcasted to:"; + for (const ValGroup& entry : loop_exact_resolved_intersection) { + err_msg << "\n " << entry->toString(); + } + NVF_ERROR(false, err_msg.str()); + } + + const ValGroup& exact_resolution_group = + loop_exact_resolved_intersection.front(); + + // Within the loop group, find the IDs that the broadcast IDs are + // resolved to + VectorOfUniqueEntries resolved_ids = + exact_resolution_group->computeIntersect(*loop_group); + + NVF_ERROR(!resolved_ids.empty()); + + // All the IDs in resolved_ids are mapped with both of the exact + // and loop graphs, so any of them can be used as an IEL promotion + // ID. Just to make it extra clear, look for corresponding + // groups in the IEL graph and make sure there's only one such group. + ValGroups promoted_iel_groups = iel_graph.toGroups(resolved_ids); + + NVF_ERROR(!promoted_iel_groups.empty()); + + if (promoted_iel_groups.size() > 1) { + std::stringstream err_msg; + err_msg + << "Invalid multiple broadcast resolution within shared loops detected, group:\n " + << iel_group->toString() << "\nIs being broadcasted to:"; + for (const ValGroup& entry : promoted_iel_groups) { + err_msg << "\n " << entry->toString(); + } + NVF_ERROR(false, err_msg.str()); + } + + iel_promotion_map[iel_group] = + promoted_iel_groups.front()->front()->as(); + } + + return iel_promotion_map; +} + void IdModel::buildAllGraphs() { if (tvs_.empty()) { return; @@ -680,4 +911,26 @@ void IdModel::maybeBuildGraph(IdMappingMode mode) { } } +ValGraph IdModel::buildIntersection( + const ValGraph& graph0, + const ValGraph& graph1, + bool propagate_exprs) { + ValGraph intersection = initializeIdGraph(propagate_exprs); + for (const ValGroup& group0 : graph0.disjointValSets().disjointSets()) { + auto set_size = group0->size(); + for (auto id0_i : c10::irange(set_size)) { + Val* id0 = group0->vector()[id0_i]; + for (auto id1_i = id0_i; id1_i < set_size; id1_i++) { + Val* id1 = group0->vector()[id1_i]; + // id0 and id1 map in group0. If they also map in the group1, + // add the mapping to the intersection. + if (graph1.disjointValSets().strictAreMapped(id0, id1)) { + intersection.mapVals(id0, id1); + } + } + } + } + return intersection; +} + } // namespace nvfuser diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index cc8386b5d20..d85fbf011de 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -30,6 +30,11 @@ struct StatefulInliningInfo { // leaf domains. std::unordered_map> p2c_ca_permissive_maps; + + // Broadcast resolution map for root domains, including non-inlined + // root domains + std::unordered_map> + p2c_root_broadcast_resolution_map; }; StatefulInliningInfo buildStatefulInliningInfo( @@ -157,6 +162,17 @@ class IdModel : public PolymorphicBase { // a new ValGraph and returns it. ValGraph initializeIdGraph(bool propagate_through_exprs = true); + // Returns an IdGraph with all Id's mapped that are mapped both in graph0 and + // graph1. + ValGraph buildIntersection( + const ValGraph& graph0, + const ValGraph& graph1, + bool propagate_exprs = true); + + const std::unordered_map& loopPromotionMap() const { + return loop_promotion_map_; + } + protected: // Fills id_uses_ and id_definitions_ for all IterDomains active in the // fusion. @@ -165,6 +181,19 @@ class IdModel : public PolymorphicBase { /// Start loop map by grouping inlined iter domains void initializeLoopGraph(const StatefulInliningInfo& info); + // Build a map of loop groups to IterDomains that represent actual + // loops. The map is built based on the broadcast resolution with + // root domains between inlined producer and consumer tensors. + std::unordered_map buildLoopPromotionMap( + const StatefulInliningInfo& info); + + // Helper function for buildLoopPromotionMap. Returns a map of + // root broadcast ValGroups in the IEL graph to a representative + // IterDomain picked from its IEL group. + std::unordered_map buildInlineRootResolutionmap( + const ValGraph& iel_graph, + const StatefulInliningInfo& info); + // Errors if self mapping occurs void assertNoSelfMapping(); @@ -182,6 +211,10 @@ class IdModel : public PolymorphicBase { // If true, validate graphs by comparing them with ComputeAtMap bool validate_ = false; + // By default, the permissive graph should map compliment domains as + // well. See the design doc for more details + bool permissive_graph_map_compliment_ids_ = true; + // Keeps ValGraphs containing all IterDomains for all mapping mode types. // // Using an array here might be nice, but it seems hard to use an enum as an @@ -206,6 +239,9 @@ class IdModel : public PolymorphicBase { self_mapping_info_ = std::nullopt; std::unordered_set view_rfactor_ids_; + + // Promotion domain for each loop group + std::unordered_map loop_promotion_map_; }; } // namespace nvfuser diff --git a/csrc/val_graph.h b/csrc/val_graph.h index f7a20e46bdb..f89e4d52b69 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -89,6 +90,31 @@ class ValGraph { // Convert Val to its ValGroup, assert that it exists. const ValGroup& toGroup(Val* val) const; + // Convert a vector of Val* or Expr* to their ValGroups or + // ExprGroups, respectively + template < + typename ContainerType, + typename ElementType = typename std::remove_pointer< + typename ContainerType::value_type>::type, + typename = std::enable_if_t< + std::is_base_of::value || + std::is_base_of::value>> + typename std::conditional< + std::is_base_of::value, + ValGroups, + ExprGroups>::type + toGroups(const ContainerType& entries) const { + using RetType = typename std::conditional< + std::is_base_of::value, + ValGroups, + ExprGroups>::type; + RetType groups; + for (auto entry : entries) { + groups.pushBack(toGroup(entry)); + } + return groups; + } + // Return output/input Val groups of provided expr // Note that the same ValGroup can show up multiple times, so the // output type cannot be VectorOfUniqueEntries diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index fc823316564..89b4f645e80 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include namespace nvfuser { @@ -32,9 +33,333 @@ TEST_F(IdModelTest, DetectSelfMapping) { fusion.addOutput(tv2); EXPECT_THAT( - [&]() { IdModel id_model(&fusion); }, + [&]() { + IdModel id_model(&fusion); + id_model.buildAllGraphs(); + }, ::testing::ThrowsMessage( ::testing::HasSubstr("!hasSelfMapping"))); } +namespace { + +// Helper class to test IdModel +class IdModelTester : public IdModel { + public: + // Do not automatically build the graphs + IdModelTester(Fusion* fusion) : IdModel(fusion, /* build_graphs */ false) {} + + std::pair> + getInlineRootResolutionMap() { + // Make sure the depedent graphs are already built + maybeBuildGraph(IdMappingMode::EXACT); + maybeBuildGraph(IdMappingMode::PERMISSIVE); + + // Gather broadcast resolution and inlining information + const StatefulInliningInfo inlining_info = buildStatefulInliningInfo( + tv_exprs_, + idGraph(IdMappingMode::EXACT), + idGraph(IdMappingMode::PERMISSIVE)); + + initializeLoopGraph(inlining_info); + + ValGraph iel_graph = buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); + + std::unordered_map root_promotion_map = + buildInlineRootResolutionmap(iel_graph, inlining_info); + + return {std::move(iel_graph), std::move(root_promotion_map)}; + } +}; + +// Test if root_broadcast_id is resolved to ref_id. If ref_id is +// nullptr, test if root_broadcast_id has no resolution. +void validateResolution( + IterDomain* root_broadcast_id, + IterDomain* ref_id, + const ValGraph& iel_graph, + const std::unordered_map& root_resolution_map) { + ASSERT_TRUE(root_broadcast_id->isBroadcast()); + const auto& iel_group = iel_graph.toGroup(root_broadcast_id); + auto root_promotion_map_it = root_resolution_map.find(iel_group); + if (ref_id != nullptr) { + ASSERT_TRUE(root_promotion_map_it != root_resolution_map.end()) + << "Root resolution not found for: " << nvfuser::toString(iel_group); + ASSERT_FALSE(ref_id->isBroadcast()); + auto resolution_id = root_promotion_map_it->second; + ASSERT_TRUE( + iel_graph.disjointValSets().strictAreMapped(resolution_id, ref_id)) + << "Unexpected root resolution. " + << "Expected: " << ref_id->toString() + << ". Actual: " << resolution_id->toString(); + } else { + ASSERT_TRUE(root_promotion_map_it == root_resolution_map.end()) + << "Root resolution should not exist for: " + << nvfuser::toString(iel_group) + << ", but found: " << root_promotion_map_it->second->toString(); + } +} + +// Create a fusion where we're missing a valid concrete id so the compute at map +// processing will fail. We need to be able to create the concrete ID not just +// look for one. It is not yet possible to lower this fusion as the +// current indexing cannot generate correct indices. Also used in +// FusionIndeixing19 +std::unique_ptr createFusionWithMultipleResolutionPaths() { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({7}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = broadcast(tv1, {false, true}); + + auto tv3 = makeConcreteTensor({7, 11}); + fusion.addInput(tv3); + + auto tv4 = add(tv3, tv2); + auto tv5 = broadcast(tv4, {false, false, true}); + // tv4[7, 11, 1] + + auto tv6 = broadcast(tv1, {false, true}); + + auto tv7 = makeConcreteTensor({7, 13}); + fusion.addInput(tv7); + auto tv8 = add(tv7, tv6); + auto tv9 = broadcast(tv8, {false, true, false}); + // tv9[7, 1, 13] + + auto tv10 = add(tv5, tv9); + fusion.addOutput(tv10); + + // tv10[7, 11, 13] + tv10->merge(0)->merge(0); + // tv10[7*11*13] + tv10->split(0, 5)->split(0, 3); + // tv10[7*11*13//5//3, 3, 5] + + TransformPropagatorWithCheck propagator(tv10); + MaxRootDomainInfoSpanningTree(tv10).traverse(&propagator); + + std::vector tensors_to_inline{tv1, tv2, tv4, tv6, tv8}; + for (auto tensor : tensors_to_inline) { + tensor->inlineAt(1); + } + + return fusion_ptr; +} + +TensorView* findTensorByName( + const std::vector& tvs, + StmtNameType name) { + if (auto it = std::find_if( + tvs.begin(), + tvs.end(), + [&](TensorView* tv) { return tv->name() == name; }); + it != tvs.end()) { + return *it; + } else { + return nullptr; + } +} + +} // namespace + +// Testing root resolution with a simple broadcast pattern +TEST_F(IdModelTest, LoopGraphRootResolution1) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto t0 = makeSymbolicTensor(1); + fusion->addInput(t0); + auto t1 = makeSymbolicTensor(2); + fusion->addInput(t1); + auto t2 = broadcast(t0, {true, false}); + auto t3 = add(t2, t1); + fusion->addOutput(t3); + + { + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // Nothing inlined. Should be no resolution + ASSERT_TRUE(root_resolution_map.empty()); + } + + t2->inlineAt(2); + ASSERT_EQ(t2->getComputeAtPosition(), 2); + + { + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // t2 is now fully inlined. Its root broadcast domain should be + // resoled with the corresponding domain of t3 + validateResolution( + t2->getRootDomain().at(0), + t3->getRootDomain().at(0), + iel_graph, + root_resolution_map); + } +} + +// Test with a fusion with progressive broadcasting +TEST_F(IdModelTest, LoopGraphRootResolution2) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto t0 = makeSymbolicTensor(1); + fusion->addInput(t0); + auto t1 = makeSymbolicTensor(3); + fusion->addInput(t1); + + auto t2 = broadcast(t0, {true, false}); + auto t3 = broadcast(t2, {true, false, false}); + auto t4 = add(t3, t1); + fusion->addOutput(t4); + + inlineMost(); + + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // Validate t2 and t3 as they have root broadcast domains + validateResolution( + t2->getRootDomain().at(0), + t4->getRootDomain().at(1), + iel_graph, + root_resolution_map); + + validateResolution( + t3->getRootDomain().at(0), + t4->getRootDomain().at(0), + iel_graph, + root_resolution_map); +} + +// Multiple inlined and non-inlined broadcast domains +TEST_F(IdModelTest, LoopGraphRootResolution3) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + auto tv1 = makeSymbolicTensor(4); + fusion->addInput(tv1); + + auto tv2 = broadcast(tv0, {false, true, false, true}); + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + // tv3: [i0, i1, i2, i3] -> [i0*i1, i2*i3] + tv3->merge(0); + tv3->merge(1); + + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv2->inlineAt(1); + + // tv2: [i0*b1, i2*b3] ca(1) + // tv3: [i0*i1, i2*i3] + + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // The b1 broadcast domain tv2 should be resolved as it's inlined, + // but b3 should not. + validateResolution( + tv2->getRootDomain().at(1), + tv3->getRootDomain().at(1), + iel_graph, + root_resolution_map); + + validateResolution( + tv2->getRootDomain().at(3), nullptr, iel_graph, root_resolution_map); +} + +TEST_F(IdModelTest, LoopGraphRootResolution4) { + auto fusion = createFusionWithMultipleResolutionPaths(); + auto all_tvs = ir_utils::allTvs(fusion.get()); + + fusion->print(); + + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // Verify all tensors with broadcast have correct resolution of root + // broadcast domains + for (auto tv : ir_utils::allTvs(fusion.get())) { + // Skip tensors with no broadcast + if (std::none_of( + tv->getRootDomain().begin(), + tv->getRootDomain().end(), + [](auto id) { return id->isBroadcast(); })) { + continue; + } + + switch (tv->name()) { + case 2: + // T2_l[ iS49{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS50{3}, + // iS48{5} ] ca_pos( 1 ) produce_pos( 1 ) + // root domain : (iS2{7}, bS3{1}) + // Resolution: Resolved by the immediate consumer (T4) + validateResolution( + tv->getRootDomain().at(1), + findTensorByName(all_tvs, 4)->getRootDomain().at(1), + iel_graph, + root_resolution_map); + break; + case 5: + // T5_l[ iS39{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 1 ), 5) ), 3) )}, + // iS40{3}, iS38{5} ] produce_pos( 1 ) + // root domain : (iS8{7}, iS9{11}, bS10{1}) + // Resolution: T5 is not inlined to the immediate consumer, + // T10. Resolution is done with the other path from T1, such + // as T8 or T9. + validateResolution( + tv->getRootDomain().at(2), + findTensorByName(all_tvs, 9)->getRootDomain().at(2), + iel_graph, + root_resolution_map); + break; + case 6: + // T6_l[ iS64{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS65{3}, + // iS63{5} ] ca_pos( 1 ) produce_pos( 1 ) + // root domain : (iS11{7}, bS12{1}) + // Resolution: Resolved by the immediate consumer (T8) + validateResolution( + tv->getRootDomain().at(1), + findTensorByName(all_tvs, 8)->getRootDomain().at(1), + iel_graph, + root_resolution_map); + break; + case 9: + // T9_l[ iS33{( ceilDiv(( ceilDiv(( ( 7 * 1 ) * 13 ), 5) ), 3) )}, + // iS34{3}, iS32{5} ] produce_pos( 1 ) + // root domain : (iS17{7}, bS18{1}, iS19{13}) + // Resolution: T9 is not inlined to the immediate consumer, + // T10. Resolution is done with the other path from T1, such + // as T4 or T5 + validateResolution( + tv->getRootDomain().at(1), + findTensorByName(all_tvs, 5)->getRootDomain().at(1), + iel_graph, + root_resolution_map); + break; + default: + FAIL() << "Unexpected tensor: " << tv->toString(); + } + } +} + } // namespace nvfuser From 45dd4189fa5839a4b8ff62ae4342c3b93c5a958c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sun, 21 Jan 2024 12:51:14 -0800 Subject: [PATCH 123/178] fix --- csrc/id_model/id_model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 970df7b89bf..b8997fff0f4 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -876,7 +876,7 @@ void IdModel::buildAllGraphs() { buildPermissiveGraph(); // Validation is not implemented when compliment mapping is enabled - if (validate_) { + if (!permissive_graph_map_compliment_ids_ && validate_) { validator->checkPermissiveGraphEquivalence( idGraph(IdMappingMode::PERMISSIVE)); } From 534ac78608fb9386a8ac351bd3b06dd384b2bb60 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 25 Jan 2024 20:57:57 -0800 Subject: [PATCH 124/178] rename --- csrc/id_model/id_model.cpp | 4 ++-- csrc/id_model/id_model.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 29c5172db92..5ae8216c0d2 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -701,7 +701,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // Step 1: Build a map of the IEL groups of root broadcast domains // to resolving domains. std::unordered_map iel_promotion_map = - buildInlineRootResolutionmap(iel_graph, inlining_info); + buildInlineRootResolutionMap(iel_graph, inlining_info); // This is not a right map to return but just a placeholder since // the loop promotion map is not yet completely merged. It will be @@ -709,7 +709,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( return iel_promotion_map; } -std::unordered_map IdModel::buildInlineRootResolutionmap( +std::unordered_map IdModel::buildInlineRootResolutionMap( const ValGraph& iel_graph, const StatefulInliningInfo& info) { std::unordered_map iel_promotion_map; diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index d85fbf011de..794de4c41ae 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -190,7 +190,7 @@ class IdModel : public PolymorphicBase { // Helper function for buildLoopPromotionMap. Returns a map of // root broadcast ValGroups in the IEL graph to a representative // IterDomain picked from its IEL group. - std::unordered_map buildInlineRootResolutionmap( + std::unordered_map buildInlineRootResolutionMap( const ValGraph& iel_graph, const StatefulInliningInfo& info); From 2837e53ab54ef3fcb08779e3333704260e3752e8 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 25 Jan 2024 21:21:48 -0800 Subject: [PATCH 125/178] Add tests with example fusions used in the design doc --- test/test_id_model.cpp | 89 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 5 deletions(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 89b4f645e80..5e7100ba168 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -67,7 +68,7 @@ class IdModelTester : public IdModel { idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); std::unordered_map root_promotion_map = - buildInlineRootResolutionmap(iel_graph, inlining_info); + buildInlineRootResolutionMap(iel_graph, inlining_info); return {std::move(iel_graph), std::move(root_promotion_map)}; } @@ -101,11 +102,50 @@ void validateResolution( } } +// Create a simple fusion with outer split. Currently invalid code +// will be generated. +// +// Used as Example 1 in the design doc about Loop +// Promotion Analysis. +std::unique_ptr createFusionWithInlinedOuterSplit() { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({1, 4}); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor({3, 4}); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = set(tv1); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + fusion.printMath(); + + // [i0, i1] + tv4->merge(0); + // [i0*i1] + tv4->split(0, 4, false); // outer split + // [4, i0*i1/4] + + TransformPropagator propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + for (auto tv: ir_utils::allTvs(&fusion)) { + tv->inlineAt(-2); + } + + return fusion_ptr; +} + // Create a fusion where we're missing a valid concrete id so the compute at map // processing will fail. We need to be able to create the concrete ID not just // look for one. It is not yet possible to lower this fusion as the // current indexing cannot generate correct indices. Also used in -// FusionIndeixing19 +// FusionIndeixing19 as well as Example 2 in the design doc about Loop +// Promotion Analysis. std::unique_ptr createFusionWithMultipleResolutionPaths() { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); @@ -286,8 +326,9 @@ TEST_F(IdModelTest, LoopGraphRootResolution3) { tv2->getRootDomain().at(3), nullptr, iel_graph, root_resolution_map); } +// Test root resolution with a fusion with outer split TEST_F(IdModelTest, LoopGraphRootResolution4) { - auto fusion = createFusionWithMultipleResolutionPaths(); + auto fusion = createFusionWithInlinedOuterSplit(); auto all_tvs = ir_utils::allTvs(fusion.get()); fusion->print(); @@ -299,14 +340,52 @@ TEST_F(IdModelTest, LoopGraphRootResolution4) { // Verify all tensors with broadcast have correct resolution of root // broadcast domains for (auto tv : ir_utils::allTvs(fusion.get())) { - // Skip tensors with no broadcast + // Skip tensors with no broadcast or non-inlined if (std::none_of( tv->getRootDomain().begin(), tv->getRootDomain().end(), - [](auto id) { return id->isBroadcast(); })) { + [](auto id) { return id->isBroadcast(); }) || + tv->getComputeAtPosition() == 0) { continue; } + switch (tv->name()) { + case 2: + // T2_l[ iS20{4}, iS21{( ceilDiv(( 1 * 4 ), 4) )} ] ca_pos( 1 ) + // root domain : (bS4{1}, iS5{4}) + validateResolution( + tv->getRootDomain().at(0), + findTensorByName(all_tvs, 4)->getRootDomain().at(0), + iel_graph, + root_resolution_map); + break; + default: + FAIL() << "Unexpected tensor: " << tv->toString(); + } + } +} + +TEST_F(IdModelTest, LoopGraphRootResolution5) { + auto fusion = createFusionWithMultipleResolutionPaths(); + auto all_tvs = ir_utils::allTvs(fusion.get()); + + IdModelTester tester(fusion.get()); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // Verify all tensors with broadcast have correct resolution of root + // broadcast domains + for (auto tv : ir_utils::allTvs(fusion.get())) { + // Skip tensors with no broadcast or non-inlined + if (std::none_of( + tv->getRootDomain().begin(), + tv->getRootDomain().end(), + [](auto id) { return id->isBroadcast(); }) || + tv->getComputeAtPosition() == 0) { + continue; + } + + switch (tv->name()) { case 2: // T2_l[ iS49{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS50{3}, From 45b8be94b5e69370c7c801acf6cc58b308b0e98c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 26 Jan 2024 10:54:48 -0800 Subject: [PATCH 126/178] Clean up tests --- test/test_id_model.cpp | 80 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 5e7100ba168..93e3b65faaa 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -133,7 +133,7 @@ std::unique_ptr createFusionWithInlinedOuterSplit() { TransformPropagator propagator(tv4); MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); - for (auto tv: ir_utils::allTvs(&fusion)) { + for (auto tv : ir_utils::allTvs(&fusion)) { tv->inlineAt(-2); } @@ -331,15 +331,13 @@ TEST_F(IdModelTest, LoopGraphRootResolution4) { auto fusion = createFusionWithInlinedOuterSplit(); auto all_tvs = ir_utils::allTvs(fusion.get()); - fusion->print(); - IdModelTester tester(fusion.get()); const auto& [iel_graph, root_resolution_map] = tester.getInlineRootResolutionMap(); // Verify all tensors with broadcast have correct resolution of root // broadcast domains - for (auto tv : ir_utils::allTvs(fusion.get())) { + for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( tv->getRootDomain().begin(), @@ -365,7 +363,78 @@ TEST_F(IdModelTest, LoopGraphRootResolution4) { } } +// Test root resolution with the same fusion as Indexing1 TEST_F(IdModelTest, LoopGraphRootResolution5) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1.0)); + auto tv3 = broadcast(tv2, {true, false, false, false}); + auto tv4 = add(tv3, tv1); + + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->merge(0); + tv4->merge(0); + + tv4->split(0, 128); + tv4->split(0, 4); + + tv2->computeAt(tv4, 1); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::Unroll); + tv4->axis(2)->parallelize(ParallelType::TIDx); + + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(2)->parallelize(ParallelType::TIDx); + + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(2)->parallelize(ParallelType::TIDx); + + auto all_tvs = ir_utils::allTvs(&fusion); + + IdModelTester tester(&fusion); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // Verify all tensors with broadcast have correct resolution of root + // broadcast domains + for (auto tv : all_tvs) { + // Skip tensors with no broadcast or non-inlined + if (std::none_of( + tv->getRootDomain().begin(), + tv->getRootDomain().end(), + [](auto id) { return id->isBroadcast(); }) || + tv->getComputeAtPosition() == 0) { + continue; + } + + switch (tv->name()) { + case 3: + // T3_l[ iS30{( ceilDiv(( ceilDiv(( ( ( 1 * i0 ) * i2 ) * i3 ), 128) ), + // 4) )}, iUR31{4}, ithreadIdx.x29{128} ] ca_pos( 1 ) produce_pos( 1 ) + // root domain : (bS10{1}, iS11{i0}, iS12{i2}, iS13{i3}) + validateResolution( + tv->getRootDomain().at(0), + findTensorByName(all_tvs, 4)->getRootDomain().at(0), + iel_graph, + root_resolution_map); + break; + default: + FAIL() << "Unexpected tensor: " << tv->toString(); + } + } +} + +// Test root resolution with the same fusion as Indexing19 +TEST_F(IdModelTest, LoopGraphRootResolution6) { auto fusion = createFusionWithMultipleResolutionPaths(); auto all_tvs = ir_utils::allTvs(fusion.get()); @@ -375,7 +444,7 @@ TEST_F(IdModelTest, LoopGraphRootResolution5) { // Verify all tensors with broadcast have correct resolution of root // broadcast domains - for (auto tv : ir_utils::allTvs(fusion.get())) { + for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( tv->getRootDomain().begin(), @@ -385,7 +454,6 @@ TEST_F(IdModelTest, LoopGraphRootResolution5) { continue; } - switch (tv->name()) { case 2: // T2_l[ iS49{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS50{3}, From ec8a2f5dd145a2809f995457937eb84ed13544d2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 26 Jan 2024 14:26:28 -0800 Subject: [PATCH 127/178] bug fix --- csrc/id_model/id_model.cpp | 96 +++++++++++++++++++++++++++++++------- 1 file changed, 78 insertions(+), 18 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index c76a7329bb0..dd732fcd64a 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -1070,19 +1070,31 @@ std::unordered_map IdModel::buildLoopPromotionMap( std::unordered_map iel_promotion_map = buildInlineRootResolutionmap(iel_graph, inlining_info); + { + std::stringstream ss; + ss << "Step 1: Root promotion map\n"; + for (const auto& [iel_group, promoted_id] : iel_promotion_map) { + ss << "\t" << nvfuser::toString(iel_group) << " -> " + << promoted_id->name() << std::endl; + } + VERBOSE() << ss.str(); + } + // Step 2: Propagate the root promotions to intermediate and leaf groups. // At this point, the promotion may not be final as the analysis is // localized to IEL groups. The map is used in the next step to // build mappings of the loop groups. propagatePromotionsInIELGraph(iel_graph, iel_promotion_map); - std::stringstream ss; - ss << "Inline promotion map\n"; - for (const auto& [iel_group, promoted_id] : iel_promotion_map) { - ss << "\t" << nvfuser::toString(iel_group) << " -> " << promoted_id->name() - << std::endl; + { + std::stringstream ss; + ss << "Step 2: IEL promotion map\n"; + for (const auto& [iel_group, promoted_id] : iel_promotion_map) { + ss << "\t" << nvfuser::toString(iel_group) << " -> " + << promoted_id->name() << std::endl; + } + VERBOSE() << ss.str(); } - VERBOSE() << ss.str(); // Step 3: Determine the promotion of each loop graph based on the // IEL promotion map. For each loop group, examine all the IEL @@ -1092,6 +1104,23 @@ std::unordered_map IdModel::buildLoopPromotionMap( projectIELPromotionToLoopGraph( iel_graph, iel_promotion_map, loop_graph_copy, inlining_info); + for (const auto& loop_group : + loop_graph_copy.disjointValSets().disjointSets()) { + auto it = loop_graph_copy_promotion_map.find(loop_group); + if (it == loop_graph_copy_promotion_map.end()) { + VERBOSE() << "No promotion found yet for loop group of " + << nvfuser::toString(loop_group) << std::endl; + } + } + + { + VERBOSE() << "Step 3: initial loop promotion map:" << std::endl; + for (const auto& [loop_group, id] : loop_graph_copy_promotion_map) { + VERBOSE() << nvfuser::toString(loop_group) << " -> " << id->name() + << std::endl; + } + } + // At this point, most of loop groups should have correct promoted // IDs. However, non-inlined loop groups may miss promotion that // should be propagated from parent ID groups, e.g., iS50 of T2 in @@ -1117,6 +1146,16 @@ std::unordered_map IdModel::buildLoopPromotionMap( loop_graph_copy_promotion_map, true); + { + std::stringstream ss; + ss << "Step 4: IEL promotion map\n"; + for (const auto& [iel_group, promoted_id] : final_iel_promotion_map) { + ss << "\t" << nvfuser::toString(iel_group) << " -> " + << promoted_id->name() << std::endl; + } + VERBOSE() << ss.str(); + } + // Step 5: Find the final promotion of each loop group based on the // final IEL promotion map auto final_loop_promotion_map = projectIELPromotionToLoopGraph( @@ -1520,6 +1559,14 @@ bool hasUniqueOutputLoopGroups( // promotions are propagated. In that case, loop_graph_promotion_map // should be just empty. // +// Propagation uses iel_promotion_map and +// loop_graph_promotion_map. If both are available for an IEL group, +// the former has the precedence. This is because when this function +// is used for step 4, the given iel_promotion_map is empty and gets +// populated during this propagation, whereas the loop promotion map +// is not guaranteed to have the correct mappings for partially +// inlined domains. +// // The loop_graph pamameter may not be up-to-date. void IdModel::propagatePromotionsInIELGraph( const ValGraph& iel_graph, @@ -1553,6 +1600,21 @@ void IdModel::propagatePromotionsInIELGraph( // Assumed all inputs are IterDomains NVF_ERROR(iel_inp_group->front()->isA()); + // Even when loop promotions are given, We still could require + // an input promotion. We could be traversing across non-inlined + // groups. Meaning we have inputs that were promoted in an + // inlined loop group traversing through the non-inlined + // portions of the iel graph. + if (auto inp_promo_it = iel_promotion_map.find(iel_inp_group); + inp_promo_it != iel_promotion_map.end()) { + maybe_promoted_inputs.push_back(inp_promo_it->second); + an_input_was_promoted = true; + VERBOSE() << "Promoted input by IEL promotion: " + << nvfuser::toString(iel_inp_group) << " -> " + << inp_promo_it->second->name() << std::endl; + continue; + } + // Promote loops based on the loop promotion map. If the loop promotion // map should be used and has an entry we should use that promotion. This // happen when an iel expression is across a loop group boundary. @@ -1566,22 +1628,13 @@ void IdModel::propagatePromotionsInIELGraph( if (inp_loop_promo_it != loop_graph_promotion_map.end()) { maybe_promoted_inputs.push_back(inp_loop_promo_it->second); an_input_was_promoted = true; + VERBOSE() << "Promoted input by loop promotion: " + << nvfuser::toString(iel_inp_group) << " -> " + << inp_loop_promo_it->second->name() << std::endl; continue; } } - // Even when loop promotions are given, We still could require - // an input promotion. We could be traversing across non-inlined - // groups. Meaning we have inputs that were promoted in an - // inlined loop group traversing through the non-inlined - // portions of the iel graph. - if (auto inp_promo_it = iel_promotion_map.find(iel_inp_group); - inp_promo_it != iel_promotion_map.end()) { - maybe_promoted_inputs.push_back(inp_promo_it->second); - an_input_was_promoted = true; - continue; - } - // No promotion found. Just use the non-promoted domain maybe_promoted_inputs.push_back(iel_inp_group->front()->as()); } @@ -1693,6 +1746,9 @@ void IdModel::propagatePromotionsInIELGraph( if (!promoted_expr) { promoted_expr = addReplayAs(maybe_promoted_inputs, iel_expr->front()); replayed = true; + for (auto id : maybe_promoted_inputs) { + VERBOSE() << "Maybe promoted input: " << id->name() << std::endl; + } VERBOSE() << "Replayed: " << promoted_expr->toString(); } else { VERBOSE() << "Reusing: " << promoted_expr->toString(); @@ -1718,6 +1774,8 @@ void IdModel::propagatePromotionsInIELGraph( } iel_promotion_map[out_groups[i]] = promoted_expr->output(i)->as(); + VERBOSE() << "IEL promotion: " << nvfuser::toString(out_groups[i]) + << " -> " << promoted_expr->output(i)->name() << std::endl; // Explicitly map loop map since expr propagation doesn't happen if (replayed) { idGraph(IdMappingMode::LOOP) @@ -1815,6 +1873,8 @@ IterDomain* IdModel::findPromotionOfLoopGroup( // Grab the iel entry const ValGroup& iel_group = iel_graph.toGroup(loop_id); + // Does it still need iel_promotion_map? The loop group already has + // the replayed domains, so we should be able to find it. auto iel_promo_it = iel_promotion_map.find(iel_group); if (iel_promo_it == iel_promotion_map.end()) { // If this terminal ID doesn't have a promotion associated with it, save From 19604322c2d112e69f17e113882a738681759143 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 26 Jan 2024 14:50:11 -0800 Subject: [PATCH 128/178] Add test --- test/test_id_model.cpp | 59 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 93e3b65faaa..e5ee1188cd3 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -509,4 +509,63 @@ TEST_F(IdModelTest, LoopGraphRootResolution6) { } } +// Same fusion as NvFuserTest.FusionInlineBroadcastIndexing0 +TEST_F(IdModelTest, LoopGraphRootResolution7) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + auto tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = set(tv0); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->split(0, 32); + + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv2->inlineAt(1); + tv3->inlineAt(1); + + tv2->split(-1, 8); + + auto all_tvs = ir_utils::allTvs(&fusion); + + IdModelTester tester(&fusion); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // Verify all tensors with broadcast have correct resolution of root + // broadcast domains + for (auto tv : all_tvs) { + // Skip tensors with no broadcast or non-inlined + if (std::none_of( + tv->getRootDomain().begin(), + tv->getRootDomain().end(), + [](auto id) { return id->isBroadcast(); }) || + tv->getComputeAtPosition() == 0) { + continue; + } + + switch (tv->name()) { + case 3: + // T3_l[ iS15{( ceilDiv(( 1 * i0 ), 32) )}, iS16{32} ] ca_pos( 1 ) + // produce_pos( 1 ) root domain : (bS4{1}, iS5{i0}) + validateResolution( + tv->getRootDomain().at(0), + findTensorByName(all_tvs, 4)->getRootDomain().at(0), + iel_graph, + root_resolution_map); + break; + default: + FAIL() << "Unexpected tensor: " << tv->toString(); + } + } +} + } // namespace nvfuser From 74c98e2699c4fabcd089683b5c74810cb82107ec Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 26 Jan 2024 16:23:48 -0800 Subject: [PATCH 129/178] Add a test --- test/test_id_model.cpp | 87 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index e5ee1188cd3..68fa4d1f058 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -568,4 +568,91 @@ TEST_F(IdModelTest, LoopGraphRootResolution7) { } } +// Same fusion as NvFuserTest.FusionIndexing20 +TEST_F(IdModelTest, LoopGraphRootResolution8) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({5}); + fusion.addInput(tv0); + + // [5] + auto tv1 = set(tv0); + auto tv2 = broadcast(tv1, {true, false}); + // [1, 5] + auto tv3 = makeConcreteTensor({3, 5}); + fusion.addInput(tv3); + auto tv4 = add(tv3, tv2); + // [3, 5] + + auto tv5 = broadcast(tv4, {false, false, true}); + // [3, 5, 1] + auto tv6 = makeConcreteTensor({3, 5, 7}); + fusion.addInput(tv6); + auto tv7 = add(tv5, tv6); + // [3, 5, 7] + fusion.addOutput(tv7); + + tv4->merge(0)->split(0, 2, false); + // [3, 5] + // [3, 3*5//2] + + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv1->inlineAt(1); + tv2->inlineAt(1); + tv4->inlineAt(1); + + // [2, 3*5//2] + tv5->merge(1)->split(1, 4, false); + // [2, 4, (3*5//2)*1//4] + tv7->merge(1)->split(1, 4, false); + // [2, 4, (3*5//2)*7//4] + tv5->inlineAt(2); + + auto all_tvs = ir_utils::allTvs(&fusion); + + IdModelTester tester(&fusion); + const auto& [iel_graph, root_resolution_map] = + tester.getInlineRootResolutionMap(); + + // Verify all tensors with broadcast have correct resolution of root + // broadcast domains + for (auto tv : all_tvs) { + // Skip tensors with no broadcast or non-inlined + if (std::none_of( + tv->getRootDomain().begin(), + tv->getRootDomain().end(), + [](auto id) { return id->isBroadcast(); }) || + tv->getComputeAtPosition() == 0) { + continue; + } + + switch (tv->name()) { + case 2: + // T2_l[ iS21{2}, iS22{( ceilDiv(( 1 * 5 ), 2) )} ] ca_pos( 1 ) + // produce_pos( 1 ) root domain : (bS2{1}, iS3{5}) + validateResolution( + tv->getRootDomain().at(0), + findTensorByName(all_tvs, 7)->getRootDomain().at(0), + iel_graph, + root_resolution_map); + break; + case 5: + // T5_l[ iS27{2}, iS40{4}, iS41{( ceilDiv(( ( ceilDiv(( 3 * 5 ), 2) ) * + // 1 ), 4) )} ] ca_pos( 2 ) produce_pos( 1 ) root domain : (iS8{3}, + // iS9{5}, bS10{1}) + validateResolution( + tv->getRootDomain().at(2), + findTensorByName(all_tvs, 7)->getRootDomain().at(2), + iel_graph, + root_resolution_map); + break; + default: + FAIL() << "Unexpected tensor: " << tv->toString(); + } + } +} + } // namespace nvfuser From ccd7bab5be675195bffe90c19858ac9f951e38ea Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 26 Jan 2024 16:29:46 -0800 Subject: [PATCH 130/178] comment --- csrc/val_graph.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/val_graph.h b/csrc/val_graph.h index f89e4d52b69..66f3892393f 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -90,8 +90,9 @@ class ValGraph { // Convert Val to its ValGroup, assert that it exists. const ValGroup& toGroup(Val* val) const; - // Convert a vector of Val* or Expr* to their ValGroups or - // ExprGroups, respectively + // Convert a vector-like container of Val* or Expr* to their + // ValGroups or ExprGroups. The vector-like container type must + // define the element type as value_type template < typename ContainerType, typename ElementType = typename std::remove_pointer< From 6505c63da865e29870a40176a1ba8d13a95dbab5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 26 Jan 2024 16:32:03 -0800 Subject: [PATCH 131/178] test cleanup --- test/test_gpu_indexing.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_gpu_indexing.cpp b/test/test_gpu_indexing.cpp index 9863864ed01..fda89b2b916 100644 --- a/test/test_gpu_indexing.cpp +++ b/test/test_gpu_indexing.cpp @@ -1131,7 +1131,11 @@ TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) { tv4->merge(0); tv4->split(0, 32); - tv0->computeAt(tv4, 1); + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv2->inlineAt(1); + tv3->inlineAt(1); tv2->split(-1, 8); From 1c111823fe638ff3e1184de0aeb6fc1f7afd4502 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 26 Jan 2024 17:12:50 -0800 Subject: [PATCH 132/178] IdModel: merge c8c21fd087737dd8fb863867e5a012116a69a731 (#1688) --- CMakeLists.txt | 2 + benchmark/matmul.cpp | 6 +- csrc/alias_analysis.cpp | 24 +-- csrc/codegen.cpp | 10 +- csrc/device_lower/lower2device.cpp | 5 + csrc/device_lower/pass/alias_memory.cpp | 4 +- csrc/device_lower/pass/grid_serialization.cpp | 182 ++++++++++++++++++ csrc/device_lower/pass/grid_serialization.h | 26 +++ csrc/device_lower/pass/index.cpp | 108 +++++++++++ csrc/device_lower/pass/index.h | 5 + csrc/device_lower/pass/insert_syncs.cpp | 30 +-- csrc/device_lower/pass/replace_size.cpp | 11 ++ csrc/device_lower/utils.cpp | 22 +++ csrc/device_lower/utils.h | 8 +- csrc/device_lower/validation.cpp | 24 +++ csrc/device_lower/validation.h | 3 + csrc/dynamic_transform.cpp | 90 +++++++++ csrc/executor_utils.cpp | 9 +- csrc/fusion_segmenter.cpp | 2 + csrc/id_model/id_model.cpp | 5 +- csrc/id_model/id_model.h | 14 +- csrc/ir/interface_nodes.h | 4 - csrc/ir/internal_nodes.h | 16 ++ csrc/ir/nodes.cpp | 1 + csrc/kernel_ir.cpp | 2 +- csrc/kernel_ir.h | 2 +- csrc/multidevice/utils.cpp | 8 +- csrc/ops/arith.cpp | 28 +-- .../exact_mapped_extent_substitution.cpp | 100 ++++++++++ .../exact_mapped_extent_substitution.h | 27 +++ csrc/optimization/mark_aliases_prepare.cpp | 11 +- csrc/optimization/pre_segmenter.cpp | 2 + csrc/scheduler/matmul.cpp | 2 + csrc/scheduler/mma_utils.cpp | 3 +- csrc/scheduler/registry_utils.cpp | 7 - csrc/tensor_view.cpp | 9 - .../test_dropout_layernorm_bwd.py | 146 ++++++++++++++ .../test_dropout_layernorm_fwd.py | 102 ++++++++++ python_benchmarks/test_dropout_rmsnorm_bwd.py | 132 +++++++++++++ python_benchmarks/test_dropout_rmsnorm_fwd.py | 104 ++++++++++ .../test_huggingface_attn_bwd.py | 1 - .../test_huggingface_attn_fwd.py | 1 - python_benchmarks/test_nanogpt_attn_bwd.py | 1 - python_benchmarks/test_nanogpt_attn_fwd.py | 1 - test/test_alias.cpp | 71 +++++++ test/test_dynamic_transform.cpp | 75 ++++++++ test/test_gpu3.cpp | 42 ++++ test/test_pipeline.cpp | 90 +++++---- test/test_serial_gridreduce.cpp | 93 +++++++++ 49 files changed, 1527 insertions(+), 144 deletions(-) create mode 100644 csrc/device_lower/pass/grid_serialization.cpp create mode 100644 csrc/device_lower/pass/grid_serialization.h create mode 100644 csrc/optimization/exact_mapped_extent_substitution.cpp create mode 100644 csrc/optimization/exact_mapped_extent_substitution.h create mode 100644 python_benchmarks/test_dropout_layernorm_bwd.py create mode 100644 python_benchmarks/test_dropout_layernorm_fwd.py create mode 100644 python_benchmarks/test_dropout_rmsnorm_bwd.py create mode 100644 python_benchmarks/test_dropout_rmsnorm_fwd.py diff --git a/CMakeLists.txt b/CMakeLists.txt index dbf692d0850..403232a8343 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -117,6 +117,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/device_lower/pass/double_buffer.cpp ${NVFUSER_SRCS_DIR}/device_lower/pass/expr_sort.cpp ${NVFUSER_SRCS_DIR}/device_lower/pass/fusion_simplifier.cpp + ${NVFUSER_SRCS_DIR}/device_lower/pass/grid_serialization.cpp ${NVFUSER_SRCS_DIR}/device_lower/pass/index.cpp ${NVFUSER_SRCS_DIR}/device_lower/pass/scalar_hoist.cpp ${NVFUSER_SRCS_DIR}/device_lower/pass/insert_syncs.cpp @@ -198,6 +199,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp ${NVFUSER_SRCS_DIR}/optimization/add_axioms.cpp ${NVFUSER_SRCS_DIR}/optimization/consecutive_cast.cpp + ${NVFUSER_SRCS_DIR}/optimization/exact_mapped_extent_substitution.cpp ${NVFUSER_SRCS_DIR}/optimization/mark_aliases_prepare.cpp ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp diff --git a/benchmark/matmul.cpp b/benchmark/matmul.cpp index 3a4fb8ae16a..4a1cc265f85 100644 --- a/benchmark/matmul.cpp +++ b/benchmark/matmul.cpp @@ -685,11 +685,7 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) { ForAllLayouts(EagerModeBenchmark); ForAllLayouts(NvfuserMatmulBenchmark); -// Disable split-K benchmarks due to slow compilation. -// See https://github.com/NVIDIA/Fuser/issues/1389. -// These benchmarks should be enabled again after merging -// https://github.com/NVIDIA/Fuser/pull/1510 -// ForAllLayouts(AutoSplitKBenchmark); +ForAllLayouts(AutoSplitKBenchmark); ForAllLayouts(AutoPartitionedKBenchmark); // Note: SplitK Reduction benchmarks are parametrized only by M, N. The splitk diff --git a/csrc/alias_analysis.cpp b/csrc/alias_analysis.cpp index 9bb37bc326e..5c42d07a81c 100644 --- a/csrc/alias_analysis.cpp +++ b/csrc/alias_analysis.cpp @@ -173,6 +173,14 @@ void AliasFinder::handle(const ViewOp* view) { LinkedHashMap> allocation_to_contiguity; for (const auto i : c10::irange(out_root_layout->size())) { + if (!out_root_layout->contiguity[i].has_value() && + !out_root_layout->allocation_domain[i]->isBroadcast()) { + // TODO(#1126): Due to #1126, `out_root` materializes an expanded + // broadcast IterDomain from `in_rfactor` when `view` splits or merges + // that IterDomain. We return no alias when this happen; otherwise + // AliasTest.MergeBroadcastsBetweenConcretes would fail. + return; + } allocation_to_contiguity.pushBack( out_root_layout->allocation_domain[i], out_root_layout->contiguity[i]); } @@ -181,18 +189,6 @@ void AliasFinder::handle(const ViewOp* view) { // `allocation_to_contiguity`. Stop when an `Expr` requires a data copy; // otherwise generate the allocation order of `out_rfactor` and the // corresponding contiguity flags. - std::unordered_map out_root_to_in_rfactor = - PairwiseRootDomainMap(in, out).mapConsumerToProducer(); - auto has_expanded_extent = [&out_root_to_in_rfactor](IterDomain* id) -> bool { - // TODO(#1174): Preserve expanded extents in `out_root` so we don't have to - // look for expanded extents in `in_rfactor`. - if (const auto i = out_root_to_in_rfactor.find(id); - i != out_root_to_in_rfactor.end()) { - id = i->second; - } - return id->hasExpandedExtent(); - }; - const std::vector& out_root = out->getRootDomain(); const std::vector& out_rfactor = out->getMaybeRFactorDomain(); for (Expr* transform : DependencyCheck::getAllExprsBetween( @@ -217,9 +213,9 @@ void AliasFinder::handle(const ViewOp* view) { const auto [inner_contiguity, merge_i] = allocation_to_contiguity.erase(merge->inner()); const auto [mergeable, contiguity] = mergeContiguity( - has_expanded_extent(merge->outer()), + merge->outer()->hasExpandedExtent(), outer_contiguity, - has_expanded_extent(merge->inner()), + merge->inner()->hasExpandedExtent(), inner_contiguity); if (!mergeable) { return; diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 30d92442b19..9b68cf4fdb5 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1604,17 +1604,17 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { const auto data_type = grop->out()->dtype(); const auto op_type = grop->getReductionOpType(); + if (grop->isSerial()) { + generateSerialGridReduction(grop); + return; + } + NVF_ERROR(grop->reduction_buffer()->buffer()->isA()); NVF_ERROR(grop->sync_buffer()->buffer()->isA()); const auto work_buffer = grop->reduction_buffer()->buffer()->as(); const auto sync_buffer = grop->sync_buffer()->buffer()->as(); - if (grop->isSerial()) { - generateSerialGridReduction(grop); - return; - } - if (grop->isAllreduce()) { generateGridAllreduce(grop); return; diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 3fc0cea3e23..cb1f65f958f 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -264,6 +265,7 @@ GpuLower::GpuLower(Fusion* fusion, const CompileParams& cparams) // const std::vector& and return a std::vector. {{"LoopNestGenerator", LoopNestGenerator::loweredExprs}, {"loadStoreOpInserter", loadStoreOpInserter}, + {"insertGridSerializationSyncs", insertGridSerializationSyncs}, {"insertAllocations", insertAllocations}, {"insertRawThreadSynchronization", insertRawThreadSynchronization}, {"reuseMemoryAllocations", reuseMemoryAllocations}, @@ -422,6 +424,9 @@ void GpuLower::analysis(Fusion* fusion) { validateResize(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "validateResize"); + validateReductions(fusion_); + dumpExprsIfEnabled(fusion_->exprs(), "validateReductions"); + // Compute thread predicates. Depends on parallel_dimension_map_ thread_pred_map_.build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build thread_pred_map_"); diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index 9ad459ac28f..cb1a2a880c3 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -876,10 +876,10 @@ class AllocationInfoMap : private kir::IrVisitor { alloc_info->outer_live_interval->markWrite(write_pos); } else if (auto inval = dynamic_cast(expr)) { auto alloc_info = getAllocInfoFromTV(inval->mbarrier()->as()); - alloc_info->inner_live_interval->markWrite(expr_pos); + alloc_info->inner_live_interval->markRead(expr_pos); auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info); auto write_pos = outer_loop_info ? outer_loop_info->start_pos : expr_pos; - alloc_info->outer_live_interval->markWrite(write_pos); + alloc_info->outer_live_interval->markRead(write_pos); } } diff --git a/csrc/device_lower/pass/grid_serialization.cpp b/csrc/device_lower/pass/grid_serialization.cpp new file mode 100644 index 00000000000..6508b030788 --- /dev/null +++ b/csrc/device_lower/pass/grid_serialization.cpp @@ -0,0 +1,182 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace nvfuser { + +namespace { + +//! Insert needed syncs in order to serialize blocks for serial grid reduction. +//! +//! We inspect the loop nest up to the point after which all outer loops are +//! trivial. This corresponds to the outer-most loop in the generated +//! kernel code. +//! +//! Conditions for using serial grid reduction: +//! - All reduction dimensions are either unparallelized or parallelized as +//! BID, not TID. Block and warp reductions could be allowed in the future, +//! but the current focus is on cases where all threads are doing separate +//! reductions simultaneously. +//! - rop is not an allreduce. Note that we could implement serial allreduce +//! but it would require inserting a separate grid sync after this outer +//! loop. +//! - There are no other reductions in this loop nest that are TID or BID +//! parallelized, unless they also satisfy the conditions above and their +//! reduction pattern matches this one. Otherwise our syncs will be +//! mismatched, and there is no good way to handle that yet. +class GridSerializationSyncInserter : kir::ExprMutator { + public: + GridSerializationSyncInserter(const std::vector& exprs) { + kir::ExprMutator::traverseAndInsert(exprs); + } + + static std::vector insert(const std::vector& exprs) { + GridSerializationSyncInserter inserter(exprs); + return inserter.exprs_; + } + + private: + using kir::ExprMutator::dispatch; + using kir::ExprMutator::handle; + + //! Record cur_expr_sync_pattern_ if this is a serial grid reduction + void handle(ReductionOp* rop) override { + if (rop->serialGridReductionRequested()) { + ParallelTypeBitmap sync_pattern; + auto out = rop->out()->as(); + NVF_ERROR(out != nullptr); + for (int i : c10::irange((int)out->nDims())) { + IterDomain* ax = out->axis(i); + if (!ax->isReduction()) { + continue; + } + NVF_ERROR( + !ax->isThreadDim(), + "Serial grid reduction cannot be applied with block reductions: ", + rop->toString()); + if (ax->isBlockDim()) { + sync_pattern.set(ax->getParallelType()); + } + } + + if (!sync_pattern.hasBID()) { + // Don't set cur_expr_sync_pattern_ since this is not actually a grid + // reduction + return; + } + + if (cur_expr_sync_pattern_.has_value()) { + NVF_ERROR( + cur_expr_sync_pattern_.value() == sync_pattern, + "Reduction op ", + rop->toString(), + " has requested serial grid reduction, but pattern ", + sync_pattern.toString(), + " conflicts with previous pattern: ", + cur_expr_sync_pattern_.value().toString()); + } else { + cur_expr_sync_pattern_ = sync_pattern; + } + } + } + + void dispatch(Expr* expr) override { + // We will detect top-level exprs here that require serialization and + // insert the required syncs before and after those exprs. + if (auto loop = dynamic_cast(expr); + cur_top_level_expr_ != nullptr || (loop && loop->isTrivial())) { + // Never sync around trivial loops since they do not appear in the + // generated CUDA code. Also avoid redefining cur_top_level_expr_ if it + // is already set, which indicates that this expression is contained in + // an outer non-trivial loop. + kir::ExprMutator::dispatch(expr); + return; + } + // Any other expr, i.e. non-trivial loops or regular Exprs, can be synced if + // it is top-level and either is or contains a serial grid reduction + cur_top_level_expr_ = expr; + // If a serial grid reduction was found when traversing expr, then + // cur_expr_sync_pattern_ will be set + cur_expr_sync_pattern_ = std::nullopt; + kir::ExprMutator::dispatch(expr); + if (cur_expr_sync_pattern_.has_value()) { + insertSyncs(); + } + // reset state variables + cur_top_level_expr_ = nullptr; + cur_expr_sync_pattern_ = std::nullopt; + } + + void insertSyncs() { + NVF_ERROR(cur_top_level_expr_ != nullptr); + NVF_ERROR(cur_expr_sync_pattern_.has_value()); + kir::Allocate* alloc = lower_utils::allocGlobalBufferForGridComm( + lower_utils::getGridSyncBufferSize(cur_expr_sync_pattern_.value()), + DataType::Int, + true); + auto wait = IrBuilder::create( + cur_expr_sync_pattern_.value(), alloc->buffer()); + registerInsertBefore(cur_top_level_expr_, alloc); + registerInsertBefore(cur_top_level_expr_, wait); + auto release = IrBuilder::create( + cur_expr_sync_pattern_.value(), alloc->buffer()); + registerInsertAfter(cur_top_level_expr_, release); + } + + private: + //! Which Expr* is the current top-level containing the current Expr in the + //! generated kernel. When serial reductions are encountered, this expression + //! determines where we will place syncs: they will be placed before and after + //! this expression. + //! + //! For example, if we have + //! + //! FOR iBlockIdx.x + //! FOR iS{32} + //! y = neg(x); + //! ENDFOR iS{32} + //! FOR iThreadIdx.x + //! z = add(y, x); + //! ENDFOR iThreadIdx.x + //! ENDFOR iBlockIdx.x + //! + //! then when we are processing the `neg` Expr, cur_top_level_expr_ will be + //! the FOR iS{32} loop. However, when processing the `add` expression, + //! cur_top_level_expr_ will be nullptr since that expression itself will + //! appear in the main scope of the generated kernel. + //! + //! IfThenElse are treated similar to unparallelized ForLoops; if an + //! IfThenElse is at the top level, or is contained in a fully parallelized + //! loop nest, it is treated as a top level expr here. Note that this pass + //! will likely be run before any IfThenElse are placed in the kernel anyway. + Expr* cur_top_level_expr_ = nullptr; + + //! If a serial grid reduction is found for the current expr, this indicates + //! parallel axes that are mapped to reduction domains in the serial + //! reduction. + std::optional cur_expr_sync_pattern_ = std::nullopt; +}; + +} // namespace + +std::vector insertGridSerializationSyncs( + const std::vector& exprs) { + FUSER_PERF_SCOPE("GpuLower::Lower::insertGridSerializationSyncs"); + return GridSerializationSyncInserter::insert(exprs); +} + +} // namespace nvfuser diff --git a/csrc/device_lower/pass/grid_serialization.h b/csrc/device_lower/pass/grid_serialization.h new file mode 100644 index 00000000000..725ce969848 --- /dev/null +++ b/csrc/device_lower/pass/grid_serialization.h @@ -0,0 +1,26 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +#include +#include + +#include + +namespace nvfuser { + +//! Detect ReductionOps that have serialGridReductionRequested() == true. When +//! found, confirm that no conflicting operations exist, then place sync nodes +//! before and after outer-most non-parallelized loop. +std::vector insertGridSerializationSyncs( + const std::vector& exprs); + +} // namespace nvfuser diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index b976bfa1c16..cc5dcef5424 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include @@ -611,10 +612,107 @@ void IndexLowering::handleBlockReduction( GpuLower::current()->propagateExprInfo(rop, back()); } +void IndexLowering::handleSerialGridReduction( + const ReductionOp* rop, + Val* out, + Val* in) { + const auto out_tv = out->as()->view(); + const auto out_domain = out_tv->domain(); + + // If we do a grid reduction we can't have a reduction axis that is not bound + // to a grid or block dim. + NVF_ERROR( + std::none_of( + out_domain->leaf().begin(), + out_domain->leaf().end(), + [](IterDomain* id) { + return !id->isThread() && id->isReduction() && + !id->extent()->isOneInt(); + }), + "Found a reduction stage that has both a non-parallelized ", + "reduction and a grid reduction. This is not supported, ", + "please use rfactor to do the serialized reduction first, ", + "then the grid reduction. ", + rop->toString()); + + NVF_ERROR(!rop->isAllreduce(), "Serial grid allReduce is not implemented"); + + // Allocate global work buffer TensorIndex. + // + // For convenience, the global work buffer is allocated like the leaf domain + // of the ReductionOp output. In the future, we may want the allocation + // domain to be different in order to enable re-use of global output buffers + // for in-place reduction. + std::vector work_buffer_root; + work_buffer_root.reserve(out_tv->nDims()); + for (IterDomain* id : out_tv->getLeafDomain()) { + work_buffer_root.push_back(IterDomainBuilder(id).build()); + } + auto work_buffer_domain = IrBuilder::create(work_buffer_root); + auto work_buffer_tv = IrBuilder::create( + work_buffer_domain, out_tv->dtype(), MemoryType::Global); + Val* work_buffer_idx_val = nullptr; + for (auto v : + Index::getGlobalConsumerStridedIndices(out_tv, for_loops_, {})) { + work_buffer_idx_val = SimplifyingIrBuilder::addExpr(work_buffer_idx_val, v); + } + + auto work_buffer_idx = IrBuilder::create( + work_buffer_tv, + GpuLower::current()->commonScalarMap().hoistScalar( + work_buffer_idx_val, for_loops_)); + + auto work_alloc = IrBuilder::create( + work_buffer_tv, work_buffer_tv->getMemoryType()); + pushBack(work_alloc); + + // The thread predicate for GridReduction needs to be set + // separately from the main predicate. Do not combine them like + // other expressions. + const auto& thread_pred = + GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); + + auto serial_grid_reduction = IrBuilder::create( + rop->getReductionOpType(), + rop->init(), + out, + in, + // skip work_buffer, sync_buffer, entrance_ind, n_entrances for serial + // reduction node + nullptr, + nullptr, + nullptr, + nullptr, + false, + work_buffer_idx); + + serial_grid_reduction = + serial_grid_reduction->withThreadPredicate(thread_pred); + + if (rop->predicate()) { + serial_grid_reduction = + serial_grid_reduction->withPredicate(rop->predicate()) + ->as(); + } + if (rop->writePredicate()) { + serial_grid_reduction = + serial_grid_reduction->withWritePredicate(rop->writePredicate()) + ->as(); + } + + pushBack(serial_grid_reduction); + GpuLower::current()->propagateExprInfo(rop, back()); +} + void IndexLowering::handleGridReduction( const ReductionOp* rop, Val* out, Val* in) { + if (rop->serialGridReductionRequested()) { + handleSerialGridReduction(rop, out, in); + return; + } + const auto out_tv = out->as()->view(); const auto out_domain = out_tv->domain(); @@ -1630,6 +1728,16 @@ void IndexLowering::handle(const kir::AsyncCommit* commit) { pushBack(const_cast(commit)); // NOLINT } +void IndexLowering::handle(const kir::BlockSerializeWait* sync) { + // TODO(kir): remove the need for const_cast + pushBack(const_cast(sync)); // NOLINT +} + +void IndexLowering::handle(const kir::BlockSerializeRelease* sync) { + // TODO(kir): remove the need for const_cast + pushBack(const_cast(sync)); // NOLINT +} + void IndexLowering::generate(const std::vector& exprs) { for (auto expr : exprs) { OptOutConstDispatch::dispatch(expr); diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index f8e5b0fddbc..86f18bf64b7 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -82,6 +82,8 @@ class IndexLowering : private OptOutConstDispatch { void handle(const kir::MBarrierInvalidate*) final; void handle(const kir::AsyncWait*) final; void handle(const kir::AsyncCommit*) final; + void handle(const kir::BlockSerializeWait*) final; + void handle(const kir::BlockSerializeRelease*) final; void generate(const std::vector& exprs); @@ -120,6 +122,9 @@ class IndexLowering : private OptOutConstDispatch { void handleBlockReduction(const ReductionOp* rop, Val* out, Val* in); void handleGridReduction(const ReductionOp* rop, Val* out, Val* in); + //! Called by handleGridReduction, this returns true if rop is lowered as a + //! serial grid reduction. + void handleSerialGridReduction(const ReductionOp* rop, Val* out, Val* in); void handleBlockReduction( const GroupedReductionOp* rop, diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 5dc28e41ef5..866b73d8771 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -369,32 +369,6 @@ class ValidatePlacementAfterWrites : private kir::IrVisitor { const std::unordered_set& writes_; }; -namespace { - -Val* getGridSyncBufferSize(const ParallelTypeBitmap& ptb) { - // See the comment above for getGridCommWorkBufferSize. - NVF_ERROR( - ptb.hasBID(), - "Detected needing a grid sync but no grid bits set in bitmap."); - Val* buffer_size = GpuLower::current()->kernel()->oneVal(); - for (auto pt : kParallelTypeBIDs) { - // Synchronized within pt, so all blocks of this PT use the same - // sync buffer location, and thus no need to expand the sync - // buffer size. - if (ptb.get(pt)) { - continue; - } - auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); - if (pt_dim == nullptr || pt_dim->isOneInt()) { - continue; - } - buffer_size = IrBuilder::mulExpr(buffer_size, pt_dim); - } - return buffer_size; -} - -} // namespace - class ReadAfterWriteSyncs : public kir::ExprMutator { private: using kir::ExprMutator::handle; @@ -489,7 +463,9 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { kir::Allocate* maybe_alloc = nullptr; if (sync_bitmap.hasBID()) { maybe_alloc = lower_utils::allocGlobalBufferForGridComm( - getGridSyncBufferSize(sync_bitmap), DataType::Int, true); + lower_utils::getGridSyncBufferSize(sync_bitmap), + DataType::Int, + true); sync_expr = IrBuilder::create( sync_bitmap, maybe_alloc->buffer()); } else { diff --git a/csrc/device_lower/pass/replace_size.cpp b/csrc/device_lower/pass/replace_size.cpp index 405b6859545..cc4fd3302b2 100644 --- a/csrc/device_lower/pass/replace_size.cpp +++ b/csrc/device_lower/pass/replace_size.cpp @@ -186,6 +186,17 @@ void replaceSymbolicSizes(Fusion* fusion) { } } + // After ExactMappedExtentSubstitutionPass, different inputs and outputs may + // have same root domain extents e.g. T1[{i0}, {i2}], T2[{i2}]. When maping + // {i2}, we want to use the lower labeled tensor size "T1.size[1]", instead of + // "T2.size[0]". + std::sort( + inputs_and_outputs.begin(), + inputs_and_outputs.end(), + [](const TensorView* a, const TensorView* b) { + return a->name() < b->name(); + }); + // Generate map for all tensorview root domain values to map them to symbolic // values. i.e. T0->getRootDomain()[0] would map to a named scalar // "T0.size[0]". This map will be used when lowering fusion ir to kernel ir. diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index b76c769af9d..140e2ddb0a4 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -830,6 +830,28 @@ Val* u32IndexScalarSmemTv(TensorView* smem_tv) { return u32addr; } +Val* getGridSyncBufferSize(const ParallelTypeBitmap& ptb) { + // See the comment above for getGridCommWorkBufferSize. + NVF_ERROR( + ptb.hasBID(), + "Detected needing a grid sync but no grid bits set in bitmap."); + Val* buffer_size = GpuLower::current()->kernel()->oneVal(); + for (auto pt : kParallelTypeBIDs) { + // Synchronized within pt, so all blocks of this PT use the same + // sync buffer location, and thus no need to expand the sync + // buffer size. + if (ptb.get(pt)) { + continue; + } + auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); + if (pt_dim == nullptr || pt_dim->isOneInt()) { + continue; + } + buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim); + } + return buffer_size; +} + } // namespace lower_utils } // namespace nvfuser diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index 39596fb77d2..749a6fbd4c5 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -298,10 +298,14 @@ bool isScalarExpr(Expr* expr); //! IterDomain object. bool isExtentEqualToMaxParallelTypeExtent(const IterDomain* id); -// Get the uint32_t index of a scalar TensorView. This is usually used for -// indexing special items in shared memory, like mbarrier. +//! Get the uint32_t index of a scalar TensorView. This is usually used for +//! indexing special items in shared memory, like mbarrier. Val* u32IndexScalarSmemTv(TensorView* tv); +//! Get the size of a global sync buffer needed to perform a grid reduction for +//! each axis in bitmap. +Val* getGridSyncBufferSize(const ParallelTypeBitmap& bitmap); + } // namespace lower_utils } // namespace nvfuser diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 9cb4cc1200a..d95aaeae4a8 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -1290,4 +1290,28 @@ void validateResize(Fusion* fusion) { } } +void validateReductions(Fusion* fusion) { + for (auto rop : ir_utils::getOpsOfType(fusion)) { + auto in = rop->in()->as(); + auto out = rop->out()->as(); + PairwiseRootDomainMap c2p_map(in, out); + c2p_map.mapBroadcast(true); + auto c2p = c2p_map.mapConsumerToProducer(); + for (auto out_id : out->getRootDomain()) { + if (out_id->isReduction()) { + auto in_it = c2p.find(out_id); + NVF_ERROR( + in_it != c2p.end(), + "Could not find producer IterDomain mapped to ", + out_id->toString()); + IterDomain* in_id = in_it->second; + NVF_ERROR( + !in_id->isBroadcast() || in_id->hasExpandedExtent(), + "Reductions of unexpanded broadcast domains should be ", + "converted to squeeze before lowering."); + } + } + } +} + } // namespace nvfuser diff --git a/csrc/device_lower/validation.h b/csrc/device_lower/validation.h index 643f924ec21..e9b068bf1d2 100644 --- a/csrc/device_lower/validation.h +++ b/csrc/device_lower/validation.h @@ -80,4 +80,7 @@ void validateLookupTV(Fusion* fusion); //! Validate resize usage void validateResize(Fusion* fusion); +//! Check that there are no reductions over unexpanded broadcasts +void validateReductions(Fusion* fusion); + } // namespace nvfuser diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 542227380ec..b23e5a781ef 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -539,6 +540,8 @@ class DynamicTransformConcretizer : public OptOutMutator { void mutate(TensorDomain* td) final; + void mutate(Expr* expr) final; + //! Concretizes the root domain of a symbolic consumer tensor from //! its producer domains. Returns true if any root ID is concretized. bool propagateFromProducerToConsumer(TensorView* consumer); @@ -597,6 +600,13 @@ void DynamicTransformConcretizer::concretize() { } OptOutMutator::dispatchMutate(stmt); } + + for (Val* outp : info_->fusion()->outputs()) { + Val* new_outp = maybeMutated(outp); + if (new_outp != outp) { + info_->fusion()->replaceOutput(outp, new_outp); + } + } } void DynamicTransformConcretizer::concretizeEmptyExtents() { @@ -943,6 +953,86 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { registerConcretization(td, mutated_val); } +//! Returns whether a reduction has any trivial partial reductions. Modifies +//! reduction_axes in place to insert indices of non-trivial reduction axes, +//! relative to squeezed input. +static bool hasTrivialReduction( + TensorView* in, + TensorView* out, + std::vector& reduction_axes) { + bool has_trivial_reduction = false; + PairwiseRootDomainMap p2c_map(in, out); + // We need to map broadcasts in order to detect reductions of broadcasts + p2c_map.mapBroadcast(true); + auto p2c = p2c_map.mapProducerToConsumer(); + int pos = -1; + for (IterDomain* in_id : + TensorDomain::noReductions(in->getMaybeRFactorDomain())) { + ++pos; + auto out_it = p2c.find(in_id); + if (out_it == p2c.end()) { + continue; + } + IterDomain* out_id = out_it->second; + if (out_id->isReduction()) { + reduction_axes.push_back(pos); + if (in_id->isBroadcast() && !in_id->hasExpandedExtent()) { + has_trivial_reduction = true; + } + } + } + return has_trivial_reduction; +} + +// Maybe insert SqueezeOps on inputs of ReductionOp, to simplify trivial +// reductions. +void DynamicTransformConcretizer::mutate(Expr* expr) { + if (ReductionOp* rop = dynamic_cast(expr); rop) { + auto* in = rop->in()->as(); + auto* orig_out = rop->out()->as(); + std::vector reduction_axes; + if (hasTrivialReduction(in, orig_out, reduction_axes)) { + // There is at least one trivial reduction that should be squeezed. Use + // binaryOp to ensure this is done exactly as it is in a non-dynamic + // fusion + // + // Note that keepdim=false always here, since that results in downstream + // broadcasts which will already have been inserted. + TensorView* new_out = reductionOp( + rop->getReductionOpType(), + reduction_axes, + rop->init(), + in, + /*keep_dim=*/false, + orig_out->dtype()); + registerConcretization(orig_out, new_out); + } + } else if (WelfordOp* wop = dynamic_cast(expr); wop) { + auto in = wop->in()->as(); + auto orig_avg = wop->outAvg()->as(); + + std::vector reduction_axes; + if (hasTrivialReduction(in, orig_avg, reduction_axes)) { + // Use Welford to ensure this is done exactly as it is in a non-dynamic + // fusion + WelfordResult new_result = Welford( + in, + reduction_axes, + // For avg and variance to be default initialized, they should be + // given as nullptr. In that case, this constructor actually sets them + // as a scalar 0. Here we use that to detect whether they are + // initialized or not. + dynamic_cast(wop->initAvg()), + dynamic_cast(wop->initVar()), + wop->initN()); + registerConcretization(orig_avg, new_result.avg); + registerConcretization(wop->outVar(), new_result.var_sum); + registerConcretization(wop->outN(), new_result.n); + } + } + OptOutMutator::mutate(expr); +} + bool DynamicTransformConcretizer::propagateFromProducerToConsumer( TensorView* consumer) { if (consumer->definition() == nullptr || diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index f1da61ea206..a8df8b1f70c 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -843,7 +843,14 @@ class NvrtcCompileDriver { char* log_buf = log_backing_buf.data(); NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(program, log_buf)); if (result != NVRTC_SUCCESS) { - NVF_ERROR(false, src, "\nCUDA NVRTC compile error: ", log_buf); + // Print CUDA starting at first global function + size_t kernel_start = src.find("__global__"); + NVF_ERROR( + false, + "\n", + src.substr(kernel_start), + "\nCUDA NVRTC compile error: ", + log_buf); } if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) { debug() << log_buf << std::endl; diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 870ce4fe15c..444f6772dd8 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1877,6 +1877,8 @@ std::unique_ptr SegmentCandidateFinder::segment( } } if (fusion) { + scheduler_debug_utils::canScheduleMessage( + "***Runtime***: Has segment hints, try to schedule fusion segmented:\n"); return SegmentCandidateFinder::segment(std::move(fusion), inputs); } else { NVF_ERROR(false, "unreachable!"); diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index dd732fcd64a..1df1972d228 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -1010,7 +1010,6 @@ void IdModel::buildLoopGraph() { VERBOSE() << ss.str(); } - // Gather broadcast resolution and inlining information const StatefulInliningInfo inlining_info = buildStatefulInliningInfo( tv_exprs_, idGraph(IdMappingMode::EXACT), @@ -1249,8 +1248,8 @@ void IdModel::buildAllGraphs() { validator->checkExactGraphEquivalence(idGraph(IdMappingMode::EXACT)); } - // Make sure there's no self mapping in TensorView's during lowering - // that would invalidate lowering assumptions. + // Make sure there's no self mapping in the Exact graph as that + // would invalidate lowering assumptions. self_mapping_info_ = findFirstSelfMapping(tvs_, *this); if (!allow_self_mapping_) { assertNoSelfMapping(); diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 4bc243cb7ff..05b2137ae04 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -122,7 +122,8 @@ class IdModel : public PolymorphicBase { bool allow_self_mapping = false, bool validate = true); - // Returns iter domain graph of provided mode. + // Returns iter domain graph of provided mode. The graph must have + // been already built. const ValGraph& idGraph(IdMappingMode mode) const; ValGraph& idGraph(IdMappingMode mode); @@ -142,7 +143,8 @@ class IdModel : public PolymorphicBase { std::string toString() const; - // Build all graphs. This is by default called from the constructor + // Build all graphs, i.e., Exact, AlmostExact, Permissive and + // LOOP. This is by default called from the constructor void buildAllGraphs(); // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs @@ -328,9 +330,17 @@ class IdModel : public PolymorphicBase { IterDomain* cloneIterDomain(IterDomain* id); protected: + // All tensor expressions that this model analyzes std::vector tv_exprs_; + + // All tensors that this model analyzes std::vector tvs_; + + // Tensors should not have domains that are mapped with another + // domains of the same tensor. This flag disables the check bool allow_self_mapping_ = false; + + // If true, validate graphs by comparing them with ComputeAtMap bool validate_ = false; // Keeps ValGraphs containing all IterDomains for all mapping mode types. diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index ab9d607a7d2..54cd2b3175b 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -109,10 +109,6 @@ class TensorView : public Val { IrBuilderPasskey passkey, const std::shared_ptr& tensor_type); - explicit TensorView( - IrBuilderPasskey passkey, - const std::shared_ptr& jit_value); - TensorView(const TensorView* src, IrCloner* ir_cloner); NVFUSER_DECLARE_CLONE diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index b360b392493..151fa036b99 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -913,6 +913,22 @@ class ReductionOp : public Expr { bool isAllreduce() const { return attribute(2); } + + //! Scheduling method to request that this reduction be performed as a + //! serial grid reduction. Note that it is an error to use this method on a + //! reduction whose output has any of its reduction axes parallelized with a + //! threadIdx, even if that parallelization occurs after this method call. + //! + //! Also note that this operation should not be inlined with other reductions + //! unless they use the same parallelization pattern and they are also serial + //! gridreductions. + void requestSerialGridReduction(bool value = true) { + attribute(3) = value; + } + + bool serialGridReductionRequested() const { + return attribute(3); + } }; //! Grouped reduction operation for horizontal fusions. It works like diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 165a756f18e..e9065474985 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -1443,6 +1443,7 @@ ReductionOp::ReductionOp( addAttribute(init); addDataAttribute(reduction_op_type); addDataAttribute(is_allreduce); + addDataAttribute(false); // serial reduction } std::string ReductionOp::toString(int indent_size) const { diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index d317c988760..dddd5b4d9be 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -149,7 +149,7 @@ Allocate::Allocate( NVF_ERROR(buffer->isA()); NVF_ERROR(buffer->as()->getMemoryType() == memory_type); const auto domain = buffer->as()->domain(); - for (auto axis : domain->noReductions()) { + for (auto axis : TensorDomain::noReductions(domain->maybeAllocation())) { shape.push_back(axis->extent()); } } diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 893d32dace5..b34f9ef1c74 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -948,7 +948,7 @@ class IfThenElse final : public Expr { //! This node provides FusionExecutor the information it needs to allocate the //! reduction and sync buffers. class GridReduction final : public ReductionOp { - static constexpr int num_reduction_op_attr = 3; + static constexpr int num_reduction_op_attr = 4; public: using ReductionOp::ReductionOp; diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index e20b2da1d50..254ee8720b4 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -20,7 +20,8 @@ bool isSharded(TensorView* tv) { for (IterDomain* id : TensorDomain::noReductions(tv->getLeafDomain())) { is_sharded.push_back(id->isDeviceDim()); } - // Currently, only the most external dim is allowed to be sharded + // Currently, only the most external dim is allowed to be sharded and we don't + // allow split/merge NVF_ERROR(tv->getMaybeRFactorDomain() == tv->getLeafDomain()); for (auto i : c10::irange(1, is_sharded.size())) { NVF_ERROR( @@ -48,7 +49,9 @@ std::unordered_set haveDifferentSharding( TensorView* ref, std::unordered_set tvs) { std::unordered_set ret; - + // isSharded asserts that there are no split/merge and that only the outmost + // dimension is possibly sharded + isSharded(ref); const auto& reference_dom = ref->getLeafDomain(); FusionGuard fg(ref->fusion()); auto ca_map = ComputeAtMap(FusionGuard::getCurFusion()); @@ -60,6 +63,7 @@ std::unordered_set haveDifferentSharding( } for (auto tv : tvs) { + isSharded(tv); if (!(ref->getDeviceMesh().vector() == tv->getDeviceMesh().vector())) { ret.insert(tv); continue; diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index e75365969c0..64988264764 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2559,19 +2559,6 @@ TensorView* fusedMultiplySum( TensorView* tv_b, const std::vector& axes, Val* init) { - if (init == nullptr) { - init = IrBuilder::create(0.0); - } - - // TODO: - // We will want to support initialize and rfactor with - // mma as well, for maybe fusing bias in prolog. - // TODO: check init type if given a tv, - // not supported currently though. - NVF_CHECK( - init->isConstScalar(), - "Cannot create a reduction operation where the initial value is not a const scalar."); - // TODO: // Validate axis relationships between a and b NVF_CHECK(tv_a->nDims() > 0, "Tried to reduce a 0-dim tensor"); @@ -2596,6 +2583,21 @@ TensorView* fusedMultiplySum( canonicalizeAxes(axes, tv_a->domain()->noReductions().size()); TensorView* out = newForMma(tv_a, tv_b, uint_axes); + + if (init == nullptr) { + init = IrBuilder::create(0.0, out->dtype()); + } + + // TODO: + // We will want to support initialize and rfactor with + // mma as well, for maybe fusing bias in prolog. + NVF_CHECK( + init->isConstScalar(), + "Cannot create a reduction operation where the initial value is not a const scalar."); + NVF_CHECK( + init->dtype() == out->dtype(), + "Init value dtype for fusedMultiplySum must match output."); + IrBuilder::create(out, tv_a, tv_b, init); return out; diff --git a/csrc/optimization/exact_mapped_extent_substitution.cpp b/csrc/optimization/exact_mapped_extent_substitution.cpp new file mode 100644 index 00000000000..6ef95644b01 --- /dev/null +++ b/csrc/optimization/exact_mapped_extent_substitution.cpp @@ -0,0 +1,100 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include +#include +#include + +namespace nvfuser::optimization { + +namespace { +// Skip broadcast without expanded extent +// Skip derived domains, e.g. iS11{( i0 * i2 )}rf +// Skip domain whose extent is derived e.g. iS12{( i0 * i2 )} +// e.g. in this set { iS11{( i0 * i2 )}rf; iS12{( i0 * i2 )}; iS14{i3} } from +// NVFuserTest.SymbolicSqueeze, we can't substitute {i0 * i2} with {i3}, +// otherwise, ValidateDomainEquivalence fails. If we really want to substitute, +// we may need to skip or modify ValidateDomainEquivalence. +inline bool isNonSubstitutableID(const IterDomain* id) { + return (id->isBroadcast() && !id->hasExpandedExtent()) || id->definition() || + id->getMaybeExpandedExtent()->definition(); +} + +void exactMappedExtentSubstitution(Fusion* fusion) { + // map non-const extents to const extents + std::unordered_map replacement_map; + + const auto mapped_sets = ExactRootDomainMap(fusion).getMappedSets(); + // Loop over each exact root domain set + for (const auto& set_ptr : mapped_sets.disjointSets()) { + // (1) pick a const extent + // (2) if no const extent, pick the var with the lowest name() + Val* const_extent = nullptr; + Val* lowest_val = nullptr; + for (auto id : *set_ptr) { + if (isNonSubstitutableID(id)) { + continue; + } + // find the const extent, if already seen, check if they are the same + if (id->getMaybeExpandedExtent()->isConstScalar()) { + if (const_extent) { + NVF_CHECK( + const_extent->sameAs(id->getMaybeExpandedExtent()), + "Found two different const extents in the same set: ", + set_ptr->toString()); + } else { + const_extent = id->getMaybeExpandedExtent(); + } + } + // find the lowest name + if (!lowest_val || + id->getMaybeExpandedExtent()->name() < lowest_val->name()) { + lowest_val = id->getMaybeExpandedExtent(); + } + } + // replace with const extents. + // if no const extents, replace with the one with the lowest name. + for (auto id : *set_ptr) { + if (isNonSubstitutableID(id)) { + continue; + } + replacement_map.emplace( + id->getMaybeExpandedExtent(), + const_extent ? const_extent : lowest_val); + } + } + + // Replace non-const extents with const extents + ir_utils::replaceValue(fusion, replacement_map); +} +} // namespace + +void ExactMappedExtentSubstitutionPass::runPass(Fusion* fusion) { + if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { + debug() << "Fusion before exactMappedExtentSubstitutionPass:" << std::endl; + fusion->printMath(); + debug() << "ExactRootDomainMap before exactMappedExtentSubstitutionPass:" + << std::endl; + const auto mapped_sets = ExactRootDomainMap(fusion).getMappedSets(); + debug() << mapped_sets.toString() << std::endl; + } + + exactMappedExtentSubstitution(fusion); + + if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { + debug() << "Fusion after exactMappedExtentSubstitutionPass:" << std::endl; + fusion->printMath(); + debug() << "ExactRootDomainMap after exactMappedExtentSubstitutionPass:" + << std::endl; + const auto mapped_sets = ExactRootDomainMap(fusion).getMappedSets(); + debug() << mapped_sets.toString() << std::endl; + } +} + +} // namespace nvfuser::optimization diff --git a/csrc/optimization/exact_mapped_extent_substitution.h b/csrc/optimization/exact_mapped_extent_substitution.h new file mode 100644 index 00000000000..37be6ed8191 --- /dev/null +++ b/csrc/optimization/exact_mapped_extent_substitution.h @@ -0,0 +1,27 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +namespace nvfuser::optimization { + +// Fusion may have tensors with const extents and symbolic extents. This pass +// replaces symbolic extents with const extents if they are mapped to the exact +// same root domain set. See https://github.com/NVIDIA/Fuser/issues/1590. +// Additionaly, if there is no const extent, it replaces all symbolic extents +// with the one with the lowest name. This could simplify some cases where we +// recompute expressions inside the kernel that are known to be equal, even if +// they are not constant. +class ExactMappedExtentSubstitutionPass + : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); +}; + +} // namespace nvfuser::optimization diff --git a/csrc/optimization/mark_aliases_prepare.cpp b/csrc/optimization/mark_aliases_prepare.cpp index 5615f096967..783dd172bf9 100644 --- a/csrc/optimization/mark_aliases_prepare.cpp +++ b/csrc/optimization/mark_aliases_prepare.cpp @@ -23,10 +23,10 @@ void MarkAliasesPreparePass::runPass(Fusion* fusion) { debug() << analysis.toString(/*indent_size=*/1) << std::endl; } - // Fusion outputs that are (1) aliased by others, (2) not aliases - // themselves, and (3) not fusion inputs (yes, a fusion may trivially forward - // an input). Code will later add `segment_set` before them so aliases are - // separated from non-aliases and more likely to be accepted by the no-op + // Fusion outputs that are (1) aliased by another fusion output, (2) not + // aliases themselves, and (3) not fusion inputs (yes, a fusion may trivially + // forward an input). Code will later add `segment_set` before them so aliases + // are separated from non-aliases and more likely to be accepted by the no-op // scheduler. std::unordered_set aliased_outs; @@ -36,7 +36,8 @@ void MarkAliasesPreparePass::runPass(Fusion* fusion) { continue; } - if (aliased_io->isFusionOutput() && !aliased_io->isFusionInput() && + if (tv->isFusionOutput() && aliased_io->isFusionOutput() && + !aliased_io->isFusionInput() && analysis.getNearestAliasedIo(aliased_io) == nullptr) { aliased_outs.insert(aliased_io); } diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index 9d4c5c5347c..de1a401dd91 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -21,6 +22,7 @@ void PreSegmenter::runPass(Fusion* fusion) { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); } } // namespace nvfuser::optimization diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 366e0821b66..83b310073ff 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -899,6 +899,8 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { splitk_sum = mma_result; mma_result = splitk_sum->rFactor({-4, -1}); + splitk_sum->definition()->as()->requestSerialGridReduction(); + num_splitk_dims = 1; } diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 67b44d78b5c..36f8da4a424 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1288,8 +1288,9 @@ RolesMapOpt getTensorsRoles(Fusion* fusion) { namespace { void addMMAOp(Fusion* fusion_, std::vector& props) { - auto* init = IrBuilder::create(0.0); for (auto prop : props) { + auto* init = + IrBuilder::create(0.0, prop.insouts.out->getDataType().value()); IrBuilder::create( prop.insouts.out, prop.insouts.a, prop.insouts.b, init); } diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index 70b7378ae43..aa710545d43 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -229,13 +229,6 @@ bool isConnectedFusionGraph(Fusion* fusion) { } } - // Map aliased outputs - for (Val* out : fusion->outputs()) { - if (Val* in = fusion->getOutputAlias(out).first; in != nullptr) { - component_sets.mapEntries(out, in); - } - } - // Check connected-ness: // If there is no independent compute flow // on this fusion graph, all outputs will be diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index db39b9b5f2a..70ebfb8fdd4 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -173,15 +173,6 @@ TensorView::TensorView( domain_ = IrBuilder::create(sizes, contig_info); } -TensorView::TensorView( - IrBuilderPasskey passkey, - const std::shared_ptr& jit_value) - : TensorView(passkey, jit_value->type()->cast()) { - NVF_ERROR( - !container()->isA(), - "Function invalid for kernel container."); -} - NVFUSER_DEFINE_CLONE(TensorView) std::string TensorView::toString(int indent_size) const { diff --git a/python_benchmarks/test_dropout_layernorm_bwd.py b/python_benchmarks/test_dropout_layernorm_bwd.py new file mode 100644 index 00000000000..5e8737d21fe --- /dev/null +++ b/python_benchmarks/test_dropout_layernorm_bwd.py @@ -0,0 +1,146 @@ +import pytest +from nvfuser import FusionDefinition, DataType +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype +from .core import run_benchmark, clear_cuda_cache +import torch +from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES + + +def dropout_layernorm_bwd_fusion( + fd: FusionDefinition, dtype: DataType, dropout_p: float +) -> None: + """ + Backward pass fusion definition for computing: + output = layernorm (input + dropout (input p=dropout_p)) + + Fusion inputs: input, dropout_mask, rms, grads, weights + Fusion outputs: grad_input, grad_weights, grad_bias + """ + T1 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=DataType.Bool, is_cpu=False + ) # mask + T2 = fd.define_tensor( + shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False + ) # mean + T3 = fd.define_tensor( + shape=[-1, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False + ) # invstd + T4 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) # grads + T5 = fd.define_tensor( + shape=[-1], contiguity=[True], dtype=dtype, is_cpu=False + ) # weights + T6 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) # inputs + if dtype in PROMOTE_DTYPES: + T1 = fd.ops.cast(T1, dtype=DataType.Float) + T4 = fd.ops.cast(T4, dtype=DataType.Float) + T5 = fd.ops.cast(T5, dtype=DataType.Float) + T6 = fd.ops.cast(T6, dtype=DataType.Float) + + T9 = fd.ops.mul(T6, T1) + S10 = fd.define_scalar(1 / (1 - dropout_p), dtype=DataType.Double) + T11 = fd.ops.mul(T9, S10) + T12 = fd.ops.add(T6, T11) + + V15 = fd.define_vector([T6.size(0), 1], dtype=DataType.Int) + T16 = fd.ops.broadcast_in_dim(T2, shape=V15, broadcast_dims=[0]) + V19 = T6.shape() + T20 = fd.ops.broadcast_in_dim(T16, shape=V19, broadcast_dims=[0, 1]) + T21 = fd.ops.sub(T12, T20) + T25 = fd.ops.broadcast_in_dim(T3, shape=V19, broadcast_dims=[0, 1]) + T26 = fd.ops.mul(T21, T25) + T30 = fd.ops.broadcast_in_dim(T5, shape=V19, broadcast_dims=[1]) + T35 = fd.ops.sum(T4, axes=[0], keepdim=False, dtype=DataType.Null) + + T37 = fd.ops.mul(T4, T30) + T38 = fd.ops.mul(T4, T26) + T39 = fd.ops.sum(T38, axes=[0], keepdim=False, dtype=DataType.Null) + + T41 = fd.ops.mul(T37, T25) + T42 = fd.ops.mul(T37, T21) + T43 = fd.ops.sum(T42, axes=[1], keepdim=False, dtype=DataType.Null) + T47 = fd.ops.broadcast_in_dim(T43, shape=V15, broadcast_dims=[0]) + T48 = fd.ops.neg(T41) + T49 = fd.ops.sum(T48, axes=[1], keepdim=False, dtype=DataType.Null) + T53 = fd.ops.broadcast_in_dim(T49, shape=V15, broadcast_dims=[0]) + S54 = fd.define_scalar(-0.500000, dtype=DataType.Double) + T55 = fd.ops.mul(S54, T47) + S56 = fd.define_scalar(3.00000, dtype=DataType.Double) + T57 = fd.ops.pow(T3, S56) + T58 = fd.ops.mul(T55, T57) + T61 = fd.ops.sum(T53, axes=[1], keepdim=False, dtype=DataType.Null) + T62 = fd.ops.sum(T58, axes=[1], keepdim=False, dtype=DataType.Null) + T66 = fd.ops.broadcast_in_dim(T62, shape=V15, broadcast_dims=[0]) + T70 = fd.ops.broadcast_in_dim(T66, shape=V19, broadcast_dims=[0, 1]) + T74 = fd.ops.broadcast_in_dim(T2, shape=V15, broadcast_dims=[0]) + T78 = fd.ops.broadcast_in_dim(T74, shape=V19, broadcast_dims=[0, 1]) + S79 = fd.define_scalar(2.00000, dtype=DataType.Double) + T80 = fd.ops.mul(S79, T70) + T81 = fd.ops.sub(T12, T78) + T82 = fd.ops.mul(T80, T81) + S84 = fd.ops.reciprocal(T6.size(1)) + T85 = fd.ops.mul(T82, S84) + T89 = fd.ops.broadcast_in_dim(T61, shape=V15, broadcast_dims=[0]) + T93 = fd.ops.broadcast_in_dim(T89, shape=V19, broadcast_dims=[0, 1]) + T95 = fd.ops.mul(S84, T93) + T96 = fd.ops.add(T85, T95) + T97 = fd.ops.add(T41, T96) + + T100 = fd.ops.mul(T97, S10) + T101 = fd.ops.mul(T100, T1) + T102 = fd.ops.add(T97, T101) + if dtype in PROMOTE_DTYPES: + T35 = fd.ops.cast(T35, dtype=dtype) + T39 = fd.ops.cast(T39, dtype=dtype) + T102 = fd.ops.cast(T102, dtype=dtype) + fd.add_output(T102) + fd.add_output(T39) + fd.add_output(T35) + + +@pytest.mark.parametrize("size", generate_input_sizes(dims=2)) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_layernorm_bwd_benchmark( + benchmark, + size: tuple, + dtype: torch.dtype, + disable_validation: bool, + disable_benchmarking: bool, + eps: float = 1e-5, +): + clear_cuda_cache() + inputs = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) + grads = torch.randn(*size, device="cuda", dtype=dtype) + weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True) + bias = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True) + dropout_p = 0.1 + dropout_mask = torch.lt(torch.rand(*size, device="cuda"), 1 - dropout_p) + x = inputs + 1 / (1 - dropout_p) * dropout_mask * inputs + mean = x.to(torch.float).mean(dim=-1) + variance = x.to(torch.float).var(dim=-1, unbiased=False) + invstd = (1.0 / torch.sqrt(variance + eps)).unsqueeze(1) + + with FusionDefinition() as fd: + dropout_layernorm_bwd_fusion( + fd, torch_dtype_to_nvfuser_dtype(dtype), dropout_p=dropout_p + ) + if not disable_validation: + eager_output = torch.nn.functional.layer_norm( + x.to(torch.double), + inputs.shape[1:], + weight=weights.to(torch.double), + bias=bias.to(torch.double), + ) + + eager_output.backward(grads.to(torch.double)) + fd.validate( + [dropout_mask, mean, invstd, grads, weights, inputs], + [inputs.grad, weights.grad, bias.grad], + ) + if not disable_benchmarking: + run_benchmark( + benchmark, fd.execute, [dropout_mask, mean, invstd, grads, weights, inputs] + ) diff --git a/python_benchmarks/test_dropout_layernorm_fwd.py b/python_benchmarks/test_dropout_layernorm_fwd.py new file mode 100644 index 00000000000..4136b74f86e --- /dev/null +++ b/python_benchmarks/test_dropout_layernorm_fwd.py @@ -0,0 +1,102 @@ +import pytest +from nvfuser import FusionDefinition, DataType +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype +from .core import run_benchmark, clear_cuda_cache +import torch +from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES + + +def dropout_layernorm_fwd_fusion( + fd: FusionDefinition, dtype: DataType, dropout_p: float, eps: float = 1e-5 +) -> None: + """ + Forward pass fusion definition for computing: + output = layernorm (input + dropout (input, p=dropout_p)) + + Fusion inputs: input, weights, bias + Fusion outputs: output, mean, invstd, dropout_mask + """ + T2 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) + T1 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=dtype, is_cpu=False) + T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=dtype, is_cpu=False) + S3 = fd.define_scalar(0.00000, dtype=DataType.Double) + S4 = fd.define_scalar(1.00000, dtype=DataType.Double) + T8 = fd.ops.uniform(S3, S4, shape=T2.shape(), dtype=DataType.Float) + S9 = fd.define_scalar(1 - dropout_p, dtype=DataType.Double) + T10 = fd.ops.lt(T8, S9) + T11 = fd.ops.cast(T10, dtype=DataType.Float) + if dtype in PROMOTE_DTYPES: + T0 = fd.ops.cast(T0, dtype=DataType.Float) + T1 = fd.ops.cast(T1, dtype=DataType.Float) + T2 = fd.ops.cast(T2, dtype=DataType.Float) + + # Dropout + Add + T13 = fd.ops.mul(T2, T11) + S14 = fd.define_scalar(1 / (1 - dropout_p), dtype=DataType.Double) + T15 = fd.ops.mul(T13, S14) + T16 = fd.ops.add(T2, T15) + # Layernorm + T17, T18 = fd.ops.var_mean(T16, axes=[1], correction=0, keepdim=False) + V21 = fd.define_vector([T2.size(0), 1], dtype=DataType.Int) + T22 = fd.ops.broadcast_in_dim(T17, shape=V21, broadcast_dims=[0]) + T26 = fd.ops.broadcast_in_dim(T18, shape=V21, broadcast_dims=[0]) + S27 = fd.define_scalar(eps, dtype=DataType.Double) + T28 = fd.ops.add(T22, S27) + T29 = fd.ops.rsqrt(T28) + T33 = fd.ops.broadcast_in_dim(T26, shape=T2.shape(), broadcast_dims=[0, 1]) + T34 = fd.ops.sub(T16, T33) + T38 = fd.ops.broadcast_in_dim(T29, shape=T2.shape(), broadcast_dims=[0, 1]) + T39 = fd.ops.mul(T34, T38) + T43 = fd.ops.broadcast_in_dim(T1, shape=T2.shape(), broadcast_dims=[1]) + T45 = fd.ops.mul(T39, T43) + T49 = fd.ops.broadcast_in_dim(T0, shape=T2.shape(), broadcast_dims=[1]) + T51 = fd.ops.add(T45, T49) + if dtype in PROMOTE_DTYPES: + T51 = fd.ops.cast(T51, dtype=dtype) + + fd.add_output(T51) + fd.add_output(T18) + fd.add_output(T29) + fd.add_output(T10) + + +@pytest.mark.parametrize("size", generate_input_sizes(dims=2)) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_dropout_layernorm_fwd_benchmark( + benchmark, + size: tuple, + dtype: torch.dtype, + disable_validation: bool, + disable_benchmarking: bool, + eps: float = 1e-5, +): + clear_cuda_cache() + inputs = [ + torch.randn(*size, device="cuda", dtype=dtype), + torch.ones(size[1], device="cuda", dtype=dtype), + torch.zeros(size[1], device="cuda", dtype=dtype), + ] + # dropout_p = 0.0 in fwd benchmark for validating the dropout mask + dropout_p = 0.0 + dropout_mask = torch.lt(torch.rand(*size, device="cuda"), 1 - dropout_p) + with FusionDefinition() as fd: + dropout_layernorm_fwd_fusion(fd, torch_dtype_to_nvfuser_dtype(dtype), dropout_p) + if not disable_validation: + # dropout + add + x = inputs[0] + 1 / (1 - dropout_p) * dropout_mask * inputs[0] + # layernorm + eager_output = torch.nn.functional.layer_norm( + x.to(torch.float), + inputs[0].shape[1:], + weight=inputs[1].to(torch.float), + bias=inputs[2].to(torch.float), + ) + # mean and invstd are computed for the output of dropout + add + mean = x.to(torch.float).mean(dim=-1) + variance = x.to(torch.float).var(dim=-1, unbiased=False) + invstd = (1.0 / torch.sqrt(variance + eps)).unsqueeze(1) + fd.validate(inputs, [eager_output.to(dtype), mean, invstd, dropout_mask]) + if not disable_benchmarking: + run_benchmark(benchmark, fd.execute, inputs) diff --git a/python_benchmarks/test_dropout_rmsnorm_bwd.py b/python_benchmarks/test_dropout_rmsnorm_bwd.py new file mode 100644 index 00000000000..73d966b46a4 --- /dev/null +++ b/python_benchmarks/test_dropout_rmsnorm_bwd.py @@ -0,0 +1,132 @@ +import pytest +from nvfuser import FusionDefinition, DataType +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype +from .core import run_benchmark, clear_cuda_cache +import torch +from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES + + +def dropout_rmsnorm_bwd_fusion( + fd: FusionDefinition, + dtype: DataType, + dropout_p: float, +) -> None: + """ + Backward pass fusion definition for computing: + output = rmsnorm (input + dropout (input, p=dropout_p)) + + Fusion inputs: input, dropout_mask, rms, grad_output, weights + Fusion outputs: grad_input, grad_weights + """ + T5 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) # inputs + T6 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=DataType.Bool, is_cpu=False + ) # dropout_mask + T7 = fd.define_tensor( + shape=[-1, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False + ) # rms_eps + T8 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) # grads + T9 = fd.define_tensor( + shape=[-1], contiguity=[True], dtype=dtype, is_cpu=False + ) # weights + + if dtype in PROMOTE_DTYPES: + T5 = fd.ops.cast(T5, dtype=DataType.Float) + T6 = fd.ops.cast(T6, dtype=DataType.Float) + T8 = fd.ops.cast(T8, dtype=DataType.Float) + T9 = fd.ops.cast(T9, dtype=DataType.Float) + + T12 = fd.ops.mul(T5, T6) + S13 = fd.define_scalar(1 / (1 - dropout_p), dtype=DataType.Double) + T14 = fd.ops.mul(T12, S13) + T15 = fd.ops.add(T5, T14) + + V19 = T5.shape() + T20 = fd.ops.broadcast_in_dim(T7, shape=V19, broadcast_dims=[0, 1]) + T22 = fd.ops.reciprocal(T20) + T23 = fd.ops.mul(T15, T22) + + T27 = fd.ops.broadcast_in_dim(T9, shape=V19, broadcast_dims=[1]) + + T30 = fd.ops.mul(T8, T23) + T31 = fd.ops.mul(T8, T27) + T32 = fd.ops.sum(T30, axes=[0], keepdim=False, dtype=DataType.Null) + + T35 = fd.ops.mul(T31, T22) + T36 = fd.ops.neg(T31) + T37 = fd.ops.mul(T36, T15) + S38 = fd.define_scalar(2.00000, dtype=DataType.Double) + T39 = fd.ops.pow(T20, S38) + T40 = fd.ops.reciprocal(T39) + T41 = fd.ops.mul(T37, T40) + T42 = fd.ops.sum(T41, axes=[1], keepdim=False, dtype=DataType.Null) + + V60 = fd.define_vector([T5.size(0), 1], dtype=DataType.Int) + T47 = fd.ops.broadcast_in_dim(T42, shape=V60, broadcast_dims=[0]) + + T50 = fd.ops.mul(S38, T7) + T51 = fd.ops.reciprocal(T50) + T52 = fd.ops.mul(T47, T51) + S55 = fd.ops.reciprocal(T5.size(1)) + T56 = fd.ops.mul(T52, S55) + T57 = fd.ops.sum(T56, axes=[1], keepdim=False, dtype=DataType.Null) + + T61 = fd.ops.broadcast_in_dim(T57, shape=V60, broadcast_dims=[0]) + T65 = fd.ops.broadcast_in_dim(T61, shape=V19, broadcast_dims=[0, 1]) + T66 = fd.ops.mul(T65, S38) + T69 = fd.ops.mul(T66, T15) + T70 = fd.ops.add(T35, T69) + + T73 = fd.ops.mul(T70, S13) + T74 = fd.ops.mul(T73, T6) + T75 = fd.ops.add(T70, T74) + + if dtype in PROMOTE_DTYPES: + T75 = fd.ops.cast(T75, dtype=dtype) + T32 = fd.ops.cast(T32, dtype=dtype) + + fd.add_output(T75) + fd.add_output(T32) + + +@pytest.mark.parametrize("size", generate_input_sizes(dims=2)) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_rmsnorm_bwd_benchmark( + benchmark, + size: tuple, + dtype: torch.dtype, + disable_validation: bool, + disable_benchmarking: bool, + eps: float = 1e-5, +): + clear_cuda_cache() + + inputs = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) + grads = torch.randn(*size, device="cuda", dtype=dtype) + weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True) + + dropout_p = 0.1 + dropout_mask = torch.lt(torch.rand(*size, device="cuda"), 1 - dropout_p) + + x = inputs + 1 / (1 - dropout_p) * dropout_mask * inputs + squared_mean = (x.to(torch.float) ** 2).mean(1, keepdim=True) + rms_eps = torch.sqrt(squared_mean + eps) + + with FusionDefinition() as fd: + dropout_rmsnorm_bwd_fusion(fd, torch_dtype_to_nvfuser_dtype(dtype), dropout_p) + + if not disable_validation: + eager_output = weights * (x / rms_eps) + eager_output.backward(grads.to(torch.double)) + fd.validate( + [inputs, dropout_mask, rms_eps, grads, weights], [inputs.grad, weights.grad] + ) + + if not disable_benchmarking: + run_benchmark( + benchmark, fd.execute, [inputs, dropout_mask, rms_eps, grads, weights] + ) diff --git a/python_benchmarks/test_dropout_rmsnorm_fwd.py b/python_benchmarks/test_dropout_rmsnorm_fwd.py new file mode 100644 index 00000000000..337e8ff76ff --- /dev/null +++ b/python_benchmarks/test_dropout_rmsnorm_fwd.py @@ -0,0 +1,104 @@ +import pytest +from nvfuser import FusionDefinition, DataType +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype +from .core import run_benchmark, clear_cuda_cache +import torch +from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES + + +def dropout_rmsnorm_fwd_fusion( + fd: FusionDefinition, + dtype: DataType, + dropout_p: float, + eps: float = 1e-5, +) -> None: + """ + Forward pass fusion definition for computing: + output = rmsnorm (input + dropout (input, p=dropout_p)) + + Fusion inputs: input, weights + Fusion outputs: output, dropout_mask, rms + """ + T0 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) + T1 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=dtype, is_cpu=False) + + S2 = fd.define_scalar(0.00000, dtype=DataType.Double) + S3 = fd.define_scalar(1.00000, dtype=DataType.Double) + + V6 = T0.shape() + T7 = fd.ops.uniform(S2, S3, shape=V6, dtype=DataType.Float) + S8 = fd.define_scalar(1 - dropout_p, dtype=DataType.Double) + T9 = fd.ops.lt(T7, S8) + T10 = fd.ops.cast(T9, dtype=DataType.Float) + + if dtype in PROMOTE_DTYPES: + T0 = fd.ops.cast(T0, dtype=DataType.Float) + T1 = fd.ops.cast(T1, dtype=DataType.Float) + + T12 = fd.ops.mul(T0, T10) + S13 = fd.define_scalar(1 / (1 - dropout_p), dtype=DataType.Double) + T14 = fd.ops.mul(T12, S13) + T15 = fd.ops.add(T0, T14) + S16 = fd.define_scalar(2.00000, dtype=DataType.Double) + T17 = fd.ops.pow(T15, S16) + T18 = fd.ops.sum(T17, axes=[1], keepdim=False, dtype=DataType.Null) + + V21 = fd.define_vector([T0.size(0), 1], dtype=DataType.Int) + T22 = fd.ops.broadcast_in_dim(T18, shape=V21, broadcast_dims=[0]) + + S24 = fd.ops.reciprocal(T0.size(1)) + T25 = fd.ops.mul(T22, S24) + S26 = fd.define_scalar(eps, dtype=DataType.Double) + T27 = fd.ops.add(T25, S26) + T28 = fd.ops.sqrt(T27) + + T33 = fd.ops.broadcast_in_dim(T28, shape=V6, broadcast_dims=[0, 1]) + + T35 = fd.ops.reciprocal(T33) + T36 = fd.ops.mul(T15, T35) + T40 = fd.ops.broadcast_in_dim(T1, shape=V6, broadcast_dims=[1]) + T42 = fd.ops.mul(T40, T36) + + if dtype in PROMOTE_DTYPES: + T42 = fd.ops.cast(T42, dtype=dtype) + + fd.add_output(T42) + fd.add_output(T9) + fd.add_output(T28) + + +@pytest.mark.parametrize("size", generate_input_sizes(dims=2)) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_rmsnorm_fwd_benchmark( + benchmark, + size: tuple, + dtype: torch.dtype, + disable_validation: bool, + disable_benchmarking: bool, + eps: float = 1e-5, +): + clear_cuda_cache() + + inputs = torch.randn(*size, device="cuda", dtype=dtype) + weights = torch.randn(size[1], device="cuda", dtype=dtype) + + # dropout_p = 0.0 in fwd benchmark for validating the dropout mask + dropout_p = 0.0 + dropout_mask = torch.lt(torch.rand(*size, device="cuda"), 1 - dropout_p) + + with FusionDefinition() as fd: + dropout_rmsnorm_fwd_fusion( + fd, torch_dtype_to_nvfuser_dtype(dtype), dropout_p, eps + ) + + if not disable_validation: + x = inputs + 1 / (1 - dropout_p) * dropout_mask * inputs + squared_mean = (x.to(torch.float) ** 2).mean(1, keepdim=True) + rms_eps = torch.sqrt(squared_mean + eps) + eager_output = weights * (x / rms_eps) + fd.validate([inputs, weights], [eager_output.to(dtype), dropout_mask, rms_eps]) + + if not disable_benchmarking: + run_benchmark(benchmark, fd.execute, [inputs, weights]) diff --git a/python_benchmarks/test_huggingface_attn_bwd.py b/python_benchmarks/test_huggingface_attn_bwd.py index 5be172be78d..108b9fd161e 100644 --- a/python_benchmarks/test_huggingface_attn_bwd.py +++ b/python_benchmarks/test_huggingface_attn_bwd.py @@ -8,7 +8,6 @@ # Fusion from huggingface attention implementation # The nvFuser defintion only includes the non-matmul computation (add + reshape + softmax + dropout) -# https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/tests/hf_bart_self_attn.py#L73-L83 def huggingface_attn_bwd_fusion( fd: FusionDefinition, dtype: DataType, diff --git a/python_benchmarks/test_huggingface_attn_fwd.py b/python_benchmarks/test_huggingface_attn_fwd.py index 392a9adb50d..6410ee3e21f 100644 --- a/python_benchmarks/test_huggingface_attn_fwd.py +++ b/python_benchmarks/test_huggingface_attn_fwd.py @@ -8,7 +8,6 @@ # Fusion from huggingface attention implementation. # The nvFuser defintion only includes the non-matmul computation (add + reshape + softmax + dropout) -# https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/tests/hf_bart_self_attn.py#L73-L83 def huggingface_attn_fwd_fusion( fd: FusionDefinition, dtype: DataType, diff --git a/python_benchmarks/test_nanogpt_attn_bwd.py b/python_benchmarks/test_nanogpt_attn_bwd.py index 3401b4d3890..3c22878adaf 100644 --- a/python_benchmarks/test_nanogpt_attn_bwd.py +++ b/python_benchmarks/test_nanogpt_attn_bwd.py @@ -8,7 +8,6 @@ # Fusion from nanogpt attention module # The nvFuser defintion only includes the non-matmul computation (masked_fill + softmax + dropout) -# https://github.com/Lightning-AI/lightning-thunder/blob/d3da8517bff02a913fd149b4d6559f6b5a4c6c7f/thunder/tests/nanogpt_model.py#L102-L106 def nanogpt_attn_bwd_fusion( fd: FusionDefinition, dtype: DataType, head_size: int, dropout_p: float ): diff --git a/python_benchmarks/test_nanogpt_attn_fwd.py b/python_benchmarks/test_nanogpt_attn_fwd.py index 82b0890792c..c0cb1964eaa 100644 --- a/python_benchmarks/test_nanogpt_attn_fwd.py +++ b/python_benchmarks/test_nanogpt_attn_fwd.py @@ -8,7 +8,6 @@ # Fusion from nanogpt attention module # The nvFuser defintion only includes the non-matmul computation (masked_fill + softmax + dropout) -# https://github.com/Lightning-AI/lightning-thunder/blob/d3da8517bff02a913fd149b4d6559f6b5a4c6c7f/thunder/tests/nanogpt_model.py#L102-L106 def nanogpt_attn_fwd_fusion( fd: FusionDefinition, dtype: DataType, head_size: int, dropout_p: float ): diff --git a/test/test_alias.cpp b/test/test_alias.cpp index 452b29c36f0..7120e30c5b7 100644 --- a/test/test_alias.cpp +++ b/test/test_alias.cpp @@ -853,6 +853,30 @@ TEST_F(AliasTest, OutputAliasesAnotherOutput) { EXPECT_TRUE(permute_out_tensor.is_alias_of(reshape_out_tensor)); } +TEST_F(AliasTest, OutputNotAliasedByAnotherOutputShouldNotBeSegmented) { + // Reproduces #1646. + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigConcreteTensor({2, 3, 5}); + TensorView* add_out = add(in, in); + TensorView* reshape_out = reshape(add_out, {2, 3, 5}, {6, 5}); + TensorView* permute_out = permute(reshape_out, {1, 0}); + TensorView* mul_out = mul(permute_out, permute_out); + + fusion->addInput(in); + fusion->addOutput(reshape_out); + fusion->addOutput(mul_out); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = at::randn({2, 3, 5}).cuda(); + std::vector out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); +} + TEST_F(AliasTest, ManyAliasesBetweenOutputs) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -908,6 +932,53 @@ TEST_F(AliasTest, Broadcast) { EXPECT_EQ(out_tensor.data_ptr(), in_tensor.data_ptr()); } +TEST_F(AliasTest, MergeTwoExpandedBroadcasts) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = TensorViewBuilder() + .ndims(3) + .dtype(DataType::Float) + .contiguity({std::nullopt, std::nullopt, std::nullopt}) + .shape({4, 5, 6}) + .expanded({true, true, true}) + .build(); + fusion->addInput(in); + TensorView* out = reshape(in, {4, 5, 6}, {20, -1}); + fusion->addOutput(out); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = at::randn({1}).cuda().as_strided({4, 5, 6}, {0, 0, 0}); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + testValidate(fec.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); + + // TODO(#1126): This should become an alias when #1126 is fixed. + // EXPECT_TRUE(out_tensor.is_alias_of(in_tensor)); +} + +TEST_F(AliasTest, MergeBroadcastsBetweenConcretes) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = TensorViewBuilder() + .ndims(4) + .dtype(DataType::Float) + .contiguity({true, std::nullopt, std::nullopt, true}) + .shape({2, 3, 5, 7}) + .expanded({false, true, true, false}) + .build(); + fusion->addInput(in); + TensorView* out = reshape(in, {2, 3, 5, 7}, {2, -1, 7}); + out = reshape(out, {2, 15, 7}, {30, 7}); + fusion->addOutput(out); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = + at::randn({2 * 7}).cuda().as_strided({2, 3, 5, 7}, {7, 0, 0, 1}); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + testValidate(fec.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); +} + TEST_F(AliasTest, Squeeze) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index e54fb2b646b..7de7771fa71 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -1385,4 +1385,79 @@ TEST_F(NVFuserTest, ConcretizeConstantExtents) { testValidate(fusion, outputs, inputs, __LINE__, __FILE__); } +// Test that dynamic reductions that should result in squeezes are handled +// properly. +// See https://github.com/NVIDIA/Fuser/issues/1667 +TEST_F(NVFuserTest, DynamicSqueezeTrivialReduction) { + auto fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(3); + fusion->addInput(tv0); + + // Explicitly cast Int to Index, so that these extents are not immediate + // constants + auto tv1 = reshape( + tv0, + { + castOp(DataType::Index, IrBuilder::create(1, DataType::Int)), + castOp(DataType::Index, IrBuilder::create(2, DataType::Int)), + castOp(DataType::Index, IrBuilder::create(2, DataType::Int)), + castOp(DataType::Index, IrBuilder::create(1, DataType::Int)), + castOp(DataType::Index, IrBuilder::create(3, DataType::Int)), + castOp(DataType::Index, IrBuilder::create(3, DataType::Int)), + }); + auto tv2 = sum(tv1, {0, 2, 3, 4}); + fusion->addOutput(tv2); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 2, 9}, options); + std::vector inputs = {t0}; + + auto outputs = fec.runFusionWithInputs(inputs); + + testValidate(fusion, outputs, inputs, __LINE__, __FILE__); +} + +// Same as above but for Welford ops +// See https://github.com/NVIDIA/Fuser/issues/1667 +TEST_F(NVFuserTest, DynamicSqueezeTrivialWelford) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(3); + fusion->addInput(tv0); + + // Explicitly cast Int to Index, so that these extents are not immediate + // constants + auto tv1 = reshape( + tv0, + { + castOp(DataType::Index, IrBuilder::create(1, DataType::Int)), + castOp(DataType::Index, IrBuilder::create(2, DataType::Int)), + castOp(DataType::Index, IrBuilder::create(2, DataType::Int)), + castOp(DataType::Index, IrBuilder::create(1, DataType::Int)), + castOp(DataType::Index, IrBuilder::create(3, DataType::Int)), + castOp(DataType::Index, IrBuilder::create(3, DataType::Int)), + }); + auto res = + variance_mean(tv1, {0, 2, 3, 4}, /*unbiased=*/true, /*keepdim=*/false); + fusion->addOutput(res.mean); + fusion->addOutput(res.var); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 2, 9}, options); + std::vector inputs = {t0}; + + auto outputs = fec.runFusionWithInputs(inputs); + + testValidate(fusion, outputs, inputs, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 881a8a74a46..06d625b5406 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8641,6 +8641,48 @@ TEST_F(NVFuserTest, ProjectToInputsAndBroadcastTvs3) { auto cg_outputs = fe.runFusion(inputs, persistent_params->lparams); } +// Test 3D reductions with constant domains. +// https://github.com/NVIDIA/Fuser/issues/1590 +TEST_F(NVFuserTest, Reduction3DConstantIterationDomain) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + long x = 2L, y = 8L, z = 8L, w = 16L, h = 512L; + auto tv0 = TensorViewBuilder() + .ndims(5) + .shape({-1, -1, -1, -1, -1}) + .contiguity({true, true, true, true, true}) + .strideOrder({4, 3, 2, 0, 1}) + .build(); + fusion->addInput(tv0); + auto tv1 = full( + {IrBuilder::create(x), + IrBuilder::create(y), + IrBuilder::create(z), + IrBuilder::create(w), + IrBuilder::create(h)}, + fusion->oneVal(), + DataType::Float); + auto tv2 = mul(tv0, tv1); + auto tv3 = sum(tv2, {2, 4}); + fusion->addOutput(tv3); + + // tv1 is a constant tensor, and its domains are constant. + // Its constant domains are used in ExactMappedExtentSubstitutionPass + // to substitute the domains of tv0. + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = + at::randn({x, y, z, w, h}, options) + .as_strided({x, y, z, w, h}, {w * h * z * y, w * h * z, w * h, 1, w}); + std::vector inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(inputs); + + auto ref = t0.to(at::kDouble).sum({2, 4}); + testValidate( + executor_cache.fusion(), cg_outputs, inputs, {ref}, __LINE__, __FILE__); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser diff --git a/test/test_pipeline.cpp b/test/test_pipeline.cpp index 6c359ad8a4a..1d806c536a8 100644 --- a/test/test_pipeline.cpp +++ b/test/test_pipeline.cpp @@ -22,6 +22,26 @@ namespace nvfuser { using namespace at::indexing; +std::string sortByLine(const std::string& input) { + auto ss = std::stringstream(input); + std::vector lines; + std::string line; + while (std::getline(ss, line, '\n')) { + lines.push_back(line); + } + std::sort(lines.begin(), lines.end()); + std::stringstream output; + bool first = true; + for (auto line : lines) { + if (!first) { + output << std::endl; + } + first = false; + output << line; + } + return output.str(); +} + TEST_F(NVFuserTest, Pipeline_CUDA) { // Fusion definition Fusion fusion; @@ -82,7 +102,7 @@ TEST_F(NVFuserTest, Pipeline_CUDA) { " PipelineVal representing Val T0_g[ iS0{i0}, iS1{i2} ] on stage " + std::to_string(stage0.unique_id) + "\n" - " PipelineVal representing Val T4_g[ iS6{i13}, iS7{i14}, iS8{i15} ] on stage " + + " PipelineVal representing Val T4_g[ iS6{i15}, iS7{i16}, iS8{i17} ] on stage " + std::to_string(stage2.unique_id) + "\n" "}\n" @@ -106,98 +126,96 @@ TEST_F(NVFuserTest, Pipeline_CUDA) { ".Inputs={T2_l[ iS4{i2} ], }. Outputs={T3_g[ rS5{i2} ], }.\n" " PipelineStage representing Stage " + std::to_string(stage2.unique_id) + - ".Inputs={T4_g[ iS6{i13}, iS7{i14}, iS8{i15} ], }. Outputs={T5_l[ rS9{i13}, iS10{i14}, iS11{i15} ], }.\n" - " PipelineVal representing Val T5_l[ rS9{i13}, iS10{i14}, iS11{i15} ] on stage " + + ".Inputs={T4_g[ iS6{i15}, iS7{i16}, iS8{i17} ], }. Outputs={T5_l[ rS9{i15}, iS10{i16}, iS11{i17} ], }.\n" + " PipelineVal representing Val T5_l[ rS9{i15}, iS10{i16}, iS11{i17} ] on stage " + std::to_string(stage2.unique_id) + "\n" - " PipelineCommunication that transfers PipelineVal representing Val T5_l[ rS9{i13}, iS10{i14}, iS11{i15} ] on stage " + + " PipelineCommunication that transfers PipelineVal representing Val T5_l[ rS9{i15}, iS10{i16}, iS11{i17} ] on stage " + std::to_string(stage2.unique_id) + - " to PipelineVal representing Val T6_l[ iS12{i14}, iS13{i15} ] on stage " + + " to PipelineVal representing Val T6_l[ iS12{i16}, iS13{i17} ] on stage " + std::to_string(stage3.unique_id) + "\n" - " PipelineVal representing Val T6_l[ iS12{i14}, iS13{i15} ] on stage " + + " PipelineVal representing Val T6_l[ iS12{i16}, iS13{i17} ] on stage " + std::to_string(stage3.unique_id) + "\n" " PipelineStage representing Stage " + std::to_string(stage3.unique_id) + - ".Inputs={T6_l[ iS12{i14}, iS13{i15} ], }. Outputs={T7_l[ iS14{i14}, iS15{i15} ], T8_l[ rS16{i14}, iS17{i15} ], }.\n" - " PipelineVal representing Val T7_l[ iS14{i14}, iS15{i15} ] on stage " + + ".Inputs={T6_l[ iS12{i16}, iS13{i17} ], }. Outputs={T7_l[ iS14{i16}, iS15{i17} ], T8_l[ rS16{i16}, iS17{i17} ], }.\n" + " PipelineVal representing Val T7_l[ iS14{i16}, iS15{i17} ] on stage " + std::to_string(stage3.unique_id) + "\n" - " PipelineCommunication that transfers PipelineVal representing Val T7_l[ iS14{i14}, iS15{i15} ] on stage " + + " PipelineCommunication that transfers PipelineVal representing Val T7_l[ iS14{i16}, iS15{i17} ] on stage " + std::to_string(stage3.unique_id) + - " to PipelineVal representing Val T12_l[ iS24{i14}, iS25{i15} ] on stage " + + " to PipelineVal representing Val T12_l[ iS24{i16}, iS25{i17} ] on stage " + std::to_string(stage5.unique_id) + "\n" - " PipelineVal representing Val T12_l[ iS24{i14}, iS25{i15} ] on stage " + + " PipelineVal representing Val T12_l[ iS24{i16}, iS25{i17} ] on stage " + std::to_string(stage5.unique_id) + "\n" " PipelineStage representing Stage " + std::to_string(stage5.unique_id) + - ".Inputs={T12_l[ iS24{i14}, iS25{i15} ], }. Outputs={T13_g[ rS26{i14}, iS27{i15} ], }.\n" - " PipelineVal representing Val T8_l[ rS16{i14}, iS17{i15} ] on stage " + + ".Inputs={T12_l[ iS24{i16}, iS25{i17} ], }. Outputs={T13_g[ rS26{i16}, iS27{i17} ], }.\n" + " PipelineVal representing Val T8_l[ rS16{i16}, iS17{i17} ] on stage " + std::to_string(stage3.unique_id) + "\n" - " PipelineCommunication that transfers PipelineVal representing Val T8_l[ rS16{i14}, iS17{i15} ] on stage " + + " PipelineCommunication that transfers PipelineVal representing Val T8_l[ rS16{i16}, iS17{i17} ] on stage " + std::to_string(stage3.unique_id) + - " to PipelineVal representing Val T14_l[ iS28{i15} ] on stage " + + " to PipelineVal representing Val T14_l[ iS28{i17} ] on stage " + std::to_string(stage6.unique_id) + "\n" - " PipelineVal representing Val T14_l[ iS28{i15} ] on stage " + + " PipelineVal representing Val T14_l[ iS28{i17} ] on stage " + std::to_string(stage6.unique_id) + "\n" - " PipelineCommunication that transfers PipelineVal representing Val T5_l[ rS9{i13}, iS10{i14}, iS11{i15} ] on stage " + + " PipelineCommunication that transfers PipelineVal representing Val T5_l[ rS9{i15}, iS10{i16}, iS11{i17} ] on stage " + std::to_string(stage2.unique_id) + - " to PipelineVal representing Val T9_l[ iS18{i14}, iS19{i15} ] on stage " + + " to PipelineVal representing Val T9_l[ iS18{i16}, iS19{i17} ] on stage " + std::to_string(stage4.unique_id) + "\n" - " PipelineVal representing Val T9_l[ iS18{i14}, iS19{i15} ] on stage " + + " PipelineVal representing Val T9_l[ iS18{i16}, iS19{i17} ] on stage " + std::to_string(stage4.unique_id) + "\n" " PipelineStage representing Stage " + std::to_string(stage4.unique_id) + - ".Inputs={T9_l[ iS18{i14}, iS19{i15} ], }. Outputs={T11_l[ rS22{i14}, iS23{i15} ], }.\n" - " PipelineVal representing Val T11_l[ rS22{i14}, iS23{i15} ] on stage " + + ".Inputs={T9_l[ iS18{i16}, iS19{i17} ], }. Outputs={T11_l[ rS22{i16}, iS23{i17} ], }.\n" + " PipelineVal representing Val T11_l[ rS22{i16}, iS23{i17} ] on stage " + std::to_string(stage4.unique_id) + "\n" - " PipelineCommunication that transfers PipelineVal representing Val T11_l[ rS22{i14}, iS23{i15} ] on stage " + + " PipelineCommunication that transfers PipelineVal representing Val T11_l[ rS22{i16}, iS23{i17} ] on stage " + std::to_string(stage4.unique_id) + - " to PipelineVal representing Val T15_l[ iS29{i15} ] on stage " + + " to PipelineVal representing Val T15_l[ iS29{i17} ] on stage " + std::to_string(stage6.unique_id) + "\n" - " PipelineVal representing Val T15_l[ iS29{i15} ] on stage " + + " PipelineVal representing Val T15_l[ iS29{i17} ] on stage " + std::to_string(stage6.unique_id) + "\n" - " PipelineCommunication that transfers PipelineVal representing Val T13_g[ rS26{i14}, iS27{i15} ] on stage " + + " PipelineCommunication that transfers PipelineVal representing Val T13_g[ rS26{i16}, iS27{i17} ] on stage " + std::to_string(stage5.unique_id) + - " to PipelineVal representing Val T16_l[ iS30{i15} ] on stage " + + " to PipelineVal representing Val T16_l[ iS30{i17} ] on stage " + std::to_string(stage6.unique_id) + "\n" - " PipelineVal representing Val T16_l[ iS30{i15} ] on stage " + + " PipelineVal representing Val T16_l[ iS30{i17} ] on stage " + std::to_string(stage6.unique_id) + "\n" " PipelineStage representing Stage " + std::to_string(stage6.unique_id) + - ".Inputs={T14_l[ iS28{i15} ], T15_l[ iS29{i15} ], T16_l[ iS30{i15} ], }. Outputs={T19_g[ rS33{i15} ], }.\n" + ".Inputs={T14_l[ iS28{i17} ], T15_l[ iS29{i17} ], T16_l[ iS30{i17} ], }. Outputs={T19_g[ rS33{i17} ], }.\n" "}\n" "Pipeline's outputs:{\n" " PipelineVal representing Val T3_g[ rS5{i2} ] on stage " + std::to_string(stage1.unique_id) + "\n" - " PipelineVal representing Val T13_g[ rS26{i14}, iS27{i15} ] on stage " + + " PipelineVal representing Val T13_g[ rS26{i16}, iS27{i17} ] on stage " + std::to_string(stage5.unique_id) + "\n" - " PipelineVal representing Val T19_g[ rS33{i15} ] on stage " + + " PipelineVal representing Val T19_g[ rS33{i17} ] on stage " + std::to_string(stage6.unique_id) + "\n" "}"}; - // We sort the string so it doesn't depend on the order of the Pipeline's DAG - // traversal - - // TODO: we should sort on lines, not on characters - std::sort(obtained_string.begin(), obtained_string.end()); - std::sort(ref_string.begin(), ref_string.end()); + // We sort the string by line so it doesn't depend on the order of the + // Pipeline's DAG traversal + obtained_string = sortByLine(obtained_string); + ref_string = sortByLine(ref_string); EXPECT_EQ(obtained_string, ref_string); } diff --git a/test/test_serial_gridreduce.cpp b/test/test_serial_gridreduce.cpp index e7fc291ebbe..43675299b68 100644 --- a/test/test_serial_gridreduce.cpp +++ b/test/test_serial_gridreduce.cpp @@ -256,4 +256,97 @@ TEST_F(SerialGridReductionTest, CodegenNodes) { } } +TEST_F(SerialGridReductionTest, Scheduling) { + for (bool serial : {true, false}) { + for (int64_t num_warps : {4, 8}) { + // B is size of inner serial loop. Outer loop is hardcoded at A=4 + // Here we set B to a small value of 8 instead of 32 (i.e. 128 elements + // per thread), so that the non-serial compilation does not take too + // long. + for (int64_t B : {8}) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + int64_t blocks_x = 8; + int64_t blocks_y = 8; + int64_t blocks_z = 5; + int64_t A = 4; // Size of outer serial loop + int64_t H = blocks_z; + int64_t W = A * B * blocks_x * blocks_y * num_warps * 32; + + // Unreduced dimensions should be concrete. Reduced dimension could be + // symbolic, but is concrete here so that we can read tv0 to registers + TensorView* tv0 = TensorViewBuilder() + .shape({H, W}) + .dtype(DataType::Float) + .contiguity(true) + .build(); + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {0}); + fusion->addOutput(tv1); + + // Start with + // [ rS{H}, iS{W} ] + // We are grid reducing the H dimension and we want to coalesce + // accesses in the W dimension. So we first reorder to + // [ iS{W}, rS{H} ] + // then schedule as + // [ iBIDx{blocks_x}, iBIDy{blocks_y}, iS{A}, iS{B}, iTIDy{num_warps}, + // iTIDx{32}, rBIDz{blocks_z} ] + auto tv2 = tv0->cacheAfter(); + auto tv3 = tv1->cacheBefore(); + + tv3->reorder({{1, 0}, {0, 1}}); // blocks_x*blocks_y*A*B*num_warps*32, H + tv3->split(0, 32); // blocks_x*blocks_y*A*B*num_warps, 32, H + tv3->split(0, num_warps); // blocks_x*blocks_y*A*B, num_warps, 32, H + tv3->split(0, B); // blocks_x*blocks_y*A, B, num_warps, 32, H + tv3->split(0, A); // blocks_x*blocks_y, A, B, num_warps, 32, H + tv3->split(0, blocks_y); // blocks_x, blocks_y, A, B, num_warps, 32, H + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(1)->parallelize(ParallelType::BIDy); + tv3->axis(4)->parallelize(ParallelType::TIDy); + tv3->axis(5)->parallelize(ParallelType::TIDx); + tv3->axis(6)->parallelize(ParallelType::BIDz); + // Reorder to put parallel dims first for better inlining + tv3->reorder({ + {4, 2}, + {5, 3}, + {2, 4}, + {3, 5}, + }); + + TransformPropagator propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv3); + + // Here we just transpose A and B in tv2, so that it will be partially + // inlined with tv3, resulting in a separate loop to load tv0 into + // registers (tv2). + tv2->reorder({ + {-2, -3}, + {-3, -2}, + }); + + inlineMost(); + + FusionExecutor fe; + if (serial) { + tv3->definition()->as()->requestSerialGridReduction(); + } + fe.compileFusion(fusion); + + auto input = at::randn( + {H, W}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0)); + auto outputs = fe.runFusion({input}); + + if (serial) { + testValidate(fusion, outputs, {input}, __LINE__, __FILE__); + } + } + } + } +} + } // namespace nvfuser From 62346a90fe6fa972706449955a9ba2bc4536d892 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 26 Jan 2024 21:49:05 -0800 Subject: [PATCH 133/178] Add tests for Step 2 --- test/test_id_model.cpp | 335 ++++++++++++++++++++++++++++++++--------- 1 file changed, 261 insertions(+), 74 deletions(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 68fa4d1f058..e33f51ffd21 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -16,8 +16,10 @@ #include #include #include +#include #include #include +#include namespace nvfuser { @@ -50,8 +52,11 @@ class IdModelTester : public IdModel { // Do not automatically build the graphs IdModelTester(Fusion* fusion) : IdModel(fusion, /* build_graphs */ false) {} - std::pair> - getInlineRootResolutionMap() { + std::tuple< + ValGraph, + std::unordered_map, + std::unordered_map> + getInitialIELPromotionMap() { // Make sure the depedent graphs are already built maybeBuildGraph(IdMappingMode::EXACT); maybeBuildGraph(IdMappingMode::PERMISSIVE); @@ -64,41 +69,148 @@ class IdModelTester : public IdModel { initializeLoopGraph(inlining_info); + VERBOSE() << "Initial loop graph:\n"; + for (const auto& group : + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { + VERBOSE() << nvfuser::toString(group) << std::endl; + } + ValGraph iel_graph = buildIntersection( idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); std::unordered_map root_promotion_map = buildInlineRootResolutionMap(iel_graph, inlining_info); - return {std::move(iel_graph), std::move(root_promotion_map)}; + { + std::stringstream ss; + ss << "Step 1: Root promotion map\n"; + for (const auto& [iel_group, promoted_id] : root_promotion_map) { + ss << "\t" << nvfuser::toString(iel_group) << " -> " + << promoted_id->name() << std::endl; + } + VERBOSE() << ss.str(); + } + + auto iel_promotion_map = root_promotion_map; + + propagatePromotionsInIELGraph(iel_graph, iel_promotion_map); + + { + std::stringstream ss; + ss << "Step 2: IEL promotion map\n"; + for (const auto& [iel_group, promoted_id] : iel_promotion_map) { + ss << "\t" << nvfuser::toString(iel_group) << " -> " + << promoted_id->name() << std::endl; + } + VERBOSE() << ss.str(); + } + + return { + std::move(iel_graph), + std::move(root_promotion_map), + std::move(iel_promotion_map)}; } }; -// Test if root_broadcast_id is resolved to ref_id. If ref_id is -// nullptr, test if root_broadcast_id has no resolution. -void validateResolution( - IterDomain* root_broadcast_id, +// Test if id is resolved to an ID that is exact mapped with +// ref_id. If ref_id is nullptr, test if root_broadcast_id has no +// resolution. +void validateIELResolution( + IterDomain* id, IterDomain* ref_id, const ValGraph& iel_graph, - const std::unordered_map& root_resolution_map) { - ASSERT_TRUE(root_broadcast_id->isBroadcast()); - const auto& iel_group = iel_graph.toGroup(root_broadcast_id); - auto root_promotion_map_it = root_resolution_map.find(iel_group); + const ValGraph& exact_graph, + const std::unordered_map& iel_promotion_map) { + const auto& iel_group = iel_graph.toGroup(id); + auto iel_promotion_map_it = iel_promotion_map.find(iel_group); if (ref_id != nullptr) { - ASSERT_TRUE(root_promotion_map_it != root_resolution_map.end()) - << "Root resolution not found for: " << nvfuser::toString(iel_group); + ASSERT_TRUE(iel_promotion_map_it != iel_promotion_map.end()) + << "IEL promotion not found for: " << nvfuser::toString(iel_group); ASSERT_FALSE(ref_id->isBroadcast()); - auto resolution_id = root_promotion_map_it->second; + auto promotion_id = iel_promotion_map_it->second; ASSERT_TRUE( - iel_graph.disjointValSets().strictAreMapped(resolution_id, ref_id)) - << "Unexpected root resolution. " + exact_graph.disjointValSets().strictAreMapped(promotion_id, ref_id)) + << "Unexpected promotion. " << "Expected: " << ref_id->toString() - << ". Actual: " << resolution_id->toString(); + << ". Actual: " << promotion_id->toString(); } else { - ASSERT_TRUE(root_promotion_map_it == root_resolution_map.end()) - << "Root resolution should not exist for: " - << nvfuser::toString(iel_group) - << ", but found: " << root_promotion_map_it->second->toString(); + ASSERT_TRUE(iel_promotion_map_it == iel_promotion_map.end()) + << "Promotion should not exist for: " << nvfuser::toString(iel_group) + << ", but found: " << iel_promotion_map_it->second->toString(); + } +} + +// Check if each domain gets promoted to a proper domain after the +// Step 2 IEL propagation. It is assumed that the proper promotion is +// the corresponding domain in the unique consumer tensor, which is +// the case with most of the test fusions. +void checkStep2Results( + Fusion* fusion, + const ValGraph& iel_graph, + const ValGraph& exact_graph, + const std::unordered_map& iel_promotion_map) { + auto getPromotedDomain = [&](IterDomain* id) -> IterDomain* { + if (auto it = iel_promotion_map.find(iel_graph.toGroup(id)); + it != iel_promotion_map.end()) { + return it->second; + } else { + return nullptr; + } + }; + + for (auto tv : ir_utils::allTvs(fusion)) { + // If there's no broadcast or it isn't inlined, there's no + // promotion + if (std::none_of( + tv->getRootDomain().begin(), + tv->getRootDomain().end(), + [](auto id) { return id->isBroadcast(); }) || + (tv->getComputeAtPosition() == 0 && + tv->getMaxProducerPosition() == 0)) { + // Make sure there's no promotion of any of the IDs of this tensor + for (auto id : ir_utils::allIDsOf(tv)) { + auto promoted_id = getPromotedDomain(id); + ASSERT_EQ(promoted_id, nullptr) + << "Expected no mapping for " << id->toString() + << " but found to be mapped to: " << promoted_id->toString(); + } + continue; + } + + auto consumers = ir_utils::consumerTvsOf(tv); + ASSERT_EQ(consumers.size(), 1) << "Assumed to have one consumer"; + TensorView* c_tv = consumers.at(0); + const auto p2c = BestEffortReplay::replayCasP( + c_tv, tv, -1, PairwiseRootDomainMap(tv, c_tv)) + .getReplay(); + + for (auto p_id : ir_utils::allIDsOf(tv)) { + // Root domains are already done at Step 1 + if (std::find( + tv->getRootDomain().begin(), tv->getRootDomain().end(), p_id) != + tv->getRootDomain().end()) { + continue; + } + + // If no broadcast is involved, nothing should be promoted + auto p_id_dep_vals = DependencyCheck::getAllValsBetween( + {tv->getRootDomain().begin(), tv->getRootDomain().end()}, {p_id}); + if (std::find_if( + p_id_dep_vals.begin(), p_id_dep_vals.end(), [](Val* dep_id) { + return dep_id->as()->isBroadcast(); + }) == p_id_dep_vals.end()) { + auto promoted_id = getPromotedDomain(p_id); + ASSERT_EQ(promoted_id, nullptr) + << "Expected no mapping for " << p_id->toString() + << " but found to be mapped to: " << promoted_id->toString(); + continue; + } + + // p_id should be promoted to c_id + auto c_id = p2c.at(p_id); + validateIELResolution( + p_id, c_id, iel_graph, exact_graph, iel_promotion_map); + } } } @@ -209,8 +321,8 @@ TensorView* findTensorByName( } // namespace -// Testing root resolution with a simple broadcast pattern -TEST_F(IdModelTest, LoopGraphRootResolution1) { +// Testing loop promotion with a simple broadcast pattern +TEST_F(IdModelTest, LoopPromotion1) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -224,8 +336,8 @@ TEST_F(IdModelTest, LoopGraphRootResolution1) { { IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getInitialIELPromotionMap(); // Nothing inlined. Should be no resolution ASSERT_TRUE(root_resolution_map.empty()); @@ -236,21 +348,29 @@ TEST_F(IdModelTest, LoopGraphRootResolution1) { { IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getInitialIELPromotionMap(); + // Check Step 1 results // t2 is now fully inlined. Its root broadcast domain should be // resoled with the corresponding domain of t3 - validateResolution( + validateIELResolution( t2->getRootDomain().at(0), t3->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); + + // Check Step 2 results + // Nothing to propagate in this fusion, so iel_promotion_map + // should be equivalent to root_resolution_map + ASSERT_EQ(root_resolution_map, iel_promotion_map) + << "Unexpected IEL promotion map"; } } // Test with a fusion with progressive broadcasting -TEST_F(IdModelTest, LoopGraphRootResolution2) { +TEST_F(IdModelTest, LoopPromotion2) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -267,25 +387,34 @@ TEST_F(IdModelTest, LoopGraphRootResolution2) { inlineMost(); IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getInitialIELPromotionMap(); + // Check Step 1 results // Validate t2 and t3 as they have root broadcast domains - validateResolution( + validateIELResolution( t2->getRootDomain().at(0), t4->getRootDomain().at(1), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); - validateResolution( + validateIELResolution( t3->getRootDomain().at(0), t4->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); + + // Check Step 2 results + // Nothing to propagate in this fusion, so iel_promotion_map + // should be equivalent to root_resolution_map + ASSERT_EQ(root_resolution_map, iel_promotion_map) + << "Unexpected IEL promotion map"; } // Multiple inlined and non-inlined broadcast domains -TEST_F(IdModelTest, LoopGraphRootResolution3) { +TEST_F(IdModelTest, LoopPromotion3) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -311,32 +440,52 @@ TEST_F(IdModelTest, LoopGraphRootResolution3) { // tv3: [i0*i1, i2*i3] IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getInitialIELPromotionMap(); + // Check Step 1 results // The b1 broadcast domain tv2 should be resolved as it's inlined, // but b3 should not. - validateResolution( + validateIELResolution( tv2->getRootDomain().at(1), tv3->getRootDomain().at(1), iel_graph, + tester.idGraph(IdMappingMode::EXACT), + root_resolution_map); + + validateIELResolution( + tv2->getRootDomain().at(3), + nullptr, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); - validateResolution( - tv2->getRootDomain().at(3), nullptr, iel_graph, root_resolution_map); + // Check Step 2 results + validateIELResolution( + tv2->axis(0), + tv3->axis(0), + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); + + validateIELResolution( + tv2->axis(1), + nullptr, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } // Test root resolution with a fusion with outer split -TEST_F(IdModelTest, LoopGraphRootResolution4) { +TEST_F(IdModelTest, LoopPromotion4) { auto fusion = createFusionWithInlinedOuterSplit(); auto all_tvs = ir_utils::allTvs(fusion.get()); IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getInitialIELPromotionMap(); - // Verify all tensors with broadcast have correct resolution of root - // broadcast domains + // Verify all tensors with root broadcast have correct resolutions for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( @@ -351,20 +500,28 @@ TEST_F(IdModelTest, LoopGraphRootResolution4) { case 2: // T2_l[ iS20{4}, iS21{( ceilDiv(( 1 * 4 ), 4) )} ] ca_pos( 1 ) // root domain : (bS4{1}, iS5{4}) - validateResolution( + validateIELResolution( tv->getRootDomain().at(0), findTensorByName(all_tvs, 4)->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); } } + + // + checkStep2Results( + fusion.get(), + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } // Test root resolution with the same fusion as Indexing1 -TEST_F(IdModelTest, LoopGraphRootResolution5) { +TEST_F(IdModelTest, LoopPromotion5) { Fusion fusion; FusionGuard fg(&fusion); @@ -401,11 +558,10 @@ TEST_F(IdModelTest, LoopGraphRootResolution5) { auto all_tvs = ir_utils::allTvs(&fusion); IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getInitialIELPromotionMap(); - // Verify all tensors with broadcast have correct resolution of root - // broadcast domains + // Check Step 1 results for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( @@ -421,29 +577,37 @@ TEST_F(IdModelTest, LoopGraphRootResolution5) { // T3_l[ iS30{( ceilDiv(( ceilDiv(( ( ( 1 * i0 ) * i2 ) * i3 ), 128) ), // 4) )}, iUR31{4}, ithreadIdx.x29{128} ] ca_pos( 1 ) produce_pos( 1 ) // root domain : (bS10{1}, iS11{i0}, iS12{i2}, iS13{i3}) - validateResolution( + validateIELResolution( tv->getRootDomain().at(0), - findTensorByName(all_tvs, 4)->getRootDomain().at(0), + tv4->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); } } + + // Check Step 2 results + checkStep2Results( + &fusion, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } // Test root resolution with the same fusion as Indexing19 -TEST_F(IdModelTest, LoopGraphRootResolution6) { +TEST_F(IdModelTest, LoopPromotion6) { auto fusion = createFusionWithMultipleResolutionPaths(); + FusionGuard fg(fusion.get()); auto all_tvs = ir_utils::allTvs(fusion.get()); IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getInitialIELPromotionMap(); - // Verify all tensors with broadcast have correct resolution of root - // broadcast domains + // Check Step 1 results for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( @@ -460,10 +624,11 @@ TEST_F(IdModelTest, LoopGraphRootResolution6) { // iS48{5} ] ca_pos( 1 ) produce_pos( 1 ) // root domain : (iS2{7}, bS3{1}) // Resolution: Resolved by the immediate consumer (T4) - validateResolution( + validateIELResolution( tv->getRootDomain().at(1), findTensorByName(all_tvs, 4)->getRootDomain().at(1), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; case 5: @@ -473,10 +638,11 @@ TEST_F(IdModelTest, LoopGraphRootResolution6) { // Resolution: T5 is not inlined to the immediate consumer, // T10. Resolution is done with the other path from T1, such // as T8 or T9. - validateResolution( + validateIELResolution( tv->getRootDomain().at(2), findTensorByName(all_tvs, 9)->getRootDomain().at(2), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; case 6: @@ -484,10 +650,11 @@ TEST_F(IdModelTest, LoopGraphRootResolution6) { // iS63{5} ] ca_pos( 1 ) produce_pos( 1 ) // root domain : (iS11{7}, bS12{1}) // Resolution: Resolved by the immediate consumer (T8) - validateResolution( + validateIELResolution( tv->getRootDomain().at(1), findTensorByName(all_tvs, 8)->getRootDomain().at(1), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; case 9: @@ -497,20 +664,27 @@ TEST_F(IdModelTest, LoopGraphRootResolution6) { // Resolution: T9 is not inlined to the immediate consumer, // T10. Resolution is done with the other path from T1, such // as T4 or T5 - validateResolution( + validateIELResolution( tv->getRootDomain().at(1), findTensorByName(all_tvs, 5)->getRootDomain().at(1), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); } } + + checkStep2Results( + fusion.get(), + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } // Same fusion as NvFuserTest.FusionInlineBroadcastIndexing0 -TEST_F(IdModelTest, LoopGraphRootResolution7) { +TEST_F(IdModelTest, LoopPromotion7) { Fusion fusion; FusionGuard fg(&fusion); @@ -537,11 +711,10 @@ TEST_F(IdModelTest, LoopGraphRootResolution7) { auto all_tvs = ir_utils::allTvs(&fusion); IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getInitialIELPromotionMap(); - // Verify all tensors with broadcast have correct resolution of root - // broadcast domains + // Verify all tensors with root broadcast have correct resolutions for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( @@ -556,20 +729,27 @@ TEST_F(IdModelTest, LoopGraphRootResolution7) { case 3: // T3_l[ iS15{( ceilDiv(( 1 * i0 ), 32) )}, iS16{32} ] ca_pos( 1 ) // produce_pos( 1 ) root domain : (bS4{1}, iS5{i0}) - validateResolution( + validateIELResolution( tv->getRootDomain().at(0), findTensorByName(all_tvs, 4)->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); } } + + checkStep2Results( + &fusion, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } // Same fusion as NvFuserTest.FusionIndexing20 -TEST_F(IdModelTest, LoopGraphRootResolution8) { +TEST_F(IdModelTest, LoopPromotion8) { Fusion fusion; FusionGuard fg(&fusion); @@ -614,11 +794,10 @@ TEST_F(IdModelTest, LoopGraphRootResolution8) { auto all_tvs = ir_utils::allTvs(&fusion); IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getInitialIELPromotionMap(); - // Verify all tensors with broadcast have correct resolution of root - // broadcast domains + // Verify all tensors with root broadcast have correct resolutions for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( @@ -633,26 +812,34 @@ TEST_F(IdModelTest, LoopGraphRootResolution8) { case 2: // T2_l[ iS21{2}, iS22{( ceilDiv(( 1 * 5 ), 2) )} ] ca_pos( 1 ) // produce_pos( 1 ) root domain : (bS2{1}, iS3{5}) - validateResolution( + validateIELResolution( tv->getRootDomain().at(0), findTensorByName(all_tvs, 7)->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; case 5: // T5_l[ iS27{2}, iS40{4}, iS41{( ceilDiv(( ( ceilDiv(( 3 * 5 ), 2) ) * // 1 ), 4) )} ] ca_pos( 2 ) produce_pos( 1 ) root domain : (iS8{3}, // iS9{5}, bS10{1}) - validateResolution( + validateIELResolution( tv->getRootDomain().at(2), findTensorByName(all_tvs, 7)->getRootDomain().at(2), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); } } + + checkStep2Results( + &fusion, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } } // namespace nvfuser From 04458cb062e9e8fa2e8c80cb9b59c59cf2b5350f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 31 Jan 2024 19:58:47 -0800 Subject: [PATCH 134/178] rename --- test/test_id_model.cpp | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index e33f51ffd21..8e116b9b24e 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -52,11 +52,12 @@ class IdModelTester : public IdModel { // Do not automatically build the graphs IdModelTester(Fusion* fusion) : IdModel(fusion, /* build_graphs */ false) {} + // Returns the IEL graph and the results of Steps 1 and 2 std::tuple< ValGraph, std::unordered_map, std::unordered_map> - getInitialIELPromotionMap() { + getLoopPromotionInfo() { // Make sure the depedent graphs are already built maybeBuildGraph(IdMappingMode::EXACT); maybeBuildGraph(IdMappingMode::PERMISSIVE); @@ -337,7 +338,7 @@ TEST_F(IdModelTest, LoopPromotion1) { { IdModelTester tester(fusion.get()); const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getInitialIELPromotionMap(); + tester.getLoopPromotionInfo(); // Nothing inlined. Should be no resolution ASSERT_TRUE(root_resolution_map.empty()); @@ -349,7 +350,7 @@ TEST_F(IdModelTest, LoopPromotion1) { { IdModelTester tester(fusion.get()); const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getInitialIELPromotionMap(); + tester.getLoopPromotionInfo(); // Check Step 1 results // t2 is now fully inlined. Its root broadcast domain should be @@ -388,7 +389,7 @@ TEST_F(IdModelTest, LoopPromotion2) { IdModelTester tester(fusion.get()); const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getInitialIELPromotionMap(); + tester.getLoopPromotionInfo(); // Check Step 1 results // Validate t2 and t3 as they have root broadcast domains @@ -441,7 +442,7 @@ TEST_F(IdModelTest, LoopPromotion3) { IdModelTester tester(fusion.get()); const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getInitialIELPromotionMap(); + tester.getLoopPromotionInfo(); // Check Step 1 results // The b1 broadcast domain tv2 should be resolved as it's inlined, @@ -483,7 +484,7 @@ TEST_F(IdModelTest, LoopPromotion4) { IdModelTester tester(fusion.get()); const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getInitialIELPromotionMap(); + tester.getLoopPromotionInfo(); // Verify all tensors with root broadcast have correct resolutions for (auto tv : all_tvs) { @@ -559,7 +560,7 @@ TEST_F(IdModelTest, LoopPromotion5) { IdModelTester tester(&fusion); const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getInitialIELPromotionMap(); + tester.getLoopPromotionInfo(); // Check Step 1 results for (auto tv : all_tvs) { @@ -605,7 +606,7 @@ TEST_F(IdModelTest, LoopPromotion6) { IdModelTester tester(fusion.get()); const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getInitialIELPromotionMap(); + tester.getLoopPromotionInfo(); // Check Step 1 results for (auto tv : all_tvs) { @@ -712,7 +713,7 @@ TEST_F(IdModelTest, LoopPromotion7) { IdModelTester tester(&fusion); const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getInitialIELPromotionMap(); + tester.getLoopPromotionInfo(); // Verify all tensors with root broadcast have correct resolutions for (auto tv : all_tvs) { @@ -795,7 +796,7 @@ TEST_F(IdModelTest, LoopPromotion8) { IdModelTester tester(&fusion); const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getInitialIELPromotionMap(); + tester.getLoopPromotionInfo(); // Verify all tensors with root broadcast have correct resolutions for (auto tv : all_tvs) { From 6a112a3e993539447e67836e178934c6c0a33eb7 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 1 Feb 2024 10:06:27 -0800 Subject: [PATCH 135/178] Cleaning up visitor --- csrc/id_model/id_model.cpp | 4 +-- csrc/id_model/visitor.cpp | 45 ++++++++++++------------ csrc/id_model/visitor.h | 71 +++++++++++++++++++------------------- 3 files changed, 60 insertions(+), 60 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index ce58b6eae61..2e7c33ef95b 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -1575,7 +1575,7 @@ void IdModel::propagatePromotionsInIELGraph( bool require_loop_mapped_promotion) { // In order to make this traversal work, the traversal order must be // topologically sorted. - IdGraphStmtSort iel_stmt_sort(iel_graph); + ValGraphStmtSort iel_stmt_sort(iel_graph); // TODO-NM: The ordering might be non-deterministic @@ -1827,7 +1827,7 @@ std::unordered_map computeCoveredGroups( } } - IdGraphStmtSort exact_stmt_sort(graph); + ValGraphStmtSort exact_stmt_sort(graph); for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) { std::vector input_groups = graph.inputGroups(exact_expr); diff --git a/csrc/id_model/visitor.cpp b/csrc/id_model/visitor.cpp index f81ce4177d9..2491b20e878 100644 --- a/csrc/id_model/visitor.cpp +++ b/csrc/id_model/visitor.cpp @@ -9,20 +9,20 @@ namespace nvfuser { -void IdGraphVisitor::traverse() { - ValGroups all_ids; +void ValGraphVisitor::traverse() { + ValGroups all_vals; ExprGroups all_exprs; { - // Initialize IDs to traverse. If sub_selection is provided, only - // traverse IDs that are included in the set are traversed. + // Initialize Vals to traverse. If sub_selection is provided, only + // traverse Vals that are included in the set are traversed. if (sub_selection_.empty()) { - all_ids = ValGroups( + all_vals = ValGroups( graph().disjointValSets().disjointSets().begin(), graph().disjointValSets().disjointSets().end()); } else { - for (auto id : sub_selection_) { - if (graph().hasGroup(id)) { - all_ids.pushBack(graph().toGroup(id)); + for (auto val : sub_selection_) { + if (graph().hasGroup(val)) { + all_vals.pushBack(graph().toGroup(val)); } } } @@ -36,22 +36,22 @@ void IdGraphVisitor::traverse() { graph().disjointExprSets().disjointSets().begin(), graph().disjointExprSets().disjointSets().end()); } else { - for (const ValGroup& id_group : all_ids) { - for (const ExprGroup& def : graph().getDefinitions(id_group)) { + for (const ValGroup& val_group : all_vals) { + for (const ExprGroup& def : graph().getDefinitions(val_group)) { if (all_exprs.has(def)) { continue; } auto inp_groups = ValGroups(graph().inputGroups(def)); auto out_groups = ValGroups(graph().outputGroups(def)); - if (inp_groups.computeSubtract(all_ids).empty() && - out_groups.computeSubtract(all_ids).empty()) { + if (inp_groups.computeSubtract(all_vals).empty() && + out_groups.computeSubtract(all_vals).empty()) { all_exprs.pushBack(def); } } } } } - // There could be IterDomains in from or to that are between other from and + // There could be Vals in from or to that are between other from and // to nodes. Make sure to clear those out. ValGroups terminating_inputs; ValGroups terminating_outputs; @@ -69,11 +69,9 @@ void IdGraphVisitor::traverse() { not_outputs.pushBack(graph().inputGroups(expr_group)); } - terminating_inputs = - ValGroups(all_ids.begin(), all_ids.end()).computeSubtract(not_inputs); + terminating_inputs = all_vals.computeSubtract(not_inputs); - terminating_outputs = - ValGroups(all_ids.begin(), all_ids.end()).computeSubtract(not_outputs); + terminating_outputs = all_vals.computeSubtract(not_outputs); } ValGroups to_visit_ids = terminating_inputs; @@ -82,7 +80,7 @@ void IdGraphVisitor::traverse() { ExprGroups to_visit_exprs; ExprGroups visited_exprs; - auto is_expr_ready = [&](const ExprGroup& expr_group) { + auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool { auto inp_groups = graph().inputGroups(expr_group); return std::all_of( inp_groups.begin(), inp_groups.end(), [&](ValGroup id_group) { @@ -90,8 +88,8 @@ void IdGraphVisitor::traverse() { }); }; - auto is_id_ready = [&](const ValGroup& id_group) { - const ExprGroups& unique_defs = graph().getDefinitions(id_group); + auto is_val_ready = [&](const ValGroup& val_group) -> bool { + const ExprGroups& unique_defs = graph().getDefinitions(val_group); return std::all_of( unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { return expr_group->empty() || visited_exprs.has(expr_group) || @@ -100,8 +98,8 @@ void IdGraphVisitor::traverse() { }; while (!to_visit_ids.empty() || !to_visit_exprs.empty()) { - // Process expressions first as all definitions of iter domains have to be - // processed before we can process that iter domain. + // Process expressions first as all definitions of vals have to be + // processed before we can process that val. // Detect if nothing has been processed which would put us in an infinite // loop @@ -137,7 +135,7 @@ void IdGraphVisitor::traverse() { continue; } - if (is_id_ready(current_id_group)) { + if (is_val_ready(current_id_group)) { handle(current_id_group); something_was_processed = true; @@ -159,4 +157,5 @@ void IdGraphVisitor::traverse() { "Infinite loop entered."); } } + } // namespace nvfuser diff --git a/csrc/id_model/visitor.h b/csrc/id_model/visitor.h index 2c13b9efae0..c10f70a192a 100644 --- a/csrc/id_model/visitor.h +++ b/csrc/id_model/visitor.h @@ -13,8 +13,8 @@ namespace nvfuser { -// Iterates through an IterDomain Graph in topological order, calling handle on -// all Id and all Expr groups in a forward topological order. +// Iterates through a Val Graph in topological order, calling handle on +// all Val and all Expr groups in a forward topological order. // // Warning: Expr groups that have an input and output in the same ValGroup are // ignored. @@ -22,50 +22,50 @@ namespace nvfuser { // Warning: This is not a great iterator if there's a desire to minimize paths // traveled to simply visit all ValGroups in order. See ExprsBetween to see how // we might minimize paths. -class IdGraphVisitor { +class ValGraphVisitor { public: - IdGraphVisitor() = delete; + ValGraphVisitor() = delete; - IdGraphVisitor& operator=(const IdGraphVisitor& other) = delete; + ValGraphVisitor& operator=(const ValGraphVisitor& other) = delete; - IdGraphVisitor& operator=(IdGraphVisitor&& other) = delete; + ValGraphVisitor& operator=(ValGraphVisitor&& other) = delete; - virtual ~IdGraphVisitor() = default; + virtual ~ValGraphVisitor() = default; protected: - // If sub_selection is assumed to be a set of iter domains by which form a - // sub-regrion of the IdGraph provided. Only that sub-region will be visited. - IdGraphVisitor( - const ValGraph& id_graph, - const VectorOfUniqueEntries sub_selection = {}) - : id_graph_(id_graph), sub_selection_(sub_selection) {} + // If sub_selection is assumed to be a set of vals by which form a + // sub-regrion of the ValGraph provided. Only that sub-region will be visited. + ValGraphVisitor( + const ValGraph& val_graph, + const VectorOfUniqueEntries sub_selection = {}) + : val_graph_(val_graph), sub_selection_(sub_selection) {} - IdGraphVisitor(const IdGraphVisitor& other) = default; + ValGraphVisitor(const ValGraphVisitor& other) = default; - IdGraphVisitor(IdGraphVisitor&& other) = default; + ValGraphVisitor(ValGraphVisitor&& other) = default; - virtual void handle(ValGroup id_group) = 0; - virtual void handle(ExprGroup expr_group) = 0; + virtual void handle(const ValGroup& id_group) = 0; + virtual void handle(const ExprGroup& expr_group) = 0; void traverse(); const ValGraph& graph() { - return id_graph_; + return val_graph_; }; private: - const ValGraph& id_graph_; - const VectorOfUniqueEntries sub_selection_; + const ValGraph& val_graph_; + const VectorOfUniqueEntries sub_selection_; }; -// Statement sorting based on IdGraphVisitor, see warnings to IdGraph Visitor. -class IdGraphStmtSort : public IdGraphVisitor { +// Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor. +class ValGraphStmtSort : public ValGraphVisitor { public: - IdGraphStmtSort( - const ValGraph& id_graph, - const VectorOfUniqueEntries sub_selection = {}) - : IdGraphVisitor(id_graph, sub_selection) { - IdGraphVisitor::traverse(); + ValGraphStmtSort( + const ValGraph& val_graph, + const VectorOfUniqueEntries sub_selection = {}) + : ValGraphVisitor(val_graph, sub_selection) { + ValGraphVisitor::traverse(); } // Return non-reference so that code like below can work @@ -74,24 +74,25 @@ class IdGraphStmtSort : public IdGraphVisitor { return sorted_exprs_; } - ValGroups ids() const { - return sorted_ids_; + ValGroups vals() const { + return sorted_vals_; } - ~IdGraphStmtSort() override = default; + ~ValGraphStmtSort() override = default; protected: - using IdGraphVisitor::handle; - void handle(ValGroup id_group) override { - sorted_ids_.pushBack(id_group); + using ValGraphVisitor::handle; + + void handle(const ValGroup& val_group) override { + sorted_vals_.pushBack(val_group); } - void handle(ExprGroup expr_group) override { + void handle(const ExprGroup& expr_group) override { sorted_exprs_.pushBack(expr_group); } ExprGroups sorted_exprs_; - ValGroups sorted_ids_; + ValGroups sorted_vals_; }; } // namespace nvfuser From f05599409add4bfd4cccd480d2dce0f0c2092c9c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 1 Feb 2024 10:11:33 -0800 Subject: [PATCH 136/178] Move ValGraphVisitor out of id_model --- CMakeLists.txt | 2 +- csrc/id_model/id_model.cpp | 2 +- csrc/{id_model/visitor.cpp => val_graph_visitor.cpp} | 2 +- csrc/{id_model/visitor.h => val_graph_visitor.h} | 0 4 files changed, 3 insertions(+), 3 deletions(-) rename csrc/{id_model/visitor.cpp => val_graph_visitor.cpp} (99%) rename csrc/{id_model/visitor.h => val_graph_visitor.h} (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 403232a8343..015ba732eb8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,7 +85,6 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/id_model/to_string.cpp ${NVFUSER_SRCS_DIR}/id_model/transform_replay.cpp ${NVFUSER_SRCS_DIR}/id_model/validation_utils.cpp - ${NVFUSER_SRCS_DIR}/id_model/visitor.cpp ${NVFUSER_SRCS_DIR}/index_compute.cpp ${NVFUSER_SRCS_DIR}/instrumentation.cpp ${NVFUSER_SRCS_DIR}/ir/base_nodes.cpp @@ -204,6 +203,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp ${NVFUSER_SRCS_DIR}/val_graph.cpp + ${NVFUSER_SRCS_DIR}/val_graph_visitor.cpp ) # We don't link CUPTI for MSVC diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 2e7c33ef95b..6083dde8426 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include @@ -19,6 +18,7 @@ #include #include #include +#include #include #include diff --git a/csrc/id_model/visitor.cpp b/csrc/val_graph_visitor.cpp similarity index 99% rename from csrc/id_model/visitor.cpp rename to csrc/val_graph_visitor.cpp index 2491b20e878..fe259dce27d 100644 --- a/csrc/id_model/visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include +#include namespace nvfuser { diff --git a/csrc/id_model/visitor.h b/csrc/val_graph_visitor.h similarity index 100% rename from csrc/id_model/visitor.h rename to csrc/val_graph_visitor.h From 4781d65377cacf3eea48f3d6136ea27c5b39c840 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 1 Feb 2024 11:53:19 -0800 Subject: [PATCH 137/178] Tests for ValGraphStmtSort --- csrc/disjoint_set.h | 8 +++++ test/test_id_model.cpp | 69 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index de1c2b9ffd7..e7fd5df7ce0 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -275,6 +275,14 @@ class VectorOfUniqueEntries { return vector_.end(); } + T& at(size_t pos) { + return vector_.at(pos); + } + + const T& at(size_t pos) const { + return vector_.at(pos); + } + std::string toString() const { std::stringstream ss; ss << "{ "; diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 8e116b9b24e..1beff601559 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -20,6 +20,7 @@ #include #include #include +#include namespace nvfuser { @@ -843,4 +844,72 @@ TEST_F(IdModelTest, LoopPromotion8) { iel_promotion_map); } +namespace { + +// Check the results of ValGraphStmtSort +void checkSortingResults( + const ValGraph& graph, + const ExprGroups& sorted_expr_groups, + const ValGroups& sorted_val_groups, + const std::vector& ref_expr_order) { + ASSERT_EQ(sorted_expr_groups.size(), ref_expr_order.size()) + << "Expected " << ref_expr_order.size() << " expr group(s) but received " + << sorted_expr_groups.size() << " group(s)"; + + for (const auto i : c10::irange(sorted_expr_groups.size())) { + auto ref_expr = ref_expr_order.at(i); + const ExprGroup& eg = sorted_expr_groups.at(i); + ASSERT_TRUE(eg->has(ref_expr)) + << "Unexpected ordering of expr groups detected at " << i + << "-th group: " << nvfuser::toString(eg) << ": " + << eg->front()->toString(); + } + + // Checking the order of the expr groups should be likely just + // sufficient. Just make sure the sorted val groups cover all the + // val groups in the graph. + const std::unordered_set& ref_val_group_set{ + graph.disjointValSets().disjointSets().begin(), + graph.disjointValSets().disjointSets().end()}; + + std::unordered_set sorted_val_group_set{ + sorted_val_groups.begin(), sorted_val_groups.end()}; + + ASSERT_EQ(sorted_val_group_set, ref_val_group_set) << "Mismatched ValGroups."; +} + +} // namespace + +// Sorting test with a trivial fusion +TEST_F(IdModelTest, ValGraphStmtSort1) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + tv2->merge(0)->split(0, 4); + + TransformPropagator propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + IdModel id_model(&fusion); + + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + const ExprGroups& sorted_exprs = vg_stmt_sort.exprs(); + + // Reference expr order: merge, split + std::vector ref_sorted_exprs{ + tv2->axis(0)->definition()->input(0)->definition(), + tv2->axis(0)->definition()}; + + checkSortingResults( + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); +} + } // namespace nvfuser From 79623380c45781a4e913f52601f0a39813fceaf4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 1 Feb 2024 13:18:46 -0800 Subject: [PATCH 138/178] WIP --- test/test_id_model.cpp | 198 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 176 insertions(+), 22 deletions(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 1beff601559..1d7d2e15b21 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -846,36 +846,72 @@ TEST_F(IdModelTest, LoopPromotion8) { namespace { -// Check the results of ValGraphStmtSort +// Check the results of ValGraphStmtSort. Ordering check is only +// implemented for ExprGroups for now as it's likely sufficient. +// +// ref_expr_orders: a list of expr pairs. Each pair indicates the +// first expr must show up before the second expr. void checkSortingResults( const ValGraph& graph, const ExprGroups& sorted_expr_groups, const ValGroups& sorted_val_groups, - const std::vector& ref_expr_order) { - ASSERT_EQ(sorted_expr_groups.size(), ref_expr_order.size()) - << "Expected " << ref_expr_order.size() << " expr group(s) but received " - << sorted_expr_groups.size() << " group(s)"; - - for (const auto i : c10::irange(sorted_expr_groups.size())) { - auto ref_expr = ref_expr_order.at(i); - const ExprGroup& eg = sorted_expr_groups.at(i); - ASSERT_TRUE(eg->has(ref_expr)) - << "Unexpected ordering of expr groups detected at " << i - << "-th group: " << nvfuser::toString(eg) << ": " - << eg->front()->toString(); + const std::vector>& ref_expr_orders) { + + { + std::cerr << "Sorted EG:\n"; + for (const auto& eg: sorted_expr_groups) { + std::cerr << nvfuser::toString(eg) << ": " << eg->front()->toString(); + } + std::cerr << "All EGs:\n"; + for (const auto& eg: graph.disjointExprSets().disjointSets()) { + std::cerr << nvfuser::toString(eg) << ": " << eg->front()->toString(); + } } - // Checking the order of the expr groups should be likely just - // sufficient. Just make sure the sorted val groups cover all the - // val groups in the graph. + // Make sure sorted_val_groups cover all Expr groups + const std::unordered_set& ref_expr_group_set{ + graph.disjointExprSets().disjointSets().begin(), + graph.disjointExprSets().disjointSets().end()}; + std::unordered_set sorted_expr_group_set{ + sorted_expr_groups.begin(), sorted_expr_groups.end()}; + ASSERT_EQ(sorted_expr_group_set, ref_expr_group_set) << "Mismatched ExprGroups."; + + // Make sure sorted_val_groups covers all Val groups const std::unordered_set& ref_val_group_set{ graph.disjointValSets().disjointSets().begin(), graph.disjointValSets().disjointSets().end()}; - std::unordered_set sorted_val_group_set{ sorted_val_groups.begin(), sorted_val_groups.end()}; - ASSERT_EQ(sorted_val_group_set, ref_val_group_set) << "Mismatched ValGroups."; + + // Convert the expr order to an ExprGroup order. Maps ExprGroup to + // its dependent ExprGroups that must show up before + std::unordered_map ref_dependencies; + + for (const auto& [expr1, expr2]: ref_expr_orders) { + // expr1 must show up before expr2 + const ExprGroup& eg1 = graph.toGroup(expr1); + const ExprGroup& eg2 = graph.toGroup(expr2); + ref_dependencies[eg2].pushBack(eg1); + } + + ExprGroups visited_expr_groups; + for (const ExprGroup& eg: sorted_expr_groups) { + std::cerr << "Visiting " << nvfuser::toString(eg) << std::endl; + if (auto it = ref_dependencies.find(eg); + it != ref_dependencies.end()) { + const ExprGroups& dep_expr_groups = it->second; + // Make sure all dep_expr_groups have been visited + for (const ExprGroup& dep_expr_group : dep_expr_groups) { + ASSERT_TRUE(visited_expr_groups.has(dep_expr_group)) + << "Invalid ordering detected at " + << nvfuser::toString(eg) + << ". Dependent expr group not visited yet: " + << nvfuser::toString(dep_expr_group); + } + } + visited_expr_groups.pushBack(eg); + } } } // namespace @@ -892,24 +928,142 @@ TEST_F(IdModelTest, ValGraphStmtSort1) { auto tv2 = add(tv0, tv1); fusion.addOutput(tv2); + // No ID expr yet + { + IdModel id_model(&fusion); + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + checkSortingResults( + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), {}); + } + tv2->merge(0)->split(0, 4); TransformPropagator propagator(tv2); MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + { + IdModel id_model(&fusion); + + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + // Reference expr order: merge, split + std::vector> ref_sorted_exprs{ + {tv2->axis(0)->definition()->input(0)->definition(), + tv2->axis(0)->definition()}}; + checkSortingResults( + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); + } +} + +// Sorting test wth a disconnected graph +TEST_F(IdModelTest, ValGraphStmtSort2) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + // Note that the two groups of tensors, {tv0, tv1} and {tv2, tv3}, + // are not connected + + for (auto tv: ir_utils::allTvs(&fusion)) { + tv->merge(0)->split(0, 4); + } + IdModel id_model(&fusion); const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); ValGraphStmtSort vg_stmt_sort(vg); - const ExprGroups& sorted_exprs = vg_stmt_sort.exprs(); // Reference expr order: merge, split - std::vector ref_sorted_exprs{ - tv2->axis(0)->definition()->input(0)->definition(), - tv2->axis(0)->definition()}; + std::vector> ref_sorted_exprs{ + {tv3->axis(0)->definition()->input(0)->definition(), + tv3->axis(0)->definition()}, + {tv1->axis(0)->definition()->input(0)->definition(), + tv1->axis(0)->definition()} + }; + checkSortingResults( + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); + + // Since there's no dependency between tv1 and tv3, the reverse + // should be valid too + std::reverse(ref_sorted_exprs.begin(), ref_sorted_exprs.end()); + checkSortingResults( + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); +} + +// Sorting with trivial ExprGroup, i.e., ExprGroup whose input and +// output are mapped as the same ValGroup. It's effectively a cyclic +// dependency and the graph is no longer a DAG. +TEST_F(IdModelTest, ValGraphStmtSort3) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + auto tv3 = makeSymbolicTensor(1); + fusion.addInput(tv3); + auto tv4 = set(tv3); + fusion.addOutput(tv4); + + + // In addition to the same schedules as done in the prior tests, + // does a split by one and later map the split input and the outer + // output as they should have the same extent. This is in fact done + // in the AlmostExact graph + for (auto tv: {tv0, tv1, tv2}) { + tv->merge(0)->split(0, 4); + tv->split(0, 1); + } + + // Also test an isolated trivial expr. Note that tv3 and tv4 are not + // connected with tv0, tv1 and tv2. + tv4->split(0, 1); + + fusion.print(); + + IdModel id_model(&fusion); + ValGraph vg = id_model.idGraph(IdMappingMode::EXACT); + + // Map the split-by-1 input and output + vg.mapVals(tv2->axis(0), tv2->axis(0)->definition()->input(0)); + vg.mapVals(tv4->axis(0), tv4->axis(0)->definition()->input(0)); + + ValGraphStmtSort vg_stmt_sort(vg); + + checkSortingResults( + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), {}); + +#if 0 + // Reference expr order: merge, split + std::vector> ref_sorted_exprs{ + {tv3->axis(0)->definition()->input(0)->definition(), + tv3->axis(0)->definition()}, + {tv1->axis(0)->definition()->input(0)->definition(), + tv1->axis(0)->definition()} + }; + checkSortingResults( + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); + // Since there's no dependency between tv1 and tv3, the reverse + // should be valid too + std::reverse(ref_sorted_exprs.begin(), ref_sorted_exprs.end()); checkSortingResults( vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); +#endif } } // namespace nvfuser From 1eebde4818bd82b83fbcd4f29403e8c7758e1e4f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 1 Feb 2024 17:02:43 -0800 Subject: [PATCH 139/178] Fix determinism bug --- csrc/id_model/id_model.cpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 6083dde8426..0614134c188 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -659,14 +659,29 @@ IterDomain* IdModel::cloneIterDomain(IterDomain* id) { ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) { ValGraph id_graph(propagate_through_exprs); + // To deterministically initialize the graph, the order of adding + // domains must be deterministic. Here, we sort all IDs by their + // names. + + std::vector all_ids; + all_ids.reserve(id_definitions_.size()); for (const auto& [id, defs] : id_definitions_) { + all_ids.push_back(id); + } + + std::sort( + all_ids.begin(), all_ids.end(), [](IterDomain* id1, IterDomain* id2) { + return id1->name() < id2->name(); + }); + + for (auto id : all_ids) { auto uses_it = id_uses_.find(id); NVF_ERROR( uses_it != id_uses_.end(), "Failed to initialize id: ", id->toString(), " as it's missing a definition entry."); - id_graph.initializeVal(id, defs, uses_it->second); + id_graph.initializeVal(id, id_definitions_.at(id), uses_it->second); } return id_graph; From 0e727087cf01063706a430261ef65108035624d6 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 1 Feb 2024 19:57:52 -0800 Subject: [PATCH 140/178] refactoring --- csrc/val_graph.cpp | 77 ++++++++++++++++++++++++++++++++++++++ csrc/val_graph.h | 4 ++ csrc/val_graph_visitor. | 1 + csrc/val_graph_visitor.cpp | 71 +++-------------------------------- 4 files changed, 87 insertions(+), 66 deletions(-) create mode 100644 csrc/val_graph_visitor. diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index f9fb9c4487a..7f4535be334 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -100,6 +100,83 @@ std::vector ValGraph::inputGroups(const ExprGroup& expr) const { return input_groups; } +ValGroups ValGraph::getTerminatingInputs(const VectorOfUniqueEntries& sub_selection) const { + // Initialize vals to traverse. If sub_selection is provided, only + // include the Val groups of the Vals in sub_selection + ValGroups all_vals; + if (sub_selection.empty()) { + all_vals = ValGroups( + disjointValSets().disjointSets().begin(), + disjointValSets().disjointSets().end()); + } else { + all_vals = toGroups(sub_selection); + } + + // Initialize exprs to traverse. If sub_selection is provided, + // only traverse exprs that are strictly contained within the provided + // sub_selection. Exprs are excluded if any of inputs or outputs + // is not in sub_selection. + ExprGroups all_exprs; + if (sub_selection.empty()) { + all_exprs = ExprGroups( + disjointExprSets().disjointSets().begin(), + disjointExprSets().disjointSets().end()); + } else { + for (const ValGroup& val_group : all_vals) { + for (const ExprGroup& def : getDefinitions(val_group)) { + if (all_exprs.has(def)) { + continue; + } + auto inp_groups = ValGroups(inputGroups(def)); + auto out_groups = ValGroups(outputGroups(def)); + if (inp_groups.computeSubtract(all_vals).empty() && + out_groups.computeSubtract(all_vals).empty()) { + all_exprs.pushBack(def); + } + } + } + } + + // Grab all vals that are not input, i.e., having a defining expr + // within all_exprs. + // + // Note that an input Val group may be mapped with an output + // group. For example, the AlmostExact graph maps an input of split + // with the outer output if the split factor is one. Such a Val + // group is considered a terminating input as long as the input has + // no defining expression. This is for the use case of + // ValGraphVisitor. + // + // Example: + // + // [i0, i1] + // split by 1 + // [i0/1, 1, i1] + // merge + // [i0/1, 1*i1] + // + // Here, i0 and i0/1 would create a Val group of {i0, i0/1} in the + // AlmostExact graph. This group has a defining expression of the + // split, but since it's a cyclic dependency, we ignore the + // expression and consider the Val group a terminating input. + + ValGroups not_inputs; + for (const ExprGroup& expr_group : all_exprs) { + const std::vector input_groups = inputGroups(expr_group); + const std::vector output_groups = outputGroups(expr_group); + std::unordered_set input_set{input_groups.begin(), input_groups.end()}; + + for (const ValGroup& output_group: output_groups) { + if (input_set.count(output_group)) { + continue; + } + not_inputs.pushBack(output_group); + } + } + + return all_vals.computeSubtract(not_inputs); +} + ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { DequeOfExprGroup to_visit; for (const ValGroup& of_val_group : of) { diff --git a/csrc/val_graph.h b/csrc/val_graph.h index c9570fc4b07..2a42f184fc0 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -122,6 +122,10 @@ class ValGraph { std::vector outputGroups(const ExprGroup& expr) const; std::vector inputGroups(const ExprGroup& expr) const; + // Return Val groups that have no definition exprs. If a set of Vals + // are provided, only the Vals in the set are considered. + ValGroups getTerminatingInputs(const VectorOfUniqueEntries& sub_selection = {}) const; + // Recursively traverses uses of the IdGroups in 'of' and returns all // ExprGroups that have a use in their definition of provided of IdGroups. ExprGroups allUsesOf(const ValGroups& of) const; diff --git a/csrc/val_graph_visitor. b/csrc/val_graph_visitor. new file mode 100644 index 00000000000..fa33c4e24c7 --- /dev/null +++ b/csrc/val_graph_visitor. @@ -0,0 +1 @@ +h \ No newline at end of file diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index fe259dce27d..b96e35216af 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -7,74 +7,12 @@ // clang-format on #include +#include + namespace nvfuser { void ValGraphVisitor::traverse() { - ValGroups all_vals; - ExprGroups all_exprs; - { - // Initialize Vals to traverse. If sub_selection is provided, only - // traverse Vals that are included in the set are traversed. - if (sub_selection_.empty()) { - all_vals = ValGroups( - graph().disjointValSets().disjointSets().begin(), - graph().disjointValSets().disjointSets().end()); - } else { - for (auto val : sub_selection_) { - if (graph().hasGroup(val)) { - all_vals.pushBack(graph().toGroup(val)); - } - } - } - - // Initialize exprs to traverse. If sub_selection is provided, - // only traverse exprs that are strictly contained within the provided - // sub_selection. Exprs are excluded if any of inputs or outputs - // is not in sub_selection. - if (sub_selection_.empty()) { - all_exprs = ExprGroups( - graph().disjointExprSets().disjointSets().begin(), - graph().disjointExprSets().disjointSets().end()); - } else { - for (const ValGroup& val_group : all_vals) { - for (const ExprGroup& def : graph().getDefinitions(val_group)) { - if (all_exprs.has(def)) { - continue; - } - auto inp_groups = ValGroups(graph().inputGroups(def)); - auto out_groups = ValGroups(graph().outputGroups(def)); - if (inp_groups.computeSubtract(all_vals).empty() && - out_groups.computeSubtract(all_vals).empty()) { - all_exprs.pushBack(def); - } - } - } - } - } - // There could be Vals in from or to that are between other from and - // to nodes. Make sure to clear those out. - ValGroups terminating_inputs; - ValGroups terminating_outputs; - - { - ValGroups not_inputs; - ValGroups not_outputs; - for (const ExprGroup& expr_group : all_exprs) { - if (graph().isTrivialExprGroup(expr_group)) { - // Expression is just a loop to its current group, ignore - continue; - } - - not_inputs.pushBack(graph().outputGroups(expr_group)); - not_outputs.pushBack(graph().inputGroups(expr_group)); - } - - terminating_inputs = all_vals.computeSubtract(not_inputs); - - terminating_outputs = all_vals.computeSubtract(not_outputs); - } - - ValGroups to_visit_ids = terminating_inputs; + ValGroups to_visit_ids = graph().getTerminatingInputs(sub_selection_); ValGroups visited_ids; ExprGroups to_visit_exprs; @@ -141,7 +79,8 @@ void ValGraphVisitor::traverse() { something_was_processed = true; visited_ids.pushBack(current_id_group); - if (!terminating_outputs.has(current_id_group)) { + //if (true || !terminating_outputs.has(current_id_group)) { + if (true) { const ExprGroups& uses = graph().getUses(current_id_group); to_visit_exprs.pushBack(uses); } From 7561d19071ae6231caaa32db45769a43a26a3b96 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 1 Feb 2024 18:21:54 -0800 Subject: [PATCH 141/178] test cleanup --- test/test_id_model.cpp | 216 ++++++++++++++++++++++++----------------- 1 file changed, 127 insertions(+), 89 deletions(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 1d7d2e15b21..135c963b02e 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -47,6 +47,17 @@ TEST_F(IdModelTest, DetectSelfMapping) { namespace { +// Get n-th parent expr traversing through the first input of each +// parent +Expr* getParentExpr(Val* val, int n) { + for (int i = 0; i < n - 1; ++i) { + NVF_ERROR(val->definition() != nullptr); + val = val->definition()->input(0); + } + NVF_ERROR(val->definition() != nullptr); + return val->definition(); +}; + // Helper class to test IdModel class IdModelTester : public IdModel { public: @@ -846,28 +857,18 @@ TEST_F(IdModelTest, LoopPromotion8) { namespace { -// Check the results of ValGraphStmtSort. Ordering check is only -// implemented for ExprGroups for now as it's likely sufficient. +// Check the results of ValGraphStmtSort. Only the ordering of +// ExprGroups is checked for now as it's likely sufficient. // -// ref_expr_orders: a list of expr pairs. Each pair indicates the -// first expr must show up before the second expr. +// ref_order: The order must be exactly the +// same as indicated by this list. While there can be different +// order that still satisfy the topologial ordering, we also need +// deterministic ordering, so the results should be always the same. void checkSortingResults( const ValGraph& graph, const ExprGroups& sorted_expr_groups, const ValGroups& sorted_val_groups, - const std::vector>& ref_expr_orders) { - - { - std::cerr << "Sorted EG:\n"; - for (const auto& eg: sorted_expr_groups) { - std::cerr << nvfuser::toString(eg) << ": " << eg->front()->toString(); - } - std::cerr << "All EGs:\n"; - for (const auto& eg: graph.disjointExprSets().disjointSets()) { - std::cerr << nvfuser::toString(eg) << ": " << eg->front()->toString(); - } - } - + const std::vector& ref_order) { // Make sure sorted_val_groups cover all Expr groups const std::unordered_set& ref_expr_group_set{ graph.disjointExprSets().disjointSets().begin(), @@ -884,33 +885,16 @@ void checkSortingResults( sorted_val_groups.begin(), sorted_val_groups.end()}; ASSERT_EQ(sorted_val_group_set, ref_val_group_set) << "Mismatched ValGroups."; - // Convert the expr order to an ExprGroup order. Maps ExprGroup to - // its dependent ExprGroups that must show up before - std::unordered_map ref_dependencies; - - for (const auto& [expr1, expr2]: ref_expr_orders) { - // expr1 must show up before expr2 - const ExprGroup& eg1 = graph.toGroup(expr1); - const ExprGroup& eg2 = graph.toGroup(expr2); - ref_dependencies[eg2].pushBack(eg1); - } - - ExprGroups visited_expr_groups; - for (const ExprGroup& eg: sorted_expr_groups) { - std::cerr << "Visiting " << nvfuser::toString(eg) << std::endl; - if (auto it = ref_dependencies.find(eg); - it != ref_dependencies.end()) { - const ExprGroups& dep_expr_groups = it->second; - // Make sure all dep_expr_groups have been visited - for (const ExprGroup& dep_expr_group : dep_expr_groups) { - ASSERT_TRUE(visited_expr_groups.has(dep_expr_group)) - << "Invalid ordering detected at " - << nvfuser::toString(eg) - << ". Dependent expr group not visited yet: " - << nvfuser::toString(dep_expr_group); - } - } - visited_expr_groups.pushBack(eg); + // Check the ordering + ASSERT_EQ(sorted_expr_groups.size(), ref_order.size()); + for (const auto i : c10::irange(ref_order.size())) { + Expr* ref_expr = ref_order.at(i); + const ExprGroup& eg = sorted_expr_groups.at(i); + ASSERT_TRUE(eg->has(ref_expr)) + << "Expected: " + << nvfuser::toString(graph.toGroup(ref_expr)) + << ". Actual: " + << nvfuser::toString(eg); } } @@ -947,12 +931,14 @@ TEST_F(IdModelTest, ValGraphStmtSort1) { const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); ValGraphStmtSort vg_stmt_sort(vg); + // Reference expr order: merge, split - std::vector> ref_sorted_exprs{ - {tv2->axis(0)->definition()->input(0)->definition(), - tv2->axis(0)->definition()}}; + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + checkSortingResults( - vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); } } @@ -983,21 +969,14 @@ TEST_F(IdModelTest, ValGraphStmtSort2) { const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); ValGraphStmtSort vg_stmt_sort(vg); - // Reference expr order: merge, split - std::vector> ref_sorted_exprs{ - {tv3->axis(0)->definition()->input(0)->definition(), - tv3->axis(0)->definition()}, - {tv1->axis(0)->definition()->input(0)->definition(), - tv1->axis(0)->definition()} - }; - checkSortingResults( - vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); + std::vector ref_order; + ref_order.push_back(getParentExpr(tv1->axis(0), 2)); + ref_order.push_back(getParentExpr(tv3->axis(0), 2)); + ref_order.push_back(getParentExpr(tv1->axis(0), 1)); + ref_order.push_back(getParentExpr(tv3->axis(0), 1)); - // Since there's no dependency between tv1 and tv3, the reverse - // should be valid too - std::reverse(ref_sorted_exprs.begin(), ref_sorted_exprs.end()); checkSortingResults( - vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); } // Sorting with trivial ExprGroup, i.e., ExprGroup whose input and @@ -1014,26 +993,19 @@ TEST_F(IdModelTest, ValGraphStmtSort3) { auto tv2 = add(tv0, tv1); fusion.addOutput(tv2); - auto tv3 = makeSymbolicTensor(1); + auto tv3 = makeSymbolicTensor(2); fusion.addInput(tv3); auto tv4 = set(tv3); fusion.addOutput(tv4); - - // In addition to the same schedules as done in the prior tests, - // does a split by one and later map the split input and the outer - // output as they should have the same extent. This is in fact done - // in the AlmostExact graph + // Merge adn split by one. The split input and output will be mapped. for (auto tv: {tv0, tv1, tv2}) { - tv->merge(0)->split(0, 4); - tv->split(0, 1); + tv->merge(0)->split(0, 1); } // Also test an isolated trivial expr. Note that tv3 and tv4 are not // connected with tv0, tv1 and tv2. - tv4->split(0, 1); - - fusion.print(); + tv4->merge(0)->split(0, 1); IdModel id_model(&fusion); ValGraph vg = id_model.idGraph(IdMappingMode::EXACT); @@ -1044,26 +1016,92 @@ TEST_F(IdModelTest, ValGraphStmtSort3) { ValGraphStmtSort vg_stmt_sort(vg); + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv4->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + ref_order.push_back(getParentExpr(tv4->axis(0), 1)); + checkSortingResults( - vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), {}); - -#if 0 - // Reference expr order: merge, split - std::vector> ref_sorted_exprs{ - {tv3->axis(0)->definition()->input(0)->definition(), - tv3->axis(0)->definition()}, - {tv1->axis(0)->definition()->input(0)->definition(), - tv1->axis(0)->definition()} - }; - checkSortingResults( - vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + +// Sorting test with the same fusion as Indexing19 +TEST_F(IdModelTest, ValGraphStmtSort4) { + auto fusion = createFusionWithMultipleResolutionPaths(); + FusionGuard fg(fusion.get()); + auto all_tvs = ir_utils::allTvs(fusion.get()); + + IdModel id_model(fusion.get(), true, false, false); + + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + + ValGraphStmtSort vg_stmt_sort(vg); + + auto tv1 = findTensorByName(all_tvs, 1); + auto tv2 = findTensorByName(all_tvs, 2); + auto tv4 = findTensorByName(all_tvs, 4); + auto tv5 = findTensorByName(all_tvs, 5); + auto tv6 = findTensorByName(all_tvs, 6); + auto tv8 = findTensorByName(all_tvs, 8); + auto tv9 = findTensorByName(all_tvs, 9); + auto tv10 = findTensorByName(all_tvs, 10); + + // Expected reference order: + // + // exprg{39}: Merge: iS2{7} and bS3{1} -> iS46{( 7 * 1 )} + // exprg{57}: Merge: iS11{7} and bS12{1} -> iS61{( 7 * 1 )} + // exprg{17}: Merge: iS17{7} and bS18{1} -> iS29{( 7 * 1 )} + // exprg{69 73 89}: Split: iS1{7} by factor 5 -> iS71{( ceilDiv(7, 5) )}, iS72{5}, start offset: 0, stop offset: 0 + // exprg{51 63 93}: Merge: iS15{7} and iS16{13} -> iS56{( 7 * 13 )} + // exprg{9 25 33 45 91 95}: Merge: iS20{7} and iS21{11} -> iS23{( 7 * 11 )} + // exprg{27}: Merge: iS35{( 7 * 11 )} and bS10{1} -> iS36{( ( 7 * 11 ) * 1 )} + // exprg{19}: Merge: iS29{( 7 * 1 )} and iS19{13} -> iS30{( ( 7 * 1 ) * 13 )} + // exprg{11 77 79 99}: Merge: iS23{( 7 * 11 )} and iS22{13} -> iS24{( ( 7 * 11 ) * 13 )} + // exprg{41}: Split: iS46{( 7 * 1 )} by factor 5 -> iS47{( ceilDiv(( 7 * 1 ), 5) )}, iS48{5}, start offset: 0, stop offset: 0 + // exprg{59}: Split: iS61{( 7 * 1 )} by factor 5 -> iS62{( ceilDiv(( 7 * 1 ), 5) )}, iS63{5}, start offset: 0, stop offset: 0 + // exprg{71 75 101}: Split: iS71{( ceilDiv(7, 5) )} by factor 3 -> iS73{( ceilDiv(( ceilDiv(7, 5) ), 3) )}, iS74{3}, start offset: 0, stop offset: 0 + // exprg{53 65 109}: Split: iS56{( 7 * 13 )} by factor 5 -> iS57{( ceilDiv(( 7 * 13 ), 5) )}, iS58{5}, start offset: 0, stop offset: 0 + // exprg{35 47 105}: Split: iS41{( 7 * 11 )} by factor 5 -> iS42{( ceilDiv(( 7 * 11 ), 5) )}, iS43{5}, start offset: 0, stop offset: 0 + // exprg{29}: Split: iS36{( ( 7 * 11 ) * 1 )} by factor 5 -> iS37{( ceilDiv(( ( 7 * 11 ) * 1 ), 5) )}, iS38{5}, start offset: 0, stop offset: 0 + // exprg{21}: Split: iS30{( ( 7 * 1 ) * 13 )} by factor 5 -> iS31{( ceilDiv(( ( 7 * 1 ) * 13 ), 5) )}, iS32{5}, start offset: 0, stop offset: 0 + // exprg{13 81 83 97 103 107 111 115 117 119 121}: Split: iS24{( ( 7 * 11 ) * 13 )} by factor 5 -> iS25{( ceilDiv(( ( 7 * 11 ) * 13 ), 5) )}, iS26{5}, start offset: 0, stop offset: 0 + // exprg{43}: Split: iS47{( ceilDiv(( 7 * 1 ), 5) )} by factor 3 -> iS49{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS50{3}, start offset: 0, stop offset: 0 + // exprg{61}: Split: iS62{( ceilDiv(( 7 * 1 ), 5) )} by factor 3 -> iS64{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS65{3}, start offset: 0, stop offset: 0 + // exprg{55 67 129}: Split: iS57{( ceilDiv(( 7 * 13 ), 5) )} by factor 3 -> iS59{( ceilDiv(( ceilDiv(( 7 * 13 ), 5) ), 3) )}, iS60{3}, start offset: 0, stop offset: 0 + // exprg{37 49 125}: Split: iS42{( ceilDiv(( 7 * 11 ), 5) )} by factor 3 -> iS44{( ceilDiv(( ceilDiv(( 7 * 11 ), 5) ), 3) )}, iS45{3}, start offset: 0, stop offset: 0 + // exprg{31}: Split: iS37{( ceilDiv(( ( 7 * 11 ) * 1 ), 5) )} by factor 3 -> iS39{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 1 ), 5) ), 3) )}, iS40{3}, start offset: 0, stop offset: 0 + // exprg{23}: Split: iS31{( ceilDiv(( ( 7 * 1 ) * 13 ), 5) )} by factor 3 -> iS33{( ceilDiv(( ceilDiv(( ( 7 * 1 ) * 13 ), 5) ), 3) )}, iS34{3}, start offset: 0, stop offset: 0 + // exprg{15 85 87 113 123 127 131 133 135 137 139}: Split: iS25{( ceilDiv(( ( 7 * 11 ) * 13 ), 5) )} by factor 3 -> iS27{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 13 ), 5) ), 3) )}, iS28{3}, start offset: 0, stop offset: 0 + + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 3)); + ref_order.push_back(getParentExpr(tv6->axis(0), 3)); + ref_order.push_back(getParentExpr(tv9->axis(0), 4)); + ref_order.push_back(getParentExpr(tv1->axis(0), 2)); + ref_order.push_back(getParentExpr(tv8->axis(0), 3)); + ref_order.push_back(getParentExpr(tv10->axis(0), 4)); + ref_order.push_back(getParentExpr(tv5->axis(0), 3)); + ref_order.push_back(getParentExpr(tv9->axis(0), 3)); + ref_order.push_back(getParentExpr(tv10->axis(0), 3)); + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv6->axis(0), 2)); + ref_order.push_back(getParentExpr(tv1->axis(0), 1)); + ref_order.push_back(getParentExpr(tv8->axis(0), 2)); + ref_order.push_back(getParentExpr(tv4->axis(0), 2)); + ref_order.push_back(getParentExpr(tv5->axis(0), 2)); + ref_order.push_back(getParentExpr(tv9->axis(0), 2)); + ref_order.push_back(getParentExpr(tv10->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + ref_order.push_back(getParentExpr(tv6->axis(0), 1)); + ref_order.push_back(getParentExpr(tv8->axis(0), 1)); + ref_order.push_back(getParentExpr(tv4->axis(0), 1)); + ref_order.push_back(getParentExpr(tv5->axis(0), 1)); + ref_order.push_back(getParentExpr(tv9->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 1)); - // Since there's no dependency between tv1 and tv3, the reverse - // should be valid too - std::reverse(ref_sorted_exprs.begin(), ref_sorted_exprs.end()); checkSortingResults( - vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_sorted_exprs); -#endif + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); } } // namespace nvfuser From 48e3019e882cb697cea426c95e76c6781dab4d5a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 1 Feb 2024 20:55:30 -0800 Subject: [PATCH 142/178] cleanup --- csrc/val_graph_visitor.cpp | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index b96e35216af..5588ee6522f 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -19,19 +19,31 @@ void ValGraphVisitor::traverse() { ExprGroups visited_exprs; auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool { - auto inp_groups = graph().inputGroups(expr_group); + const auto inp_groups = graph().inputGroups(expr_group); return std::all_of( inp_groups.begin(), inp_groups.end(), [&](ValGroup id_group) { return visited_ids.has(id_group) || id_group->empty(); }); }; + auto is_output_mapped_with_all_inputs = + [&](const ValGroup& output_group, const ExprGroup& expr_group) -> bool { + const auto inp_groups = graph().inputGroups(expr_group); + return std::all_of( + inp_groups.begin(), inp_groups.end(), [&](const ValGroup& inp_group) { + return inp_group == output_group; + }); + }; + auto is_val_ready = [&](const ValGroup& val_group) -> bool { const ExprGroups& unique_defs = graph().getDefinitions(val_group); + return std::all_of( unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { + // If all the inputs of the def expr are mapped with the val + // group itself, it should be ready to visit. return expr_group->empty() || visited_exprs.has(expr_group) || - graph().isTrivialExprGroup(expr_group); + is_output_mapped_with_all_inputs(val_group, expr_group); }); }; @@ -79,15 +91,12 @@ void ValGraphVisitor::traverse() { something_was_processed = true; visited_ids.pushBack(current_id_group); - //if (true || !terminating_outputs.has(current_id_group)) { - if (true) { - const ExprGroups& uses = graph().getUses(current_id_group); - to_visit_exprs.pushBack(uses); - } + to_visit_exprs.pushBack(graph().getUses(current_id_group)); } else { still_to_visit_ids.pushBack(current_id_group); } } + std::swap(to_visit_ids, still_to_visit_ids); NVF_ERROR( From a828f6fbc4c406cc6f818853ded379887b77b50f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 1 Feb 2024 21:49:57 -0800 Subject: [PATCH 143/178] Simplify by removing sub_selection as it's not used --- csrc/val_graph.cpp | 42 ++++++++------------------------------ csrc/val_graph.h | 5 ++--- csrc/val_graph_visitor.cpp | 14 ++++++++++++- csrc/val_graph_visitor.h | 16 ++------------- 4 files changed, 25 insertions(+), 52 deletions(-) diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 7f4535be334..5c2fc12c10e 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -100,42 +100,16 @@ std::vector ValGraph::inputGroups(const ExprGroup& expr) const { return input_groups; } -ValGroups ValGraph::getTerminatingInputs(const VectorOfUniqueEntries& sub_selection) const { - // Initialize vals to traverse. If sub_selection is provided, only - // include the Val groups of the Vals in sub_selection - ValGroups all_vals; - if (sub_selection.empty()) { - all_vals = ValGroups( +ValGroups ValGraph::getTerminatingInputs() const { + // Initialize vals to traverse + ValGroups all_vals{ disjointValSets().disjointSets().begin(), - disjointValSets().disjointSets().end()); - } else { - all_vals = toGroups(sub_selection); - } - - // Initialize exprs to traverse. If sub_selection is provided, - // only traverse exprs that are strictly contained within the provided - // sub_selection. Exprs are excluded if any of inputs or outputs - // is not in sub_selection. - ExprGroups all_exprs; - if (sub_selection.empty()) { - all_exprs = ExprGroups( + disjointValSets().disjointSets().end()}; + + // Initialize exprs to traverse + ExprGroups all_exprs{ disjointExprSets().disjointSets().begin(), - disjointExprSets().disjointSets().end()); - } else { - for (const ValGroup& val_group : all_vals) { - for (const ExprGroup& def : getDefinitions(val_group)) { - if (all_exprs.has(def)) { - continue; - } - auto inp_groups = ValGroups(inputGroups(def)); - auto out_groups = ValGroups(outputGroups(def)); - if (inp_groups.computeSubtract(all_vals).empty() && - out_groups.computeSubtract(all_vals).empty()) { - all_exprs.pushBack(def); - } - } - } - } + disjointExprSets().disjointSets().end()}; // Grab all vals that are not input, i.e., having a defining expr // within all_exprs. diff --git a/csrc/val_graph.h b/csrc/val_graph.h index 2a42f184fc0..fd2e3695031 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -122,9 +122,8 @@ class ValGraph { std::vector outputGroups(const ExprGroup& expr) const; std::vector inputGroups(const ExprGroup& expr) const; - // Return Val groups that have no definition exprs. If a set of Vals - // are provided, only the Vals in the set are considered. - ValGroups getTerminatingInputs(const VectorOfUniqueEntries& sub_selection = {}) const; + // Return Val groups that have no definition exprs. + ValGroups getTerminatingInputs() const; // Recursively traverses uses of the IdGroups in 'of' and returns all // ExprGroups that have a use in their definition of provided of IdGroups. diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 5588ee6522f..41c49a06646 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -12,12 +12,16 @@ namespace nvfuser { void ValGraphVisitor::traverse() { - ValGroups to_visit_ids = graph().getTerminatingInputs(sub_selection_); + ValGroups to_visit_ids = graph().getTerminatingInputs(); ValGroups visited_ids; ExprGroups to_visit_exprs; ExprGroups visited_exprs; + for (const auto& idg : to_visit_ids) { + std::cerr << "Initial IDs: " << nvfuser::toString(idg) << std::endl; + } + auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool { const auto inp_groups = graph().inputGroups(expr_group); return std::all_of( @@ -66,6 +70,9 @@ void ValGraphVisitor::traverse() { if (is_expr_ready(current_expr_group)) { handle(current_expr_group); + std::cerr << "EG: " << nvfuser::toString(current_expr_group) + << std::endl; + something_was_processed = true; visited_exprs.pushBack(current_expr_group); @@ -88,11 +95,16 @@ void ValGraphVisitor::traverse() { if (is_val_ready(current_id_group)) { handle(current_id_group); + std::cerr << "IDG: " << nvfuser::toString(current_id_group) + << std::endl; + something_was_processed = true; visited_ids.pushBack(current_id_group); to_visit_exprs.pushBack(graph().getUses(current_id_group)); } else { + std::cerr << "NOT READY IDG: " << nvfuser::toString(current_id_group) + << std::endl; still_to_visit_ids.pushBack(current_id_group); } } diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h index c10f70a192a..5b5d3df90d7 100644 --- a/csrc/val_graph_visitor.h +++ b/csrc/val_graph_visitor.h @@ -16,9 +16,6 @@ namespace nvfuser { // Iterates through a Val Graph in topological order, calling handle on // all Val and all Expr groups in a forward topological order. // -// Warning: Expr groups that have an input and output in the same ValGroup are -// ignored. -// // Warning: This is not a great iterator if there's a desire to minimize paths // traveled to simply visit all ValGroups in order. See ExprsBetween to see how // we might minimize paths. @@ -33,12 +30,7 @@ class ValGraphVisitor { virtual ~ValGraphVisitor() = default; protected: - // If sub_selection is assumed to be a set of vals by which form a - // sub-regrion of the ValGraph provided. Only that sub-region will be visited. - ValGraphVisitor( - const ValGraph& val_graph, - const VectorOfUniqueEntries sub_selection = {}) - : val_graph_(val_graph), sub_selection_(sub_selection) {} + ValGraphVisitor(const ValGraph& val_graph) : val_graph_(val_graph) {} ValGraphVisitor(const ValGraphVisitor& other) = default; @@ -55,16 +47,12 @@ class ValGraphVisitor { private: const ValGraph& val_graph_; - const VectorOfUniqueEntries sub_selection_; }; // Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor. class ValGraphStmtSort : public ValGraphVisitor { public: - ValGraphStmtSort( - const ValGraph& val_graph, - const VectorOfUniqueEntries sub_selection = {}) - : ValGraphVisitor(val_graph, sub_selection) { + ValGraphStmtSort(const ValGraph& val_graph) : ValGraphVisitor(val_graph) { ValGraphVisitor::traverse(); } From fa99fe2b6a90be570fbdd60f8da675206f1b69e8 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 2 Feb 2024 18:30:44 +0000 Subject: [PATCH 144/178] fix for trivial exprs --- csrc/val_graph_visitor.cpp | 49 ++++++++++++++++++-------------------- csrc/val_graph_visitor.h | 35 +++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 26 deletions(-) diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 41c49a06646..498a56029a3 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -12,16 +12,13 @@ namespace nvfuser { void ValGraphVisitor::traverse() { - ValGroups to_visit_ids = graph().getTerminatingInputs(); + const ValGroups terminating_inputs = graph().getTerminatingInputs(); + ValGroups to_visit_ids = terminating_inputs; ValGroups visited_ids; ExprGroups to_visit_exprs; ExprGroups visited_exprs; - for (const auto& idg : to_visit_ids) { - std::cerr << "Initial IDs: " << nvfuser::toString(idg) << std::endl; - } - auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool { const auto inp_groups = graph().inputGroups(expr_group); return std::all_of( @@ -30,24 +27,32 @@ void ValGraphVisitor::traverse() { }); }; - auto is_output_mapped_with_all_inputs = - [&](const ValGroup& output_group, const ExprGroup& expr_group) -> bool { - const auto inp_groups = graph().inputGroups(expr_group); - return std::all_of( - inp_groups.begin(), inp_groups.end(), [&](const ValGroup& inp_group) { - return inp_group == output_group; - }); - }; - + // If any input of the def expr is mapped with the val + // group itself, i.e., a trivial expr, allow visiting the + // val group first. The trivial expr group will be visited + // after the val group. + // + // Example: + // + // [i0, 1] + // merge + // [i0*1] + // map i0 and i0*1 + // ValGroups: {{i0, i0*1}, {1}} + // + // Then, {i0, i0*1} and {1} would be visited first, then the merge + // expr group would be visited. {i0, i0*1} is also an output group + // of the merge but since it's already in the visited set, it would + // not be visited again. + // + // See also IdModelTest.ValGraphStmtSort3 for a concrete example. auto is_val_ready = [&](const ValGroup& val_group) -> bool { const ExprGroups& unique_defs = graph().getDefinitions(val_group); - return std::all_of( unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { - // If all the inputs of the def expr are mapped with the val - // group itself, it should be ready to visit. return expr_group->empty() || visited_exprs.has(expr_group) || - is_output_mapped_with_all_inputs(val_group, expr_group); + terminating_inputs.has(val_group) || + graph().isTrivialExprGroup(expr_group); }); }; @@ -70,9 +75,6 @@ void ValGraphVisitor::traverse() { if (is_expr_ready(current_expr_group)) { handle(current_expr_group); - std::cerr << "EG: " << nvfuser::toString(current_expr_group) - << std::endl; - something_was_processed = true; visited_exprs.pushBack(current_expr_group); @@ -95,16 +97,11 @@ void ValGraphVisitor::traverse() { if (is_val_ready(current_id_group)) { handle(current_id_group); - std::cerr << "IDG: " << nvfuser::toString(current_id_group) - << std::endl; - something_was_processed = true; visited_ids.pushBack(current_id_group); to_visit_exprs.pushBack(graph().getUses(current_id_group)); } else { - std::cerr << "NOT READY IDG: " << nvfuser::toString(current_id_group) - << std::endl; still_to_visit_ids.pushBack(current_id_group); } } diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h index 5b5d3df90d7..391c08902f3 100644 --- a/csrc/val_graph_visitor.h +++ b/csrc/val_graph_visitor.h @@ -16,6 +16,41 @@ namespace nvfuser { // Iterates through a Val Graph in topological order, calling handle on // all Val and all Expr groups in a forward topological order. // +// Warning: A ValGraph is not guaranteed to be a DAG. In fact, the +// AlmostExact and Permissive graphs would have cycles with a ValGroup +// and an ExprGroup. For example: +// +// [i0, 1] +// merge +// [i0*1] +// Current ValGroups: {{i0}, {1}, {i0*1}} +// map i0 and i0*1 as they effectively have the same extent +// Final ValGroups: {{i0, i0*1}, {1}} +// +// Here, the merge expr is the user of i0 and the definition of +// i0*1. Since i0 and i0*1 are mapped, the dependency chain looks +// like: +// +// {i0, i0*1} ----> {merge} ----> {i0, i0*1} +// use def +// +// These ExprGroups are called trivial ExprGroups (see also +// ValGraph::isTrivialExprGroup). +// +// Strictly speaking, these cycles mean there's no valid topological +// order anymore. In our use cases for IdModel, however, it's likely +// sufficient to return an ordering such as: +// +// {i0, i0*1} -> {merge} +// +// I.e., we visit {i0, i0*1} first even though {merge} is technically +// a definition. +// +// Another alternative may be simply giving up when such a cycle is +// detected, which may be more preferrable as it would be less +// confusing. At this moment, this visitor is only used with graphs +// with no such cycle. Should be revisited when necessary. +// // Warning: This is not a great iterator if there's a desire to minimize paths // traveled to simply visit all ValGroups in order. See ExprsBetween to see how // we might minimize paths. From 2321f3e30819f8c458e78bcf575248a7444f0796 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 2 Feb 2024 18:32:51 +0000 Subject: [PATCH 145/178] Accidentally added --- csrc/val_graph_visitor.h | 121 --------------------------------------- 1 file changed, 121 deletions(-) delete mode 100644 csrc/val_graph_visitor.h diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h deleted file mode 100644 index 391c08902f3..00000000000 --- a/csrc/val_graph_visitor.h +++ /dev/null @@ -1,121 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include -#include -#include - -namespace nvfuser { - -// Iterates through a Val Graph in topological order, calling handle on -// all Val and all Expr groups in a forward topological order. -// -// Warning: A ValGraph is not guaranteed to be a DAG. In fact, the -// AlmostExact and Permissive graphs would have cycles with a ValGroup -// and an ExprGroup. For example: -// -// [i0, 1] -// merge -// [i0*1] -// Current ValGroups: {{i0}, {1}, {i0*1}} -// map i0 and i0*1 as they effectively have the same extent -// Final ValGroups: {{i0, i0*1}, {1}} -// -// Here, the merge expr is the user of i0 and the definition of -// i0*1. Since i0 and i0*1 are mapped, the dependency chain looks -// like: -// -// {i0, i0*1} ----> {merge} ----> {i0, i0*1} -// use def -// -// These ExprGroups are called trivial ExprGroups (see also -// ValGraph::isTrivialExprGroup). -// -// Strictly speaking, these cycles mean there's no valid topological -// order anymore. In our use cases for IdModel, however, it's likely -// sufficient to return an ordering such as: -// -// {i0, i0*1} -> {merge} -// -// I.e., we visit {i0, i0*1} first even though {merge} is technically -// a definition. -// -// Another alternative may be simply giving up when such a cycle is -// detected, which may be more preferrable as it would be less -// confusing. At this moment, this visitor is only used with graphs -// with no such cycle. Should be revisited when necessary. -// -// Warning: This is not a great iterator if there's a desire to minimize paths -// traveled to simply visit all ValGroups in order. See ExprsBetween to see how -// we might minimize paths. -class ValGraphVisitor { - public: - ValGraphVisitor() = delete; - - ValGraphVisitor& operator=(const ValGraphVisitor& other) = delete; - - ValGraphVisitor& operator=(ValGraphVisitor&& other) = delete; - - virtual ~ValGraphVisitor() = default; - - protected: - ValGraphVisitor(const ValGraph& val_graph) : val_graph_(val_graph) {} - - ValGraphVisitor(const ValGraphVisitor& other) = default; - - ValGraphVisitor(ValGraphVisitor&& other) = default; - - virtual void handle(const ValGroup& id_group) = 0; - virtual void handle(const ExprGroup& expr_group) = 0; - - void traverse(); - - const ValGraph& graph() { - return val_graph_; - }; - - private: - const ValGraph& val_graph_; -}; - -// Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor. -class ValGraphStmtSort : public ValGraphVisitor { - public: - ValGraphStmtSort(const ValGraph& val_graph) : ValGraphVisitor(val_graph) { - ValGraphVisitor::traverse(); - } - - // Return non-reference so that code like below can work - // for (auto expr_group: IdGraphStmtSort(graph).exprs()) - ExprGroups exprs() const { - return sorted_exprs_; - } - - ValGroups vals() const { - return sorted_vals_; - } - - ~ValGraphStmtSort() override = default; - - protected: - using ValGraphVisitor::handle; - - void handle(const ValGroup& val_group) override { - sorted_vals_.pushBack(val_group); - } - - void handle(const ExprGroup& expr_group) override { - sorted_exprs_.pushBack(expr_group); - } - - ExprGroups sorted_exprs_; - ValGroups sorted_vals_; -}; - -} // namespace nvfuser From ad3da55e70101f8ebd7c3a9bc6ab543cc5009674 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 2 Feb 2024 18:56:38 +0000 Subject: [PATCH 146/178] Revert "Accidentally added" This reverts commit 2321f3e30819f8c458e78bcf575248a7444f0796. --- csrc/val_graph_visitor.h | 121 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 csrc/val_graph_visitor.h diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h new file mode 100644 index 00000000000..391c08902f3 --- /dev/null +++ b/csrc/val_graph_visitor.h @@ -0,0 +1,121 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include + +namespace nvfuser { + +// Iterates through a Val Graph in topological order, calling handle on +// all Val and all Expr groups in a forward topological order. +// +// Warning: A ValGraph is not guaranteed to be a DAG. In fact, the +// AlmostExact and Permissive graphs would have cycles with a ValGroup +// and an ExprGroup. For example: +// +// [i0, 1] +// merge +// [i0*1] +// Current ValGroups: {{i0}, {1}, {i0*1}} +// map i0 and i0*1 as they effectively have the same extent +// Final ValGroups: {{i0, i0*1}, {1}} +// +// Here, the merge expr is the user of i0 and the definition of +// i0*1. Since i0 and i0*1 are mapped, the dependency chain looks +// like: +// +// {i0, i0*1} ----> {merge} ----> {i0, i0*1} +// use def +// +// These ExprGroups are called trivial ExprGroups (see also +// ValGraph::isTrivialExprGroup). +// +// Strictly speaking, these cycles mean there's no valid topological +// order anymore. In our use cases for IdModel, however, it's likely +// sufficient to return an ordering such as: +// +// {i0, i0*1} -> {merge} +// +// I.e., we visit {i0, i0*1} first even though {merge} is technically +// a definition. +// +// Another alternative may be simply giving up when such a cycle is +// detected, which may be more preferrable as it would be less +// confusing. At this moment, this visitor is only used with graphs +// with no such cycle. Should be revisited when necessary. +// +// Warning: This is not a great iterator if there's a desire to minimize paths +// traveled to simply visit all ValGroups in order. See ExprsBetween to see how +// we might minimize paths. +class ValGraphVisitor { + public: + ValGraphVisitor() = delete; + + ValGraphVisitor& operator=(const ValGraphVisitor& other) = delete; + + ValGraphVisitor& operator=(ValGraphVisitor&& other) = delete; + + virtual ~ValGraphVisitor() = default; + + protected: + ValGraphVisitor(const ValGraph& val_graph) : val_graph_(val_graph) {} + + ValGraphVisitor(const ValGraphVisitor& other) = default; + + ValGraphVisitor(ValGraphVisitor&& other) = default; + + virtual void handle(const ValGroup& id_group) = 0; + virtual void handle(const ExprGroup& expr_group) = 0; + + void traverse(); + + const ValGraph& graph() { + return val_graph_; + }; + + private: + const ValGraph& val_graph_; +}; + +// Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor. +class ValGraphStmtSort : public ValGraphVisitor { + public: + ValGraphStmtSort(const ValGraph& val_graph) : ValGraphVisitor(val_graph) { + ValGraphVisitor::traverse(); + } + + // Return non-reference so that code like below can work + // for (auto expr_group: IdGraphStmtSort(graph).exprs()) + ExprGroups exprs() const { + return sorted_exprs_; + } + + ValGroups vals() const { + return sorted_vals_; + } + + ~ValGraphStmtSort() override = default; + + protected: + using ValGraphVisitor::handle; + + void handle(const ValGroup& val_group) override { + sorted_vals_.pushBack(val_group); + } + + void handle(const ExprGroup& expr_group) override { + sorted_exprs_.pushBack(expr_group); + } + + ExprGroups sorted_exprs_; + ValGroups sorted_vals_; +}; + +} // namespace nvfuser From e35fb699f0bcb7bddc39c90d79ccc77fdeff9a36 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 2 Feb 2024 18:57:08 +0000 Subject: [PATCH 147/178] Accidentally added --- csrc/val_graph_visitor. | 1 - 1 file changed, 1 deletion(-) delete mode 100644 csrc/val_graph_visitor. diff --git a/csrc/val_graph_visitor. b/csrc/val_graph_visitor. deleted file mode 100644 index fa33c4e24c7..00000000000 --- a/csrc/val_graph_visitor. +++ /dev/null @@ -1 +0,0 @@ -h \ No newline at end of file From 9727240d23d6397a42b131e299c01f4e9b9eedea Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 13 Feb 2024 12:52:43 -0800 Subject: [PATCH 148/178] Repro for the compliment mapping issue --- test/test_id_model.cpp | 81 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 135c963b02e..0c5cd3238f1 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -1100,8 +1100,85 @@ TEST_F(IdModelTest, ValGraphStmtSort4) { ref_order.push_back(getParentExpr(tv9->axis(0), 1)); ref_order.push_back(getParentExpr(tv10->axis(0), 1)); - checkSortingResults( - vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + +// A repro that produces an invalid loop graph due to the compliment +// mapping. This is not currently supported. +TEST_F(IdModelTest, ComplimentMappingRepro) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({7}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({7, 8}); + fusion.addInput(tv1); + auto tv2 = makeConcreteTensor({7, 9}); + fusion.addInput(tv2); + + auto tv3 = broadcast(tv0, {false, true}); + auto tv4 = add(tv1, tv3); + auto tv5 = broadcast(tv4, {false, false, true}); + + auto tv6 = broadcast(tv0, {false, true}); + auto tv7 = add(tv2, tv6); + auto tv8 = broadcast(tv7, {false, true, false}); + + auto tv9 = add(tv5, tv8); + + auto tv10 = set(tv9); + auto tv11 = set(tv10); + fusion.addOutput(tv11); + + // Merge all domains except for tv10 and tv11 + for (auto tv : ir_utils::allTvs(&fusion)) { + if (tv == tv10 || tv == tv11) { + continue; + } + while (tv->nDims() > 1) { + tv->merge(0); + } + } + + // Fully inline all tensors up until tv10 + for (auto tv : ir_utils::allTvs(&fusion)) { + if (tv == tv9 || tv == tv10 || tv == tv11) { + continue; + } + tv->inlineAt(1); + } + + // Fully inline tv10 to tv11 without merging + tv10->inlineAt(-1); + + IdModel id_model(&fusion, true, false, false); + + const ValGraph& loop_graph = id_model.idGraph(IdMappingMode::LOOP); + + // Due to the compliment mapping, the leaf domains of tv10 and tv11 + // are loop mapped, which is invalid. + // + // Specifically, here are the tv10 and tv11 tensors: + // + // T10_l[ iS22{7}, iS23{8}, iS24{9} ] ca_pos( 3 ) + // root domain : (iS22{7}, iS23{8}, iS24{9}) + // contiguity: t t t + // leaf domain : (iS22{7}, iS23{8}, iS24{9}) + // T11_g[ iS25{7}, iS26{8}, iS27{9} ] produce_pos( 3 ) + // root domain : (iS25{7}, iS26{8}, iS27{9}) + // contiguity: t t t + // leaf domain : (iS25{7}, iS26{8}, iS27{9}) + // + // Here's the loop graph for tv10 and tv11: + // idg{22 23 24 25 26 27} + + // These assertions should fail at this moment. + ASSERT_NE( + loop_graph.toGroup(tv10->axis(0)), loop_graph.toGroup(tv10->axis(1))); + ASSERT_NE( + loop_graph.toGroup(tv10->axis(0)), loop_graph.toGroup(tv10->axis(2))); + ASSERT_NE( + loop_graph.toGroup(tv10->axis(1)), loop_graph.toGroup(tv10->axis(2))); } } // namespace nvfuser From 4028dea1b6b43010762df0eb36071729558d4373 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 13 Feb 2024 12:53:03 -0800 Subject: [PATCH 149/178] clang-format --- test/test_id_model.cpp | 82 ++++++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 34 deletions(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 0c5cd3238f1..d3ba5c202ab 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -875,7 +875,8 @@ void checkSortingResults( graph.disjointExprSets().disjointSets().end()}; std::unordered_set sorted_expr_group_set{ sorted_expr_groups.begin(), sorted_expr_groups.end()}; - ASSERT_EQ(sorted_expr_group_set, ref_expr_group_set) << "Mismatched ExprGroups."; + ASSERT_EQ(sorted_expr_group_set, ref_expr_group_set) + << "Mismatched ExprGroups."; // Make sure sorted_val_groups covers all Val groups const std::unordered_set& ref_val_group_set{ @@ -891,10 +892,8 @@ void checkSortingResults( Expr* ref_expr = ref_order.at(i); const ExprGroup& eg = sorted_expr_groups.at(i); ASSERT_TRUE(eg->has(ref_expr)) - << "Expected: " - << nvfuser::toString(graph.toGroup(ref_expr)) - << ". Actual: " - << nvfuser::toString(eg); + << "Expected: " << nvfuser::toString(graph.toGroup(ref_expr)) + << ". Actual: " << nvfuser::toString(eg); } } @@ -917,8 +916,7 @@ TEST_F(IdModelTest, ValGraphStmtSort1) { IdModel id_model(&fusion); const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); ValGraphStmtSort vg_stmt_sort(vg); - checkSortingResults( - vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), {}); + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), {}); } tv2->merge(0)->split(0, 4); @@ -960,7 +958,7 @@ TEST_F(IdModelTest, ValGraphStmtSort2) { // Note that the two groups of tensors, {tv0, tv1} and {tv2, tv3}, // are not connected - for (auto tv: ir_utils::allTvs(&fusion)) { + for (auto tv : ir_utils::allTvs(&fusion)) { tv->merge(0)->split(0, 4); } @@ -975,8 +973,7 @@ TEST_F(IdModelTest, ValGraphStmtSort2) { ref_order.push_back(getParentExpr(tv1->axis(0), 1)); ref_order.push_back(getParentExpr(tv3->axis(0), 1)); - checkSortingResults( - vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); } // Sorting with trivial ExprGroup, i.e., ExprGroup whose input and @@ -999,7 +996,7 @@ TEST_F(IdModelTest, ValGraphStmtSort3) { fusion.addOutput(tv4); // Merge adn split by one. The split input and output will be mapped. - for (auto tv: {tv0, tv1, tv2}) { + for (auto tv : {tv0, tv1, tv2}) { tv->merge(0)->split(0, 1); } @@ -1022,8 +1019,7 @@ TEST_F(IdModelTest, ValGraphStmtSort3) { ref_order.push_back(getParentExpr(tv2->axis(0), 1)); ref_order.push_back(getParentExpr(tv4->axis(0), 1)); - checkSortingResults( - vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); } // Sorting test with the same fusion as Indexing19 @@ -1052,27 +1048,45 @@ TEST_F(IdModelTest, ValGraphStmtSort4) { // exprg{39}: Merge: iS2{7} and bS3{1} -> iS46{( 7 * 1 )} // exprg{57}: Merge: iS11{7} and bS12{1} -> iS61{( 7 * 1 )} // exprg{17}: Merge: iS17{7} and bS18{1} -> iS29{( 7 * 1 )} - // exprg{69 73 89}: Split: iS1{7} by factor 5 -> iS71{( ceilDiv(7, 5) )}, iS72{5}, start offset: 0, stop offset: 0 - // exprg{51 63 93}: Merge: iS15{7} and iS16{13} -> iS56{( 7 * 13 )} - // exprg{9 25 33 45 91 95}: Merge: iS20{7} and iS21{11} -> iS23{( 7 * 11 )} - // exprg{27}: Merge: iS35{( 7 * 11 )} and bS10{1} -> iS36{( ( 7 * 11 ) * 1 )} - // exprg{19}: Merge: iS29{( 7 * 1 )} and iS19{13} -> iS30{( ( 7 * 1 ) * 13 )} - // exprg{11 77 79 99}: Merge: iS23{( 7 * 11 )} and iS22{13} -> iS24{( ( 7 * 11 ) * 13 )} - // exprg{41}: Split: iS46{( 7 * 1 )} by factor 5 -> iS47{( ceilDiv(( 7 * 1 ), 5) )}, iS48{5}, start offset: 0, stop offset: 0 - // exprg{59}: Split: iS61{( 7 * 1 )} by factor 5 -> iS62{( ceilDiv(( 7 * 1 ), 5) )}, iS63{5}, start offset: 0, stop offset: 0 - // exprg{71 75 101}: Split: iS71{( ceilDiv(7, 5) )} by factor 3 -> iS73{( ceilDiv(( ceilDiv(7, 5) ), 3) )}, iS74{3}, start offset: 0, stop offset: 0 - // exprg{53 65 109}: Split: iS56{( 7 * 13 )} by factor 5 -> iS57{( ceilDiv(( 7 * 13 ), 5) )}, iS58{5}, start offset: 0, stop offset: 0 - // exprg{35 47 105}: Split: iS41{( 7 * 11 )} by factor 5 -> iS42{( ceilDiv(( 7 * 11 ), 5) )}, iS43{5}, start offset: 0, stop offset: 0 - // exprg{29}: Split: iS36{( ( 7 * 11 ) * 1 )} by factor 5 -> iS37{( ceilDiv(( ( 7 * 11 ) * 1 ), 5) )}, iS38{5}, start offset: 0, stop offset: 0 - // exprg{21}: Split: iS30{( ( 7 * 1 ) * 13 )} by factor 5 -> iS31{( ceilDiv(( ( 7 * 1 ) * 13 ), 5) )}, iS32{5}, start offset: 0, stop offset: 0 - // exprg{13 81 83 97 103 107 111 115 117 119 121}: Split: iS24{( ( 7 * 11 ) * 13 )} by factor 5 -> iS25{( ceilDiv(( ( 7 * 11 ) * 13 ), 5) )}, iS26{5}, start offset: 0, stop offset: 0 - // exprg{43}: Split: iS47{( ceilDiv(( 7 * 1 ), 5) )} by factor 3 -> iS49{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS50{3}, start offset: 0, stop offset: 0 - // exprg{61}: Split: iS62{( ceilDiv(( 7 * 1 ), 5) )} by factor 3 -> iS64{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS65{3}, start offset: 0, stop offset: 0 - // exprg{55 67 129}: Split: iS57{( ceilDiv(( 7 * 13 ), 5) )} by factor 3 -> iS59{( ceilDiv(( ceilDiv(( 7 * 13 ), 5) ), 3) )}, iS60{3}, start offset: 0, stop offset: 0 - // exprg{37 49 125}: Split: iS42{( ceilDiv(( 7 * 11 ), 5) )} by factor 3 -> iS44{( ceilDiv(( ceilDiv(( 7 * 11 ), 5) ), 3) )}, iS45{3}, start offset: 0, stop offset: 0 - // exprg{31}: Split: iS37{( ceilDiv(( ( 7 * 11 ) * 1 ), 5) )} by factor 3 -> iS39{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 1 ), 5) ), 3) )}, iS40{3}, start offset: 0, stop offset: 0 - // exprg{23}: Split: iS31{( ceilDiv(( ( 7 * 1 ) * 13 ), 5) )} by factor 3 -> iS33{( ceilDiv(( ceilDiv(( ( 7 * 1 ) * 13 ), 5) ), 3) )}, iS34{3}, start offset: 0, stop offset: 0 - // exprg{15 85 87 113 123 127 131 133 135 137 139}: Split: iS25{( ceilDiv(( ( 7 * 11 ) * 13 ), 5) )} by factor 3 -> iS27{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 13 ), 5) ), 3) )}, iS28{3}, start offset: 0, stop offset: 0 + // exprg{69 73 89}: Split: iS1{7} by factor 5 -> iS71{( ceilDiv(7, 5) )}, + // iS72{5}, start offset: 0, stop offset: 0 exprg{51 63 93}: Merge: iS15{7} + // and iS16{13} -> iS56{( 7 * 13 )} exprg{9 25 33 45 91 95}: Merge: iS20{7} + // and iS21{11} -> iS23{( 7 * 11 )} exprg{27}: Merge: iS35{( 7 * 11 )} and + // bS10{1} -> iS36{( ( 7 * 11 ) * 1 )} exprg{19}: Merge: iS29{( 7 * 1 )} and + // iS19{13} -> iS30{( ( 7 * 1 ) * 13 )} exprg{11 77 79 99}: Merge: iS23{( 7 * + // 11 )} and iS22{13} -> iS24{( ( 7 * 11 ) * 13 )} exprg{41}: Split: iS46{( 7 + // * 1 )} by factor 5 -> iS47{( ceilDiv(( 7 * 1 ), 5) )}, iS48{5}, start + // offset: 0, stop offset: 0 exprg{59}: Split: iS61{( 7 * 1 )} by factor 5 -> + // iS62{( ceilDiv(( 7 * 1 ), 5) )}, iS63{5}, start offset: 0, stop offset: 0 + // exprg{71 75 101}: Split: iS71{( ceilDiv(7, 5) )} by factor 3 -> iS73{( + // ceilDiv(( ceilDiv(7, 5) ), 3) )}, iS74{3}, start offset: 0, stop offset: 0 + // exprg{53 65 109}: Split: iS56{( 7 * 13 )} by factor 5 -> iS57{( ceilDiv(( 7 + // * 13 ), 5) )}, iS58{5}, start offset: 0, stop offset: 0 exprg{35 47 105}: + // Split: iS41{( 7 * 11 )} by factor 5 -> iS42{( ceilDiv(( 7 * 11 ), 5) )}, + // iS43{5}, start offset: 0, stop offset: 0 exprg{29}: Split: iS36{( ( 7 * 11 + // ) * 1 )} by factor 5 -> iS37{( ceilDiv(( ( 7 * 11 ) * 1 ), 5) )}, iS38{5}, + // start offset: 0, stop offset: 0 exprg{21}: Split: iS30{( ( 7 * 1 ) * 13 )} + // by factor 5 -> iS31{( ceilDiv(( ( 7 * 1 ) * 13 ), 5) )}, iS32{5}, start + // offset: 0, stop offset: 0 exprg{13 81 83 97 103 107 111 115 117 119 121}: + // Split: iS24{( ( 7 * 11 ) * 13 )} by factor 5 -> iS25{( ceilDiv(( ( 7 * 11 ) + // * 13 ), 5) )}, iS26{5}, start offset: 0, stop offset: 0 exprg{43}: Split: + // iS47{( ceilDiv(( 7 * 1 ), 5) )} by factor 3 -> iS49{( ceilDiv(( ceilDiv(( 7 + // * 1 ), 5) ), 3) )}, iS50{3}, start offset: 0, stop offset: 0 exprg{61}: + // Split: iS62{( ceilDiv(( 7 * 1 ), 5) )} by factor 3 -> iS64{( ceilDiv(( + // ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS65{3}, start offset: 0, stop offset: 0 + // exprg{55 67 129}: Split: iS57{( ceilDiv(( 7 * 13 ), 5) )} by factor 3 -> + // iS59{( ceilDiv(( ceilDiv(( 7 * 13 ), 5) ), 3) )}, iS60{3}, start offset: 0, + // stop offset: 0 exprg{37 49 125}: Split: iS42{( ceilDiv(( 7 * 11 ), 5) )} by + // factor 3 -> iS44{( ceilDiv(( ceilDiv(( 7 * 11 ), 5) ), 3) )}, iS45{3}, + // start offset: 0, stop offset: 0 exprg{31}: Split: iS37{( ceilDiv(( ( 7 * 11 + // ) * 1 ), 5) )} by factor 3 -> iS39{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 1 ), + // 5) ), 3) )}, iS40{3}, start offset: 0, stop offset: 0 exprg{23}: Split: + // iS31{( ceilDiv(( ( 7 * 1 ) * 13 ), 5) )} by factor 3 -> iS33{( ceilDiv(( + // ceilDiv(( ( 7 * 1 ) * 13 ), 5) ), 3) )}, iS34{3}, start offset: 0, stop + // offset: 0 exprg{15 85 87 113 123 127 131 133 135 137 139}: Split: iS25{( + // ceilDiv(( ( 7 * 11 ) * 13 ), 5) )} by factor 3 -> iS27{( ceilDiv(( + // ceilDiv(( ( 7 * 11 ) * 13 ), 5) ), 3) )}, iS28{3}, start offset: 0, stop + // offset: 0 std::vector ref_order; ref_order.push_back(getParentExpr(tv2->axis(0), 3)); From 3d8b582e33246d774fec1415cd835ff7b5a02e10 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 13 Feb 2024 14:53:53 -0800 Subject: [PATCH 150/178] rename --- test/test_id_model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 962b73d94f0..5e8c722f0b8 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -1121,7 +1121,7 @@ TEST_F(IdModelTest, LoopPromotion8) { // A repro that produces an invalid loop graph due to the compliment // mapping. This is not currently supported. -TEST_F(IdModelTest, ComplimentMappingRepro) { +TEST_F(IdModelTest, ComplimentMappingCausingLoopSelfMapping) { Fusion fusion; FusionGuard fg(&fusion); From 44da75ad9cd43e484f22151a8e2f0dc80691a3f4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 13 Feb 2024 15:32:31 -0800 Subject: [PATCH 151/178] Error check --- csrc/id_model/id_model.cpp | 30 ++++++++++++++++++++++++++++++ test/test_id_model.cpp | 26 ++++++++++++++++---------- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 0614134c188..62bf0b37c2c 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -1014,6 +1014,30 @@ void IdModel::initializeLoopGraph(const StatefulInliningInfo& info) { } } +namespace { + +// Loop graph represents the loop structure of the given fusion, so +// there must not be any mapping between the leaf domains of each +// tensor. +void validateLoopGraphHasNoSelfMappedLeafDomains( + const std::vector& tvs, + const IdModel& id_model) { + for (auto tv : tvs) { + auto self_mappped_leaf_pair = + detectMappablePair(tv->domain()->leaf(), id_model, IdMappingMode::LOOP); + NVF_ERROR( + !self_mappped_leaf_pair.has_value(), + "Detected leaf domains are mapped in the loop graph. Tensor: ", + tv->toString(), + ". Mapped leaf domains: ", + self_mappped_leaf_pair->first->toString(), + " and ", + self_mappped_leaf_pair->second->toString()); + } +} + +} // namespace + void IdModel::buildLoopGraph() { // Make sure the depedent graphs are already built maybeBuildGraph(IdMappingMode::EXACT); @@ -1032,6 +1056,8 @@ void IdModel::buildLoopGraph() { initializeLoopGraph(inlining_info); + validateLoopGraphHasNoSelfMappedLeafDomains(tvs_, *this); + VERBOSE() << "Initial loop graph:\n"; for (const auto& group : idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { @@ -1040,6 +1066,10 @@ void IdModel::buildLoopGraph() { loop_promotion_map_ = buildLoopPromotionMap(inlining_info); + // New domains are added. Make sure there's still no self mapping in + // the leaf domains + validateLoopGraphHasNoSelfMappedLeafDomains(tvs_, *this); + idGraph(IdMappingMode::LOOP).validateConsistency(); } diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index 5e8c722f0b8..97dc677a16a 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -1167,10 +1167,6 @@ TEST_F(IdModelTest, ComplimentMappingCausingLoopSelfMapping) { // Fully inline tv10 to tv11 without merging tv10->inlineAt(-1); - IdModel id_model(&fusion, true, false, false); - - const ValGraph& loop_graph = id_model.idGraph(IdMappingMode::LOOP); - // Due to the compliment mapping, the leaf domains of tv10 and tv11 // are loop mapped, which is invalid. // @@ -1188,13 +1184,23 @@ TEST_F(IdModelTest, ComplimentMappingCausingLoopSelfMapping) { // Here's the loop graph for tv10 and tv11: // idg{22 23 24 25 26 27} + // Due to the invalid mapping, building IdModel should fail for now + EXPECT_THAT( + [&]() { IdModel id_model(&fusion, true, false, false); }, + ::testing::ThrowsMessage(::testing::HasSubstr( + "Detected leaf domains are mapped in the loop graph"))); + + // Enable the below validation once the above problem is resolved. + // + // const ValGraph& loop_graph = id_model.idGraph(IdMappingMode::LOOP); + // // These assertions should fail at this moment. - ASSERT_NE( - loop_graph.toGroup(tv10->axis(0)), loop_graph.toGroup(tv10->axis(1))); - ASSERT_NE( - loop_graph.toGroup(tv10->axis(0)), loop_graph.toGroup(tv10->axis(2))); - ASSERT_NE( - loop_graph.toGroup(tv10->axis(1)), loop_graph.toGroup(tv10->axis(2))); + // ASSERT_NE( + // loop_graph.toGroup(tv10->axis(0)), loop_graph.toGroup(tv10->axis(1))); + // ASSERT_NE( + // loop_graph.toGroup(tv10->axis(0)), loop_graph.toGroup(tv10->axis(2))); + // ASSERT_NE( + // loop_graph.toGroup(tv10->axis(1)), loop_graph.toGroup(tv10->axis(2))); } } // namespace nvfuser From d8f3ed42764a489e2b630084de28a26e44015e78 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 16 Feb 2024 02:01:15 -0800 Subject: [PATCH 152/178] WIP: Loop promotion analysis step 2 --- CMakeLists.txt | 1 + csrc/device_lower/lower2device.cpp | 2 +- csrc/id_model/id_model.cpp | 401 +++++++++++++++++++++++++++++ csrc/id_model/id_model.h | 20 ++ csrc/id_model/transform_replay.cpp | 89 +++++++ csrc/id_model/transform_replay.h | 55 ++++ csrc/id_model/utils.h | 55 ++++ csrc/val_graph.cpp | 8 + csrc/val_graph.h | 5 + test/test_id_model.cpp | 323 ++++++++++++++++++----- 10 files changed, 890 insertions(+), 69 deletions(-) create mode 100644 csrc/id_model/transform_replay.cpp create mode 100644 csrc/id_model/transform_replay.h create mode 100644 csrc/id_model/utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index ac091c0dc59..86329885737 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,6 +91,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/grouped_reduction.cpp ${NVFUSER_SRCS_DIR}/id_model/id_model.cpp ${NVFUSER_SRCS_DIR}/id_model/to_string.cpp + ${NVFUSER_SRCS_DIR}/id_model/transform_replay.cpp ${NVFUSER_SRCS_DIR}/id_model/validation_utils.cpp ${NVFUSER_SRCS_DIR}/index_compute.cpp ${NVFUSER_SRCS_DIR}/instrumentation.cpp diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 63c2948c97c..cb1f65f958f 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -384,7 +384,7 @@ void GpuLower::analysis(Fusion* fusion) { // functionality should be affected. New IterDomains may be created, // so it is expected that generated code may use diffrent variable // names - if (isOptionEnabled(EnableOption::IdModel)) { + if (true || isOptionEnabled(EnableOption::IdModel)) { IdModel id_model(fusion_); } diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 032bba2e8d5..b13e278ab3b 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include #include @@ -16,6 +17,7 @@ #include #include #include +#include #include #include @@ -719,6 +721,12 @@ std::unordered_map IdModel::buildLoopPromotionMap( std::unordered_map iel_promotion_map = buildInlineRootResolutionMap(iel_graph, inlining_info); + // Step 2: Propagate the root promotions to intermediate and leaf groups. + // At this point, the promotion may not be final as the analysis is + // localized to IEL groups. The map is used in the next step to + // build mappings of the loop groups. + propagatePromotionsInIELGraph(iel_graph, iel_promotion_map); + // This is not a right map to return but just a placeholder since // the loop promotion map is not yet completely merged. It will be // replaced by a proper map. @@ -949,4 +957,397 @@ ValGraph IdModel::buildIntersection( return intersection; } +namespace { + +// When replaying the transformations we can't blindly apply loop promotion +// to all iter domains within a loop group as it would replay the +// transformations within that loop group on the promoted id of that loop +// group. +// +// i.e. if we have the inlined domains from: +// T2[i0*i1] pa(1) = T0[i0*b1]ca(1) + T1[i0*i1]ca(1) +// The inlined loop group would be: +// +// i0, i1, b1, i0*i1, b0*i1 +// Then if we replayed the iel transformations they would be: +// merge(i0, i1) +// merge(i0, b1) +// +// So if we replayed them with loop promotion, then i0, i1, b1 would be +// promoted to i0*i1, and the merges would be replayed. +// +// Therefore only promote i0*b1 to i0*i1, or i0*i1 to i0*i1 (i.e. don't +// promote an input to any transformation within the loop group). +// +// So if we have an iel_expr make sure it's inputs and outputs are not in +// the same loop group. +bool hasUniqueOutputLoopGroups( + const ExprGroup& iel_expr, + const ValGraph& iel_graph, + const ValGraph& loop_graph) { + const std::vector iel_inp_groups = iel_graph.inputGroups(iel_expr); + + const std::vector iel_out_groups = iel_graph.outputGroups(iel_expr); + + ValGroups inp_loop_groups; + for (const ValGroup& iel_inp_group : iel_inp_groups) { + inp_loop_groups.pushBack(loop_graph.toGroup(iel_inp_group->front())); + } + ValGroups out_loop_groups; + for (const ValGroup& iel_out_group : iel_out_groups) { + out_loop_groups.pushBack(loop_graph.toGroup(iel_out_group->front())); + } + + // Check if output groups that are not included in the input group set + return !inp_loop_groups.computeSubtract(out_loop_groups).empty(); +} + +} // namespace + +// Propagate promotion mappings from root domains to derived domains +// by traversing IEL exprs. For each expr, if an input is promoted, +// the output needs to be promoted too. If there's already a domain +// that the output domain should be promoted to, create a mapping to it from +// the promoted output domain. If not, a new domain is created by +// replaying the expr with the promoted inputs. +// +// This is used twice when building the promotion map. The first time +// it is used there's no loop graph promotion yet, so only the IEL +// promotions are propagated. In that case, loop_graph_promotion_map +// should be just empty. +// +// Propagation uses iel_promotion_map and +// loop_graph_promotion_map. If both are available for an IEL group, +// the former has the precedence. This is because when this function +// is used for step 4, the given iel_promotion_map is empty and gets +// populated during this propagation, whereas the loop promotion map +// is not guaranteed to have the correct mappings for partially +// inlined domains. +// +// The loop_graph pamameter may not be up-to-date. +void IdModel::propagatePromotionsInIELGraph( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map, + const ValGraph& loop_graph, + const std::unordered_map& loop_graph_promotion_map, + bool require_loop_mapped_promotion) { + // In order to make this traversal work, the traversal order must be + // topologically sorted. + ValGraphStmtSort iel_stmt_sort(iel_graph); + + // TODO-NM: The ordering might be non-deterministic + + for (const ExprGroup& iel_expr : iel_stmt_sort.exprs()) { + NVF_ERROR(!iel_expr->empty()); + const std::vector iel_inp_groups = + iel_graph.inputGroups(iel_expr); + + // Propagate loop graph promotion only when the inputs and outputs are + // not in the same loop group. + const bool loop_promote_inputs = !loop_graph_promotion_map.empty() && + hasUniqueOutputLoopGroups(iel_expr, iel_graph, loop_graph); + + // Check if any inputs need promotion indicating this expr group needs to + // be replayed with promoted inputs + bool an_input_was_promoted = false; + std::vector maybe_promoted_inputs; + maybe_promoted_inputs.reserve(iel_inp_groups.size()); + + for (const ValGroup& iel_inp_group : iel_inp_groups) { + // Assumed all inputs are IterDomains + NVF_ERROR(iel_inp_group->front()->isA()); + + // Even when loop promotions are given, We still could require + // an input promotion. We could be traversing across non-inlined + // groups. Meaning we have inputs that were promoted in an + // inlined loop group traversing through the non-inlined + // portions of the iel graph. + if (auto inp_promo_it = iel_promotion_map.find(iel_inp_group); + inp_promo_it != iel_promotion_map.end()) { + maybe_promoted_inputs.push_back(inp_promo_it->second); + an_input_was_promoted = true; + continue; + } + + // Promote loops based on the loop promotion map. If the loop promotion + // map should be used and has an entry we should use that promotion. This + // happen when an iel expression is across a loop group boundary. + // Signifying and capturing instances when we traverse across an inlined + // loop group to a non-inlined loop group boundary (think of the iel graph + // projected onto the loop graph). + if (loop_promote_inputs) { + const ValGroup& loop_copy_group = + loop_graph.toGroup(iel_inp_group->front()); + auto inp_loop_promo_it = loop_graph_promotion_map.find(loop_copy_group); + if (inp_loop_promo_it != loop_graph_promotion_map.end()) { + maybe_promoted_inputs.push_back(inp_loop_promo_it->second); + an_input_was_promoted = true; + continue; + } + } + + // No promotion found. Just use the non-promoted domain + maybe_promoted_inputs.push_back(iel_inp_group->front()->as()); + } + + if (!an_input_was_promoted) { + // No inputs need promotion so just continue + continue; + } + + // Before replaying, check if there's already an expression like this, if so + // use that for promotion. We would need the iel entries for non-promoted + // inputs to match exactly to reuse the expression. + auto findMatchingExpr = + [this, &require_loop_mapped_promotion]( + const ExprGroup& iel_expr, + const ValGraph& iel_graph, + const std::vector& maybe_promoted_inputs) -> Expr* { + ExprGroups maybe_promoted_input_uses; + + for (auto inp_id : maybe_promoted_inputs) { + // inp_id may have been just replayed, in which case it should + // not exist in the IEL graph. It should be just ignored as it + // should not have any use yet. + if (!iel_graph.hasGroup(inp_id)) { + continue; + } + const auto& inp_exact_group = iel_graph.toGroup(inp_id); + maybe_promoted_input_uses.pushBack(iel_graph.getUses(inp_exact_group)); + } + + // Look for exprs that have inputs that are mapped in the IEL + // graph with the (promoted) inputs of iel_expr. If found, no need + // to create a new expr to produce promoted outputs + for (const ExprGroup& maybe_promoted_input_use_group : + maybe_promoted_input_uses) { + NVF_ERROR(!maybe_promoted_input_use_group->empty()); + // No need to check itself + if (iel_expr == maybe_promoted_input_use_group) { + continue; + } + Expr* maybe_promoted_input_use = + maybe_promoted_input_use_group->front(); + if (!iel_expr->front()->sameOp(maybe_promoted_input_use)) { + continue; + } + // Check if all inputs are mapped + NVF_ERROR( + maybe_promoted_inputs.size() == + maybe_promoted_input_use->inputs().size()); + bool inps_match = true; + for (const auto inp_i : c10::irange(maybe_promoted_inputs.size())) { + // Here, new promoted ids are not added to iel_graph, so + // once promoted, this should not return true anymore. Also, + // strictAreMapped doesn't work as promoted domains are not + // in the graph + inps_match = inps_match && + iel_graph.disjointValSets().permissiveAreMapped( + maybe_promoted_inputs[inp_i], + maybe_promoted_input_use->inputs().at(inp_i)); + } + if (!inps_match) { + continue; + } + + // For the final loop promotion map, we want to find + // promotions within the same loop groups. Note that that's + // guaranteed when replayed. + if (require_loop_mapped_promotion) { + if (!idGraph(IdMappingMode::LOOP) + .disjointExprSets() + .permissiveAreMapped( + iel_expr->front(), + maybe_promoted_input_use_group->front())) { + continue; + } + // This is just an extra sanity check. Make sure all exprs in + // the use group are mapped + NVF_ERROR( + std::all_of( + maybe_promoted_input_use_group->vector().begin(), + maybe_promoted_input_use_group->vector().end(), + [&](Expr* iel_use) { + return idGraph(IdMappingMode::LOOP) + .disjointExprSets() + .permissiveAreMapped(iel_expr->front(), iel_use); + }), + "Not all mapped: ", + nvfuser::toString(iel_expr), + "\n", + nvfuser::toString(maybe_promoted_input_use_group)); + } + return maybe_promoted_input_use; + } + + return nullptr; + }; + + bool replayed = false; + Expr* promoted_expr = + findMatchingExpr(iel_expr, iel_graph, maybe_promoted_inputs); + + if (!promoted_expr) { + promoted_expr = addReplayAs(maybe_promoted_inputs, iel_expr->front()); + replayed = true; + } + + // Mark outputs as having a promoted iter domain + std::vector out_groups = iel_graph.outputGroups(iel_expr); + NVF_ERROR(promoted_expr->outputs().size() == out_groups.size()); + NVF_ERROR( + ir_utils::filterByType(promoted_expr->outputs()).size() == + out_groups.size(), + "Unexpected non IterDomain outputs found: ", + promoted_expr->toString()); + + for (const auto i : c10::irange(out_groups.size())) { + // Promote if necessary, if the output is already in the same exact map + // it doesn't need a promotion. + if (idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped( + promoted_expr->output(i), out_groups[i]->front())) { + continue; + } + iel_promotion_map[out_groups[i]] = + promoted_expr->output(i)->as(); + // Explicitly map loop map since expr propagation doesn't happen + if (replayed) { + idGraph(IdMappingMode::LOOP) + .mapVals(iel_expr->front()->output(i), promoted_expr->output(i)); + } + } + } +} + +void IdModel::propagatePromotionsInIELGraph( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map) { + propagatePromotionsInIELGraph( + iel_graph, iel_promotion_map, idGraph(IdMappingMode::LOOP), {}, false); +} + +// Replay Expr but with the inputs provided. +Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { + // Figure out which graphs are already initialized to make sure we add the new + // expression to them. + std::vector initialized_modes; + for (auto mode : kIdMappingModes) { + auto graph_it = id_graphs_.find(mode); + if (graph_it == id_graphs_.end()) { + continue; + } + + auto& graph = graph_it->second; + if (graph.disjointValSets().disjointSetMap().empty()) { + continue; + } + + initialized_modes.push_back(mode); + } + + auto orig_inputs = ir_utils::filterByType(expr->inputs()); + std::vector orig_input_ids( + orig_inputs.begin(), orig_inputs.end()); + + // Replace the provided inputs with IterType::Iteration domains as + // reduction domains cannot be merged with non-reduction domains. + if (std::any_of( + new_inputs.begin(), + new_inputs.end(), + [](IterDomain* id) { return id->isReduction(); }) && + std::any_of(new_inputs.begin(), new_inputs.end(), [](IterDomain* id) { + return !id->isReduction(); + })) { + // Inputs have mismatched type, replace new_inputs + decltype(new_inputs) tmp_inputs; + std::swap(tmp_inputs, new_inputs); + for (auto tmp_input : tmp_inputs) { + new_inputs.push_back( + IterDomainBuilder(tmp_input).iter_type(IterType::Iteration).build()); + id_definitions_[new_inputs.back()]; + id_uses_[new_inputs.back()]; + for (auto mode : initialized_modes) { + idGraph(mode).initializeVal(new_inputs.back(), {}, {}); + idGraph(mode).mapVals(new_inputs.back(), tmp_input); + } + } + } + + { + NVF_ERROR( + new_inputs.size() == orig_input_ids.size(), + "Invalid number of inputs: ", + new_inputs.size(), + " does not match number of iter domain inputs for ", + expr->toString()); + + VectorOfUniqueEntries all_inputs{ + orig_input_ids.begin(), orig_input_ids.end()}; + + all_inputs.pushBack(new_inputs); + + for (auto mode : initialized_modes) { + for (auto inp : all_inputs) { + NVF_ERROR( + idGraph(mode).hasGroup(inp), + "All inputs for replay need to be initialized in all graphs, ", + inp->toString(), + " was not found in mode: ", + mode); + } + } + } + + // Create the new expression with provided inputs + auto replay = ReplayTransform::replayAs(new_inputs, expr); + + for (auto out_id : ir_utils::filterByType(replay->outputs())) { + id_definitions_[out_id].pushBack(replay); + id_uses_[out_id]; + } + + // Add the expression to the uses of the inputs + for (auto inp_id : ir_utils::filterByType(replay->inputs())) { + id_definitions_[inp_id]; + id_uses_[inp_id].pushBack(replay); + } + + // Initialize output iter domains in the graphs + for (auto mode : initialized_modes) { + idGraph(mode).registerExpr(replay); + auto replay_group = idGraph(mode).toGroup(replay); + + // Initialize output ids in map + for (auto out_id : ir_utils::filterByType(replay->outputs())) { + idGraph(mode).initializeVal(out_id, {replay}, {}); + } + + // Update uses of the inputs in the graphs + for (auto inp_id : ir_utils::filterByType(replay->inputs())) { + auto inp_group = idGraph(mode).toGroup(inp_id); + idGraph(mode).addUniqueUses(inp_group, replay_group); + } + + // Propagate through all the uses of the iter domain groups of the inputs + // with the new expression. + auto& graph = idGraph(mode); + // Gather all use expressions from inputs + VectorOfUniqueEntries representative_uses; + for (IterDomain* inp : new_inputs) { + for (const ExprGroup& use_group : graph.getUses(graph.toGroup(inp))) { + NVF_ERROR(!use_group->empty()); + representative_uses.pushBack(use_group->front()); + } + } + + for (auto rep_use : representative_uses) { + graph.maybeMapThroughExprs(rep_use, replay, true); + } + } + + return replay; +} + } // namespace nvfuser diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 794de4c41ae..4a322d1bbcb 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -194,9 +194,29 @@ class IdModel : public PolymorphicBase { const ValGraph& iel_graph, const StatefulInliningInfo& info); + // Helper function for building loop promotion map. Propagate + // promotions of root IEL groups to leaf IEL groups + void propagatePromotionsInIELGraph( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map); + + // Same as the other version but also propagates promotoins of loop + // groups as well + void propagatePromotionsInIELGraph( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map, + const ValGraph& loop_graph, + const std::unordered_map& loop_promotion_map, + bool require_loop_mapped_promotion); + // Errors if self mapping occurs void assertNoSelfMapping(); + // Replay Expr but with the inputs provided. IterDomainGraphss will be updated + // for all maps that have entries, adding the output iter domains of the + // replayed expression and adding potential mappings through the expression. + Expr* addReplayAs(std::vector new_inputs, Expr* expr); + protected: // All tensor expressions that this model analyzes std::vector tv_exprs_; diff --git a/csrc/id_model/transform_replay.cpp b/csrc/id_model/transform_replay.cpp new file mode 100644 index 00000000000..5f8a8f7af3e --- /dev/null +++ b/csrc/id_model/transform_replay.cpp @@ -0,0 +1,89 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include + +namespace nvfuser { + +Expr* ReplayTransform::replayAs( + const std::vector& ordered_inputs, + const Expr* expression_to_match) { + ReplayTransform replay(ordered_inputs, expression_to_match); + return replay.replayed_expr_; +} + +ReplayTransform::ReplayTransform( + const std::vector& ordered_inputs, + const Expr* expression_to_match) + : input_ids_(ordered_inputs) { + OptOutConstDispatch::dispatch(expression_to_match); +} + +// We're going to replay this split operation on the corresponding ID +void ReplayTransform::handle(const Split* split) { + NVF_ERROR( + input_ids_.size() == 1, + "Expected one input to match split: ", + split->toString()); + replayed_expr_ = IterDomain::split( + input_ids_[0], + split->factor(), + split->innerSplit(), + split->startOffset(), + split->stopOffset()) + .first->definition(); +} + +// We're going to replay this merge operation on the corresponding IDs +void ReplayTransform::handle(const Merge* merge) { + NVF_ERROR( + input_ids_.size() == 2, + "Expected two inputs to match merge: ", + merge->toString()); + replayed_expr_ = + IterDomain::merge(input_ids_[0], input_ids_[1])->definition(); +} + +// We're going to replay this swizzle operation on the corresponding IDs +// if replaying swizzle is enabled. +void ReplayTransform::handle(const Swizzle2D* swizzle_2d) { + NVF_ERROR( + input_ids_.size() == 2, + "Expected two inputs to match swizzle: ", + swizzle_2d->toString()); + replayed_expr_ = IterDomain::swizzle( + swizzle_2d->swizzleType(), + input_ids_[0], + input_ids_[1], + swizzle_2d->swizzleMode()) + .first->definition(); +} + +void ReplayTransform::handle(const Swizzle* swizzle) { + NVF_ERROR( + input_ids_.size() == 2, + "Expected two inputs to match swizzle: ", + swizzle->toString()); + replayed_expr_ = + IterDomain::swizzle(swizzle->swizzleType(), input_ids_[0], input_ids_[1]) + .first->definition(); +} + +void ReplayTransform::handle(const Resize* resize) { + NVF_ERROR( + input_ids_.size() == 1, + "Expected one input to match resize: ", + resize->toString()); + replayed_expr_ = + IterDomain::resize( + input_ids_[0], resize->leftExpand(), resize->rightExpand()) + ->definition(); +} + +} // namespace nvfuser diff --git a/csrc/id_model/transform_replay.h b/csrc/id_model/transform_replay.h new file mode 100644 index 00000000000..a37e9ab4aa0 --- /dev/null +++ b/csrc/id_model/transform_replay.h @@ -0,0 +1,55 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include + +#include +#include + +namespace nvfuser { + +class ReplayTransform : OptInConstDispatch { + public: + // Replays expression_to_match with the provided ordered_inputs. Inputs should + // be ordered as they would be used in provided expression. Returns new + // replayed expression. + static Expr* replayAs( + const std::vector& ordered_inputs, + const Expr* expression_to_match); + + private: + ReplayTransform( + const std::vector& ordered_inputs, + const Expr* expression_to_match); + + using OptInConstDispatch::handle; + + // We're going to replay this split operation on the corresponding ID + void handle(const Split* split) final; + + // We're going to replay this merge operation on the corresponding IDs + void handle(const Merge* merge) final; + + // We're going to replay this swizzle operation on the corresponding IDs + // if replaying swizzle is enabled. + void handle(const Swizzle2D* swizzle_2d) final; + + void handle(const Swizzle* swizzle) final; + + // We're going to replay this resize operation on the corresponding IDs + // if replaying resize is enabled. + void handle(const Resize* resize) final; + + Expr* replayed_expr_ = nullptr; + const std::vector& input_ids_; +}; + +} // namespace nvfuser diff --git a/csrc/id_model/utils.h b/csrc/id_model/utils.h new file mode 100644 index 00000000000..2d6327bf586 --- /dev/null +++ b/csrc/id_model/utils.h @@ -0,0 +1,55 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include +#include +#include + +#define VERBOSE() verbose(__LINE__) +#define WARN() warn(__LINE__) + +namespace nvfuser { + +// Temporary logging utility +class DebugStream { + public: + DebugStream() + : enabled_(getNvFuserEnv("ID_MODEL_VERBOSE")), out_(std::cerr) {} + + template + DebugStream& operator<<(const T& v) { + if (enabled_) { + out_ << v; + } + return *this; + } + + DebugStream& operator<<(std::ostream& (*endl)(std::ostream&)) { + if (enabled_) { + out_ << endl; + } + return *this; + } + + private: + bool enabled_ = false; + std::ostream& out_; +}; + +inline DebugStream verbose(int line) { + return DebugStream() << "[DEBUG@" << line << "] "; +} + +inline DebugStream warn(int line) { + return DebugStream() << "[WARN@" << line << "] "; +} + +} // namespace nvfuser diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index f5578a76648..81798bb5cb7 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -310,6 +310,14 @@ void ValGraph::initializeVal(Val* val) { initializeVal(val, defs, uses); } +void ValGraph::registerExpr(Expr* expr) { + NVF_ERROR( + !disjoint_exprs_.mappingExists(expr), + "Already in the disjoint sets: ", + expr->toString()); + disjoint_exprs_.initializeSet(expr); +} + bool ValGraph::exprsMap(Expr* first, Expr* second, bool forward) const { NVF_ERROR(first); NVF_ERROR(second); diff --git a/csrc/val_graph.h b/csrc/val_graph.h index f80a610ac5d..121768593b0 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -195,6 +195,11 @@ class ValGraph { // used void initializeVal(Val* val); + // Add expr to the disjoint sets as a sole group. Used for + // registering replayed domains and exprs. Error if the expr is + // already registered. + void registerExpr(Expr* expr); + // Returns true if first and second are expressions through which // this ValGraph has matching inputs (if forward), or outputs (if not // forward). Returning true means the expressions are "the same", in terms diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index af56ad89982..790837a8256 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -16,8 +16,10 @@ #include #include #include +#include #include #include +#include #include namespace nvfuser { @@ -76,8 +78,12 @@ class IdModelTester : public IdModel { // Do not automatically build the graphs IdModelTester(Fusion* fusion) : IdModel(fusion, /* build_graphs */ false) {} - std::pair> - getInlineRootResolutionMap() { + // Returns the IEL graph and the results of Steps 1 and 2 + std::tuple< + ValGraph, + std::unordered_map, + std::unordered_map> + getLoopPromotionInfo() { // Make sure the depedent graphs are already built maybeBuildGraph(IdMappingMode::EXACT); maybeBuildGraph(IdMappingMode::PERMISSIVE); @@ -90,41 +96,148 @@ class IdModelTester : public IdModel { initializeLoopGraph(inlining_info); + VERBOSE() << "Initial loop graph:\n"; + for (const auto& group : + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { + VERBOSE() << nvfuser::toString(group) << std::endl; + } + ValGraph iel_graph = buildIntersection( idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); std::unordered_map root_promotion_map = buildInlineRootResolutionMap(iel_graph, inlining_info); - return {std::move(iel_graph), std::move(root_promotion_map)}; + { + std::stringstream ss; + ss << "Step 1: Root promotion map\n"; + for (const auto& [iel_group, promoted_id] : root_promotion_map) { + ss << "\t" << nvfuser::toString(iel_group) << " -> " + << promoted_id->name() << std::endl; + } + VERBOSE() << ss.str(); + } + + auto iel_promotion_map = root_promotion_map; + + propagatePromotionsInIELGraph(iel_graph, iel_promotion_map); + + { + std::stringstream ss; + ss << "Step 2: IEL promotion map\n"; + for (const auto& [iel_group, promoted_id] : iel_promotion_map) { + ss << "\t" << nvfuser::toString(iel_group) << " -> " + << promoted_id->name() << std::endl; + } + VERBOSE() << ss.str(); + } + + return { + std::move(iel_graph), + std::move(root_promotion_map), + std::move(iel_promotion_map)}; } }; -// Test if root_broadcast_id is resolved to ref_id. If ref_id is -// nullptr, test if root_broadcast_id has no resolution. -void validateResolution( - IterDomain* root_broadcast_id, +// Test if id is resolved to an ID that is exact mapped with +// ref_id. If ref_id is nullptr, test if root_broadcast_id has no +// resolution. +void validateIELResolution( + IterDomain* id, IterDomain* ref_id, const ValGraph& iel_graph, - const std::unordered_map& root_resolution_map) { - ASSERT_TRUE(root_broadcast_id->isBroadcast()); - const auto& iel_group = iel_graph.toGroup(root_broadcast_id); - auto root_promotion_map_it = root_resolution_map.find(iel_group); + const ValGraph& exact_graph, + const std::unordered_map& iel_promotion_map) { + const auto& iel_group = iel_graph.toGroup(id); + auto iel_promotion_map_it = iel_promotion_map.find(iel_group); if (ref_id != nullptr) { - ASSERT_TRUE(root_promotion_map_it != root_resolution_map.end()) - << "Root resolution not found for: " << nvfuser::toString(iel_group); + ASSERT_TRUE(iel_promotion_map_it != iel_promotion_map.end()) + << "IEL promotion not found for: " << nvfuser::toString(iel_group); ASSERT_FALSE(ref_id->isBroadcast()); - auto resolution_id = root_promotion_map_it->second; + auto promotion_id = iel_promotion_map_it->second; ASSERT_TRUE( - iel_graph.disjointValSets().strictAreMapped(resolution_id, ref_id)) - << "Unexpected root resolution. " + exact_graph.disjointValSets().strictAreMapped(promotion_id, ref_id)) + << "Unexpected promotion. " << "Expected: " << ref_id->toString() - << ". Actual: " << resolution_id->toString(); + << ". Actual: " << promotion_id->toString(); } else { - ASSERT_TRUE(root_promotion_map_it == root_resolution_map.end()) - << "Root resolution should not exist for: " - << nvfuser::toString(iel_group) - << ", but found: " << root_promotion_map_it->second->toString(); + ASSERT_TRUE(iel_promotion_map_it == iel_promotion_map.end()) + << "Promotion should not exist for: " << nvfuser::toString(iel_group) + << ", but found: " << iel_promotion_map_it->second->toString(); + } +} + +// Check if each domain gets promoted to a proper domain after the +// Step 2 IEL propagation. It is assumed that the proper promotion is +// the corresponding domain in the unique consumer tensor, which is +// the case with most of the test fusions. +void checkStep2Results( + Fusion* fusion, + const ValGraph& iel_graph, + const ValGraph& exact_graph, + const std::unordered_map& iel_promotion_map) { + auto getPromotedDomain = [&](IterDomain* id) -> IterDomain* { + if (auto it = iel_promotion_map.find(iel_graph.toGroup(id)); + it != iel_promotion_map.end()) { + return it->second; + } else { + return nullptr; + } + }; + + for (auto tv : ir_utils::allTvs(fusion)) { + // If there's no broadcast or it isn't inlined, there's no + // promotion + if (std::none_of( + tv->getRootDomain().begin(), + tv->getRootDomain().end(), + [](auto id) { return id->isBroadcast(); }) || + (tv->getComputeAtPosition() == 0 && + tv->getMaxProducerPosition() == 0)) { + // Make sure there's no promotion of any of the IDs of this tensor + for (auto id : ir_utils::allIDsOf(tv)) { + auto promoted_id = getPromotedDomain(id); + ASSERT_EQ(promoted_id, nullptr) + << "Expected no mapping for " << id->toString() + << " but found to be mapped to: " << promoted_id->toString(); + } + continue; + } + + auto consumers = ir_utils::consumerTvsOf(tv); + ASSERT_EQ(consumers.size(), 1) << "Assumed to have one consumer"; + TensorView* c_tv = consumers.at(0); + const auto p2c = BestEffortReplay::replayCasP( + c_tv, tv, -1, PairwiseRootDomainMap(tv, c_tv)) + .getReplay(); + + for (auto p_id : ir_utils::allIDsOf(tv)) { + // Root domains are already done at Step 1 + if (std::find( + tv->getRootDomain().begin(), tv->getRootDomain().end(), p_id) != + tv->getRootDomain().end()) { + continue; + } + + // If no broadcast is involved, nothing should be promoted + auto p_id_dep_vals = DependencyCheck::getAllValsBetween( + {tv->getRootDomain().begin(), tv->getRootDomain().end()}, {p_id}); + if (std::find_if( + p_id_dep_vals.begin(), p_id_dep_vals.end(), [](Val* dep_id) { + return dep_id->as()->isBroadcast(); + }) == p_id_dep_vals.end()) { + auto promoted_id = getPromotedDomain(p_id); + ASSERT_EQ(promoted_id, nullptr) + << "Expected no mapping for " << p_id->toString() + << " but found to be mapped to: " << promoted_id->toString(); + continue; + } + + // p_id should be promoted to c_id + auto c_id = p2c.at(p_id); + validateIELResolution( + p_id, c_id, iel_graph, exact_graph, iel_promotion_map); + } } } @@ -447,7 +560,7 @@ TEST_F(IdModelTest, ValGraphStmtSort4) { checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); } -// Testing root resolution with a simple broadcast pattern +// Testing loop promotion with a simple broadcast pattern TEST_F(IdModelTest, LoopPromotion1) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -462,8 +575,8 @@ TEST_F(IdModelTest, LoopPromotion1) { { IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getLoopPromotionInfo(); // Nothing inlined. Should be no resolution ASSERT_TRUE(root_resolution_map.empty()); @@ -474,16 +587,24 @@ TEST_F(IdModelTest, LoopPromotion1) { { IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getLoopPromotionInfo(); + // Check Step 1 results // t2 is now fully inlined. Its root broadcast domain should be // resoled with the corresponding domain of t3 - validateResolution( + validateIELResolution( t2->getRootDomain().at(0), t3->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); + + // Check Step 2 results + // Nothing to propagate in this fusion, so iel_promotion_map + // should be equivalent to root_resolution_map + ASSERT_EQ(root_resolution_map, iel_promotion_map) + << "Unexpected IEL promotion map"; } } @@ -505,21 +626,30 @@ TEST_F(IdModelTest, LoopPromotion2) { inlineMost(); IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getLoopPromotionInfo(); + // Check Step 1 results // Validate t2 and t3 as they have root broadcast domains - validateResolution( + validateIELResolution( t2->getRootDomain().at(0), t4->getRootDomain().at(1), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); - validateResolution( + validateIELResolution( t3->getRootDomain().at(0), t4->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); + + // Check Step 2 results + // Nothing to propagate in this fusion, so iel_promotion_map + // should be equivalent to root_resolution_map + ASSERT_EQ(root_resolution_map, iel_promotion_map) + << "Unexpected IEL promotion map"; } // Multiple inlined and non-inlined broadcast domains @@ -549,19 +679,40 @@ TEST_F(IdModelTest, LoopPromotion3) { // tv3: [i0*i1, i2*i3] IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getLoopPromotionInfo(); + // Check Step 1 results // The b1 broadcast domain tv2 should be resolved as it's inlined, // but b3 should not. - validateResolution( + validateIELResolution( tv2->getRootDomain().at(1), tv3->getRootDomain().at(1), iel_graph, + tester.idGraph(IdMappingMode::EXACT), + root_resolution_map); + + validateIELResolution( + tv2->getRootDomain().at(3), + nullptr, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); - validateResolution( - tv2->getRootDomain().at(3), nullptr, iel_graph, root_resolution_map); + // Check Step 2 results + validateIELResolution( + tv2->axis(0), + tv3->axis(0), + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); + + validateIELResolution( + tv2->axis(1), + nullptr, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } // Test root resolution with a fusion with outer split. @@ -597,11 +748,10 @@ TEST_F(IdModelTest, LoopPromotion4) { } IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getLoopPromotionInfo(); - // Verify all tensors with broadcast have correct resolution of root - // broadcast domains + // Verify all tensors with root broadcast have correct resolutions for (auto tv : ir_utils::allTvs(&fusion)) { // Skip tensors with no broadcast or non-inlined if (std::none_of( @@ -616,16 +766,23 @@ TEST_F(IdModelTest, LoopPromotion4) { case 2: // T2_l[ iS20{4}, iS21{( ceilDiv(( 1 * 4 ), 4) )} ] ca_pos( 1 ) // root domain : (bS4{1}, iS5{4}) - validateResolution( + validateIELResolution( tv->getRootDomain().at(0), tv4->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); } } + + checkStep2Results( + &fusion, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } // Test root resolution with the same fusion as Indexing1 @@ -666,11 +823,10 @@ TEST_F(IdModelTest, LoopPromotion5) { auto all_tvs = ir_utils::allTvs(&fusion); IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getLoopPromotionInfo(); - // Verify all tensors with broadcast have correct resolution of root - // broadcast domains + // Check Step 1 results for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( @@ -686,29 +842,37 @@ TEST_F(IdModelTest, LoopPromotion5) { // T3_l[ iS30{( ceilDiv(( ceilDiv(( ( ( 1 * i0 ) * i2 ) * i3 ), 128) ), // 4) )}, iUR31{4}, ithreadIdx.x29{128} ] ca_pos( 1 ) produce_pos( 1 ) // root domain : (bS10{1}, iS11{i0}, iS12{i2}, iS13{i3}) - validateResolution( + validateIELResolution( tv->getRootDomain().at(0), tv4->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); } } + + // Check Step 2 results + checkStep2Results( + &fusion, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } // Test root resolution with the same fusion as Indexing19 TEST_F(IdModelTest, LoopPromotion6) { auto fusion = createFusionWithMultipleResolutionPaths(); + FusionGuard fg(fusion.get()); auto all_tvs = ir_utils::allTvs(fusion.get()); IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getLoopPromotionInfo(); - // Verify all tensors with broadcast have correct resolution of root - // broadcast domains + // Check Step 1 results for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( @@ -725,10 +889,11 @@ TEST_F(IdModelTest, LoopPromotion6) { // iS48{5} ] ca_pos( 1 ) produce_pos( 1 ) // root domain : (iS2{7}, bS3{1}) // Resolution: Resolved by the immediate consumer (T4) - validateResolution( + validateIELResolution( tv->getRootDomain().at(1), getTensorByName(all_tvs, 4)->getRootDomain().at(1), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; case 5: @@ -738,10 +903,11 @@ TEST_F(IdModelTest, LoopPromotion6) { // Resolution: T5 is not inlined to the immediate consumer, // T10. Resolution is done with the other path from T1, such // as T8 or T9. - validateResolution( + validateIELResolution( tv->getRootDomain().at(2), getTensorByName(all_tvs, 9)->getRootDomain().at(2), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; case 6: @@ -749,10 +915,11 @@ TEST_F(IdModelTest, LoopPromotion6) { // iS63{5} ] ca_pos( 1 ) produce_pos( 1 ) // root domain : (iS11{7}, bS12{1}) // Resolution: Resolved by the immediate consumer (T8) - validateResolution( + validateIELResolution( tv->getRootDomain().at(1), getTensorByName(all_tvs, 8)->getRootDomain().at(1), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; case 9: @@ -762,16 +929,23 @@ TEST_F(IdModelTest, LoopPromotion6) { // Resolution: T9 is not inlined to the immediate consumer, // T10. Resolution is done with the other path from T1, such // as T4 or T5 - validateResolution( + validateIELResolution( tv->getRootDomain().at(1), getTensorByName(all_tvs, 5)->getRootDomain().at(1), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); } } + + checkStep2Results( + fusion.get(), + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } // Same fusion as NvFuserTest.FusionInlineBroadcastIndexing0 @@ -802,11 +976,10 @@ TEST_F(IdModelTest, LoopPromotion7) { auto all_tvs = ir_utils::allTvs(&fusion); IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getLoopPromotionInfo(); - // Verify all tensors with broadcast have correct resolution of root - // broadcast domains + // Verify all tensors with root broadcast have correct resolutions for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( @@ -821,16 +994,23 @@ TEST_F(IdModelTest, LoopPromotion7) { case 3: // T3_l[ iS15{( ceilDiv(( 1 * i0 ), 32) )}, iS16{32} ] ca_pos( 1 ) // produce_pos( 1 ) root domain : (bS4{1}, iS5{i0}) - validateResolution( + validateIELResolution( tv->getRootDomain().at(0), - tv4->getRootDomain().at(0), + getTensorByName(all_tvs, 4)->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); } } + + checkStep2Results( + &fusion, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } // Same fusion as NvFuserTest.FusionIndexing20 @@ -879,11 +1059,10 @@ TEST_F(IdModelTest, LoopPromotion8) { auto all_tvs = ir_utils::allTvs(&fusion); IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map] = - tester.getInlineRootResolutionMap(); + const auto& [iel_graph, root_resolution_map, iel_promotion_map] = + tester.getLoopPromotionInfo(); - // Verify all tensors with broadcast have correct resolution of root - // broadcast domains + // Verify all tensors with root broadcast have correct resolutions for (auto tv : all_tvs) { // Skip tensors with no broadcast or non-inlined if (std::none_of( @@ -898,26 +1077,34 @@ TEST_F(IdModelTest, LoopPromotion8) { case 2: // T2_l[ iS21{2}, iS22{( ceilDiv(( 1 * 5 ), 2) )} ] ca_pos( 1 ) // produce_pos( 1 ) root domain : (bS2{1}, iS3{5}) - validateResolution( + validateIELResolution( tv->getRootDomain().at(0), - tv7->getRootDomain().at(0), + getTensorByName(all_tvs, 7)->getRootDomain().at(0), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; case 5: // T5_l[ iS27{2}, iS40{4}, iS41{( ceilDiv(( ( ceilDiv(( 3 * 5 ), 2) ) * // 1 ), 4) )} ] ca_pos( 2 ) produce_pos( 1 ) root domain : (iS8{3}, // iS9{5}, bS10{1}) - validateResolution( + validateIELResolution( tv->getRootDomain().at(2), - tv7->getRootDomain().at(2), + getTensorByName(all_tvs, 7)->getRootDomain().at(2), iel_graph, + tester.idGraph(IdMappingMode::EXACT), root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); } } + + checkStep2Results( + &fusion, + iel_graph, + tester.idGraph(IdMappingMode::EXACT), + iel_promotion_map); } } // namespace nvfuser From c14dd723a7a5bd0d243de7858b33175a5ee14f0e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 22 Feb 2024 18:45:14 -0800 Subject: [PATCH 153/178] cleanup --- csrc/id_model/id_model.cpp | 22 ------------------ csrc/id_model/id_model.h | 40 +++++++++++++++++++++++++------- csrc/id_model/transform_replay.h | 3 +++ 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index c1c19a9894d..ac8076f973e 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -1003,26 +1003,6 @@ bool hasUniqueOutputLoopGroups( } // namespace -// Propagate promotion mappings from root domains to derived domains -// by traversing IEL exprs. For each expr, if an input is promoted, -// the output needs to be promoted too. If there's already a domain -// that the output domain should be promoted to, create a mapping to it from -// the promoted output domain. If not, a new domain is created by -// replaying the expr with the promoted inputs. -// -// This is used twice when building the promotion map. The first time -// it is used there's no loop graph promotion yet, so only the IEL -// promotions are propagated. In that case, loop_graph_promotion_map -// should be just empty. -// -// Propagation uses iel_promotion_map and -// loop_graph_promotion_map. If both are available for an IEL group, -// the former has the precedence. This is because when this function -// is used for step 4, the given iel_promotion_map is empty and gets -// populated during this propagation, whereas the loop promotion map -// is not guaranteed to have the correct mappings for partially -// inlined domains. -// // The loop_graph pamameter may not be up-to-date. void IdModel::propagatePromotionsInIELGraph( const ValGraph& iel_graph, @@ -1034,8 +1014,6 @@ void IdModel::propagatePromotionsInIELGraph( // topologically sorted. ValGraphStmtSort iel_stmt_sort(iel_graph); - // TODO-NM: The ordering might be non-deterministic - for (const ExprGroup& iel_expr : iel_stmt_sort.exprs()) { NVF_ERROR(!iel_expr->empty()); const std::vector iel_inp_groups = diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 4a322d1bbcb..c6dd49dbff9 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -194,14 +194,29 @@ class IdModel : public PolymorphicBase { const ValGraph& iel_graph, const StatefulInliningInfo& info); - // Helper function for building loop promotion map. Propagate - // promotions of root IEL groups to leaf IEL groups - void propagatePromotionsInIELGraph( - const ValGraph& iel_graph, - std::unordered_map& iel_promotion_map); - - // Same as the other version but also propagates promotoins of loop - // groups as well + // Helper function for building loop promotion map. + // + // Propagate promotion mappings from root IEL groups to intermediate + // and leaf IEL groups by traversing IEL exprs. For each expr, if an + // input is promoted, the output needs to be promoted too. If + // there's already a domain that the output domain should be + // promoted to, create a mapping to it from the promoted output + // domain. If not, a new domain is created by replaying the expr + // with the promoted inputs. + // + // This is used twice when building the promotion map. The first time + // it is used there's no loop graph promotion yet, so only the IEL + // promotions are propagated. In that case, loop_graph_promotion_map + // should be just empty. + // + // Propagation uses iel_promotion_map and + // loop_graph_promotion_map. If both are available for an IEL group, + // the former has the precedence. This is because when this function + // is used for step 4, the given iel_promotion_map starts as an + // empty map and gets populated during this propagation, so any + // mapping in the map is guaranteed to be the correct final mapping, + // whereas the loop graph may have invalid mappings for partially + // inlined domains. void propagatePromotionsInIELGraph( const ValGraph& iel_graph, std::unordered_map& iel_promotion_map, @@ -209,10 +224,17 @@ class IdModel : public PolymorphicBase { const std::unordered_map& loop_promotion_map, bool require_loop_mapped_promotion); + // Same as the other propagatePromotionsInIELGraph but without loop + // graph map. This is used for step 2, where there's no loop + // graph map yet. + void propagatePromotionsInIELGraph( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map); + // Errors if self mapping occurs void assertNoSelfMapping(); - // Replay Expr but with the inputs provided. IterDomainGraphss will be updated + // Replay Expr but with the inputs provided. ValGraphs will be updated // for all maps that have entries, adding the output iter domains of the // replayed expression and adding potential mappings through the expression. Expr* addReplayAs(std::vector new_inputs, Expr* expr); diff --git a/csrc/id_model/transform_replay.h b/csrc/id_model/transform_replay.h index a37e9ab4aa0..531dcc9729d 100644 --- a/csrc/id_model/transform_replay.h +++ b/csrc/id_model/transform_replay.h @@ -16,6 +16,9 @@ namespace nvfuser { +// TODO: Consider merging this class with the existing replay +// classes. The use cases are not exactly the same, so it isn't +// immediately clear if they could be trivially merge. class ReplayTransform : OptInConstDispatch { public: // Replays expression_to_match with the provided ordered_inputs. Inputs should From fc861a6af3563d482b4edd5143baf6f07df05dc8 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 22 Feb 2024 21:09:31 -0800 Subject: [PATCH 154/178] update --- csrc/id_model/id_model.cpp | 243 ++++++++++++++++++------------------- csrc/id_model/id_model.h | 11 +- 2 files changed, 123 insertions(+), 131 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index ac8076f973e..62f18ea9b94 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -958,10 +958,9 @@ ValGraph IdModel::buildIntersection( namespace { -// When replaying the transformations we can't blindly apply loop promotion -// to all iter domains within a loop group as it would replay the -// transformations within that loop group on the promoted id of that loop -// group. +// When propagating loop promotions from inputs to outputs of an IEL +// expr, we can't blindly apply loop promotion when all of the input +// domains are loop mapped with the outputs. // // i.e. if we have the inlined domains from: // T2[i0*i1] pa(1) = T0[i0*b1]ca(1) + T1[i0*i1]ca(1) @@ -978,9 +977,9 @@ namespace { // Therefore only promote i0*b1 to i0*i1, or i0*i1 to i0*i1 (i.e. don't // promote an input to any transformation within the loop group). // -// So if we have an iel_expr make sure it's inputs and outputs are not in +// So if we have an iel_expr make sure its inputs and outputs are not in // the same loop group. -bool hasUniqueOutputLoopGroups( +bool hasUniqueInputLoopGroups( const ExprGroup& iel_expr, const ValGraph& iel_graph, const ValGraph& loop_graph) { @@ -997,13 +996,99 @@ bool hasUniqueOutputLoopGroups( out_loop_groups.pushBack(loop_graph.toGroup(iel_out_group->front())); } - // Check if output groups that are not included in the input group set + // Check if input groups that are not included in the output group set return !inp_loop_groups.computeSubtract(out_loop_groups).empty(); } +// Check if there's an equivalent expression as iel_expr that uses +// maybe_promoted_inputs. This is used to avoid redundantly replaying +// expressions. +// NOTE: This is currently overly conservative and some +// opportunities for reuse are lost, althought it doesn't affect +// the correctness of the analysis. +Expr* findMatchingExpr( + const ExprGroup& iel_expr, + const ValGraph& iel_graph, + const std::vector& maybe_promoted_inputs, + bool require_loop_mapped_promotion, + const ValGraph& loop_graph) { + // Grab all uses of the promoted inputs + ExprGroups maybe_promoted_input_uses; + for (auto inp_id : maybe_promoted_inputs) { + // inp_id may have been just replayed, in which case it should + // not exist in the IEL graph. It should be just ignored as it + // should not have any use yet. + if (!iel_graph.hasGroup(inp_id)) { + continue; + } + const auto& inp_exact_group = iel_graph.toGroup(inp_id); + maybe_promoted_input_uses.pushBack(iel_graph.getUses(inp_exact_group)); + } + + // Look for exprs that have inputs that are mapped in the IEL + // graph with the (promoted) inputs of iel_expr. + for (const ExprGroup& maybe_promoted_input_use_group : + maybe_promoted_input_uses) { + NVF_ERROR(!maybe_promoted_input_use_group->empty()); + // TODO: why skip this? If iel_expr is also an use of the promoted + // inputs, shouldn't it be also a candidate? + if (iel_expr == maybe_promoted_input_use_group) { + continue; + } + Expr* maybe_promoted_input_use = maybe_promoted_input_use_group->front(); + if (!iel_expr->front()->sameOp(maybe_promoted_input_use)) { + continue; + } + // Check if all inputs are mapped + NVF_ERROR( + maybe_promoted_inputs.size() == + maybe_promoted_input_use->inputs().size()); + bool all_inputs_match = true; + for (const auto inp_i : c10::irange(maybe_promoted_inputs.size())) { + // Here, new promoted ids are not added to iel_graph, so + // once promoted, this should not return true anymore. Also, + // strictAreMapped doesn't work as promoted domains are not + // in the graph + all_inputs_match = all_inputs_match && + iel_graph.disjointValSets().permissiveAreMapped( + maybe_promoted_inputs[inp_i], + maybe_promoted_input_use->inputs().at(inp_i)); + } + if (!all_inputs_match) { + continue; + } + + // For the final loop promotion map, we want to find + // promotions within the same loop groups. Note that that's + // guaranteed when replayed. + if (require_loop_mapped_promotion) { + if (!loop_graph.disjointExprSets().permissiveAreMapped( + iel_expr->front(), maybe_promoted_input_use_group->front())) { + continue; + } + // This is just an extra sanity check. Make sure all exprs in + // the use group are mapped + NVF_ERROR( + std::all_of( + maybe_promoted_input_use_group->vector().begin(), + maybe_promoted_input_use_group->vector().end(), + [&](Expr* iel_use) { + return loop_graph.disjointExprSets().permissiveAreMapped( + iel_expr->front(), iel_use); + }), + "Not all mapped: ", + nvfuser::toString(iel_expr), + "\n", + nvfuser::toString(maybe_promoted_input_use_group)); + } + return maybe_promoted_input_use; + } + + return nullptr; +} + } // namespace -// The loop_graph pamameter may not be up-to-date. void IdModel::propagatePromotionsInIELGraph( const ValGraph& iel_graph, std::unordered_map& iel_promotion_map, @@ -1022,7 +1107,7 @@ void IdModel::propagatePromotionsInIELGraph( // Propagate loop graph promotion only when the inputs and outputs are // not in the same loop group. const bool loop_promote_inputs = !loop_graph_promotion_map.empty() && - hasUniqueOutputLoopGroups(iel_expr, iel_graph, loop_graph); + hasUniqueInputLoopGroups(iel_expr, iel_graph, loop_graph); // Check if any inputs need promotion indicating this expr group needs to // be replayed with promoted inputs @@ -1034,11 +1119,7 @@ void IdModel::propagatePromotionsInIELGraph( // Assumed all inputs are IterDomains NVF_ERROR(iel_inp_group->front()->isA()); - // Even when loop promotions are given, We still could require - // an input promotion. We could be traversing across non-inlined - // groups. Meaning we have inputs that were promoted in an - // inlined loop group traversing through the non-inlined - // portions of the iel graph. + // Propagate IEL promotions when available. if (auto inp_promo_it = iel_promotion_map.find(iel_inp_group); inp_promo_it != iel_promotion_map.end()) { maybe_promoted_inputs.push_back(inp_promo_it->second); @@ -1047,11 +1128,7 @@ void IdModel::propagatePromotionsInIELGraph( } // Promote loops based on the loop promotion map. If the loop promotion - // map should be used and has an entry we should use that promotion. This - // happen when an iel expression is across a loop group boundary. - // Signifying and capturing instances when we traverse across an inlined - // loop group to a non-inlined loop group boundary (think of the iel graph - // projected onto the loop graph). + // map should be used and has an entry we should use that promotion. if (loop_promote_inputs) { const ValGroup& loop_copy_group = loop_graph.toGroup(iel_inp_group->front()); @@ -1072,97 +1149,14 @@ void IdModel::propagatePromotionsInIELGraph( continue; } - // Before replaying, check if there's already an expression like this, if so - // use that for promotion. We would need the iel entries for non-promoted - // inputs to match exactly to reuse the expression. - auto findMatchingExpr = - [this, &require_loop_mapped_promotion]( - const ExprGroup& iel_expr, - const ValGraph& iel_graph, - const std::vector& maybe_promoted_inputs) -> Expr* { - ExprGroups maybe_promoted_input_uses; - - for (auto inp_id : maybe_promoted_inputs) { - // inp_id may have been just replayed, in which case it should - // not exist in the IEL graph. It should be just ignored as it - // should not have any use yet. - if (!iel_graph.hasGroup(inp_id)) { - continue; - } - const auto& inp_exact_group = iel_graph.toGroup(inp_id); - maybe_promoted_input_uses.pushBack(iel_graph.getUses(inp_exact_group)); - } - - // Look for exprs that have inputs that are mapped in the IEL - // graph with the (promoted) inputs of iel_expr. If found, no need - // to create a new expr to produce promoted outputs - for (const ExprGroup& maybe_promoted_input_use_group : - maybe_promoted_input_uses) { - NVF_ERROR(!maybe_promoted_input_use_group->empty()); - // No need to check itself - if (iel_expr == maybe_promoted_input_use_group) { - continue; - } - Expr* maybe_promoted_input_use = - maybe_promoted_input_use_group->front(); - if (!iel_expr->front()->sameOp(maybe_promoted_input_use)) { - continue; - } - // Check if all inputs are mapped - NVF_ERROR( - maybe_promoted_inputs.size() == - maybe_promoted_input_use->inputs().size()); - bool inps_match = true; - for (const auto inp_i : c10::irange(maybe_promoted_inputs.size())) { - // Here, new promoted ids are not added to iel_graph, so - // once promoted, this should not return true anymore. Also, - // strictAreMapped doesn't work as promoted domains are not - // in the graph - inps_match = inps_match && - iel_graph.disjointValSets().permissiveAreMapped( - maybe_promoted_inputs[inp_i], - maybe_promoted_input_use->inputs().at(inp_i)); - } - if (!inps_match) { - continue; - } - - // For the final loop promotion map, we want to find - // promotions within the same loop groups. Note that that's - // guaranteed when replayed. - if (require_loop_mapped_promotion) { - if (!idGraph(IdMappingMode::LOOP) - .disjointExprSets() - .permissiveAreMapped( - iel_expr->front(), - maybe_promoted_input_use_group->front())) { - continue; - } - // This is just an extra sanity check. Make sure all exprs in - // the use group are mapped - NVF_ERROR( - std::all_of( - maybe_promoted_input_use_group->vector().begin(), - maybe_promoted_input_use_group->vector().end(), - [&](Expr* iel_use) { - return idGraph(IdMappingMode::LOOP) - .disjointExprSets() - .permissiveAreMapped(iel_expr->front(), iel_use); - }), - "Not all mapped: ", - nvfuser::toString(iel_expr), - "\n", - nvfuser::toString(maybe_promoted_input_use_group)); - } - return maybe_promoted_input_use; - } - - return nullptr; - }; + Expr* promoted_expr = findMatchingExpr( + iel_expr, + iel_graph, + maybe_promoted_inputs, + require_loop_mapped_promotion, + idGraph(IdMappingMode::LOOP)); bool replayed = false; - Expr* promoted_expr = - findMatchingExpr(iel_expr, iel_graph, maybe_promoted_inputs); if (!promoted_expr) { promoted_expr = addReplayAs(maybe_promoted_inputs, iel_expr->front()); @@ -1224,10 +1218,6 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { initialized_modes.push_back(mode); } - auto orig_inputs = ir_utils::filterByType(expr->inputs()); - std::vector orig_input_ids( - orig_inputs.begin(), orig_inputs.end()); - // Replace the provided inputs with IterType::Iteration domains as // reduction domains cannot be merged with non-reduction domains. if (std::any_of( @@ -1238,20 +1228,24 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { return !id->isReduction(); })) { // Inputs have mismatched type, replace new_inputs - decltype(new_inputs) tmp_inputs; - std::swap(tmp_inputs, new_inputs); - for (auto tmp_input : tmp_inputs) { - new_inputs.push_back( - IterDomainBuilder(tmp_input).iter_type(IterType::Iteration).build()); - id_definitions_[new_inputs.back()]; - id_uses_[new_inputs.back()]; + auto tmp_inputs = new_inputs; + for (const auto i : c10::irange(new_inputs.size())) { + new_inputs.at(i) = IterDomainBuilder(tmp_inputs.at(i)) + .iter_type(IterType::Iteration) + .build(); + id_definitions_[new_inputs.at(i)]; + id_uses_[new_inputs.at(i)]; for (auto mode : initialized_modes) { - idGraph(mode).initializeVal(new_inputs.back(), {}, {}); - idGraph(mode).mapVals(new_inputs.back(), tmp_input); + idGraph(mode).initializeVal(new_inputs.at(i), {}, {}); + idGraph(mode).mapVals(new_inputs.at(i), tmp_inputs.at(i)); } } } + const std::vector orig_input_ids = + ir_utils::filterByType(expr->inputs()).vector(); + + // Sanity check of the original inputs { NVF_ERROR( new_inputs.size() == orig_input_ids.size(), @@ -1260,13 +1254,8 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { " does not match number of iter domain inputs for ", expr->toString()); - VectorOfUniqueEntries all_inputs{ - orig_input_ids.begin(), orig_input_ids.end()}; - - all_inputs.pushBack(new_inputs); - for (auto mode : initialized_modes) { - for (auto inp : all_inputs) { + for (auto inp : orig_input_ids) { NVF_ERROR( idGraph(mode).hasGroup(inp), "All inputs for replay need to be initialized in all graphs, ", diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index c6dd49dbff9..ff5c3b2e25c 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -199,10 +199,13 @@ class IdModel : public PolymorphicBase { // Propagate promotion mappings from root IEL groups to intermediate // and leaf IEL groups by traversing IEL exprs. For each expr, if an // input is promoted, the output needs to be promoted too. If - // there's already a domain that the output domain should be - // promoted to, create a mapping to it from the promoted output - // domain. If not, a new domain is created by replaying the expr - // with the promoted inputs. + // there's already an equivalent expr that uses the promoted inputs, + // create a mapping from the outputs of the IEL expr to the outputs + // of the equivalent expr. When require_loop_mapped_promotion is + // true, the equivalent expr needs to be already loop mapped. If no + // such expr is found, the IEL expr is replayed iwth the promoted + // inputs. require_loop_mapped_promotion is true when this function + // is used for step 3. // // This is used twice when building the promotion map. The first time // it is used there's no loop graph promotion yet, so only the IEL From e6071e566757e77cc1afd5004d5f39e09939dc59 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 22 Feb 2024 22:04:22 -0800 Subject: [PATCH 155/178] cleanup --- csrc/id_model/utils.h | 55 ------------------------------------------ test/test_id_model.cpp | 33 +++---------------------- 2 files changed, 3 insertions(+), 85 deletions(-) delete mode 100644 csrc/id_model/utils.h diff --git a/csrc/id_model/utils.h b/csrc/id_model/utils.h deleted file mode 100644 index 2d6327bf586..00000000000 --- a/csrc/id_model/utils.h +++ /dev/null @@ -1,55 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include - -#include -#include -#include - -#define VERBOSE() verbose(__LINE__) -#define WARN() warn(__LINE__) - -namespace nvfuser { - -// Temporary logging utility -class DebugStream { - public: - DebugStream() - : enabled_(getNvFuserEnv("ID_MODEL_VERBOSE")), out_(std::cerr) {} - - template - DebugStream& operator<<(const T& v) { - if (enabled_) { - out_ << v; - } - return *this; - } - - DebugStream& operator<<(std::ostream& (*endl)(std::ostream&)) { - if (enabled_) { - out_ << endl; - } - return *this; - } - - private: - bool enabled_ = false; - std::ostream& out_; -}; - -inline DebugStream verbose(int line) { - return DebugStream() << "[DEBUG@" << line << "] "; -} - -inline DebugStream warn(int line) { - return DebugStream() << "[WARN@" << line << "] "; -} - -} // namespace nvfuser diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index b382b802bdb..3e25e0334b8 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -96,42 +95,16 @@ class IdModelTester : public IdModel { initializeLoopGraph(inlining_info); - VERBOSE() << "Initial loop graph:\n"; - for (const auto& group : - idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { - VERBOSE() << nvfuser::toString(group) << std::endl; - } - ValGraph iel_graph = buildIntersection( idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); std::unordered_map root_promotion_map = buildInlineRootResolutionMap(iel_graph, inlining_info); - { - std::stringstream ss; - ss << "Step 1: Root promotion map\n"; - for (const auto& [iel_group, promoted_id] : root_promotion_map) { - ss << "\t" << nvfuser::toString(iel_group) << " -> " - << promoted_id->name() << std::endl; - } - VERBOSE() << ss.str(); - } - auto iel_promotion_map = root_promotion_map; propagatePromotionsInIELGraph(iel_graph, iel_promotion_map); - { - std::stringstream ss; - ss << "Step 2: IEL promotion map\n"; - for (const auto& [iel_group, promoted_id] : iel_promotion_map) { - ss << "\t" << nvfuser::toString(iel_group) << " -> " - << promoted_id->name() << std::endl; - } - VERBOSE() << ss.str(); - } - return { std::move(iel_graph), std::move(root_promotion_map), @@ -996,7 +969,7 @@ TEST_F(IdModelTest, LoopPromotion7) { // produce_pos( 1 ) root domain : (bS4{1}, iS5{i0}) validateIELResolution( tv->getRootDomain().at(0), - getTensorByName(all_tvs, 4)->getRootDomain().at(0), + tv4->getRootDomain().at(0), iel_graph, tester.idGraph(IdMappingMode::EXACT), root_resolution_map); @@ -1079,7 +1052,7 @@ TEST_F(IdModelTest, LoopPromotion8) { // produce_pos( 1 ) root domain : (bS2{1}, iS3{5}) validateIELResolution( tv->getRootDomain().at(0), - getTensorByName(all_tvs, 7)->getRootDomain().at(0), + tv7->getRootDomain().at(0), iel_graph, tester.idGraph(IdMappingMode::EXACT), root_resolution_map); @@ -1090,7 +1063,7 @@ TEST_F(IdModelTest, LoopPromotion8) { // iS9{5}, bS10{1}) validateIELResolution( tv->getRootDomain().at(2), - getTensorByName(all_tvs, 7)->getRootDomain().at(2), + tv7->getRootDomain().at(2), iel_graph, tester.idGraph(IdMappingMode::EXACT), root_resolution_map); From 0e6951354da06ecf05e2435b020e4150e5cc7694 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 23 Feb 2024 04:36:04 -0800 Subject: [PATCH 156/178] Simplify for step 2 --- csrc/id_model/id_model.cpp | 65 +++----------------------------------- csrc/id_model/id_model.h | 30 +----------------- 2 files changed, 5 insertions(+), 90 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 62f18ea9b94..7ea0c97e4ac 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -1009,9 +1009,7 @@ bool hasUniqueInputLoopGroups( Expr* findMatchingExpr( const ExprGroup& iel_expr, const ValGraph& iel_graph, - const std::vector& maybe_promoted_inputs, - bool require_loop_mapped_promotion, - const ValGraph& loop_graph) { + const std::vector& maybe_promoted_inputs) { // Grab all uses of the promoted inputs ExprGroups maybe_promoted_input_uses; for (auto inp_id : maybe_promoted_inputs) { @@ -1058,29 +1056,6 @@ Expr* findMatchingExpr( continue; } - // For the final loop promotion map, we want to find - // promotions within the same loop groups. Note that that's - // guaranteed when replayed. - if (require_loop_mapped_promotion) { - if (!loop_graph.disjointExprSets().permissiveAreMapped( - iel_expr->front(), maybe_promoted_input_use_group->front())) { - continue; - } - // This is just an extra sanity check. Make sure all exprs in - // the use group are mapped - NVF_ERROR( - std::all_of( - maybe_promoted_input_use_group->vector().begin(), - maybe_promoted_input_use_group->vector().end(), - [&](Expr* iel_use) { - return loop_graph.disjointExprSets().permissiveAreMapped( - iel_expr->front(), iel_use); - }), - "Not all mapped: ", - nvfuser::toString(iel_expr), - "\n", - nvfuser::toString(maybe_promoted_input_use_group)); - } return maybe_promoted_input_use; } @@ -1091,10 +1066,7 @@ Expr* findMatchingExpr( void IdModel::propagatePromotionsInIELGraph( const ValGraph& iel_graph, - std::unordered_map& iel_promotion_map, - const ValGraph& loop_graph, - const std::unordered_map& loop_graph_promotion_map, - bool require_loop_mapped_promotion) { + std::unordered_map& iel_promotion_map) { // In order to make this traversal work, the traversal order must be // topologically sorted. ValGraphStmtSort iel_stmt_sort(iel_graph); @@ -1104,11 +1076,6 @@ void IdModel::propagatePromotionsInIELGraph( const std::vector iel_inp_groups = iel_graph.inputGroups(iel_expr); - // Propagate loop graph promotion only when the inputs and outputs are - // not in the same loop group. - const bool loop_promote_inputs = !loop_graph_promotion_map.empty() && - hasUniqueInputLoopGroups(iel_expr, iel_graph, loop_graph); - // Check if any inputs need promotion indicating this expr group needs to // be replayed with promoted inputs bool an_input_was_promoted = false; @@ -1127,19 +1094,6 @@ void IdModel::propagatePromotionsInIELGraph( continue; } - // Promote loops based on the loop promotion map. If the loop promotion - // map should be used and has an entry we should use that promotion. - if (loop_promote_inputs) { - const ValGroup& loop_copy_group = - loop_graph.toGroup(iel_inp_group->front()); - auto inp_loop_promo_it = loop_graph_promotion_map.find(loop_copy_group); - if (inp_loop_promo_it != loop_graph_promotion_map.end()) { - maybe_promoted_inputs.push_back(inp_loop_promo_it->second); - an_input_was_promoted = true; - continue; - } - } - // No promotion found. Just use the non-promoted domain maybe_promoted_inputs.push_back(iel_inp_group->front()->as()); } @@ -1149,12 +1103,8 @@ void IdModel::propagatePromotionsInIELGraph( continue; } - Expr* promoted_expr = findMatchingExpr( - iel_expr, - iel_graph, - maybe_promoted_inputs, - require_loop_mapped_promotion, - idGraph(IdMappingMode::LOOP)); + Expr* promoted_expr = + findMatchingExpr(iel_expr, iel_graph, maybe_promoted_inputs); bool replayed = false; @@ -1192,13 +1142,6 @@ void IdModel::propagatePromotionsInIELGraph( } } -void IdModel::propagatePromotionsInIELGraph( - const ValGraph& iel_graph, - std::unordered_map& iel_promotion_map) { - propagatePromotionsInIELGraph( - iel_graph, iel_promotion_map, idGraph(IdMappingMode::LOOP), {}, false); -} - // Replay Expr but with the inputs provided. Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { // Figure out which graphs are already initialized to make sure we add the new diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index ff5c3b2e25c..6c14cb27609 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -201,35 +201,7 @@ class IdModel : public PolymorphicBase { // input is promoted, the output needs to be promoted too. If // there's already an equivalent expr that uses the promoted inputs, // create a mapping from the outputs of the IEL expr to the outputs - // of the equivalent expr. When require_loop_mapped_promotion is - // true, the equivalent expr needs to be already loop mapped. If no - // such expr is found, the IEL expr is replayed iwth the promoted - // inputs. require_loop_mapped_promotion is true when this function - // is used for step 3. - // - // This is used twice when building the promotion map. The first time - // it is used there's no loop graph promotion yet, so only the IEL - // promotions are propagated. In that case, loop_graph_promotion_map - // should be just empty. - // - // Propagation uses iel_promotion_map and - // loop_graph_promotion_map. If both are available for an IEL group, - // the former has the precedence. This is because when this function - // is used for step 4, the given iel_promotion_map starts as an - // empty map and gets populated during this propagation, so any - // mapping in the map is guaranteed to be the correct final mapping, - // whereas the loop graph may have invalid mappings for partially - // inlined domains. - void propagatePromotionsInIELGraph( - const ValGraph& iel_graph, - std::unordered_map& iel_promotion_map, - const ValGraph& loop_graph, - const std::unordered_map& loop_promotion_map, - bool require_loop_mapped_promotion); - - // Same as the other propagatePromotionsInIELGraph but without loop - // graph map. This is used for step 2, where there's no loop - // graph map yet. + // of the equivalent expr. void propagatePromotionsInIELGraph( const ValGraph& iel_graph, std::unordered_map& iel_promotion_map); From 30e3d5bdf6c1e8e2ffd6e71ad4d3579720222400 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 23 Feb 2024 08:28:33 -0800 Subject: [PATCH 157/178] cleanup --- csrc/id_model/id_model.cpp | 42 -------------------------------------- 1 file changed, 42 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 7ea0c97e4ac..6e8f4fb7a5e 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -958,48 +958,6 @@ ValGraph IdModel::buildIntersection( namespace { -// When propagating loop promotions from inputs to outputs of an IEL -// expr, we can't blindly apply loop promotion when all of the input -// domains are loop mapped with the outputs. -// -// i.e. if we have the inlined domains from: -// T2[i0*i1] pa(1) = T0[i0*b1]ca(1) + T1[i0*i1]ca(1) -// The inlined loop group would be: -// -// i0, i1, b1, i0*i1, b0*i1 -// Then if we replayed the iel transformations they would be: -// merge(i0, i1) -// merge(i0, b1) -// -// So if we replayed them with loop promotion, then i0, i1, b1 would be -// promoted to i0*i1, and the merges would be replayed. -// -// Therefore only promote i0*b1 to i0*i1, or i0*i1 to i0*i1 (i.e. don't -// promote an input to any transformation within the loop group). -// -// So if we have an iel_expr make sure its inputs and outputs are not in -// the same loop group. -bool hasUniqueInputLoopGroups( - const ExprGroup& iel_expr, - const ValGraph& iel_graph, - const ValGraph& loop_graph) { - const std::vector iel_inp_groups = iel_graph.inputGroups(iel_expr); - - const std::vector iel_out_groups = iel_graph.outputGroups(iel_expr); - - ValGroups inp_loop_groups; - for (const ValGroup& iel_inp_group : iel_inp_groups) { - inp_loop_groups.pushBack(loop_graph.toGroup(iel_inp_group->front())); - } - ValGroups out_loop_groups; - for (const ValGroup& iel_out_group : iel_out_groups) { - out_loop_groups.pushBack(loop_graph.toGroup(iel_out_group->front())); - } - - // Check if input groups that are not included in the output group set - return !inp_loop_groups.computeSubtract(out_loop_groups).empty(); -} - // Check if there's an equivalent expression as iel_expr that uses // maybe_promoted_inputs. This is used to avoid redundantly replaying // expressions. From b30682aa0ab40015201022b9a3031b2f8013a26c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 24 Feb 2024 11:16:07 -0800 Subject: [PATCH 158/178] format --- csrc/id_model/id_model.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index daf4fd2b9e4..94163d7753b 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -245,7 +245,7 @@ class IdModel : public PolymorphicBase { void propagatePromotionsInIELGraph( const ValGraph& iel_graph, std::unordered_map& iel_promotion_map); - + // Returns a similar thing to buildInlinePromotions but also includes iter // domains that are not inlined. std::unordered_map projectIELPromotionToLoopGraph( From 04b1479140ca615375836d0edc21b5be34e0e78b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 18 Mar 2024 15:09:30 -0700 Subject: [PATCH 159/178] Test cleanup --- tests/cpp/test_id_model.cpp | 657 ++++++++++++++++++++++++++++-------- 1 file changed, 523 insertions(+), 134 deletions(-) diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 2f8a7b20373..94f37d3fd75 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -24,7 +24,7 @@ namespace nvfuser { -class IdModelTest : public NVFuserTest {}; +using IdModelTest = NVFuserTest; TEST_F(IdModelTest, DetectSelfMapping) { Fusion fusion; @@ -37,14 +37,33 @@ TEST_F(IdModelTest, DetectSelfMapping) { fusion.addOutput(tv2); EXPECT_THAT( - [&]() { - IdModel id_model(&fusion); - id_model.buildAllGraphs(); - }, + [&]() { IdModel id_model(&fusion, /*build_graphs=*/true); }, ::testing::ThrowsMessage( ::testing::HasSubstr("!hasSelfMapping"))); } +TEST_F(IdModelTest, PerTensorSelfMapping) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* x0 = makeConcreteTensor({2, 2}); + fusion.addInput(x0); + TensorView* x1 = makeConcreteTensor({2, 2}); + fusion.addInput(x1); + + TensorView* y0 = transpose(x0, 0, 1); + y0 = add(x0, y0); + fusion.addOutput(y0); + + TensorView* y1 = transpose(x1, 0, 1); + fusion.addOutput(y1); + + IdModel id_model(&fusion, /*build_graphs=*/true, /*allow_self_mapping=*/true); + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + EXPECT_TRUE(hasSelfMapping(y0, exact_graph).has_value()); + EXPECT_FALSE(hasSelfMapping(y1, exact_graph).has_value()); +} + namespace { // Get n-th parent expr traversing through the first input of each @@ -58,14 +77,34 @@ Expr* getParentExpr(Val* val, int n) { return val->definition(); }; -TensorView* getTensorByName( - const std::vector& tvs, - StmtNameType name) { +IterDomain* getParentId(IterDomain* id, int n) { + for (int i = 0; i < n; ++i) { + NVF_ERROR(id->definition() != nullptr); + NVF_ERROR(id->definition()->input(0)->isA()); + id = id->definition()->input(0)->as(); + } + NVF_ERROR(id != nullptr); + return id; +}; + +// Get the n-th descendant by traversing a sibling +IterDomain* getChildId(IterDomain* id, int n, int sibling_idx = 0) { + for (int i = 0; i < n; ++i) { + NVF_ERROR(!id->uses().empty()); + NVF_ERROR(id->uses().front()->output(sibling_idx)->isA()); + id = id->uses().front()->output(sibling_idx)->as(); + } + NVF_ERROR(id != nullptr); + return id; +}; + +template +ValType* getValByName(const std::vector& vals, StmtNameType name) { if (auto it = std::find_if( - tvs.begin(), - tvs.end(), - [&](TensorView* tv) { return tv->name() == name; }); - it != tvs.end()) { + vals.begin(), + vals.end(), + [&](auto val) { return val->name() == name; }); + it != vals.end()) { return *it; } else { return nullptr; @@ -76,14 +115,7 @@ TensorView* getTensorByName( class IdModelTester : public IdModel { public: // Do not automatically build the graphs - IdModelTester(Fusion* fusion) : IdModel(fusion, /* build_graphs */ false) {} - - // Returns the IEL graph and the results of Steps 1 and 2 - std::tuple< - ValGraph, - std::unordered_map, - std::unordered_map> - getLoopPromotionInfo() { + IdModelTester(Fusion* fusion) : IdModel(fusion, /*build_graphs=*/false) { // Make sure the depedent graphs are already built maybeBuildGraph(IdMappingMode::EXACT); maybeBuildGraph(IdMappingMode::PERMISSIVE); @@ -102,41 +134,47 @@ class IdModelTester : public IdModel { VERBOSE() << nvfuser::toString(group) << std::endl; } - ValGraph iel_graph = buildIntersection( + iel_graph = buildIntersection( idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); - std::unordered_map root_promotion_map = + s1_root_resolution_map = buildInlineRootResolutionMap(iel_graph, inlining_info); { std::stringstream ss; ss << "Step 1: Root promotion map\n"; - for (const auto& [iel_group, promoted_id] : root_promotion_map) { + for (const auto& [iel_group, promoted_id] : s1_root_resolution_map) { ss << "\t" << nvfuser::toString(iel_group) << " -> " << promoted_id->name() << std::endl; } VERBOSE() << ss.str(); } - auto iel_promotion_map = root_promotion_map; + s2_iel_promotion_map = s1_root_resolution_map; - propagatePromotionsInIELGraph(iel_graph, iel_promotion_map); + propagatePromotionsInIELGraph(iel_graph, s2_iel_promotion_map); { std::stringstream ss; ss << "Step 2: IEL promotion map\n"; - for (const auto& [iel_group, promoted_id] : iel_promotion_map) { + for (const auto& [iel_group, promoted_id] : s2_iel_promotion_map) { ss << "\t" << nvfuser::toString(iel_group) << " -> " << promoted_id->name() << std::endl; } VERBOSE() << ss.str(); } - return { - std::move(iel_graph), - std::move(root_promotion_map), - std::move(iel_promotion_map)}; + s3_loop_promotion_map = projectIELPromotionToLoopGraph( + iel_graph, + s2_iel_promotion_map, + idGraph(IdMappingMode::LOOP), + inlining_info); } + + ValGraph iel_graph; + std::unordered_map s1_root_resolution_map; + std::unordered_map s2_iel_promotion_map; + std::unordered_map s3_loop_promotion_map; }; // Test if id is resolved to an ID that is exact mapped with @@ -241,6 +279,46 @@ void checkStep2Results( } } +// Validate the loop promotion map at Step 3. This validation ensures +// the promotion map is exactly the same as a given reference +// map. Since the valid promotion map may not be unique, the exact +// equality is not required, however, as long as everything is done +// deterministically, the resulting map should always be the +// same. The exact equality helps ensure the determinism as well. +void checkStep3Results( + const ValGraph& loop_graph, + const std::unordered_map& loop_promotion_map, + const std::vector, IterDomain*>>& + ref_promotion_map) { + for (const auto& loop_group : loop_graph.disjointValSets().disjointSets()) { + auto promotion_it = loop_promotion_map.find(loop_group); + ASSERT_NE(promotion_it, loop_promotion_map.end()) + << "No promotion found for: " << nvfuser::toString(loop_group); + IterDomain* promotion_id = promotion_it->second; + + auto ref_promotion_it = std::find_if( + ref_promotion_map.begin(), + ref_promotion_map.end(), + [&](const auto& ref_promotion) { + return ref_promotion.first == loop_group->set(); + }); + + // Self promotion omitted in the reference + if (ref_promotion_it == ref_promotion_map.end()) { + ASSERT_EQ(loop_group->size(), 1); + ASSERT_EQ(loop_group->front(), promotion_id) + << "Expected promotion: " << loop_group->front()->toString() + << ". Actual: " << promotion_id->toString(); + continue; + } + + auto ref_promotion_id = ref_promotion_it->second; + ASSERT_EQ(promotion_id, ref_promotion_id) + << "Expected promotion: " << ref_promotion_id->toString() + << ". Actual: " << promotion_id->toString(); + } +} + // Create a fusion where we're missing a valid concrete id so the compute at map // processing will fail. We need to be able to create the concrete ID not just // look for one. It is not yet possible to lower this fusion as the @@ -495,14 +573,14 @@ TEST_F(IdModelTest, ValGraphStmtSort4) { ValGraphStmtSort vg_stmt_sort(vg); - auto tv1 = getTensorByName(all_tvs, 1); - auto tv2 = getTensorByName(all_tvs, 2); - auto tv4 = getTensorByName(all_tvs, 4); - auto tv5 = getTensorByName(all_tvs, 5); - auto tv6 = getTensorByName(all_tvs, 6); - auto tv8 = getTensorByName(all_tvs, 8); - auto tv9 = getTensorByName(all_tvs, 9); - auto tv10 = getTensorByName(all_tvs, 10); + auto tv1 = getValByName(all_tvs, 1); + auto tv2 = getValByName(all_tvs, 2); + auto tv4 = getValByName(all_tvs, 4); + auto tv5 = getValByName(all_tvs, 5); + auto tv6 = getValByName(all_tvs, 6); + auto tv8 = getValByName(all_tvs, 8); + auto tv9 = getValByName(all_tvs, 9); + auto tv10 = getValByName(all_tvs, 10); // Expected reference order: // @@ -575,11 +653,9 @@ TEST_F(IdModelTest, LoopPromotion1) { { IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getLoopPromotionInfo(); // Nothing inlined. Should be no resolution - ASSERT_TRUE(root_resolution_map.empty()); + ASSERT_TRUE(tester.s1_root_resolution_map.empty()); } t2->inlineAt(2); @@ -587,8 +663,6 @@ TEST_F(IdModelTest, LoopPromotion1) { { IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getLoopPromotionInfo(); // Check Step 1 results // t2 is now fully inlined. Its root broadcast domain should be @@ -596,15 +670,26 @@ TEST_F(IdModelTest, LoopPromotion1) { validateIELResolution( t2->getRootDomain().at(0), t3->getRootDomain().at(0), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); // Check Step 2 results // Nothing to propagate in this fusion, so iel_promotion_map // should be equivalent to root_resolution_map - ASSERT_EQ(root_resolution_map, iel_promotion_map) + ASSERT_EQ(tester.s1_root_resolution_map, tester.s2_iel_promotion_map) << "Unexpected IEL promotion map"; + + // Check Step 3 results. See the design doc for the expected results + std::vector, IterDomain*>> + s3_reference_map = { + {std::unordered_set{t2->axis(0), t3->axis(0)}, t3->axis(0)}, + {std::unordered_set{t2->axis(1), t3->axis(1)}, t3->axis(1)}}; + + checkStep3Results( + tester.idGraph(IdMappingMode::LOOP), + tester.s3_loop_promotion_map, + s3_reference_map); } } @@ -626,30 +711,42 @@ TEST_F(IdModelTest, LoopPromotion2) { inlineMost(); IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getLoopPromotionInfo(); // Check Step 1 results // Validate t2 and t3 as they have root broadcast domains validateIELResolution( t2->getRootDomain().at(0), t4->getRootDomain().at(1), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); validateIELResolution( t3->getRootDomain().at(0), t4->getRootDomain().at(0), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); // Check Step 2 results // Nothing to propagate in this fusion, so iel_promotion_map // should be equivalent to root_resolution_map - ASSERT_EQ(root_resolution_map, iel_promotion_map) + ASSERT_EQ(tester.s1_root_resolution_map, tester.s2_iel_promotion_map) << "Unexpected IEL promotion map"; + + // Check Step 3 results. See the design doc for the expected results + std::vector, IterDomain*>> + s3_reference_map = { + {std::unordered_set{t2->axis(0), t3->axis(1), t4->axis(1)}, + t4->axis(1)}, + {std::unordered_set{t2->axis(1), t3->axis(2), t4->axis(2)}, + t4->axis(2)}, + {std::unordered_set{t3->axis(0), t4->axis(0)}, t4->axis(0)}}; + + checkStep3Results( + tester.idGraph(IdMappingMode::LOOP), + tester.s3_loop_promotion_map, + s3_reference_map); } // Multiple inlined and non-inlined broadcast domains @@ -679,8 +776,6 @@ TEST_F(IdModelTest, LoopPromotion3) { // tv3: [i0*i1, i2*i3] IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getLoopPromotionInfo(); // Check Step 1 results // The b1 broadcast domain tv2 should be resolved as it's inlined, @@ -688,31 +783,48 @@ TEST_F(IdModelTest, LoopPromotion3) { validateIELResolution( tv2->getRootDomain().at(1), tv3->getRootDomain().at(1), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); validateIELResolution( tv2->getRootDomain().at(3), nullptr, - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); // Check Step 2 results validateIELResolution( tv2->axis(0), tv3->axis(0), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - iel_promotion_map); + tester.s2_iel_promotion_map); validateIELResolution( tv2->axis(1), nullptr, - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - iel_promotion_map); + tester.s2_iel_promotion_map); + + // Check Step 3 results. See the design doc for the expected results + std::vector, IterDomain*>> + s3_reference_map = { + {std::unordered_set{ + tv2->axis(0), + tv2->getRootDomain().at(0), + tv2->getRootDomain().at(1), + tv3->axis(0), + tv3->getRootDomain().at(0), + tv3->getRootDomain().at(1)}, + tv3->axis(0)}}; + + checkStep3Results( + tester.idGraph(IdMappingMode::LOOP), + tester.s3_loop_promotion_map, + s3_reference_map); } // Test root resolution with a fusion with outer split. @@ -748,8 +860,6 @@ TEST_F(IdModelTest, LoopPromotion4) { } IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getLoopPromotionInfo(); // Verify all tensors with root broadcast have correct resolutions for (auto tv : ir_utils::allTvs(&fusion)) { @@ -769,9 +879,9 @@ TEST_F(IdModelTest, LoopPromotion4) { validateIELResolution( tv->getRootDomain().at(0), tv4->getRootDomain().at(0), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); @@ -780,9 +890,41 @@ TEST_F(IdModelTest, LoopPromotion4) { checkStep2Results( &fusion, - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - iel_promotion_map); + tester.s2_iel_promotion_map); + + // Check Step 3 results. See the design doc for the expected results + std::vector, IterDomain*>> + s3_reference_map = { + // 4, 6, 8 -> 8 + {std::unordered_set{ + tv2->getRootDomain().at(0), + tv3->getRootDomain().at(0), + tv4->getRootDomain().at(0)}, + tv4->getRootDomain().at(0)}, + // 5, 7, 9 -> 9 + {std::unordered_set{ + tv2->getRootDomain().at(1), + tv3->getRootDomain().at(1), + tv4->getRootDomain().at(1)}, + tv4->getRootDomain().at(1)}, + // 10, 13, 19 -> 10 + {std::unordered_set{ + getParentId(tv2->axis(0), 1), + getParentId(tv3->axis(0), 1), + getParentId(tv4->axis(0), 1)}, + getParentId(tv4->axis(0), 1)}, + // 11, 14, 20 -> 11 + {std::unordered_set{tv2->axis(0), tv3->axis(0), tv4->axis(0)}, + tv4->axis(0)}, + // 21 -> 12 + {std::unordered_set{tv2->axis(1)}, tv4->axis(1)}}; + + checkStep3Results( + tester.idGraph(IdMappingMode::LOOP), + tester.s3_loop_promotion_map, + s3_reference_map); } // Test root resolution with the same fusion as Indexing1 @@ -823,8 +965,6 @@ TEST_F(IdModelTest, LoopPromotion5) { auto all_tvs = ir_utils::allTvs(&fusion); IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getLoopPromotionInfo(); // Check Step 1 results for (auto tv : all_tvs) { @@ -845,9 +985,9 @@ TEST_F(IdModelTest, LoopPromotion5) { validateIELResolution( tv->getRootDomain().at(0), tv4->getRootDomain().at(0), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); @@ -857,9 +997,66 @@ TEST_F(IdModelTest, LoopPromotion5) { // Check Step 2 results checkStep2Results( &fusion, - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - iel_promotion_map); + tester.s2_iel_promotion_map); + + // Check Step 3 results. See the design doc for the expected results + std::vector, IterDomain*>> + s3_reference_map = { + // 7, 10, 11, 25, 14, 15, 18 -> 18 + {std::unordered_set{ + tv2->getRootDomain().at(0), + tv3->getRootDomain().at(0), + tv3->getRootDomain().at(1), + getParentId(tv3->axis(0), 4), + tv4->getRootDomain().at(0), + tv4->getRootDomain().at(1), + getParentId(tv4->axis(0), 4)}, + getParentId(tv4->axis(0), 4)}, + // 8, 12, 16 -> 16 + {std::unordered_set{ + tv2->getRootDomain().at(1), + tv3->getRootDomain().at(2), + tv4->getRootDomain().at(2)}, + tv4->getRootDomain().at(2)}, + // 9, 13, 17 -> 17 + {std::unordered_set{ + tv2->getRootDomain().at(2), + tv3->getRootDomain().at(3), + tv4->getRootDomain().at(3)}, + tv4->getRootDomain().at(3)}, + // 32, 26, 19 -> 19 + {std::unordered_set{ + getParentId(tv2->axis(0), 3), + getParentId(tv3->axis(0), 3), + getParentId(tv4->axis(0), 3)}, + getParentId(tv4->axis(0), 3)}, + // 33, 27, 20 -> 20 + {std::unordered_set{ + getParentId(tv2->axis(0), 2), + getParentId(tv3->axis(0), 2), + getParentId(tv4->axis(0), 2)}, + getParentId(tv4->axis(0), 2)}, + // 34, 28, 21 -> 21 + {std::unordered_set{ + getParentId(tv2->axis(0), 1), + getParentId(tv3->axis(0), 1), + getParentId(tv4->axis(0), 1)}, + getParentId(tv4->axis(0), 1)}, + // 29 -> 22 + {std::unordered_set{tv3->axis(2)}, tv4->axis(2)}, + // 31 -> 24 + {std::unordered_set{tv3->axis(1)}, tv4->axis(1)}, + // 36, 30, 23 -> 23 + {std::unordered_set{tv2->axis(0), tv3->axis(0), tv4->axis(0)}, + tv4->axis(0)}, + }; + + checkStep3Results( + tester.idGraph(IdMappingMode::LOOP), + tester.s3_loop_promotion_map, + s3_reference_map); } // Test root resolution with the same fusion as Indexing19 @@ -869,8 +1066,14 @@ TEST_F(IdModelTest, LoopPromotion6) { auto all_tvs = ir_utils::allTvs(fusion.get()); IdModelTester tester(fusion.get()); - const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getLoopPromotionInfo(); + + auto tv1 = getValByName(all_tvs, 1); + auto tv2 = getValByName(all_tvs, 2); + auto tv4 = getValByName(all_tvs, 4); + auto tv5 = getValByName(all_tvs, 5); + auto tv6 = getValByName(all_tvs, 6); + auto tv8 = getValByName(all_tvs, 8); + auto tv9 = getValByName(all_tvs, 9); // Check Step 1 results for (auto tv : all_tvs) { @@ -891,10 +1094,10 @@ TEST_F(IdModelTest, LoopPromotion6) { // Resolution: Resolved by the immediate consumer (T4) validateIELResolution( tv->getRootDomain().at(1), - getTensorByName(all_tvs, 4)->getRootDomain().at(1), - iel_graph, + tv4->getRootDomain().at(1), + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); break; case 5: // T5_l[ iS39{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 1 ), 5) ), 3) )}, @@ -905,10 +1108,10 @@ TEST_F(IdModelTest, LoopPromotion6) { // as T8 or T9. validateIELResolution( tv->getRootDomain().at(2), - getTensorByName(all_tvs, 9)->getRootDomain().at(2), - iel_graph, + tv9->getRootDomain().at(2), + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); break; case 6: // T6_l[ iS64{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS65{3}, @@ -917,10 +1120,10 @@ TEST_F(IdModelTest, LoopPromotion6) { // Resolution: Resolved by the immediate consumer (T8) validateIELResolution( tv->getRootDomain().at(1), - getTensorByName(all_tvs, 8)->getRootDomain().at(1), - iel_graph, + tv8->getRootDomain().at(1), + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); break; case 9: // T9_l[ iS33{( ceilDiv(( ceilDiv(( ( 7 * 1 ) * 13 ), 5) ), 3) )}, @@ -931,10 +1134,10 @@ TEST_F(IdModelTest, LoopPromotion6) { // as T4 or T5 validateIELResolution( tv->getRootDomain().at(1), - getTensorByName(all_tvs, 5)->getRootDomain().at(1), - iel_graph, + tv5->getRootDomain().at(1), + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); @@ -943,9 +1146,113 @@ TEST_F(IdModelTest, LoopPromotion6) { checkStep2Results( fusion.get(), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - iel_promotion_map); + tester.s2_iel_promotion_map); + + auto id79 = + getValByName(ir_utils::consumerValsOf(tv9->getRootDomain().at(2)), 79) + ->as(); + ASSERT_NE(id79, nullptr) << "IterDomain 79 not found"; + auto id80 = + getValByName(ir_utils::consumerValsOf(tv9->getRootDomain().at(2)), 80) + ->as(); + ASSERT_NE(id80, nullptr) << "IterDomain 80 not found"; + auto id81 = getChildId(id79, 1); + ASSERT_EQ(id81->name(), 81); + auto id82 = getChildId(id79, 1, 1); + ASSERT_EQ(id82->name(), 82); + auto id83 = getChildId(id80, 1); + ASSERT_EQ(id83->name(), 83); + auto id84 = getChildId(id80, 1, 1); + ASSERT_EQ(id84->name(), 84); + auto id85 = getChildId(id81, 1); + ASSERT_EQ(id85->name(), 85); + auto id86 = getChildId(id81, 1, 1); + ASSERT_EQ(id86->name(), 86); + auto id87 = getChildId(id83, 1); + ASSERT_EQ(id87->name(), 87); + auto id88 = getChildId(id83, 1, 1); + ASSERT_EQ(id88->name(), 88); + + // Check Step 3 results. See the design doc for the expected results + std::vector, IterDomain*>> + s3_reference_map = { + // 1 2 3 6 7 8 9 10 11 12 15 16 17 18 19 29 30 35 36 41 46 56 61 + // 79 80 -> 80 + {std::unordered_set{ + tv1->getRootDomain().at(0), + tv2->getRootDomain().at(0), + tv2->getRootDomain().at(1), + getChildId(tv2->getRootDomain().at(0), 1), + tv4->getRootDomain().at(0), + tv4->getRootDomain().at(1), + getChildId(tv4->getRootDomain().at(0), 1), + tv5->getRootDomain().at(0), + tv5->getRootDomain().at(1), + tv5->getRootDomain().at(2), + getChildId(tv5->getRootDomain().at(0), 1), + getChildId(tv5->getRootDomain().at(2), 1), + tv6->getRootDomain().at(0), + tv6->getRootDomain().at(1), + getChildId(tv6->getRootDomain().at(0), 1), + tv8->getRootDomain().at(0), + tv8->getRootDomain().at(1), + getChildId(tv8->getRootDomain().at(0), 1), + tv9->getRootDomain().at(0), + tv9->getRootDomain().at(1), + tv9->getRootDomain().at(2), + getChildId(tv9->getRootDomain().at(0), 1), + getChildId(tv9->getRootDomain().at(0), 2), + id79, + id80}, + id80}, + // 31 37 42 47 57 62 71 81 83 -> 83 + {std::unordered_set{ + getChildId(tv1->getRootDomain().at(0), 1), + getChildId(tv2->getRootDomain().at(0), 2), + getChildId(tv4->getRootDomain().at(0), 2), + getChildId(tv5->getRootDomain().at(0), 3), + getChildId(tv6->getRootDomain().at(0), 2), + getChildId(tv8->getRootDomain().at(0), 2), + getChildId(tv9->getRootDomain().at(0), 3), + id81, + id83}, + id83}, + // 33 39 44 49 59 64 73 85 87 -> 87 + {std::unordered_set{ + tv1->axis(0), + tv2->axis(0), + tv4->axis(0), + tv5->axis(0), + tv6->axis(0), + tv8->axis(0), + tv9->axis(0), + id85, + id87}, + id87}, + // 48 -> 43 + {std::unordered_set{tv2->axis(2)}, tv4->axis(2)}, + // 50 -> 45 + {std::unordered_set{tv2->axis(1)}, tv4->axis(1)}, + // 40 88 -> 88 + {std::unordered_set{tv5->axis(1), id88}, id88}, + // 63 -> 58 + {std::unordered_set{tv6->axis(2)}, tv8->axis(2)}, + // 65 -> 60 + {std::unordered_set{tv6->axis(1)}, tv8->axis(1)}, + // 34 86 -> 86 + {std::unordered_set{tv9->axis(1), id86}, id86}, + // 38 84 -> 84 + {std::unordered_set{tv5->axis(2), id84}, id84}, + // 32 82 -> 82 + {std::unordered_set{tv9->axis(2), id82}, id82}, + }; + + checkStep3Results( + tester.idGraph(IdMappingMode::LOOP), + tester.s3_loop_promotion_map, + s3_reference_map); } // Same fusion as NvFuserTest.FusionInlineBroadcastIndexing0 @@ -976,8 +1283,6 @@ TEST_F(IdModelTest, LoopPromotion7) { auto all_tvs = ir_utils::allTvs(&fusion); IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getLoopPromotionInfo(); // Verify all tensors with root broadcast have correct resolutions for (auto tv : all_tvs) { @@ -997,9 +1302,9 @@ TEST_F(IdModelTest, LoopPromotion7) { validateIELResolution( tv->getRootDomain().at(0), tv4->getRootDomain().at(0), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); @@ -1008,9 +1313,33 @@ TEST_F(IdModelTest, LoopPromotion7) { checkStep2Results( &fusion, - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - iel_promotion_map); + tester.s2_iel_promotion_map); + + // Check Step 3 results. See the design doc for the expected results + std::vector, IterDomain*>> + s3_reference_map = { + // 3, 4, 5, 14, 6, 7, 8, -> 8 + {std::unordered_set{ + tv2->getRootDomain().at(0), + tv3->getRootDomain().at(0), + tv3->getRootDomain().at(1), + getChildId(tv3->getRootDomain().at(0), 1), + tv4->getRootDomain().at(0), + tv4->getRootDomain().at(1), + getChildId(tv4->getRootDomain().at(0), 1)}, + getChildId(tv4->getRootDomain().at(0), 1)}, + // 17, 15, 9 -> 9 + {std::unordered_set{tv2->axis(0), tv3->axis(0), tv4->axis(0)}, + tv4->axis(0)}, + // 16 -> 10 + {std::unordered_set{tv3->axis(1)}, tv4->axis(1)}}; + + checkStep3Results( + tester.idGraph(IdMappingMode::LOOP), + tester.s3_loop_promotion_map, + s3_reference_map); } // Same fusion as NvFuserTest.FusionIndexing20 @@ -1059,8 +1388,6 @@ TEST_F(IdModelTest, LoopPromotion8) { auto all_tvs = ir_utils::allTvs(&fusion); IdModelTester tester(&fusion); - const auto& [iel_graph, root_resolution_map, iel_promotion_map] = - tester.getLoopPromotionInfo(); // Verify all tensors with root broadcast have correct resolutions for (auto tv : all_tvs) { @@ -1080,9 +1407,9 @@ TEST_F(IdModelTest, LoopPromotion8) { validateIELResolution( tv->getRootDomain().at(0), tv7->getRootDomain().at(0), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); break; case 5: // T5_l[ iS27{2}, iS40{4}, iS41{( ceilDiv(( ( ceilDiv(( 3 * 5 ), 2) ) * @@ -1091,9 +1418,9 @@ TEST_F(IdModelTest, LoopPromotion8) { validateIELResolution( tv->getRootDomain().at(2), tv7->getRootDomain().at(2), - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - root_resolution_map); + tester.s1_root_resolution_map); break; default: FAIL() << "Unexpected tensor: " << tv->toString(); @@ -1102,9 +1429,60 @@ TEST_F(IdModelTest, LoopPromotion8) { checkStep2Results( &fusion, - iel_graph, + tester.iel_graph, tester.idGraph(IdMappingMode::EXACT), - iel_promotion_map); + tester.s2_iel_promotion_map); + + // Check Step 3 results. See the design doc for the expected results + std::vector, IterDomain*>> + s3_reference_map = { + // 1, 2, 3, 20, 6, 7, 17, 8, 9, 26, 14, 15, 29 -> 29 + {std::unordered_set{ + tv1->getRootDomain().at(0), + tv2->getRootDomain().at(0), + tv2->getRootDomain().at(1), + getChildId(tv2->getRootDomain().at(0), 1), + tv4->getRootDomain().at(0), + tv4->getRootDomain().at(1), + getChildId(tv4->getRootDomain().at(0), 1), + tv5->getRootDomain().at(0), + tv5->getRootDomain().at(1), + getChildId(tv5->getRootDomain().at(0), 1), + tv7->getRootDomain().at(0), + tv7->getRootDomain().at(1), + getChildId(tv7->getRootDomain().at(0), 1)}, + getChildId(tv7->getRootDomain().at(0), 1)}, + // 35, 21, 18, 27, 30 -> 30 + {std::unordered_set{ + tv1->axis(0), + tv2->axis(0), + tv4->axis(0), + tv5->axis(0), + tv7->axis(0)}, + tv7->axis(0)}, + // 28, 10, 39, 31, 16, 42 -> 42 + {std::unordered_set{ + getChildId( + getChildId(tv5->getRootDomain().at(0), 1), 1, 1), // 28 + tv5->getRootDomain().at(2), // 10 + getChildId(tv5->getRootDomain().at(2), 1), // 39 + getChildId( + getChildId(tv7->getRootDomain().at(0), 1), 1, 1), // 31 + tv7->getRootDomain().at(2), // 16 + getChildId(tv7->getRootDomain().at(2), 1)}, // 42 + getChildId(tv7->getRootDomain().at(2), 1)}, + // 22 -> 19 + {std::unordered_set{tv2->axis(1)}, tv4->axis(1)}, + // 40, 43 -> 43 + {std::unordered_set{tv5->axis(1), tv7->axis(1)}, tv7->axis(1)}, + // 41 -> 44 + {std::unordered_set{tv5->axis(2)}, tv7->axis(2)}, + }; + + checkStep3Results( + tester.idGraph(IdMappingMode::LOOP), + tester.s3_loop_promotion_map, + s3_reference_map); } // A repro that produces an invalid loop graph due to the compliment @@ -1191,6 +1569,16 @@ TEST_F(IdModelTest, ComplimentMappingCausingLoopSelfMapping) { // loop_graph.toGroup(tv10->axis(1)), loop_graph.toGroup(tv10->axis(2))); } +namespace { +bool iterDomainsAreMapped( + const IdModel& id_model, + IterDomain* a, + IterDomain* b) { + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + return exact_graph.disjointValSets().strictAreMapped(a, b); +} +} // namespace + TEST_F(IdModelTest, SomeButNotAllArePermuted) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1206,24 +1594,20 @@ TEST_F(IdModelTest, SomeButNotAllArePermuted) { IdModel id_model( fusion.get(), /*build_graphs=*/true, /*allow_self_mapping=*/true); - const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s0->axis(0), t0->axis(1))); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s0->axis(1), t0->axis(0))); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s0->axis(2), t0->axis(2))); + EXPECT_TRUE(iterDomainsAreMapped(id_model, s0->axis(0), t0->axis(1))); + EXPECT_TRUE(iterDomainsAreMapped(id_model, s0->axis(1), t0->axis(0))); + EXPECT_TRUE(iterDomainsAreMapped(id_model, s0->axis(2), t0->axis(2))); } TEST_F(IdModelTest, PermutedDifferently) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - TensorView* in = makeContigConcreteTensor({2, 2, 2, 5}); - TensorView* s0 = slice(in, {0, 0, 0, 0}, {2, 2, 2, 2}); - TensorView* s1 = slice(in, {0, 0, 0, 2}, {2, 2, 2, 5}); - TensorView* t0 = permute(s0, {2, 1, 0, 3}); - TensorView* t1 = permute(s1, {1, 0, 2, 3}); + TensorView* in = makeContigConcreteTensor({2, 2, 5}); + TensorView* s0 = slice(in, {0, 0, 0}, {2, 2, 2}); + TensorView* s1 = slice(in, {0, 0, 2}, {2, 2, 5}); + TensorView* t0 = permute(s0, {1, 0, 2}); + TensorView* t1 = set(s1); TensorView* out = cat({t0, t1}, /*dim=*/-1); fusion->addInput(in); @@ -1231,23 +1615,28 @@ TEST_F(IdModelTest, PermutedDifferently) { IdModel id_model( fusion.get(), /*build_graphs=*/true, /*allow_self_mapping=*/true); - const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s0->axis(2), t0->axis(0))); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s0->axis(1), t0->axis(1))); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s0->axis(0), t0->axis(2))); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s0->axis(3), t0->axis(3))); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s1->axis(1), t1->axis(0))); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s1->axis(0), t1->axis(1))); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s1->axis(2), t1->axis(2))); - EXPECT_TRUE( - exact_graph.disjointValSets().strictAreMapped(s1->axis(3), t1->axis(3))); + + // Due to the `slice`s, `s0` and `s1`'s non-split dimensions (0 and 1) are + // mapped respectively. The split dimension (2) isn't. + EXPECT_TRUE(iterDomainsAreMapped(id_model, s0->axis(0), s1->axis(0))); + EXPECT_TRUE(iterDomainsAreMapped(id_model, s0->axis(1), s1->axis(1))); + EXPECT_FALSE(iterDomainsAreMapped(id_model, s0->axis(2), s1->axis(2))); + + // Due to the `cat`, t0' and `t1`'s non-catted dimensions (0 and 1) are + // respectively mapped. The catted dimension (2) isn't. + EXPECT_TRUE(iterDomainsAreMapped(id_model, t0->axis(0), t1->axis(0))); + EXPECT_TRUE(iterDomainsAreMapped(id_model, t0->axis(1), t1->axis(1))); + EXPECT_FALSE(iterDomainsAreMapped(id_model, t0->axis(2), t1->axis(2))); + + // Check the mapping introduced by `t0 = permute(s0, ...)`. + EXPECT_TRUE(iterDomainsAreMapped(id_model, s0->axis(1), t0->axis(0))); + EXPECT_TRUE(iterDomainsAreMapped(id_model, s0->axis(0), t0->axis(1))); + EXPECT_TRUE(iterDomainsAreMapped(id_model, s0->axis(2), t0->axis(2))); + + // Check the mapping introduced by `t1 = set(s1, ...)`. + EXPECT_TRUE(iterDomainsAreMapped(id_model, s1->axis(0), t1->axis(0))); + EXPECT_TRUE(iterDomainsAreMapped(id_model, s1->axis(1), t1->axis(1))); + EXPECT_TRUE(iterDomainsAreMapped(id_model, s1->axis(2), t1->axis(2))); } } // namespace nvfuser From 4b391d8990f6745c32b5f98921bc473d6da8002d Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 25 Mar 2024 12:28:30 -0700 Subject: [PATCH 160/178] cleanup --- csrc/id_model/id_model.cpp | 118 +------------------------------------ 1 file changed, 2 insertions(+), 116 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 68251b09672..2eebcdce708 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -156,120 +156,6 @@ ValGraph& IdModel::idGraph(IdMappingMode mode) { return graph_it->second; } -namespace { - -// Returns the first pair of id's in ids detected to match each other on the -// exact ID graph. TODO: what this is really looking for is if -// there's any overlapping between the iter domains in the provided set. -// -// i.e. if we have: -// tv0 = arange(6).reshape({3, 2}) -// tv1 = tv0[3, 2].t() -// tv2 = tv0[3, 2].reshape({2, 3}) -// tv3 = tv1 + tv2 -// -// Then we can see this overlap in the tv3 expression as: -// -// tv0 = { {0, 1, 2}, -// {3, 4, 5} } -// -// tv1 = { {0, 3}, -// {1, 4}, -// {2, 5} } -// -// tv2 = { {0, 1}, -// {2, 3}, -// {4, 5} } -// -// The elements in tv1 {3, 1, 4, 2}, map respectively to the elements in tv2 -// {1, 2, 3, 4}. The reason this is so important is it means that generating -// tv3 is no longer a trivially parallelizable problem (if we include the dag -// all the way to tv0). So tv0's axes cannot be inlined across both the tv0 -// and tv1 path. This breaks some assumptions we have today in schedulers that -// will assume tv2 can be trivially inlined/parallelized. Instead we'd need to -// take into consideration the effective communication going on here, so that -// we pull multiple values of tv0 to compute tv3. -// -// Note, however, that the above example is not detectable at this -// moment as the self mapping is partial through reshape. The analysis -// below would need to be extended to consider producer and consumers -// of domains as well rather than just root, rfactor and leaf domains. -std::optional> detectMappablePair( - const std::vector& ids, - const IdModel& id_graph, - IdMappingMode mode) { - for (auto id1 : ids) { - for (auto id2 : ids) { - if (id1 == id2) { - continue; - } - if (id_graph.idGraph(mode).disjointValSets().permissiveAreMapped( - id1, id2)) { - return std::make_pair(id1, id2); - } - } - } - - return std::nullopt; -} - -// It is assumed that for any tensor represented by a list of domains, -// those domains should never be mapped with each other. It may be -// possible to lift this assumption, but it's unclear if it could -// matter in practice. -std::optional> -findFirstSelfMapping( - const std::vector& all_tvs, - const IdModel& id_model) { - for (auto tv : all_tvs) { - // For each tensor, make sure root, rfactor and leaf domains - // should not include domains that are mapped with another domain - // in the same set of domains. This may be overly conservative, - // and it maybe enough to check the root domains. - - // Root domains - auto self_mappped_root_pair = - detectMappablePair(tv->getRootDomain(), id_model, IdMappingMode::EXACT); - if (self_mappped_root_pair.has_value()) { - return std::make_tuple( - tv, - self_mappped_root_pair->first, - self_mappped_root_pair->second, - "Root"); - } - - // Rfactor domains - if (tv->hasRFactor()) { - auto self_mappped_rf_pair = detectMappablePair( - tv->getRFactorDomain(), id_model, IdMappingMode::EXACT); - if (self_mappped_rf_pair.has_value()) { - return std::make_tuple( - tv, - self_mappped_rf_pair->first, - self_mappped_rf_pair->second, - "RFactor"); - } - } - - // Leaf domains - // TODO: Exact map isn't quite right here, it should be based on the index - // map. However, it should also be impossible for index map to generate a - // case like this. - auto self_mappped_leaf_pair = detectMappablePair( - tv->domain()->leaf(), id_model, IdMappingMode::EXACT); - if (self_mappped_leaf_pair.has_value()) { - return std::make_tuple( - tv, - self_mappped_leaf_pair->first, - self_mappped_leaf_pair->second, - "Leaf"); - } - } - return std::nullopt; -} - -} // namespace - void IdModel::buildIterDomainDefinitionsAndUses() { for (const auto tv : tvs_) { VectorOfUniqueEntries root_domain_ids{ @@ -905,8 +791,8 @@ void validateLoopGraphHasNoSelfMappedLeafDomains( const std::vector& tvs, const IdModel& id_model) { for (auto tv : tvs) { - auto self_mappped_leaf_pair = - detectMappablePair(tv->domain()->leaf(), id_model, IdMappingMode::LOOP); + auto self_mappped_leaf_pair = detectSelfMapping( + tv->domain()->leaf(), id_model.idGraph(IdMappingMode::LOOP)); NVF_ERROR( !self_mappped_leaf_pair.has_value(), "Detected leaf domains are mapped in the loop graph. Tensor: ", From 270ac19d43322736a8728081ae5e424477b17f69 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 25 Mar 2024 17:17:52 -0700 Subject: [PATCH 161/178] test update for step 4 --- csrc/id_model/id_model.cpp | 52 +++-- csrc/id_model/id_model.h | 7 + tests/cpp/test_id_model.cpp | 394 +++++++++++++++++++++++++++++++----- 3 files changed, 373 insertions(+), 80 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 2eebcdce708..439d36bb42d 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -681,31 +681,6 @@ std::vector> resolvedRootBroadcasts( return resolved_bcast_domains; } -// Update a map of ValGroups to ID from an old Valgraph to a new -// ValGraph. The new graph must be a superset of the old graph. -std::unordered_map updateMap( - const std::unordered_map& stale_map, - ValGraph& new_graph) { - std::unordered_map new_map; - - for (const auto& [stale_group, mapped_id] : stale_map) { - const ValGroups& new_groups = new_graph.toGroups(*stale_group); - NVF_ERROR( - new_groups.size() == 1, - "\nUpdate map assumes that new graph is equivalent to old graph plus extra mappings.\n", - "i.e. all mappings in new_graph should exist in the graph stale_map was produced on.\n", - "old:", - nvfuser::toString(stale_group), - "new: ", - nvfuser::toString(new_groups)); - NVF_ERROR( - new_map.emplace(new_groups.front(), mapped_id).second, - "Expected only a single mapping but multiple entries detected for ", - nvfuser::toString(new_groups.front())); - } - return new_map; -} - } // namespace // Grab inlining relationships @@ -945,7 +920,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( loop_graph_copy = idGraph(IdMappingMode::LOOP); loop_graph_copy_promotion_map = - updateMap(loop_graph_copy_promotion_map, loop_graph_copy); + updateValGroupIdMap(loop_graph_copy_promotion_map, loop_graph_copy); // Step 4: In order to fully propagate the loop graph promotions, first // propagate them to the IEL groups, which are then used to @@ -976,7 +951,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // The loop map is built for loop_graph_copy. Update the map to the // latest loop graph final_loop_promotion_map = - updateMap(final_loop_promotion_map, idGraph(IdMappingMode::LOOP)); + updateValGroupIdMap(final_loop_promotion_map, idGraph(IdMappingMode::LOOP)); sanityCheckLoopPromotionMap(final_loop_promotion_map); @@ -1906,6 +1881,29 @@ void IdModel::sanityCheckLoopPromotionMap( } } +std::unordered_map updateValGroupIdMap( + const std::unordered_map& stale_map, + ValGraph& new_graph) { + std::unordered_map new_map; + + for (const auto& [stale_group, mapped_id] : stale_map) { + const ValGroups& new_groups = new_graph.toGroups(*stale_group); + NVF_ERROR( + new_groups.size() == 1, + "\nUpdate map assumes that new graph is equivalent to old graph plus extra mappings.\n", + "i.e. all mappings in new_graph should exist in the graph stale_map was produced on.\n", + "old:", + nvfuser::toString(stale_group), + "new: ", + nvfuser::toString(new_groups)); + NVF_ERROR( + new_map.emplace(new_groups.front(), mapped_id).second, + "Expected only a single mapping but multiple entries detected for ", + nvfuser::toString(new_groups.front())); + } + return new_map; +} + std::unordered_map IdModel::buildIndexGraph( const std::vector& exprs, const std::vector& all_tvs, diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 3ff5c71a580..391641b3adf 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -394,4 +394,11 @@ class IdModel : public PolymorphicBase { std::unordered_map loop_promotion_map_; }; +// A utility function to update a map of ValGroups to ID from an old +// Valgraph to a new ValGraph. The new graph must be a superset of the +// old graph. +std::unordered_map updateValGroupIdMap( + const std::unordered_map& stale_map, + ValGraph& new_graph); + } // namespace nvfuser diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 94f37d3fd75..3d6247982d0 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -111,6 +111,13 @@ ValType* getValByName(const std::vector& vals, StmtNameType name) { } } +IterDomain* getChildIdByName(IterDomain* id, StmtNameType name) { + auto named_val = getValByName(ir_utils::consumerValsOf(id), name); + NVF_ERROR(named_val != nullptr, "Cannot find a child ID named ", name); + NVF_ERROR(named_val->isA()); + return named_val->as(); +}; + // Helper class to test IdModel class IdModelTester : public IdModel { public: @@ -164,17 +171,60 @@ class IdModelTester : public IdModel { VERBOSE() << ss.str(); } - s3_loop_promotion_map = projectIELPromotionToLoopGraph( + const auto s3_original_loop_promotion_map = projectIELPromotionToLoopGraph( iel_graph, s2_iel_promotion_map, idGraph(IdMappingMode::LOOP), inlining_info); + + // Make a copy for validation as idGraph(IdMappingMode::LOOP) will + // be updated in the later steps + s3_loop_graph = idGraph(IdMappingMode::LOOP); + s3_loop_promotion_map = + updateValGroupIdMap(s3_original_loop_promotion_map, s3_loop_graph); + + for (const auto& loop_group : + s3_loop_graph.disjointValSets().disjointSets()) { + NVF_ERROR( + s3_loop_promotion_map.find(loop_group) != s3_loop_promotion_map.end(), + "No promotion found for: ", + nvfuser::toString(loop_group)); + } + + { + VERBOSE() << "Step 3: initial loop promotion map:" << std::endl; + for (const auto& [loop_group, id] : s3_loop_promotion_map) { + VERBOSE() << nvfuser::toString(loop_group) << " -> " << id->name() + << std::endl; + } + } + + // Note that s4_iel_promotion_map is an empty map at this + // point. It'll be populated with the Step-3 map + propagatePromotionsInIELGraph( + iel_graph, + s4_iel_promotion_map, + idGraph(IdMappingMode::LOOP), + s3_original_loop_promotion_map, + true); + + { + std::stringstream ss; + ss << "Step 4: IEL promotion map\n"; + for (const auto& [iel_group, promoted_id] : s4_iel_promotion_map) { + ss << "\t" << nvfuser::toString(iel_group) << " -> " + << promoted_id->name() << std::endl; + } + VERBOSE() << ss.str(); + } } ValGraph iel_graph; std::unordered_map s1_root_resolution_map; std::unordered_map s2_iel_promotion_map; + ValGraph s3_loop_graph; std::unordered_map s3_loop_promotion_map; + std::unordered_map s4_iel_promotion_map; }; // Test if id is resolved to an ID that is exact mapped with @@ -319,6 +369,37 @@ void checkStep3Results( } } +void checkStep4Results( + const ValGraph& iel_graph, + const std::unordered_map& iel_promotion_map, + const std::vector, IterDomain*>>& + ref_promotion_map) { + EXPECT_EQ(iel_promotion_map.size(), ref_promotion_map.size()) + << "Mismatched Step-4 result map. " + << "Expected to have " << ref_promotion_map.size() + << " mappings but found " << iel_promotion_map.size(); + + // for (const auto& [iel_group, promotion_id] : iel_promotion_map) { + for (const auto& ref_promotion_pair : ref_promotion_map) { + const auto& ref_promotion_group = ref_promotion_pair.first; + const auto& ref_promotion_id = ref_promotion_pair.second; + + auto iel_promotion_it = std::find_if( + iel_promotion_map.begin(), + iel_promotion_map.end(), + [&](const auto& iel_promotion) { + return iel_promotion.first->set() == ref_promotion_group; + }); + + auto iel_promotion_id = iel_promotion_it->second; + ASSERT_EQ(ref_promotion_id, iel_promotion_id) + << "Expected promotion: " << ref_promotion_id->toString() + << ". Actual: " << iel_promotion_id->toString(); + } + + std::cerr << "checkStep4Results done\n"; +} + // Create a fusion where we're missing a valid concrete id so the compute at map // processing will fail. We need to be able to create the concrete ID not just // look for one. It is not yet possible to lower this fusion as the @@ -687,9 +768,10 @@ TEST_F(IdModelTest, LoopPromotion1) { {std::unordered_set{t2->axis(1), t3->axis(1)}, t3->axis(1)}}; checkStep3Results( - tester.idGraph(IdMappingMode::LOOP), - tester.s3_loop_promotion_map, - s3_reference_map); + tester.s3_loop_graph, tester.s3_loop_promotion_map, s3_reference_map); + + ASSERT_TRUE(tester.s4_iel_promotion_map.empty()) + << "No step-4 IEL promotion expected"; } } @@ -744,9 +826,10 @@ TEST_F(IdModelTest, LoopPromotion2) { {std::unordered_set{t3->axis(0), t4->axis(0)}, t4->axis(0)}}; checkStep3Results( - tester.idGraph(IdMappingMode::LOOP), - tester.s3_loop_promotion_map, - s3_reference_map); + tester.s3_loop_graph, tester.s3_loop_promotion_map, s3_reference_map); + + ASSERT_TRUE(tester.s4_iel_promotion_map.empty()) + << "No step-4 IEL promotion expected"; } // Multiple inlined and non-inlined broadcast domains @@ -822,9 +905,10 @@ TEST_F(IdModelTest, LoopPromotion3) { tv3->axis(0)}}; checkStep3Results( - tester.idGraph(IdMappingMode::LOOP), - tester.s3_loop_promotion_map, - s3_reference_map); + tester.s3_loop_graph, tester.s3_loop_promotion_map, s3_reference_map); + + ASSERT_TRUE(tester.s4_iel_promotion_map.empty()) + << "No step-4 IEL promotion expected"; } // Test root resolution with a fusion with outer split. @@ -922,9 +1006,26 @@ TEST_F(IdModelTest, LoopPromotion4) { {std::unordered_set{tv2->axis(1)}, tv4->axis(1)}}; checkStep3Results( - tester.idGraph(IdMappingMode::LOOP), - tester.s3_loop_promotion_map, - s3_reference_map); + tester.s3_loop_graph, tester.s3_loop_promotion_map, s3_reference_map); + + auto id10 = getParentId(tv4->axis(0), 1); + ASSERT_EQ(id10->name(), 10); + auto id32 = + getValByName(ir_utils::consumerValsOf(id10), 32)->as(); + auto id33 = + getValByName(ir_utils::consumerValsOf(id10), 33)->as(); + + std::vector, IterDomain*>> + s4_reference_map = { + // 19 -> 10 + {std::unordered_set{getParentId(tv2->axis(0), 1)}, id10}, + // 20 -> 32 + {std::unordered_set{tv2->axis(0)}, id32}, + // 21 -> 33 + {std::unordered_set{tv2->axis(1)}, id33}}; + + checkStep4Results( + tester.iel_graph, tester.s4_iel_promotion_map, s4_reference_map); } // Test root resolution with the same fusion as Indexing1 @@ -1054,9 +1155,50 @@ TEST_F(IdModelTest, LoopPromotion5) { }; checkStep3Results( - tester.idGraph(IdMappingMode::LOOP), - tester.s3_loop_promotion_map, - s3_reference_map); + tester.s3_loop_graph, tester.s3_loop_promotion_map, s3_reference_map); + + auto id19 = getParentId(tv4->axis(0), 3); + ASSERT_EQ(id19->name(), 19); + auto id20 = getParentId(tv4->axis(0), 2); + ASSERT_EQ(id20->name(), 20); + auto id40 = getChildIdByName(id20, 40); + auto id41 = getChildIdByName(id20, 41); + auto id42 = getChildIdByName(id20, 42); + auto id43 = getChildIdByName(id20, 43); + auto id46 = getChildIdByName(id40, 46); + auto id47 = getChildIdByName(id40, 47); + auto id48 = getChildIdByName(id42, 48); + auto id49 = getChildIdByName(id42, 49); + + std::vector, IterDomain*>> + s4_reference_map = { + // 32 -> 19 + {std::unordered_set{getParentId(tv2->axis(0), 3)}, id19}, + // 33 -> 20 + {std::unordered_set{getParentId(tv2->axis(0), 2)}, id20}, + // 34 -> 40 + {std::unordered_set{getParentId(tv2->axis(0), 1)}, id40}, + // 35 -> 41 + {std::unordered_set{tv2->axis(2)}, id41}, + // 36 -> 46 + {std::unordered_set{tv2->axis(0)}, id46}, + // 37 -> 47 + {std::unordered_set{tv2->axis(1)}, id47}, + // 26 -> 19 + {std::unordered_set{getParentId(tv3->axis(0), 3)}, id19}, + // 27 -> 20 + {std::unordered_set{getParentId(tv3->axis(0), 2)}, id20}, + // 28 -> 42 + {std::unordered_set{getParentId(tv3->axis(0), 1)}, id42}, + // 29 -> 43 + {std::unordered_set{tv3->axis(2)}, id43}, + // 30 -> 48 + {std::unordered_set{tv3->axis(0)}, id48}, + // 31 -> 49 + {std::unordered_set{tv3->axis(1)}, id49}}; + + checkStep4Results( + tester.iel_graph, tester.s4_iel_promotion_map, s4_reference_map); } // Test root resolution with the same fusion as Indexing19 @@ -1150,30 +1292,16 @@ TEST_F(IdModelTest, LoopPromotion6) { tester.idGraph(IdMappingMode::EXACT), tester.s2_iel_promotion_map); - auto id79 = - getValByName(ir_utils::consumerValsOf(tv9->getRootDomain().at(2)), 79) - ->as(); - ASSERT_NE(id79, nullptr) << "IterDomain 79 not found"; - auto id80 = - getValByName(ir_utils::consumerValsOf(tv9->getRootDomain().at(2)), 80) - ->as(); - ASSERT_NE(id80, nullptr) << "IterDomain 80 not found"; - auto id81 = getChildId(id79, 1); - ASSERT_EQ(id81->name(), 81); - auto id82 = getChildId(id79, 1, 1); - ASSERT_EQ(id82->name(), 82); - auto id83 = getChildId(id80, 1); - ASSERT_EQ(id83->name(), 83); - auto id84 = getChildId(id80, 1, 1); - ASSERT_EQ(id84->name(), 84); - auto id85 = getChildId(id81, 1); - ASSERT_EQ(id85->name(), 85); - auto id86 = getChildId(id81, 1, 1); - ASSERT_EQ(id86->name(), 86); - auto id87 = getChildId(id83, 1); - ASSERT_EQ(id87->name(), 87); - auto id88 = getChildId(id83, 1, 1); - ASSERT_EQ(id88->name(), 88); + auto id79 = getChildIdByName(tv9->getRootDomain().at(2), 79); + auto id80 = getChildIdByName(tv9->getRootDomain().at(2), 80); + auto id81 = getChildIdByName(id79, 81); + auto id82 = getChildIdByName(id79, 82); + auto id83 = getChildIdByName(id80, 83); + auto id84 = getChildIdByName(id80, 84); + auto id85 = getChildIdByName(id81, 85); + auto id86 = getChildIdByName(id81, 86); + auto id87 = getChildIdByName(id83, 87); + auto id88 = getChildIdByName(id83, 88); // Check Step 3 results. See the design doc for the expected results std::vector, IterDomain*>> @@ -1250,9 +1378,111 @@ TEST_F(IdModelTest, LoopPromotion6) { }; checkStep3Results( - tester.idGraph(IdMappingMode::LOOP), - tester.s3_loop_promotion_map, - s3_reference_map); + tester.s3_loop_graph, tester.s3_loop_promotion_map, s3_reference_map); + + // For tv1 + auto id94 = getChildIdByName(id80, 94); + auto id95 = getChildIdByName(id80, 95); + auto id109 = getChildIdByName(id94, 109); + auto id110 = getChildIdByName(id94, 110); + + // For tv2 + auto id98 = getChildIdByName(id80, 98); + auto id99 = getChildIdByName(id80, 99); + auto id113 = getChildIdByName(id98, 113); + auto id114 = getChildIdByName(id98, 114); + + // For tv6 + auto id102 = getChildIdByName(id80, 102); + auto id103 = getChildIdByName(id80, 103); + auto id117 = getChildIdByName(id102, 117); + auto id118 = getChildIdByName(id102, 118); + + // For tv4 + auto id111 = getChildIdByName(id80, 111); + auto id112 = getChildIdByName(id80, 112); + auto id129 = getChildIdByName(id111, 129); + auto id130 = getChildIdByName(id111, 130); + + // For tv5 + auto id127 = getChildIdByName(id80, 127); + auto id128 = getChildIdByName(id80, 128); + auto id135 = getChildIdByName(id127, 135); + auto id136 = getChildIdByName(id127, 136); + + // For tv8 + auto id107 = getChildIdByName(id80, 107); + auto id108 = getChildIdByName(id80, 108); + auto id125 = getChildIdByName(id107, 125); + auto id126 = getChildIdByName(id107, 126); + + // For tv9 + auto id121 = getChildIdByName(id80, 121); + auto id122 = getChildIdByName(id80, 122); + auto id131 = getChildIdByName(id121, 131); + auto id132 = getChildIdByName(id121, 132); + + std::vector, IterDomain*>> + s4_reference_map = { + // tv1: 71 -> 94 + {std::unordered_set{getParentId(tv1->axis(0), 1)}, id94}, + // tv1: 72 -> 95 + {std::unordered_set{tv1->axis(2)}, id95}, + // tv1: 73 -> 109 + {std::unordered_set{tv1->axis(0)}, id109}, + // tv1: 74 -> 110 + {std::unordered_set{tv1->axis(1)}, id110}, + // tv2: 47 -> 98 + {std::unordered_set{getParentId(tv2->axis(0), 1)}, id98}, + // tv2: 48 -> 99 + {std::unordered_set{tv2->axis(2)}, id99}, + // tv2: 49 -> 113 + {std::unordered_set{tv2->axis(0)}, id113}, + // tv2: 50 -> 114 + {std::unordered_set{tv2->axis(1)}, id114}, + // tv4: 42 -> 111 + {std::unordered_set{getParentId(tv4->axis(0), 1)}, id111}, + // tv4: 43 -> 112 + {std::unordered_set{tv4->axis(2)}, id112}, + // tv4: 44 -> 129 + {std::unordered_set{tv4->axis(0)}, id129}, + // tv4: 45 -> 130 + {std::unordered_set{tv4->axis(1)}, id130}, + // tv5: 37 -> 127 + {std::unordered_set{getParentId(tv5->axis(0), 1)}, id127}, + // tv5: 38 -> 128 + {std::unordered_set{tv5->axis(2)}, id128}, + // tv5: 39 -> 135 + {std::unordered_set{tv5->axis(0)}, id135}, + // tv5: 40 -> 136 + {std::unordered_set{tv5->axis(1)}, id136}, + // tv6: 62 -> 102 + {std::unordered_set{getParentId(tv6->axis(0), 1)}, id102}, + // tv6: 63 -> 103 + {std::unordered_set{tv6->axis(2)}, id103}, + // tv6: 64 -> 117 + {std::unordered_set{tv6->axis(0)}, id117}, + // tv6: 65 -> 118 + {std::unordered_set{tv6->axis(1)}, id118}, + // tv8: 57 -> 107 + {std::unordered_set{getParentId(tv8->axis(0), 1)}, id107}, + // tv8: 58 -> 108 + {std::unordered_set{tv8->axis(2)}, id108}, + // tv8: 59 -> 125 + {std::unordered_set{tv8->axis(0)}, id125}, + // tv8: 60 -> 126 + {std::unordered_set{tv8->axis(1)}, id126}, + // tv9: 31 -> 121 + {std::unordered_set{getParentId(tv9->axis(0), 1)}, id121}, + // tv9: 32 -> 122 + {std::unordered_set{tv9->axis(2)}, id122}, + // tv9: 33 -> 131 + {std::unordered_set{tv9->axis(0)}, id131}, + // tv9: 34 -> 132 + {std::unordered_set{tv9->axis(1)}, id132}}; + + checkStep4Results( + tester.iel_graph, tester.s4_iel_promotion_map, s4_reference_map); } // Same fusion as NvFuserTest.FusionInlineBroadcastIndexing0 @@ -1317,6 +1547,8 @@ TEST_F(IdModelTest, LoopPromotion7) { tester.idGraph(IdMappingMode::EXACT), tester.s2_iel_promotion_map); + auto id8 = getChildIdByName(tv4->getRootDomain().at(0), 8); + // Check Step 3 results. See the design doc for the expected results std::vector, IterDomain*>> s3_reference_map = { @@ -1328,8 +1560,8 @@ TEST_F(IdModelTest, LoopPromotion7) { getChildId(tv3->getRootDomain().at(0), 1), tv4->getRootDomain().at(0), tv4->getRootDomain().at(1), - getChildId(tv4->getRootDomain().at(0), 1)}, - getChildId(tv4->getRootDomain().at(0), 1)}, + id8}, + id8}, // 17, 15, 9 -> 9 {std::unordered_set{tv2->axis(0), tv3->axis(0), tv4->axis(0)}, tv4->axis(0)}, @@ -1337,9 +1569,36 @@ TEST_F(IdModelTest, LoopPromotion7) { {std::unordered_set{tv3->axis(1)}, tv4->axis(1)}}; checkStep3Results( - tester.idGraph(IdMappingMode::LOOP), - tester.s3_loop_promotion_map, - s3_reference_map); + tester.s3_loop_graph, tester.s3_loop_promotion_map, s3_reference_map); + + // For tv2 + auto id26 = getChildIdByName(id8, 26); + auto id27 = getChildIdByName(id8, 27); + auto id34 = getChildIdByName(id27, 34); + auto id35 = getChildIdByName(id27, 35); + + // For tv3 + auto id30 = getChildIdByName(id8, 30); + auto id31 = getChildIdByName(id8, 31); + + std::vector, IterDomain*>> + s4_reference_map = { + // tv2: 17 -> 26 + {std::unordered_set{tv2->axis(0)}, id26}, + // tv2: 18 -> 27 + {std::unordered_set{getParentId(tv2->axis(1), 1)}, id27}, + // tv2: 21 -> 34 + {std::unordered_set{tv2->axis(1)}, id34}, + // tv2: 22 -> 35 + {std::unordered_set{tv2->axis(2)}, id35}, + // tv3: 15 -> 26 + {std::unordered_set{tv3->axis(0)}, id30}, + // tv3: 16 -> 27 + {std::unordered_set{tv3->axis(1)}, id31}, + }; + + checkStep4Results( + tester.iel_graph, tester.s4_iel_promotion_map, s4_reference_map); } // Same fusion as NvFuserTest.FusionIndexing20 @@ -1433,6 +1692,11 @@ TEST_F(IdModelTest, LoopPromotion8) { tester.idGraph(IdMappingMode::EXACT), tester.s2_iel_promotion_map); + auto id29 = getParentId(tv7->axis(0), 1); + ASSERT_EQ(id29->name(), 29) << "Unexpected ID: " << id29->toString(); + auto id42 = getParentId(tv7->axis(1), 1); + ASSERT_EQ(id42->name(), 42); + // Check Step 3 results. See the design doc for the expected results std::vector, IterDomain*>> s3_reference_map = { @@ -1469,8 +1733,8 @@ TEST_F(IdModelTest, LoopPromotion8) { getChildId( getChildId(tv7->getRootDomain().at(0), 1), 1, 1), // 31 tv7->getRootDomain().at(2), // 16 - getChildId(tv7->getRootDomain().at(2), 1)}, // 42 - getChildId(tv7->getRootDomain().at(2), 1)}, + id42}, // 42 + id42}, // 22 -> 19 {std::unordered_set{tv2->axis(1)}, tv4->axis(1)}, // 40, 43 -> 43 @@ -1480,9 +1744,33 @@ TEST_F(IdModelTest, LoopPromotion8) { }; checkStep3Results( - tester.idGraph(IdMappingMode::LOOP), - tester.s3_loop_promotion_map, - s3_reference_map); + tester.s3_loop_graph, tester.s3_loop_promotion_map, s3_reference_map); + + auto id49 = getChildIdByName(id29, 49); + auto id50 = getChildIdByName(id29, 50); + auto id51 = getChildIdByName(id29, 51); + auto id52 = getChildIdByName(id29, 52); + auto id63 = getChildIdByName(id42, 63); + auto id64 = getChildIdByName(id42, 64); + + std::vector, IterDomain*>> + s4_reference_map = { + // tv1: 35 -> 49 + {std::unordered_set{tv1->axis(0)}, id49}, + // tv1: 36 -> 50 + {std::unordered_set{tv1->axis(1)}, id50}, + // tv2: 21 -> 51 + {std::unordered_set{tv2->axis(0)}, id51}, + // tv2: 22 -> 52 + {std::unordered_set{tv2->axis(1)}, id52}, + // tv5: 40 -> 63 + {std::unordered_set{tv5->axis(1)}, id63}, + // tv5: 41 -> 64 + {std::unordered_set{tv5->axis(2)}, id64}, + }; + + checkStep4Results( + tester.iel_graph, tester.s4_iel_promotion_map, s4_reference_map); } // A repro that produces an invalid loop graph due to the compliment From e17fa3557a57cc18f1cf54aea6671e0e126fe87f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 3 Apr 2024 18:30:13 -0700 Subject: [PATCH 162/178] cleanup --- csrc/id_model/id_model.h | 2 +- tests/cpp/test_id_model.cpp | 13 +------------ 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 391641b3adf..1c8cd8812fa 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -208,7 +208,7 @@ class IdModel : public PolymorphicBase { // create a mapping from the outputs of the IEL expr to the outputs // of the equivalent expr. When require_loop_mapped_promotion is // true, the equivalent expr needs to be already loop mapped. If no - // such expr is found, the IEL expr is replayed iwth the promoted + // such expr is found, the IEL expr is replayed with the promoted // inputs. require_loop_mapped_promotion is true when this function // is used for step 3. // diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index b6817993dd5..ee0304e5c2e 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -183,14 +183,6 @@ class IdModelTester : public IdModel { s3_loop_promotion_map = updateValGroupIdMap(s3_original_loop_promotion_map, s3_loop_graph); - for (const auto& loop_group : - s3_loop_graph.disjointValSets().disjointSets()) { - NVF_ERROR( - s3_loop_promotion_map.find(loop_group) != s3_loop_promotion_map.end(), - "No promotion found for: ", - nvfuser::toString(loop_group)); - } - { VERBOSE() << "Step 3: initial loop promotion map:" << std::endl; for (const auto& [loop_group, id] : s3_loop_promotion_map) { @@ -379,7 +371,6 @@ void checkStep4Results( << "Expected to have " << ref_promotion_map.size() << " mappings but found " << iel_promotion_map.size(); - // for (const auto& [iel_group, promotion_id] : iel_promotion_map) { for (const auto& ref_promotion_pair : ref_promotion_map) { const auto& ref_promotion_group = ref_promotion_pair.first; const auto& ref_promotion_id = ref_promotion_pair.second; @@ -392,12 +383,10 @@ void checkStep4Results( }); auto iel_promotion_id = iel_promotion_it->second; - ASSERT_EQ(ref_promotion_id, iel_promotion_id) + EXPECT_EQ(ref_promotion_id, iel_promotion_id) << "Expected promotion: " << ref_promotion_id->toString() << ". Actual: " << iel_promotion_id->toString(); } - - std::cerr << "checkStep4Results done\n"; } // Create a fusion where we're missing a valid concrete id so the compute at map From e21e4e55a9fecbb8f2574a6ac401894beffdb537 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 8 May 2024 16:15:22 -0700 Subject: [PATCH 163/178] clang-tidy --- csrc/id_model/id_model.cpp | 6 +++--- tests/cpp/test_id_model.cpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 72d6e306b49..4a0a69f903b 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -949,8 +949,8 @@ std::unordered_map IdModel::buildLoopPromotionMap( // The loop map is built for loop_graph_copy. Update the map to the // latest loop graph - final_loop_promotion_map = - updateValGroupIdMap(final_loop_promotion_map, idGraph(IdMappingMode::LOOP)); + final_loop_promotion_map = updateValGroupIdMap( + final_loop_promotion_map, idGraph(IdMappingMode::LOOP)); sanityCheckLoopPromotionMap(final_loop_promotion_map); @@ -1321,7 +1321,7 @@ Expr* findMatchingExpr( iel_expr->front(), maybe_promoted_input_use_group->front())) { continue; } - + // This is just an extra sanity check. Make sure all exprs in // the use group are mapped NVF_ERROR( diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 01bb1635977..e86c91fe375 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -197,7 +197,7 @@ class IdModelTester : public IdModel { iel_graph, s4_iel_promotion_map, idGraph(IdMappingMode::LOOP), - s3_original_loop_promotion_map); + s3_original_loop_promotion_map); { std::stringstream ss; From 801a7b4de39961ea691629d9ff6c7f1ffffc819f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 8 May 2024 16:44:24 -0700 Subject: [PATCH 164/178] Step 5 of loop promotion analysis --- csrc/device_lower/lower2device.cpp | 2 +- csrc/id_model/id_model.cpp | 87 ++++++++++++-- csrc/id_model/id_model.h | 26 +++- tests/cpp/test_id_model.cpp | 183 +++++++++++++++++++++++++---- 4 files changed, 266 insertions(+), 32 deletions(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 76181bd72b4..f988ce65c4f 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -391,7 +391,7 @@ void GpuLower::analysis(Fusion* fusion) { // functionality should be affected. New IterDomains may be created, // so it is expected that generated code may use diffrent variable // names - if (isOptionEnabled(EnableOption::IdModel)) { + if (true || isOptionEnabled(EnableOption::IdModel)) { IdModel id_model(fusion_); } diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index c35784e7826..281ff0f8536 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -74,7 +74,8 @@ IdModel::IdModel( const std::vector& exprs, const std::vector& additional_tvs, bool build_graphs, - bool allow_self_mapping) { + bool allow_self_mapping) + : allow_self_mapping_(allow_self_mapping) { std::copy_if( exprs.begin(), exprs.end(), @@ -570,8 +571,14 @@ void IdModel::buildLoopGraph() { initializeLoopGraph(inlining_info); + validateLoopGraphHasNoSelfMappedLeafDomains(); + loop_promotion_map_ = buildLoopPromotionMap(inlining_info); + // New domains are added. Make sure there's still no self mapping in + // the leaf domains + validateLoopGraphHasNoSelfMappedLeafDomains(); + idGraph(IdMappingMode::LOOP).validateConsistency(); } @@ -645,10 +652,32 @@ std::unordered_map IdModel::buildLoopPromotionMap( idGraph(IdMappingMode::LOOP), loop_promotion_map); - // This is not a right map to return but just a placeholder since - // the loop promotion map is not yet completely merged. It will be - // replaced by a proper map. - return final_iel_promotion_map; + // Step 5: Find the final promotion of each loop group based on the + // final IEL promotion map + auto final_loop_promotion_map = projectIELPromotionToLoopGraph( + iel_graph, + final_iel_promotion_map, + idGraph(IdMappingMode::LOOP), + inlining_info); + + // The promotion map produced in Step 5 only includes those are + // further propagated at Step 4, so the correct mappings produced at + // Step 3 may not be included in the Step-5 results. Any Step-3 mappings + // that are not found in the Step-5 results are already valid + // results, so merge them into the Step-5 results. + + // Update the Step-3 map to the latest LOOP graph + loop_promotion_map = + updateValGroupIdMap(loop_promotion_map, idGraph(IdMappingMode::LOOP)); + + // Insert the updated Step-3 results into the Step-5 resutls. Note + // that this insertion does not overwrite the existing mappings. + final_iel_promotion_map.insert( + loop_promotion_map.begin(), loop_promotion_map.end()); + + sanityCheckLoopPromotionMap(final_loop_promotion_map); + + return final_loop_promotion_map; } std::unordered_map IdModel::buildInlineRootResolutionMap( @@ -1226,10 +1255,10 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { // Initialize output ids in map. The replay expr will be // registered as a definition by registerExpr for (auto out_id : ir_utils::filterByType(replay->outputs())) { - idGraph(mode).initializeVal(out_id, {}, {}); + graph.initializeVal(out_id, {}, {}); } - idGraph(mode).registerExpr(replay); + graph.registerExpr(replay); // Propagate through all the uses of the iter domain groups of the inputs // with the new expression. @@ -1455,6 +1484,50 @@ VectorOfUniqueEntries IdModel::computeTerminalLoopIds( return terminal_loop_ids; } +void IdModel::sanityCheckLoopPromotionMap( + const std::unordered_map& loop_promotion_map) const { + const auto& loop_graph = idGraph(IdMappingMode::LOOP); + for (const ValGroup& loop_group : + loop_graph.disjointValSets().disjointSets()) { + // Non-leaf loop groups are not guaranteed to have valid + // promotions. See for example FusionRepro1713, where root domains + // are all grouped together but there's no valid promotion. + if (loop_graph.hasUses(loop_group)) { + continue; + } + // Make sure the loop group is promoted to a domain that is mapped + // in the LOOP graph + auto promotion_it = loop_promotion_map.find(loop_group); + NVF_ERROR( + promotion_it != loop_promotion_map.end(), + "Loop promotion not found for ", + nvfuser::toString(loop_group)); + IterDomain* promotion = promotion_it->second; + // Make sure the promotion domain is also loop-mapped + NVF_ERROR( + loop_group->has(promotion), + "Loop promotion not loop-mapped. Loop group: ", + nvfuser::toString(loop_group), + ". Promotion domain: ", + promotion->name()); + } +} + +void IdModel::validateLoopGraphHasNoSelfMappedLeafDomains() const { + for (auto tv : tvs_) { + auto self_mappped_leaf_pair = + detectSelfMapping(tv->domain()->leaf(), idGraph(IdMappingMode::LOOP)); + NVF_ERROR( + !self_mappped_leaf_pair.has_value(), + "Detected leaf domains are mapped in the loop graph. Tensor: ", + tv->toString(), + ". Mapped leaf domains: ", + self_mappped_leaf_pair->first->toString(), + " and ", + self_mappped_leaf_pair->second->toString()); + } +} + std::unordered_map updateValGroupIdMap( const std::unordered_map& stale_map, ValGraph& new_graph) { diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 4c7da6cfaa7..8255e6a8000 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -64,6 +64,19 @@ StatefulInliningInfo buildStatefulInliningInfo( // considered the exact same size operating on matching dimensions from the root // domain mapping. // +// LOOP mode is important to resolve inlined broadcassts. If we have something +// like: consumer[i0o, threadIdx.x{i0i}] = producer[i0o, +// threadIdx.y{i0i}](computeAt = 1) which can easily happen when using shared +// memory. Loop is actually defined for all iteration domains, and resembles +// groups of iter domains that are effectively inlined with each other. +// Therefore iter domain's that are a common dependency of inlined leaf domains +// may be loop mapped together. +// +// Loop promotion is a mechanism by which to capture inlined resolved +// broadcasts. If a consumer resolves a broadcast of a producer, and the +// producer's broadcast is inlined (in total or partially). Then the producer's +// iter domain will be "promoted" to the size of the consumers iter domain. +// // IdMappingMode::EXACT // Don't map any broadcast axes to non-broadcast axes // Do not forward through any broadcast IDs @@ -80,8 +93,7 @@ StatefulInliningInfo buildStatefulInliningInfo( // Forward through split one axes, i.e. id{ceilDiv(i0, 1)}, id{i0} are mapped // IdMappingMode::LOOP // Subgraph of the permissive graph. Maps only CA and their -// dependent domains -// +// dependent domains. class IdModel : public PolymorphicBase { public: // Sometimes fusion inputs or outputs are disconnected from expressions, in @@ -260,6 +272,16 @@ class IdModel : public PolymorphicBase { // Errors if self mapping occurs void assertNoSelfMapping(); + // Basic consistency check of the given loop promotion map + void sanityCheckLoopPromotionMap( + const std::unordered_map& loop_promotion_map) + const; + + // Loop graph represents the loop structure of the given fusion, so + // there must not be any mapping between the leaf domains of each + // tensor. + void validateLoopGraphHasNoSelfMappedLeafDomains() const; + // Replay Expr but with the inputs provided. ValGraphs will be updated // for all maps that have entries, adding the output iter domains of the // replayed expression and adding potential mappings through the expression. diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index f8daa1a20ed..a0a004d5c88 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -134,6 +134,8 @@ class IdModelTester : public IdModel { initializeLoopGraph(inlining_info); + validateLoopGraphHasNoSelfMappedLeafDomains(); + iel_graph = buildIntersection( idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); @@ -163,6 +165,27 @@ class IdModelTester : public IdModel { s4_iel_promotion_map, idGraph(IdMappingMode::LOOP), s3_original_loop_promotion_map); + + // Step 5: Find the final promotion of each loop group based on the + // final IEL promotion map + s5_loop_promotion_map = projectIELPromotionToLoopGraph( + iel_graph, + s4_iel_promotion_map, + idGraph(IdMappingMode::LOOP), + inlining_info); + + auto updated_s3_loop_promotion_map = updateValGroupIdMap( + s3_loop_promotion_map, idGraph(IdMappingMode::LOOP)); + s5_loop_promotion_map.insert( + updated_s3_loop_promotion_map.begin(), + updated_s3_loop_promotion_map.end()); + + sanityCheckLoopPromotionMap(s5_loop_promotion_map); + validateLoopGraphHasNoSelfMappedLeafDomains(); + + s5_loop_graph = idGraph(IdMappingMode::LOOP); + s5_loop_promotion_map = + updateValGroupIdMap(s5_loop_promotion_map, s5_loop_graph); } void print(std::ostream& os) const { @@ -182,6 +205,10 @@ class IdModelTester : public IdModel { for (const auto& [g, id] : s4_iel_promotion_map) { os << nvfuser::toString(g) << " -> " << id->toString() << std::endl; } + os << "Step 5 results:\n"; + for (const auto& [g, id] : s5_loop_promotion_map) { + os << nvfuser::toString(g) << " -> " << id->toString() << std::endl; + } } ValGraph iel_graph; @@ -190,6 +217,8 @@ class IdModelTester : public IdModel { ValGraph s3_loop_graph; std::unordered_map s3_loop_promotion_map; std::unordered_map s4_iel_promotion_map; + ValGraph s5_loop_graph; + std::unordered_map s5_loop_promotion_map; }; // Test if id is resolved to an ID that is exact mapped with @@ -374,6 +403,42 @@ void checkStep4Results( } } +void checkStep5Results( + const IdModelTester& tester, + const std::unordered_map>& + ref_promotion_map) { + const auto& loop_graph = tester.s5_loop_graph; + const auto& loop_promotion_map = tester.s5_loop_promotion_map; + + // Record if each entry of ref_promotion_map is found + std::vector ref_promotion_map_found(ref_promotion_map.size(), false); + + for (const auto& [tv, ref_promotion_domains] : ref_promotion_map) { + ASSERT_EQ(ref_promotion_domains.size(), tv->nDims()) + << "Invalid number of domains: " + << toDelimitedString(ref_promotion_domains); + for (const auto i : c10::irange(tv->nDims())) { + IterDomain* loop_id = tv->axis(i); + const ValGroup& loop_group = loop_graph.toGroup(loop_id); + + auto promotion_it = loop_promotion_map.find(loop_group); + ASSERT_NE(promotion_it, loop_promotion_map.end()) + << "No promotion found for: " << nvfuser::toString(loop_group); + + IterDomain* promotion_id = promotion_it->second; + + ASSERT_EQ(promotion_id, ref_promotion_domains.at(i)) + << "Expected promotion: " << ref_promotion_domains.at(i)->toString() + << ". Actual: " << promotion_id->toString(); + + ASSERT_EQ(loop_graph.toGroup(promotion_id), loop_group) + << "Loop group promoted to a non-mapped domain. Loop group: " + << nvfuser::toString(loop_group) + << ". Promotion: " << promotion_id->toString(); + } + } +} + // Create a fusion where we're missing a valid concrete id so the compute at map // processing will fail. We need to be able to create the concrete ID not just // look for one. It is not yet possible to lower this fusion as the @@ -932,7 +997,10 @@ TEST_F(IdModelTest, LoopPromotion4) { checkStep2Results(&fusion, tester); auto id10 = getChildIdByName(tv4->getRootDomain()[0], 10); + auto id11 = getChildIdByName(id10, 11); + auto id12 = getChildIdByName(id10, 12); auto id13 = getChildIdByName(tv3->getRootDomain()[0], 13); + auto id15 = getChildIdByName(id13, 15); auto id19 = getChildIdByName(tv2->getRootDomain()[0], 19); auto id25 = getChildIdByName(id10, 25); auto id26 = getChildIdByName(id10, 26); @@ -956,7 +1024,7 @@ TEST_F(IdModelTest, LoopPromotion4) { // 11, 14, 20, 25 -> 11 {std::unordered_set{ tv2->axis(0), tv3->axis(0), tv4->axis(0), id25}, - tv4->axis(0)}, + id11}, // 21, 26 -> 26 {std::unordered_set{tv2->axis(1), id26}, id26}}; @@ -975,6 +1043,15 @@ TEST_F(IdModelTest, LoopPromotion4) { {std::unordered_set{tv2->axis(1)}, id35}}; checkStep4Results(tester, s4_reference_map); + + // Check Step 5 results. See the design doc for the expected results + std::unordered_map> s5_reference_map = { + {tv2, {id11, id35}}, + {tv3, {id11, id15}}, + {tv4, {id11, id12}}, + }; + + checkStep5Results(tester, s5_reference_map); } // Test root resolution with the same fusion as Indexing1 @@ -1050,6 +1127,10 @@ TEST_F(IdModelTest, LoopPromotion5) { ASSERT_EQ(id19->name(), 19); auto id20 = getParentId(tv4->axis(0), 2); ASSERT_EQ(id20->name(), 20); + auto id21 = getChildIdByName(id20, 21); + auto id22 = getChildIdByName(id20, 22); + auto id23 = getChildIdByName(id21, 23); + auto id24 = getChildIdByName(id21, 24); auto id38 = getChildIdByName(id20, 38); auto id39 = getChildIdByName(id20, 39); auto id40 = getChildIdByName(id38, 40); @@ -1096,7 +1177,7 @@ TEST_F(IdModelTest, LoopPromotion5) { {std::unordered_set{ getParentId(tv2->axis(0), 1), getParentId(tv3->axis(0), 1), - getParentId(tv4->axis(0), 1), + id21, id38}, getParentId(tv4->axis(0), 1)}, // 29, 39 -> 29 @@ -1104,9 +1185,8 @@ TEST_F(IdModelTest, LoopPromotion5) { // 31, 41 -> 41 {std::unordered_set{tv3->axis(1), id41}, id41}, // 23, 30, 36, 40 -> 23 - {std::unordered_set{ - tv2->axis(0), tv3->axis(0), tv4->axis(0), id40}, - tv4->axis(0)}, + {std::unordered_set{tv2->axis(0), tv3->axis(0), id23, id40}, + id23}, }; checkStep3Results(tester, s3_reference_map); @@ -1149,6 +1229,15 @@ TEST_F(IdModelTest, LoopPromotion5) { {std::unordered_set{tv3->axis(1)}, id53}}; checkStep4Results(tester, s4_reference_map); + + // Check Step 5 results. See the design doc for the expected results + std::unordered_map> s5_reference_map = { + {tv2, {id23, id51, id45}}, + {tv3, {id23, id53, id47}}, + {tv4, {id23, id24, id22}}, + }; + + checkStep5Results(tester, s5_reference_map); } // Test root resolution with the same fusion as Indexing19 @@ -1448,6 +1537,19 @@ TEST_F(IdModelTest, LoopPromotion6) { {std::unordered_set{tv9->axis(1)}, id140}}; checkStep4Results(tester, s4_reference_map); + + // Check Step 5 results. See the design doc for the expected results + std::unordered_map> s5_reference_map = { + {tv1, {id143, id118, id103}}, + {tv2, {id143, id122, id107}}, + {tv4, {id143, id138, id120}}, + {tv5, {id143, id144, id136}}, + {tv6, {id143, id126, id111}}, + {tv8, {id143, id134, id116}}, + {tv9, {id143, id140, id130}}, + }; + + checkStep5Results(tester, s5_reference_map); } // Same fusion as NvFuserTest.FusionInlineBroadcastIndexing0 @@ -1508,27 +1610,29 @@ TEST_F(IdModelTest, LoopPromotion7) { checkStep2Results(&fusion, tester); auto id8 = getChildIdByName(tv4->getRootDomain().at(0), 8); + auto id9 = getChildIdByName(id8, 9); + auto id10 = getChildIdByName(id8, 10); auto id23 = getChildIdByName(id8, 23); auto id24 = getChildIdByName(id8, 24); // Check Step 3 results. See the design doc for the expected results std::vector, IterDomain*>> - s3_reference_map = {// 3, 4, 5, 14, 6, 7, 8, -> 8 - {std::unordered_set{ - tv2->getRootDomain().at(0), - tv3->getRootDomain().at(0), - tv3->getRootDomain().at(1), - getChildId(tv3->getRootDomain().at(0), 1), - tv4->getRootDomain().at(0), - tv4->getRootDomain().at(1), - id8}, - id8}, - // 9, 15, 17, 23 -> 9 - {std::unordered_set{ - tv2->axis(0), tv3->axis(0), tv4->axis(0), id23}, - tv4->axis(0)}, - // 16, 24 -> 24 - {std::unordered_set{tv3->axis(1), id24}, id24}}; + s3_reference_map = { + // 3, 4, 5, 14, 6, 7, 8, -> 8 + {std::unordered_set{ + tv2->getRootDomain().at(0), + tv3->getRootDomain().at(0), + tv3->getRootDomain().at(1), + getChildId(tv3->getRootDomain().at(0), 1), + tv4->getRootDomain().at(0), + tv4->getRootDomain().at(1), + id8}, + id8}, + // 9, 15, 17, 23 -> 9 + {std::unordered_set{tv2->axis(0), tv3->axis(0), id9, id23}, + id9}, + // 16, 24 -> 24 + {std::unordered_set{tv3->axis(1), id24}, id24}}; checkStep3Results(tester, s3_reference_map); @@ -1559,6 +1663,15 @@ TEST_F(IdModelTest, LoopPromotion7) { }; checkStep4Results(tester, s4_reference_map); + + // Check Step 5 results. See the design doc for the expected results + std::unordered_map> s5_reference_map = { + {tv2, {id9, id36, id37}}, + {tv3, {id9, id33}}, + {tv4, {id9, id10}}, + }; + + checkStep5Results(tester, s5_reference_map); } // Same fusion as NvFuserTest.FusionIndexing20 @@ -1675,6 +1788,7 @@ TEST_F(IdModelTest, LoopPromotion8) { auto id40 = tv5->axis(1); auto id41 = tv5->axis(2); auto id42 = getChildIdByName(id16, 42); + auto id44 = getChildIdByName(id42, 44); auto id47 = getChildIdByName(id42, 47); auto id48 = getChildIdByName(id42, 48); @@ -1682,7 +1796,8 @@ TEST_F(IdModelTest, LoopPromotion8) { auto id6 = tv4->getRootDomain().at(0); auto id7 = tv4->getRootDomain().at(1); auto id17 = getChildIdByName(id6, 17); - auto id18 = tv4->axis(0); + auto id18 = getChildIdByName(id17, 18); + auto id19 = getChildIdByName(id17, 19); // tv1 auto id1 = tv1->getRootDomain().at(0); @@ -1750,6 +1865,17 @@ TEST_F(IdModelTest, LoopPromotion8) { }; checkStep4Results(tester, s4_reference_map); + + // Check Step 5 results. See the design doc for the expected results + std::unordered_map> s5_reference_map = { + {tv1, {id30, id54}}, + {tv2, {id30, id56}}, + {tv4, {id30, id19}}, + {tv5, {id30, id43, id68}}, + {tv7, {id30, id43, id44}}, + }; + + checkStep5Results(tester, s5_reference_map); } // A case to illustrate the effect of the below issue and RR. @@ -1801,14 +1927,18 @@ TEST_F(IdModelTest, LoopPromotionPromoteToSameLoopGroup) { auto id8 = tv4->getRootDomain().at(1); auto id11 = getChildIdByName(id7, 11); auto id9 = getChildIdByName(id8, 9); + auto id10 = getChildIdByName(id8, 10); auto id13 = getChildIdByName(id11, 13); + auto id14 = getChildIdByName(id10, 14); // tv3 auto id5 = tv3->getRootDomain().at(0); auto id6 = tv3->getRootDomain().at(1); auto id15 = getChildIdByName(id5, 15); + auto id16 = getChildIdByName(id5, 16); auto id17 = getChildIdByName(id6, 17); auto id19 = getChildIdByName(id15, 19); + auto id20 = getChildIdByName(id16, 20); // tv2 auto id3 = tv2->getRootDomain().at(0); @@ -1850,6 +1980,15 @@ TEST_F(IdModelTest, LoopPromotionPromoteToSameLoopGroup) { {std::unordered_set{id32}, id60}}; checkStep4Results(tester, s4_reference_map); + + // Check Step 5 results. See the design doc for the expected results + std::unordered_map> s5_reference_map = { + {tv2, {id13, id60}}, + {tv3, {id13, id20}}, + {tv4, {id13, id14}}, + }; + + checkStep5Results(tester, s5_reference_map); } namespace { From 649563a57c387b2cfcc11595175e090dc0a20909 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 9 May 2024 23:20:11 -0700 Subject: [PATCH 165/178] comment --- csrc/id_model/id_model.cpp | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 281ff0f8536..d69e1620531 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -665,6 +665,37 @@ std::unordered_map IdModel::buildLoopPromotionMap( // Step 3 may not be included in the Step-5 results. Any Step-3 mappings // that are not found in the Step-5 results are already valid // results, so merge them into the Step-5 results. + // + // For example, in the below case, nothing will be propated at Step + // 4. + // + // t0: [i0] + // t1: [i1, i2] + // t2 = broadcast(t0, {true, false}) + // t3 = t2 + t1 + // + // t2: [b3, i4] + // t3: [i5, i6] + // + // t3->merge(0) + // propagate-and-inline-most + // + // t0: [i0] ca_pos(1) + // t1: [i1*i2] ca_pos(1) + // t2: [b3*i4] ca_pos(1) + // t3: [i5*i6] + // + // In this case, all domains will be grouped together and there will + // be just a single group in the Loop graph: + // + // - {i0, i1, i2, b3, i4, i5, i6, i1*i2, b3*i4, i5*i6} + // + // Step 3 will identify i5*i6 is the promotion domain. Since all + // domains are promoted to i5*i6, there will be no propagation in + // Step 4 (i.e., loop_promote_inputs will be false). Since the + // result of Step 4 is empty, the Step 5 result will also be empty, + // but that just means there's no change is necessary from the Step + // 3 results. // Update the Step-3 map to the latest LOOP graph loop_promotion_map = From b25ced1eed8f9a2ce6feac14c24577d2c2c90960 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 18 Mar 2024 15:19:29 -0700 Subject: [PATCH 166/178] repro of issue #1759 --- tests/cpp/test_id_model.cpp | 85 +++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index a0a004d5c88..5033119f758 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -1991,6 +1991,91 @@ TEST_F(IdModelTest, LoopPromotionPromoteToSameLoopGroup) { checkStep5Results(tester, s5_reference_map); } +// A repro that produces an invalid loop graph due to the compliment +// mapping. This is not currently supported. See +// https://github.com/NVIDIA/Fuser/issues/1759 +TEST_F(IdModelTest, ComplimentMappingCausingLoopSelfMapping) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({7}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({7, 8}); + fusion.addInput(tv1); + auto tv2 = makeConcreteTensor({7, 9}); + fusion.addInput(tv2); + + auto tv3 = broadcast(tv0, {false, true}); + auto tv4 = add(tv1, tv3); + auto tv5 = broadcast(tv4, {false, false, true}); + + auto tv6 = broadcast(tv0, {false, true}); + auto tv7 = add(tv2, tv6); + auto tv8 = broadcast(tv7, {false, true, false}); + + auto tv9 = add(tv5, tv8); + + auto tv10 = set(tv9); + auto tv11 = set(tv10); + fusion.addOutput(tv11); + + // Merge all domains except for tv10 and tv11 + for (auto tv : ir_utils::allTvs(&fusion)) { + if (tv == tv10 || tv == tv11) { + continue; + } + while (tv->nDims() > 1) { + tv->merge(0); + } + } + + // Fully inline all tensors up until tv10 + for (auto tv : ir_utils::allTvs(&fusion)) { + if (tv == tv9 || tv == tv10 || tv == tv11) { + continue; + } + tv->inlineAt(1); + } + + // Fully inline tv10 to tv11 without merging + tv10->inlineAt(-1); + + // Due to the compliment mapping, the leaf domains of tv10 and tv11 + // are loop mapped, which is invalid. + // + // Specifically, here are the tv10 and tv11 tensors: + // + // T10_l[ iS22{7}, iS23{8}, iS24{9} ] ca_pos( 3 ) + // root domain : (iS22{7}, iS23{8}, iS24{9}) + // contiguity: t t t + // leaf domain : (iS22{7}, iS23{8}, iS24{9}) + // T11_g[ iS25{7}, iS26{8}, iS27{9} ] produce_pos( 3 ) + // root domain : (iS25{7}, iS26{8}, iS27{9}) + // contiguity: t t t + // leaf domain : (iS25{7}, iS26{8}, iS27{9}) + // + // Here's the loop graph for tv10 and tv11: + // idg{22 23 24 25 26 27} + + // Due to the invalid mapping, building IdModel should fail for now + EXPECT_THAT( + [&]() { IdModel id_model(&fusion, true, false, false); }, + ::testing::ThrowsMessage(::testing::HasSubstr( + "Detected leaf domains are mapped in the loop graph"))); + + // Enable the below validation once the above problem is resolved. + // + // const ValGraph& loop_graph = id_model.idGraph(IdMappingMode::LOOP); + // + // These assertions should fail at this moment. + // ASSERT_NE( + // loop_graph.toGroup(tv10->axis(0)), loop_graph.toGroup(tv10->axis(1))); + // ASSERT_NE( + // loop_graph.toGroup(tv10->axis(0)), loop_graph.toGroup(tv10->axis(2))); + // ASSERT_NE( + // loop_graph.toGroup(tv10->axis(1)), loop_graph.toGroup(tv10->axis(2))); +} + namespace { bool iterDomainsAreMapped( const IdModel& id_model, From fa8455a19a98cea693e0cbcd8f363afe6e6a2786 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 May 2024 11:54:42 -0700 Subject: [PATCH 167/178] LoopPromotionBuilder class --- CMakeLists.txt | 1 + csrc/id_model/id_model.cpp | 5 +- csrc/id_model/id_model.h | 2 + csrc/id_model/loop_promotion.cpp | 159 +++++++++++++++++++++++++++++++ csrc/id_model/loop_promotion.h | 39 ++++++++ 5 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 csrc/id_model/loop_promotion.cpp create mode 100644 csrc/id_model/loop_promotion.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 04a7996933e..f0dfb955f31 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -135,6 +135,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/host_ir/executor.cpp ${NVFUSER_SRCS_DIR}/host_ir/host_ir.cpp ${NVFUSER_SRCS_DIR}/id_model/id_model.cpp + ${NVFUSER_SRCS_DIR}/id_model/loop_promotion.cpp ${NVFUSER_SRCS_DIR}/id_model/to_string.cpp ${NVFUSER_SRCS_DIR}/id_model/transform_replay.cpp ${NVFUSER_SRCS_DIR}/id_model/validation_utils.cpp diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index d69e1620531..58521ee1a9c 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include #include #include #include @@ -580,6 +581,8 @@ void IdModel::buildLoopGraph() { validateLoopGraphHasNoSelfMappedLeafDomains(); idGraph(IdMappingMode::LOOP).validateConsistency(); + + auto loop_promotion_map2 = LoopPromotionMapBuilder::get(*this, inlining_info); } std::unordered_map IdModel::buildLoopPromotionMap( @@ -703,7 +706,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( // Insert the updated Step-3 results into the Step-5 resutls. Note // that this insertion does not overwrite the existing mappings. - final_iel_promotion_map.insert( + final_loop_promotion_map.insert( loop_promotion_map.begin(), loop_promotion_map.end()); sanityCheckLoopPromotionMap(final_loop_promotion_map); diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 8255e6a8000..167f0007561 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -20,6 +20,7 @@ namespace nvfuser { class ValGraph; +class LoopPromotionMapBuilder; struct StatefulInliningInfo { // All producer ids within (including dependencies of) inlined leaf domains, @@ -95,6 +96,7 @@ StatefulInliningInfo buildStatefulInliningInfo( // Subgraph of the permissive graph. Maps only CA and their // dependent domains. class IdModel : public PolymorphicBase { + friend class LoopPromotionMapBuilder; public: // Sometimes fusion inputs or outputs are disconnected from expressions, in // those cases we still may want to send in some additional tensor views from diff --git a/csrc/id_model/loop_promotion.cpp b/csrc/id_model/loop_promotion.cpp new file mode 100644 index 00000000000..4874df0c7cb --- /dev/null +++ b/csrc/id_model/loop_promotion.cpp @@ -0,0 +1,159 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include + +namespace nvfuser { + +LoopPromotionMapBuilder::LoopPromotionMapBuilder( + IdModel& id_model, + const StatefulInliningInfo& inlining_info) + : id_model_(id_model), inlining_info_(inlining_info) {} + +void LoopPromotionMapBuilder::build() { + auto& loop_graph = id_model_.idGraph(IdMappingMode::LOOP); + + std::cerr << nvfuser::idGroupsString(loop_graph); + std::cerr << "Size: " << inlining_info_.ordered_p_ca_ids.size() << std::endl; +#if 0 + // Make an intersection of the exact and loop map. This will group together + // entries in each loop group that are exact with each other. This provides a + // better graph to do promotion and replays. + // + // It's tempting to use the intersection of the almost exact and loop, but we + // need to model broadcast promotion, and if we have two tensors like: + // + // T1[i0, b1] = T0[i0] + // T2[i0, b2] = T0[i0] + // Then resolution of: + // T4 = T1[i0, b1] + T3[i0, i1] + // T6 = T2[i0, b2] + T5[i0, i2] + // + // Then merge(0, 1) with all tensors except for T0 + // + // The almost exact map will map i0, i0*b1, and i0*b2 together, but b1 and b2 + // are being resolved to i1 and i2 respectively. So we want to have separate + // entries so we can have an easy to process promotion map. + // + // Loop is a permissive like map, it could have many entries, use the exact + // map as the one we iterate on to reduce complexity as it hopefully has + // smaller groups and this algorithm scales with the number of groups * + // (number of entries in groups ^ 2) + // + // iel stands for Intersection of the Exact and Loop graphs. + ValGraph iel_graph = buildIntersection( + id_model_.idGraph(IdMappingMode::EXACT), id_model_.idGraph(IdMappingMode::LOOP), false); + + // Step 1: Build a map of the IEL groups of root broadcast domains + // to resolving domains. + std::unordered_map iel_promotion_map = + buildInlineRootResolutionMap(iel_graph, inlining_info_); + + // Step 2: Propagate the root promotions to intermediate and leaf groups. + // At this point, the promotion may not be final as the analysis is + // localized to IEL groups. The map is used in the next step to + // build mappings of the loop groups. + propagatePromotionsInIELGraph(iel_graph, iel_promotion_map); + + // Step 3: Determine the promotion of each loop graph based on the + // IEL promotion map. For each loop group, examine all the IEL + // promotions and find the most representative one that captures all + // the dependent input domains of the loop group + std::unordered_map loop_promotion_map = + projectIELPromotionToLoopGraph( + iel_graph, + iel_promotion_map, + idGraph(IdMappingMode::LOOP), + inlining_info); + + // At this point, most of loop groups should have correct promoted + // IDs. However, non-inlined loop groups may miss promotion that + // should be propagated from parent ID groups, e.g., iS50 of T2 in + // Indexing19. Its parent ID loop group is promoted, but the loop + // group of iS50 is not found yet. + + // Step 4: In order to fully propagate the loop graph promotions, first + // propagate them to the IEL groups, which are then used to + // propagate back to the loop groups in Step 5. Unlike Step 2, the + // initial IEL promotion map is empty and is populated with the loop + // promotion map as we traverse down the IEL graph. + std::unordered_map final_iel_promotion_map; + propagatePromotionsInIELGraph( + iel_graph, + final_iel_promotion_map, + idGraph(IdMappingMode::LOOP), + loop_promotion_map); + + // Step 5: Find the final promotion of each loop group based on the + // final IEL promotion map + auto final_loop_promotion_map = projectIELPromotionToLoopGraph( + iel_graph, + final_iel_promotion_map, + idGraph(IdMappingMode::LOOP), + inlining_info); + + // The promotion map produced in Step 5 only includes those are + // further propagated at Step 4, so the correct mappings produced at + // Step 3 may not be included in the Step-5 results. Any Step-3 mappings + // that are not found in the Step-5 results are already valid + // results, so merge them into the Step-5 results. + // + // For example, in the below case, nothing will be propated at Step + // 4. + // + // t0: [i0] + // t1: [i1, i2] + // t2 = broadcast(t0, {true, false}) + // t3 = t2 + t1 + // + // t2: [b3, i4] + // t3: [i5, i6] + // + // t3->merge(0) + // propagate-and-inline-most + // + // t0: [i0] ca_pos(1) + // t1: [i1*i2] ca_pos(1) + // t2: [b3*i4] ca_pos(1) + // t3: [i5*i6] + // + // In this case, all domains will be grouped together and there will + // be just a single group in the Loop graph: + // + // - {i0, i1, i2, b3, i4, i5, i6, i1*i2, b3*i4, i5*i6} + // + // Step 3 will identify i5*i6 is the promotion domain. Since all + // domains are promoted to i5*i6, there will be no propagation in + // Step 4 (i.e., loop_promote_inputs will be false). Since the + // result of Step 4 is empty, the Step 5 result will also be empty, + // but that just means there's no change is necessary from the Step + // 3 results. + + // Update the Step-3 map to the latest LOOP graph + loop_promotion_map = + updateValGroupIdMap(loop_promotion_map, idGraph(IdMappingMode::LOOP)); + + // Insert the updated Step-3 results into the Step-5 resutls. Note + // that this insertion does not overwrite the existing mappings. + final_loop_promotion_map.insert( + loop_promotion_map.begin(), loop_promotion_map.end()); + + sanityCheckLoopPromotionMap(final_loop_promotion_map); +#endif +} + +std::unordered_map LoopPromotionMapBuilder::get( + IdModel& id_model, + const StatefulInliningInfo& inlining_info) { + LoopPromotionMapBuilder builder(id_model, inlining_info); + builder.build(); + return builder.loop_promotion_map_; +} + +} // namespace nvfuser diff --git a/csrc/id_model/loop_promotion.h b/csrc/id_model/loop_promotion.h new file mode 100644 index 00000000000..dfc9812fe09 --- /dev/null +++ b/csrc/id_model/loop_promotion.h @@ -0,0 +1,39 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser { + +class IdModel; +struct StatefulInliningInfo; + +class LoopPromotionMapBuilder { + public: + // Build a map of loop groups to IterDomains that represent actual + // loops. The map is built based on the broadcast resolution with + // root domains between inlined producer and consumer tensors. + static std::unordered_map get( + IdModel& id_model, + const StatefulInliningInfo& inlining_info); + + private: + LoopPromotionMapBuilder( + IdModel& id_model, + const StatefulInliningInfo& inlining_info); + + void build(); + + private: + IdModel& id_model_; + const StatefulInliningInfo& inlining_info_; + std::unordered_map loop_promotion_map_; +}; + +} // namespace nvfuser From c05fd851e2b3f96abb3e1b02178b8f199bc4fa48 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 May 2024 12:58:29 -0700 Subject: [PATCH 168/178] Copied all loop promotion code --- csrc/id_model/id_model.cpp | 6 +- csrc/id_model/id_model.h | 27 +- csrc/id_model/loop_promotion.cpp | 665 ++++++++++++++++++++++++++++++- csrc/id_model/loop_promotion.h | 84 ++++ 4 files changed, 752 insertions(+), 30 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 58521ee1a9c..ff5d3fd5cb7 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -226,7 +226,7 @@ std::string IdModel::toString() const { return ss.str(); } -ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) { +ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) const { ValGraph id_graph(propagate_through_exprs); // To deterministically initialize the graph, the order of adding @@ -716,7 +716,7 @@ std::unordered_map IdModel::buildLoopPromotionMap( std::unordered_map IdModel::buildInlineRootResolutionMap( const ValGraph& iel_graph, - const StatefulInliningInfo& info) { + const StatefulInliningInfo& info) const { std::unordered_map iel_promotion_map; // This should probably work just on terminating inputs, as we shouldn't be @@ -918,7 +918,7 @@ void IdModel::maybeBuildGraph(IdMappingMode mode) { ValGraph IdModel::buildIntersection( const ValGraph& graph0, const ValGraph& graph1, - bool propagate_exprs) { + bool propagate_exprs) const { ValGraph intersection = initializeIdGraph(propagate_exprs); for (const ValGroup& group0 : graph0.disjointValSets().disjointSets()) { auto set_size = group0->size(); diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 167f0007561..c58c17531b0 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -96,7 +96,6 @@ StatefulInliningInfo buildStatefulInliningInfo( // Subgraph of the permissive graph. Maps only CA and their // dependent domains. class IdModel : public PolymorphicBase { - friend class LoopPromotionMapBuilder; public: // Sometimes fusion inputs or outputs are disconnected from expressions, in // those cases we still may want to send in some additional tensor views from @@ -127,6 +126,16 @@ class IdModel : public PolymorphicBase { const ValGraph& idGraph(IdMappingMode mode) const; ValGraph& idGraph(IdMappingMode mode); + const std::unordered_map>& idUses() + const { + return id_uses_; + } + + const std::unordered_map>& + idDefinitions() const { + return id_definitions_; + } + // TODO: Seems a bit unfortunate that this isn't IterDomain local information. const std::unordered_set& viewRfactorIds() const { return view_rfactor_ids_; @@ -165,19 +174,24 @@ class IdModel : public PolymorphicBase { // Iterates over all IterDomains in id_definitions_ and calls initializeVal on // a new ValGraph and returns it. - ValGraph initializeIdGraph(bool propagate_through_exprs = true); + ValGraph initializeIdGraph(bool propagate_through_exprs = true) const; // Returns an IdGraph with all Id's mapped that are mapped both in graph0 and // graph1. ValGraph buildIntersection( const ValGraph& graph0, const ValGraph& graph1, - bool propagate_exprs = true); + bool propagate_exprs = true) const; const std::unordered_map& loopPromotionMap() const { return loop_promotion_map_; } + // Replay Expr but with the inputs provided. ValGraphs will be updated + // for all maps that have entries, adding the output iter domains of the + // replayed expression and adding potential mappings through the expression. + Expr* addReplayAs(std::vector new_inputs, Expr* expr); + protected: // Fills id_uses_ and id_definitions_ for all IterDomains active in the // fusion. @@ -197,7 +211,7 @@ class IdModel : public PolymorphicBase { // IterDomain picked from its IEL group. std::unordered_map buildInlineRootResolutionMap( const ValGraph& iel_graph, - const StatefulInliningInfo& info); + const StatefulInliningInfo& info) const; // Helper function for building loop promotion map. // @@ -284,11 +298,6 @@ class IdModel : public PolymorphicBase { // tensor. void validateLoopGraphHasNoSelfMappedLeafDomains() const; - // Replay Expr but with the inputs provided. ValGraphs will be updated - // for all maps that have entries, adding the output iter domains of the - // replayed expression and adding potential mappings through the expression. - Expr* addReplayAs(std::vector new_inputs, Expr* expr); - protected: // All tensor expressions that this model analyzes std::vector tv_exprs_; diff --git a/csrc/id_model/loop_promotion.cpp b/csrc/id_model/loop_promotion.cpp index 4874df0c7cb..73af2c3ee08 100644 --- a/csrc/id_model/loop_promotion.cpp +++ b/csrc/id_model/loop_promotion.cpp @@ -8,6 +8,8 @@ #include #include #include +#include +#include namespace nvfuser { @@ -16,12 +18,15 @@ LoopPromotionMapBuilder::LoopPromotionMapBuilder( const StatefulInliningInfo& inlining_info) : id_model_(id_model), inlining_info_(inlining_info) {} -void LoopPromotionMapBuilder::build() { - auto& loop_graph = id_model_.idGraph(IdMappingMode::LOOP); +ValGraph& LoopPromotionMapBuilder::idGraph(IdMappingMode mode) { + return id_model_.idGraph(mode); +} + +const ValGraph& LoopPromotionMapBuilder::idGraph(IdMappingMode mode) const { + return id_model_.idGraph(mode); +} - std::cerr << nvfuser::idGroupsString(loop_graph); - std::cerr << "Size: " << inlining_info_.ordered_p_ca_ids.size() << std::endl; -#if 0 +void LoopPromotionMapBuilder::build() { // Make an intersection of the exact and loop map. This will group together // entries in each loop group that are exact with each other. This provides a // better graph to do promotion and replays. @@ -47,8 +52,8 @@ void LoopPromotionMapBuilder::build() { // (number of entries in groups ^ 2) // // iel stands for Intersection of the Exact and Loop graphs. - ValGraph iel_graph = buildIntersection( - id_model_.idGraph(IdMappingMode::EXACT), id_model_.idGraph(IdMappingMode::LOOP), false); + ValGraph iel_graph = id_model_.buildIntersection( + idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); // Step 1: Build a map of the IEL groups of root broadcast domains // to resolving domains. @@ -65,12 +70,12 @@ void LoopPromotionMapBuilder::build() { // IEL promotion map. For each loop group, examine all the IEL // promotions and find the most representative one that captures all // the dependent input domains of the loop group - std::unordered_map loop_promotion_map = + std::unordered_map initial_loop_promotion_map = projectIELPromotionToLoopGraph( iel_graph, iel_promotion_map, idGraph(IdMappingMode::LOOP), - inlining_info); + inlining_info_); // At this point, most of loop groups should have correct promoted // IDs. However, non-inlined loop groups may miss promotion that @@ -88,15 +93,15 @@ void LoopPromotionMapBuilder::build() { iel_graph, final_iel_promotion_map, idGraph(IdMappingMode::LOOP), - loop_promotion_map); + initial_loop_promotion_map); // Step 5: Find the final promotion of each loop group based on the // final IEL promotion map - auto final_loop_promotion_map = projectIELPromotionToLoopGraph( + loop_promotion_map_ = projectIELPromotionToLoopGraph( iel_graph, final_iel_promotion_map, idGraph(IdMappingMode::LOOP), - inlining_info); + inlining_info_); // The promotion map produced in Step 5 only includes those are // further propagated at Step 4, so the correct mappings produced at @@ -136,16 +141,640 @@ void LoopPromotionMapBuilder::build() { // 3 results. // Update the Step-3 map to the latest LOOP graph - loop_promotion_map = - updateValGroupIdMap(loop_promotion_map, idGraph(IdMappingMode::LOOP)); + initial_loop_promotion_map = updateValGroupIdMap( + initial_loop_promotion_map, idGraph(IdMappingMode::LOOP)); // Insert the updated Step-3 results into the Step-5 resutls. Note // that this insertion does not overwrite the existing mappings. - final_loop_promotion_map.insert( - loop_promotion_map.begin(), loop_promotion_map.end()); + loop_promotion_map_.insert( + initial_loop_promotion_map.begin(), initial_loop_promotion_map.end()); + + sanityCheckLoopPromotionMap(loop_promotion_map_); +} + +std::unordered_map LoopPromotionMapBuilder:: + buildInlineRootResolutionMap( + const ValGraph& iel_graph, + const StatefulInliningInfo& info) const { + std::unordered_map iel_promotion_map; + + // This should probably work just on terminating inputs, as we shouldn't be + // able to modify a broadcast domain between root and rfactor which would be + // required to resolve a non input broadcast domain. But for now leaving it as + // traversal on all broadcast groups. + // + + // We first visit all broadcast root domains. If a broadcast is + // resovled, see if it's promoted. Note that a domain be resolved to + // a domain that may not be loop mapped, yet it can still be + // promoted. In other words, there can be a domain that is exactly + // mapped with the resolving domain *and* is mapped with the + // broadcast domain by the loop map. The algorihm here is: + // + // 1. For a broadcast domain, find the domain that the broadcast is + // resolved to. + // 2. If the resolving domain is also loop-mapped with the + // broadcast, that is the promotion domain, but the resolving + // domain may not be loop mapped as mentioned above. Instead, + // find all loop-mapped domains with the broadcast domain and + // pick one that is exactly mapped with the resolving domain + // + // Note again this process is only done for root domains. Once we + // find promotion relationships for root domains, we propagate the + // mappings to derived domains + for (const ValGroup& iel_group : iel_graph.disjointValSets().disjointSets()) { + NVF_ERROR(!iel_group->empty()); + + IterDomain* iel_group_id = iel_group->front()->as(); + + if (!iel_group_id->isBroadcast()) { + continue; + } + + // Collect all the exact groups of the resolutions of the broadcast id's + ValGroups resolved_exact_groups; + for (Val* bcast_id : *iel_group) { + if (auto p2c_root_broadcast_resolution_map_it = + info.p2c_root_broadcast_resolution_map.find( + bcast_id->as()); + p2c_root_broadcast_resolution_map_it != + info.p2c_root_broadcast_resolution_map.end()) { + resolved_exact_groups.pushBack( + idGraph(IdMappingMode::EXACT) + .toGroups(p2c_root_broadcast_resolution_map_it->second)); + } + } + + if (resolved_exact_groups.empty()) { + // No resolution + continue; + } + + // resolved_exact_groups is a list of IDs that resolves the + // broadcast. We only care those that are also in the same loop + // group, and there must be just one or none. Otherwise, the + // resolution is ambiguous. + + // Collect all the exact groups in the loop set containing this iel_group + const ValGroup& loop_group = + idGraph(IdMappingMode::LOOP).toGroup(iel_group_id); + ValGroups loop_covered_exact_groups = + idGraph(IdMappingMode::EXACT).toGroups(*loop_group); + + // The intersection of the exact groups that the broadcast domains can be + // broadcasted to, and those that exist within the same loop groop are is + // the promotion needed for this iel_group. The promotion should + // be none or unique. + ValGroups loop_exact_resolved_intersection = + resolved_exact_groups.computeIntersect(loop_covered_exact_groups); + + if (loop_exact_resolved_intersection.empty()) { + // No promotion + continue; + } + + if (loop_exact_resolved_intersection.size() > 1) { + // Ambiguous promotion. This should not happen. + std::stringstream err_msg; + err_msg + << "Invalid multiple broadcast resolution within shared loops detected, group:\n " + << iel_group->toString() << "\nIs being broadcasted to:"; + for (const ValGroup& entry : loop_exact_resolved_intersection) { + err_msg << "\n " << entry->toString(); + } + NVF_ERROR(false, err_msg.str()); + } + + const ValGroup& exact_resolution_group = + loop_exact_resolved_intersection.front(); + + // Within the loop group, find the IDs that the broadcast IDs are + // resolved to + VectorOfUniqueEntries resolved_ids = + exact_resolution_group->computeIntersect(*loop_group); + + NVF_ERROR(!resolved_ids.empty()); + + // All the IDs in resolved_ids are mapped with both of the exact + // and loop graphs, so any of them can be used as an IEL promotion + // ID. Just to make it extra clear, look for corresponding + // groups in the IEL graph and make sure there's only one such group. + ValGroups promoted_iel_groups = iel_graph.toGroups(resolved_ids); + + NVF_ERROR(!promoted_iel_groups.empty()); + + if (promoted_iel_groups.size() > 1) { + std::stringstream err_msg; + err_msg + << "Invalid multiple broadcast resolution within shared loops detected, group:\n " + << iel_group->toString() << "\nIs being broadcasted to:"; + for (const ValGroup& entry : promoted_iel_groups) { + err_msg << "\n " << entry->toString(); + } + NVF_ERROR(false, err_msg.str()); + } + + iel_promotion_map[iel_group] = + promoted_iel_groups.front()->front()->as(); + } + + return iel_promotion_map; +} + +namespace { + +// Check if there's an equivalent expression as iel_expr that uses +// maybe_promoted_inputs. This is used to avoid redundantly replaying +// expressions. +// NOTE: This is currently overly conservative and some +// opportunities for reuse are lost, althought it doesn't affect +// the correctness of the analysis. +Expr* findMatchingExpr( + const ExprGroup& iel_expr, + const ValGraph& iel_graph, + const std::vector& maybe_promoted_inputs, + const ValGraph& loop_graph) { + // If any of domains in maybe_promoted_inputs is not found in + // iel_graph, it means the domain is just replayed and by definition + // has no mapping with any existing domain, which means there's no + // matching expr. + if (std::any_of( + maybe_promoted_inputs.begin(), + maybe_promoted_inputs.end(), + [&](IterDomain* maybe_promoted_input) -> bool { + return !iel_graph.hasGroup(maybe_promoted_input); + })) { + return nullptr; + } + + // Grab all eligible uses of the promoted inputs. + // Note that any eligible matching expr should be a use of all + // inputs in maybe_promoted_input_uses, no matter it's promoted or + // not. So it isn't necessary to look at all of + // maybe_promoted_input_uses but just need to grab one. + NVF_ERROR(!maybe_promoted_inputs.empty()); + ExprGroups maybe_promoted_input_uses = + iel_graph.getUses(iel_graph.toGroup(maybe_promoted_inputs.front())); + + if (maybe_promoted_input_uses.empty()) { + return nullptr; + } + + // Look for exprs that have inputs that are mapped in the IEL + // graph with the (promoted) inputs of iel_expr. + for (const ExprGroup& maybe_promoted_input_use_group : + maybe_promoted_input_uses) { + NVF_ERROR(!maybe_promoted_input_use_group->empty()); + // maybe_promoted_inputs may include non-promoted inputs as + // well, so maybe_promoted_input_uses may include the original + // iel_expr itself. Since there must at least be a promoted input, + // iel_expr itself should not be an expr group we are looking for. + if (iel_expr == maybe_promoted_input_use_group) { + continue; + } + Expr* maybe_promoted_input_use = maybe_promoted_input_use_group->front(); + if (!iel_expr->front()->sameOp(maybe_promoted_input_use)) { + continue; + } + // Check if all inputs are mapped + NVF_ERROR( + maybe_promoted_inputs.size() == + maybe_promoted_input_use->inputs().size()); + bool all_inputs_match = true; + for (const auto inp_i : c10::irange(maybe_promoted_inputs.size())) { + all_inputs_match = all_inputs_match && + iel_graph.disjointValSets().strictAreMapped( + maybe_promoted_inputs[inp_i], + maybe_promoted_input_use->inputs().at(inp_i)); + } + if (!all_inputs_match) { + continue; + } + + // We always want to find promotions within the same loop + // groups since we are looking for domains that represent actual + // loops. Note that that's guaranteed when a new domain is + // replayed instead of reusing an existing domain. + if (!loop_graph.disjointExprSets().permissiveAreMapped( + iel_expr->front(), maybe_promoted_input_use_group->front())) { + continue; + } + // This is just an extra sanity check. Make sure all exprs in + // the use group are mapped + NVF_ERROR( + std::all_of( + maybe_promoted_input_use_group->vector().begin(), + maybe_promoted_input_use_group->vector().end(), + [&](Expr* iel_use) { + return loop_graph.disjointExprSets().permissiveAreMapped( + iel_expr->front(), iel_use); + }), + "Not all mapped: ", + nvfuser::toString(iel_expr), + "\n", + nvfuser::toString(maybe_promoted_input_use_group)); + + return maybe_promoted_input_use; + } + + return nullptr; +} + +// When propagating loop promotions from inputs to outputs of an IEL +// expr, we can't blindly apply loop promotion when all of the input +// domains are loop mapped with the outputs. +// +// i.e. if we have the inlined domains from: +// Inputs: +// T0[i0] +// T1[i0, i1] +// +// T2[i0, b2] = broadcast(T0) +// T3[i0, i1] = T2 + T1 +// +// {T1, T2, T3}->merge(0, 1) +// inlineMost +// +// The inlined loop group would consist of: +// +// {i0, i1, b2, i0*b2, i0*i1} +// +// Note that all these domains would have promotion to i0*i1 at the +// end of Step 3. When the IEL expression of merge(i0, i1) is visited by +// propagatePromotionsInIELGraph again, the promotion to i0*i1 of both +// inputs would be propagated to its output, resulting in promotion of +// i0*i1 to (i0*i1)*(i0*i1), which is not the correct propagation. +// +// Therefore only promote i0*b1 to i0*i1, or i0*i1 to i0*i1 (i.e. don't +// promote an input to any transformation within the loop group). +// +// So if we have an iel_expr make sure its inputs and outputs are not in +// the same loop group. +bool hasUniqueInputLoopGroups( + const ExprGroup& iel_expr, + const ValGraph& iel_graph, + const ValGraph& loop_graph) { + const std::vector iel_inp_groups = iel_graph.inputGroups(iel_expr); + + const std::vector iel_out_groups = iel_graph.outputGroups(iel_expr); + + ValGroups inp_loop_groups; + for (const ValGroup& iel_inp_group : iel_inp_groups) { + inp_loop_groups.pushBack(loop_graph.toGroup(iel_inp_group->front())); + } + ValGroups out_loop_groups; + for (const ValGroup& iel_out_group : iel_out_groups) { + out_loop_groups.pushBack(loop_graph.toGroup(iel_out_group->front())); + } + + // Check if input groups that are not included in the output group set + return !inp_loop_groups.computeSubtract(out_loop_groups).empty(); +} + +} // namespace + +void LoopPromotionMapBuilder::propagatePromotionsInIELGraph( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map, + const ValGraph& loop_graph, + const std::unordered_map& loop_graph_promotion_map) { + // In order to make this traversal work, the traversal order must be + // topologically sorted. + ValGraphStmtSort iel_stmt_sort(iel_graph); + + for (const ExprGroup& iel_expr : iel_stmt_sort.exprs()) { + NVF_ERROR(!iel_expr->empty()); + const std::vector iel_inp_groups = + iel_graph.inputGroups(iel_expr); + + // Check if any inputs need promotion indicating this expr group needs to + // be replayed with promoted inputs + bool an_input_was_promoted = false; + std::vector maybe_promoted_inputs; + maybe_promoted_inputs.reserve(iel_inp_groups.size()); + + // Propagate loop graph promotion only when the inputs and outputs are + // not in the same loop group. + const bool loop_promote_inputs = !loop_graph_promotion_map.empty() && + hasUniqueInputLoopGroups(iel_expr, iel_graph, loop_graph); + + for (const ValGroup& iel_inp_group : iel_inp_groups) { + // Assumed all inputs are IterDomains + NVF_ERROR(iel_inp_group->front()->isA()); + + // Propagate IEL promotions when available. + if (auto inp_promo_it = iel_promotion_map.find(iel_inp_group); + inp_promo_it != iel_promotion_map.end()) { + maybe_promoted_inputs.push_back(inp_promo_it->second); + an_input_was_promoted = true; + continue; + } + + // Promote loops based on the loop promotion map. If the loop promotion + // map should be used and has an entry we should use that promotion. + if (loop_promote_inputs) { + const ValGroup& loop_copy_group = + loop_graph.toGroup(iel_inp_group->front()); + auto inp_loop_promo_it = loop_graph_promotion_map.find(loop_copy_group); + if (inp_loop_promo_it != loop_graph_promotion_map.end()) { + maybe_promoted_inputs.push_back(inp_loop_promo_it->second); + an_input_was_promoted = true; + continue; + } + } + + // No promotion found. Just use the non-promoted domain + maybe_promoted_inputs.push_back(iel_inp_group->front()->as()); + } + + if (!an_input_was_promoted) { + // No inputs need promotion so just continue + continue; + } + + Expr* promoted_expr = findMatchingExpr( + iel_expr, + iel_graph, + maybe_promoted_inputs, + idGraph(IdMappingMode::LOOP)); + + bool replayed = false; + + if (!promoted_expr) { + promoted_expr = + id_model_.addReplayAs(maybe_promoted_inputs, iel_expr->front()); + replayed = true; + } + + // Mark outputs as having a promoted iter domain + std::vector out_groups = iel_graph.outputGroups(iel_expr); + NVF_ERROR(promoted_expr->outputs().size() == out_groups.size()); + NVF_ERROR( + ir_utils::filterByType(promoted_expr->outputs()).size() == + out_groups.size(), + "Unexpected non IterDomain outputs found: ", + promoted_expr->toString()); + + for (const auto i : c10::irange(out_groups.size())) { + // Promote if necessary, if the output is already in the same exact map + // it doesn't need a promotion. + if (idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped( + promoted_expr->output(i), out_groups[i]->front())) { + continue; + } + iel_promotion_map[out_groups[i]] = + promoted_expr->output(i)->as(); + // Explicitly map loop map since expr propagation doesn't happen + if (replayed) { + idGraph(IdMappingMode::LOOP) + .mapVals(iel_expr->front()->output(i), promoted_expr->output(i)); + } + } + } +} + +void LoopPromotionMapBuilder::propagatePromotionsInIELGraph( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map) { + propagatePromotionsInIELGraph( + iel_graph, iel_promotion_map, idGraph(IdMappingMode::LOOP), {}); +} + +namespace { + +// Returns for each ValGroup in provided IdGraph what the input ValGroups are +// traversing on definitions. Ignoring broadcast ValGroups and resetting inputs +// at RFactor ValGroups. +std::unordered_map computeCoveredGroups( + const ValGraph& graph, + const std::unordered_set& view_rfactor_ids) { + // Map from an exact iter domain group, to all the exact iter domain groups it + // covers + std::unordered_map covered_ids; + + for (const ValGroup& id_group : graph.disjointValSets().disjointSets()) { + // Initialize inputs + const ExprGroups& id_group_defs = graph.getDefinitions(id_group); + if (id_group_defs.empty()) { + covered_ids[id_group] = {id_group}; + } + + // Initialize rfactor groups + if (std::any_of(id_group->begin(), id_group->end(), [&](Val* id) { + return view_rfactor_ids.find(id->as()) != + view_rfactor_ids.end(); + })) { + covered_ids[id_group] = {id_group}; + } + + // Initialize broadcast groups to empty since broadcast domains + // don't matter for indexing + if (std::any_of(id_group->begin(), id_group->end(), [&](Val* id) { + return id->as()->isBroadcast(); + })) { + covered_ids[id_group] = {}; + } + } + + ValGraphStmtSort exact_stmt_sort(graph); + + for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) { + std::vector input_groups = graph.inputGroups(exact_expr); + + ValGroups covered; + for (const ValGroup& inp_group : input_groups) { + covered.pushBack(covered_ids.at(inp_group)); + } + + for (const ValGroup& output_group : graph.outputGroups(exact_expr)) { + // Don't overwrite initialized cases due to rfactor markings. + if (covered_ids.find(output_group) == covered_ids.end()) { + covered_ids[output_group] = covered; + } + } + } + + return covered_ids; +} + +}; // namespace + +std::unordered_map LoopPromotionMapBuilder:: + projectIELPromotionToLoopGraph( + const ValGraph& iel_graph, + const std::unordered_map& iel_promotion_map, + const ValGraph& loop_graph, + const StatefulInliningInfo& inlining_info) { + const std::unordered_map exact_covered_ids = + computeCoveredGroups( + idGraph(IdMappingMode::EXACT), id_model_.viewRfactorIds()); + + // Grab terminal iter domain in the loop groups. + const VectorOfUniqueEntries terminal_loop_ids = + computeTerminalLoopIds(inlining_info); + + std::unordered_map loop_promotion_map; + + for (const ValGroup& loop_group : + loop_graph.disjointValSets().disjointSets()) { + IterDomain* promotion_id = findPromotionOfLoopGroup( + loop_group, + iel_graph, + iel_promotion_map, + exact_covered_ids, + terminal_loop_ids); + if (promotion_id) { + loop_promotion_map[loop_group] = promotion_id; + } + } + + return loop_promotion_map; +} + +IterDomain* LoopPromotionMapBuilder::findPromotionOfLoopGroup( + const ValGroup& loop_group, + const ValGraph& iel_graph, + const std::unordered_map& iel_promotion_map, + const std::unordered_map& exact_covered_ids, + const VectorOfUniqueEntries& terminal_loop_ids) { + const ValGraph& exact_graph = idGraph(IdMappingMode::EXACT); + + // Grab all the (potentially promoted) terminal iter domains in this group. + // Save the exact group and the iter domain in this vector. + std::vector> exact_promoted_terminal_ids; + for (auto loop_id : *loop_group) { + // If not a terminal id in the group skip + if (!terminal_loop_ids.has(loop_id->as())) { + continue; + } + + // Grab the iel entry. There can be iter domains that were added + // after the IEL graph was built. All the promotion information is + // associated with the domains that exist in the original graph, + // so the new domains can be simply ignored. + if (!iel_graph.hasGroup(loop_id)) { + continue; + } + + const ValGroup& iel_group = iel_graph.toGroup(loop_id); + + // Does it still need iel_promotion_map? The loop group already has + // the replayed domains, so we should be able to find it. + auto iel_promo_it = iel_promotion_map.find(iel_group); + if (iel_promo_it == iel_promotion_map.end()) { + // If this terminal ID doesn't have a promotion associated with it, save + // the terminal ID. + exact_promoted_terminal_ids.emplace_back( + exact_graph.toGroup(loop_id), loop_id->as()); + } else { + // If this terminal ID has a promotion, grab the promoted ID. + exact_promoted_terminal_ids.emplace_back( + exact_graph.toGroup(iel_promo_it->second), iel_promo_it->second); + } + } + + // All the exact groups of the iter domains in the loop group + ValGroups exact_groups = exact_graph.toGroups(*loop_group); + + // All exact groups covered by all iter domains in this loop group + ValGroups loop_group_covered_ids; + for (const ValGroup& exact_group : exact_groups) { + auto covered_it = exact_covered_ids.find(exact_group); + NVF_ERROR(covered_it != exact_covered_ids.end()); + loop_group_covered_ids.pushBack(covered_it->second); + } + + // Check if any of the candidate Iter Domains we collected cover all the + // exact groups of loop_group_covered_ids. If so, that's the correct + // promoted iter domain of this group. + for (const auto& entry : exact_promoted_terminal_ids) { + const ValGroup& terminal_id_group = entry.first; + IterDomain* terminal_id = entry.second; + auto covered_it = exact_covered_ids.find(terminal_id_group); + NVF_ERROR(covered_it != exact_covered_ids.end()); + if (loop_group_covered_ids.computeSubtract(covered_it->second).empty()) { + return terminal_id; + } + } + + return nullptr; +} + +VectorOfUniqueEntries LoopPromotionMapBuilder:: + computeTerminalLoopIds(const StatefulInliningInfo& info) { + VectorOfUniqueEntries terminal_loop_ids; + for (const ValGroup& group : + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { + if (group->size() == 1) { + terminal_loop_ids.pushBack(group->front()->as()); + } + + // Don't select producer iter domains + for (auto loop_id : *group) { + if (info.p2c_ca_permissive_maps.find(loop_id->as()) != + info.p2c_ca_permissive_maps.end()) { + continue; + } + + // It's terminal if there's no use group + auto uses_it = id_model_.idUses().find(loop_id->as()); + if (uses_it == id_model_.idUses().end() || uses_it->second.empty()) { + terminal_loop_ids.pushBack(loop_id->as()); + continue; + } + + // If there's an output group that is not in the same group, + // then it's a terminal ID + bool all_outs_in_loop_group = true; + for (auto use : uses_it->second) { + if (std::any_of( + use->outputs().begin(), + use->outputs().end(), + [&](Val* out) -> bool { + return group != idGraph(IdMappingMode::LOOP).toGroup(out); + })) { + all_outs_in_loop_group = false; + break; + } + } + + if (!all_outs_in_loop_group) { + terminal_loop_ids.pushBack(loop_id->as()); + } + } + } + return terminal_loop_ids; +} - sanityCheckLoopPromotionMap(final_loop_promotion_map); -#endif +void LoopPromotionMapBuilder::sanityCheckLoopPromotionMap( + const std::unordered_map& loop_promotion_map) const { + const auto& loop_graph = idGraph(IdMappingMode::LOOP); + for (const ValGroup& loop_group : + loop_graph.disjointValSets().disjointSets()) { + // Non-leaf loop groups are not guaranteed to have valid + // promotions. See for example FusionRepro1713, where root domains + // are all grouped together but there's no valid promotion. + if (loop_graph.hasUses(loop_group)) { + continue; + } + // Make sure the loop group is promoted to a domain that is mapped + // in the LOOP graph + auto promotion_it = loop_promotion_map.find(loop_group); + NVF_ERROR( + promotion_it != loop_promotion_map.end(), + "Loop promotion not found for ", + nvfuser::toString(loop_group)); + IterDomain* promotion = promotion_it->second; + // Make sure the promotion domain is also loop-mapped + NVF_ERROR( + loop_group->has(promotion), + "Loop promotion not loop-mapped. Loop group: ", + nvfuser::toString(loop_group), + ". Promotion domain: ", + promotion->name()); + } } std::unordered_map LoopPromotionMapBuilder::get( diff --git a/csrc/id_model/loop_promotion.h b/csrc/id_model/loop_promotion.h index dfc9812fe09..c2f647d81a6 100644 --- a/csrc/id_model/loop_promotion.h +++ b/csrc/id_model/loop_promotion.h @@ -30,6 +30,90 @@ class LoopPromotionMapBuilder { void build(); + ValGraph& idGraph(IdMappingMode mode); + const ValGraph& idGraph(IdMappingMode mode) const; + + std::unordered_map buildInlineRootResolutionMap( + const ValGraph& iel_graph, + const StatefulInliningInfo& info) const; + + // Helper function for building loop promotion map. + // + // Propagate promotion mappings from root IEL groups to intermediate + // and leaf IEL groups by traversing IEL exprs. For each expr, if an + // input is promoted, the output needs to be promoted too. If + // there's already an equivalent expr that uses the promoted inputs, + // create a mapping from the outputs of the IEL expr to the outputs + // of the equivalent expr. We only consider exprs that are mapped + // in the loop graph as we are looking for domains that represent + // the actual loops of the input and output domains of the IEL + // expr. If no such expr is found, the IEL expr is replayed with the + // promoted inputs. + // + // This is used twice when building the promotion map. The first time + // it is used there's no loop graph promotion yet, so only the IEL + // promotions are propagated. In that case, loop_graph_promotion_map + // should be just empty. + // + // Propagation uses iel_promotion_map and + // loop_graph_promotion_map. If both are available for an IEL group, + // the former has the precedence. This is because when this function + // is used for step 4, the given iel_promotion_map starts as an + // empty map and gets populated during this propagation, so any + // mapping in the map is guaranteed to be the correct final mapping, + // whereas the loop graph may have invalid mappings for partially + // inlined domains. + void propagatePromotionsInIELGraph( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map, + const ValGraph& loop_graph, + const std::unordered_map& loop_promotion_map); + + // Same as the other propagatePromotionsInIELGraph but without loop + // graph map. This is used for step 2, where there's no loop + // graph map yet. + void propagatePromotionsInIELGraph( + const ValGraph& iel_graph, + std::unordered_map& iel_promotion_map); + + // Given an IEL promotion map, identify the mapping of each loop + // group. The promotion must represent all the domains in each loop + // group. If a valid representative promotion is not found for a + // loop group, no mapping is added for the group. + std::unordered_map projectIELPromotionToLoopGraph( + const ValGraph& iel_graph, + const std::unordered_map& iel_promotion_map, + const ValGraph& loop_graph, + const StatefulInliningInfo& inlining_info); + + // Find a promoted iter domain of a given loop group that covers all + // the exact groups representative of the resolved transformations + // within the loop group. Specifically, we examine each IEL group of + // the loop group, and if an IEL group has a promotion, we consider it as a + // candidate of the promotion of this loop group. If not, we include a + // domain of the IEL group as a candidate too. Once all candidates are + // obtained, we pick one that covers all the exact domains (cf. concrete + // domains in ComputeAtMap) + IterDomain* findPromotionOfLoopGroup( + const ValGroup& loop_group, + const ValGraph& iel_graph, + const std::unordered_map& iel_promotion_map, + const std::unordered_map& exact_covered_ids, + const VectorOfUniqueEntries& terminal_loop_ids); + + // Terminal loop ids are iteration domains in each loop group that: + // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a + // consumer TV's iter domain maps to this domain in a way that that domain + // is also in the same loop group + // 2) Don't have a direct IterDomain consumer within the group + VectorOfUniqueEntries computeTerminalLoopIds( + const StatefulInliningInfo& info); + + // Basic consistency check of the given loop promotion map + void sanityCheckLoopPromotionMap( + const std::unordered_map& loop_promotion_map) + const; + private: IdModel& id_model_; const StatefulInliningInfo& inlining_info_; From 31d28e570f30f10b86b0a211bd61420eee3d3210 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 May 2024 12:59:43 -0700 Subject: [PATCH 169/178] enable idmodel --- csrc/device_lower/lower2device.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 76181bd72b4..f988ce65c4f 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -391,7 +391,7 @@ void GpuLower::analysis(Fusion* fusion) { // functionality should be affected. New IterDomains may be created, // so it is expected that generated code may use diffrent variable // names - if (isOptionEnabled(EnableOption::IdModel)) { + if (true || isOptionEnabled(EnableOption::IdModel)) { IdModel id_model(fusion_); } From d0bdc8ede7e9737af9234f09bf4382c6c541cc71 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 May 2024 13:02:56 -0700 Subject: [PATCH 170/178] Switch to the new builder --- csrc/id_model/id_model.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index ff5d3fd5cb7..d3d8446f96e 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -574,15 +574,13 @@ void IdModel::buildLoopGraph() { validateLoopGraphHasNoSelfMappedLeafDomains(); - loop_promotion_map_ = buildLoopPromotionMap(inlining_info); + loop_promotion_map_ = LoopPromotionMapBuilder::get(*this, inlining_info); // New domains are added. Make sure there's still no self mapping in // the leaf domains validateLoopGraphHasNoSelfMappedLeafDomains(); idGraph(IdMappingMode::LOOP).validateConsistency(); - - auto loop_promotion_map2 = LoopPromotionMapBuilder::get(*this, inlining_info); } std::unordered_map IdModel::buildLoopPromotionMap( From aa9b414e3f0ba2674821dca1ed8b637355ad0bd9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 May 2024 13:08:48 -0700 Subject: [PATCH 171/178] const --- csrc/id_model/loop_promotion.cpp | 6 +++--- csrc/id_model/loop_promotion.h | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/id_model/loop_promotion.cpp b/csrc/id_model/loop_promotion.cpp index 73af2c3ee08..68491e1a9c5 100644 --- a/csrc/id_model/loop_promotion.cpp +++ b/csrc/id_model/loop_promotion.cpp @@ -606,7 +606,7 @@ std::unordered_map LoopPromotionMapBuilder:: const ValGraph& iel_graph, const std::unordered_map& iel_promotion_map, const ValGraph& loop_graph, - const StatefulInliningInfo& inlining_info) { + const StatefulInliningInfo& inlining_info) const { const std::unordered_map exact_covered_ids = computeCoveredGroups( idGraph(IdMappingMode::EXACT), id_model_.viewRfactorIds()); @@ -638,7 +638,7 @@ IterDomain* LoopPromotionMapBuilder::findPromotionOfLoopGroup( const ValGraph& iel_graph, const std::unordered_map& iel_promotion_map, const std::unordered_map& exact_covered_ids, - const VectorOfUniqueEntries& terminal_loop_ids) { + const VectorOfUniqueEntries& terminal_loop_ids) const { const ValGraph& exact_graph = idGraph(IdMappingMode::EXACT); // Grab all the (potentially promoted) terminal iter domains in this group. @@ -703,7 +703,7 @@ IterDomain* LoopPromotionMapBuilder::findPromotionOfLoopGroup( } VectorOfUniqueEntries LoopPromotionMapBuilder:: - computeTerminalLoopIds(const StatefulInliningInfo& info) { + computeTerminalLoopIds(const StatefulInliningInfo& info) const { VectorOfUniqueEntries terminal_loop_ids; for (const ValGroup& group : idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { diff --git a/csrc/id_model/loop_promotion.h b/csrc/id_model/loop_promotion.h index c2f647d81a6..bbde115a623 100644 --- a/csrc/id_model/loop_promotion.h +++ b/csrc/id_model/loop_promotion.h @@ -84,7 +84,7 @@ class LoopPromotionMapBuilder { const ValGraph& iel_graph, const std::unordered_map& iel_promotion_map, const ValGraph& loop_graph, - const StatefulInliningInfo& inlining_info); + const StatefulInliningInfo& inlining_info) const; // Find a promoted iter domain of a given loop group that covers all // the exact groups representative of the resolved transformations @@ -99,7 +99,7 @@ class LoopPromotionMapBuilder { const ValGraph& iel_graph, const std::unordered_map& iel_promotion_map, const std::unordered_map& exact_covered_ids, - const VectorOfUniqueEntries& terminal_loop_ids); + const VectorOfUniqueEntries& terminal_loop_ids) const; // Terminal loop ids are iteration domains in each loop group that: // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a @@ -107,7 +107,7 @@ class LoopPromotionMapBuilder { // is also in the same loop group // 2) Don't have a direct IterDomain consumer within the group VectorOfUniqueEntries computeTerminalLoopIds( - const StatefulInliningInfo& info); + const StatefulInliningInfo& info) const; // Basic consistency check of the given loop promotion map void sanityCheckLoopPromotionMap( From 7aef29d2d73ea668568a1071c16818317b604cb7 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 May 2024 13:15:23 -0700 Subject: [PATCH 172/178] cleanup --- csrc/id_model/loop_promotion.cpp | 13 +++++++------ csrc/id_model/loop_promotion.h | 3 +-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/id_model/loop_promotion.cpp b/csrc/id_model/loop_promotion.cpp index 68491e1a9c5..6ea37b7f5f6 100644 --- a/csrc/id_model/loop_promotion.cpp +++ b/csrc/id_model/loop_promotion.cpp @@ -26,7 +26,7 @@ const ValGraph& LoopPromotionMapBuilder::idGraph(IdMappingMode mode) const { return id_model_.idGraph(mode); } -void LoopPromotionMapBuilder::build() { +std::unordered_map LoopPromotionMapBuilder::build() { // Make an intersection of the exact and loop map. This will group together // entries in each loop group that are exact with each other. This provides a // better graph to do promotion and replays. @@ -97,7 +97,7 @@ void LoopPromotionMapBuilder::build() { // Step 5: Find the final promotion of each loop group based on the // final IEL promotion map - loop_promotion_map_ = projectIELPromotionToLoopGraph( + auto final_loop_promotion_map = projectIELPromotionToLoopGraph( iel_graph, final_iel_promotion_map, idGraph(IdMappingMode::LOOP), @@ -146,10 +146,12 @@ void LoopPromotionMapBuilder::build() { // Insert the updated Step-3 results into the Step-5 resutls. Note // that this insertion does not overwrite the existing mappings. - loop_promotion_map_.insert( + final_loop_promotion_map.insert( initial_loop_promotion_map.begin(), initial_loop_promotion_map.end()); - sanityCheckLoopPromotionMap(loop_promotion_map_); + sanityCheckLoopPromotionMap(final_loop_promotion_map); + + return final_loop_promotion_map; } std::unordered_map LoopPromotionMapBuilder:: @@ -781,8 +783,7 @@ std::unordered_map LoopPromotionMapBuilder::get( IdModel& id_model, const StatefulInliningInfo& inlining_info) { LoopPromotionMapBuilder builder(id_model, inlining_info); - builder.build(); - return builder.loop_promotion_map_; + return builder.build(); } } // namespace nvfuser diff --git a/csrc/id_model/loop_promotion.h b/csrc/id_model/loop_promotion.h index bbde115a623..a16b8220fb9 100644 --- a/csrc/id_model/loop_promotion.h +++ b/csrc/id_model/loop_promotion.h @@ -28,7 +28,7 @@ class LoopPromotionMapBuilder { IdModel& id_model, const StatefulInliningInfo& inlining_info); - void build(); + std::unordered_map build(); ValGraph& idGraph(IdMappingMode mode); const ValGraph& idGraph(IdMappingMode mode) const; @@ -117,7 +117,6 @@ class LoopPromotionMapBuilder { private: IdModel& id_model_; const StatefulInliningInfo& inlining_info_; - std::unordered_map loop_promotion_map_; }; } // namespace nvfuser From 24a0c8ce33f9be2278f60ffb6d6705374bac759e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 May 2024 13:25:13 -0700 Subject: [PATCH 173/178] Remove loop promotion code from IdModel --- csrc/id_model/id_model.cpp | 751 ------------------------------- csrc/id_model/id_model.h | 85 ---- csrc/id_model/loop_promotion.cpp | 6 +- 3 files changed, 3 insertions(+), 839 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index d3d8446f96e..b436e1fc1f1 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -583,263 +583,6 @@ void IdModel::buildLoopGraph() { idGraph(IdMappingMode::LOOP).validateConsistency(); } -std::unordered_map IdModel::buildLoopPromotionMap( - const StatefulInliningInfo& inlining_info) { - // Make an intersection of the exact and loop map. This will group together - // entries in each loop group that are exact with each other. This provides a - // better graph to do promotion and replays. - // - // It's tempting to use the intersection of the almost exact and loop, but we - // need to model broadcast promotion, and if we have two tensors like: - // - // T1[i0, b1] = T0[i0] - // T2[i0, b2] = T0[i0] - // Then resolution of: - // T4 = T1[i0, b1] + T3[i0, i1] - // T6 = T2[i0, b2] + T5[i0, i2] - // - // Then merge(0, 1) with all tensors except for T0 - // - // The almost exact map will map i0, i0*b1, and i0*b2 together, but b1 and b2 - // are being resolved to i1 and i2 respectively. So we want to have separate - // entries so we can have an easy to process promotion map. - // - // Loop is a permissive like map, it could have many entries, use the exact - // map as the one we iterate on to reduce complexity as it hopefully has - // smaller groups and this algorithm scales with the number of groups * - // (number of entries in groups ^ 2) - // - // iel stands for Intersection of the Exact and Loop graphs. - ValGraph iel_graph = buildIntersection( - idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); - - // Step 1: Build a map of the IEL groups of root broadcast domains - // to resolving domains. - std::unordered_map iel_promotion_map = - buildInlineRootResolutionMap(iel_graph, inlining_info); - - // Step 2: Propagate the root promotions to intermediate and leaf groups. - // At this point, the promotion may not be final as the analysis is - // localized to IEL groups. The map is used in the next step to - // build mappings of the loop groups. - propagatePromotionsInIELGraph(iel_graph, iel_promotion_map); - - // Step 3: Determine the promotion of each loop graph based on the - // IEL promotion map. For each loop group, examine all the IEL - // promotions and find the most representative one that captures all - // the dependent input domains of the loop group - std::unordered_map loop_promotion_map = - projectIELPromotionToLoopGraph( - iel_graph, - iel_promotion_map, - idGraph(IdMappingMode::LOOP), - inlining_info); - - // At this point, most of loop groups should have correct promoted - // IDs. However, non-inlined loop groups may miss promotion that - // should be propagated from parent ID groups, e.g., iS50 of T2 in - // Indexing19. Its parent ID loop group is promoted, but the loop - // group of iS50 is not found yet. - - // Step 4: In order to fully propagate the loop graph promotions, first - // propagate them to the IEL groups, which are then used to - // propagate back to the loop groups in Step 5. Unlike Step 2, the - // initial IEL promotion map is empty and is populated with the loop - // promotion map as we traverse down the IEL graph. - std::unordered_map final_iel_promotion_map; - propagatePromotionsInIELGraph( - iel_graph, - final_iel_promotion_map, - idGraph(IdMappingMode::LOOP), - loop_promotion_map); - - // Step 5: Find the final promotion of each loop group based on the - // final IEL promotion map - auto final_loop_promotion_map = projectIELPromotionToLoopGraph( - iel_graph, - final_iel_promotion_map, - idGraph(IdMappingMode::LOOP), - inlining_info); - - // The promotion map produced in Step 5 only includes those are - // further propagated at Step 4, so the correct mappings produced at - // Step 3 may not be included in the Step-5 results. Any Step-3 mappings - // that are not found in the Step-5 results are already valid - // results, so merge them into the Step-5 results. - // - // For example, in the below case, nothing will be propated at Step - // 4. - // - // t0: [i0] - // t1: [i1, i2] - // t2 = broadcast(t0, {true, false}) - // t3 = t2 + t1 - // - // t2: [b3, i4] - // t3: [i5, i6] - // - // t3->merge(0) - // propagate-and-inline-most - // - // t0: [i0] ca_pos(1) - // t1: [i1*i2] ca_pos(1) - // t2: [b3*i4] ca_pos(1) - // t3: [i5*i6] - // - // In this case, all domains will be grouped together and there will - // be just a single group in the Loop graph: - // - // - {i0, i1, i2, b3, i4, i5, i6, i1*i2, b3*i4, i5*i6} - // - // Step 3 will identify i5*i6 is the promotion domain. Since all - // domains are promoted to i5*i6, there will be no propagation in - // Step 4 (i.e., loop_promote_inputs will be false). Since the - // result of Step 4 is empty, the Step 5 result will also be empty, - // but that just means there's no change is necessary from the Step - // 3 results. - - // Update the Step-3 map to the latest LOOP graph - loop_promotion_map = - updateValGroupIdMap(loop_promotion_map, idGraph(IdMappingMode::LOOP)); - - // Insert the updated Step-3 results into the Step-5 resutls. Note - // that this insertion does not overwrite the existing mappings. - final_loop_promotion_map.insert( - loop_promotion_map.begin(), loop_promotion_map.end()); - - sanityCheckLoopPromotionMap(final_loop_promotion_map); - - return final_loop_promotion_map; -} - -std::unordered_map IdModel::buildInlineRootResolutionMap( - const ValGraph& iel_graph, - const StatefulInliningInfo& info) const { - std::unordered_map iel_promotion_map; - - // This should probably work just on terminating inputs, as we shouldn't be - // able to modify a broadcast domain between root and rfactor which would be - // required to resolve a non input broadcast domain. But for now leaving it as - // traversal on all broadcast groups. - // - - // We first visit all broadcast root domains. If a broadcast is - // resovled, see if it's promoted. Note that a domain be resolved to - // a domain that may not be loop mapped, yet it can still be - // promoted. In other words, there can be a domain that is exactly - // mapped with the resolving domain *and* is mapped with the - // broadcast domain by the loop map. The algorihm here is: - // - // 1. For a broadcast domain, find the domain that the broadcast is - // resolved to. - // 2. If the resolving domain is also loop-mapped with the - // broadcast, that is the promotion domain, but the resolving - // domain may not be loop mapped as mentioned above. Instead, - // find all loop-mapped domains with the broadcast domain and - // pick one that is exactly mapped with the resolving domain - // - // Note again this process is only done for root domains. Once we - // find promotion relationships for root domains, we propagate the - // mappings to derived domains - for (const ValGroup& iel_group : iel_graph.disjointValSets().disjointSets()) { - NVF_ERROR(!iel_group->empty()); - - IterDomain* iel_group_id = iel_group->front()->as(); - - if (!iel_group_id->isBroadcast()) { - continue; - } - - // Collect all the exact groups of the resolutions of the broadcast id's - ValGroups resolved_exact_groups; - for (Val* bcast_id : *iel_group) { - if (auto p2c_root_broadcast_resolution_map_it = - info.p2c_root_broadcast_resolution_map.find( - bcast_id->as()); - p2c_root_broadcast_resolution_map_it != - info.p2c_root_broadcast_resolution_map.end()) { - resolved_exact_groups.pushBack( - idGraph(IdMappingMode::EXACT) - .toGroups(p2c_root_broadcast_resolution_map_it->second)); - } - } - - if (resolved_exact_groups.empty()) { - // No resolution - continue; - } - - // resolved_exact_groups is a list of IDs that resolves the - // broadcast. We only care those that are also in the same loop - // group, and there must be just one or none. Otherwise, the - // resolution is ambiguous. - - // Collect all the exact groups in the loop set containing this iel_group - const ValGroup& loop_group = - idGraph(IdMappingMode::LOOP).toGroup(iel_group_id); - ValGroups loop_covered_exact_groups = - idGraph(IdMappingMode::EXACT).toGroups(*loop_group); - - // The intersection of the exact groups that the broadcast domains can be - // broadcasted to, and those that exist within the same loop groop are is - // the promotion needed for this iel_group. The promotion should - // be none or unique. - ValGroups loop_exact_resolved_intersection = - resolved_exact_groups.computeIntersect(loop_covered_exact_groups); - - if (loop_exact_resolved_intersection.empty()) { - // No promotion - continue; - } - - if (loop_exact_resolved_intersection.size() > 1) { - // Ambiguous promotion. This should not happen. - std::stringstream err_msg; - err_msg - << "Invalid multiple broadcast resolution within shared loops detected, group:\n " - << iel_group->toString() << "\nIs being broadcasted to:"; - for (const ValGroup& entry : loop_exact_resolved_intersection) { - err_msg << "\n " << entry->toString(); - } - NVF_ERROR(false, err_msg.str()); - } - - const ValGroup& exact_resolution_group = - loop_exact_resolved_intersection.front(); - - // Within the loop group, find the IDs that the broadcast IDs are - // resolved to - VectorOfUniqueEntries resolved_ids = - exact_resolution_group->computeIntersect(*loop_group); - - NVF_ERROR(!resolved_ids.empty()); - - // All the IDs in resolved_ids are mapped with both of the exact - // and loop graphs, so any of them can be used as an IEL promotion - // ID. Just to make it extra clear, look for corresponding - // groups in the IEL graph and make sure there's only one such group. - ValGroups promoted_iel_groups = iel_graph.toGroups(resolved_ids); - - NVF_ERROR(!promoted_iel_groups.empty()); - - if (promoted_iel_groups.size() > 1) { - std::stringstream err_msg; - err_msg - << "Invalid multiple broadcast resolution within shared loops detected, group:\n " - << iel_group->toString() << "\nIs being broadcasted to:"; - for (const ValGroup& entry : promoted_iel_groups) { - err_msg << "\n " << entry->toString(); - } - NVF_ERROR(false, err_msg.str()); - } - - iel_promotion_map[iel_group] = - promoted_iel_groups.front()->front()->as(); - } - - return iel_promotion_map; -} - void IdModel::buildAllGraphs() { if (tvs_.empty()) { return; @@ -935,266 +678,6 @@ ValGraph IdModel::buildIntersection( return intersection; } -namespace { - -// Check if there's an equivalent expression as iel_expr that uses -// maybe_promoted_inputs. This is used to avoid redundantly replaying -// expressions. -// NOTE: This is currently overly conservative and some -// opportunities for reuse are lost, althought it doesn't affect -// the correctness of the analysis. -Expr* findMatchingExpr( - const ExprGroup& iel_expr, - const ValGraph& iel_graph, - const std::vector& maybe_promoted_inputs, - const ValGraph& loop_graph) { - // If any of domains in maybe_promoted_inputs is not found in - // iel_graph, it means the domain is just replayed and by definition - // has no mapping with any existing domain, which means there's no - // matching expr. - if (std::any_of( - maybe_promoted_inputs.begin(), - maybe_promoted_inputs.end(), - [&](IterDomain* maybe_promoted_input) -> bool { - return !iel_graph.hasGroup(maybe_promoted_input); - })) { - return nullptr; - } - - // Grab all eligible uses of the promoted inputs. - // Note that any eligible matching expr should be a use of all - // inputs in maybe_promoted_input_uses, no matter it's promoted or - // not. So it isn't necessary to look at all of - // maybe_promoted_input_uses but just need to grab one. - NVF_ERROR(!maybe_promoted_inputs.empty()); - ExprGroups maybe_promoted_input_uses = - iel_graph.getUses(iel_graph.toGroup(maybe_promoted_inputs.front())); - - if (maybe_promoted_input_uses.empty()) { - return nullptr; - } - - // Look for exprs that have inputs that are mapped in the IEL - // graph with the (promoted) inputs of iel_expr. - for (const ExprGroup& maybe_promoted_input_use_group : - maybe_promoted_input_uses) { - NVF_ERROR(!maybe_promoted_input_use_group->empty()); - // maybe_promoted_inputs may include non-promoted inputs as - // well, so maybe_promoted_input_uses may include the original - // iel_expr itself. Since there must at least be a promoted input, - // iel_expr itself should not be an expr group we are looking for. - if (iel_expr == maybe_promoted_input_use_group) { - continue; - } - Expr* maybe_promoted_input_use = maybe_promoted_input_use_group->front(); - if (!iel_expr->front()->sameOp(maybe_promoted_input_use)) { - continue; - } - // Check if all inputs are mapped - NVF_ERROR( - maybe_promoted_inputs.size() == - maybe_promoted_input_use->inputs().size()); - bool all_inputs_match = true; - for (const auto inp_i : c10::irange(maybe_promoted_inputs.size())) { - all_inputs_match = all_inputs_match && - iel_graph.disjointValSets().strictAreMapped( - maybe_promoted_inputs[inp_i], - maybe_promoted_input_use->inputs().at(inp_i)); - } - if (!all_inputs_match) { - continue; - } - - // We always want to find promotions within the same loop - // groups since we are looking for domains that represent actual - // loops. Note that that's guaranteed when a new domain is - // replayed instead of reusing an existing domain. - if (!loop_graph.disjointExprSets().permissiveAreMapped( - iel_expr->front(), maybe_promoted_input_use_group->front())) { - continue; - } - // This is just an extra sanity check. Make sure all exprs in - // the use group are mapped - NVF_ERROR( - std::all_of( - maybe_promoted_input_use_group->vector().begin(), - maybe_promoted_input_use_group->vector().end(), - [&](Expr* iel_use) { - return loop_graph.disjointExprSets().permissiveAreMapped( - iel_expr->front(), iel_use); - }), - "Not all mapped: ", - nvfuser::toString(iel_expr), - "\n", - nvfuser::toString(maybe_promoted_input_use_group)); - - return maybe_promoted_input_use; - } - - return nullptr; -} - -// When propagating loop promotions from inputs to outputs of an IEL -// expr, we can't blindly apply loop promotion when all of the input -// domains are loop mapped with the outputs. -// -// i.e. if we have the inlined domains from: -// Inputs: -// T0[i0] -// T1[i0, i1] -// -// T2[i0, b2] = broadcast(T0) -// T3[i0, i1] = T2 + T1 -// -// {T1, T2, T3}->merge(0, 1) -// inlineMost -// -// The inlined loop group would consist of: -// -// {i0, i1, b2, i0*b2, i0*i1} -// -// Note that all these domains would have promotion to i0*i1 at the -// end of Step 3. When the IEL expression of merge(i0, i1) is visited by -// propagatePromotionsInIELGraph again, the promotion to i0*i1 of both -// inputs would be propagated to its output, resulting in promotion of -// i0*i1 to (i0*i1)*(i0*i1), which is not the correct propagation. -// -// Therefore only promote i0*b1 to i0*i1, or i0*i1 to i0*i1 (i.e. don't -// promote an input to any transformation within the loop group). -// -// So if we have an iel_expr make sure its inputs and outputs are not in -// the same loop group. -bool hasUniqueInputLoopGroups( - const ExprGroup& iel_expr, - const ValGraph& iel_graph, - const ValGraph& loop_graph) { - const std::vector iel_inp_groups = iel_graph.inputGroups(iel_expr); - - const std::vector iel_out_groups = iel_graph.outputGroups(iel_expr); - - ValGroups inp_loop_groups; - for (const ValGroup& iel_inp_group : iel_inp_groups) { - inp_loop_groups.pushBack(loop_graph.toGroup(iel_inp_group->front())); - } - ValGroups out_loop_groups; - for (const ValGroup& iel_out_group : iel_out_groups) { - out_loop_groups.pushBack(loop_graph.toGroup(iel_out_group->front())); - } - - // Check if input groups that are not included in the output group set - return !inp_loop_groups.computeSubtract(out_loop_groups).empty(); -} - -} // namespace - -void IdModel::propagatePromotionsInIELGraph( - const ValGraph& iel_graph, - std::unordered_map& iel_promotion_map, - const ValGraph& loop_graph, - const std::unordered_map& loop_graph_promotion_map) { - // In order to make this traversal work, the traversal order must be - // topologically sorted. - ValGraphStmtSort iel_stmt_sort(iel_graph); - - for (const ExprGroup& iel_expr : iel_stmt_sort.exprs()) { - NVF_ERROR(!iel_expr->empty()); - const std::vector iel_inp_groups = - iel_graph.inputGroups(iel_expr); - - // Check if any inputs need promotion indicating this expr group needs to - // be replayed with promoted inputs - bool an_input_was_promoted = false; - std::vector maybe_promoted_inputs; - maybe_promoted_inputs.reserve(iel_inp_groups.size()); - - // Propagate loop graph promotion only when the inputs and outputs are - // not in the same loop group. - const bool loop_promote_inputs = !loop_graph_promotion_map.empty() && - hasUniqueInputLoopGroups(iel_expr, iel_graph, loop_graph); - - for (const ValGroup& iel_inp_group : iel_inp_groups) { - // Assumed all inputs are IterDomains - NVF_ERROR(iel_inp_group->front()->isA()); - - // Propagate IEL promotions when available. - if (auto inp_promo_it = iel_promotion_map.find(iel_inp_group); - inp_promo_it != iel_promotion_map.end()) { - maybe_promoted_inputs.push_back(inp_promo_it->second); - an_input_was_promoted = true; - continue; - } - - // Promote loops based on the loop promotion map. If the loop promotion - // map should be used and has an entry we should use that promotion. - if (loop_promote_inputs) { - const ValGroup& loop_copy_group = - loop_graph.toGroup(iel_inp_group->front()); - auto inp_loop_promo_it = loop_graph_promotion_map.find(loop_copy_group); - if (inp_loop_promo_it != loop_graph_promotion_map.end()) { - maybe_promoted_inputs.push_back(inp_loop_promo_it->second); - an_input_was_promoted = true; - continue; - } - } - - // No promotion found. Just use the non-promoted domain - maybe_promoted_inputs.push_back(iel_inp_group->front()->as()); - } - - if (!an_input_was_promoted) { - // No inputs need promotion so just continue - continue; - } - - Expr* promoted_expr = findMatchingExpr( - iel_expr, - iel_graph, - maybe_promoted_inputs, - idGraph(IdMappingMode::LOOP)); - - bool replayed = false; - - if (!promoted_expr) { - promoted_expr = addReplayAs(maybe_promoted_inputs, iel_expr->front()); - replayed = true; - } - - // Mark outputs as having a promoted iter domain - std::vector out_groups = iel_graph.outputGroups(iel_expr); - NVF_ERROR(promoted_expr->outputs().size() == out_groups.size()); - NVF_ERROR( - ir_utils::filterByType(promoted_expr->outputs()).size() == - out_groups.size(), - "Unexpected non IterDomain outputs found: ", - promoted_expr->toString()); - - for (const auto i : c10::irange(out_groups.size())) { - // Promote if necessary, if the output is already in the same exact map - // it doesn't need a promotion. - if (idGraph(IdMappingMode::EXACT) - .disjointValSets() - .strictAreMapped( - promoted_expr->output(i), out_groups[i]->front())) { - continue; - } - iel_promotion_map[out_groups[i]] = - promoted_expr->output(i)->as(); - // Explicitly map loop map since expr propagation doesn't happen - if (replayed) { - idGraph(IdMappingMode::LOOP) - .mapVals(iel_expr->front()->output(i), promoted_expr->output(i)); - } - } - } -} - -void IdModel::propagatePromotionsInIELGraph( - const ValGraph& iel_graph, - std::unordered_map& iel_promotion_map) { - propagatePromotionsInIELGraph( - iel_graph, iel_promotion_map, idGraph(IdMappingMode::LOOP), {}); -} - // Replay Expr but with the inputs provided. Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { // Figure out which graphs are already initialized to make sure we add the new @@ -1311,240 +794,6 @@ Expr* IdModel::addReplayAs(std::vector new_inputs, Expr* expr) { return replay; } -namespace { - -// Returns for each ValGroup in provided IdGraph what the input ValGroups are -// traversing on definitions. Ignoring broadcast ValGroups and resetting inputs -// at RFactor ValGroups. -std::unordered_map computeCoveredGroups( - const ValGraph& graph, - const std::unordered_set& view_rfactor_ids) { - // Map from an exact iter domain group, to all the exact iter domain groups it - // covers - std::unordered_map covered_ids; - - for (const ValGroup& id_group : graph.disjointValSets().disjointSets()) { - // Initialize inputs - const ExprGroups& id_group_defs = graph.getDefinitions(id_group); - if (id_group_defs.empty()) { - covered_ids[id_group] = {id_group}; - } - - // Initialize rfactor groups - if (std::any_of(id_group->begin(), id_group->end(), [&](Val* id) { - return view_rfactor_ids.find(id->as()) != - view_rfactor_ids.end(); - })) { - covered_ids[id_group] = {id_group}; - } - - // Initialize broadcast groups to empty since broadcast domains - // don't matter for indexing - if (std::any_of(id_group->begin(), id_group->end(), [&](Val* id) { - return id->as()->isBroadcast(); - })) { - covered_ids[id_group] = {}; - } - } - - ValGraphStmtSort exact_stmt_sort(graph); - - for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) { - std::vector input_groups = graph.inputGroups(exact_expr); - - ValGroups covered; - for (const ValGroup& inp_group : input_groups) { - covered.pushBack(covered_ids.at(inp_group)); - } - - for (const ValGroup& output_group : graph.outputGroups(exact_expr)) { - // Don't overwrite initialized cases due to rfactor markings. - if (covered_ids.find(output_group) == covered_ids.end()) { - covered_ids[output_group] = covered; - } - } - } - - return covered_ids; -} - -}; // namespace - -std::unordered_map IdModel:: - projectIELPromotionToLoopGraph( - const ValGraph& iel_graph, - const std::unordered_map& iel_promotion_map, - const ValGraph& loop_graph, - const StatefulInliningInfo& inlining_info) { - const std::unordered_map exact_covered_ids = - computeCoveredGroups(idGraph(IdMappingMode::EXACT), view_rfactor_ids_); - - // Grab terminal iter domain in the loop groups. - const VectorOfUniqueEntries terminal_loop_ids = - computeTerminalLoopIds(inlining_info); - - std::unordered_map loop_promotion_map; - - for (const ValGroup& loop_group : - loop_graph.disjointValSets().disjointSets()) { - IterDomain* promotion_id = findPromotionOfLoopGroup( - loop_group, - iel_graph, - iel_promotion_map, - exact_covered_ids, - terminal_loop_ids); - if (promotion_id) { - loop_promotion_map[loop_group] = promotion_id; - } - } - - return loop_promotion_map; -} - -IterDomain* IdModel::findPromotionOfLoopGroup( - const ValGroup& loop_group, - const ValGraph& iel_graph, - const std::unordered_map& iel_promotion_map, - const std::unordered_map& exact_covered_ids, - const VectorOfUniqueEntries& terminal_loop_ids) { - const ValGraph& exact_graph = idGraph(IdMappingMode::EXACT); - - // Grab all the (potentially promoted) terminal iter domains in this group. - // Save the exact group and the iter domain in this vector. - std::vector> exact_promoted_terminal_ids; - for (auto loop_id : *loop_group) { - // If not a terminal id in the group skip - if (!terminal_loop_ids.has(loop_id->as())) { - continue; - } - - // Grab the iel entry. There can be iter domains that were added - // after the IEL graph was built. All the promotion information is - // associated with the domains that exist in the original graph, - // so the new domains can be simply ignored. - if (!iel_graph.hasGroup(loop_id)) { - continue; - } - - const ValGroup& iel_group = iel_graph.toGroup(loop_id); - - // Does it still need iel_promotion_map? The loop group already has - // the replayed domains, so we should be able to find it. - auto iel_promo_it = iel_promotion_map.find(iel_group); - if (iel_promo_it == iel_promotion_map.end()) { - // If this terminal ID doesn't have a promotion associated with it, save - // the terminal ID. - exact_promoted_terminal_ids.emplace_back( - exact_graph.toGroup(loop_id), loop_id->as()); - } else { - // If this terminal ID has a promotion, grab the promoted ID. - exact_promoted_terminal_ids.emplace_back( - exact_graph.toGroup(iel_promo_it->second), iel_promo_it->second); - } - } - - // All the exact groups of the iter domains in the loop group - ValGroups exact_groups = exact_graph.toGroups(*loop_group); - - // All exact groups covered by all iter domains in this loop group - ValGroups loop_group_covered_ids; - for (const ValGroup& exact_group : exact_groups) { - auto covered_it = exact_covered_ids.find(exact_group); - NVF_ERROR(covered_it != exact_covered_ids.end()); - loop_group_covered_ids.pushBack(covered_it->second); - } - - // Check if any of the candidate Iter Domains we collected cover all the - // exact groups of loop_group_covered_ids. If so, that's the correct - // promoted iter domain of this group. - for (const auto& entry : exact_promoted_terminal_ids) { - const ValGroup& terminal_id_group = entry.first; - IterDomain* terminal_id = entry.second; - auto covered_it = exact_covered_ids.find(terminal_id_group); - NVF_ERROR(covered_it != exact_covered_ids.end()); - if (loop_group_covered_ids.computeSubtract(covered_it->second).empty()) { - return terminal_id; - } - } - - return nullptr; -} - -VectorOfUniqueEntries IdModel::computeTerminalLoopIds( - const StatefulInliningInfo& info) { - VectorOfUniqueEntries terminal_loop_ids; - for (const ValGroup& group : - idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { - if (group->size() == 1) { - terminal_loop_ids.pushBack(group->front()->as()); - } - - // Don't select producer iter domains - for (auto loop_id : *group) { - if (info.p2c_ca_permissive_maps.find(loop_id->as()) != - info.p2c_ca_permissive_maps.end()) { - continue; - } - - // It's terminal if there's no use group - auto uses_it = id_uses_.find(loop_id->as()); - if (uses_it == id_uses_.end() || uses_it->second.empty()) { - terminal_loop_ids.pushBack(loop_id->as()); - continue; - } - - // If there's an output group that is not in the same group, - // then it's a terminal ID - bool all_outs_in_loop_group = true; - for (auto use : uses_it->second) { - if (std::any_of( - use->outputs().begin(), - use->outputs().end(), - [&](Val* out) -> bool { - return group != idGraph(IdMappingMode::LOOP).toGroup(out); - })) { - all_outs_in_loop_group = false; - break; - } - } - - if (!all_outs_in_loop_group) { - terminal_loop_ids.pushBack(loop_id->as()); - } - } - } - return terminal_loop_ids; -} - -void IdModel::sanityCheckLoopPromotionMap( - const std::unordered_map& loop_promotion_map) const { - const auto& loop_graph = idGraph(IdMappingMode::LOOP); - for (const ValGroup& loop_group : - loop_graph.disjointValSets().disjointSets()) { - // Non-leaf loop groups are not guaranteed to have valid - // promotions. See for example FusionRepro1713, where root domains - // are all grouped together but there's no valid promotion. - if (loop_graph.hasUses(loop_group)) { - continue; - } - // Make sure the loop group is promoted to a domain that is mapped - // in the LOOP graph - auto promotion_it = loop_promotion_map.find(loop_group); - NVF_ERROR( - promotion_it != loop_promotion_map.end(), - "Loop promotion not found for ", - nvfuser::toString(loop_group)); - IterDomain* promotion = promotion_it->second; - // Make sure the promotion domain is also loop-mapped - NVF_ERROR( - loop_group->has(promotion), - "Loop promotion not loop-mapped. Loop group: ", - nvfuser::toString(loop_group), - ". Promotion domain: ", - promotion->name()); - } -} - void IdModel::validateLoopGraphHasNoSelfMappedLeafDomains() const { for (auto tv : tvs_) { auto self_mappped_leaf_pair = diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index c58c17531b0..64db01c0064 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -20,7 +20,6 @@ namespace nvfuser { class ValGraph; -class LoopPromotionMapBuilder; struct StatefulInliningInfo { // All producer ids within (including dependencies of) inlined leaf domains, @@ -206,93 +205,9 @@ class IdModel : public PolymorphicBase { std::unordered_map buildLoopPromotionMap( const StatefulInliningInfo& info); - // Helper function for buildLoopPromotionMap. Returns a map of - // root broadcast ValGroups in the IEL graph to a representative - // IterDomain picked from its IEL group. - std::unordered_map buildInlineRootResolutionMap( - const ValGraph& iel_graph, - const StatefulInliningInfo& info) const; - - // Helper function for building loop promotion map. - // - // Propagate promotion mappings from root IEL groups to intermediate - // and leaf IEL groups by traversing IEL exprs. For each expr, if an - // input is promoted, the output needs to be promoted too. If - // there's already an equivalent expr that uses the promoted inputs, - // create a mapping from the outputs of the IEL expr to the outputs - // of the equivalent expr. We only consider exprs that are mapped - // in the loop graph as we are looking for domains that represent - // the actual loops of the input and output domains of the IEL - // expr. If no such expr is found, the IEL expr is replayed with the - // promoted inputs. - // - // This is used twice when building the promotion map. The first time - // it is used there's no loop graph promotion yet, so only the IEL - // promotions are propagated. In that case, loop_graph_promotion_map - // should be just empty. - // - // Propagation uses iel_promotion_map and - // loop_graph_promotion_map. If both are available for an IEL group, - // the former has the precedence. This is because when this function - // is used for step 4, the given iel_promotion_map starts as an - // empty map and gets populated during this propagation, so any - // mapping in the map is guaranteed to be the correct final mapping, - // whereas the loop graph may have invalid mappings for partially - // inlined domains. - void propagatePromotionsInIELGraph( - const ValGraph& iel_graph, - std::unordered_map& iel_promotion_map, - const ValGraph& loop_graph, - const std::unordered_map& loop_promotion_map); - - // Same as the other propagatePromotionsInIELGraph but without loop - // graph map. This is used for step 2, where there's no loop - // graph map yet. - void propagatePromotionsInIELGraph( - const ValGraph& iel_graph, - std::unordered_map& iel_promotion_map); - - // Given an IEL promotion map, identify the mapping of each loop - // group. The promotion must represent all the domains in each loop - // group. If a valid representative promotion is not found for a - // loop group, no mapping is added for the group. - std::unordered_map projectIELPromotionToLoopGraph( - const ValGraph& iel_graph, - const std::unordered_map& iel_promotion_map, - const ValGraph& loop_graph, - const StatefulInliningInfo& inlining_info); - - // Find a promoted iter domain of a given loop group that covers all - // the exact groups representative of the resolved transformations - // within the loop group. Specifically, we examine each IEL group of - // the loop group, and if an IEL group has a promotion, we consider it as a - // candidate of the promotion of this loop group. If not, we include a - // domain of the IEL group as a candidate too. Once all candidates are - // obtained, we pick one that covers all the exact domains (cf. concrete - // domains in ComputeAtMap) - IterDomain* findPromotionOfLoopGroup( - const ValGroup& loop_group, - const ValGraph& iel_graph, - const std::unordered_map& iel_promotion_map, - const std::unordered_map& exact_covered_ids, - const VectorOfUniqueEntries& terminal_loop_ids); - - // Terminal loop ids are iteration domains in each loop group that: - // 1) Don't have an entry in p2c_ca_permissive_maps, which would mean a - // consumer TV's iter domain maps to this domain in a way that that domain - // is also in the same loop group - // 2) Don't have a direct IterDomain consumer within the group - VectorOfUniqueEntries computeTerminalLoopIds( - const StatefulInliningInfo& info); - // Errors if self mapping occurs void assertNoSelfMapping(); - // Basic consistency check of the given loop promotion map - void sanityCheckLoopPromotionMap( - const std::unordered_map& loop_promotion_map) - const; - // Loop graph represents the loop structure of the given fusion, so // there must not be any mapping between the leaf domains of each // tensor. diff --git a/csrc/id_model/loop_promotion.cpp b/csrc/id_model/loop_promotion.cpp index 6ea37b7f5f6..fe9b494371d 100644 --- a/csrc/id_model/loop_promotion.cpp +++ b/csrc/id_model/loop_promotion.cpp @@ -70,7 +70,7 @@ std::unordered_map LoopPromotionMapBuilder::build() { // IEL promotion map. For each loop group, examine all the IEL // promotions and find the most representative one that captures all // the dependent input domains of the loop group - std::unordered_map initial_loop_promotion_map = + const std::unordered_map initial_loop_promotion_map = projectIELPromotionToLoopGraph( iel_graph, iel_promotion_map, @@ -141,13 +141,13 @@ std::unordered_map LoopPromotionMapBuilder::build() { // 3 results. // Update the Step-3 map to the latest LOOP graph - initial_loop_promotion_map = updateValGroupIdMap( + const auto updated_initial_loop_promotion_map = updateValGroupIdMap( initial_loop_promotion_map, idGraph(IdMappingMode::LOOP)); // Insert the updated Step-3 results into the Step-5 resutls. Note // that this insertion does not overwrite the existing mappings. final_loop_promotion_map.insert( - initial_loop_promotion_map.begin(), initial_loop_promotion_map.end()); + updated_initial_loop_promotion_map.begin(), updated_initial_loop_promotion_map.end()); sanityCheckLoopPromotionMap(final_loop_promotion_map); From c7c04b16a89c32c6ee60fd8805986e4398f42bf5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 May 2024 14:53:57 -0700 Subject: [PATCH 174/178] replace tester with callback --- csrc/id_model/id_model.cpp | 14 ++-- csrc/id_model/id_model.h | 11 +++- csrc/id_model/loop_promotion.cpp | 32 +++++++-- csrc/id_model/loop_promotion.h | 23 ++++++- tests/cpp/test_id_model.cpp | 107 +++++++++++++------------------ 5 files changed, 112 insertions(+), 75 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index b436e1fc1f1..3eacaea16f2 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -75,8 +75,9 @@ IdModel::IdModel( const std::vector& exprs, const std::vector& additional_tvs, bool build_graphs, - bool allow_self_mapping) - : allow_self_mapping_(allow_self_mapping) { + bool allow_self_mapping, + LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback) + : allow_self_mapping_(allow_self_mapping), loop_promotion_map_builder_callback_(loop_promotion_map_builder_callback) { std::copy_if( exprs.begin(), exprs.end(), @@ -103,8 +104,11 @@ IdModel::IdModel( Fusion* fusion, bool build_graphs, bool allow_self_mapping, - bool validate) - : allow_self_mapping_(allow_self_mapping), validate_(validate) { + bool validate, + LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback) + : allow_self_mapping_(allow_self_mapping), + validate_(validate), + loop_promotion_map_builder_callback_(loop_promotion_map_builder_callback) { auto all_exprs = fusion->exprs(); std::copy_if( all_exprs.begin(), @@ -574,7 +578,7 @@ void IdModel::buildLoopGraph() { validateLoopGraphHasNoSelfMappedLeafDomains(); - loop_promotion_map_ = LoopPromotionMapBuilder::get(*this, inlining_info); + loop_promotion_map_ = LoopPromotionMapBuilder::get(*this, inlining_info, loop_promotion_map_builder_callback_); // New domains are added. Make sure there's still no self mapping in // the leaf domains diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 64db01c0064..a2565dd94e4 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -20,6 +20,7 @@ namespace nvfuser { class ValGraph; +class LoopPromotionMapBuilderCallback; struct StatefulInliningInfo { // All producer ids within (including dependencies of) inlined leaf domains, @@ -106,7 +107,8 @@ class IdModel : public PolymorphicBase { const std::vector& exprs, const std::vector& additional_tvs = {}, bool build_graphs = true, - bool allow_self_mapping = false); + bool allow_self_mapping = false, + LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback = nullptr); // Same as the above constructor with fusion->exprs() excpet fusion may have // some dangling inputs/outputs that are expected to have IterDomain entries @@ -118,7 +120,8 @@ class IdModel : public PolymorphicBase { Fusion* fusion, bool build_graphs = true, bool allow_self_mapping = false, - bool validate = true); + bool validate = true, + LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback = nullptr); // Returns iter domain graph of provided mode. The graph must have // been already built. @@ -227,6 +230,10 @@ class IdModel : public PolymorphicBase { // If true, validate graphs by comparing them with ComputeAtMap bool validate_ = false; + // Optional callback for the loop promotion map builder for + // debugging and testing + LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback_ = nullptr; + // By default, the permissive graph should map compliment domains as // well. See the design doc for more details bool permissive_graph_map_compliment_ids_ = true; diff --git a/csrc/id_model/loop_promotion.cpp b/csrc/id_model/loop_promotion.cpp index fe9b494371d..7fb16ce4fcd 100644 --- a/csrc/id_model/loop_promotion.cpp +++ b/csrc/id_model/loop_promotion.cpp @@ -15,8 +15,9 @@ namespace nvfuser { LoopPromotionMapBuilder::LoopPromotionMapBuilder( IdModel& id_model, - const StatefulInliningInfo& inlining_info) - : id_model_(id_model), inlining_info_(inlining_info) {} + const StatefulInliningInfo& inlining_info, + LoopPromotionMapBuilderCallback* callback) + : id_model_(id_model), inlining_info_(inlining_info), callback_(callback) {} ValGraph& LoopPromotionMapBuilder::idGraph(IdMappingMode mode) { return id_model_.idGraph(mode); @@ -52,7 +53,7 @@ std::unordered_map LoopPromotionMapBuilder::build() { // (number of entries in groups ^ 2) // // iel stands for Intersection of the Exact and Loop graphs. - ValGraph iel_graph = id_model_.buildIntersection( + const ValGraph iel_graph = id_model_.buildIntersection( idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); // Step 1: Build a map of the IEL groups of root broadcast domains @@ -60,12 +61,20 @@ std::unordered_map LoopPromotionMapBuilder::build() { std::unordered_map iel_promotion_map = buildInlineRootResolutionMap(iel_graph, inlining_info_); + if (callback_) { + callback_->postStep1(iel_promotion_map, iel_graph); + } + // Step 2: Propagate the root promotions to intermediate and leaf groups. // At this point, the promotion may not be final as the analysis is // localized to IEL groups. The map is used in the next step to // build mappings of the loop groups. propagatePromotionsInIELGraph(iel_graph, iel_promotion_map); + if (callback_) { + callback_->postStep2(iel_promotion_map, iel_graph); + } + // Step 3: Determine the promotion of each loop graph based on the // IEL promotion map. For each loop group, examine all the IEL // promotions and find the most representative one that captures all @@ -77,6 +86,10 @@ std::unordered_map LoopPromotionMapBuilder::build() { idGraph(IdMappingMode::LOOP), inlining_info_); + if (callback_) { + callback_->postStep3(initial_loop_promotion_map); + } + // At this point, most of loop groups should have correct promoted // IDs. However, non-inlined loop groups may miss promotion that // should be propagated from parent ID groups, e.g., iS50 of T2 in @@ -95,6 +108,10 @@ std::unordered_map LoopPromotionMapBuilder::build() { idGraph(IdMappingMode::LOOP), initial_loop_promotion_map); + if (callback_) { + callback_->postStep4(final_iel_promotion_map, iel_graph); + } + // Step 5: Find the final promotion of each loop group based on the // final IEL promotion map auto final_loop_promotion_map = projectIELPromotionToLoopGraph( @@ -151,6 +168,10 @@ std::unordered_map LoopPromotionMapBuilder::build() { sanityCheckLoopPromotionMap(final_loop_promotion_map); + if (callback_) { + callback_->postStep5(final_loop_promotion_map); + } + return final_loop_promotion_map; } @@ -781,8 +802,9 @@ void LoopPromotionMapBuilder::sanityCheckLoopPromotionMap( std::unordered_map LoopPromotionMapBuilder::get( IdModel& id_model, - const StatefulInliningInfo& inlining_info) { - LoopPromotionMapBuilder builder(id_model, inlining_info); + const StatefulInliningInfo& inlining_info, + LoopPromotionMapBuilderCallback* callback) { + LoopPromotionMapBuilder builder(id_model, inlining_info, callback); return builder.build(); } diff --git a/csrc/id_model/loop_promotion.h b/csrc/id_model/loop_promotion.h index a16b8220fb9..30d3a9e5b6e 100644 --- a/csrc/id_model/loop_promotion.h +++ b/csrc/id_model/loop_promotion.h @@ -14,6 +14,22 @@ namespace nvfuser { class IdModel; struct StatefulInliningInfo; +class LoopPromotionMapBuilderCallback { + public: + virtual ~LoopPromotionMapBuilderCallback() = default; + + virtual void postStep1( + const std::unordered_map& iel_root_resolution_map, + const ValGraph& iel_graph) {} + virtual void postStep2( + const std::unordered_map& iel_promotion_map, + const ValGraph& iel_graph) {} + virtual void postStep3(const std::unordered_map& loop_promotion_map) {} + virtual void postStep4(const std::unordered_map& iel_promotion_map, + const ValGraph& iel_graph) {} + virtual void postStep5(const std::unordered_map& loop_promotion_map) {} +}; + class LoopPromotionMapBuilder { public: // Build a map of loop groups to IterDomains that represent actual @@ -21,12 +37,14 @@ class LoopPromotionMapBuilder { // root domains between inlined producer and consumer tensors. static std::unordered_map get( IdModel& id_model, - const StatefulInliningInfo& inlining_info); + const StatefulInliningInfo& inlining_info, + LoopPromotionMapBuilderCallback* callback = nullptr); private: LoopPromotionMapBuilder( IdModel& id_model, - const StatefulInliningInfo& inlining_info); + const StatefulInliningInfo& inlining_info, + LoopPromotionMapBuilderCallback* callback = nullptr); std::unordered_map build(); @@ -117,6 +135,7 @@ class LoopPromotionMapBuilder { private: IdModel& id_model_; const StatefulInliningInfo& inlining_info_; + LoopPromotionMapBuilderCallback* callback_ = nullptr; }; } // namespace nvfuser diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 5033119f758..973dfb6d7ee 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -118,74 +119,57 @@ IterDomain* getChildIdByName(IterDomain* id, StmtNameType name) { }; // Helper class to test IdModel -class IdModelTester : public IdModel { +class IdModelTester : public LoopPromotionMapBuilderCallback { public: // Do not automatically build the graphs - IdModelTester(Fusion* fusion) : IdModel(fusion, /*build_graphs=*/false) { - // Make sure the depedent graphs are already built - maybeBuildGraph(IdMappingMode::EXACT); - maybeBuildGraph(IdMappingMode::PERMISSIVE); - - // Gather broadcast resolution and inlining information - const StatefulInliningInfo inlining_info = buildStatefulInliningInfo( - tv_exprs_, - idGraph(IdMappingMode::EXACT), - idGraph(IdMappingMode::PERMISSIVE)); - - initializeLoopGraph(inlining_info); - - validateLoopGraphHasNoSelfMappedLeafDomains(); - - iel_graph = buildIntersection( - idGraph(IdMappingMode::EXACT), idGraph(IdMappingMode::LOOP), false); + IdModelTester(Fusion* fusion) { + id_model = std::make_unique( + fusion, + /*build_graphs=*/false, + /*allow_self_mapping=*/false, + /*validate=*/true, + /*loop_promotion_map_builder_callback=*/this); + + // Only build the loop graph + id_model->buildLoopGraph(); + } + void postStep1( + const std::unordered_map& iel_root_resolution_map, + const ValGraph& iel_graph) override { + this->iel_graph = iel_graph; + // this->iel_graph is a copy of the original IEL graph. The given + // map is for the original graph and needs to be updated. s1_root_resolution_map = - buildInlineRootResolutionMap(iel_graph, inlining_info); + updateValGroupIdMap(iel_root_resolution_map, this->iel_graph); + } - s2_iel_promotion_map = s1_root_resolution_map; + void postStep2( + const std::unordered_map& iel_promotion_map, + const ValGraph& iel_graph) override { + s2_iel_promotion_map = + updateValGroupIdMap(iel_promotion_map, this->iel_graph); + } - propagatePromotionsInIELGraph(iel_graph, s2_iel_promotion_map); + void postStep3(const std::unordered_map& + loop_promotion_map) override { + s3_loop_graph = id_model->idGraph(IdMappingMode::LOOP); + s3_loop_promotion_map = + updateValGroupIdMap(loop_promotion_map, s3_loop_graph); + } - const auto s3_original_loop_promotion_map = projectIELPromotionToLoopGraph( - iel_graph, - s2_iel_promotion_map, - idGraph(IdMappingMode::LOOP), - inlining_info); + void postStep4( + const std::unordered_map& iel_promotion_map, + const ValGraph& iel_graph) override { + s4_iel_promotion_map = + updateValGroupIdMap(iel_promotion_map, this->iel_graph); + } - // Make a copy for validation as idGraph(IdMappingMode::LOOP) will - // be updated in the later steps - s3_loop_graph = idGraph(IdMappingMode::LOOP); - s3_loop_promotion_map = - updateValGroupIdMap(s3_original_loop_promotion_map, s3_loop_graph); - - // Note that s4_iel_promotion_map is an empty map at this - // point. It'll be populated with the Step-3 map - propagatePromotionsInIELGraph( - iel_graph, - s4_iel_promotion_map, - idGraph(IdMappingMode::LOOP), - s3_original_loop_promotion_map); - - // Step 5: Find the final promotion of each loop group based on the - // final IEL promotion map - s5_loop_promotion_map = projectIELPromotionToLoopGraph( - iel_graph, - s4_iel_promotion_map, - idGraph(IdMappingMode::LOOP), - inlining_info); - - auto updated_s3_loop_promotion_map = updateValGroupIdMap( - s3_loop_promotion_map, idGraph(IdMappingMode::LOOP)); - s5_loop_promotion_map.insert( - updated_s3_loop_promotion_map.begin(), - updated_s3_loop_promotion_map.end()); - - sanityCheckLoopPromotionMap(s5_loop_promotion_map); - validateLoopGraphHasNoSelfMappedLeafDomains(); - - s5_loop_graph = idGraph(IdMappingMode::LOOP); + void postStep5(const std::unordered_map& + loop_promotion_map) override { + s5_loop_graph = id_model->idGraph(IdMappingMode::LOOP); s5_loop_promotion_map = - updateValGroupIdMap(s5_loop_promotion_map, s5_loop_graph); + updateValGroupIdMap(loop_promotion_map, s5_loop_graph); } void print(std::ostream& os) const { @@ -211,6 +195,7 @@ class IdModelTester : public IdModel { } } + std::unique_ptr id_model; ValGraph iel_graph; std::unordered_map s1_root_resolution_map; std::unordered_map s2_iel_promotion_map; @@ -230,8 +215,8 @@ void validateIELResolution( const IdModelTester& tester, const std::unordered_map& iel_promotion_map) { const auto& iel_graph = tester.iel_graph; - const auto& exact_graph = tester.idGraph(IdMappingMode::EXACT); - const auto& loop_graph = tester.idGraph(IdMappingMode::LOOP); + const auto& exact_graph = tester.id_model->idGraph(IdMappingMode::EXACT); + const auto& loop_graph = tester.id_model->idGraph(IdMappingMode::LOOP); const auto& iel_group = iel_graph.toGroup(id); auto iel_promotion_map_it = iel_promotion_map.find(iel_group); From 2b605cd513dedb451c2aeba385ea11fdf24fd5d6 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 May 2024 14:55:42 -0700 Subject: [PATCH 175/178] clang-format --- csrc/id_model/id_model.cpp | 10 +++++++--- csrc/id_model/id_model.h | 9 ++++++--- csrc/id_model/loop_promotion.cpp | 3 ++- csrc/id_model/loop_promotion.h | 11 +++++++---- 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 3eacaea16f2..bc63de3f452 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -77,7 +77,9 @@ IdModel::IdModel( bool build_graphs, bool allow_self_mapping, LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback) - : allow_self_mapping_(allow_self_mapping), loop_promotion_map_builder_callback_(loop_promotion_map_builder_callback) { + : allow_self_mapping_(allow_self_mapping), + loop_promotion_map_builder_callback_( + loop_promotion_map_builder_callback) { std::copy_if( exprs.begin(), exprs.end(), @@ -108,7 +110,8 @@ IdModel::IdModel( LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback) : allow_self_mapping_(allow_self_mapping), validate_(validate), - loop_promotion_map_builder_callback_(loop_promotion_map_builder_callback) { + loop_promotion_map_builder_callback_( + loop_promotion_map_builder_callback) { auto all_exprs = fusion->exprs(); std::copy_if( all_exprs.begin(), @@ -578,7 +581,8 @@ void IdModel::buildLoopGraph() { validateLoopGraphHasNoSelfMappedLeafDomains(); - loop_promotion_map_ = LoopPromotionMapBuilder::get(*this, inlining_info, loop_promotion_map_builder_callback_); + loop_promotion_map_ = LoopPromotionMapBuilder::get( + *this, inlining_info, loop_promotion_map_builder_callback_); // New domains are added. Make sure there's still no self mapping in // the leaf domains diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index a2565dd94e4..618599e634e 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -108,7 +108,8 @@ class IdModel : public PolymorphicBase { const std::vector& additional_tvs = {}, bool build_graphs = true, bool allow_self_mapping = false, - LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback = nullptr); + LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback = + nullptr); // Same as the above constructor with fusion->exprs() excpet fusion may have // some dangling inputs/outputs that are expected to have IterDomain entries @@ -121,7 +122,8 @@ class IdModel : public PolymorphicBase { bool build_graphs = true, bool allow_self_mapping = false, bool validate = true, - LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback = nullptr); + LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback = + nullptr); // Returns iter domain graph of provided mode. The graph must have // been already built. @@ -232,7 +234,8 @@ class IdModel : public PolymorphicBase { // Optional callback for the loop promotion map builder for // debugging and testing - LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback_ = nullptr; + LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback_ = + nullptr; // By default, the permissive graph should map compliment domains as // well. See the design doc for more details diff --git a/csrc/id_model/loop_promotion.cpp b/csrc/id_model/loop_promotion.cpp index 7fb16ce4fcd..26b9413a079 100644 --- a/csrc/id_model/loop_promotion.cpp +++ b/csrc/id_model/loop_promotion.cpp @@ -164,7 +164,8 @@ std::unordered_map LoopPromotionMapBuilder::build() { // Insert the updated Step-3 results into the Step-5 resutls. Note // that this insertion does not overwrite the existing mappings. final_loop_promotion_map.insert( - updated_initial_loop_promotion_map.begin(), updated_initial_loop_promotion_map.end()); + updated_initial_loop_promotion_map.begin(), + updated_initial_loop_promotion_map.end()); sanityCheckLoopPromotionMap(final_loop_promotion_map); diff --git a/csrc/id_model/loop_promotion.h b/csrc/id_model/loop_promotion.h index 30d3a9e5b6e..336823d0ad7 100644 --- a/csrc/id_model/loop_promotion.h +++ b/csrc/id_model/loop_promotion.h @@ -24,10 +24,13 @@ class LoopPromotionMapBuilderCallback { virtual void postStep2( const std::unordered_map& iel_promotion_map, const ValGraph& iel_graph) {} - virtual void postStep3(const std::unordered_map& loop_promotion_map) {} - virtual void postStep4(const std::unordered_map& iel_promotion_map, - const ValGraph& iel_graph) {} - virtual void postStep5(const std::unordered_map& loop_promotion_map) {} + virtual void postStep3( + const std::unordered_map& loop_promotion_map) {} + virtual void postStep4( + const std::unordered_map& iel_promotion_map, + const ValGraph& iel_graph) {} + virtual void postStep5( + const std::unordered_map& loop_promotion_map) {} }; class LoopPromotionMapBuilder { From 0f5ab07d5cebe4a8a1b9538288e18d51f170e8e5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 May 2024 15:16:07 -0700 Subject: [PATCH 176/178] comment --- csrc/id_model/loop_promotion.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/csrc/id_model/loop_promotion.h b/csrc/id_model/loop_promotion.h index 336823d0ad7..ba3d7aae4de 100644 --- a/csrc/id_model/loop_promotion.h +++ b/csrc/id_model/loop_promotion.h @@ -14,21 +14,31 @@ namespace nvfuser { class IdModel; struct StatefulInliningInfo; +// Callback interface for LoopPromotionMapBuilder. Allow exposing the +// temporary maps for testing and debugging class LoopPromotionMapBuilderCallback { public: virtual ~LoopPromotionMapBuilderCallback() = default; + // Called after Step 1 with the root resolution map and the + // corresponding IEL graph virtual void postStep1( const std::unordered_map& iel_root_resolution_map, const ValGraph& iel_graph) {} + // Called after Step 2 with the IEL promotion map and the + // corresponding IEL graph virtual void postStep2( const std::unordered_map& iel_promotion_map, const ValGraph& iel_graph) {} + // Called after Step 3 with the loop promotion map virtual void postStep3( const std::unordered_map& loop_promotion_map) {} + // Called after Step 4 with the IEL promotion map and the + // corresponding IEL graph virtual void postStep4( const std::unordered_map& iel_promotion_map, const ValGraph& iel_graph) {} + // Called after Step 3 with the final loop promotion map virtual void postStep5( const std::unordered_map& loop_promotion_map) {} }; From f0aeab717dbb8c31b7b2a96fd94a6f2b667f12a7 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 20 May 2024 21:11:11 -0700 Subject: [PATCH 177/178] cleanup --- csrc/val_graph.cpp | 48 ---------------------------------------------- csrc/val_graph.h | 3 --- 2 files changed, 51 deletions(-) diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index e55f101f78c..a5e2f8b1729 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -508,54 +508,6 @@ std::string ValGraph::toString() const { return ss.str(); } -bool ValGraph::transformAtributesMatch(Expr* first, Expr* second) { - if (first == nullptr || second == nullptr) { - return false; - } - - NVF_ERROR( - first->isA() || first->isA() || first->isA() || - first->isA() || first->isA(), - "Unsupported rfactor expressions in compute at map:\n", - first->toString()); - - if (typeid(*first) != typeid(*second)) { - return false; - } - - if (first->isA()) { - auto first_split = first->as(); - auto second_split = second->as(); - if (!first_split->factor()->sameAs(second_split->factor()) || - first_split->innerSplit() != second_split->innerSplit() || - !first_split->startOffset()->sameAs(second_split->startOffset()) || - !first_split->stopOffset()->sameAs(second_split->stopOffset())) { - return false; - } - } - - if (first->isA()) { - auto first_swizzle = first->as(); - auto second_swizzle = second->as(); - if (first_swizzle->swizzleMode() != second_swizzle->swizzleMode() || - first_swizzle->swizzleType() != second_swizzle->swizzleType()) { - return false; - } - } - - if (first->isA()) { - auto swizzle_1 = first->as(); - auto swizzle_2 = first->as(); - if (swizzle_1->swizzleType() != swizzle_2->swizzleType()) { - return false; - } - } - - // TODO: Resize properties - - return true; -} - void ValGraph::initializeVal( Val* val, const VectorOfUniqueEntries& definitions, diff --git a/csrc/val_graph.h b/csrc/val_graph.h index 57b3f88b49a..947807b3cd8 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -189,9 +189,6 @@ class ValGraph { std::string toString() const; - // Returns if all atributes of the ID transforms first and second are the same - static bool transformAtributesMatch(Expr* first, Expr* second); - // Initializes entries for the provided Val with its definitions and // uses. void initializeVal( From 18de52378cf58da06451d9dbca79dfd6722b122f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 28 May 2024 11:17:06 -0700 Subject: [PATCH 178/178] enable --- csrc/device_lower/lower2device.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 0206ab7f885..6871e6a592d 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -390,7 +390,7 @@ void GpuLower::analysis(Fusion* fusion) { // functionality should be affected. New IterDomains may be created, // so it is expected that generated code may use diffrent variable // names - if (isOptionEnabled(EnableOption::IdModel)) { + if (true || isOptionEnabled(EnableOption::IdModel)) { IdModel id_model(fusion_); }