diff --git a/.gitignore b/.gitignore index 9a3e95b82b..dac058c615 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# custom tooling +/.tools/ +/.state/ + # ctest-created files Testing/ @@ -5,7 +9,6 @@ Testing/ compile_commands.json # Compiled files -/.tools/ /python/flexflow_python /python/flexflow/core/legion_cffi.py python/flexflow/core/flexflow_cffi_header.py diff --git a/.gitmodules b/.gitmodules index e6068aa368..f0bd8a9ff8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -43,3 +43,9 @@ [submodule "deps/any"] path = deps/any url = https://github.com/thelink2012/any.git +[submodule "deps/nameof"] + path = deps/nameof + url = git@github.com:Neargye/nameof.git +[submodule "deps/boost_preprocessor"] + path = deps/boost_preprocessor + url = https://github.com/boostorg/preprocessor.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 3bbdf13b22..c00597c56b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,84 +3,11 @@ project(FlexFlow) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR}/cmake) -set(FF_MAX_DIM "5" CACHE STRING "Maximum tensor order") -set(FF_MAX_OPNAME "128" CACHE STRING "Maximum op name length") -set(FF_MAX_NUM_OUTPUTS "256" CACHE STRING "Maximum number of outputs (per operator)") -set(FF_MAX_NUM_INPUTS "256" CACHE STRING "Maximum number of inputs (per operator)") -set(FF_MAX_NUM_WEIGHTS "64" CACHE STRING "Maximum number of weights (per operator)") -set(FF_MAX_NUM_FUSED_OPERATORS "64" CACHE STRING "Maximum number of fused tensors") -set(FF_MAX_NUM_FUSED_TENSORS "64" CACHE STRING "Maximum number of input and output tensors per fused op") -set(FF_MAX_NUM_WORKERS "1024" CACHE STRING "Maximum number of GPUs") -set(FF_MAX_NUM_TASK_REGIONS "20" CACHE STRING - "Maximum number of regions that can be passed to a task through the TaskSpec interface") -set(FF_MAX_NUM_TASK_ARGUMENTS "5" CACHE STRING - "Maximum number of arguments that can be declared in a TaskSignature") -option(FF_USE_NCCL "Run FlexFlow with NCCL" OFF) -option(FF_USE_PREBUILT_NCCL "Enable use of NCCL pre-compiled library, if available" ON) -option(FF_USE_PREBUILT_LEGION "Enable use of Legion pre-compiled library, if available" ON) -option(FF_USE_ALL_PREBUILT_LIBRARIES "Enable use of all pre-compiled libraries, if available" OFF) -option(FF_USE_PYTHON "Enable Python" ON) -option(FF_BUILD_FROM_PYPI "Build from pypi" OFF) - -set(FF_GASNET_CONDUITS aries udp mpi ibv ucx) -set(FF_GASNET_CONDUIT "mpi" CACHE STRING "Select GASNet conduit ${FF_GASNET_CONDUITS}") -set_property(CACHE FF_GASNET_CONDUIT PROPERTY STRINGS ${FF_GASNET_CONDUITS}) -set(FF_LEGION_NETWORKS "" CACHE STRING "Network backend(s) to use") - -set(FF_GPU_BACKENDS cuda hip_cuda hip_rocm intel) -set(FF_GPU_BACKEND "cuda" CACHE STRING "Select GPU Backend ${FF_GPU_BACKENDS}") -set_property(CACHE FF_GPU_BACKEND PROPERTY STRINGS ${FF_GPU_BACKENDS}) - -option(FF_USE_EXTERNAL_LEGION "Use pre-installed Legion" OFF) -option(FF_BUILD_RESNET "build resnet example" OFF) -option(FF_BUILD_RESNEXT "build resnext example" OFF) -option(FF_BUILD_ALEXNET "build alexnet example" OFF) -option(FF_BUILD_DLRM "build DLRM example" OFF) -option(FF_BUILD_XDL "build XDL example" OFF) -option(FF_BUILD_INCEPTION "build inception example" OFF) -option(FF_BUILD_CANDLE_UNO "build candle uno example" OFF) -option(FF_BUILD_TRANSFORMER "build transformer example" OFF) -option(FF_BUILD_MOE "build mixture of experts example" OFF) -option(FF_BUILD_MLP_UNIFY "build mlp unify example" OFF) -option(FF_BUILD_SPLIT_TEST "build split test example" OFF) -option(FF_BUILD_SPLIT_TEST_2 "build split test 2 example" OFF) -option(FF_BUILD_ALL_EXAMPLES "build all examples. Overrides others" OFF) -option(FF_BUILD_UNIT_TESTS "build non-operator unit tests" OFF) -option(FF_BUILD_SUBSTITUTION_TOOL "build substitution conversion tool" OFF) -option(FF_BUILD_VISUALIZATION_TOOL "build substitution visualization tool" OFF) -option(FF_BUILD_ARG_PARSER "build command line argument parser" OFF) - -set(FF_CUDA_ARCH "autodetect" CACHE STRING "Target CUDA Arch") -if (FF_CUDA_ARCH STREQUAL "") - message(FATAL_ERROR "FF_CUDA_ARCH cannot be an empty string. Set it to `autodetect`, `all`, or pass one or multiple valid CUDA archs.") -endif() - -if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") - set(LIBEXT ".so") -endif() - -include(cuda) -include(cudnn) -include(nccl) # set_property(CACHE FF_GPU_BACKEND PROPERTY STRINGS ${FF_GPU_BACKENDS}) -include(json) -include(optional) -include(expected) -include(spdlog) -include(variant) -include(doctest) -include(visit_struct) include(CTest) -include(fmt) -include(legion) -include(rapidcheck) -include(invoke) -include(any) -#include(gtest) -#include(fmt) - -include(flexflow-utils) +include(utils) +include(deps) # TODO @lockshaw remove me # https://discourse.nixos.org/t/get-clangd-to-find-standard-headers-in-nix-shell/11268/6 diff --git a/cmake/any.cmake b/cmake/any.cmake deleted file mode 100644 index 9a6164da4f..0000000000 --- a/cmake/any.cmake +++ /dev/null @@ -1,16 +0,0 @@ -add_library( - any - INTERFACE -) -target_include_directories( - any - INTERFACE - ${CMAKE_CURRENT_SOURCE_DIR}/deps/any/ -) -set_target_properties( - any - PROPERTIES - CXX_STANDARD 11 - CXX_STANDARD_REQUIRED YES - CXX_EXTENSIONS NO -) diff --git a/cmake/deps.cmake b/cmake/deps.cmake new file mode 100644 index 0000000000..f9e095e3c8 --- /dev/null +++ b/cmake/deps.cmake @@ -0,0 +1,14 @@ +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR}/deps) + +include(boost_preprocessor) +include(cuda) +include(cudnn) +include(doctest) +include(fmt) +include(json) +include(legion) +include(nameof) +include(nccl) +include(rapidcheck) +include(spdlog) +include(visit_struct) diff --git a/cmake/deps/boost_preprocessor.cmake b/cmake/deps/boost_preprocessor.cmake new file mode 100644 index 0000000000..8294aa95f1 --- /dev/null +++ b/cmake/deps/boost_preprocessor.cmake @@ -0,0 +1 @@ +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/boost_preprocessor) diff --git a/cmake/cuda.cmake b/cmake/deps/cuda.cmake similarity index 100% rename from cmake/cuda.cmake rename to cmake/deps/cuda.cmake diff --git a/cmake/cudnn.cmake b/cmake/deps/cudnn.cmake similarity index 100% rename from cmake/cudnn.cmake rename to cmake/deps/cudnn.cmake diff --git a/cmake/doctest.cmake b/cmake/deps/doctest.cmake similarity index 100% rename from cmake/doctest.cmake rename to cmake/deps/doctest.cmake diff --git a/cmake/fmt.cmake b/cmake/deps/fmt.cmake similarity index 100% rename from cmake/fmt.cmake rename to cmake/deps/fmt.cmake diff --git a/cmake/json.cmake b/cmake/deps/json.cmake similarity index 100% rename from cmake/json.cmake rename to cmake/deps/json.cmake diff --git a/cmake/legion.cmake b/cmake/deps/legion.cmake similarity index 100% rename from cmake/legion.cmake rename to cmake/deps/legion.cmake diff --git a/cmake/deps/nameof.cmake b/cmake/deps/nameof.cmake new file mode 100644 index 0000000000..e1c22c23af --- /dev/null +++ b/cmake/deps/nameof.cmake @@ -0,0 +1 @@ +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/nameof) diff --git a/cmake/nccl.cmake b/cmake/deps/nccl.cmake similarity index 100% rename from cmake/nccl.cmake rename to cmake/deps/nccl.cmake diff --git a/cmake/rapidcheck.cmake b/cmake/deps/rapidcheck.cmake similarity index 100% rename from cmake/rapidcheck.cmake rename to cmake/deps/rapidcheck.cmake diff --git a/cmake/spdlog.cmake b/cmake/deps/spdlog.cmake similarity index 100% rename from cmake/spdlog.cmake rename to cmake/deps/spdlog.cmake diff --git a/cmake/ucx.cmake b/cmake/deps/ucx.cmake similarity index 100% rename from cmake/ucx.cmake rename to cmake/deps/ucx.cmake diff --git a/cmake/visit_struct.cmake b/cmake/deps/visit_struct.cmake similarity index 100% rename from cmake/visit_struct.cmake rename to cmake/deps/visit_struct.cmake diff --git a/cmake/zlib.cmake b/cmake/deps/zlib.cmake similarity index 100% rename from cmake/zlib.cmake rename to cmake/deps/zlib.cmake diff --git a/cmake/expected.cmake b/cmake/expected.cmake deleted file mode 100644 index 7ae0749354..0000000000 --- a/cmake/expected.cmake +++ /dev/null @@ -1,4 +0,0 @@ -set(EXPECTED_BUILD_TESTS OFF) -set(EXPECTED_BUILD_PACKAGE OFF) - -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/expected) diff --git a/cmake/invoke.cmake b/cmake/invoke.cmake deleted file mode 100644 index 3ec406ed05..0000000000 --- a/cmake/invoke.cmake +++ /dev/null @@ -1,5 +0,0 @@ -include(aliasing) - -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/invoke) - -alias_library(invoke invoke.hpp::invoke.hpp) diff --git a/cmake/optional.cmake b/cmake/optional.cmake deleted file mode 100644 index afaa6330c0..0000000000 --- a/cmake/optional.cmake +++ /dev/null @@ -1,4 +0,0 @@ -set(OPTIONAL_BUILD_TESTS OFF) -set(OPTIONAL_BUILD_PACKAGE OFF) - -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/optional) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 4e23ed2e3f..8c9102ee54 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -1,56 +1,6 @@ -set(known_gpu_archs "") -function(remove_duplicate_args __string) - if(${__string}) - set(__list ${${__string}}) - separate_arguments(__list) - list(REMOVE_DUPLICATES __list) - foreach(__e ${__list}) - set(__str "${__str} ${__e}") - endforeach() - set(${__string} ${__str} PARENT_SCOPE) - endif() -endfunction() -function(detect_installed_gpus out_variable) - if(NOT CUDA_gpu_detect_output) - set(__cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu) - file(WRITE ${__cufile} "" - "#include \n" - "int main()\n" - "{\n" - " int count = 0;\n" - " if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n" - " if (count == 0) return -1;\n" - " for (int device = 0; device < count; ++device)\n" - " {\n" - " cudaDeviceProp prop;\n" - " if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n" - " std::printf(\"%d.%d \", prop.major, prop.minor);\n" - " }\n" - " return 0;\n" - "}\n") - execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "--run" "${__cufile}" - WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" - RESULT_VARIABLE __nvcc_res OUTPUT_VARIABLE __nvcc_out - ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) - if(__nvcc_res EQUAL 0) - message(STATUS "No result from nvcc so building for 2.0") - string(REPLACE "2.1" "2.1(2.0)" __nvcc_out "${__nvcc_out}") - set(CUDA_gpu_detect_output ${__nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_gpus tool" FORCE) - endif() - endif() - if(NOT CUDA_gpu_detect_output) - message(STATUS "Automatic GPU detection failed, Architecture is not set: ${known_gpu_archs}.") - set(${out_variable} ${known_gpu_archs} PARENT_SCOPE) - else() - remove_duplicate_args(CUDA_gpu_detect_output) - #Strip leading and trailing whitespaces - string(STRIP "${CUDA_gpu_detect_output}" CUDA_gpu_detect_output) - #Replace spaces in between with commas so you go from "5.2 6.1" to "5.2,6.1" - string(REGEX REPLACE " " "," CUDA_gpu_detect_output "${CUDA_gpu_detect_output}") - # message(${CUDA_gpu_detect_output}) - string(REPLACE "." "" CUDA_gpu_detect_output "${CUDA_gpu_detect_output}") - # message(${CUDA_gpu_detect_output}) - set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE) - # message(STATUS "Automatic GPU ARCH detection: ${CUDA_gpu_detect_output}") - endif() -endfunction() +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR}/utils) + +include(flexflow-utils) +include(aliasing) +include(build-options) +include(libext) diff --git a/cmake/aliasing.cmake b/cmake/utils/aliasing.cmake similarity index 100% rename from cmake/aliasing.cmake rename to cmake/utils/aliasing.cmake diff --git a/cmake/utils/build-options.cmake b/cmake/utils/build-options.cmake new file mode 100644 index 0000000000..8f59e807a7 --- /dev/null +++ b/cmake/utils/build-options.cmake @@ -0,0 +1,51 @@ +set(FF_MAX_DIM "5" CACHE STRING "Maximum tensor order") +set(FF_MAX_OPNAME "128" CACHE STRING "Maximum op name length") +set(FF_MAX_NUM_OUTPUTS "256" CACHE STRING "Maximum number of outputs (per operator)") +set(FF_MAX_NUM_INPUTS "256" CACHE STRING "Maximum number of inputs (per operator)") +set(FF_MAX_NUM_WEIGHTS "64" CACHE STRING "Maximum number of weights (per operator)") +set(FF_MAX_NUM_FUSED_OPERATORS "64" CACHE STRING "Maximum number of fused tensors") +set(FF_MAX_NUM_FUSED_TENSORS "64" CACHE STRING "Maximum number of input and output tensors per fused op") +set(FF_MAX_NUM_WORKERS "1024" CACHE STRING "Maximum number of GPUs") +set(FF_MAX_NUM_TASK_REGIONS "20" CACHE STRING + "Maximum number of regions that can be passed to a task through the TaskSpec interface") +set(FF_MAX_NUM_TASK_ARGUMENTS "5" CACHE STRING + "Maximum number of arguments that can be declared in a TaskSignature") +option(FF_USE_NCCL "Run FlexFlow with NCCL" OFF) +option(FF_USE_PREBUILT_NCCL "Enable use of NCCL pre-compiled library, if available" ON) +option(FF_USE_PREBUILT_LEGION "Enable use of Legion pre-compiled library, if available" ON) +option(FF_USE_ALL_PREBUILT_LIBRARIES "Enable use of all pre-compiled libraries, if available" OFF) +option(FF_USE_PYTHON "Enable Python" ON) +option(FF_BUILD_FROM_PYPI "Build from pypi" OFF) + +set(FF_GASNET_CONDUITS aries udp mpi ibv ucx) +set(FF_GASNET_CONDUIT "mpi" CACHE STRING "Select GASNet conduit ${FF_GASNET_CONDUITS}") +set_property(CACHE FF_GASNET_CONDUIT PROPERTY STRINGS ${FF_GASNET_CONDUITS}) +set(FF_LEGION_NETWORKS "" CACHE STRING "Network backend(s) to use") + +set(FF_GPU_BACKENDS cuda hip_cuda hip_rocm intel) +set(FF_GPU_BACKEND "cuda" CACHE STRING "Select GPU Backend ${FF_GPU_BACKENDS}") +set_property(CACHE FF_GPU_BACKEND PROPERTY STRINGS ${FF_GPU_BACKENDS}) + +option(FF_USE_EXTERNAL_LEGION "Use pre-installed Legion" OFF) +option(FF_BUILD_RESNET "build resnet example" OFF) +option(FF_BUILD_RESNEXT "build resnext example" OFF) +option(FF_BUILD_ALEXNET "build alexnet example" OFF) +option(FF_BUILD_DLRM "build DLRM example" OFF) +option(FF_BUILD_XDL "build XDL example" OFF) +option(FF_BUILD_INCEPTION "build inception example" OFF) +option(FF_BUILD_CANDLE_UNO "build candle uno example" OFF) +option(FF_BUILD_TRANSFORMER "build transformer example" OFF) +option(FF_BUILD_MOE "build mixture of experts example" OFF) +option(FF_BUILD_MLP_UNIFY "build mlp unify example" OFF) +option(FF_BUILD_SPLIT_TEST "build split test example" OFF) +option(FF_BUILD_SPLIT_TEST_2 "build split test 2 example" OFF) +option(FF_BUILD_ALL_EXAMPLES "build all examples. Overrides others" OFF) +option(FF_BUILD_UNIT_TESTS "build non-operator unit tests" OFF) +option(FF_BUILD_SUBSTITUTION_TOOL "build substitution conversion tool" OFF) +option(FF_BUILD_VISUALIZATION_TOOL "build substitution visualization tool" OFF) +option(FF_BUILD_ARG_PARSER "build command line argument parser" OFF) + +set(FF_CUDA_ARCH "autodetect" CACHE STRING "Target CUDA Arch") +if (FF_CUDA_ARCH STREQUAL "") + message(FATAL_ERROR "FF_CUDA_ARCH cannot be an empty string. Set it to `autodetect`, `all`, or pass one or multiple valid CUDA archs.") +endif() diff --git a/cmake/flexflow-utils.cmake b/cmake/utils/flexflow-utils.cmake similarity index 64% rename from cmake/flexflow-utils.cmake rename to cmake/utils/flexflow-utils.cmake index d41573acab..635cac3546 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/utils/flexflow-utils.cmake @@ -43,6 +43,49 @@ function(ff_set_cxx_properties target) ) endfunction() +function(ff_get_source_files) + file(GLOB_RECURSE SRC + LIST_DIRECTORIES False + "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cc") + list(FILTER SRC EXCLUDE REGEX "\\.test\\.") + set(FF_SOURCE_FILES ${SRC} PARENT_SCOPE) +endfunction() + +function(ff_add_intree_test_executable) + ff_parse_args( + PREFIX + FF_TEST_EXEC + ARGS + NAME + VARIADIC_ARGS + DEPS + PARSE + ${ARGN} + ) + + project(${FF_TEST_EXEC_NAME}) + file(GLOB_RECURSE FF_SOURCE_FILES + LIST_DIRECTORIES False + "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cc") + + add_executable( + ${FF_TEST_EXEC_NAME} + ${FF_SOURCE_FILES}) + + target_link_libraries( + ${FF_TEST_EXEC_NAME} + ${FF_TEST_EXEC_DEPS} + utils-testing + utils-rapidcheck_extra + utils-test_types) + + target_compile_definitions(${FF_TEST_EXEC_NAME} PRIVATE FF_TEST_SUITE="${FF_TEST_EXEC_NAME}") + + define_ff_vars(${FF_TEST_EXEC_NAME}) + ff_set_cxx_properties(${FF_TEST_EXEC_NAME}) + doctest_discover_tests(${FF_TEST_EXEC_NAME} ADD_LABELS 1) +endfunction() + function(ff_add_library) ff_parse_args( PREFIX @@ -50,9 +93,6 @@ function(ff_add_library) ARGS NAME VARIADIC_ARGS - SRC_PATTERNS - PUBLIC_INCLUDE - PRIVATE_INCLUDE DEPS PRIVATE_DEPS PARSE @@ -60,22 +100,19 @@ function(ff_add_library) ) project(${FF_LIBRARY_NAME}) - file(GLOB_RECURSE SRC - CONFIGURE_DEPENDS - LIST_DIRECTORIES False - ${FF_LIBRARY_SRC_PATTERNS}) + ff_get_source_files() add_library( ${FF_LIBRARY_NAME} SHARED - ${SRC}) + ${FF_SOURCE_FILES}) target_include_directories( ${FF_LIBRARY_NAME} PUBLIC - ${FF_LIBRARY_PUBLIC_INCLUDE} + "${CMAKE_CURRENT_SOURCE_DIR}/include" PRIVATE - ${FF_LIBRARY_PRIVATE_INCLUDE}) + "${CMAKE_CURRENT_SOURCE_DIR}/src") target_link_libraries( ${FF_LIBRARY_NAME} @@ -86,6 +123,14 @@ function(ff_add_library) ) define_ff_vars(${FF_LIBRARY_NAME}) ff_set_cxx_properties(${FF_LIBRARY_NAME}) + + ff_add_intree_test_executable( + NAME + "${FF_LIBRARY_NAME}-tests" + ${FF_LIBRARY_NO_TEST_LIB} + DEPS + ${FF_LIBRARY_NAME} + ) endfunction() function(ff_add_test_executable) @@ -100,25 +145,27 @@ function(ff_add_test_executable) DEPS PARSE ${ARGN} - rapidcheck - doctest ) project(${FF_TEST_EXEC_NAME}) - file(GLOB_RECURSE SRC - CONFIGURE_DEPENDS + ff_get_source_files() + file(GLOB_RECURSE IN_TREE_FILES LIST_DIRECTORIES False - ${FF_TEST_EXEC_SRC_PATTERNS}) + "${CMAKE_CURRENT_SOURCE_DIR}/src/*.test.cc") + list(APPEND FF_SOURCE_FILES ${IN_TREE_FILES}) add_executable( ${FF_TEST_EXEC_NAME} - ${SRC}) + ${FF_SOURCE_FILES}) target_link_libraries( ${FF_TEST_EXEC_NAME} - ${FF_TEST_EXEC_DEPS}) + ${FF_TEST_EXEC_DEPS} + rapidcheck + doctest) define_ff_vars(${FF_TEST_EXEC_NAME}) ff_set_cxx_properties(${FF_TEST_EXEC_NAME}) doctest_discover_tests(${FF_TEST_EXEC_NAME}) + endfunction() diff --git a/cmake/utils/gpu_utils.cmake b/cmake/utils/gpu_utils.cmake new file mode 100644 index 0000000000..4e23ed2e3f --- /dev/null +++ b/cmake/utils/gpu_utils.cmake @@ -0,0 +1,56 @@ +set(known_gpu_archs "") +function(remove_duplicate_args __string) + if(${__string}) + set(__list ${${__string}}) + separate_arguments(__list) + list(REMOVE_DUPLICATES __list) + foreach(__e ${__list}) + set(__str "${__str} ${__e}") + endforeach() + set(${__string} ${__str} PARENT_SCOPE) + endif() +endfunction() +function(detect_installed_gpus out_variable) + if(NOT CUDA_gpu_detect_output) + set(__cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu) + file(WRITE ${__cufile} "" + "#include \n" + "int main()\n" + "{\n" + " int count = 0;\n" + " if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n" + " if (count == 0) return -1;\n" + " for (int device = 0; device < count; ++device)\n" + " {\n" + " cudaDeviceProp prop;\n" + " if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n" + " std::printf(\"%d.%d \", prop.major, prop.minor);\n" + " }\n" + " return 0;\n" + "}\n") + execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "--run" "${__cufile}" + WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" + RESULT_VARIABLE __nvcc_res OUTPUT_VARIABLE __nvcc_out + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + if(__nvcc_res EQUAL 0) + message(STATUS "No result from nvcc so building for 2.0") + string(REPLACE "2.1" "2.1(2.0)" __nvcc_out "${__nvcc_out}") + set(CUDA_gpu_detect_output ${__nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_gpus tool" FORCE) + endif() + endif() + if(NOT CUDA_gpu_detect_output) + message(STATUS "Automatic GPU detection failed, Architecture is not set: ${known_gpu_archs}.") + set(${out_variable} ${known_gpu_archs} PARENT_SCOPE) + else() + remove_duplicate_args(CUDA_gpu_detect_output) + #Strip leading and trailing whitespaces + string(STRIP "${CUDA_gpu_detect_output}" CUDA_gpu_detect_output) + #Replace spaces in between with commas so you go from "5.2 6.1" to "5.2,6.1" + string(REGEX REPLACE " " "," CUDA_gpu_detect_output "${CUDA_gpu_detect_output}") + # message(${CUDA_gpu_detect_output}) + string(REPLACE "." "" CUDA_gpu_detect_output "${CUDA_gpu_detect_output}") + # message(${CUDA_gpu_detect_output}) + set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE) + # message(STATUS "Automatic GPU ARCH detection: ${CUDA_gpu_detect_output}") + endif() +endfunction() diff --git a/cmake/utils/libext.cmake b/cmake/utils/libext.cmake new file mode 100644 index 0000000000..cb8efc2a8f --- /dev/null +++ b/cmake/utils/libext.cmake @@ -0,0 +1,3 @@ +if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + set(LIBEXT ".so") +endif() diff --git a/cmake/variant.cmake b/cmake/variant.cmake deleted file mode 100644 index ddf5781281..0000000000 --- a/cmake/variant.cmake +++ /dev/null @@ -1,5 +0,0 @@ -include(aliasing) - -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/variant) - -alias_library(variant mpark_variant) diff --git a/deps/boost_preprocessor b/deps/boost_preprocessor new file mode 160000 index 0000000000..667e87b339 --- /dev/null +++ b/deps/boost_preprocessor @@ -0,0 +1 @@ +Subproject commit 667e87b3392db338a919cbe0213979713aca52e3 diff --git a/deps/nameof b/deps/nameof new file mode 160000 index 0000000000..8aeb677413 --- /dev/null +++ b/deps/nameof @@ -0,0 +1 @@ +Subproject commit 8aeb6774132a01765d8c8679d016b728acd069f5 diff --git a/scripts/mnist_mlp_run.sh b/examples/multi-node/mnist_mlp_run.sh similarity index 100% rename from scripts/mnist_mlp_run.sh rename to examples/multi-node/mnist_mlp_run.sh diff --git a/scripts/osdi22ae/bert.sh b/examples/papers/osdi22ae/bert.sh similarity index 100% rename from scripts/osdi22ae/bert.sh rename to examples/papers/osdi22ae/bert.sh diff --git a/scripts/osdi22ae/candle_uno.sh b/examples/papers/osdi22ae/candle_uno.sh similarity index 100% rename from scripts/osdi22ae/candle_uno.sh rename to examples/papers/osdi22ae/candle_uno.sh diff --git a/scripts/osdi22ae/dlrm.sh b/examples/papers/osdi22ae/dlrm.sh similarity index 100% rename from scripts/osdi22ae/dlrm.sh rename to examples/papers/osdi22ae/dlrm.sh diff --git a/scripts/osdi22ae/inception.sh b/examples/papers/osdi22ae/inception.sh similarity index 100% rename from scripts/osdi22ae/inception.sh rename to examples/papers/osdi22ae/inception.sh diff --git a/scripts/osdi22ae/mlp.sh b/examples/papers/osdi22ae/mlp.sh similarity index 100% rename from scripts/osdi22ae/mlp.sh rename to examples/papers/osdi22ae/mlp.sh diff --git a/scripts/osdi22ae/resnext-50.sh b/examples/papers/osdi22ae/resnext-50.sh similarity index 100% rename from scripts/osdi22ae/resnext-50.sh rename to examples/papers/osdi22ae/resnext-50.sh diff --git a/scripts/osdi22ae/xdl.sh b/examples/papers/osdi22ae/xdl.sh similarity index 100% rename from scripts/osdi22ae/xdl.sh rename to examples/papers/osdi22ae/xdl.sh diff --git a/lib/kernels/CMakeLists.txt b/lib/kernels/CMakeLists.txt index 59c7d44b60..f68ab0fb2c 100644 --- a/lib/kernels/CMakeLists.txt +++ b/lib/kernels/CMakeLists.txt @@ -1,37 +1,20 @@ -set(project_target kernels) - -project(${project_target} - LANGUAGES CUDA) - -file(GLOB_RECURSE SRC - CONFIGURE_DEPENDS - LIST_DIRECTORIES False - src/*.cc) - -add_library( - ${project_target} - SHARED - ${SRC} -) -target_include_directories( - ${project_target} - PRIVATE - src/cuda/ - PUBLIC +ff_add_library( + NAME + kernels + SRC_PATTERNS + src/*.cc + PUBLIC_INCLUDE include/ -) -target_link_libraries( - ${project_target} - op-attrs - cuda - cudnn - nccl -) - -define_ff_vars(kernels) - + PRIVATE_INCLUDE + src/ + DEPS + op-attrs + cuda + cudnn + nccl +) set_target_properties( - ${project_target} + kernels PROPERTIES CUDA_STANDARD 11 ) diff --git a/lib/op-attrs/CMakeLists.txt b/lib/op-attrs/CMakeLists.txt index 778be53d7c..23d8bfabf5 100644 --- a/lib/op-attrs/CMakeLists.txt +++ b/lib/op-attrs/CMakeLists.txt @@ -11,4 +11,5 @@ ff_add_library( utils ) +add_subdirectory(test) add_subdirectory(ffi) diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 5fd067313e..4e3daefc7f 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -66,65 +66,24 @@ using SharedOperatorAttrs = variant; - -CHECK_VALID_OP_ATTR(AggregateAttrs); - -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); +CHECK_WELL_BEHAVED_VALUE_TYPE(SharedOperatorAttrs); +CHECK_FMTABLE(SharedOperatorAttrs); using ParallelOperatorAttrs = variant; +CHECK_WELL_BEHAVED_VALUE_TYPE(ParallelOperatorAttrs); +CHECK_FMTABLE(ParallelOperatorAttrs); using ComputationGraphAttrs = variant_join>; using CompGraphOperatorAttrs = ComputationGraphAttrs; +CHECK_WELL_BEHAVED_VALUE_TYPE(ComputationGraphAttrs); +CHECK_FMTABLE(ComputationGraphAttrs); using PCGOperatorAttrs = variant_join; - -static_assert(is_equal_comparable::value, - "ComputationGraphAttrs must support =="); -static_assert(elements_satisfy::value, - ""); -static_assert(is_neq_comparable::value, - "ComputationGraphAttrs must support !="); -static_assert(is_lt_comparable::value, - "ComputationGraphAttrs must support <"); -static_assert(is_hashable::value, - "ComputationGraphAttrs must be hashable"); - -static_assert(is_equal_comparable::value, - "PCGOperatorAttrs must support =="); -static_assert(is_neq_comparable::value, - "PCGOperatorAttrs must support !="); -static_assert(is_lt_comparable::value, - "PCGOperatorAttrs must support <"); -static_assert(is_hashable::value, - "PCGOperatorAttrs must be hashable"); +CHECK_WELL_BEHAVED_VALUE_TYPE(PCGOperatorAttrs); +CHECK_FMTABLE(PCGOperatorAttrs); /* OperatorType get_op_type(CompGraphOperatorAttrs const &); */ /* OperatorType get_op_type(PCGOperatorAttrs const &); */ diff --git a/lib/op-attrs/include/op-attrs/ops/aggregate.h b/lib/op-attrs/include/op-attrs/ops/aggregate.h index faf16472b1..381bd8344f 100644 --- a/lib/op-attrs/include/op-attrs/ops/aggregate.h +++ b/lib/op-attrs/include/op-attrs/ops/aggregate.h @@ -12,10 +12,11 @@ namespace FlexFlow { struct AggregateAttrs { - req n; + int n; req lambda_bal; }; FF_VISITABLE_STRUCT(AggregateAttrs, n, lambda_bal); +FF_VISIT_FMTABLE(AggregateAttrs); DataType get_datatype(AggregateAttrs const &); bool is_valid(AggregateAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/aggregate_spec.h b/lib/op-attrs/include/op-attrs/ops/aggregate_spec.h index 8373452dfa..0e717db2ca 100644 --- a/lib/op-attrs/include/op-attrs/ops/aggregate_spec.h +++ b/lib/op-attrs/include/op-attrs/ops/aggregate_spec.h @@ -8,10 +8,11 @@ namespace FlexFlow { struct AggregateSpecAttrs { - req n; + int n; req lambda_bal; }; FF_VISITABLE_STRUCT(AggregateSpecAttrs, n, lambda_bal); +FF_VISIT_FMTABLE(AggregateSpecAttrs); ParallelTensorShape get_output_shape(AggregateSpecAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index ec3e592607..198c76deda 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -8,9 +8,14 @@ namespace FlexFlow { struct MultiHeadAttentionAttrs { - req embed_dim, num_heads, kdim, vdim; - req dropout; - req bias, add_bias_kv, add_zero_attn; + int embed_dim; + int num_heads; + int kdim; + int vdim; + float dropout; + bool bias; + bool add_bias_kv; + req add_zero_attn; }; FF_VISITABLE_STRUCT(MultiHeadAttentionAttrs, embed_dim, @@ -21,6 +26,7 @@ FF_VISITABLE_STRUCT(MultiHeadAttentionAttrs, bias, add_bias_kv, add_zero_attn); +FF_VISIT_FMTABLE(MultiHeadAttentionAttrs); template struct MultiHeadAttentionInputs diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index c74824570c..96325d4373 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -8,9 +8,11 @@ namespace FlexFlow { struct BatchMatmulAttrs { - req a_seq_length_dim, b_seq_length_dim; + int a_seq_length_dim; + req b_seq_length_dim; }; FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); +FF_VISIT_FMTABLE(BatchMatmulAttrs); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 4ec823d4ae..ab168557d1 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -11,6 +11,7 @@ struct BatchNormAttrs { req relu; }; FF_VISITABLE_STRUCT(BatchNormAttrs, relu); +FF_VISIT_FMTABLE(BatchNormAttrs); ParallelTensorShape get_output_shape(BatchNormAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 433bf23241..9fd96505d5 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -11,6 +11,7 @@ struct BroadcastAttrs { req> target_dims; }; FF_VISITABLE_STRUCT(BroadcastAttrs, target_dims); +FF_VISIT_FMTABLE(BroadcastAttrs); CHECK_VALID_OP_ATTR(BroadcastAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 63563f8df8..6cb46979be 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -12,6 +12,7 @@ struct CastAttrs { req dtype; }; FF_VISITABLE_STRUCT(CastAttrs, dtype); +FF_VISIT_FMTABLE(CastAttrs); CHECK_VALID_OP_ATTR(CastAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index deaba9e093..92ca3e9a8e 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -13,6 +13,8 @@ struct CombineAttrs { req combine_degree; }; FF_VISITABLE_STRUCT(CombineAttrs, combine_dim, combine_degree); +FF_VISIT_FMTABLE(CombineAttrs); + CHECK_VALID_OP_ATTR(CombineAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index b9bd14a231..0f400c4a3b 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -12,6 +12,8 @@ struct ConcatAttrs { ff_dim_t axis; }; FF_VISITABLE_STRUCT(ConcatAttrs, axis); +FF_VISIT_FMTABLE(ConcatAttrs); + CHECK_VALID_OP_ATTR(ConcatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 3034dc8c62..728c883e19 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -27,6 +27,8 @@ FF_VISITABLE_STRUCT(Conv2DAttrs, groups, activation, use_bias); +FF_VISIT_FMTABLE(Conv2DAttrs); + CHECK_VALID_OP_ATTR(Conv2DAttrs); TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/core.h b/lib/op-attrs/include/op-attrs/ops/core.h index 611b53def5..14dddcdcfc 100644 --- a/lib/op-attrs/include/op-attrs/ops/core.h +++ b/lib/op-attrs/include/op-attrs/ops/core.h @@ -5,7 +5,9 @@ namespace FlexFlow { -#define CHECK_VALID_OP_ATTR(TYPENAME) CHECK_WELL_BEHAVED_VALUE_TYPE(TYPENAME) +#define CHECK_VALID_OP_ATTR(TYPENAME) \ + CHECK_WELL_BEHAVED_VALUE_TYPE(TYPENAME); \ + CHECK_FMTABLE(TYPENAME) template using is_valid_opattr = is_well_behaved_value_type; diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 8e0049f526..073eea1713 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -12,6 +12,8 @@ struct DropoutAttrs { req seed; }; FF_VISITABLE_STRUCT(DropoutAttrs, rate, seed); +FF_VISIT_FMTABLE(DropoutAttrs); + CHECK_VALID_OP_ATTR(DropoutAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index c4a096166d..eabaf9e6e0 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -10,9 +10,9 @@ namespace FlexFlow { struct ElementBinaryAttrs { - req type; - req compute_type; - req should_broadcast_lhs; + Op type; + DataType compute_type; + bool should_broadcast_lhs; req should_broadcast_rhs; }; FF_VISITABLE_STRUCT(ElementBinaryAttrs, @@ -20,6 +20,8 @@ FF_VISITABLE_STRUCT(ElementBinaryAttrs, compute_type, should_broadcast_lhs, should_broadcast_rhs); +FF_VISIT_FMTABLE(ElementBinaryAttrs); + CHECK_VALID_OP_ATTR(ElementBinaryAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 1b72e83cb5..92b56e24bc 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -9,17 +9,20 @@ namespace FlexFlow { struct ElementScalarUnaryAttrs { - req op; - /* bool inplace; */ + Op op; req scalar; }; FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op, scalar); +FF_VISIT_FMTABLE(ElementScalarUnaryAttrs); + CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); struct ElementUnaryAttrs { req op; }; FF_VISITABLE_STRUCT(ElementUnaryAttrs, op); +FF_VISIT_FMTABLE(ElementUnaryAttrs); + CHECK_VALID_OP_ATTR(ElementUnaryAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index 8b00fa22ce..8838ff05d0 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -15,40 +15,22 @@ enum class AggregateOp { AVG, }; +std::string format_as(AggregateOp const &); +CHECK_FMTABLE(AggregateOp); + struct EmbeddingAttrs { - req num_entries, out_channels; - req aggr; + int num_entries; + int out_channels; + AggregateOp aggr; req data_type; }; FF_VISITABLE_STRUCT(EmbeddingAttrs, num_entries, out_channels, aggr, data_type); +FF_VISIT_FMTABLE(EmbeddingAttrs); + CHECK_VALID_OP_ATTR(EmbeddingAttrs); TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &); } // namespace FlexFlow -namespace fmt { - -template <> -struct formatter<::FlexFlow::AggregateOp> : formatter { - template - auto format(::FlexFlow::AggregateOp o, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (o) { - case AggregateOp::SUM: - name = "Sum"; - break; - case AggregateOp::AVG: - name = "Avg"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 706689199d..705f42ae3d 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -9,6 +9,8 @@ namespace FlexFlow { struct FlatAttrs {}; FF_VISITABLE_STRUCT(FlatAttrs); +FF_VISIT_FMTABLE(FlatAttrs); + CHECK_VALID_OP_ATTR(FlatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index ca2406ef75..c546edb923 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -12,6 +12,8 @@ struct GatherAttrs { ff_dim_t dim; }; FF_VISITABLE_STRUCT(GatherAttrs, dim); +FF_VISIT_FMTABLE(GatherAttrs); + CHECK_VALID_OP_ATTR(GatherAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/groupby.h b/lib/op-attrs/include/op-attrs/ops/groupby.h index 174c40242e..f4232a7a41 100644 --- a/lib/op-attrs/include/op-attrs/ops/groupby.h +++ b/lib/op-attrs/include/op-attrs/ops/groupby.h @@ -8,10 +8,12 @@ namespace FlexFlow { struct Group_byAttrs { - req n; + int n; req alpha; }; FF_VISITABLE_STRUCT(Group_byAttrs, n, alpha); +FF_VISIT_FMTABLE(Group_byAttrs); + CHECK_VALID_OP_ATTR(Group_byAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index 26c486c9ac..1eaef3b9d3 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -8,6 +8,8 @@ namespace FlexFlow { struct InputAttrs {}; FF_VISITABLE_STRUCT(InputAttrs); +FF_VISIT_FMTABLE(InputAttrs); + CHECK_VALID_OP_ATTR(InputAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index dab055b2c9..271eed6ffe 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -10,10 +10,12 @@ namespace FlexFlow { struct LayerNormAttrs { stack_vector axes; - req elementwise_affine; + bool elementwise_affine; req eps; }; FF_VISITABLE_STRUCT(LayerNormAttrs, axes, elementwise_affine, eps); +FF_VISIT_FMTABLE(LayerNormAttrs); + CHECK_VALID_OP_ATTR(LayerNormAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 3be8be2040..251873aed5 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -13,25 +13,28 @@ struct L1RegularizerAttrs { req lambda; }; FF_VISITABLE_STRUCT(L1RegularizerAttrs, lambda); +FF_VISIT_FMTABLE(L1RegularizerAttrs); CHECK_VALID_OP_ATTR(L1RegularizerAttrs); struct L2RegularizerAttrs { req lambda; }; FF_VISITABLE_STRUCT(L2RegularizerAttrs, lambda); +FF_VISIT_FMTABLE(L2RegularizerAttrs); CHECK_VALID_OP_ATTR(L2RegularizerAttrs); using RegularizerAttrs = variant; struct LinearAttrs { - req out_channels; - req use_bias; - req data_type; - req activation; + int out_channels; + bool use_bias; + DataType data_type; + Activation activation; req> regularizer; }; FF_VISITABLE_STRUCT( LinearAttrs, out_channels, use_bias, data_type, activation, regularizer); +FF_VISIT_FMTABLE(LinearAttrs); CHECK_VALID_OP_ATTR(LinearAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions.h index 7a3db05329..47462c3bc6 100644 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions.h +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions.h @@ -16,22 +16,28 @@ enum class LossFunction { IDENTITY }; +std::string format_as(LossFunction const &); +CHECK_FMTABLE(LossFunction); + LossFunction parse_loss_function_name(std::string const &); struct SparseCategoricalCrossEntropyLossAttrs { req replace_labels; // for aggregate_spec: More predictions than labels }; FF_VISITABLE_STRUCT(SparseCategoricalCrossEntropyLossAttrs, replace_labels); +FF_VISIT_FMTABLE(SparseCategoricalCrossEntropyLossAttrs); CHECK_VALID_OP_ATTR(SparseCategoricalCrossEntropyLossAttrs); struct OtherLossAttrs { req loss_type; }; FF_VISITABLE_STRUCT(OtherLossAttrs, loss_type); +FF_VISIT_FMTABLE(OtherLossAttrs); CHECK_VALID_OP_ATTR(OtherLossAttrs); using LossAttrs = variant; +CHECK_FMTABLE(LossAttrs); LossFunction get_loss_function(OtherLossAttrs const &); LossFunction get_loss_function(SparseCategoricalCrossEntropyLossAttrs const &); @@ -39,37 +45,4 @@ LossFunction get_loss_function(LossAttrs const &); } // namespace FlexFlow -namespace fmt { - -template <> -struct formatter<::FlexFlow::LossFunction> : formatter { - template - auto format(::FlexFlow::LossFunction d, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (d) { - case LossFunction::CATEGORICAL_CROSSENTROPY: - name = "CategoricalCrossEntropy"; - break; - case LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY: - name = "SparseCategoricalCrossEntropy"; - break; - case LossFunction::MEAN_SQUARED_ERROR_AVG_REDUCE: - name = "MeanSquaredErrorAvgReduce"; - break; - case LossFunction::MEAN_SQUARED_ERROR_SUM_REDUCE: - name = "MeanSquaredErrorSumReduce"; - break; - case LossFunction::IDENTITY: - name = "Identity"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index 658e1b7d98..e9a64fab36 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -8,6 +8,8 @@ namespace FlexFlow { struct NoopAttrs {}; FF_VISITABLE_STRUCT(NoopAttrs); +FF_VISIT_FMTABLE(NoopAttrs); + CHECK_VALID_OP_ATTR(NoopAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index efe29b3b2e..2e9e903411 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -13,9 +13,16 @@ enum class PoolOp { AVG, }; +std::string format_as(PoolOp); + struct Pool2DAttrs { - req kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w; - req pool_type; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int padding_h; + int padding_w; + PoolOp pool_type; req activation; }; FF_VISITABLE_STRUCT(Pool2DAttrs, @@ -27,32 +34,10 @@ FF_VISITABLE_STRUCT(Pool2DAttrs, padding_w, pool_type, activation); +FF_VISIT_FMTABLE(Pool2DAttrs); + CHECK_VALID_OP_ATTR(Pool2DAttrs); } // namespace FlexFlow -namespace fmt { - -template <> -struct formatter<::FlexFlow::PoolOp> : formatter { - template - auto format(::FlexFlow::PoolOp o, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (o) { - case PoolOp::AVG: - name = "Avg"; - break; - case PoolOp::MAX: - name = "Max"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 193d3b0dc8..24d8a78a54 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -12,10 +12,12 @@ namespace FlexFlow { struct ReduceAttrs { stack_vector axes; - req op_type; + Op op_type; req keepdims; }; FF_VISITABLE_STRUCT(ReduceAttrs, axes, op_type, keepdims); +FF_VISIT_FMTABLE(ReduceAttrs); + CHECK_VALID_OP_ATTR(ReduceAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index f848f879fc..8f03051d58 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -13,6 +13,8 @@ struct ReductionAttrs { req reduction_degree; }; FF_VISITABLE_STRUCT(ReductionAttrs, reduction_dim, reduction_degree); +FF_VISIT_FMTABLE(ReductionAttrs); + CHECK_VALID_OP_ATTR(ReductionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 83c4ae870b..acffb12337 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -13,6 +13,8 @@ struct RepartitionAttrs { req repartition_degree; }; FF_VISITABLE_STRUCT(RepartitionAttrs, repartition_dim, repartition_degree); +FF_VISIT_FMTABLE(RepartitionAttrs); + CHECK_VALID_OP_ATTR(RepartitionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 92e64a4120..d4d15bb54a 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -13,6 +13,8 @@ struct ReplicateAttrs { req replicate_degree; }; FF_VISITABLE_STRUCT(ReplicateAttrs, replicate_dim, replicate_degree); +FF_VISIT_FMTABLE(ReplicateAttrs); + CHECK_VALID_OP_ATTR(ReplicateAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index b118482a2b..930c76b47a 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -11,6 +11,8 @@ struct ReshapeAttrs { TensorShape shape; }; FF_VISITABLE_STRUCT(ReshapeAttrs, shape); +FF_VISIT_FMTABLE(ReshapeAttrs); + CHECK_VALID_OP_ATTR(ReshapeAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index 6030285f14..1f1918aa50 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -11,6 +11,8 @@ struct ReverseAttrs { ff_dim_t axis; }; FF_VISITABLE_STRUCT(ReverseAttrs, axis); +FF_VISIT_FMTABLE(ReverseAttrs); + CHECK_VALID_OP_ATTR(ReverseAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 9a776737f5..1ce475b217 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -12,6 +12,8 @@ struct SoftmaxAttrs { ff_dim_t dim; }; FF_VISITABLE_STRUCT(SoftmaxAttrs, dim); +FF_VISIT_FMTABLE(SoftmaxAttrs); + CHECK_VALID_OP_ATTR(SoftmaxAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index fa66bc46f5..f21fa6a2bf 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -9,10 +9,12 @@ namespace FlexFlow { struct SplitAttrs { - req> splits; + stack_vector splits; ff_dim_t axis; }; FF_VISITABLE_STRUCT(SplitAttrs, splits, axis); +FF_VISIT_FMTABLE(SplitAttrs); + CHECK_VALID_OP_ATTR(SplitAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index 413855913c..8b74d04895 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -8,10 +8,12 @@ namespace FlexFlow { struct TopKAttrs { - req k; + int k; req sorted; }; FF_VISITABLE_STRUCT(TopKAttrs, k, sorted); +FF_VISIT_FMTABLE(TopKAttrs); + CHECK_VALID_OP_ATTR(TopKAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 87db435979..60b5c2c853 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -12,6 +12,8 @@ struct TransposeAttrs { req> perm; }; FF_VISITABLE_STRUCT(TransposeAttrs, perm); +FF_VISIT_FMTABLE(TransposeAttrs); + CHECK_VALID_OP_ATTR(TransposeAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/src/embedding.cc b/lib/op-attrs/src/embedding.cc index 02cbfaa031..538ae1ff52 100644 --- a/lib/op-attrs/src/embedding.cc +++ b/lib/op-attrs/src/embedding.cc @@ -1,3 +1,17 @@ #include "op-attrs/ops/embedding.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +std::string format_as(AggregateOp const &o) { + switch (o) { + case AggregateOp::SUM: + return "Sum"; + case AggregateOp::AVG: + return "Avg"; + default: + throw mk_runtime_error("Unknown AggregateOp with value {}", + static_cast(o)); + } +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/loss_functions.cc b/lib/op-attrs/src/loss_functions.cc index 8b95722f7c..ca878f9ade 100644 --- a/lib/op-attrs/src/loss_functions.cc +++ b/lib/op-attrs/src/loss_functions.cc @@ -5,6 +5,23 @@ namespace FlexFlow { +std::string format_as(LossFunction const &d) { + switch (d) { + case LossFunction::CATEGORICAL_CROSSENTROPY: + return "CategoricalCrossEntropy"; + case LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY: + return "SparseCategoricalCrossEntropy"; + case LossFunction::MEAN_SQUARED_ERROR_AVG_REDUCE: + return "MeanSquaredErrorAvgReduce"; + case LossFunction::MEAN_SQUARED_ERROR_SUM_REDUCE: + return "MeanSquaredErrorSumReduce"; + case LossFunction::IDENTITY: + return "Identity"; + default: + throw mk_runtime_error("Unknown LossFunction with value {}", static_cast(d)); + } +} + LossFunction get_loss_type(OtherLossAttrs const &attrs) { return attrs.loss_type; } @@ -42,4 +59,6 @@ LossFunction parse_loss_name(std::string const &raw_name) { } } + + } // namespace FlexFlow diff --git a/lib/op-attrs/src/pool_2d.cc b/lib/op-attrs/src/pool_2d.cc index 0867aeb344..800f0f09d0 100644 --- a/lib/op-attrs/src/pool_2d.cc +++ b/lib/op-attrs/src/pool_2d.cc @@ -4,6 +4,17 @@ namespace FlexFlow { +std::string format_as(PoolOp o) { + switch (o) { + case PoolOp::AVG: + return "Avg"; + case PoolOp::MAX: + return "Max"; + default: + throw mk_runtime_error("Unknown PoolOp value {}", static_cast(o)); + } +} + namespace Input { constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, REPLICA = 4; diff --git a/lib/utils/test/common/CMakeLists.txt b/lib/op-attrs/test/CMakeLists.txt similarity index 52% rename from lib/utils/test/common/CMakeLists.txt rename to lib/op-attrs/test/CMakeLists.txt index 18681af9c6..15b420ec48 100644 --- a/lib/utils/test/common/CMakeLists.txt +++ b/lib/op-attrs/test/CMakeLists.txt @@ -1,14 +1,11 @@ -ff_add_library( +ff_add_test_executable( NAME - utils-test-common + op-attrs-test SRC_PATTERNS src/*.cc - PUBLIC_INCLUDE - include/ PRIVATE_INCLUDE src/ DEPS - utils - rapidcheck - doctest + op-attrs + utils-test-common ) diff --git a/lib/op-attrs/test/src/test_fmt.cc b/lib/op-attrs/test/src/test_fmt.cc new file mode 100644 index 0000000000..442ca69c54 --- /dev/null +++ b/lib/op-attrs/test/src/test_fmt.cc @@ -0,0 +1,10 @@ +#include "test/utils/all.h" +#include "op-attrs/operator_attrs.h" + +using namespace FlexFlow; + +TEST_CASE("ComputationGraphAttrs is fmtable") { + rc::dc_check([](ComputationGraphAttrs const &a) { + CHECK(fmt::to_string(a) != ""); + }); +} diff --git a/lib/pcg/CMakeLists.txt b/lib/pcg/CMakeLists.txt index 81009b0f1f..d064d95b12 100644 --- a/lib/pcg/CMakeLists.txt +++ b/lib/pcg/CMakeLists.txt @@ -1,12 +1,6 @@ ff_add_library( NAME pcg - SRC_PATTERNS - src/*.cc - PUBLIC_INCLUDE - include/ - PRIVATE_INCLUDE - src/ DEPS op-attrs utils diff --git a/lib/pcg/include/pcg/file_format/v1/data_type.h b/lib/pcg/include/pcg/file_format/v1/data_type.h index dad98e462d..73b2e6e82b 100644 --- a/lib/pcg/include/pcg/file_format/v1/data_type.h +++ b/lib/pcg/include/pcg/file_format/v1/data_type.h @@ -45,7 +45,7 @@ struct adl_serializer { namespace FlexFlow { static_assert(is_jsonable::value, ""); -static_assert(is_json_serializable::value, ""); +CHECK_AUTO_JSON_SERIALIZABLE(V1DataTypeValue); static_assert(is_json_deserializable::value, ""); static_assert(is_jsonable::value, ""); } // namespace FlexFlow diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index ac23248db6..9d2ce4ed03 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -1,23 +1,31 @@ -ff_add_library( - NAME - utils - SRC_PATTERNS - src/*.cc - PUBLIC_INCLUDE - include/ - PRIVATE_INCLUDE - src/ - DEPS - optional - expected - variant - visit_struct - fmt - invoke - json - any - cuda -) - -add_subdirectory(ffi) -add_subdirectory(test) +add_subdirectory(algorithms) +add_subdirectory(backports) +add_subdirectory(bidict) +# add_subdirectory(compile_time_sequence) +add_subdirectory(deduplicated_priority_queue) +add_subdirectory(disjoint_set) +# add_subdirectory(dot_format) +add_subdirectory(ff_exceptions) +add_subdirectory(ffi_helper) +add_subdirectory(fmt_extra) +add_subdirectory(fp16) +# add_subdirectory(graph) +add_subdirectory(hash_extra) +add_subdirectory(json_extra) +add_subdirectory(optional_extra) +add_subdirectory(overload) +add_subdirectory(preprocessor_extra) +add_subdirectory(random_extra) +add_subdirectory(rapidcheck_extra) +add_subdirectory(smart_ptrs) +add_subdirectory(stack_containers) +add_subdirectory(string_extra) +add_subdirectory(strong_typedef) +add_subdirectory(test_types) +add_subdirectory(testing) +add_subdirectory(tuple_extra) +add_subdirectory(type_index_extra) +add_subdirectory(type_list) +add_subdirectory(type_traits_extra) +add_subdirectory(variant_extra) +add_subdirectory(visitable) diff --git a/lib/utils/README.md b/lib/utils/README.md index 59140912d9..35d88a79ba 100644 --- a/lib/utils/README.md +++ b/lib/utils/README.md @@ -1,8 +1,6 @@ -# utils +# `visitable` -## `visitable` - -### Motivation +## Motivation FlexFlow's codebase makes heavy use of "plain old data"[^2] types[^1] (referred to as _product types_ in the rest of this document) such as the following: ```cpp @@ -246,7 +244,7 @@ FlexFlow's codebase contains tens if not hundreds of these product types, and so [^1]: aka product types, aka Haskell's `data`. Essentially types that are just a tuple of fields with names. [^2]: by "plain old data" we refer to the general idea behind [C++'s POD](https://en.cppreference.com/w/cpp/named_req/PODType), but not its exact definition -### Adding new `visitable` types +## Adding new `visitable` types FlexFlow's `visitable` support provides an easy way to express product types, and prevents any of the bugs listed above. To express the above definition of `Person` using `visitable`, we would write the following code: @@ -272,7 +270,7 @@ and any subset of the fields would raise an error at compile time. Without any a [^3]: The full list of properties is detailed in [Macros Details](#macro-reference) -### Limitations +## Limitations `visitable` types have two primary limitations. First, they do not support initialization with `(...)`: ```cpp @@ -294,7 +292,7 @@ FF_VISITABLE_STRUCT(MyInts, list1, list2); // CORRECT ``` A smaller limitation is that `FF_VISITABLE_STRUCT` only works from within the `FlexFlow` namespace (this is not much of an issue as all of the `FlexFlow` code resides in a single namespace). -### Advanced Features +## Advanced Features While `FF_VISITABLE_STRUCT` matches the behavior of many product types in FlexFlow's codebase, there are exceptions. Many of these resemble the code below: ```cpp @@ -395,11 +393,11 @@ struct TownPopulation { }; ``` -#### JSON Serialization +### JSON Serialization TODO -### Macro Reference +## Macro Reference The properties that are checked by each macro are as follows: @@ -427,20 +425,6 @@ The properties that are checked by each macro are as follows: [^4]: This is usually resolved by either wrapping the last field in a `req` or using `FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION` -### Internals +## Internals TODO - -## `stack_vector`, `stack_string`, `stack_map` - -## `strong_typedef` - -## `containers.h` - -## `graph` - -## `bidict` - -## `type_traits` - -## `test_types` diff --git a/lib/utils/algorithms/CMakeLists.txt b/lib/utils/algorithms/CMakeLists.txt new file mode 100644 index 0000000000..c416080912 --- /dev/null +++ b/lib/utils/algorithms/CMakeLists.txt @@ -0,0 +1,6 @@ +ff_add_library( + NAME utils-algorithms + DEPS + utils-type_traits_extra + utils-backports +) diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/algorithms/include/utils/algorithms/containers.decl.h similarity index 90% rename from lib/utils/include/utils/containers.decl.h rename to lib/utils/algorithms/include/utils/algorithms/containers.decl.h index 8ad65a4488..a33719a242 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/algorithms/include/utils/algorithms/containers.decl.h @@ -3,8 +3,10 @@ #include "utils/bidict.h" #include "utils/invoke.h" -#include "utils/optional.decl" +#include "utils/optional.decl.h" #include "utils/required_core.h" +#include "utils/sequence.decl.h" +#include "utils/type_traits.decl.h" #include "utils/type_traits_core.h" #include #include @@ -30,6 +32,10 @@ std::string template std::string join_strings(Container const &c, std::string const &delimiter); +template +std::string + join_strings(Container const &c, std::string const &delimiter, F const &f); + template typename Container::const_iterator find(Container const &c, typename Container::value_type const &e); @@ -223,6 +229,16 @@ template ()(std::declval()))> std::vector transform(std::vector const &v, F const &f); +template +auto transform(std::tuple const &tup, F const &f); + +template +void for_each(std::tuple const &tup, F const &f); + +template +enable_if_t, std::vector> + to_vector(std::tuple const &tup); + template auto transform(req const &c, F const &f) -> decltype(transform(std::declval(), std::declval())); @@ -250,10 +266,15 @@ std::vector count(size_t n); template ()( - std::declval()))::value_type> + typename Out = get_element_type_t>> std::vector flatmap(std::vector const &v, F const &f); +template >, + std::string>>> +std::string flatmap(std::string const &v, F const &f); + template >> @@ -266,8 +287,11 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, template void inplace_sorted_by(C &c, F const &f); -template -std::vector sorted_by(C const &c, F const &f); +template +std::vector> sorted_by(C const &c, F const &f); + +template +std::vector> sorted(C const &c); template std::function compare_by(F const &f); diff --git a/lib/utils/algorithms/include/utils/algorithms/containers.h b/lib/utils/algorithms/include/utils/algorithms/containers.h new file mode 100644 index 0000000000..8721535928 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/containers.h @@ -0,0 +1,415 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_INL +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_INL + +#include "bidict.h" +#include "containers.decl.h" +#include "invoke.h" +#include "required_core.h" +#include "type_traits_core.h" +#include "utils/exception.h" +#include "utils/fmt.h" +#include "utils/optional.h" +#include "utils/sequence.h" +#include "utils/type_traits.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/type_traits.h" + +namespace FlexFlow { + +template +typename Container::const_iterator + find(Container const &c, typename Container::value_type const &e) { + return std::find(c.cbegin(), c.cend(), e); +} + +template +bool contains(Container const &c, typename Container::value_type const &e) { + return find(c, e) != c.cend(); +} + +template +std::unordered_map map_keys(std::unordered_map const &m, + F const &f) { + std::unordered_map result; + for (auto const &kv : m) { + result.insert({f(kv.first), kv.second}); + } + return result; +} + +template +std::unordered_map filter_keys(std::unordered_map const &m, + F const &f) { + std::unordered_map result; + for (auto const &kv : m) { + if (f(kv.first)) { + result.insert(kv); + } + } + return result; +} + +template +bidict filter_values(bidict const &m, F const &f) { + std::unordered_map result; + for (auto const &kv : m) { + if (f(kv.second)) { + result.equate(kv); + } + } + return result; +} + +template +std::unordered_map map_values(std::unordered_map const &m, + F const &f) { + std::unordered_map result; + for (auto const &kv : m) { + result.insert({kv.first, f(kv.second)}); + } + return result; +} + +template +std::unordered_map filter_values(std::unordered_map const &m, + F const &f) { + std::unordered_map result; + for (auto const &kv : m) { + if (f(kv.second)) { + result.insert(kv); + } + } + return result; +} + +template +bool is_submap(std::unordered_map const &m, + std::unordered_map const &sub) { + return restrict_keys(m, keys(sub)) == sub; +} + +template +std::unordered_set keys(C const &c) { + std::unordered_set result; + for (auto const &kv : c) { + result.insert(kv.first); + } + return result; +} + +template +std::vector values(C const &c) { + std::vector result; + for (auto const &kv : c) { + result.push_back(kv.second); + } + return result; +} + +template +std::unordered_set> + items(C const &c) { + return {c.begin(), c.end()}; +} + +template +std::unordered_set unique(C const &c) { + return {c.cbegin(), c.cend()}; +} + +template +std::unordered_set without_order(C const &c) { + return unique(c); +} + +template +tl::optional index_of(Container const &c, Element const &e) { + auto it = std::find(c.cbegin(), c.cend(), e); + if (it == c.cend()) { + return tl::nullopt; + } else { + return std::distance(c.cbegin(), it); + } +} + +template +std::unordered_set intersection(std::unordered_set const &l, + std::unordered_set const &r) { + std::unordered_set result; + for (T const &ll : l) { + if (contains(r, ll)) { + result.insert(ll); + } + } + return result; +} + +template +optional intersection(C const &c) { + optional result; + for (T const &t : c) { + result = intersection(result.value_or(t), t); + } + + return result; +} + +template +bool are_disjoint(std::unordered_set const &l, + std::unordered_set const &r) { + return intersection(l, r).empty(); +} + +template +std::unordered_map restrict_keys(std::unordered_map const &m, + std::unordered_set const &mask) { + std::unordered_map result; + for (auto const &kv : m) { + if (contains(mask, kv.first)) { + result.insert(kv); + } + } + return result; +} + +template +std::unordered_map merge_maps(std::unordered_map const &lhs, + std::unordered_map const &rhs) { + assert(are_disjoint(keys(lhs), keys(rhs))); + + std::unordered_map result; + for (auto const &kv : lhs) { + result.insert(kv); + } + for (auto const &kv : rhs) { + result.insert(kv); + } + + return result; +} + +template +std::unordered_map generate_map(C const &c, F const &f) { + static_assert(is_hashable::value, + "Key type should be hashable (but is not)"); + + auto transformed = transform(c, [&](K const &k) -> std::pair { + return {k, f(k)}; + }); + return {transformed.cbegin(), transformed.cend()}; +} + +template +std::function lookup_in(std::unordered_map const &m) { + return [&m](K const &k) -> V { return m.at(k); }; +} + +template +std::unordered_set set_union(C const &sets) { + std::unordered_set result; + for (std::unordered_set const &s : sets) { + for (T const &element : s) { + result.insert(element); + } + } + return result; +} + + +template +bool is_supserseteq_of(std::unordered_set const &l, + std::unordered_set const &r) { + return is_subseteq_of(r, l); +} + +template +optional maybe_get_only(C const &c) { + if (c.size() == 1) { + return *c.cbegin(); + } else { + return nullopt; + } +} + +template +typename C::value_type get_only(C const &c) { + return unwrap(maybe_get_only(c), [&] { + throw mk_runtime_error("Encountered container with size {} in get_only", + c.size()); + }); +} + +template +bool all_of(C const &c, F const &f) { + for (auto const &v : c) { + if (!f(v)) { + return false; + } + } + return true; +} + +template +int count(C const &c, F const &f) { + int result = 0; + for (auto const &v : c) { + if (f(v)) { + result++; + } + } + return result; +} + +template +bool are_all_same(C const &c) { + auto const &first = *c.cbegin(); + for (auto const &v : c) { + if (v != first) { + return false; + } + } + return true; +} + +template +std::vector as_vector(C const &c) { + std::vector result(c.cbegin(), c.end()); + return result; +} + +template +std::vector transform(std::vector const &v, F const &f) { + std::vector result; + std::transform(v.cbegin(), v.cend(), std::back_inserter(result), f); + return result; +} + +template +auto transform(std::tuple const &tup, F const &f) { + return seq_transform( + [&](auto idx) { return f(std::get(tup)); }, + seq_enumerate_args_t{}); +} + +template +void for_each(std::tuple const &tup, F const &f) { + seq_for_each([&](auto idx) { f(std::get(tup)); }, + seq_enumerate_args_t{}); +} + +template +enable_if_t, std::vector> + to_vector(std::tuple const &tup) { + std::vector result; + for_each(tup, [&](Head const &h) { result.push_back(h); }); + return result; +} + +template +auto transform(req const &c, F const &f) + -> decltype(transform(std::declval(), std::declval())) { + return transform(static_cast(c), f); +} + +template +std::string transform(std::string const &s, F const &f) { + std::string result; + std::transform(s.cbegin(), s.cend(), std::back_inserter(result), f); + return result; +} + +template +bidict enumerate(std::unordered_set const &c) { + bidict m; + size_t idx = 0; + for (auto const &v : c) { + m.equate(idx++, v); + } + return m; +} + +std::vector count(size_t n); + +template +struct get_element_type { + using type = typename C::value_type; +}; + +template +struct get_element_type> { + using type = T; +}; + +template +using get_element_type_t = typename get_element_type::type; + +template +std::unordered_set flatmap(std::unordered_set const &v, F const &f) { + std::unordered_set result; + for (auto const &elem : v) { + extend(result, f(elem)); + } + return result; +} + +template +std::unordered_set flatmap_v2(std::unordered_set const &v, + std::unordered_set (*f)(In const &)) { + std::unordered_set result; + for (auto const &elem : v) { + extend(result, f(elem)); + } + return result; +} + + +template +C filter(C const &v, F const &f) { + C result(v); + inplace_filter(result, f); + return result; +} + +template +void inplace_filter(C &v, F const &f) { + std::remove_if(v.begin(), v.end(), [&](Elem const &e) { return !f(e); }); +} + +template +typename C::value_type maximum(C const &v) { + return *std::max_element(v.begin(), v.end()); +} + +template +T reversed(T const &t) { + T r; + for (auto i = t.cend() - 1; i >= t.begin(); i--) { + r.push_back(*i); + } + return r; +} + +template +std::vector value_all(std::vector> const &v) { + return transform(v, [](optional const &element) { + return unwrap(element, [] { + throw mk_runtime_error( + "Encountered element without value in call to value_all"); + }); + }); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/extend.h b/lib/utils/algorithms/include/utils/algorithms/extend.h new file mode 100644 index 0000000000..2966b53713 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/extend.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_EXTEND_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_EXTEND_H + +#include +#include +#include +#include "utils/type_traits_extra/is_optional.h" + +namespace FlexFlow { + +template +void extend(C &lhs, std::optional> const &e) { + if (e.has_value()) { + extend(lhs, std::vector{e.value()}); + } +} + +template +void extend(C &lhs, std::nullopt_t) { + return; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/flatmap.h b/lib/utils/algorithms/include/utils/algorithms/flatmap.h new file mode 100644 index 0000000000..0fe00be0ef --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/flatmap.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_FLATMAP_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_FLATMAP_H + +#include +#include +#include + +namespace FlexFlow { + +template +auto flatmap(std::vector const &v, F const &f) { + std::vector> result; + for (In const &elem : v) { + extend(result, f(elem)); + } + return result; +} + +template >, + std::string>>> +std::string flatmap(std::string const &s, F const &f) { + std::ostringstream oss; + for (char c : s) { + oss << f(c); + } + return oss.str(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/generic/contains.h b/lib/utils/algorithms/include/utils/algorithms/generic/contains.h new file mode 100644 index 0000000000..7783a0aee2 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/generic/contains.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_CONTAINS_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_CONTAINS_H + +#include "find.h" + +namespace FlexFlow { + +template +bool contains(Container const &c, typename Container::value_type const &e) { + return find(c, e) != c.cend(); +} + +template +bool contains_key(C const &m, typename C::key_type const &k) { + return m.find(k) != m.end(); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/generic/extended.h b/lib/utils/algorithms/include/utils/algorithms/generic/extended.h new file mode 100644 index 0000000000..bcd89830b8 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/generic/extended.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_EXTENDED_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_EXTENDED_H + +#include "utils/algorithms/type/functor/functor.h" +#include "utils/algorithms/type/monoid/monoid.h" + +namespace FlexFlow { + +template , typename FunctorInstance = default_functor_t> +void extend(T1 &t1, T2 const &t2) { + static_assert(std::is_constructible_v>); + + T1 t2_converted = mconcat(fmap(t2, [](element_type_t const &a) { return T1{a}; })); + mappend_inplace(t1, t2_converted); +} + +template +T1 extended(T1 const &t1, T2 const &t2) { + T1 result = t1; + extend(result, t2); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/generic/find.h b/lib/utils/algorithms/include/utils/algorithms/generic/find.h new file mode 100644 index 0000000000..93c26b07e9 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/generic/find.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_FIND_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_FIND_H + +#include + +namespace FlexFlow { + +template +typename Container::const_iterator + find(Container const &c, typename Container::value_type const &e) { + return std::find(c.cbegin(), c.cend(), e); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/generic/product.h b/lib/utils/algorithms/include/utils/algorithms/generic/product.h new file mode 100644 index 0000000000..a4e8e72d3a --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/generic/product.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_PRODUCT_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_PRODUCT_H + +#include "utils/algorithms/type/monoid/functions/mproduct.h" + +namespace FlexFlow { + +template +auto product(C const &c) { + return mproduct(c); +} + +template +auto product_where(C const &c, F const &f) { + return mproductWhere(c, f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/generic/reversed.h b/lib/utils/algorithms/include/utils/algorithms/generic/reversed.h new file mode 100644 index 0000000000..a1a0f6a512 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/generic/reversed.h @@ -0,0 +1,85 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_REVERSED_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_REVERSED_H + +namespace FlexFlow { + +template +struct reversed_container_t { + reversed_container_t() = delete; + reversed_container_t(C const &c) : container(c) {} + + reversed_container_t(reversed_container_t const &) = delete; + reversed_container_t(reversed_container_t &&) = delete; + reversed_container_t &operator=(reversed_container_t const &) = delete; + reversed_container_t &operator=(reversed_container_t &&) = delete; + + using iterator = typename C::reverse_iterator; + using const_iterator = typename C::const_reverse_iterator; + using reverse_iterator = typename C::iterator; + using const_reverse_iterator = typename C::const_iterator; + using value_type = typename C::value_type; + using pointer = typename C::pointer; + using const_pointer = typename C::const_pointer; + using reference = typename C::reference; + using const_reference = typename C::const_reference; + + iterator begin() { + return this->container.rend(); + } + + iterator end() { + return this->container.rbegin(); + } + + const_iterator cbegin() const { + return this->container.crend(); + } + + const_iterator cend() const { + return this->container.crbegin(); + } + + const_iterator begin() const { + return this->cbegin(); + } + + const_iterator end() const { + return this->cend(); + } + + reverse_iterator rbegin() { + return this->container.begin(); + } + + reverse_iterator rend() { + return this->container.end(); + } + + const_reverse_iterator crbegin() const { + return this->container.cbegin(); + } + + const_reverse_iterator crend() const { + return this->container.cend(); + } + + const_reverse_iterator rbegin() const { + return this->crbegin(); + } + + const_reverse_iterator rend() const { + return this->crend(); + } + +private: + C const &container; +}; + +template +reversed_container_t reversed_container(C const &c) { + return reversed_container_t(c); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/generic/sum.h b/lib/utils/algorithms/include/utils/algorithms/generic/sum.h new file mode 100644 index 0000000000..89015ef849 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/generic/sum.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_SUM_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_SUM_H + +#include "utils/algorithms/type/monoid/functions/msum.h" + +namespace FlexFlow { + +template +auto sum(C const &c) { + return msum(c); +} + +template +auto sum_where(C const &c, F const &f) { + return msumWhere(c, f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/generic/transform.h b/lib/utils/algorithms/include/utils/algorithms/generic/transform.h new file mode 100644 index 0000000000..775f6c6315 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/generic/transform.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_GENERIC_TRANSFORM_H + +#include "utils/algorithms/type/functor/functor.h" + +namespace FlexFlow { + +template +auto transform(C const &c, F const &f) { + return fmap(c, f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/set/difference.h b/lib/utils/algorithms/include/utils/algorithms/set/difference.h new file mode 100644 index 0000000000..aa790ca6a1 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/set/difference.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_DIFFERENCE_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_DIFFERENCE_H + +#include + +namespace FlexFlow { + +template +std::unordered_set set_difference(std::unordered_set const &l, + std::unordered_set const &r) { + return filter(l, [&](T const &element) { return !contains(r, element); }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/set/extend.h b/lib/utils/algorithms/include/utils/algorithms/set/extend.h new file mode 100644 index 0000000000..0b3675f831 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/set/extend.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_EXTEND_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_EXTEND_H + +#include +#include "utils/type_traits_extra/is_optional.h" + +namespace FlexFlow { + +template >> +void extend(std::unordered_set &lhs, C const &rhs) { + lhs.reserve(lhs.size() + std::distance(rhs.begin(), rhs.end())); + lhs.insert(rhs.cbegin(), rhs.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/set/filter.h b/lib/utils/algorithms/include/utils/algorithms/set/filter.h new file mode 100644 index 0000000000..8db9381551 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/set/filter.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_FILTER_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_FILTER_H + +#include + +namespace FlexFlow { + +template +std::unordered_set filter(std::unordered_set const &v, F const &f) { + std::unordered_set result; + for (T const &t : v) { + if (f(t)) { + result.insert(t); + } + } + return result; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/set/get_first.h b/lib/utils/algorithms/include/utils/algorithms/set/get_first.h new file mode 100644 index 0000000000..213bc507a6 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/set/get_first.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_GET_FIRST_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_GET_FIRST_H + +namespace FlexFlow { + +template +T get_first(std::unordered_set const &s) { + return *s.cbegin(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/set/is_subseteq.h b/lib/utils/algorithms/include/utils/algorithms/set/is_subseteq.h new file mode 100644 index 0000000000..91967c65f3 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/set/is_subseteq.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_IS_SUBSETEQ_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_IS_SUBSETEQ_H + +#include + +namespace FlexFlow { + +template +bool is_subseteq_of(std::unordered_set const &l, + std::unordered_set const &r) { + if (l.size() > r.size()) { + return false; + } + + for (auto const &ll : l) { + if (!contains(r, ll)) { + return false; + } + } + return true; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/set/union.h b/lib/utils/algorithms/include/utils/algorithms/set/union.h new file mode 100644 index 0000000000..bf8de9f7b8 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/set/union.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_UNION_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_SET_UNION_H + +#include + +namespace FlexFlow { + +template +std::unordered_set set_union(std::unordered_set const &l, + std::unordered_set const &r) { + std::unordered_set result = l; + result.insert(r.cbegin(), r.cend()); + return result; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/sorting.h b/lib/utils/algorithms/include/utils/algorithms/sorting.h new file mode 100644 index 0000000000..1ec2b0e5e6 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/sorting.h @@ -0,0 +1,62 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_ALGORITHMS_SORTING_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_ALGORITHMS_SORTING_H + +#include "utils/type_traits_extra/iterator.h" +#include +#include "utils/backports/type_identity.h" +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct sorted_elem_type : type_identity {}; + +template +struct sorted_elem_type> : type_identity> {}; + +template +struct sorted_elem_type> : type_identity> {}; + +template +using sorted_elem_type_t = typename sorted_elem_type::type; + +template +void inplace_sorted_by(C &c, F const &f) { + CHECK_SUPPORTS_ITERATOR_TAG(std::random_access_iterator_tag, C); + + using Elem = sorted_elem_type_t; + + auto custom_comparator = [&](Elem const &lhs, Elem const &rhs) -> bool { + return f(lhs, rhs); + }; + std::sort(c.begin(), c.end(), custom_comparator); +} + +template +std::vector> sorted_by(C const &c, F const &f) { + using Elem = sorted_elem_type_t; + + std::vector result(c.begin(), c.end()); + inplace_sorted_by(result, f); + return result; +} + +template +std::vector> sorted(C const &c) { + using Elem = sorted_elem_type_t; + + return sorted_by(c, + [](Elem const &lhs, Elem const &rhs) { return lhs < rhs; }); +} + +template +std::function compare_by(F const &f) { + return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; +} + +} + +#endif diff --git a/lib/utils/algorithms/include/utils/algorithms/type/array.h b/lib/utils/algorithms/include/utils/algorithms/type/array.h new file mode 100644 index 0000000000..9970b8664e --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/type/array.h @@ -0,0 +1 @@ +#include "utils/algorithms/type/functor/instances/array.h" diff --git a/lib/utils/algorithms/include/utils/algorithms/type/list.h b/lib/utils/algorithms/include/utils/algorithms/type/list.h new file mode 100644 index 0000000000..5fe22362df --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/type/list.h @@ -0,0 +1,2 @@ +#include "utils/algorithms/type/monoid/instances/list.h" +#include "utils/algorithms/type/functor/instances/list.h" diff --git a/lib/utils/algorithms/include/utils/algorithms/type/optional.h b/lib/utils/algorithms/include/utils/algorithms/type/optional.h new file mode 100644 index 0000000000..22d02df89b --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/type/optional.h @@ -0,0 +1,2 @@ +#include "utils/algorithms/typeclass/functor/instances/optional.h" +#include "utils/algorithms/typeclass/monoid/instances/optional.h" diff --git a/lib/utils/algorithms/include/utils/algorithms/type/unordered_map.h b/lib/utils/algorithms/include/utils/algorithms/type/unordered_map.h new file mode 100644 index 0000000000..f2605545e1 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/type/unordered_map.h @@ -0,0 +1,2 @@ +#include "utils/algorithms/typeclass/functor/instances/unordered_map.h" +#include "utils/algorithms/typeclass/monoid/instances/unordered_map.h" diff --git a/lib/utils/algorithms/include/utils/algorithms/type/unordered_set.h b/lib/utils/algorithms/include/utils/algorithms/type/unordered_set.h new file mode 100644 index 0000000000..a1e3f38e7e --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/type/unordered_set.h @@ -0,0 +1,2 @@ +#include "utils/algorithms/typeclass/functor/instances/unordered_set.h" +#include "utils/algorithms/typeclass/monoid/instances/unordered_set.h" diff --git a/lib/utils/algorithms/include/utils/algorithms/type/vector.h b/lib/utils/algorithms/include/utils/algorithms/type/vector.h new file mode 100644 index 0000000000..46b98d4f78 --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/type/vector.h @@ -0,0 +1,2 @@ +#include "utils/algorithms/typeclass/monoid/instances/vector.h" +#include "utils/algorithms/typeclass/functor/instances/vector.h" diff --git a/lib/utils/algorithms/include/utils/algorithms/typeclass/functor/functor.h b/lib/utils/algorithms/include/utils/algorithms/typeclass/functor/functor.h new file mode 100644 index 0000000000..d8e2fe03ce --- /dev/null +++ b/lib/utils/algorithms/include/utils/algorithms/typeclass/functor/functor.h @@ -0,0 +1,64 @@ +#ifndef _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_TYPE_FUNCTOR_FUNCTOR_H +#define _FLEXFLOW_LIB_UTILS_ALGORITHMS_INCLUDE_UTILS_ALGORITHMS_TYPE_FUNCTOR_FUNCTOR_H + +#include +#include +#include "utils/backports/type_identity.h" + +namespace FlexFlow { + +struct opaque_input_type_t { }; +struct opaque_output_type_t { }; + + + +template < + typename Instance, + typename B = opaque_output_type_t, + typename F_B = typename Instance::template F, + typename A = typename Instance::A, + typename F_A = typename Instance::template F> +struct is_valid_functor_instance + : std::is_invocable< + decltype(Instance::template fmap>), + F_A, + std::function + > { }; + +template +inline constexpr bool is_valid_functor_instance_v = is_valid_functor_instance::value; + +/* template */ +/* struct is_valid_functor_instance2 */ +/* : std::is_same< */ +/* std::invoke_result_t), T const &, std::function>, */ +/* T */ +/* > { }; */ + +// template