diff --git a/conda/recipes/libcuvs/recipe.yaml b/conda/recipes/libcuvs/recipe.yaml index b192f2af5f..7f7a180049 100644 --- a/conda/recipes/libcuvs/recipe.yaml +++ b/conda/recipes/libcuvs/recipe.yaml @@ -73,6 +73,7 @@ cache: - ${{ stdlib("c") }} host: - libnvjitlink-dev + - cuda-nvrtc-dev - librmm =${{ minor_version }} - libraft-headers =${{ minor_version }} - nccl ${{ nccl_version }} @@ -122,6 +123,7 @@ outputs: - libcusolver-dev - libcusparse-dev - libnvjitlink-dev + - cuda-nvrtc-dev run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - libraft-headers =${{ minor_version }} @@ -184,6 +186,7 @@ outputs: - libcusolver-dev - libcusparse-dev - libnvjitlink-dev + - cuda-nvrtc-dev run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - ${{ pin_subpackage("libcuvs-headers", exact=True) }} @@ -245,6 +248,7 @@ outputs: - libcusolver-dev - libcusparse-dev - libnvjitlink-dev + - cuda-nvrtc-dev run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - ${{ pin_subpackage("libcuvs-headers", exact=True) }} @@ -327,6 +331,10 @@ outputs: - librmm - mkl - nccl + - if: cuda_major == "13" + then: + - cuda-nvrtc + - libnvjitlink about: homepage: ${{ load_from_file("python/libcuvs/pyproject.toml").project.urls.Homepage }} license: ${{ load_from_file("python/libcuvs/pyproject.toml").project.license }} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index b33daa635a..3bcd877190 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -223,6 +223,9 @@ if(NOT BUILD_CPU_ONLY) "$" "$" ) + target_compile_definitions( + cuvs_cpp_headers INTERFACE $<$:CUVS_ENABLE_JIT_LTO> + ) target_link_libraries(cuvs_cpp_headers INTERFACE raft::raft rmm::rmm) generate_inst_matrix( @@ -340,6 +343,13 @@ if(NOT BUILD_CPU_ONLY) CUDA_SEPARABLE_COMPILATION ON POSITION_INDEPENDENT_CODE ON ) + target_compile_definitions( + cuvs-cagra-search + PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> + $<$:NVTX_ENABLED> + # Temporary: mirror cuvs_cpp_headers so JIT sources always see the macro when LTO is on + $<$:CUVS_ENABLE_JIT_LTO> + ) target_link_libraries( cuvs-cagra-search PRIVATE cuvs::cuvs_cpp_headers $ @@ -382,6 +392,7 @@ if(NOT BUILD_CPU_ONLY) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) set(JIT_LTO_TARGET_ARCHITECTURE "75-real") endif() + set(JIT_LTO_COMPILATION ON) # Generate interleaved scan kernel files at build time include(cmake/modules/generate_jit_lto_kernels.cmake) @@ -398,6 +409,10 @@ if(NOT BUILD_CPU_ONLY) ) target_compile_features(jit_lto_kernel_usage_requirements INTERFACE cuda_std_20) target_link_libraries(jit_lto_kernel_usage_requirements INTERFACE rmm::rmm raft::raft CCCL::CCCL) + # Kernel OBJECT targets (add_jit_lto_kernel) only pull usage requirements from here, not from + # cuvs_objs — they must see the same JIT/LTO preprocessor branch as host code (e.g. + # standard_dataset_descriptor_t<..., QueryT> in compute_distance_standard-impl.cuh). + target_compile_definitions(jit_lto_kernel_usage_requirements INTERFACE CUVS_ENABLE_JIT_LTO) block(PROPAGATE jit_lto_files) set(jit_lto_files) @@ -427,7 +442,7 @@ if(NOT BUILD_CPU_ONLY) "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_kernel.cu.in" FRAGMENT_TAG_FORMAT "${ivf_flat_ns}::fragment_tag_interleaved_scan<${ivf_flat_ns}::tag_@data_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, @capacity@, @ascending_value@>" - FRAGMENT_TAG_HEADER_FILES "" + FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/interleaved_scan" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -442,7 +457,7 @@ if(NOT BUILD_CPU_ONLY) "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/detail/jit_lto_kernels/load_and_compute_dist_kernel.cu.in" FRAGMENT_TAG_FORMAT "${ivf_flat_ns}::fragment_tag_load_and_compute_dist<${ivf_flat_ns}::tag_@data_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, @compute_norm_value@, @veclen@>" - FRAGMENT_TAG_HEADER_FILES "" + FRAGMENT_TAG_HEADER_FILES "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/load_and_compute_dist" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) @@ -455,7 +470,7 @@ if(NOT BUILD_CPU_ONLY) "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/detail/jit_lto_kernels/metric_kernel.cu.in" FRAGMENT_TAG_FORMAT "${ivf_flat_ns}::fragment_tag_metric<${ivf_flat_ns}::tag_@data_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, ${ivf_flat_ns}::tag_metric_@metric_name@, @veclen@>" - FRAGMENT_TAG_HEADER_FILES "" + FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/metric" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -469,7 +484,7 @@ if(NOT BUILD_CPU_ONLY) "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/detail/jit_lto_kernels/filter_kernel.cu.in" FRAGMENT_TAG_FORMAT "${ivf_flat_ns}::fragment_tag_filter<${neighbors_ns}::tag_index_@index_abbrev@, ${neighbors_ns}::tag_filter_@filter_name@>" - FRAGMENT_TAG_HEADER_FILES "" + FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/filter" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements @@ -483,7 +498,7 @@ if(NOT BUILD_CPU_ONLY) "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/detail/jit_lto_kernels/post_process_kernel.cu.in" FRAGMENT_TAG_FORMAT "${ivf_flat_ns}::fragment_tag_post_lambda<${ivf_flat_ns}::tag_post_process_@post_process_name@>" - FRAGMENT_TAG_HEADER_FILES "" + FRAGMENT_TAG_HEADER_FILES "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/post_process" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) @@ -633,6 +648,157 @@ if(NOT BUILD_CPU_ONLY) OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_pq/increment_score" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) + set(cagra_ns "cuvs::neighbors::cagra::detail") + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_setup_workspace<${cagra_ns}::tag_@data_abbrev@, ${cagra_ns}::tag_idx_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${cagra_ns}::tag_@query_abbrev@, ${cagra_ns}::@jit_fragment_codebook_tag@, ${cagra_ns}::tag_team_@team_size@, ${cagra_ns}::tag_dim_@dataset_block_dim@, ${cagra_ns}::tag_pq_bits_@pq_bits@, ${cagra_ns}::tag_pq_len_@pq_len@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/setup_workspace" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_compute_distance<${cagra_ns}::tag_@data_abbrev@, ${cagra_ns}::tag_idx_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${cagra_ns}::tag_@query_abbrev@, ${cagra_ns}::@jit_fragment_codebook_tag@, ${cagra_ns}::tag_team_@team_size@, ${cagra_ns}::tag_dim_@dataset_block_dim@, ${cagra_ns}::tag_pq_bits_@pq_bits@, ${cagra_ns}::tag_pq_len_@pq_len@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/compute_distance" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_dist_op_@metric_tag@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_dist_op<${cagra_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${cagra_ns}::@jit_metric_tag@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/dist_op" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_apply_normalization_standard_@norm_kind@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_apply_normalization_standard<${cagra_ns}::tag_@data_abbrev@, ${cagra_ns}::tag_idx_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${cagra_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_team_@team_size@, ${cagra_ns}::tag_dim_@dataset_block_dim@, ${cagra_ns}::tag_norm_@norm_kind@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/apply_normalization_standard" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_search_single_cta_@topk_by_bitonic_sort_str@_@bitonic_sort_and_merge_multi_warps_str@_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_search_single_cta<${cagra_ns}::tag_@data_abbrev@, ${cagra_ns}::tag_idx_@source_index_abbrev@, ${cagra_ns}::tag_idx_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, @topk_by_bitonic_sort@, @bitonic_sort_and_merge_multi_warps@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_single_cta" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_search_single_cta_p_@topk_by_bitonic_sort_str@_@bitonic_sort_and_merge_multi_warps_str@_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_search_single_cta_p<${cagra_ns}::tag_@data_abbrev@, ${cagra_ns}::tag_idx_@source_index_abbrev@, ${cagra_ns}::tag_idx_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, @topk_by_bitonic_sort@, @bitonic_sort_and_merge_multi_warps@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_single_cta_p" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_search_multi_cta_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_search_multi_cta<${cagra_ns}::tag_@data_abbrev@, ${cagra_ns}::tag_idx_@source_index_abbrev@, ${cagra_ns}::tag_idx_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_multi_cta" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_random_pickup_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_random_pickup<${cagra_ns}::tag_@data_abbrev@, ${cagra_ns}::tag_idx_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/random_pickup" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_compute_distance_to_child_nodes_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_compute_distance_to_child_nodes<${cagra_ns}::tag_@data_abbrev@, ${cagra_ns}::tag_idx_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${cagra_ns}::tag_idx_@source_index_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/compute_distance_to_child_nodes" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_apply_filter" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_apply_filter_kernel<${cagra_ns}::tag_idx_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${cagra_ns}::tag_idx_@source_index_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/apply_filter" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_sample_filter_@filter_name@_index_@source_index_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/filter_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/jit_lto_kernels/filter_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${neighbors_ns}::fragment_tag_sample_filter<${neighbors_ns}::tag_bitset_@bitset_abbrev@, ${cagra_ns}::tag_idx_@source_index_abbrev@, ${neighbors_ns}::tag_filter_@filter_name@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/filter" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) endblock() # Note that this matrix contains an `arch_includes` placeholder, since we don't currently have a @@ -944,8 +1110,11 @@ if(NOT BUILD_CPU_ONLY) ) target_compile_definitions( - cuvs_objs PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> - $<$:NVTX_ENABLED> + cuvs_objs + PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> + $<$:NVTX_ENABLED> + # Temporary: mirror cuvs_cpp_headers so JIT sources always see the macro when LTO is on + $<$:CUVS_ENABLE_JIT_LTO> ) target_link_libraries( @@ -956,6 +1125,7 @@ if(NOT BUILD_CPU_ONLY) ${CUVS_CTK_MATH_DEPENDENCIES} $ $ + $<$:CUDA::nvrtc> ) target_include_directories( diff --git a/cpp/cmake/modules/generate_jit_lto_kernels.cmake b/cpp/cmake/modules/generate_jit_lto_kernels.cmake index fc2166c734..da8982e10c 100644 --- a/cpp/cmake/modules/generate_jit_lto_kernels.cmake +++ b/cpp/cmake/modules/generate_jit_lto_kernels.cmake @@ -48,6 +48,19 @@ function(process_jit_lto_matrix_entry source_list_var) cmake_parse_arguments(_JIT_LTO "${options}" "${one_value}" "${multi_value}" ${ARGN}) populate_matrix_variables("${_JIT_LTO_MATRIX_JSON_ENTRY}") + + # Registration fragment tags (FRAGMENT_TAG_FORMAT) can track GPU template params in .cu.in; map + # those scalars to tag_* names here so matrix JSON does not duplicate tag strings. + if(DEFINED codebook_type) + if(codebook_type STREQUAL "void") + set(jit_fragment_codebook_tag "tag_codebook_none") + elseif(codebook_type STREQUAL "half") + set(jit_fragment_codebook_tag "tag_codebook_half") + else() + message(FATAL_ERROR "Unknown codebook_type for JIT fragment: ${codebook_type}") + endif() + endif() + string(CONFIGURE "${_JIT_LTO_NAME_FORMAT}" kernel_name @ONLY) string(CONFIGURE "${_JIT_LTO_FRAGMENT_TAG_FORMAT}" fragment_tag @ONLY) diff --git a/cpp/cmake/modules/register_fatbin.cpp.in b/cpp/cmake/modules/register_fatbin.cpp.in index b154132358..5881eff623 100644 --- a/cpp/cmake/modules/register_fatbin.cpp.in +++ b/cpp/cmake/modules/register_fatbin.cpp.in @@ -5,6 +5,7 @@ #include "@fatbin_header_file@" #include +#include @fragment_tag_header_files@ diff --git a/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp b/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp index ae975630c5..b6a27cf68b 100644 --- a/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp +++ b/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp @@ -37,10 +37,25 @@ struct AlgorithmLauncher { this->call(stream, grid, block, shared_mem, kernel_args); } + template + void dispatch_cooperative( + cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, Args&&... args) + { + static_assert( + std::is_same_v...)>, + "dispatch_cooperative() argument types do not match the kernel function signature FuncT"); + + void* kernel_args[] = {const_cast(static_cast(&args))...}; + this->call_cooperative(stream, grid, block, shared_mem, kernel_args); + } + + cudaLibrary_t get_library() { return this->library; } cudaKernel_t get_kernel() { return this->kernel; } private: void call(cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args); + void call_cooperative( + cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args); cudaKernel_t kernel; cudaLibrary_t library; }; diff --git a/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp new file mode 100644 index 0000000000..8837c345b7 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp @@ -0,0 +1,74 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::cagra::detail { + +template +struct fragment_tag_setup_workspace {}; + +template +struct fragment_tag_compute_distance {}; + +template +struct fragment_tag_dist_op {}; + +template +struct fragment_tag_apply_normalization_standard {}; + +template +struct fragment_tag_search_single_cta {}; + +template +struct fragment_tag_search_single_cta_p {}; + +template +struct fragment_tag_search_multi_cta {}; + +template +struct fragment_tag_random_pickup {}; + +template +struct fragment_tag_compute_distance_to_child_nodes {}; + +template +struct fragment_tag_apply_filter_kernel {}; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp deleted file mode 100644 index 526d094d5c..0000000000 --- a/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp +++ /dev/null @@ -1,47 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -namespace cuvs::neighbors::ivf_flat::detail { - -// Tag types for data types -struct tag_f {}; -struct tag_h {}; -struct tag_i8 {}; -struct tag_u8 {}; - -// Tag types for accumulator types -struct tag_acc_f {}; -struct tag_acc_h {}; -struct tag_acc_i32 {}; -struct tag_acc_u32 {}; - -// Tag types for distance metrics with full template info -struct tag_metric_euclidean {}; -struct tag_metric_inner_product {}; -struct tag_metric_custom_udf {}; - -// Tag types for post-processing -struct tag_post_process_identity {}; -struct tag_post_process_sqrt {}; -struct tag_post_process_compose {}; - -template -struct fragment_tag_interleaved_scan {}; - -template -struct fragment_tag_load_and_compute_dist {}; - -template -struct fragment_tag_metric {}; - -template -struct fragment_tag_filter {}; - -template -struct fragment_tag_post_lambda {}; - -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/include/cuvs/detail/jit_lto/registration_tags.hpp b/cpp/include/cuvs/detail/jit_lto/registration_tags.hpp new file mode 100644 index 0000000000..2215f1e804 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/registration_tags.hpp @@ -0,0 +1,98 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::detail::jit_lto { + +struct tag_f {}; +struct tag_h {}; +struct tag_sc {}; +struct tag_uc {}; +struct tag_idx_l {}; +struct tag_filter_none {}; +struct tag_filter_bitset {}; + +} // namespace cuvs::detail::jit_lto + +namespace cuvs::neighbors::cagra::detail { + +using cuvs::detail::jit_lto::tag_f; +using cuvs::detail::jit_lto::tag_filter_bitset; +using cuvs::detail::jit_lto::tag_filter_none; +using cuvs::detail::jit_lto::tag_h; +using cuvs::detail::jit_lto::tag_idx_l; +using cuvs::detail::jit_lto::tag_sc; +using cuvs::detail::jit_lto::tag_uc; + +struct tag_idx_ui {}; +struct tag_dist_f {}; +struct tag_metric_l2 {}; +struct tag_metric_inner_product {}; +struct tag_metric_cosine {}; +struct tag_metric_hamming {}; +struct tag_team_8 {}; +struct tag_team_16 {}; +struct tag_team_32 {}; +struct tag_dim_128 {}; +struct tag_dim_256 {}; +struct tag_dim_512 {}; +struct tag_pq_bits_0 {}; +struct tag_pq_bits_8 {}; +struct tag_pq_len_0 {}; +struct tag_pq_len_2 {}; +struct tag_pq_len_4 {}; +struct tag_codebook_none {}; +struct tag_codebook_half {}; +struct tag_metric_l1 {}; +struct tag_norm_noop {}; +struct tag_norm_cosine {}; + +} // namespace cuvs::neighbors::cagra::detail + +namespace cuvs::neighbors::ivf_flat::detail { + +using cuvs::detail::jit_lto::tag_f; +using cuvs::detail::jit_lto::tag_filter_bitset; +using cuvs::detail::jit_lto::tag_filter_none; +using cuvs::detail::jit_lto::tag_h; +using cuvs::detail::jit_lto::tag_idx_l; +using cuvs::detail::jit_lto::tag_sc; +using cuvs::detail::jit_lto::tag_uc; + +struct tag_i8 {}; +struct tag_u8 {}; + +struct tag_acc_f {}; +struct tag_acc_h {}; +struct tag_acc_i32 {}; +struct tag_acc_u32 {}; + +struct tag_metric_euclidean {}; +struct tag_metric_inner_product {}; +struct tag_metric_custom_udf {}; + +struct tag_post_process_identity {}; +struct tag_post_process_sqrt {}; +struct tag_post_process_compose {}; + +template +struct fragment_tag_interleaved_scan {}; + +template +struct fragment_tag_load_and_compute_dist {}; + +template +struct fragment_tag_metric {}; + +template +struct fragment_tag_filter {}; + +template +struct fragment_tag_post_lambda {}; + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index f214db295c..722aff3e51 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -171,6 +171,8 @@ struct index : cuvs::neighbors::index { /** Distance metric used for clustering. */ cuvs::distance::DistanceType metric() const noexcept; + void set_metric(cuvs::distance::DistanceType metric); + /** Whether `centers()` change upon extending the index (ivf_flat::extend). */ bool adaptive_centers() const noexcept; diff --git a/cpp/src/detail/jit_lto/AlgorithmLauncher.cpp b/cpp/src/detail/jit_lto/AlgorithmLauncher.cpp index abe519e688..0076f7b959 100644 --- a/cpp/src/detail/jit_lto/AlgorithmLauncher.cpp +++ b/cpp/src/detail/jit_lto/AlgorithmLauncher.cpp @@ -24,7 +24,6 @@ AlgorithmLauncher::AlgorithmLauncher(AlgorithmLauncher&& other) noexcept AlgorithmLauncher& AlgorithmLauncher::operator=(AlgorithmLauncher&& other) noexcept { if (this != &other) { - // Unload current library if it exists if (library != nullptr) { cudaLibraryUnload(library); } kernel = other.kernel; library = other.library; @@ -47,3 +46,27 @@ void AlgorithmLauncher::call( RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args)); } + +void AlgorithmLauncher::call_cooperative( + cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** kernel_args) +{ + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeCooperative; + attribute[0].val.cooperative = 1; + + cudaLaunchConfig_t config{}; + config.gridDim = grid; + config.blockDim = block; + config.stream = stream; + config.dynamicSmemBytes = shared_mem; + config.numAttrs = 1; + config.attrs = attribute; + + RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args)); +} + +std::unordered_map>& get_cached_launchers() +{ + static std::unordered_map> launchers; + return launchers; +} diff --git a/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp b/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp index 0556199ade..95a07bed44 100644 --- a/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp +++ b/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp @@ -73,8 +73,8 @@ std::shared_ptr AlgorithmPlanner::build() // Load the generated LTO IR and link them together nvJitLinkHandle handle; - const char* lopts[] = {"-lto", archs.c_str()}; - auto result = nvJitLinkCreate(&handle, 2, lopts); + const char* lopts[] = {"-lto", archs.c_str(), "-maxrregcount=64"}; + auto result = nvJitLinkCreate(&handle, 3, lopts); check_nvjitlink_result(handle, result); for (const auto& frag : this->fragments) { diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index f1650980e0..bca8d3314d 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -32,6 +32,7 @@ #include #include +// All includes are done before opening namespace to avoid nested namespace issues namespace cuvs::neighbors::cagra::detail { template ; using init_f = @@ -270,10 +278,21 @@ struct dataset_descriptor_host { }; template - dataset_descriptor_host(const DescriptorImpl& dd_host, InitF init) + dataset_descriptor_host(const DescriptorImpl& dd_host, + InitF init, + cuvs::distance::DistanceType metric_val, + uint32_t dataset_block_dim_val, + bool is_vpq_val = false, + uint32_t pq_bits_val = 0, + uint32_t pq_len_val = 0) : value_{std::make_shared(init, sizeof(DescriptorImpl))}, smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()}, - team_size{dd_host.team_size()} + team_size{dd_host.team_size()}, + metric{metric_val}, + dataset_block_dim{dataset_block_dim_val}, + is_vpq{is_vpq_val}, + pq_bits{pq_bits_val}, + pq_len{pq_len_val} { } diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh index 05adce20e9..e6dcd910cb 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh @@ -13,7 +13,40 @@ #include namespace cuvs::neighbors::cagra::detail { + +#if defined(CUVS_ENABLE_JIT_LTO) || defined(BUILD_KERNEL) + +// When JIT LTO is enabled or building kernel fragments, dist_op is an extern function that gets JIT +// linked from fragments Each fragment provides a metric-specific implementation (L2Expanded, +// InnerProduct, etc.) The planner will link the appropriate fragment based on the metric Note: +// extern functions cannot be constexpr, so we remove constexpr here Note: These are in the detail +// namespace (not anonymous) so they can be found by JIT linking +// QueryT can be float (for most metrics) or uint8_t (for BitwiseHamming) +template +extern __device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b); + +// Normalization is also JIT linked from fragments (no-op for most metrics, cosine normalization for +// CosineExpanded) The planner will link the appropriate fragment (cosine or noop) based on the +// metric +// QueryT is needed to match the descriptor template signature (always float for normalization) +template +extern __device__ DistanceT apply_normalization_standard( + DistanceT distance, + const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index); + +#endif + namespace { + +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + +// When JIT LTO is disabled, dist_op is a template function with Metric as a template parameter template requires(Metric == cuvs::distance::DistanceType::L2Expanded) RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) @@ -46,18 +79,33 @@ RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) DISTANCE_T diff = a - b; return raft::abs(diff); } + +#endif // #if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) } // namespace -template + typename DistanceT +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + , + cuvs::distance::DistanceType Metric +#else + , + typename QueryT +#endif + > struct standard_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; - using QUERY_T = typename std:: +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + // When JIT LTO is disabled, Metric is a template parameter + using QUERY_T = typename std:: conditional_t; +#else + // When JIT LTO is enabled, QueryT is passed as a template parameter + using QUERY_T = QueryT; +#endif using base_type::args; using base_type::smem_ws_size_in_bytes; using typename base_type::args_t; @@ -67,7 +115,9 @@ struct standard_dataset_descriptor_t : public dataset_descriptor_base_t uint32_t { return sizeof(standard_dataset_descriptor_t) + raft::round_up_safe(dim, DatasetBlockDim) * sizeof(QUERY_T); } + + private: }; template @@ -174,7 +226,6 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_standard( buf[j] = 0; } } - return const_cast(r); } @@ -211,7 +262,6 @@ RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_standard_worker( if (k >= dim) break; #pragma unroll for (uint32_t v = 0; v < vlen; v++) { - // Note this loop can go above the dataset_dim for padded arrays. This is not a problem // because: // - Above the last element (dataset_dim-1), the query array is filled with zeros. // - The data buffer has to be also padded with zeros. @@ -220,8 +270,16 @@ RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_standard_worker( d, query_smem_ptr + sizeof(QUERY_T) * device::swizzling(k + v)); +#if defined(CUVS_ENABLE_JIT_LTO) || defined(BUILD_KERNEL) + // When JIT LTO is enabled or building kernel fragments, dist_op is an extern function (no + // template parameters) + r += dist_op( + d, cuvs::spatial::knn::detail::utils::mapping{}(data[e][v])); +#else + // When JIT LTO is disabled, dist_op is a template function with Metric parameter r += dist_op( d, cuvs::spatial::knn::detail::utils::mapping{}(data[e][v])); +#endif } } } @@ -238,15 +296,32 @@ _RAFT_DEVICE __noinline__ auto compute_distance_standard( args.dim, args.smem_ws_ptr); +#if defined(CUVS_ENABLE_JIT_LTO) || defined(BUILD_KERNEL) + // Normalization is JIT linked from fragments (no-op or cosine normalization) + // The planner links the appropriate fragment based on the metric + distance = + apply_normalization_standard(distance, args, dataset_index); +#else + // When JIT LTO is disabled, kMetric is always available as a compile-time constant if constexpr (DescriptorT::kMetric == cuvs::distance::DistanceType::CosineExpanded) { const auto* dataset_norms = DescriptorT::dataset_norms_ptr(args); auto norm = dataset_norms[dataset_index]; if (norm > 0) { distance = distance / norm; } } +#endif return distance; } +#ifndef BUILD_KERNEL +// The init kernel is used for both JIT and non-JIT initialization +// When BUILD_KERNEL is defined, we're building a JIT fragment and don't want this kernel. +// The kernel handles JIT vs non-JIT via ifdef internally template ; + standard_dataset_descriptor_t; using base_type = typename desc_type::base_type; + + // For CUDA 12 (non-JIT), set the function pointers properly new (out) desc_type(reinterpret_cast( &setup_workspace_standard), reinterpret_cast( @@ -273,8 +352,30 @@ RAFT_KERNEL __launch_bounds__(1, 1) dim, ld, dataset_norms); +#else + // When JIT LTO is enabled, Metric is not a template parameter + using query_t = + std::conditional_t; + using desc_type = + standard_dataset_descriptor_t; + using base_type = typename desc_type::base_type; + + // For JIT, we don't use the function pointers, so set them to nullptr + // The free functions are called directly instead + new (out) desc_type(nullptr, // setup_workspace_impl - not used in JIT + nullptr, // compute_distance_impl - not used in JIT + ptr, + size, + dim, + ld, + dataset_norms); +#endif } +#endif // #ifndef BUILD_KERNEL +#ifndef BUILD_KERNEL +// The init_ function is used for both JIT and non-JIT initialization +// When BUILD_KERNEL is defined, we're building a JIT fragment and don't want this function. template ; + using base_type = typename desc_type::base_type; +#else + // When JIT LTO is enabled, Metric is not a template parameter + // QueryT depends on metric: uint8_t for BitwiseHamming, float for others + using query_t = + std::conditional_t; using desc_type = - standard_dataset_descriptor_t; + standard_dataset_descriptor_t; using base_type = typename desc_type::base_type; +#endif RAFT_EXPECTS(Metric != cuvs::distance::DistanceType::CosineExpanded || dataset_norms != nullptr, "Dataset norms must be provided for CosineExpanded metric"); - desc_type dd_host{nullptr, nullptr, ptr, size, dim, ld, dataset_norms}; - return host_type{dd_host, + return host_type{desc_type{nullptr, nullptr, ptr, size, dim, ld, dataset_norms}, [=](dataset_descriptor_base_t* dev_ptr, rmm::cuda_stream_view stream) { + // Use init kernel for both JIT and CUDA 12 + // The kernel handles JIT vs non-JIT via ifdef internally standard_dataset_descriptor_init_kernel <<<1, 1, 0, stream>>>(dev_ptr, ptr, size, dim, ld, dataset_norms); RAFT_CUDA_TRY(cudaPeekAtLastError()); - }}; + }, + Metric, + DatasetBlockDim, + false, // is_vpq + 0, // pq_bits + 0}; // pq_len } +#endif // #ifndef BUILD_KERNEL } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh index cdafb173ed..1cb593830d 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -15,19 +15,30 @@ namespace cuvs::neighbors::cagra::detail { -template + typename DistanceT +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + , + cuvs::distance::DistanceType Metric +#else + , + typename QueryT +#endif + > struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; using CODE_BOOK_T = CodebookT; - using QUERY_T = half; +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + using QUERY_T = half; +#else + using QUERY_T = QueryT; +#endif using base_type::args; using base_type::extra_ptr3; using typename base_type::args_t; @@ -37,7 +48,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t uint32_t { /* SMEM workspace layout: @@ -121,6 +134,8 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t(dim, DatasetBlockDim) * sizeof(QUERY_T); } + + private: }; template @@ -347,6 +362,10 @@ _RAFT_DEVICE __noinline__ auto compute_distance_vpq( args.smem_ws_ptr); } +#ifndef BUILD_KERNEL +// The init kernel is not needed when building JIT fragments (BUILD_KERNEL is defined) +// It's only needed for non-JIT initialization. When BUILD_KERNEL is defined, we're building +// a JIT fragment and don't want this kernel to be instantiated. template ; + DistanceT +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + , + Metric +#else + , + half +#endif + >; using base_type = typename desc_type::base_type; +#ifdef CUVS_ENABLE_JIT_LTO + // For JIT, we don't use the function pointers, so set them to nullptr + // The free functions are called directly instead + new (out) desc_type(nullptr, // setup_workspace_impl - not used in JIT + nullptr, // compute_distance_impl - not used in JIT + encoded_dataset_ptr, + encoded_dataset_dim, + vq_code_book_ptr, + pq_code_book_ptr, + size, + dim); +#else + // For CUDA 12 (non-JIT), set the function pointers properly new (out) desc_type( reinterpret_cast(&setup_workspace_vpq), reinterpret_cast(&compute_distance_vpq), @@ -384,8 +423,14 @@ RAFT_KERNEL __launch_bounds__(1, 1) pq_code_book_ptr, size, dim); +#endif } +#endif // #ifndef BUILD_KERNEL +#ifndef BUILD_KERNEL +// The init_ function is not needed when building JIT fragments (BUILD_KERNEL is defined) +// It's only needed for non-JIT initialization. When BUILD_KERNEL is defined, we're building +// a JIT fragment and don't want this host function to be included. template ; + DistanceT +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + , + Metric +#else + , + half +#endif + >; using base_type = typename desc_type::base_type; - desc_type dd_host{nullptr, - nullptr, - encoded_dataset_ptr, - encoded_dataset_dim, - vq_code_book_ptr, - pq_code_book_ptr, - size, - dim}; - return host_type{dd_host, + return host_type{desc_type{nullptr, + nullptr, + encoded_dataset_ptr, + encoded_dataset_dim, + vq_code_book_ptr, + pq_code_book_ptr, + size, + dim}, [=](dataset_descriptor_base_t* dev_ptr, rmm::cuda_stream_view stream) { + // Use init kernel for both JIT and CUDA 12 + // The kernel handles JIT vs non-JIT via ifdef internally vpq_dataset_descriptor_init_kernel +#include + +namespace { + +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +extern "C" __global__ void apply_filter_kernel(const source_index_t* const source_indices_ptr, + index_t* const result_indices_ptr, + distance_t* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const index_t query_id_offset, + uint32_t* bitset_ptr, + source_index_t bitset_len, + source_index_t original_nbits) +{ + apply_filter_kernel_jit(source_indices_ptr, + result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + bitset_ptr, + bitset_len, + original_nbits); +} + +static_assert(std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json new file mode 100644 index 0000000000..4f14f7d8c0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json @@ -0,0 +1,20 @@ +{ + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_cosine_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_cosine_impl.cuh new file mode 100644 index 0000000000..c691e58ef6 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_cosine_impl.cuh @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_standard-impl.cuh" + +namespace cuvs::neighbors::cagra::detail { + +// Cosine normalization fragment implementation +// This provides apply_normalization_standard that normalizes by dataset norm (for CosineExpanded +// metric) +// QueryT is needed to match the descriptor template signature, but not used in this function +template +__device__ DistanceT +apply_normalization_standard(DistanceT distance, + const typename cuvs::neighbors::cagra::detail:: + dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + // CosineExpanded normalization: divide by dataset norm + const auto* dataset_norms = + standard_dataset_descriptor_t:: + dataset_norms_ptr(args); + auto norm = dataset_norms[dataset_index]; + if (norm > 0) { distance = distance / norm; } + return distance; +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in new file mode 100644 index 0000000000..55a807a0f0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in @@ -0,0 +1,25 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using query_t = @query_type@; + +} // namespace + +namespace cuvs::neighbors::cagra::detail { + +using args_t = typename dataset_descriptor_base_t::args_t; +template __device__ distance_t +apply_normalization_standard<@team_size@, @dataset_block_dim@, data_t, index_t, distance_t, query_t>(distance_t, + const args_t, + index_t); + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json new file mode 100644 index 0000000000..a4e64e5616 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json @@ -0,0 +1,66 @@ +{ + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_normalization": [ + { + "norm_kind": "noop", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "norm_kind": "cosine", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_noop_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_noop_impl.cuh new file mode 100644 index 0000000000..e9b9bc6556 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_noop_impl.cuh @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_standard-impl.cuh" + +namespace cuvs::neighbors::cagra::detail { + +// No-op normalization fragment implementation +// This provides apply_normalization_standard that does nothing (for non-CosineExpanded metrics) +// QueryT is needed to match the descriptor template signature, but not used in this function +template +__device__ DistanceT +apply_normalization_standard(DistanceT distance, + const typename cuvs::neighbors::cagra::detail:: + dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + // No normalization needed for non-CosineExpanded metrics + return distance; +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp new file mode 100644 index 0000000000..0360a0769f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp @@ -0,0 +1,408 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_ENABLE_JIT_LTO +#error "cagra_jit_launcher_factory.hpp included but CUVS_ENABLE_JIT_LTO not defined!" +#endif + +#include "../compute_distance.hpp" +#include "../shared_launcher_jit.hpp" +#include "search_multi_cta_planner.hpp" +#include "search_multi_kernel_planner.hpp" +#include "search_single_cta_planner.hpp" + +#include +#include + +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +/// Build a JIT AlgorithmLauncher for single-CTA CAGRA search (runtime VPQ / metric → tag dispatch). +template +std::shared_ptr make_cagra_single_cta_jit_launcher( + const dataset_descriptor_host& dataset_desc, + bool topk_by_bitonic_sort, + bool bitonic_sort_and_merge_multi_warps, + bool persistent, + const std::string& filter_name) +{ + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + single_cta_search:: + CagraSingleCtaSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len, + persistent); + + planner.template add_setup_workspace_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_compute_distance_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_search_kernel_fragment( + topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + planner.add_sample_filter_device_function(filter_name); + return planner.get_launcher(); + } + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + single_cta_search:: + CagraSingleCtaSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len, + persistent); + + planner.template add_setup_workspace_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_compute_distance_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_search_kernel_fragment( + topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + planner.add_sample_filter_device_function(filter_name); + return planner.get_launcher(); + } + using QueryTag = query_type_tag_standard_t; + single_cta_search:: + CagraSingleCtaSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len, + persistent); + + planner.template add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_search_kernel_fragment( + topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + planner.add_sample_filter_device_function(filter_name); + return planner.get_launcher(); +} + +/// Build a JIT AlgorithmLauncher for multi-CTA CAGRA search. +template +std::shared_ptr make_cagra_multi_cta_jit_launcher( + const dataset_descriptor_host& dataset_desc, + const std::string& filter_name) +{ + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + multi_cta_search:: + CagraMultiCtaSearchPlanner + planner(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + planner.template add_setup_workspace_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_compute_distance_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_search_multi_cta_kernel_fragment(); + planner.add_sample_filter_device_function(filter_name); + return planner.get_launcher(); + } + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + multi_cta_search:: + CagraMultiCtaSearchPlanner + planner(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + planner.template add_setup_workspace_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_compute_distance_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_search_multi_cta_kernel_fragment(); + planner.add_sample_filter_device_function(filter_name); + return planner.get_launcher(); + } + using QueryTag = query_type_tag_standard_t; + multi_cta_search:: + CagraMultiCtaSearchPlanner + planner(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + planner.template add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_search_multi_cta_kernel_fragment(); + planner.add_sample_filter_device_function(filter_name); + return planner.get_launcher(); +} + +/// Build a JIT AlgorithmLauncher for multi-kernel CAGRA helpers (random_pickup, compute_distance, +/// …). +template +std::shared_ptr make_cagra_multi_kernel_jit_launcher( + const dataset_descriptor_host& dataset_desc, + const char* linked_kernel_name) +{ + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + multi_kernel_search:: + CagraMultiKernelSearchPlanner + planner(dataset_desc.metric, + linked_kernel_name, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_setup_workspace_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_compute_distance_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_linked_kernel(linked_kernel_name); + return planner.get_launcher(); + } + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + multi_kernel_search:: + CagraMultiKernelSearchPlanner + planner(dataset_desc.metric, + linked_kernel_name, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_setup_workspace_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_compute_distance_device_function( + dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_linked_kernel(linked_kernel_name); + return planner.get_launcher(); + } + using QueryTag = query_type_tag_standard_t; + multi_kernel_search:: + CagraMultiKernelSearchPlanner + planner(dataset_desc.metric, + linked_kernel_name, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.template add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_linked_kernel(linked_kernel_name); + return planner.get_launcher(); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp new file mode 100644 index 0000000000..2b2bf94272 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -0,0 +1,385 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +struct CagraPlannerBase : AlgorithmPlanner { + static inline LauncherJitCache launcher_jit_cache{}; + + explicit CagraPlannerBase(std::string entrypoint) + : AlgorithmPlanner(std::move(entrypoint), launcher_jit_cache) + { + } + + template + void add_setup_workspace_device_function(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq, + uint32_t pq_bits, + uint32_t pq_len) + { + (void)metric; + (void)is_vpq; + (void)pq_bits; + auto add = [&]() { + this->add_static_fragment>(); + }; + if constexpr (std::is_same_v) { + if (pq_bits != 0 || pq_len != 0) { + RAFT_FAIL("CAGRA JIT standard path expects pq_bits==0 and pq_len==0"); + } + if (team_size == 8) { + if (dataset_block_dim == 128) { + add.template operator()(); + } else if (dataset_block_dim == 256) { + add.template operator()(); + } else if (dataset_block_dim == 512) { + add.template operator()(); + } + } else if (team_size == 16) { + if (dataset_block_dim == 128) { + add.template operator()(); + } else if (dataset_block_dim == 256) { + add.template operator()(); + } else if (dataset_block_dim == 512) { + add.template operator()(); + } + } else if (team_size == 32) { + if (dataset_block_dim == 128) { + add.template operator()(); + } else if (dataset_block_dim == 256) { + add.template operator()(); + } else if (dataset_block_dim == 512) { + add.template operator()(); + } + } + } else { + if (pq_bits != 8 || (pq_len != 2 && pq_len != 4)) { + RAFT_FAIL("CAGRA JIT VPQ path expects pq_bits==8 and pq_len in {2,4}"); + } + if (team_size == 8) { + if (dataset_block_dim == 128) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 256) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 512) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } + } else if (team_size == 16) { + if (dataset_block_dim == 128) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 256) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 512) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } + } else if (team_size == 32) { + if (dataset_block_dim == 128) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 256) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 512) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } + } + } + } + + template + void add_compute_distance_device_function(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq, + uint32_t pq_bits, + uint32_t pq_len) + { + (void)is_vpq; + if (!is_vpq) { + add_dist_op_device_function(metric); + add_normalization_device_function( + metric, team_size, dataset_block_dim); + } + auto add = [&]() { + this->add_static_fragment>(); + }; + if constexpr (std::is_same_v) { + if (pq_bits != 0 || pq_len != 0) { + RAFT_FAIL("CAGRA JIT standard path expects pq_bits==0 and pq_len==0"); + } + if (team_size == 8) { + if (dataset_block_dim == 128) { + add.template operator()(); + } else if (dataset_block_dim == 256) { + add.template operator()(); + } else if (dataset_block_dim == 512) { + add.template operator()(); + } + } else if (team_size == 16) { + if (dataset_block_dim == 128) { + add.template operator()(); + } else if (dataset_block_dim == 256) { + add.template operator()(); + } else if (dataset_block_dim == 512) { + add.template operator()(); + } + } else if (team_size == 32) { + if (dataset_block_dim == 128) { + add.template operator()(); + } else if (dataset_block_dim == 256) { + add.template operator()(); + } else if (dataset_block_dim == 512) { + add.template operator()(); + } + } + } else { + if (pq_bits != 8 || (pq_len != 2 && pq_len != 4)) { + RAFT_FAIL("CAGRA JIT VPQ path expects pq_bits==8 and pq_len in {2,4}"); + } + if (team_size == 8) { + if (dataset_block_dim == 128) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 256) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 512) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } + } else if (team_size == 16) { + if (dataset_block_dim == 128) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 256) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 512) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } + } else if (team_size == 32) { + if (dataset_block_dim == 128) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 256) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } else if (dataset_block_dim == 512) { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + } + } + } + } + + template + void add_dist_op_device_function(cuvs::distance::DistanceType metric) + { + switch (metric) { + case cuvs::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2Unexpanded: + this->add_static_fragment>(); + break; + case cuvs::distance::DistanceType::InnerProduct: + this->add_static_fragment< + fragment_tag_dist_op>(); + break; + case cuvs::distance::DistanceType::CosineExpanded: + this->add_static_fragment< + fragment_tag_dist_op>(); + break; + case cuvs::distance::DistanceType::BitwiseHamming: + this + ->add_static_fragment>(); + break; + case cuvs::distance::DistanceType::L1: + this->add_static_fragment>(); + break; + default: RAFT_FAIL("Unsupported metric for CAGRA JIT dist_op"); + } + } + + // Maps runtime dataset layout (same grid as the JIT matrix) to (TeamTag, DimTag). IVF-style + // planners pass these as template parameters; CAGRA reads team_size / dataset_block_dim from + // the host descriptor at planning time. + template + static void dispatch_cagra_team_dim(uint32_t team_size, uint32_t dataset_block_dim, Lambda&& l) + { + switch (team_size) { + case 8: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()(); return; + case 256: std::forward(l).template operator()(); return; + case 512: std::forward(l).template operator()(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()(); return; + case 256: std::forward(l).template operator()(); return; + case 512: std::forward(l).template operator()(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()(); return; + case 256: std::forward(l).template operator()(); return; + case 512: std::forward(l).template operator()(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL( + "Unsupported team_size / dataset_block_dim for CAGRA JIT normalization: team=%u dim=%u", + static_cast(team_size), + static_cast(dataset_block_dim)); + } + + template + void add_normalization_device_function(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim) + { + auto go = [&]() { + dispatch_cagra_team_dim(team_size, dataset_block_dim, [&]() { + this->add_static_fragment>(); + }); + }; + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + go.template operator()(); + } else { + go.template operator()(); + } + } + + void add_sample_filter_device_function(std::string const& filter_name) + { + if (filter_name == "filter_none_source_index_ui") { + this->add_static_fragment>(); + } else if (filter_name == "filter_bitset_source_index_ui") { + this->add_static_fragment>(); + } else { + RAFT_FAIL("Unknown CAGRA sample filter name for JIT: %s", filter_name.c_str()); + } + } +}; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in new file mode 100644 index 0000000000..31a639dc7b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in @@ -0,0 +1,59 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using query_t = @query_type@; + +} // namespace + +namespace cuvs::neighbors::cagra::detail { + +using args_t = typename dataset_descriptor_base_t::args_t; +template __device__ distance_t +compute_distance<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, data_t, index_t, distance_t, query_t>( + const args_t, index_t); + +template <> +__device__ distance_t compute_distance_base(const args_t args, + index_t dataset_index, + bool valid, + uint32_t team_size_bits) +{ + auto per_thread = valid ? compute_distance<@team_size@, + @dataset_block_dim@, + @pq_bits@, + @pq_len@, + @codebook_type@, + data_t, + index_t, + distance_t, + query_t>(args, dataset_index) + : 0; + return device::team_sum(per_thread, team_size_bits); +} + +template <> +__device__ distance_t compute_distance_per_thread_base( + const args_t args, index_t dataset_index) +{ + return compute_distance<@team_size@, + @dataset_block_dim@, + @pq_bits@, + @pq_len@, + @codebook_type@, + data_t, + index_t, + distance_t, + query_t>(args, dataset_index); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json new file mode 100644 index 0000000000..39cf9ad2c5 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json @@ -0,0 +1,156 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "sc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_len": "0", + "pq_bits": "0", + "pq_prefix": "_standard", + "pq_suffix": "" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_tag": "", + "codebook_comma": "" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_len": "2", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_len": "4", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_tag": "tag_codebook_half", + "codebook_comma": ", " + } + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_standard_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_standard_impl.cuh new file mode 100644 index 0000000000..7aa1a12395 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_standard_impl.cuh @@ -0,0 +1,39 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_standard-impl.cuh" +#include "../device_common.hpp" // For dataset_descriptor_base_t + +namespace cuvs::neighbors::cagra::detail { + +// Unified compute_distance implementation for standard descriptors +// This is instantiated when PQ_BITS=0, PQ_LEN=0, CodebookT=void +// QueryT can be float (for most metrics) or uint8_t (for BitwiseHamming) +template +__device__ DistanceT +compute_distance(const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + // For standard descriptors, PQ_BITS=0, PQ_LEN=0, CodebookT=void + static_assert(PQ_BITS == 0 && PQ_LEN == 0 && std::is_same_v, + "Standard descriptor requires PQ_BITS=0, PQ_LEN=0, CodebookT=void"); + + // Reconstruct the descriptor type with QueryT and call compute_distance_standard + using desc_t = + standard_dataset_descriptor_t; + return compute_distance_standard(args, dataset_index); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in new file mode 100644 index 0000000000..177ea68c9d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in @@ -0,0 +1,68 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using dataset_desc_base = + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +extern "C" __global__ void compute_distance_to_child_nodes( + const index_t* const parent_node_list, + index_t* const parent_candidates_ptr, + distance_t* const parent_distance_ptr, + const std::size_t lds, + const std::uint32_t search_width, + const dataset_desc_base* dataset_desc, + const index_t* const neighbor_graph_ptr, + const std::uint32_t graph_degree, + const source_index_t* source_indices_ptr, + const data_t* query_ptr, + index_t* const visited_hashmap_ptr, + const std::uint32_t hash_bitlen, + index_t* const result_indices_ptr, + distance_t* const result_distances_ptr, + const std::uint32_t ldd, + cuvs::neighbors::filtering::none_sample_filter sample_filter) +{ + compute_distance_to_child_nodes_kernel_jit( + parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, + search_width, + dataset_desc, + neighbor_graph_ptr, + graph_degree, + source_indices_ptr, + query_ptr, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd, + sample_filter); +} + +static_assert( + std::is_same_v< + decltype(compute_distance_to_child_nodes), + compute_distance_to_child_nodes_jit_func_t>); + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json new file mode 100644 index 0000000000..f934f26c11 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json @@ -0,0 +1,13 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "uc"}, + {"data_type": "int8_t", "data_abbrev": "sc"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "ui"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "ui"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_vpq_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_vpq_impl.cuh new file mode 100644 index 0000000000..e21d48a2f1 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_vpq_impl.cuh @@ -0,0 +1,48 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_vpq-impl.cuh" +#include "../device_common.hpp" // For dataset_descriptor_base_t + +namespace cuvs::neighbors::cagra::detail { + +// Unified compute_distance implementation for VPQ descriptors +// This is instantiated when PQ_BITS>0, PQ_LEN>0, CodebookT=half +// QueryT is always half for VPQ +template +__device__ DistanceT +compute_distance(const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + // For VPQ descriptors, PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half + static_assert( + PQ_BITS > 0 && PQ_LEN > 0 && std::is_same_v && std::is_same_v, + "VPQ descriptor requires PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half"); + + // Reconstruct the descriptor type and call compute_distance_vpq + // QueryT is always half for VPQ + using desc_t = cagra_q_dataset_descriptor_t; + return compute_distance_vpq(args, dataset_index); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/device_common_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/device_common_jit.cuh new file mode 100644 index 0000000000..5146ec8bac --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/device_common_jit.cuh @@ -0,0 +1,185 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../device_common.hpp" +#include "../hashmap.hpp" +#include "../utils.hpp" +#include "extern_device_functions.cuh" + +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { +namespace device { + +// Helper to check if DescriptorT has kPqBits (VPQ descriptor) +template +struct has_kpq_bits { + template + static auto test(int) -> decltype(U::kPqBits, std::true_type{}); + template + static std::false_type test(...); + static constexpr bool value = decltype(test(0))::value; +}; + +template +inline constexpr bool has_kpq_bits_v = has_kpq_bits::value; + +// JIT version of compute_distance_to_random_nodes - uses const dataset_descriptor_base_t* (smem) +// Shared between single_cta and multi_cta JIT kernels +template +RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes_jit( + IndexT* __restrict__ result_indices_ptr, // [num_pickup] + DistanceT* __restrict__ result_distances_ptr, // [num_pickup] + const dataset_descriptor_base_t* smem_desc, + const uint32_t num_pickup, + const uint32_t num_distilation, + const uint64_t rand_xor_mask, + const IndexT* __restrict__ seed_ptr, // [num_seeds] + const uint32_t num_seeds, + IndexT* __restrict__ visited_hash_ptr, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hash_ptr, + const uint32_t traversed_hash_bitlen, + const uint32_t block_id = 0, + const uint32_t num_blocks = 1) +{ + constexpr unsigned warp_size = 32; + + uint32_t team_size_bits = smem_desc->team_size_bitshift_from_smem(); + IndexT dataset_size = smem_desc->size; + const auto args_load = smem_desc->args.load(); + + const auto max_i = raft::round_up_safe(num_pickup, warp_size >> team_size_bits); + + for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += (blockDim.x >> team_size_bits)) { + const bool valid_i = (i < num_pickup); + + IndexT best_index_team_local = raft::upper_bound(); + DistanceT best_norm2_team_local = raft::upper_bound(); + for (uint32_t j = 0; j < num_distilation; j++) { + IndexT seed_index = 0; + if (valid_i) { + uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j))); + if (seed_ptr && (gid < num_seeds)) { + seed_index = seed_ptr[gid]; + } else { + seed_index = device::xorshift64(gid ^ rand_xor_mask) % dataset_size; + } + } + + const auto norm2 = + cuvs::neighbors::cagra::detail::compute_distance_base( + args_load, seed_index, valid_i, team_size_bits); + + if (valid_i && (norm2 < best_norm2_team_local)) { + best_norm2_team_local = norm2; + best_index_team_local = seed_index; + } + } + + const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u); + if (valid_i && lane_id == 0) { + if (best_index_team_local != raft::upper_bound()) { + if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) { + // Deactivate this entry as insertion into visited hash table has failed. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } else if ((traversed_hash_ptr != nullptr) && + hashmap::search( + traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) { + // Deactivate this entry as it has been already used by others. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } + } + result_distances_ptr[i] = best_norm2_team_local; + result_indices_ptr[i] = best_index_team_local; + } + } +} + +// JIT version of compute_distance_to_child_nodes - uses const dataset_descriptor_base_t* (smem) +// Shared between single_cta and multi_cta JIT kernels +template +RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes_jit( + IndexT* __restrict__ result_child_indices_ptr, + DistanceT* __restrict__ result_child_distances_ptr, + const dataset_descriptor_base_t* smem_desc, + const IndexT* __restrict__ knn_graph, + const uint32_t knn_k, + IndexT* __restrict__ visited_hashmap_ptr, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hashmap_ptr, + const uint32_t traversed_hash_bitlen, + const IndexT* __restrict__ parent_indices, + const IndexT* __restrict__ internal_topk_list, + const uint32_t search_width, + int* __restrict__ result_position = nullptr, + const int max_result_position = 0) +{ + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + constexpr IndexT invalid_index = ~static_cast(0); + + // Read child indices of parents from knn graph and check if the distance computation is + // necessary. + for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += blockDim.x) { + const IndexT smem_parent_id = parent_indices[i / knn_k]; + IndexT child_id = invalid_index; + if (smem_parent_id != invalid_index) { + const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask; + child_id = knn_graph[(i % knn_k) + (static_cast(knn_k) * parent_id)]; + } + if (child_id != invalid_index) { + if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) { + child_id = invalid_index; + } else if ((traversed_hashmap_ptr != nullptr) && + hashmap::search( + traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) { + child_id = invalid_index; + } + } + if (STATIC_RESULT_POSITION) { + result_child_indices_ptr[i] = child_id; + } else if (child_id != invalid_index) { + int j = atomicSub(result_position, 1) - 1; + result_child_indices_ptr[j] = child_id; + } + } + __syncthreads(); + + // Compute the distance to child nodes - same inline pattern as non-JIT (device_common.hpp) + constexpr unsigned warp_size = 32; + + const auto team_size_bits = smem_desc->team_size_bitshift_from_smem(); + const auto num_k = knn_k * search_width; + const auto max_i = raft::round_up_safe(num_k, warp_size >> team_size_bits); + const auto args = smem_desc->args.load(); + const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0; + const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0]; + + for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) { + const auto j = i + ofst; + const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position); + const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index; + + const auto per_thread = + (child_id != invalid_index) + ? cuvs::neighbors::cagra::detail:: + compute_distance_per_thread_base(args, child_id) + : (lead_lane ? raft::upper_bound() : 0); + const DistanceT child_dist = device::team_sum(per_thread, team_size_bits); + __syncwarp(); + + // Store the distance + if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; } + } +} + +} // namespace device +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_cosine_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_cosine_impl.cuh new file mode 100644 index 0000000000..ba6c270fa1 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_cosine_impl.cuh @@ -0,0 +1,16 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b) +{ + return -static_cast(a) * static_cast(b); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_hamming_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_hamming_impl.cuh new file mode 100644 index 0000000000..cd4ed29ac6 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_hamming_impl.cuh @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b) +{ + const auto v = (a ^ b) & 0xffu; + return __popc(v); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_inner_product_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_inner_product_impl.cuh new file mode 100644 index 0000000000..ba6c270fa1 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_inner_product_impl.cuh @@ -0,0 +1,16 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b) +{ + return -static_cast(a) * static_cast(b); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in new file mode 100644 index 0000000000..cabc2c792e --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in @@ -0,0 +1,21 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +namespace { + +using query_t = @query_type@; +using distance_t = @distance_type@; + +} // namespace + +namespace cuvs::neighbors::cagra::detail { + +template __device__ distance_t dist_op(query_t, query_t); + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_l1_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_l1_impl.cuh new file mode 100644 index 0000000000..693a84fddd --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_l1_impl.cuh @@ -0,0 +1,19 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b) +{ + DISTANCE_T diff = a - b; + return raft::abs(diff); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_l2_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_l2_impl.cuh new file mode 100644 index 0000000000..f74b62b4b0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_l2_impl.cuh @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b) +{ + DISTANCE_T diff = a - b; + return diff * diff; +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json new file mode 100644 index 0000000000..9353c268e3 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json @@ -0,0 +1,34 @@ +{ + "_metric": [ + { + "metric_tag": "l2", + "jit_metric_tag": "tag_metric_l2", + "query_type": "float", + "query_abbrev": "f" + }, + { + "metric_tag": "inner_product", + "jit_metric_tag": "tag_metric_inner_product", + "query_type": "float", + "query_abbrev": "f" + }, + { + "metric_tag": "hamming", + "jit_metric_tag": "tag_metric_hamming", + "query_type": "uint8_t", + "query_abbrev": "uc" + }, + { + "metric_tag": "l1", + "jit_metric_tag": "tag_metric_l1", + "query_type": "float", + "query_abbrev": "f" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/extern_device_functions.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/extern_device_functions.cuh new file mode 100644 index 0000000000..9d9c006c29 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/extern_device_functions.cuh @@ -0,0 +1,62 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance.hpp" +#include + +namespace cuvs::neighbors::cagra::detail { + +template +extern __device__ const dataset_descriptor_base_t* setup_workspace( + const dataset_descriptor_base_t* desc_ptr, + void* smem, + const DataT* queries, + uint32_t query_id); + +template +extern __device__ const dataset_descriptor_base_t* setup_workspace_base( + const dataset_descriptor_base_t*, void*, const DataT*, uint32_t); + +template +extern __device__ DistanceT +compute_distance(const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index); + +template +extern __device__ DistanceT compute_distance_base( + const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index, + bool valid, + uint32_t team_size_bits); + +template +extern __device__ DistanceT compute_distance_per_thread_base( + const typename dataset_descriptor_base_t::args_t, IndexT); +} // namespace cuvs::neighbors::cagra::detail + +namespace cuvs::neighbors::detail { + +template +extern __device__ bool sample_filter(uint32_t query_id, SourceIndexT node_id, void* filter_data); + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/filter_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/filter_matrix.json new file mode 100644 index 0000000000..c4efae0572 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/filter_matrix.json @@ -0,0 +1,21 @@ +{ + "filter_name": ["none", "bitset"], + "_bitset": [ + { + "bitset_type": "uint32_t", + "bitset_abbrev": "u32" + } + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp new file mode 100644 index 0000000000..164b7a72d9 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp @@ -0,0 +1,163 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +#include + +#include "../compute_distance.hpp" // dataset_descriptor_base_t (device_common.hpp alone is not enough) +#include "search_single_cta_device_helpers.cuh" + +namespace cuvs::neighbors::cagra::detail { + +// Function types for extern "C" __global__ JIT entry points — must match cudaLibraryGetKernel / +// AlgorithmLauncher::dispatch signatures exactly (see static_assert in each *_kernel.cu). + +template +using search_single_cta_jit_func_t = + void(uintptr_t, + DistanceT* const, + const std::uint32_t, + const DataT* const, + const IndexT* const, + const std::uint32_t, + const SourceIndexT*, + const unsigned, + const uint64_t, + const IndexT*, + const uint32_t, + IndexT* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + std::uint32_t* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const dataset_descriptor_base_t*, + uint32_t*, + SourceIndexT, + SourceIndexT); + +namespace single_cta_search { + +template +using search_single_cta_p_jit_func_t = + void(worker_handle_t*, + job_desc_t>*, + uint32_t*, + const IndexT* const, + const std::uint32_t, + const SourceIndexT*, + const unsigned, + const uint64_t, + const IndexT*, + const uint32_t, + IndexT* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + std::uint32_t* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const dataset_descriptor_base_t*, + uint32_t*, + SourceIndexT, + SourceIndexT); + +} // namespace single_cta_search + +namespace multi_cta_search { + +template +using search_multi_cta_jit_func_t = void(IndexT* const, + DistanceT* const, + const dataset_descriptor_base_t*, + const DataT* const, + const IndexT* const, + const std::uint32_t, + const std::uint32_t, + const SourceIndexT*, + const unsigned, + const uint64_t, + const IndexT*, + const std::uint32_t, + const std::uint32_t, + IndexT* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + std::uint32_t* const, + const std::uint32_t, + uint32_t*, + SourceIndexT, + SourceIndexT); + +} // namespace multi_cta_search + +namespace multi_kernel_search { + +template +using random_pickup_jit_func_t = void(const dataset_descriptor_base_t*, + const DataT* const, + const std::size_t, + const unsigned, + const uint64_t, + const IndexT*, + const std::uint32_t, + IndexT* const, + DistanceT* const, + const std::uint32_t, + IndexT* const, + const std::uint32_t); + +template +using compute_distance_to_child_nodes_jit_func_t = + void(const IndexT* const, + IndexT* const, + DistanceT* const, + const std::size_t, + const std::uint32_t, + const dataset_descriptor_base_t*, + const IndexT* const, + const std::uint32_t, + const SourceIndexT*, + const DataT*, + IndexT* const, + const std::uint32_t, + IndexT* const, + DistanceT* const, + const std::uint32_t, + cuvs::neighbors::filtering::none_sample_filter); + +template +using apply_filter_kernel_jit_func_t = void(const SourceIndexT* const, + IndexT* const, + DistanceT* const, + const std::size_t, + const std::uint32_t, + const std::uint32_t, + const IndexT, + uint32_t*, + SourceIndexT, + SourceIndexT); + +} // namespace multi_kernel_search + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in new file mode 100644 index 0000000000..6a99e423dd --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using dataset_desc_base = + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +extern "C" __global__ void random_pickup(const dataset_desc_base* dataset_desc, + const data_t* const queries_ptr, + const std::size_t num_pickup, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const index_t* seed_ptr, + const std::uint32_t num_seeds, + index_t* const result_indices_ptr, + distance_t* const result_distances_ptr, + const std::uint32_t ldr, + index_t* const visited_hashmap_ptr, + const std::uint32_t hash_bitlen) +{ + random_pickup_kernel_jit(dataset_desc, + queries_ptr, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr, + visited_hashmap_ptr, + hash_bitlen); +} + +static_assert( + std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json new file mode 100644 index 0000000000..f934f26c11 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json @@ -0,0 +1,13 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "uc"}, + {"data_type": "int8_t", "data_abbrev": "sc"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "ui"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "ui"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_helpers.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_helpers.cuh new file mode 100644 index 0000000000..fe985f7275 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_helpers.cuh @@ -0,0 +1,138 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +template +RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parent( + INDEX_T* const next_parent_indices, + INDEX_T* const itopk_indices, // [itopk_size * 2] + DISTANCE_T* const itopk_distances, // [itopk_size * 2] + INDEX_T* const hash_ptr, + const uint32_t hash_bitlen) +{ + constexpr uint32_t itopk_size = 32; + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + constexpr INDEX_T invalid_index = ~static_cast(0); + + const unsigned warp_id = threadIdx.x / 32; + if (warp_id > 0) { return; } + if (threadIdx.x == 0) { next_parent_indices[0] = invalid_index; } + __syncwarp(); + + int j = -1; + for (unsigned i = threadIdx.x; i < itopk_size * 2; i += 32) { + INDEX_T index = itopk_indices[i]; + int is_invalid = 0; + int is_candidate = 0; + if (index == invalid_index) { + is_invalid = 1; + } else if (index & index_msb_1_mask) { + } else { + is_candidate = 1; + } + + const auto ballot_mask = __ballot_sync(0xffffffff, is_candidate); + const auto candidate_id = __popc(ballot_mask & ((1 << threadIdx.x) - 1)); + for (int k = 0; k < __popc(ballot_mask); k++) { + int flag_done = 0; + if (is_candidate && candidate_id == k) { + is_candidate = 0; + if (hashmap::insert(hash_ptr, hash_bitlen, index)) { + // Use this candidate as next parent + index |= index_msb_1_mask; // set most significant bit as used node + if (i < itopk_size) { + next_parent_indices[0] = i; + itopk_indices[i] = index; + } else { + next_parent_indices[0] = j; + // Move the next parent node from i-th position to j-th position + itopk_indices[j] = index; + itopk_distances[j] = itopk_distances[i]; + itopk_indices[i] = invalid_index; + itopk_distances[i] = utils::get_max_value(); + } + flag_done = 1; + } else { + // Deactivate the node since it has been used by other CTA. + itopk_indices[i] = invalid_index; + itopk_distances[i] = utils::get_max_value(); + is_invalid = 1; + } + } + if (__any_sync(0xffffffff, (flag_done > 0))) { return; } + } + if (i < itopk_size) { + j = 31 - __clz(__ballot_sync(0xffffffff, is_invalid)); + if (j < 0) { return; } + } + } +} + +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num_elements] + INDEX_T* indices, // [num_elements] + const uint32_t num_elements) +{ + const unsigned warp_id = threadIdx.x / raft::warp_size(); + if (warp_id > 0) { return; } + const unsigned lane_id = threadIdx.x % raft::warp_size(); + constexpr unsigned N = (MAX_ELEMENTS + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + INDEX_T val[N]; + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i); + if (j < num_elements) { + key[i] = distances[j]; + val[i] = indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = ~static_cast(0); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + /* Store sorted results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_elements) { + distances[j] = key[i]; + indices[j] = val[i]; + } + } +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_64( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<64, uint32_t>(distances, indices, num_elements); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_128( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<128, uint32_t>(distances, indices, num_elements); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_256( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<256, uint32_t>(distances, indices, num_elements); +} + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh new file mode 100644 index 0000000000..4af367ad68 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh @@ -0,0 +1,369 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../hashmap.hpp" +#include "../utils.hpp" + +#include + +#include +#include +#include + +#ifdef _CLK_BREAKDOWN +#include +#endif + +#include "../../jit_lto_kernels/filter_data.h" +#include "device_common_jit.cuh" +#include "extern_device_functions.cuh" + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +using cuvs::neighbors::cagra::detail::device::compute_distance_to_child_nodes_jit; +using cuvs::neighbors::cagra::detail::device::compute_distance_to_random_nodes_jit; +using cuvs::neighbors::detail::sample_filter; +template +__device__ void search_kernel_jit( + IndexT* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] + DistanceT* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] + const dataset_descriptor_base_t* dataset_desc, + const DataT* const queries_ptr, // [num_queries, dataset_dim] + const IndexT* const knn_graph, // [dataset_size, graph_degree] + const uint32_t max_elements, + const uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, // [num_queries, search_width] + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + const uint32_t visited_hash_bitlen, + IndexT* const traversed_hashmap_ptr, // [num_queries, 1 << traversed_hash_bitlen] + const uint32_t traversed_hash_bitlen, + const uint32_t itopk_size, + const uint32_t min_iteration, + const uint32_t max_iteration, + uint32_t* const num_executed_iterations, /* stats */ + const uint32_t query_id_offset, // Offset to add to query_id when calling filter + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) + SourceIndexT bitset_len, // Bitset length + SourceIndexT original_nbits) +{ + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + + auto to_source_index = [source_indices_ptr](INDEX_T x) { + return source_indices_ptr == nullptr ? static_cast(x) : source_indices_ptr[x]; + }; + + const auto num_queries = gridDim.y; + const auto query_id = blockIdx.y; + const auto num_cta_per_query = gridDim.x; + const auto cta_id = blockIdx.x; // local CTA ID + +#ifdef _CLK_BREAKDOWN + uint64_t clk_init = 0; + uint64_t clk_compute_1st_distance = 0; + uint64_t clk_topk = 0; + uint64_t clk_pickup_parents = 0; + uint64_t clk_compute_distance = 0; + uint64_t clk_start; +#define _CLK_START() clk_start = clock64() +#define _CLK_REC(V) V += clock64() - clk_start; +#else +#define _CLK_START() +#define _CLK_REC(V) +#endif + _CLK_START(); + + extern __shared__ uint8_t smem[]; + + // Layout of result_buffer + // +----------------+---------+---------------------------+ + // | internal_top_k | padding | neighbors of parent nodes | + // | | upto 32 | | + // +----------------+---------+---------------------------+ + // |<--- result_buffer_size_32 --->| + const auto result_buffer_size = itopk_size + graph_degree; + const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); + assert(result_buffer_size_32 <= max_elements); + + // Get dim and smem_ws_size_in_bytes directly from base descriptor + uint32_t dim = dataset_desc->args.dim; + uint32_t smem_ws_size_in_bytes = dataset_desc->smem_ws_size_in_bytes(); + + auto smem_desc = + setup_workspace_base(dataset_desc, smem, queries_ptr, query_id); + + auto* __restrict__ result_indices_buffer = + reinterpret_cast(smem + smem_ws_size_in_bytes); + auto* __restrict__ result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto* __restrict__ local_visited_hashmap_ptr = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto* __restrict__ parent_indices_buffer = + reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); + auto* __restrict__ result_position = reinterpret_cast(parent_indices_buffer + 1); + + INDEX_T* const local_traversed_hashmap_ptr = + traversed_hashmap_ptr + (hashmap::get_size(traversed_hash_bitlen) * query_id); + + constexpr INDEX_T invalid_index = ~static_cast(0); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen); + __syncthreads(); + _CLK_REC(clk_init); + + // compute distance to randomly selecting nodes using JIT version + _CLK_START(); + const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + uint32_t block_id = cta_id + (num_cta_per_query * query_id); + uint32_t num_blocks = num_cta_per_query * num_queries; + + compute_distance_to_random_nodes_jit(result_indices_buffer, + result_distances_buffer, + smem_desc, + graph_degree, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, + block_id, + num_blocks); + __syncthreads(); + _CLK_REC(clk_compute_1st_distance); + + uint32_t iter = 0; + while (1) { + _CLK_START(); + if (threadIdx.x < 32) { + // [1st warp] Topk with bitonic sort + if constexpr (std::is_same_v) { + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort + // function (vs post-inlining, this impacts register pressure) + if (max_elements <= 64) { + topk_by_bitonic_sort_wrapper_64( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else if (max_elements <= 128) { + topk_by_bitonic_sort_wrapper_128( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else { + assert(max_elements <= 256); + topk_by_bitonic_sort_wrapper_256( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } + } else { + if (max_elements <= 64) { + topk_by_bitonic_sort<64, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else if (max_elements <= 128) { + topk_by_bitonic_sort<128, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else { + assert(max_elements <= 256); + topk_by_bitonic_sort<256, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } + } + } + __syncthreads(); + _CLK_REC(clk_topk); + + if (iter + 1 >= max_iteration) { break; } + + _CLK_START(); + if (threadIdx.x < 32) { + // [1st warp] Pick up a next parent + pickup_next_parent(parent_indices_buffer, + result_indices_buffer, + result_distances_buffer, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); + } else { + // [Other warps] Reset visited hashmap + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen, 32); + } + __syncthreads(); + _CLK_REC(clk_pickup_parents); + + if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; } + + _CLK_START(); + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + if ((i >= itopk_size) && (index & index_msb_1_mask)) { + // Remove nodes kicked out of the itopk list from the traversed hash table. + hashmap::remove( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } else { + // Restore visited hashmap by putting nodes on result buffer in it. + index &= ~index_msb_1_mask; + hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); + } + } + // Initialize buffer for compute_distance_to_child_nodes. + if (threadIdx.x == blockDim.x - 1) { result_position[0] = result_buffer_size_32; } + __syncthreads(); + + // Compute the norms between child nodes and query node using JIT version + compute_distance_to_child_nodes_jit(result_indices_buffer, + result_distances_buffer, + smem_desc, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, + parent_indices_buffer, + result_indices_buffer, + 1, + result_position, + result_buffer_size_32); + __syncthreads(); + + // Check the state of the nodes in the result buffer which were not updated + // by the compute_distance_to_child_nodes above, and if it cannot be used as + // a parent node, it is deactivated. + for (uint32_t i = threadIdx.x; i < result_position[0]; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index || index & index_msb_1_mask) { continue; } + if (hashmap::search(local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + __syncthreads(); + _CLK_REC(clk_compute_distance); + + // Filtering - use extern sample_filter function (linked via JIT LTO) + for (unsigned p = threadIdx.x; p < 1; p += blockDim.x) { + if (parent_indices_buffer[p] != invalid_index) { + const auto parent_id = result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; + // Construct filter_data struct (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + if (!sample_filter(query_id + query_id_offset, + to_source_index(parent_id), + bitset_ptr != nullptr ? &filter_data : nullptr)) { + // If the parent must not be in the resulting top-k list, remove from the parent list + result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_indices_buffer[p]] = invalid_index; + } + } + } + __syncthreads(); + + iter++; + } + + // Filtering - use extern sample_filter function (linked via JIT LTO) + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + index &= ~index_msb_1_mask; + // Construct filter_data struct (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + if (!sample_filter(query_id + query_id_offset, + to_source_index(index), + bitset_ptr != nullptr ? &filter_data : nullptr)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + __syncthreads(); + + // Output search results (1st warp only). + if (threadIdx.x < 32) { + uint32_t offset = 0; + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += 32) { + INDEX_T index = result_indices_buffer[i]; + bool is_valid = false; + if (index != invalid_index) { + if (index & index_msb_1_mask) { + is_valid = true; + index &= ~index_msb_1_mask; + } else if ((offset < itopk_size) && + hashmap::insert( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + // If a node that is not used as a parent can be inserted into + // the traversed hash table, it is considered a valid result. + is_valid = true; + } + } + const auto mask = __ballot_sync(0xffffffff, is_valid); + if (is_valid) { + const auto j = offset + __popc(mask & ((1 << threadIdx.x) - 1)); + if (j < itopk_size) { + uint32_t k = j + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + result_indices_ptr[k] = index & ~index_msb_1_mask; + if (result_distances_ptr != nullptr) { + DISTANCE_T dist = result_distances_buffer[i]; + result_distances_ptr[k] = dist; + } + } else { + // If it is valid and registered in the traversed hash table but is + // not output as a result, it is removed from the hash table. + hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); + } + } + offset += __popc(mask); + } + // If the number of outputs is insufficient, fill in with invalid results. + for (uint32_t i = offset + threadIdx.x; i < itopk_size; i += 32) { + uint32_t k = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + result_indices_ptr[k] = invalid_index; + if (result_distances_ptr != nullptr) { + result_distances_ptr[k] = utils::get_max_value(); + } + } + } + + if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { + num_executed_iterations[query_id] = iter + 1; + } + +#ifdef _CLK_BREAKDOWN + if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && (blockIdx.x == 0) && + ((query_id * 3) % gridDim.y < 3)) { + printf( + "%s:%d " + "query, %d, thread, %d" + ", init, %lu" + ", 1st_distance, %lu" + ", topk, %lu" + ", pickup_parents, %lu" + ", distance, %lu" + "\n", + __FILE__, + __LINE__, + query_id, + threadIdx.x, + clk_init, + clk_compute_1st_distance, + clk_topk, + clk_pickup_parents, + clk_compute_distance); + } +#endif +} + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in new file mode 100644 index 0000000000..50f28d6df7 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in @@ -0,0 +1,77 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using dataset_desc_base = + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +extern "C" __global__ __launch_bounds__(1024, 1) void search_multi_cta( + index_t* const result_indices_ptr, + distance_t* const result_distances_ptr, + const dataset_desc_base* dataset_desc, + const data_t* const queries_ptr, + const index_t* const knn_graph, + const std::uint32_t max_elements, + const std::uint32_t graph_degree, + const source_index_t* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const index_t* seed_ptr, + const std::uint32_t num_seeds, + const std::uint32_t visited_hash_bitlen, + index_t* const traversed_hashmap_ptr, + const std::uint32_t traversed_hash_bitlen, + const std::uint32_t itopk_size, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t query_id_offset, + uint32_t* bitset_ptr, + source_index_t bitset_len, + source_index_t original_nbits) +{ + search_kernel_jit(result_indices_ptr, + result_distances_ptr, + dataset_desc, + queries_ptr, + knn_graph, + max_elements, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen, + itopk_size, + min_iteration, + max_iteration, + num_executed_iterations, + query_id_offset, + bitset_ptr, + bitset_len, + original_nbits); +} + +static_assert( + std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json new file mode 100644 index 0000000000..f934f26c11 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json @@ -0,0 +1,13 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "uc"}, + {"data_type": "int8_t", "data_abbrev": "sc"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "ui"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "ui"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp new file mode 100644 index 0000000000..146c1b0dfa --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp @@ -0,0 +1,38 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "cagra_planner_base.hpp" +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +template +struct CagraMultiCtaSearchPlanner : CagraPlannerBase { + CagraMultiCtaSearchPlanner(cuvs::distance::DistanceType /*metric*/, + uint32_t /*team_size*/, + uint32_t /*dataset_block_dim*/, + bool /*is_vpq*/, + uint32_t /*pq_bits*/, + uint32_t /*pq_len*/) + : CagraPlannerBase("search_multi_cta") + { + } + + void add_search_multi_cta_kernel_fragment() + { + this->add_static_fragment< + fragment_tag_search_multi_cta>(); + } +}; + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh new file mode 100644 index 0000000000..86ddeb30d6 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh @@ -0,0 +1,220 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../device_common.hpp" +#include "../hashmap.hpp" +#include "../utils.hpp" + +#include +#include +#include + +#include "../../jit_lto_kernels/filter_data.h" +#include "extern_device_functions.cuh" + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +template +__device__ void random_pickup_kernel_jit( + const dataset_descriptor_base_t* dataset_desc, + const DataT* const queries_ptr, // [num_queries, dataset_dim] + const std::size_t num_pickup, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + IndexT* const result_indices_ptr, // [num_queries, ldr] + DistanceT* const result_distances_ptr, // [num_queries, ldr] + const std::uint32_t ldr, // (*) ldr >= num_pickup + IndexT* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + const std::uint32_t hash_bitlen) +{ + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + + // Get team_size_bits directly from base descriptor + uint32_t team_size_bits = dataset_desc->team_size_bitshift(); + + const auto ldb = hashmap::get_size(hash_bitlen); + const auto global_team_index = (blockIdx.x * blockDim.x + threadIdx.x) >> team_size_bits; + const uint32_t query_id = blockIdx.y; + if (global_team_index >= num_pickup) { return; } + extern __shared__ uint8_t smem[]; + + auto smem_desc = + setup_workspace_base(dataset_desc, smem, queries_ptr, query_id); + __syncthreads(); + + IndexT dataset_size = smem_desc->size; + const auto args_load = smem_desc->args.load(); + + INDEX_T best_index_team_local; + DISTANCE_T best_norm2_team_local = utils::get_max_value(); + for (unsigned i = 0; i < num_distilation; i++) { + INDEX_T seed_index; + if (seed_ptr && (global_team_index < num_seeds)) { + seed_index = seed_ptr[global_team_index + (num_seeds * query_id)]; + } else { + seed_index = device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_size; + } + + const auto norm2 = + compute_distance_base(args_load, seed_index, true, team_size_bits); + + if (norm2 < best_norm2_team_local) { + best_norm2_team_local = norm2; + best_index_team_local = seed_index; + } + } + + const auto store_gmem_index = global_team_index + (ldr * query_id); + if ((threadIdx.x & ((1u << team_size_bits) - 1u)) == 0) { + if (hashmap::insert( + visited_hashmap_ptr + (ldb * query_id), hash_bitlen, best_index_team_local)) { + result_distances_ptr[store_gmem_index] = best_norm2_team_local; + result_indices_ptr[store_gmem_index] = best_index_team_local; + } else { + result_distances_ptr[store_gmem_index] = utils::get_max_value(); + result_indices_ptr[store_gmem_index] = utils::get_max_value(); + } + } +} + +template +__device__ void compute_distance_to_child_nodes_kernel_jit( + const IndexT* const parent_node_list, // [num_queries, search_width] + IndexT* const parent_candidates_ptr, // [num_queries, search_width] + DistanceT* const parent_distance_ptr, // [num_queries, search_width] + const std::size_t lds, + const std::uint32_t search_width, + const dataset_descriptor_base_t* dataset_desc, + const IndexT* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const DataT* query_ptr, // [num_queries, data_dim] + IndexT* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t hash_bitlen, + IndexT* const result_indices_ptr, // [num_queries, ldd] + DistanceT* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + SAMPLE_FILTER_T sample_filter) +{ + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + + // Get team_size_bits directly from base descriptor + uint32_t team_size_bits = dataset_desc->team_size_bitshift(); + + const auto team_size = 1u << team_size_bits; + const uint32_t ldb = hashmap::get_size(hash_bitlen); + const auto tid = threadIdx.x + blockDim.x * blockIdx.x; + const auto global_team_id = tid >> team_size_bits; + const auto query_id = blockIdx.y; + + extern __shared__ uint8_t smem[]; + auto smem_desc = + setup_workspace_base(dataset_desc, smem, query_ptr, query_id); + + __syncthreads(); + if (global_team_id >= search_width * graph_degree) { return; } + + const std::size_t parent_list_index = + parent_node_list[global_team_id / graph_degree + (search_width * blockIdx.y)]; + + if (parent_list_index == utils::get_max_value()) { return; } + + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const auto raw_parent_index = parent_candidates_ptr[parent_list_index + (lds * query_id)]; + + if (raw_parent_index == utils::get_max_value()) { + result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + return; + } + const auto parent_index = raw_parent_index & ~index_msb_1_mask; + + const auto neighbor_list_head_ptr = neighbor_graph_ptr + (graph_degree * parent_index); + + const std::size_t child_id = neighbor_list_head_ptr[global_team_id % graph_degree]; + + const auto compute_distance_flag = hashmap::insert( + team_size, visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id); + + const auto args = smem_desc->args.load(); + DISTANCE_T norm2 = compute_distance_base( + args, static_cast(child_id), compute_distance_flag, team_size_bits); + + if (compute_distance_flag) { + if ((threadIdx.x & (team_size - 1)) == 0) { + result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; + result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2; + } + } else { + if ((threadIdx.x & (team_size - 1)) == 0) { + result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + } + } + + if constexpr (!std::is_same::value) { + if (!sample_filter( + query_id, + source_indices_ptr == nullptr ? parent_index : source_indices_ptr[parent_index])) { + parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); + parent_distance_ptr[parent_list_index + (lds * query_id)] = + utils::get_max_value(); + } + } +} + +using cuvs::neighbors::detail::sample_filter; +template +__device__ void apply_filter_kernel_jit( + const SourceIndexT* source_indices_ptr, // [num_queries, search_width] + IndexT* const result_indices_ptr, + DistanceT* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const IndexT query_id_offset, + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) - in global memory + SourceIndexT bitset_len, // Bitset length + SourceIndexT original_nbits) // Original number of bits +{ + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= result_buffer_size * num_queries) { return; } + const auto i = tid % result_buffer_size; + const auto j = tid / result_buffer_size; + const auto index = i + j * lds; + + if (result_indices_ptr[index] != ~index_msb_1_mask) { + // Use extern sample_filter function with 3 params: query_id, node_id, filter_data + // filter_data is a void* pointer to bitset_filter_data_t (or nullptr for none_filter) + SourceIndexT node_id = source_indices_ptr == nullptr + ? static_cast(result_indices_ptr[index]) + : source_indices_ptr[result_indices_ptr[index]]; + + // Construct filter_data struct in registers (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + + if (!sample_filter( + query_id_offset + j, node_id, bitset_ptr != nullptr ? &filter_data : nullptr)) { + result_indices_ptr[index] = utils::get_max_value(); + result_distances_ptr[index] = utils::get_max_value(); + } + } +} + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_kernel_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_kernel_planner.hpp new file mode 100644 index 0000000000..ca28acc14c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_kernel_planner.hpp @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "cagra_planner_base.hpp" +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +template +struct CagraMultiKernelSearchPlanner : CagraPlannerBase { + CagraMultiKernelSearchPlanner(cuvs::distance::DistanceType /*metric*/, + const std::string& kernel_name, + uint32_t /*team_size*/, + uint32_t /*dataset_block_dim*/, + bool /*is_vpq*/, + uint32_t /*pq_bits*/, + uint32_t /*pq_len*/) + : CagraPlannerBase(kernel_name) + { + } + + void add_linked_kernel(std::string const& kernel_name) + { + if (kernel_name == "random_pickup") { + this->add_static_fragment>(); + } else if (kernel_name == "compute_distance_to_child_nodes") { + this->add_static_fragment>(); + } else if (kernel_name == "apply_filter_kernel") { + this->add_static_fragment< + fragment_tag_apply_filter_kernel>(); + } else { + RAFT_FAIL("Unknown CAGRA multi-kernel JIT kernel: %s", kernel_name.c_str()); + } + } +}; + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_device_helpers.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_device_helpers.cuh new file mode 100644 index 0000000000..b549a435c9 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_device_helpers.cuh @@ -0,0 +1,672 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +// Device-only includes - no host-side headers +#include "../bitonic.hpp" +#include "../device_common.hpp" +#include "../hashmap.hpp" +#include "../utils.hpp" + +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include // For uint4 + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// Descriptor tag for JIT persistent job queues (matches dataset_descriptor DATA_T / INDEX_T / +// DISTANCE_T) +template +struct job_desc_jit_helper_desc { + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; +}; + +// Constants for persistent kernels +constexpr size_t kCacheLineBytes = 64; +constexpr uint32_t kMaxJobsNum = 8192; + +// Worker handle for persistent kernels +struct alignas(kCacheLineBytes) worker_handle_t { + using handle_t = uint64_t; + struct value_t { + uint32_t desc_id; + uint32_t query_id; + }; + union data_t { + handle_t handle; + value_t value; + }; + cuda::atomic data; +}; +static_assert(sizeof(worker_handle_t::value_t) == sizeof(worker_handle_t::handle_t)); +static_assert( + cuda::atomic::is_always_lock_free); + +constexpr worker_handle_t::handle_t kWaitForWork = std::numeric_limits::max(); +constexpr worker_handle_t::handle_t kNoMoreWork = kWaitForWork - 1; + +// Job descriptor for persistent kernels +template +struct alignas(kCacheLineBytes) job_desc_t { + using index_type = typename DATASET_DESCRIPTOR_T::INDEX_T; + using distance_type = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + using data_type = typename DATASET_DESCRIPTOR_T::DATA_T; + // The algorithm input parameters + struct value_t { + uintptr_t result_indices_ptr; // [num_queries, top_k] + distance_type* result_distances_ptr; // [num_queries, top_k] + const data_type* queries_ptr; // [num_queries, dataset_dim] + uint32_t top_k; + uint32_t n_queries; + }; + using blob_elem_type = uint4; + constexpr static inline size_t kBlobSize = + raft::div_rounding_up_safe(sizeof(value_t), sizeof(blob_elem_type)); + // Union facilitates loading the input by a warp in a single request + union input_t { + blob_elem_type blob[kBlobSize]; // NOLINT + value_t value; + } input; + // Last thread triggers this flag. + cuda::atomic completion_flag; +}; + +// Pick up next parent nodes from the internal topk list +template +RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents(std::uint32_t* const terminate_flag, + INDEX_T* const next_parent_indices, + INDEX_T* const internal_topk_indices, + const std::size_t internal_topk_size, + const std::uint32_t search_width) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + for (std::uint32_t i = threadIdx.x; i < search_width; i += 32) { + next_parent_indices[i] = utils::get_max_value(); + } + std::uint32_t itopk_max = internal_topk_size; + if (itopk_max % 32) { itopk_max += 32 - (itopk_max % 32); } + std::uint32_t num_new_parents = 0; + for (std::uint32_t j = threadIdx.x; j < itopk_max; j += 32) { + std::uint32_t jj = j; + if (TOPK_BY_BITONIC_SORT) { jj = device::swizzling(j); } + INDEX_T index; + int new_parent = 0; + if (j < internal_topk_size) { + index = internal_topk_indices[jj]; + if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set + new_parent = 1; + } + } + const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); + if (new_parent) { + const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; + if (i < search_width) { + next_parent_indices[i] = jj; + // set most significant bit as used node + internal_topk_indices[jj] |= index_msb_1_mask; + } + } + num_new_parents += __popc(ballot_mask); + if (num_new_parents >= search_width) { break; } + } + if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } +} + +// Helper function for bitonic sort and full +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full( + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + const unsigned lane_id = threadIdx.x % raft::warp_size(); + const unsigned warp_id = threadIdx.x / raft::warp_size(); + static_assert(MAX_CANDIDATES <= 256); + if constexpr (!MULTI_WARPS) { + if (warp_id > 0) { return; } + constexpr unsigned N = (MAX_CANDIDATES + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + /* Candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i); + if (j < num_candidates) { + key[i] = candidate_distances[j]; + val[i] = candidate_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Sort */ + bitonic::warp_sort(key, val); + /* Reg -> Temp_itopk */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_candidates && j < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } else { + assert(blockDim.x >= 64); + // Use two warps (64 threads) + constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; + static_assert(max_candidates_per_warp <= 128); + constexpr unsigned N = (max_candidates_per_warp + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + if (warp_id < 2) { + /* Candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = lane_id + (raft::warp_size() * i); + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates) { + key[i] = candidate_distances[j]; + val[i] = candidate_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Sort */ + bitonic::warp_sort(key, val); + /* Reg -> Temp_candidates */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates && jl < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } + __syncthreads(); + + unsigned num_warps_used = (num_itopk + max_candidates_per_warp - 1) / max_candidates_per_warp; + if (warp_id < num_warps_used) { + /* Temp_candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned kl = max_candidates_per_warp - 1 - jl; + unsigned j = jl + (max_candidates_per_warp * warp_id); + unsigned k = MAX_CANDIDATES - 1 - j; + if (j >= num_candidates || k >= num_candidates || kl >= num_itopk) continue; + float temp_key = candidate_distances[device::swizzling(k)]; + if (key[i] == temp_key) continue; + if ((warp_id == 0) == (key[i] > temp_key)) { + key[i] = temp_key; + val[i] = candidate_indices[device::swizzling(k)]; + } + } + } + if (num_warps_used > 1) { __syncthreads(); } + if (warp_id < num_warps_used) { + /* Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + /* Reg -> Temp_itopk */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates && j < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } + if (num_warps_used > 1) { __syncthreads(); } + } +} + +// Wrapper functions to avoid pre-inlining (impacts register pressure) +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_64_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<64, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_128_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<128, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_256_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<256, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +// TopK by bitonic sort and merge (template version with MAX_ITOPK) +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( + float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + const unsigned lane_id = threadIdx.x % raft::warp_size(); + const unsigned warp_id = threadIdx.x / raft::warp_size(); + + static_assert(MAX_ITOPK <= 512); + if constexpr (!MULTI_WARPS) { + static_assert(MAX_ITOPK <= 256); + if (warp_id > 0) { return; } + constexpr unsigned N = (MAX_ITOPK + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + if (first) { + /* Load itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i); + if (j < num_itopk) { + key[i] = itopk_distances[j]; + val[i] = itopk_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + } else { + /* Load itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + key[i] = itopk_distances[device::swizzling(j)]; + val[i] = itopk_indices[device::swizzling(j)]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + } + /* Merge candidates */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; // [0:max_itopk-1] + unsigned k = MAX_ITOPK - 1 - j; + if (k >= num_itopk || k >= num_candidates) continue; + float candidate_key = candidate_distances[device::swizzling(k)]; + if (key[i] > candidate_key) { + key[i] = candidate_key; + val[i] = candidate_indices[device::swizzling(k)]; + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + /* Store new itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + } else { + static_assert(MAX_ITOPK == 512); + assert(blockDim.x >= 64); + // Use two warps (64 threads) or more + constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; + constexpr unsigned N = (max_itopk_per_warp + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + if (first) { + /* Load itop results (not sorted) */ + if (warp_id < 2) { + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i) + (max_itopk_per_warp * warp_id); + if (j < num_itopk) { + key[i] = itopk_distances[j]; + val[i] = itopk_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + /* Store intermedidate results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + if (j >= num_itopk) continue; + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + __syncthreads(); + if (warp_id < 2) { + /* Load intermedidate results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + unsigned k = MAX_ITOPK - 1 - j; + if (k >= num_itopk) continue; + float temp_key = itopk_distances[device::swizzling(k)]; + if (key[i] == temp_key) continue; + if ((warp_id == 0) == (key[i] > temp_key)) { + key[i] = temp_key; + val[i] = itopk_indices[device::swizzling(k)]; + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + } + __syncthreads(); + /* Store itopk results (sorted) */ + if (warp_id < 2) { + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + if (j >= num_itopk) continue; + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + } + const uint32_t num_itopk_div2 = num_itopk / 2; + if (threadIdx.x < 3) { + // work_buf is used to obtain turning points in 1st and 2nd half of itopk afer merge. + work_buf[threadIdx.x] = num_itopk_div2; + } + __syncthreads(); + + // Merge candidates (using whole threads) + for (unsigned k = threadIdx.x; k < (num_candidates < num_itopk ? num_candidates : num_itopk); + k += blockDim.x) { + const unsigned j = num_itopk - 1 - k; + const float itopk_key = itopk_distances[device::swizzling(j)]; + const float candidate_key = candidate_distances[device::swizzling(k)]; + if (itopk_key > candidate_key) { + itopk_distances[device::swizzling(j)] = candidate_key; + itopk_indices[device::swizzling(j)] = candidate_indices[device::swizzling(k)]; + if (j < num_itopk_div2) { + atomicMin(work_buf + 2, j); + } else { + atomicMin(work_buf + 1, j - num_itopk_div2); + } + } + } + __syncthreads(); + + // Merge 1st and 2nd half of itopk (using whole threads) + for (unsigned j = threadIdx.x; j < num_itopk_div2; j += blockDim.x) { + const unsigned k = j + num_itopk_div2; + float key_0 = itopk_distances[device::swizzling(j)]; + float key_1 = itopk_distances[device::swizzling(k)]; + if (key_0 > key_1) { + itopk_distances[device::swizzling(j)] = key_1; + itopk_distances[device::swizzling(k)] = key_0; + IdxT val_0 = itopk_indices[device::swizzling(j)]; + IdxT val_1 = itopk_indices[device::swizzling(k)]; + itopk_indices[device::swizzling(j)] = val_1; + itopk_indices[device::swizzling(k)] = val_0; + atomicMin(work_buf + 0, j); + } + } + if (threadIdx.x == blockDim.x - 1) { + if (work_buf[2] < num_itopk_div2) { work_buf[1] = work_buf[2]; } + } + __syncthreads(); + // Warp-0 merges 1st half of itopk, warp-1 does 2nd half. + if (warp_id < 2) { + // Load intermedidate itopk results + const uint32_t turning_point = work_buf[warp_id]; // turning_point <= num_itopk_div2 + for (unsigned i = 0; i < N; i++) { + unsigned k = num_itopk; + unsigned j = (N * lane_id) + i; + if (j < turning_point) { + k = j + (num_itopk_div2 * warp_id); + } else if (j >= (MAX_ITOPK / 2 - num_itopk_div2)) { + j -= (MAX_ITOPK / 2 - num_itopk_div2); + if ((turning_point <= j) && (j < num_itopk_div2)) { k = j + (num_itopk_div2 * warp_id); } + } + if (k < num_itopk) { + key[i] = itopk_distances[device::swizzling(k)]; + val[i] = itopk_indices[device::swizzling(k)]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + /* Store new itopk results */ + for (unsigned i = 0; i < N; i++) { + const unsigned j = (N * lane_id) + i; + if (j < num_itopk_div2) { + unsigned k = j + (num_itopk_div2 * warp_id); + itopk_distances[device::swizzling(k)] = key[i]; + itopk_indices[device::swizzling(k)] = val[i]; + } + } + } + } +} + +// Wrapper functions to avoid pre-inlining +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_64_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<64, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_128_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<128, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_256_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<256, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +// TopK by bitonic sort and merge (runtime version) +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( + float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] + const std::uint32_t max_itopk, + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t max_candidates, + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + static_assert(std::is_same_v); + assert(max_itopk <= 512); + assert(max_candidates <= 256); + assert(!MULTI_WARPS || blockDim.x >= 64); + + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort_and_full + // function (vs post-inlining, this impacts register pressure) + if (max_candidates <= 64) { + topk_by_bitonic_sort_and_full_wrapper_64_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } else if (max_candidates <= 128) { + topk_by_bitonic_sort_and_full_wrapper_128_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } else { + topk_by_bitonic_sort_and_full_wrapper_256_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } + + if constexpr (!MULTI_WARPS) { + assert(max_itopk <= 256); + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort_and_merge + // function (vs post-inlining, this impacts register pressure) + if (max_itopk <= 64) { + topk_by_bitonic_sort_and_merge_wrapper_64_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } else if (max_itopk <= 128) { + topk_by_bitonic_sort_and_merge_wrapper_128_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } else { + topk_by_bitonic_sort_and_merge_wrapper_256_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } + } else { + assert(max_itopk > 256); + topk_by_bitonic_sort_and_merge<512, MULTI_WARPS, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } +} + +// This function move the invalid index element to the end of the itopk list. +// Require : array_length % 32 == 0 && The invalid entry is only one. +template +RAFT_DEVICE_INLINE_FUNCTION void move_invalid_to_end_of_list(IdxT* const index_array, + float* const distance_array, + const std::uint32_t array_length) +{ + constexpr std::uint32_t warp_size = 32; + constexpr std::uint32_t invalid_index = utils::get_max_value(); + const std::uint32_t lane_id = threadIdx.x % warp_size; + + if (threadIdx.x >= warp_size) { return; } + + bool found_invalid = false; + if (array_length % warp_size == 0) { + for (std::uint32_t i = lane_id; i < array_length; i += warp_size) { + const auto index = index_array[i]; + const auto distance = distance_array[i]; + + if (found_invalid) { + index_array[i - 1] = index; + distance_array[i - 1] = distance; + } else { + // Check if the index is invalid + const auto I_found_invalid = (index == invalid_index); + const auto who_has_invalid = raft::ballot(I_found_invalid); + // if a value that is loaded by a smaller lane id thread, shift the array + if (who_has_invalid << (warp_size - lane_id)) { + index_array[i - 1] = index; + distance_array[i - 1] = distance; + } + + found_invalid = who_has_invalid; + } + } + } + if (lane_id == 0) { + index_array[array_length - 1] = invalid_index; + distance_array[array_length - 1] = utils::get_max_value(); + } +} + +template +RAFT_DEVICE_INLINE_FUNCTION void hashmap_restore(INDEX_T* const hashmap_ptr, + const size_t hashmap_bitlen, + const INDEX_T* itopk_indices, + const uint32_t itopk_size, + const uint32_t first_tid = 0) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + if (threadIdx.x < first_tid) return; + for (unsigned i = threadIdx.x - first_tid; i < itopk_size; i += blockDim.x - first_tid) { + auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit + hashmap::insert(hashmap_ptr, hashmap_bitlen, key); + } +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh new file mode 100644 index 0000000000..7d13ccea41 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh @@ -0,0 +1,659 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +// Device-only helpers - extracted from search_single_cta_kernel-inl.cuh to avoid host-side includes +#include "search_single_cta_device_helpers.cuh" + +// Additional device-side includes needed +#include "../device_common.hpp" +#include "../hashmap.hpp" +#include "../topk_by_radix.cuh" +#include "../utils.hpp" + +#include // For raft::shfl_xor +#include // For raft::round_up_safe +#include + +#include + +#include +#include + +#include // For assert() + +#ifdef _CLK_BREAKDOWN +#include // For printf() in debug mode +#endif + +// Include extern function declarations before namespace so they're available to kernel definitions +#include "../../jit_lto_kernels/filter_data.h" +#include "extern_device_functions.cuh" +// Include shared JIT device functions +#include "device_common_jit.cuh" + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// are defined in search_single_cta_kernel-inl.cuh which is included by the launcher. +// We don't redefine them here to avoid duplicate definitions. + +// Sample filter extern function +// sample_filter is declared in extern_device_functions.cuh +using cuvs::neighbors::detail::sample_filter; + +// JIT versions of compute_distance_to_random_nodes and compute_distance_to_child_nodes +// are now shared in device_common_jit.cuh - use fully qualified names +using cuvs::neighbors::cagra::detail::device::compute_distance_to_child_nodes_jit; +using cuvs::neighbors::cagra::detail::device::compute_distance_to_random_nodes_jit; + +// JIT search_core - setup_workspace/compute_distance via function pointers +template +RAFT_DEVICE_INLINE_FUNCTION void search_core( + uintptr_t result_indices_ptr, + DistanceT* const result_distances_ptr, + const std::uint32_t top_k, + const DataT* const queries_ptr, + const IndexT* const knn_graph, + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, + const uint32_t num_seeds, + IndexT* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id, + const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter + const dataset_descriptor_base_t* dataset_desc, + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) + SourceIndexT bitset_len, // Bitset length + SourceIndexT original_nbits) // Original number of bits +{ + using LOAD_T = device::LOAD_128BIT_T; + + auto to_source_index = [source_indices_ptr](IndexT x) { + return source_indices_ptr == nullptr ? static_cast(x) : source_indices_ptr[x]; + }; + +#ifdef _CLK_BREAKDOWN + std::uint64_t clk_init = 0; + std::uint64_t clk_compute_1st_distance = 0; + std::uint64_t clk_topk = 0; + std::uint64_t clk_reset_hash = 0; + std::uint64_t clk_pickup_parents = 0; + std::uint64_t clk_restore_hash = 0; + std::uint64_t clk_compute_distance = 0; + std::uint64_t clk_start; +#define _CLK_START() clk_start = clock64() +#define _CLK_REC(V) V += clock64() - clk_start; +#else +#define _CLK_START() +#define _CLK_REC(V) +#endif + _CLK_START(); + + extern __shared__ uint8_t smem[]; + + // Layout of result_buffer + const auto result_buffer_size = internal_topk + (search_width * graph_degree); + const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); + const auto small_hash_size = hashmap::get_size(small_hash_bitlen); + + // Get dim and smem_ws_size directly from base descriptor + uint32_t dim = dataset_desc->args.dim; + uint32_t smem_ws_size_in_bytes = dataset_desc->smem_ws_size_in_bytes(); + + auto smem_desc = + setup_workspace_base(dataset_desc, smem, queries_ptr, query_id); + + auto* __restrict__ result_indices_buffer = + reinterpret_cast(smem + smem_ws_size_in_bytes); + auto* __restrict__ result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto* __restrict__ visited_hash_buffer = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto* __restrict__ parent_list_buffer = + reinterpret_cast(visited_hash_buffer + small_hash_size); + auto* __restrict__ topk_ws = reinterpret_cast(parent_list_buffer + search_width); + auto* terminate_flag = reinterpret_cast(topk_ws + 3); + auto* __restrict__ smem_work_ptr = reinterpret_cast(terminate_flag + 1); + + // A flag for filtering. + auto filter_flag = terminate_flag; + + if (threadIdx.x == 0) { + terminate_flag[0] = 0; + topk_ws[0] = ~0u; + } + + // Init hashmap + IndexT* local_visited_hashmap_ptr; + if (small_hash_bitlen) { + local_visited_hashmap_ptr = visited_hash_buffer; + } else { + local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * blockIdx.y); + } + hashmap::init(local_visited_hashmap_ptr, hash_bitlen, 0); + __syncthreads(); + _CLK_REC(clk_init); + + // compute distance to randomly selecting nodes using JIT version + _CLK_START(); + const IndexT* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + // Get dataset_size directly from base descriptor + IndexT dataset_size = smem_desc->size; + compute_distance_to_random_nodes_jit(result_indices_buffer, + result_distances_buffer, + smem_desc, + result_buffer_size, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + hash_bitlen, + (IndexT*)nullptr, + 0); + __syncthreads(); + _CLK_REC(clk_compute_1st_distance); + + std::uint32_t iter = 0; + while (1) { + // sort + if constexpr (TOPK_BY_BITONIC_SORT) { + assert(blockDim.x >= 64); + const bool bitonic_sort_and_full_multi_warps = (max_candidates > 128) ? true : false; + + // reset small-hash table. + if ((iter + 1) % small_hash_reset_interval == 0) { + _CLK_START(); + unsigned hash_start_tid; + if (blockDim.x == 32) { + hash_start_tid = 0; + } else if (blockDim.x == 64) { + if (bitonic_sort_and_full_multi_warps || BITONIC_SORT_AND_MERGE_MULTI_WARPS) { + hash_start_tid = 0; + } else { + hash_start_tid = 32; + } + } else { + if (bitonic_sort_and_full_multi_warps || BITONIC_SORT_AND_MERGE_MULTI_WARPS) { + hash_start_tid = 64; + } else { + hash_start_tid = 32; + } + } + hashmap::init(local_visited_hashmap_ptr, hash_bitlen, hash_start_tid); + _CLK_REC(clk_reset_hash); + } + + // topk with bitonic sort + _CLK_START(); + // For JIT version, we always check filter_flag at runtime since sample_filter is extern + if (*filter_flag != 0) { + // Move the filtered out index to the end of the itopk list + for (unsigned i = 0; i < search_width; i++) { + move_invalid_to_end_of_list( + result_indices_buffer, result_distances_buffer, internal_topk); + } + if (threadIdx.x == 0) { *terminate_flag = 0; } + } + topk_by_bitonic_sort_and_merge( + result_distances_buffer, + result_indices_buffer, + max_itopk, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + max_candidates, + search_width * graph_degree, + topk_ws, + (iter == 0)); + __syncthreads(); + _CLK_REC(clk_topk); + } else { + _CLK_START(); + // topk with radix block sort + topk_by_radix_sort{}(max_itopk, + internal_topk, + result_buffer_size, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + nullptr, + topk_ws, + true, + smem_work_ptr); + _CLK_REC(clk_topk); + + // reset small-hash table + if ((iter + 1) % small_hash_reset_interval == 0) { + _CLK_START(); + hashmap::init(local_visited_hashmap_ptr, hash_bitlen); + _CLK_REC(clk_reset_hash); + } + } + __syncthreads(); + + if (iter + 1 == max_iteration) { break; } + + // pick up next parents + if (threadIdx.x < 32) { + _CLK_START(); + pickup_next_parents( + terminate_flag, parent_list_buffer, result_indices_buffer, internal_topk, search_width); + _CLK_REC(clk_pickup_parents); + } + + // restore small-hash table by putting internal-topk indices in it + if ((iter + 1) % small_hash_reset_interval == 0) { + const unsigned first_tid = ((blockDim.x <= 32) ? 0 : 32); + _CLK_START(); + hashmap_restore( + local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk, first_tid); + _CLK_REC(clk_restore_hash); + } + __syncthreads(); + + if (*terminate_flag && iter >= min_iteration) { break; } + + __syncthreads(); + // compute the norms between child nodes and query node using JIT version + _CLK_START(); + compute_distance_to_child_nodes_jit( + result_indices_buffer + internal_topk, + result_distances_buffer + internal_topk, + smem_desc, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + (IndexT*)nullptr, + 0u, + parent_list_buffer, + result_indices_buffer, + search_width); + // Critical: __syncthreads() must be reached by ALL threads + // If any thread is stuck in compute_distance_to_child_nodes_jit, this will hang + __syncthreads(); + _CLK_REC(clk_compute_distance); + + // Filtering - use extern sample_filter function + if (threadIdx.x == 0) { *filter_flag = 0; } + __syncthreads(); + + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const IndexT invalid_index = utils::get_max_value(); + + for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { + if (parent_list_buffer[p] != invalid_index) { + const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; + // Construct filter_data struct (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + if (!sample_filter(query_id + query_id_offset, + to_source_index(parent_id), + bitset_ptr != nullptr ? &filter_data : nullptr)) { + result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_list_buffer[p]] = invalid_index; + *filter_flag = 1; + } + } + } + __syncthreads(); + + iter++; + } + + // Post process for filtering - use extern sample_filter function + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const IndexT invalid_index = utils::get_max_value(); + + for (unsigned i = threadIdx.x; i < internal_topk + search_width * graph_degree; i += blockDim.x) { + const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; + // Construct filter_data struct (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + if (node_id != (invalid_index & ~index_msb_1_mask) && + !sample_filter(query_id + query_id_offset, + to_source_index(node_id), + bitset_ptr != nullptr ? &filter_data : nullptr)) { + result_distances_buffer[i] = utils::get_max_value(); + result_indices_buffer[i] = invalid_index; + } + } + + __syncthreads(); + // Move invalid index items to the end of the buffer without sorting the entire buffer + using scan_op_t = cub::WarpScan; + auto& temp_storage = *reinterpret_cast(smem_work_ptr); + + constexpr std::uint32_t warp_size = 32; + if (threadIdx.x < warp_size) { + std::uint32_t num_found_valid = 0; + for (std::uint32_t buffer_offset = 0; buffer_offset < internal_topk; + buffer_offset += warp_size) { + const auto src_position = buffer_offset + threadIdx.x; + const std::uint32_t is_valid_index = + (result_indices_buffer[src_position] & (~index_msb_1_mask)) == invalid_index ? 0 : 1; + std::uint32_t new_position; + scan_op_t(temp_storage).InclusiveSum(is_valid_index, new_position); + if (is_valid_index) { + const auto dst_position = num_found_valid + (new_position - 1); + result_indices_buffer[dst_position] = result_indices_buffer[src_position]; + result_distances_buffer[dst_position] = result_distances_buffer[src_position]; + } + + num_found_valid += new_position; + for (std::uint32_t offset = (warp_size >> 1); offset > 0; offset >>= 1) { + const auto v = raft::shfl_xor(num_found_valid, offset); + if ((threadIdx.x & offset) == 0) { num_found_valid = v; } + } + + if (num_found_valid >= top_k) { break; } + } + + if (num_found_valid < top_k) { + for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + } + + // If the sufficient number of valid indexes are not in the internal topk, pick up from the + // candidate list. + if (top_k > internal_topk || result_indices_buffer[top_k - 1] == invalid_index) { + __syncthreads(); + topk_by_bitonic_sort_and_merge( + result_distances_buffer, + result_indices_buffer, + max_itopk, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + max_candidates, + search_width * graph_degree, + topk_ws, + (iter == 0)); + } + __syncthreads(); + + // NB: The indices pointer is tagged with its element size. + const uint32_t index_element_tag = result_indices_ptr & 0x3; + result_indices_ptr ^= index_element_tag; + auto write_indices = + index_element_tag == 3 + ? [](uintptr_t ptr, + uint32_t i, + SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } + : index_element_tag == 2 + ? [](uintptr_t ptr, + uint32_t i, + SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } + : index_element_tag == 1 + ? [](uintptr_t ptr, + uint32_t i, + SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } + : [](uintptr_t ptr, uint32_t i, SourceIndexT x) { + reinterpret_cast(ptr)[i] = static_cast(x); + }; + for (std::uint32_t i = threadIdx.x; i < top_k; i += blockDim.x) { + unsigned j = i + (top_k * query_id); + unsigned ii = i; + if constexpr (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } + if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } + + auto internal_index = + result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit + auto source_index = to_source_index(internal_index); + write_indices(result_indices_ptr, j, source_index); + } + if (threadIdx.x == 0 && num_executed_iterations != nullptr) { + num_executed_iterations[query_id] = iter + 1; + } +#ifdef _CLK_BREAKDOWN + if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && ((query_id * 3) % gridDim.y < 3)) { + printf( + "%s:%d " + "query, %d, thread, %d" + ", init, %lu" + ", 1st_distance, %lu" + ", topk, %lu" + ", reset_hash, %lu" + ", pickup_parents, %lu" + ", restore_hash, %lu" + ", distance, %lu" + "\n", + __FILE__, + __LINE__, + query_id, + threadIdx.x, + clk_init, + clk_compute_1st_distance, + clk_topk, + clk_reset_hash, + clk_pickup_parents, + clk_restore_hash, + clk_compute_distance); + } +#endif +} + +// JIT device implementation - called from extern "C" __global__ entry in generated .cu +template +__device__ void search_kernel_jit( + uintptr_t result_indices_ptr, + DistanceT* const result_distances_ptr, + const std::uint32_t top_k, + const DataT* const queries_ptr, + const IndexT* const knn_graph, + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, + const uint32_t num_seeds, + IndexT* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter + const dataset_descriptor_base_t* dataset_desc, + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) + SourceIndexT bitset_len, // Bitset length + SourceIndexT original_nbits) // Original number of bits +{ + const auto query_id = blockIdx.y; + search_core(result_indices_ptr, + result_distances_ptr, + top_k, + queries_ptr, + knn_graph, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + query_id_offset, + dataset_desc, + bitset_ptr, + bitset_len, + original_nbits); +} + +// JIT persistent device implementation - called from extern "C" __global__ entry in generated .cu +template +__device__ void search_kernel_p_jit( + worker_handle_t* worker_handles, + job_desc_t>* job_descriptors, + uint32_t* completion_counters, + const IndexT* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + IndexT* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter + const dataset_descriptor_base_t* dataset_desc, + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) + SourceIndexT bitset_len, // Bitset length + SourceIndexT original_nbits) // Original number of bits +{ + using job_desc_type = job_desc_t>; + __shared__ typename job_desc_type::input_t job_descriptor; + __shared__ worker_handle_t::data_t worker_data; + + auto& worker_handle = worker_handles[blockIdx.y].data; + uint32_t job_ix; + + while (true) { + // wait the writing phase + if (threadIdx.x == 0) { + worker_handle_t::data_t worker_data_local; + do { + worker_data_local = worker_handle.load(cuda::memory_order_relaxed); + } while (worker_data_local.handle == kWaitForWork); + if (worker_data_local.handle != kNoMoreWork) { + worker_handle.store({kWaitForWork}, cuda::memory_order_relaxed); + } + job_ix = worker_data_local.value.desc_id; + cuda::atomic_thread_fence(cuda::memory_order_acquire, cuda::thread_scope_system); + worker_data = worker_data_local; + } + if (threadIdx.x < raft::WarpSize) { + // Sync one warp and copy descriptor data + static_assert(job_desc_type::kBlobSize <= raft::WarpSize); + constexpr uint32_t kMaxJobsNum = 8192; + job_ix = raft::shfl(job_ix, 0); + if (threadIdx.x < job_desc_type::kBlobSize && job_ix < kMaxJobsNum) { + job_descriptor.blob[threadIdx.x] = job_descriptors[job_ix].input.blob[threadIdx.x]; + } + } + __syncthreads(); + if (worker_data.handle == kNoMoreWork) { break; } + + // reading phase + auto result_indices_ptr = job_descriptor.value.result_indices_ptr; + auto* result_distances_ptr = job_descriptor.value.result_distances_ptr; + auto* queries_ptr = job_descriptor.value.queries_ptr; + auto top_k = job_descriptor.value.top_k; + auto n_queries = job_descriptor.value.n_queries; + auto query_id = worker_data.value.query_id; + + // work phase - use JIT search_core + search_core(result_indices_ptr, + result_distances_ptr, + top_k, + queries_ptr, + knn_graph, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + query_id_offset, + dataset_desc, + bitset_ptr, + bitset_len, + original_nbits); + + // make sure all writes are visible even for the host + // (e.g. when result buffers are in pinned memory) + cuda::atomic_thread_fence(cuda::memory_order_release, cuda::thread_scope_system); + + // arrive to mark the end of the work phase + __syncthreads(); + if (threadIdx.x == 0) { + auto completed_count = atomicInc(completion_counters + job_ix, n_queries - 1) + 1; + if (completed_count >= n_queries) { + job_descriptors[job_ix].completion_flag.store(true, cuda::memory_order_relaxed); + } + } + } +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in new file mode 100644 index 0000000000..cba1bf525c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in @@ -0,0 +1,86 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using dataset_desc_base = + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; + +} // namespace + +namespace cuvs::neighbors::cagra::detail { + +extern "C" __global__ __launch_bounds__(1024, 1) void search_single_cta( + uintptr_t topk_indices_ptr, + distance_t* const topk_distances_ptr, + const std::uint32_t topk, + const data_t* const queries_ptr, + const index_t* const knn_graph, + const std::uint32_t graph_degree, + const source_index_t* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const index_t* seed_ptr, + const uint32_t num_seeds, + index_t* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, + const dataset_desc_base* dataset_desc, + uint32_t* bitset_ptr, + source_index_t bitset_len, + source_index_t original_nbits) +{ + single_cta_search:: + search_kernel_jit<@topk_by_bitonic_sort@, @bitonic_sort_and_merge_multi_warps@, data_t, index_t, distance_t, source_index_t>( + topk_indices_ptr, + topk_distances_ptr, + topk, + queries_ptr, + knn_graph, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id_offset, + dataset_desc, + bitset_ptr, + bitset_len, + original_nbits); +} + +static_assert( + std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json new file mode 100644 index 0000000000..a536af418e --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json @@ -0,0 +1,21 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "uc"}, + {"data_type": "int8_t", "data_abbrev": "sc"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "ui"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "ui"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}], + "_topk_by_bitonic": [ + {"topk_by_bitonic_sort": "true", "topk_by_bitonic_sort_str": "topk_by_bitonic_sort"}, + {"topk_by_bitonic_sort": "false", "topk_by_bitonic_sort_str": "no_topk_by_bitonic_sort"} + ], + "_bitonic_sort_and_merge_multi_warps": [ + {"bitonic_sort_and_merge_multi_warps": "true", "bitonic_sort_and_merge_multi_warps_str": "bitonic_sort_and_merge_multi_warps"}, + {"bitonic_sort_and_merge_multi_warps": "false", "bitonic_sort_and_merge_multi_warps_str": "no_bitonic_sort_and_merge_multi_warps"} + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in new file mode 100644 index 0000000000..738d6bd98f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in @@ -0,0 +1,87 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace scta_jit = cuvs::neighbors::cagra::detail::single_cta_search; + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using dataset_desc_base = + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; +using job_desc_jit = + scta_jit::job_desc_t>; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +extern "C" __global__ __launch_bounds__(1024, 1) void search_single_cta_p( + worker_handle_t* worker_handles, + job_desc_jit* job_descriptors, + uint32_t* completion_counters, + const index_t* const knn_graph, + const std::uint32_t graph_degree, + const source_index_t* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const index_t* seed_ptr, + const uint32_t num_seeds, + index_t* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, + const dataset_desc_base* dataset_desc, + uint32_t* bitset_ptr, + source_index_t bitset_len, + source_index_t original_nbits) +{ + search_kernel_p_jit<@topk_by_bitonic_sort@, @bitonic_sort_and_merge_multi_warps@, data_t, index_t, distance_t, source_index_t>( + worker_handles, + job_descriptors, + completion_counters, + knn_graph, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id_offset, + dataset_desc, + bitset_ptr, + bitset_len, + original_nbits); +} + +static_assert( + std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json new file mode 100644 index 0000000000..a536af418e --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json @@ -0,0 +1,21 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "uc"}, + {"data_type": "int8_t", "data_abbrev": "sc"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "ui"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "ui"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}], + "_topk_by_bitonic": [ + {"topk_by_bitonic_sort": "true", "topk_by_bitonic_sort_str": "topk_by_bitonic_sort"}, + {"topk_by_bitonic_sort": "false", "topk_by_bitonic_sort_str": "no_topk_by_bitonic_sort"} + ], + "_bitonic_sort_and_merge_multi_warps": [ + {"bitonic_sort_and_merge_multi_warps": "true", "bitonic_sort_and_merge_multi_warps_str": "bitonic_sort_and_merge_multi_warps"}, + {"bitonic_sort_and_merge_multi_warps": "false", "bitonic_sort_and_merge_multi_warps_str": "no_bitonic_sort_and_merge_multi_warps"} + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp new file mode 100644 index 0000000000..afe856c495 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp @@ -0,0 +1,103 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "cagra_planner_base.hpp" +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +template +struct CagraSingleCtaSearchPlanner : CagraPlannerBase { + CagraSingleCtaSearchPlanner(cuvs::distance::DistanceType /*metric*/, + bool /*topk_by_bitonic_sort*/, + bool /*bitonic_sort_and_merge_multi_warps*/, + uint32_t /*team_size*/, + uint32_t /*dataset_block_dim*/, + bool /*is_vpq*/, + uint32_t /*pq_bits*/, + uint32_t /*pq_len*/, + bool persistent = false) + : CagraPlannerBase(persistent ? "search_single_cta_p" : "search_single_cta") + { + } + + void add_search_kernel_fragment(bool topk_by_bitonic_sort, + bool bitonic_sort_and_merge_multi_warps, + bool persistent) + { + if (persistent) { + if (topk_by_bitonic_sort && bitonic_sort_and_merge_multi_warps) { + this->add_static_fragment>(); + } else if (topk_by_bitonic_sort && !bitonic_sort_and_merge_multi_warps) { + this->add_static_fragment>(); + } else if (!topk_by_bitonic_sort && bitonic_sort_and_merge_multi_warps) { + this->add_static_fragment>(); + } else { + this->add_static_fragment>(); + } + } else { + if (topk_by_bitonic_sort && bitonic_sort_and_merge_multi_warps) { + this->add_static_fragment>(); + } else if (topk_by_bitonic_sort && !bitonic_sort_and_merge_multi_warps) { + this->add_static_fragment>(); + } else if (!topk_by_bitonic_sort && bitonic_sort_and_merge_multi_warps) { + this->add_static_fragment>(); + } else { + this->add_static_fragment>(); + } + } + } +}; + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in new file mode 100644 index 0000000000..9bd9497b48 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in @@ -0,0 +1,47 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using query_t = @query_type@; + +} // namespace + +namespace cuvs::neighbors::cagra::detail { + +// NOTE: Explicit instantiation may be redundant: setup_workspace_base<> below calls +// setup_workspace with the same args, which normally implicit-instantiates it. +// Revisit to drop this if NVCC/tooling does not require the explicit line. + +template __device__ const dataset_descriptor_base_t* +setup_workspace<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, data_t, index_t, distance_t, query_t>( + const dataset_descriptor_base_t*, void*, const data_t*, uint32_t); + +template <> +__device__ const dataset_descriptor_base_t* +setup_workspace_base( + const dataset_descriptor_base_t* desc, + void* smem, + const data_t* queries, + uint32_t query_id) +{ + return setup_workspace<@team_size@, + @dataset_block_dim@, + @pq_bits@, + @pq_len@, + @codebook_type@, + data_t, + index_t, + distance_t, + query_t>(desc, smem, queries, query_id); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json new file mode 100644 index 0000000000..23822b3996 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json @@ -0,0 +1,158 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "sc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_prefix": "_standard", + "pq_suffix": "", + "pq_bits": "0", + "pq_len": "0" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_tag": "", + "codebook_tag_comma": "" + } + ], + "impl_file": "setup_workspace_standard_impl.cuh" + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_len": "2", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_len": "4", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_tag": "tag_codebook_half", + "codebook_tag_comma": ", " + } + ], + "impl_file": "setup_workspace_vpq_impl.cuh" + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_standard_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_standard_impl.cuh new file mode 100644 index 0000000000..3089feb68d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_standard_impl.cuh @@ -0,0 +1,46 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_standard-impl.cuh" +#include "../device_common.hpp" + +namespace cuvs::neighbors::cagra::detail { + +// Unified setup_workspace implementation for standard descriptors +// This is instantiated when PQ_BITS=0, PQ_LEN=0, CodebookT=void +// Takes const dataset_descriptor_base_t* and reconstructs the derived descriptor inside smem +// QueryT can be float (for most metrics) or uint8_t (for BitwiseHamming) +template +__device__ const dataset_descriptor_base_t* setup_workspace( + const dataset_descriptor_base_t* desc_ptr, + void* smem, + const DataT* queries, + uint32_t query_id) +{ + // For standard descriptors, PQ_BITS=0, PQ_LEN=0, CodebookT=void + static_assert(PQ_BITS == 0 && PQ_LEN == 0 && std::is_same_v, + "Standard descriptor requires PQ_BITS=0, PQ_LEN=0, CodebookT=void"); + + // Reconstruct the descriptor pointer from base pointer with QueryT + using desc_t = + standard_dataset_descriptor_t; + const desc_t* desc = static_cast(desc_ptr); + + // Call the free function directly - it takes DescriptorT as template parameter + const desc_t* result = setup_workspace_standard(desc, smem, queries, query_id); + return static_cast*>(result); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_vpq_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_vpq_impl.cuh new file mode 100644 index 0000000000..7d21bbb8b7 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_vpq_impl.cuh @@ -0,0 +1,55 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_vpq-impl.cuh" +#include "../device_common.hpp" + +namespace cuvs::neighbors::cagra::detail { + +// Unified setup_workspace implementation for VPQ descriptors +// This is instantiated when PQ_BITS>0, PQ_LEN>0, CodebookT=half +// Takes const dataset_descriptor_base_t* and reconstructs the derived descriptor inside smem +// QueryT is always half for VPQ +template +__device__ const dataset_descriptor_base_t* setup_workspace( + const dataset_descriptor_base_t* desc_ptr, + void* smem, + const DataT* queries, + uint32_t query_id) +{ + // For VPQ descriptors, PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half + static_assert( + PQ_BITS > 0 && PQ_LEN > 0 && std::is_same_v && std::is_same_v, + "VPQ descriptor requires PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half"); + + // Reconstruct the descriptor pointer from base pointer + // QueryT is always half for VPQ + using desc_t = cagra_q_dataset_descriptor_t; + const desc_t* desc = static_cast(desc_ptr); + + // Call the free function directly - it takes DescriptorT as template parameter + const desc_t* result = setup_workspace_vpq(desc, smem, queries, query_id); + return static_cast*>(result); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 4d4ddb9b80..15c9c87c98 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -30,6 +30,8 @@ #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp +#include + #include #include #include @@ -91,10 +93,10 @@ struct search constexpr static bool kNeedIndexCopy = sizeof(INDEX_T) != sizeof(OutputIndexT); uint32_t num_cta_per_query; - lightweight_uvector intermediate_indices; - lightweight_uvector intermediate_distances; + rmm::device_uvector intermediate_indices; + rmm::device_uvector intermediate_distances; size_t topk_workspace_size; - lightweight_uvector topk_workspace; + rmm::device_uvector topk_workspace; search(raft::resources const& res, search_params params, @@ -104,9 +106,9 @@ struct search int64_t graph_degree, uint32_t topk) : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), - intermediate_indices(res), - intermediate_distances(res), - topk_workspace(res) + intermediate_indices(0, raft::resource::get_cuda_stream(res)), + intermediate_distances(0, raft::resource::get_cuda_stream(res)), + topk_workspace(0, raft::resource::get_cuda_stream(res)) { set_params(res, params); } diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh index bd4d25d8f3..834b7b21ee 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh @@ -1,37 +1,39 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include "../../sample_filter.cuh" +#include "sample_filter_utils.cuh" #include "search_multi_cta_kernel-inl.cuh" #include namespace cuvs::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \ - template void select_and_run( \ - const dataset_descriptor_host& dataset_desc, \ - raft::device_matrix_view graph, \ - const IndexT* source_indices_ptr, \ - uint32_t* topk_indices_ptr, \ - DistanceT* topk_distances_ptr, \ - const DataT* queries_ptr, \ - uint32_t num_queries, \ - const uint32_t* dev_seed_ptr, \ - uint32_t* num_executed_iterations, \ - const search_params& ps, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - uint32_t small_hash_bitlen, \ - int64_t hash_bitlen, \ - uint32_t* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_seeds, \ - SampleFilterT sample_filter, \ +#define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \ + template void select_and_run( \ + const dataset_descriptor_host& dataset_desc, \ + raft::device_matrix_view graph, \ + const IndexT* source_indices_ptr, \ + IndexT* topk_indices_ptr, \ + DistanceT* topk_distances_ptr, \ + const DataT* queries_ptr, \ + uint32_t num_queries, \ + const IndexT* dev_seed_ptr, \ + uint32_t* num_executed_iterations, \ + const search_params& ps, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + uint32_t small_hash_bitlen, \ + int64_t hash_bitlen, \ + IndexT* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_seeds, \ + SampleFilterT sample_filter, \ cudaStream_t stream); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 4ac0020c5c..17e7a5684f 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -14,6 +14,12 @@ #include "utils.hpp" #include +#ifdef CUVS_ENABLE_JIT_LTO +#include "search_multi_cta_kernel_launcher_jit.cuh" +#else +#include "set_value_batch.cuh" +#endif + #include #include #include @@ -455,7 +461,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( uint32_t k = j + (itopk_size * (cta_id + (num_cta_per_query * query_id))); result_indices_ptr[k] = index & ~index_msb_1_mask; if (result_distances_ptr != nullptr) { - result_distances_ptr[k] = result_distances_buffer[i]; + DISTANCE_T dist = result_distances_buffer[i]; + result_distances_ptr[k] = dist; } } else { // If it is valid and registered in the traversed hash table but is @@ -504,34 +511,6 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( #endif } -template -RAFT_KERNEL set_value_batch_kernel(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto batch_id = tid / count; - const auto elem_id = tid % count; - dev_ptr[elem_id + ld * batch_id] = val; -} - -template -void set_value_batch(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size, - cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count * batch_size + block_size - 1) / block_size; - set_value_batch_kernel - <<>>(dev_ptr, ld, val, count, batch_size); -} - template struct search_kernel_config { // Search kernel function type. Note that the actual values for the template value @@ -576,6 +555,31 @@ void select_and_run(const dataset_descriptor_host& dat SampleFilterT sample_filter, cudaStream_t stream) { +#ifdef CUVS_ENABLE_JIT_LTO + // Use JIT version when JIT is enabled + select_and_run_jit(dataset_desc, + graph, + source_indices_ptr, + topk_indices_ptr, + topk_distances_ptr, + queries_ptr, + num_queries, + dev_seed_ptr, + num_executed_iterations, + ps, + topk, + block_size, + result_buffer_size, + smem_size, + visited_hash_bitlen, + traversed_hash_bitlen, + traversed_hashmap_ptr, + num_cta_per_query, + num_seeds, + sample_filter, + stream); +#else + // Non-JIT path auto kernel = search_kernel_config, SourceIndexT, @@ -609,27 +613,31 @@ void select_and_run(const dataset_descriptor_host& dat num_queries, smem_size); - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - dataset_desc.dev_ptr(stream), - queries_ptr, - graph.data_handle(), - max_elements, - graph.extent(1), - source_indices_ptr, - ps.num_random_samplings, - ps.rand_xor_mask, - dev_seed_ptr, - num_seeds, - visited_hash_bitlen, - traversed_hashmap_ptr, - traversed_hash_bitlen, - ps.itopk_size, - ps.min_iterations, - ps.max_iterations, - num_executed_iterations, - sample_filter, - static_cast(graph.extent(0))); + auto const& kernel_launcher = [&](auto const& kernel) -> void { + kernel<<>>(topk_indices_ptr, + topk_distances_ptr, + dataset_desc.dev_ptr(stream), + queries_ptr, + graph.data_handle(), + max_elements, + graph.extent(1), + source_indices_ptr, + ps.num_random_samplings, + ps.rand_xor_mask, + dev_seed_ptr, + num_seeds, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen, + ps.itopk_size, + ps.min_iterations, + ps.max_iterations, + num_executed_iterations, + sample_filter, + static_cast(graph.extent(0))); + }; + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size(kernel, smem_size, kernel_launcher); +#endif } } // namespace multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh new file mode 100644 index 0000000000..6eb7402f3e --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh @@ -0,0 +1,179 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_ENABLE_JIT_LTO +#error "search_multi_cta_kernel_launcher_jit.cuh included but CUVS_ENABLE_JIT_LTO not defined!" +#endif + +#include "../smem_utils.cuh" + +// Include tags header before any other includes that might open namespaces +#include + +#include "compute_distance.hpp" // For dataset_descriptor_host +#include "jit_lto_kernels/cagra_jit_launcher_factory.hpp" +#include "jit_lto_kernels/kernel_def.hpp" +#include "sample_filter_utils.cuh" // For CagraSampleFilterWithQueryIdOffset +#include "search_plan.cuh" // For search_params +#include "set_value_batch.cuh" // For set_value_batch +#include "shared_launcher_jit.hpp" // For shared JIT helper functions +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +// JIT version of select_and_run for multi_cta +template +void select_and_run_jit( + const dataset_descriptor_host& dataset_desc, + raft::device_matrix_view graph, + const SourceIndexT* source_indices_ptr, + IndexT* topk_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] + DistanceT* topk_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] + const DataT* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + const IndexT* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* num_executed_iterations, // [num_queries,] + const search_params& ps, + uint32_t topk, + // multi_cta_search (params struct) + uint32_t block_size, // + uint32_t result_buffer_size, + uint32_t smem_size, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, + uint32_t num_cta_per_query, + uint32_t num_seeds, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + // Extract bitset data from filter object (if it's a bitset_filter) + uint32_t* bitset_ptr = nullptr; + SourceIndexT bitset_len = 0; + SourceIndexT original_nbits = 0; + uint32_t query_id_offset = 0; + + // Check if it has the wrapper members (CagraSampleFilterWithQueryIdOffset) + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + using InnerFilter = decltype(sample_filter.filter); + // Always extract offset for wrapped filters + query_id_offset = sample_filter.offset; + if constexpr (is_bitset_filter::value) { + // Extract bitset data for bitset_filter (works for any bitset_filter instantiation) + auto bitset_view = sample_filter.filter.view(); + bitset_ptr = const_cast(bitset_view.data()); + bitset_len = static_cast(bitset_view.size()); + original_nbits = static_cast(bitset_view.get_original_nbits()); + } + } + + std::string const filter_name = get_sample_filter_name(); + std::shared_ptr launcher = + make_cagra_multi_cta_jit_launcher(dataset_desc, + filter_name); + + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher"); } + + uint32_t max_elements{}; + if (result_buffer_size <= 64) { + max_elements = 64; + } else if (result_buffer_size <= 128) { + max_elements = 128; + } else if (result_buffer_size <= 256) { + max_elements = 256; + } else { + THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); + } + + // Initialize hash table + const uint32_t traversed_hash_size = hashmap::get_size(traversed_hash_bitlen); + set_value_batch(traversed_hashmap_ptr, + traversed_hash_size, + ~static_cast(0), + traversed_hash_size, + num_queries, + stream); + + dim3 block_dims(block_size, 1, 1); + dim3 grid_dims(num_cta_per_query, num_queries, 1); + + // Get the device descriptor pointer + const dataset_descriptor_base_t* dev_desc_base = + dataset_desc.dev_ptr(stream); + const auto* dev_desc = dev_desc_base; + + // Note: dataset_desc is passed by const reference, so it stays alive for the duration of this + // function The descriptor's state is managed by a shared_ptr internally, so no need to explicitly + // keep it alive + + // Cast size_t/int64_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + // graph.extent(1) returns int64_t but kernel expects uint32_t + // traversed_hash_bitlen is int64_t but kernel expects uint32_t + // ps.itopk_size, ps.min_iterations, ps.max_iterations are size_t (8 bytes) but kernel expects + // uint32_t (4 bytes) ps.num_random_samplings is uint32_t but kernel expects unsigned - cast for + // consistency + const uint32_t graph_degree_u32 = static_cast(graph.extent(1)); + const uint32_t traversed_hash_bitlen_u32 = static_cast(traversed_hash_bitlen); + const uint32_t itopk_size_u32 = static_cast(ps.itopk_size); + const uint32_t min_iterations_u32 = static_cast(ps.min_iterations); + const uint32_t max_iterations_u32 = static_cast(ps.max_iterations); + const unsigned num_random_samplings_u = static_cast(ps.num_random_samplings); + + auto kernel_launcher = [&](auto const& kernel) -> void { + launcher->dispatch< + multi_cta_search::search_multi_cta_jit_func_t>( + stream, + grid_dims, + block_dims, + smem_size, + topk_indices_ptr, + topk_distances_ptr, + dev_desc, + queries_ptr, + graph.data_handle(), + max_elements, + graph_degree_u32, + source_indices_ptr, + num_random_samplings_u, + ps.rand_xor_mask, + dev_seed_ptr, + num_seeds, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen_u32, + itopk_size_u32, + min_iterations_u32, + max_iterations_u32, + num_executed_iterations, + query_id_offset, + bitset_ptr, + bitset_len, + original_nbits); + }; + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size( + launcher->get_kernel(), smem_size, kernel_launcher); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index 7aab9e241b..f7609c0c33 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -4,7 +4,15 @@ */ #pragma once -#include "device_common.hpp" +// Include tags header before any namespace declarations to avoid issues when it's included inside +// functions +#ifdef CUVS_ENABLE_JIT_LTO +#include "search_multi_kernel_launcher_jit.cuh" +#include +#endif + +#include "set_value_batch.cuh" + #include "hashmap.hpp" #include "search_plan.cuh" #include "topk_for_cagra/topk.h" //todo replace with raft kernel @@ -172,25 +180,48 @@ void random_pickup(const dataset_descriptor_host& data cudaStream_t cuda_stream, IndexT graph_size = 0) { - const auto block_size = 256u; - const auto num_teams_per_threadblock = block_size / dataset_desc.team_size; - const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, - num_queries); - - random_pickup_kernel<<>>( - dataset_desc.dev_ptr(cuda_stream), - queries_ptr, - num_pickup, - num_distilation, - rand_xor_mask, - seed_ptr, - num_seeds, - result_indices_ptr, - result_distances_ptr, - ldr, - visited_hashmap_ptr, - hash_bitlen, - graph_size); +#ifdef CUVS_ENABLE_JIT_LTO + // Use JIT version when JIT is enabled + random_pickup_jit(dataset_desc, + queries_ptr, + num_queries, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr, + visited_hashmap_ptr, + hash_bitlen, + cuda_stream); +#else + // Non-JIT path + { + const auto block_size = 256u; + const auto num_teams_per_threadblock = block_size / dataset_desc.team_size; + const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, + num_queries); + + random_pickup_kernel<<>>(dataset_desc.dev_ptr(cuda_stream), + queries_ptr, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr, + visited_hashmap_ptr, + hash_bitlen, + graph_size); + } +#endif } template @@ -407,30 +438,55 @@ void compute_distance_to_child_nodes( SAMPLE_FILTER_T sample_filter, cudaStream_t cuda_stream) { - const auto block_size = 128; - const auto teams_per_block = block_size / dataset_desc.team_size; - const dim3 grid_size((search_width * graph_degree + teams_per_block - 1) / teams_per_block, - num_queries); - - compute_distance_to_child_nodes_kernel<<>>(parent_node_list, - parent_candidates_ptr, - parent_distance_ptr, - lds, - search_width, - dataset_desc.dev_ptr(cuda_stream), - neighbor_graph_ptr, - graph_degree, - source_indices_ptr, - query_ptr, - visited_hashmap_ptr, - hash_bitlen, - result_indices_ptr, - result_distances_ptr, - ldd, - sample_filter); +#ifdef CUVS_ENABLE_JIT_LTO + // Use JIT version when JIT is enabled + compute_distance_to_child_nodes_jit(parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, + search_width, + dataset_desc, + neighbor_graph_ptr, + graph_degree, + source_indices_ptr, + query_ptr, + num_queries, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd, + sample_filter, + cuda_stream); +#else + // Non-JIT path + { + const auto block_size = 128; + const auto teams_per_block = block_size / dataset_desc.team_size; + const dim3 grid_size((search_width * graph_degree + teams_per_block - 1) / teams_per_block, + num_queries); + + compute_distance_to_child_nodes_kernel<<>>(parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, + search_width, + dataset_desc.dev_ptr(cuda_stream), + neighbor_graph_ptr, + graph_degree, + source_indices_ptr, + query_ptr, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd, + sample_filter); + } +#endif } template @@ -502,17 +558,33 @@ void apply_filter(const SourceIndexT* source_indices_ptr, SAMPLE_FILTER_T sample_filter, cudaStream_t cuda_stream) { - const std::uint32_t block_size = 256; - const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); - - apply_filter_kernel<<>>(source_indices_ptr, - result_indices_ptr, - result_distances_ptr, - lds, - result_buffer_size, - num_queries, - query_id_offset, - sample_filter); +#ifdef CUVS_ENABLE_JIT_LTO + // Use JIT version when JIT is enabled + apply_filter_jit(source_indices_ptr, + result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + sample_filter, + cuda_stream); +#else + // Non-JIT path + { + const std::uint32_t block_size = 256; + const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); + + apply_filter_kernel<<>>(source_indices_ptr, + result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + sample_filter); + } +#endif } template @@ -547,34 +619,6 @@ void batched_memcpy(T* const dst, // [batch_size, ld_dst] <<>>(dst, ld_dst, src, ld_src, count, batch_size); } -template -RAFT_KERNEL set_value_batch_kernel(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto batch_id = tid / count; - const auto elem_id = tid % count; - dev_ptr[elem_id + ld * batch_id] = val; -} - -template -void set_value_batch(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size, - cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count * batch_size + block_size - 1) / block_size; - set_value_batch_kernel - <<>>(dev_ptr, ld, val, count, batch_size); -} - // result_buffer (work buffer) for "multi-kernel" // +--------------------+------------------------------+-------------------+ // | internal_top_k (A) | neighbors of internal_top_k | internal_topk (B) | @@ -634,18 +678,18 @@ struct search using base_type::num_seeds; size_t result_buffer_allocation_size; - lightweight_uvector result_indices; // results_indices_buffer - lightweight_uvector result_distances; // result_distances_buffer - lightweight_uvector parent_node_list; - lightweight_uvector topk_hint; - lightweight_uvector terminate_flag; // dev_terminate_flag, host_terminate_flag.; - lightweight_uvector topk_workspace; + rmm::device_uvector result_indices; // results_indices_buffer + rmm::device_uvector result_distances; // result_distances_buffer + rmm::device_uvector parent_node_list; + rmm::device_uvector topk_hint; + rmm::device_uvector terminate_flag; // dev_terminate_flag, host_terminate_flag.; + rmm::device_uvector topk_workspace; // temporary storage for _find_topk - lightweight_uvector input_keys_storage; - lightweight_uvector output_keys_storage; - lightweight_uvector input_values_storage; - lightweight_uvector output_values_storage; + rmm::device_uvector input_keys_storage; + rmm::device_uvector output_keys_storage; + rmm::device_uvector input_values_storage; + rmm::device_uvector output_values_storage; search(raft::resources const& res, search_params params, @@ -655,16 +699,16 @@ struct search int64_t graph_degree, uint32_t topk) : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), - result_indices(res), - result_distances(res), - parent_node_list(res), - topk_hint(res), - topk_workspace(res), - terminate_flag(res), - input_keys_storage(res), - output_keys_storage(res), - input_values_storage(res), - output_values_storage(res) + result_indices(0, raft::resource::get_cuda_stream(res)), + result_distances(0, raft::resource::get_cuda_stream(res)), + parent_node_list(0, raft::resource::get_cuda_stream(res)), + topk_hint(0, raft::resource::get_cuda_stream(res)), + topk_workspace(0, raft::resource::get_cuda_stream(res)), + terminate_flag(0, raft::resource::get_cuda_stream(res)), + input_keys_storage(0, raft::resource::get_cuda_stream(res)), + output_keys_storage(0, raft::resource::get_cuda_stream(res)), + input_values_storage(0, raft::resource::get_cuda_stream(res)), + output_values_storage(0, raft::resource::get_cuda_stream(res)) { set_params(res); } @@ -864,6 +908,7 @@ struct search // pickup parent nodes uint32_t _small_hash_bitlen = 0; if ((iter + 1) % small_hash_reset_interval == 0) { _small_hash_bitlen = small_hash_bitlen; } + pickup_next_parents(result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, result_buffer_allocation_size, itopk_size, @@ -878,9 +923,11 @@ struct search stream); // termination (2) - if (iter + 1 >= min_iterations && get_value(terminate_flag.data(), stream)) { - iter++; - break; + if (iter + 1 >= min_iterations) { + if (get_value(terminate_flag.data(), stream)) { + iter++; + break; + } } // Compute distance to child nodes that are adjacent to the parent node @@ -988,7 +1035,6 @@ struct search num_executed_iterations[i] = iter; } } - RAFT_CUDA_TRY(cudaPeekAtLastError()); } }; diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh new file mode 100644 index 0000000000..780b186a9e --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh @@ -0,0 +1,239 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_ENABLE_JIT_LTO +#error "search_multi_kernel_launcher_jit.cuh included but CUVS_ENABLE_JIT_LTO not defined!" +#endif + +// Tags header should be included before this header (at file scope, not inside functions) +// to avoid namespace definition errors when this header is included inside function bodies + +#include "compute_distance.hpp" // For dataset_descriptor_host +#include "jit_lto_kernels/cagra_jit_launcher_factory.hpp" +#include "jit_lto_kernels/kernel_def.hpp" +#include "jit_lto_kernels/search_multi_kernel_planner.hpp" +#include "sample_filter_utils.cuh" // For CagraSampleFilterWithQueryIdOffset +#include "search_plan.cuh" // For search_params +#include "shared_launcher_jit.hpp" // For shared JIT helper functions +#include +#include +#include +#include + +#include +#include +#include +#include +// - The launcher doesn't need the kernel function definitions +// - The kernel is dispatched via the JIT LTO launcher system +// - Including it would pull in impl files that cause namespace issues + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +// JIT version of random_pickup +template +void random_pickup_jit(const dataset_descriptor_host& dataset_desc, + const DataT* queries_ptr, // [num_queries, dataset_dim] + std::size_t num_queries, + std::size_t num_pickup, + unsigned num_distilation, + uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + uint32_t num_seeds, + IndexT* result_indices_ptr, // [num_queries, ldr] + DistanceT* result_distances_ptr, // [num_queries, ldr] + std::size_t ldr, // (*) ldr >= num_pickup + IndexT* visited_hashmap_ptr, // [num_queries, 1 << bitlen] + std::uint32_t hash_bitlen, + cudaStream_t cuda_stream) +{ + std::shared_ptr launcher = + make_cagra_multi_kernel_jit_launcher(dataset_desc, + "random_pickup"); + + const auto block_size = 256u; + const auto num_teams_per_threadblock = block_size / dataset_desc.team_size; + const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, + num_queries); + + // Get the device descriptor pointer + const auto* dev_desc = dataset_desc.dev_ptr(cuda_stream); + + // Cast size_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + const uint32_t ldr_u32 = static_cast(ldr); + + launcher->dispatch>( + cuda_stream, + grid_size, + dim3(block_size, 1, 1), + dataset_desc.smem_ws_size_in_bytes, + dev_desc, + queries_ptr, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr_u32, + visited_hashmap_ptr, + hash_bitlen); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// JIT version of compute_distance_to_child_nodes +template +void compute_distance_to_child_nodes_jit( + const IndexT* parent_node_list, // [num_queries, search_width] + IndexT* const parent_candidates_ptr, // [num_queries, search_width] + DistanceT* const parent_distance_ptr, // [num_queries, search_width] + std::size_t lds, + uint32_t search_width, + const dataset_descriptor_host& dataset_desc, + const IndexT* neighbor_graph_ptr, // [dataset_size, graph_degree] + std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const DataT* query_ptr, // [num_queries, data_dim] + std::uint32_t num_queries, + IndexT* visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + std::uint32_t hash_bitlen, + IndexT* result_indices_ptr, // [num_queries, ldd] + DistanceT* result_distances_ptr, // [num_queries, ldd] + std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + SAMPLE_FILTER_T sample_filter, + cudaStream_t cuda_stream) +{ + std::shared_ptr launcher = + make_cagra_multi_kernel_jit_launcher( + dataset_desc, "compute_distance_to_child_nodes"); + + const auto block_size = 128; + const auto teams_per_block = block_size / dataset_desc.team_size; + const dim3 grid_size((search_width * graph_degree + teams_per_block - 1) / teams_per_block, + num_queries); + + // Get the device descriptor pointer + const auto* dev_desc = dataset_desc.dev_ptr(cuda_stream); + + launcher + ->dispatch>( + cuda_stream, + grid_size, + dim3(block_size, 1, 1), + dataset_desc.smem_ws_size_in_bytes, + parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, + search_width, + dev_desc, + neighbor_graph_ptr, + graph_degree, + source_indices_ptr, + query_ptr, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd, + cuvs::neighbors::filtering::none_sample_filter{}); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// JIT version of apply_filter +template +void apply_filter_jit(const SourceIndexT* source_indices_ptr, + INDEX_T* const result_indices_ptr, + DISTANCE_T* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const INDEX_T query_id_offset, + SAMPLE_FILTER_T sample_filter, + cudaStream_t cuda_stream) +{ + // Extract bitset data from filter object (if it's a bitset_filter) + uint32_t* bitset_ptr = nullptr; + SourceIndexT bitset_len = 0; + SourceIndexT original_nbits = 0; + + // Check if it has the wrapper members (CagraSampleFilterWithQueryIdOffset) + // Note: query_id_offset is already a parameter to this function, so we don't extract it here + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + using InnerFilter = decltype(sample_filter.filter); + if constexpr (is_bitset_filter::value) { + // Extract bitset data for bitset_filter (works for any bitset_filter instantiation) + auto bitset_view = sample_filter.filter.view(); + bitset_ptr = const_cast(bitset_view.data()); + bitset_len = static_cast(bitset_view.size()); + original_nbits = static_cast(bitset_view.get_original_nbits()); + } + } + + // Create planner with tags + using DataTag = + decltype(get_data_type_tag()); // Not used for apply_filter, but required by planner + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + // Create planner - apply_filter doesn't use dataset_descriptor, so we use dummy values + // The kernel name is "apply_filter_kernel" and build_entrypoint_name will handle it specially + using QueryTag = query_type_tag_standard_t; + using CodebookTag = tag_codebook_none; + CagraMultiKernelSearchPlanner + planner(cuvs::distance::DistanceType::L2Expanded, + "apply_filter_kernel", + 8, + 128, + false, + 0, + 0); // Dummy values, not used by apply_filter + + planner.add_sample_filter_device_function(get_sample_filter_name()); + planner.add_linked_kernel("apply_filter_kernel"); + + std::shared_ptr launcher = planner.get_launcher(); + + const std::uint32_t block_size = 256; + const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); + + // Alias avoids nested `dispatch< alias_template<...>>` which NVCC can misparse as + // comparison/shift. + using apply_filter_kernel_func_t = + apply_filter_kernel_jit_func_t; + // `template` required: in template code, `->dispatch<...>` is otherwise parsed as `dispatch <` … + launcher->template dispatch(cuda_stream, + dim3(grid_size, 1, 1), + dim3(block_size, 1, 1), + 0, + source_indices_ptr, + result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + bitset_ptr, + bitset_len, + original_nbits); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index 02bf1ff697..43098c817a 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -34,6 +34,7 @@ #include #include +// All includes are done before opening namespace to avoid nested namespace issues namespace cuvs::neighbors::cagra::detail { namespace single_cta_search { diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh index 11b468cfca..d242e13b95 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh @@ -1,13 +1,16 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include "search_single_cta_kernel-inl.cuh" #include +// Include explicit instantiations before namespace (launcher includes JIT LTO headers with +// namespace definitions) +#include "search_single_cta_kernel_explicit_inst.cuh" + namespace cuvs::neighbors::cagra::detail::single_cta_search { #define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 48553611bf..1c9d657728 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -15,6 +15,8 @@ #include "utils.hpp" #include +#include + #include #include #include @@ -2221,135 +2223,5 @@ auto get_runner(Args... args) -> std::shared_ptr weak = runner; return runner; } - -template -void select_and_run( - const dataset_descriptor_host& dataset_desc, - raft::device_matrix_view graph, - std::optional> source_indices, - uintptr_t topk_indices_ptr, // [num_queries, topk] - DistanceT* topk_distances_ptr, // [num_queries, topk] - const DataT* queries_ptr, // [num_queries, dataset_dim] - uint32_t num_queries, - const IndexT* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* num_executed_iterations, // [num_queries,] - const search_params& ps, - uint32_t topk, - uint32_t num_itopk_candidates, - uint32_t block_size, // - uint32_t smem_size, - int64_t hash_bitlen, - IndexT* hashmap_ptr, - size_t small_hash_bitlen, - size_t small_hash_reset_interval, - uint32_t num_seeds, - SampleFilterT sample_filter, - cudaStream_t stream) -{ - const SourceIndexT* source_indices_ptr = - source_indices.has_value() ? source_indices->data_handle() : nullptr; - - uint32_t max_candidates{}; - if (num_itopk_candidates <= 64) { - max_candidates = 64; - } else if (num_itopk_candidates <= 128) { - max_candidates = 128; - } else if (num_itopk_candidates <= 256) { - max_candidates = 256; - } else { - max_candidates = - 32; // irrelevant, radix based topk is used (see choose_itopk_and_max_candidates) - } - - uint32_t max_itopk{}; - assert(ps.itopk_size <= 512); - if (num_itopk_candidates <= 256) { // bitonic sort - if (ps.itopk_size <= 64) { - max_itopk = 64; - } else if (ps.itopk_size <= 128) { - max_itopk = 128; - } else if (ps.itopk_size <= 256) { - max_itopk = 256; - } else { - max_itopk = 512; - } - } else { // radix sort - if (ps.itopk_size <= 256) { - max_itopk = 256; - } else { - max_itopk = 512; - } - } - - if (ps.persistent) { - using runner_type = persistent_runner_t; - - get_runner(/* -Note, we're passing the descriptor by reference here, and this reference is going to be passed to a -new spawned thread, which is dangerous. However, the descriptor is copied in that thread before the -control is returned in this thread (in persistent_runner_t constructor), so we're safe. -*/ - std::cref(dataset_desc), - graph, - source_indices_ptr, - max_candidates, - num_itopk_candidates, - block_size, - smem_size, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - ps.num_random_samplings, - ps.rand_xor_mask, - num_seeds, - max_itopk, - ps.itopk_size, - ps.search_width, - ps.min_iterations, - ps.max_iterations, - sample_filter, - ps.persistent_lifetime, - ps.persistent_device_usage) - ->launch(topk_indices_ptr, topk_distances_ptr, queries_ptr, num_queries, topk); - } else { - using descriptor_base_type = dataset_descriptor_base_t; - auto kernel = search_kernel_config:: - choose_itopk_and_mx_candidates(ps.itopk_size, num_itopk_candidates, block_size); - dim3 thread_dims(block_size, 1, 1); - dim3 block_dims(1, num_queries, 1); - RAFT_LOG_DEBUG( - "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - topk, - dataset_desc.dev_ptr(stream), - queries_ptr, - graph.data_handle(), - graph.extent(1), - source_indices_ptr, - ps.num_random_samplings, - ps.rand_xor_mask, - dev_seed_ptr, - num_seeds, - hashmap_ptr, - max_candidates, - max_itopk, - ps.itopk_size, - ps.search_width, - ps.min_iterations, - ps.max_iterations, - num_executed_iterations, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - sample_filter, - static_cast(graph.extent(0))); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } -} } // namespace single_cta_search } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_explicit_inst.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_explicit_inst.cuh new file mode 100644 index 0000000000..8f715bbbc4 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_explicit_inst.cuh @@ -0,0 +1,12 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifdef CUVS_ENABLE_JIT_LTO +#include "search_single_cta_kernel_launcher_jit.cuh" +#else +#include "search_single_cta_kernel_launcher.cuh" +#endif diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher.cuh new file mode 100644 index 0000000000..10b863f5f9 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher.cuh @@ -0,0 +1,123 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../smem_utils.cuh" + +#include "search_single_cta_kernel-inl.cuh" // For search_kernel_config, persistent_runner_t, etc. +#include "search_single_cta_kernel_launcher_common.cuh" + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +template +void select_and_run( + const dataset_descriptor_host& dataset_desc, + raft::device_matrix_view graph, + std::optional> source_indices, + uintptr_t topk_indices_ptr, // [num_queries, topk] + DistanceT* topk_distances_ptr, // [num_queries, topk] + const DataT* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + const IndexT* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* num_executed_iterations, // [num_queries,] + const search_params& ps, + uint32_t topk, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + IndexT* hashmap_ptr, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_seeds, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + const SourceIndexT* source_indices_ptr = + source_indices.has_value() ? source_indices->data_handle() : nullptr; + + // Use common logic to compute launch config + auto config = compute_launch_config(num_itopk_candidates, ps.itopk_size, block_size); + uint32_t max_candidates = config.max_candidates; + uint32_t max_itopk = config.max_itopk; + + if (ps.persistent) { + using runner_type = persistent_runner_t; + + get_runner(/* +Note, we're passing the descriptor by reference here, and this reference is going to be passed to a +new spawned thread, which is dangerous. However, the descriptor is copied in that thread before the +control is returned in this thread (in persistent_runner_t constructor), so we're safe. +*/ + std::cref(dataset_desc), + graph, + source_indices_ptr, + max_candidates, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + ps.num_random_samplings, + ps.rand_xor_mask, + num_seeds, + max_itopk, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, + sample_filter, + ps.persistent_lifetime, + ps.persistent_device_usage) + ->launch(topk_indices_ptr, topk_distances_ptr, queries_ptr, num_queries, topk); + } else { + using descriptor_base_type = dataset_descriptor_base_t; + auto kernel = search_kernel_config:: + choose_itopk_and_mx_candidates(ps.itopk_size, num_itopk_candidates, block_size); + + dim3 thread_dims(block_size, 1, 1); + dim3 block_dims(1, num_queries, 1); + RAFT_LOG_DEBUG( + "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); + auto const& kernel_launcher = [&](auto const& kernel) -> void { + kernel<<>>(topk_indices_ptr, + topk_distances_ptr, + topk, + dataset_desc.dev_ptr(stream), + queries_ptr, + graph.data_handle(), + graph.extent(1), + source_indices_ptr, + ps.num_random_samplings, + ps.rand_xor_mask, + dev_seed_ptr, + num_seeds, + hashmap_ptr, + max_candidates, + max_itopk, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + sample_filter, + static_cast(graph.extent(0))); + }; + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size( + kernel, smem_size, kernel_launcher); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_common.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_common.cuh new file mode 100644 index 0000000000..b1e2191fec --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_common.cuh @@ -0,0 +1,63 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// Common logic for computing max_candidates and max_itopk +struct LaunchConfig { + uint32_t max_candidates; + uint32_t max_itopk; + bool topk_by_bitonic_sort; + bool bitonic_sort_and_merge_multi_warps; +}; + +inline LaunchConfig compute_launch_config(uint32_t num_itopk_candidates, + uint32_t itopk_size, + uint32_t block_size) +{ + LaunchConfig config{}; + + // Compute max_candidates + if (num_itopk_candidates <= 64) { + config.max_candidates = 64; + } else if (num_itopk_candidates <= 128) { + config.max_candidates = 128; + } else if (num_itopk_candidates <= 256) { + config.max_candidates = 256; + } else { + config.max_candidates = 32; // irrelevant, radix based topk is used + } + + // Compute max_itopk and sort flags + config.topk_by_bitonic_sort = (num_itopk_candidates <= 256); + config.bitonic_sort_and_merge_multi_warps = false; + + if (config.topk_by_bitonic_sort) { + if (itopk_size <= 64) { + config.max_itopk = 64; + } else if (itopk_size <= 128) { + config.max_itopk = 128; + } else if (itopk_size <= 256) { + config.max_itopk = 256; + } else { + config.max_itopk = 512; + config.bitonic_sort_and_merge_multi_warps = (block_size >= 64); + } + } else { + if (itopk_size <= 256) { + config.max_itopk = 256; + } else { + config.max_itopk = 512; + } + } + + return config; +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh new file mode 100644 index 0000000000..c4b52afa99 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -0,0 +1,855 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_ENABLE_JIT_LTO +#error "search_single_cta_kernel_launcher_jit.cuh included but CUVS_ENABLE_JIT_LTO not defined!" +#endif + +#include "../smem_utils.cuh" + +#include +#include + +// Include tags header before any other includes that might open namespaces +#include + +#include "compute_distance.hpp" // For dataset_descriptor_host +#include "jit_lto_kernels/cagra_jit_launcher_factory.hpp" +#include "jit_lto_kernels/kernel_def.hpp" +#include "sample_filter_utils.cuh" // For CagraSampleFilterWithQueryIdOffset +#include "search_plan.cuh" // For search_params +#include "search_single_cta_kernel-inl.cuh" // For resource_queue_t, local_deque_t, launcher_t, persistent_runner_base_t, etc. +#include "search_single_cta_kernel_launcher_common.cuh" +#include "shared_launcher_jit.hpp" // For shared JIT helper functions + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// The launcher uses types from search_single_cta_kernel-inl.cuh (worker_handle_t, job_desc_t) +// The JIT kernel headers define _jit versions that are compatible + +// Forward declarations +template +auto get_runner_jit(Args... args) -> std::shared_ptr; + +template +auto create_runner_jit(Args... args) -> std::shared_ptr; + +// Helper functions are now in shared_launcher_jit.hpp + +// JIT-compatible launcher_t that works with worker_handle_t (same as non-JIT version) +struct alignas(kCacheLineBytes) launcher_jit_t { + using job_queue_type = resource_queue_t; + using worker_queue_type = resource_queue_t; + using pending_reads_queue_type = local_deque_t; + using completion_flag_type = cuda::atomic; + + pending_reads_queue_type pending_reads; + job_queue_type& job_ids; + worker_queue_type& idle_worker_ids; + worker_handle_t* worker_handles; + uint32_t job_id; + completion_flag_type* completion_flag; + bool all_done = false; + + static inline constexpr auto kDefaultLatency = std::chrono::nanoseconds(50000); + static inline constexpr auto kMaxExpectedLatency = + kDefaultLatency * std::max(10, kMaxJobsNum / 128); + static inline thread_local auto expected_latency = kDefaultLatency; + const std::chrono::time_point start; + std::chrono::time_point now; + const int64_t pause_factor; + int pause_count = 0; + std::chrono::time_point deadline; + + template + launcher_jit_t(job_queue_type& job_ids, + worker_queue_type& idle_worker_ids, + worker_handle_t* worker_handles, + uint32_t n_queries, + std::chrono::milliseconds max_wait_time, + RecordWork record_work) + : pending_reads{std::min(n_queries, kMaxWorkersPerThread)}, + job_ids{job_ids}, + idle_worker_ids{idle_worker_ids}, + worker_handles{worker_handles}, + job_id{job_ids.pop().wait()}, + completion_flag{record_work(job_id)}, + start{std::chrono::system_clock::now()}, + pause_factor{calc_pause_factor(n_queries)}, + now{start}, + deadline{start + max_wait_time + expected_latency} + { + submit_query(idle_worker_ids.pop().wait(), 0); + for (uint32_t i = 1; i < n_queries; i++) { + auto promised_worker = idle_worker_ids.pop(); + uint32_t worker_id; + while (!promised_worker.test(worker_id)) { + if (pending_reads.try_pop_front(worker_id)) { + bool returned_some = false; + for (bool keep_returning = true; keep_returning;) { + if (try_return_worker(worker_id)) { + keep_returning = pending_reads.try_pop_front(worker_id); + returned_some = true; + } else { + pending_reads.push_front(worker_id); + keep_returning = false; + } + } + if (!returned_some) { pause(); } + } else { + worker_id = promised_worker.wait(); + break; + } + } + pause_count = 0; + submit_query(worker_id, i); + if (i >= kSoftMaxWorkersPerThread && pending_reads.try_pop_front(worker_id)) { + if (!try_return_worker(worker_id)) { pending_reads.push_front(worker_id); } + } + } + } + + inline ~launcher_jit_t() noexcept + { + constexpr size_t kWindow = 100; + expected_latency = std::min( + ((kWindow - 1) * expected_latency + now - start) / kWindow, kMaxExpectedLatency); + if (job_id != job_queue_type::kEmpty) { job_ids.push(job_id); } + uint32_t worker_id; + while (pending_reads.try_pop_front(worker_id)) { + idle_worker_ids.push(worker_id); + } + } + + inline void submit_query(uint32_t worker_id, uint32_t query_id) + { + worker_handles[worker_id].data.store(worker_handle_t::data_t{.value = {job_id, query_id}}, + cuda::memory_order_relaxed); + while (!pending_reads.try_push_back(worker_id)) { + auto pending_worker_id = pending_reads.pop_front(); + while (!try_return_worker(pending_worker_id)) { + pause(); + } + } + pause_count = 0; + } + + inline auto try_return_worker(uint32_t worker_id) -> bool + { + if (all_done || + !is_worker_busy(worker_handles[worker_id].data.load(cuda::memory_order_relaxed).handle)) { + idle_worker_ids.push(worker_id); + return true; + } else { + return false; + } + } + + inline auto is_all_done() + { + if (all_done) { return true; } + all_done = completion_flag->load(cuda::memory_order_relaxed); + return all_done; + } + + [[nodiscard]] inline auto sleep_limit() const + { + constexpr auto kMinWakeTime = std::chrono::nanoseconds(10000); + constexpr double kSleepLimit = 0.6; + return start + expected_latency * kSleepLimit - kMinWakeTime; + } + + [[nodiscard]] inline auto overtime_threshold() const + { + constexpr auto kOvertimeFactor = 3; + return start + expected_latency * kOvertimeFactor; + } + + [[nodiscard]] inline auto calc_pause_factor(uint32_t n_queries) const -> uint32_t + { + constexpr uint32_t kMultiplier = 10; + return kMultiplier * raft::div_rounding_up_safe(n_queries, idle_worker_ids.capacity()); + } + + inline void pause() + { + constexpr auto kSpinLimit = 3; + constexpr auto kPauseTimeMin = std::chrono::nanoseconds(1000); + constexpr auto kPauseTimeMax = std::chrono::nanoseconds(50000); + if (pause_count++ < kSpinLimit) { + std::this_thread::yield(); + return; + } + now = std::chrono::system_clock::now(); + auto pause_time_base = std::max(now - start, expected_latency); + auto pause_time = std::clamp(pause_time_base / pause_factor, kPauseTimeMin, kPauseTimeMax); + if (now + pause_time < sleep_limit()) { + std::this_thread::sleep_for(pause_time); + } else if (now <= overtime_threshold()) { + std::this_thread::yield(); + } else if (now <= deadline) { + std::this_thread::sleep_for(pause_time); + } else { + throw raft::exception( + "The calling thread didn't receive the results from the persistent CAGRA kernel within the " + "expected kernel lifetime. Here are possible reasons of this failure:\n" + " (1) `persistent_lifetime` search parameter is too small - increase it;\n" + " (2) there is other work being executed on the same device and the kernel failed to " + "progress - decreasing `persistent_device_usage` may help (but not guaranteed);\n" + " (3) there is a bug in the implementation - please report it to cuVS team."); + } + } + + inline void wait() + { + uint32_t worker_id; + while (pending_reads.try_pop_front(worker_id)) { + while (!try_return_worker(worker_id)) { + if (!is_all_done()) { pause(); } + } + } + pause_count = 0; + now = std::chrono::system_clock::now(); + while (!is_all_done()) { + auto till_time = sleep_limit(); + if (now < till_time) { + std::this_thread::sleep_until(till_time); + now = std::chrono::system_clock::now(); + } else { + pause(); + } + } + job_ids.push(job_id); + job_id = job_queue_type::kEmpty; + } +}; + +// JIT persistent runner - uses AlgorithmLauncher instead of kernel function pointer +template +struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runner_base_t { + using index_type = IndexT; + using distance_type = DistanceT; + using data_type = DataT; + // Use non-JIT types - JIT kernel header will alias _jit versions to these + struct job_desc_helper_desc { + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + }; + using job_desc_type = job_desc_t; + + std::shared_ptr launcher; + uint32_t block_size; + dataset_descriptor_host dd_host; + rmm::device_uvector worker_handles; + rmm::device_uvector job_descriptors; + rmm::device_uvector completion_counters; + rmm::device_uvector hashmap; + std::atomic> last_touch; + uint64_t param_hash; + uint32_t* bitset_ptr; // Bitset data pointer (nullptr for none_filter) + SourceIndexT bitset_len; // Bitset length + SourceIndexT original_nbits; // Original number of bits + + static inline auto calculate_parameter_hash( + std::reference_wrapper> dataset_desc, + raft::device_matrix_view graph, + const SourceIndexT* source_indices_ptr, + uint32_t max_candidates, + uint32_t num_itopk_candidates, + uint32_t block_size, + uint32_t smem_size, + int64_t hash_bitlen, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + uint32_t max_itopk, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SampleFilterT sample_filter, + float persistent_lifetime, + float persistent_device_usage, + std::shared_ptr /* launcher_ptr - not part of hash */, + const void* /* dataset_desc - not part of hash */) -> uint64_t + { + return uint64_t(graph.data_handle()) ^ uint64_t(source_indices_ptr) ^ + dataset_desc.get().team_size ^ num_itopk_candidates ^ block_size ^ smem_size ^ + hash_bitlen ^ small_hash_reset_interval ^ num_random_samplings ^ rand_xor_mask ^ + num_seeds ^ itopk_size ^ search_width ^ min_iterations ^ max_iterations ^ + uint64_t(persistent_lifetime * 1000) ^ uint64_t(persistent_device_usage * 1000); + } + + persistent_runner_jit_t( + std::reference_wrapper> dataset_desc, + raft::device_matrix_view graph, + const SourceIndexT* source_indices_ptr, + uint32_t max_candidates, + uint32_t num_itopk_candidates, + uint32_t block_size, + uint32_t smem_size, + int64_t hash_bitlen, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + uint32_t max_itopk, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SampleFilterT sample_filter, + float persistent_lifetime, + float persistent_device_usage, + std::shared_ptr launcher_ptr, + const void* /* dataset_desc - descriptor contains all needed info */) + : persistent_runner_base_t{persistent_lifetime}, + launcher{launcher_ptr}, + block_size{block_size}, + worker_handles(0, stream, worker_handles_mr), + job_descriptors(kMaxJobsNum, stream, job_descriptor_mr), + completion_counters(kMaxJobsNum, stream, device_mr), + hashmap(0, stream, device_mr), + dd_host{dataset_desc.get()}, + param_hash(calculate_parameter_hash(dd_host, + graph, + source_indices_ptr, + max_candidates, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + num_random_samplings, + rand_xor_mask, + num_seeds, + max_itopk, + itopk_size, + search_width, + min_iterations, + max_iterations, + sample_filter, + persistent_lifetime, + persistent_device_usage, + launcher_ptr, + nullptr)) // descriptor not needed in hash + { + // Extract bitset data from filter object (if it's a bitset_filter) + // Handle both direct bitset_filter and CagraSampleFilterWithQueryIdOffset wrapper + bitset_ptr = nullptr; + bitset_len = 0; + original_nbits = 0; + uint32_t query_id_offset = 0; + + // Check if it has the wrapper members (CagraSampleFilterWithQueryIdOffset) + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + using InnerFilter = decltype(sample_filter.filter); + // Always extract offset for wrapped filters + query_id_offset = sample_filter.offset; + if constexpr (is_bitset_filter::value) { + // Extract bitset data for bitset_filter (works for any bitset_filter instantiation) + auto bitset_view = sample_filter.filter.view(); + bitset_ptr = const_cast(bitset_view.data()); + bitset_len = static_cast(bitset_view.size()); + original_nbits = static_cast(bitset_view.get_original_nbits()); + } + } + + // set kernel launch parameters + dim3 gs = calc_coop_grid_size(block_size, smem_size, persistent_device_usage); + dim3 bs(block_size, 1, 1); + RAFT_LOG_DEBUG( + "Launching JIT persistent kernel with %u threads, %u block %u smem", bs.x, gs.y, smem_size); + + // initialize the job queue + auto* completion_counters_ptr = completion_counters.data(); + auto* job_descriptors_ptr = job_descriptors.data(); + for (uint32_t i = 0; i < kMaxJobsNum; i++) { + auto& jd = job_descriptors_ptr[i].input.value; + jd.result_indices_ptr = 0; + jd.result_distances_ptr = nullptr; + jd.queries_ptr = nullptr; + jd.top_k = 0; + jd.n_queries = 0; + job_descriptors_ptr[i].completion_flag.store(false); + job_queue.push(i); + } + + // initialize the worker queue + worker_queue.set_capacity(gs.y); + worker_handles.resize(gs.y, stream); + auto* worker_handles_ptr = worker_handles.data(); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (uint32_t i = 0; i < gs.y; i++) { + worker_handles_ptr[i].data.store({kWaitForWork}); + worker_queue.push(i); + } + + index_type* hashmap_ptr = nullptr; + if (small_hash_bitlen == 0) { + hashmap.resize(gs.y * hashmap::get_size(hash_bitlen), stream); + hashmap_ptr = hashmap.data(); + } + + // Prepare kernel arguments + // Get the device descriptor pointer - kernel will use the concrete type from template + const auto* dev_desc = dataset_desc.get().dev_ptr(stream); + + // Cast size_t/int64_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + const uint32_t graph_degree_u32 = static_cast(graph.extent(1)); + const uint32_t hash_bitlen_u32 = static_cast(hash_bitlen); + const uint32_t small_hash_bitlen_u32 = static_cast(small_hash_bitlen); + const uint32_t small_hash_reset_interval_u32 = static_cast(small_hash_reset_interval); + const uint32_t itopk_size_u32 = static_cast(itopk_size); + const uint32_t search_width_u32 = static_cast(search_width); + const uint32_t min_iterations_u32 = static_cast(min_iterations); + const uint32_t max_iterations_u32 = static_cast(max_iterations); + const unsigned num_random_samplings_u = static_cast(num_random_samplings); + + const IndexT* seed_ptr_arg = nullptr; + uint32_t* num_executed_iterations_arg = nullptr; + // Launch the persistent kernel via AlgorithmLauncher + // The persistent kernel now takes the descriptor pointer directly + launcher->dispatch_cooperative< + single_cta_search::search_single_cta_p_jit_func_t>( + stream, + gs, + bs, + static_cast(smem_size), + worker_handles_ptr, + job_descriptors_ptr, + completion_counters_ptr, + graph.data_handle(), + graph_degree_u32, // Cast int64_t to uint32_t + source_indices_ptr, + num_random_samplings_u, // Cast uint32_t to unsigned for consistency + rand_xor_mask, // uint64_t matches kernel (8 bytes) + seed_ptr_arg, + num_seeds, + hashmap_ptr, + max_candidates, + max_itopk, + itopk_size_u32, // Cast size_t to uint32_t + search_width_u32, // Cast size_t to uint32_t + min_iterations_u32, // Cast size_t to uint32_t + max_iterations_u32, // Cast size_t to uint32_t + num_executed_iterations_arg, + hash_bitlen_u32, // Cast int64_t to uint32_t + small_hash_bitlen_u32, // Cast size_t to uint32_t + small_hash_reset_interval_u32, // Cast size_t to uint32_t + query_id_offset, // Offset to add to query_id when calling filter + dev_desc, + bitset_ptr, + bitset_len, + original_nbits); + + last_touch.store(std::chrono::system_clock::now(), std::memory_order_relaxed); + } + + ~persistent_runner_jit_t() noexcept override + { + auto whs = worker_handles.data(); + for (auto i = worker_handles.size(); i > 0; i--) { + whs[worker_queue.pop().wait()].data.store({kNoMoreWork}, cuda::memory_order_relaxed); + } + RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream)); + } + + void launch(uintptr_t result_indices_ptr, + distance_type* result_distances_ptr, + const data_type* queries_ptr, + uint32_t num_queries, + uint32_t top_k) + { + launcher_jit_t launcher{job_queue, + worker_queue, + worker_handles.data(), + num_queries, + this->lifetime, + [&job_descriptors = this->job_descriptors, + result_indices_ptr, + result_distances_ptr, + queries_ptr, + top_k, + num_queries](uint32_t job_ix) { + auto& jd = job_descriptors.data()[job_ix].input.value; + auto* cflag = &job_descriptors.data()[job_ix].completion_flag; + jd.result_indices_ptr = result_indices_ptr; + jd.result_distances_ptr = result_distances_ptr; + jd.queries_ptr = queries_ptr; + jd.top_k = top_k; + jd.n_queries = num_queries; + cflag->store(false, cuda::memory_order_relaxed); + cuda::atomic_thread_fence(cuda::memory_order_release, + cuda::thread_scope_system); + return cflag; + }}; + + auto prev_touch = last_touch.load(std::memory_order_relaxed); + if (prev_touch + lifetime / 10 < launcher.now) { + last_touch.store(launcher.now, std::memory_order_relaxed); + } + launcher.wait(); + } + + auto calc_coop_grid_size(uint32_t block_size, uint32_t smem_size, float persistent_device_usage) + -> dim3 + { + int ctas_per_sm = 1; + cudaKernel_t kernel_handle = launcher->get_kernel(); + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel_handle, block_size, smem_size)); + int num_sm = raft::getMultiProcessorCount(); + auto n_blocks = static_cast(persistent_device_usage * (ctas_per_sm * num_sm)); + if (n_blocks > kMaxWorkersNum) { + RAFT_LOG_WARN("Limiting the grid size limit due to the size of the queue: %u -> %u", + n_blocks, + kMaxWorkersNum); + n_blocks = kMaxWorkersNum; + } + return {1, n_blocks, 1}; + } +}; + +template +void select_and_run_jit( + const dataset_descriptor_host& dataset_desc, + raft::device_matrix_view graph, + std::optional> source_indices, + uintptr_t topk_indices_ptr, // [num_queries, topk] + DistanceT* topk_distances_ptr, // [num_queries, topk] + const DataT* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + const IndexT* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* num_executed_iterations, // [num_queries,] + const search_params& ps, + uint32_t topk, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + IndexT* hashmap_ptr, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_seeds, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + const SourceIndexT* source_indices_ptr = + source_indices.has_value() ? source_indices->data_handle() : nullptr; + + // Extract bitset data from filter object (if it's a bitset_filter) + // Handle both direct bitset_filter and CagraSampleFilterWithQueryIdOffset wrapper + uint32_t* bitset_ptr = nullptr; + SourceIndexT bitset_len = 0; + SourceIndexT original_nbits = 0; + uint32_t query_id_offset = 0; + + // Check if it has the wrapper members (CagraSampleFilterWithQueryIdOffset) + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + using InnerFilter = decltype(sample_filter.filter); + // Always extract offset for wrapped filters + query_id_offset = sample_filter.offset; + if constexpr (is_bitset_filter::value) { + // Extract bitset data for bitset_filter (works for any bitset_filter instantiation) + auto bitset_view = sample_filter.filter.view(); + bitset_ptr = const_cast(bitset_view.data()); + bitset_len = static_cast(bitset_view.size()); + original_nbits = static_cast(bitset_view.get_original_nbits()); + } + } + + // Use common logic to compute launch config + auto config = compute_launch_config(num_itopk_candidates, ps.itopk_size, block_size); + uint32_t max_candidates = config.max_candidates; + uint32_t max_itopk = config.max_itopk; + bool topk_by_bitonic_sort = config.topk_by_bitonic_sort; + bool bitonic_sort_and_merge_multi_warps = config.bitonic_sort_and_merge_multi_warps; + + // Handle persistent kernels + if (ps.persistent) { + // Use persistent runner for JIT kernels + using runner_type = + persistent_runner_jit_t; + + std::string const filter_name = get_sample_filter_name(); + std::shared_ptr launcher = + make_cagra_single_cta_jit_launcher( + dataset_desc, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + true /* persistent */, + filter_name); + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA persistent search kernel"); } + + // Use get_runner pattern similar to non-JIT version + const auto* dev_desc_persistent = dataset_desc.dev_ptr(stream); + get_runner_jit(std::cref(dataset_desc), + graph, + source_indices_ptr, + max_candidates, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + ps.num_random_samplings, + ps.rand_xor_mask, + num_seeds, + max_itopk, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, + sample_filter, + ps.persistent_lifetime, + ps.persistent_device_usage, + launcher, + dev_desc_persistent) // Pass descriptor pointer + ->launch(topk_indices_ptr, topk_distances_ptr, queries_ptr, num_queries, topk); + return; + } else { + std::string const filter_name = get_sample_filter_name(); + std::shared_ptr launcher = + make_cagra_single_cta_jit_launcher( + dataset_desc, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + false /* persistent */, + filter_name); + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA search kernel"); } + + // Get the device descriptor pointer - dev_ptr() initializes it if needed + const auto* dev_desc = dataset_desc.dev_ptr(stream); + + // Cast size_t/int64_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + const uint32_t graph_degree_u32 = static_cast(graph.extent(1)); + const uint32_t hash_bitlen_u32 = static_cast(hash_bitlen); + const uint32_t small_hash_bitlen_u32 = static_cast(small_hash_bitlen); + const uint32_t small_hash_reset_interval_u32 = static_cast(small_hash_reset_interval); + const uint32_t itopk_size_u32 = static_cast(ps.itopk_size); + const uint32_t search_width_u32 = static_cast(ps.search_width); + const uint32_t min_iterations_u32 = static_cast(ps.min_iterations); + const uint32_t max_iterations_u32 = static_cast(ps.max_iterations); + const unsigned num_random_samplings_u = static_cast(ps.num_random_samplings); + + dim3 grid(1, num_queries, 1); + dim3 block(block_size, 1, 1); + + RAFT_LOG_DEBUG("Launching JIT kernel with %u threads, %u blocks, %u smem", + block_size, + num_queries, + smem_size); + + // Dispatch kernel via launcher + auto kernel_launcher = [&](auto const& kernel) -> void { + launcher->dispatch>( + stream, + grid, + block, + static_cast(smem_size), + topk_indices_ptr, + topk_distances_ptr, + topk, + queries_ptr, + graph.data_handle(), + graph_degree_u32, // Cast int64_t to uint32_t + source_indices_ptr, + num_random_samplings_u, // Cast uint32_t to unsigned for consistency + ps.rand_xor_mask, // uint64_t matches kernel (8 bytes) + dev_seed_ptr, + num_seeds, + hashmap_ptr, + max_candidates, + max_itopk, + itopk_size_u32, // Cast size_t to uint32_t + search_width_u32, // Cast size_t to uint32_t + min_iterations_u32, // Cast size_t to uint32_t + max_iterations_u32, // Cast size_t to uint32_t + num_executed_iterations, + hash_bitlen_u32, // Cast int64_t to uint32_t + small_hash_bitlen_u32, // Cast size_t to uint32_t + small_hash_reset_interval_u32, // Cast size_t to uint32_t + query_id_offset, // Offset to add to query_id when calling filter + dev_desc, + bitset_ptr, + bitset_len, + original_nbits); + }; + + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size( + launcher->get_kernel(), smem_size, kernel_launcher); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } +} + +// Wrapper to match the non-JIT interface +// This function MUST be called if JIT is enabled +template +void select_and_run( + const dataset_descriptor_host& dataset_desc, + raft::device_matrix_view graph, + std::optional> source_indices, + uintptr_t topk_indices_ptr, // [num_queries, topk] + DistanceT* topk_distances_ptr, // [num_queries, topk] + const DataT* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + const IndexT* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* num_executed_iterations, // [num_queries,] + const search_params& ps, + uint32_t topk, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + IndexT* hashmap_ptr, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_seeds, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + select_and_run_jit(dataset_desc, + graph, + source_indices, + topk_indices_ptr, + topk_distances_ptr, + queries_ptr, + num_queries, + dev_seed_ptr, + num_executed_iterations, + ps, + topk, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + hashmap_ptr, + small_hash_bitlen, + small_hash_reset_interval, + num_seeds, + sample_filter, + stream); +} + +// get_runner for JIT persistent runners (similar to non-JIT version) +template +auto get_runner_jit(Args... args) -> std::shared_ptr +{ + static thread_local std::weak_ptr weak; + auto runner = weak.lock(); + if (runner) { + if (runner->param_hash == RunnerT::calculate_parameter_hash(args...)) { + return runner; + } else { + weak.reset(); + runner.reset(); + } + } + launcher_jit_t::expected_latency = launcher_jit_t::kDefaultLatency; + runner = create_runner_jit(args...); + weak = runner; + return runner; +} + +template +auto create_runner_jit(Args... args) -> std::shared_ptr +{ + std::lock_guard guard(persistent.lock); + std::shared_ptr runner_outer = std::dynamic_pointer_cast(persistent.runner); + if (runner_outer) { + // calculate_parameter_hash needs all args to match constructor signature + // but only uses a subset for the actual hash + if (runner_outer->param_hash == RunnerT::calculate_parameter_hash(args...)) { + return runner_outer; + } else { + runner_outer.reset(); + } + } + persistent.runner.reset(); + + cuda::std::atomic_flag ready{}; + ready.clear(cuda::std::memory_order_relaxed); + std::thread( + [&runner_outer, &ready](Args... thread_args) { + runner_outer = std::make_shared(thread_args...); + auto lifetime = runner_outer->lifetime; + persistent.runner = std::static_pointer_cast(runner_outer); + std::weak_ptr runner_weak = runner_outer; + ready.test_and_set(cuda::std::memory_order_release); + ready.notify_one(); + + while (true) { + std::this_thread::sleep_for(lifetime); + auto runner = runner_weak.lock(); + if (!runner) { return; } + if (runner->last_touch.load(std::memory_order_relaxed) + lifetime < + std::chrono::system_clock::now()) { + std::lock_guard guard(persistent.lock); + if (runner == persistent.runner) { persistent.runner.reset(); } + return; + } + } + }, + args...) + .detach(); + ready.wait(false, cuda::std::memory_order_acquire); + return runner_outer; +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/set_value_batch.cuh b/cpp/src/neighbors/detail/cagra/set_value_batch.cuh new file mode 100644 index 0000000000..a4433005a7 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/set_value_batch.cuh @@ -0,0 +1,40 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +__global__ void set_value_batch_kernel(T* const dev_ptr, + const std::size_t ld, + const T val, + const std::size_t count, + const std::size_t batch_size) +{ + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= count * batch_size) { return; } + const auto batch_id = tid / count; + const auto elem_id = tid % count; + dev_ptr[elem_id + ld * batch_id] = val; +} + +template +void set_value_batch(T* const dev_ptr, + const std::size_t ld, + const T val, + const std::size_t count, + const std::size_t batch_size, + cudaStream_t cuda_stream) +{ + constexpr std::uint32_t block_size = 256; + const auto grid_size = (count * batch_size + block_size - 1) / block_size; + set_value_batch_kernel + <<>>(dev_ptr, ld, val, count, batch_size); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp new file mode 100644 index 0000000000..f8798dfd04 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp @@ -0,0 +1,118 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_ENABLE_JIT_LTO +#error "shared_launcher_jit.hpp included but CUVS_ENABLE_JIT_LTO not defined!" +#endif + +// Include tags header before any other includes that might open namespaces +#include + +#include "../../sample_filter.cuh" // For none_sample_filter, bitset_filter + +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +// Helper functions to get tags for JIT LTO +template +constexpr auto get_data_type_tag() +{ + if constexpr (std::is_same_v) { return tag_f{}; } + if constexpr (std::is_same_v) { return tag_h{}; } + if constexpr (std::is_same_v) { return tag_sc{}; } + if constexpr (std::is_same_v) { return tag_uc{}; } +} + +template +constexpr auto get_index_type_tag() +{ + if constexpr (std::is_same_v) { return tag_idx_ui{}; } +} + +template +constexpr auto get_distance_type_tag() +{ + if constexpr (std::is_same_v) { return tag_dist_f{}; } +} + +template +constexpr auto get_source_index_type_tag() +{ + if constexpr (std::is_same_v) { return tag_idx_ui{}; } + if constexpr (std::is_same_v) { return tag_idx_l{}; } +} + +template +struct query_type_tag_standard { + using type = std::conditional_t, + tag_uc, + tag_f>; +}; + +template +using query_type_tag_standard_t = typename query_type_tag_standard::type; + +template +using query_type_tag_vpq_t = tag_h; + +template +using query_type_tag_standard_l2_t = + query_type_tag_standard_t; +template +using query_type_tag_standard_inner_product_t = + query_type_tag_standard_t; +template +using query_type_tag_standard_cosine_t = + query_type_tag_standard_t; +template +using query_type_tag_standard_hamming_t = + query_type_tag_standard_t; + +using codebook_tag_vpq_t = tag_codebook_half; +using codebook_tag_standard_t = tag_codebook_none; + +// Helper trait to detect if a type is a bitset_filter (regardless of template parameters) +template +struct is_bitset_filter : std::false_type {}; + +template +struct is_bitset_filter> + : std::true_type {}; + +template +std::string get_sample_filter_name() +{ + using namespace cuvs::neighbors::filtering; + using DecayedFilter = std::decay_t; + + // First check for none_sample_filter (the only unwrapped case) + if constexpr (std::is_same_v) { + return "filter_none_source_index_ui"; + } + + // All other filters are wrapped in CagraSampleFilterWithQueryIdOffset + // Access the inner filter type via decltype + if constexpr (requires { std::declval().filter; }) { + using InnerFilter = decltype(std::declval().filter); + if constexpr (is_bitset_filter::value || + std::is_same_v> || + std::is_same_v>) { + return "filter_bitset_source_index_ui"; + } + } + + // Default to none filter for unknown types + return "filter_none_source_index_ui"; +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/jit_lto_kernels/filter_bitset.cuh b/cpp/src/neighbors/detail/jit_lto_kernels/filter_bitset.cuh new file mode 100644 index 0000000000..415fae7075 --- /dev/null +++ b/cpp/src/neighbors/detail/jit_lto_kernels/filter_bitset.cuh @@ -0,0 +1,77 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "filter_data.h" + +namespace cuvs::neighbors::detail { + +// Inline implementation of bitset_view::test() to avoid including bitset.cuh +// which transitively includes Thrust +template +__device__ inline bool bitset_view_test(const bitset_t* bitset_ptr, + index_t bitset_len, + index_t original_nbits, + index_t sample_index) +{ + constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; + const index_t nbits = sizeof(bitset_t) * 8; + index_t bit_index = 0; + index_t bit_offset = 0; + + if (original_nbits == 0 || nbits == original_nbits) { + bit_index = sample_index / bitset_element_size; + bit_offset = sample_index % bitset_element_size; + } else { + // Handle original_nbits != nbits case + const index_t original_bit_index = sample_index / original_nbits; + const index_t original_bit_offset = sample_index % original_nbits; + bit_index = original_bit_index * original_nbits / nbits; + bit_offset = 0; + if (original_nbits > nbits) { + bit_index += original_bit_offset / nbits; + bit_offset = original_bit_offset % nbits; + } else { + index_t ratio = nbits / original_nbits; + bit_offset += (original_bit_index % ratio) * original_nbits; + bit_offset += original_bit_offset % nbits; + } + } + const bitset_t bit_element = bitset_ptr[bit_index]; + const bool is_bit_set = (bit_element & (bitset_t{1} << bit_offset)) != 0; + return is_bit_set; +} + +// Unified sample_filter: takes query_id, node_id, and void* filter_data +// Used by both CAGRA and IVF Flat +// For IVF Flat: node_id should be computed from (cluster_ix, sample_ix) using inds_ptrs from +// filter_data +template +__device__ bool sample_filter(uint32_t query_id, SourceIndexT node_id, void* filter_data) +{ + // bitset_filter checks if the node_id is in the bitset + // filter_data points to bitset_filter_data_t struct + if (filter_data == nullptr) { + return true; // No filter data, allow all + } + + auto* bitset_data = static_cast*>(filter_data); + if (bitset_data->bitset_ptr == nullptr) { + return true; // No bitset provided, allow all + } + + // Directly test the bitset without needing bitset_filter wrapper + // bitset_view_test returns true if the bit is set (node_id is in the bitset) + // The bitset marks allowed indices (same as non-JIT bitset_filter which returns test() directly) + // Return true if the bit is set (node is allowed), false if not set (node should be filtered out) + bool is_in_bitset = bitset_view_test( + bitset_data->bitset_ptr, bitset_data->bitset_len, bitset_data->original_nbits, node_id); + // If node_id is in the bitset (allowed), return true to allow it + // If node_id is not in the bitset, return false to reject it + return is_in_bitset; +} + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/jit_lto_kernels/filter_data.h b/cpp/src/neighbors/detail/jit_lto_kernels/filter_data.h new file mode 100644 index 0000000000..9fc4336872 --- /dev/null +++ b/cpp/src/neighbors/detail/jit_lto_kernels/filter_data.h @@ -0,0 +1,28 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +namespace cuvs::neighbors::detail { + +// Structure to hold bitset filter data +// This is passed as void* to the extern sample_filter function +// Used by both CAGRA and IVF Flat +template +struct bitset_filter_data_t { + uint32_t* bitset_ptr; // Pointer to bitset data in global memory + SourceIndexT bitset_len; // Length of bitset array + SourceIndexT original_nbits; // Original number of bits + + __device__ bitset_filter_data_t(uint32_t* ptr, SourceIndexT len, SourceIndexT nbits) + : bitset_ptr(ptr), bitset_len(len), original_nbits(nbits) + { + } +}; + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/jit_lto_kernels/filter_none.cuh b/cpp/src/neighbors/detail/jit_lto_kernels/filter_none.cuh new file mode 100644 index 0000000000..e3ca5496c1 --- /dev/null +++ b/cpp/src/neighbors/detail/jit_lto_kernels/filter_none.cuh @@ -0,0 +1,22 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::detail { + +// Unified sample_filter: takes query_id, node_id, and void* filter_data +// Used by both CAGRA and IVF Flat +template +__device__ bool sample_filter(uint32_t query_id, SourceIndexT node_id, void* filter_data) +{ + // none_sample_filter always returns true (no filtering) + // filter_data is ignored (can be nullptr) + return true; +} + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/smem_utils.cuh b/cpp/src/neighbors/detail/smem_utils.cuh index 73bc8c578d..ba625bd890 100644 --- a/cpp/src/neighbors/detail/smem_utils.cuh +++ b/cpp/src/neighbors/detail/smem_utils.cuh @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -31,15 +32,16 @@ namespace cuvs::neighbors::detail { * @param launch The kernel launch function/lambda. */ template -void safely_launch_kernel_with_smem_size(KernelT const& kernel, - uint32_t smem_size, - KernelLauncherT const& launch) +void safely_launch_kernel_with_smem_size_impl(KernelT const& kernel, + uint32_t smem_size, + KernelLauncherT const& launch, + std::mutex& mutex, + std::atomic& current_smem_size) { // last_smem_size is a monotonically growing high-water mark across all kernel pointers. // last_kernel tracks which kernel pointer was last used. static std::atomic last_smem_size{0}; static std::atomic last_kernel{KernelT{}}; - static std::mutex mutex; // Fast path: skip the lock when the kernel matches and the smem size is within bounds. // Load order matters: last_smem_size (acquire) before last_kernel (relaxed). Inside the lock // we store in the opposite order: last_kernel (relaxed) then last_smem_size (release). @@ -70,4 +72,52 @@ void safely_launch_kernel_with_smem_size(KernelT const& kernel, return launch(kernel); } +/** + * @brief (Thread-)Safely invoke a kernel with a maximum dynamic shared memory size. + * This is required because the sequence `cudaFuncSetAttribute` + kernel launch is not executed + * atomically. + * + * Used this way, the cudaFuncAttributeMaxDynamicSharedMemorySize can only grow and thus + * guarantees that the kernel is safe to launch. + * + * @tparam KernelT The type of the kernel. + * @tparam InvocationT The type of the invocation function. + * @param kernel The kernel function address (for whom the smem-size is specified). + * @param smem_size The size of the dynamic shared memory to be set. + * @param launch The kernel launch function/lambda. + */ +// Specialization for cudaKernel_t (JIT LTO kernels) - track by kernel pointer +template +void safely_launch_kernel_with_smem_size(cudaKernel_t kernel, + uint32_t smem_size, + KernelLauncherT const& launch) +{ + // For JIT kernels, track by kernel pointer since all cudaKernel_t have the same type + static std::unordered_map>> + jit_smem_sizes; + std::mutex map_mutex; + + std::pair>* current_smem_size; + { + std::lock_guard map_lock{map_mutex}; + current_smem_size = &jit_smem_sizes[kernel]; + } + safely_launch_kernel_with_smem_size_impl( + kernel, smem_size, launch, current_smem_size->first, current_smem_size->second); +} + +// General template for regular function pointers +template +void safely_launch_kernel_with_smem_size(KernelT const& kernel, + uint32_t smem_size, + KernelLauncherT const& launch) +{ + // the last smem size is parameterized by the kernel thanks to the template parameter. + static std::atomic current_smem_size{0}; + static std::mutex mutex; + + safely_launch_kernel_with_smem_size_impl( + kernel, smem_size, launch, mutex, current_smem_size); +} + } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp index ed8191016b..a83a8cbd4d 100644 --- a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp +++ b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh index 052b7bfe9a..f392d54026 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh @@ -22,6 +22,7 @@ typename cuvs::spatial::knn::detail::utils::config::value_t, \ IdxT, \ SampleFilterT>(const index& index, \ + const search_params& params, \ const T* queries, \ const uint32_t* coarse_query_results, \ const uint32_t n_queries, \ diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh index 3a782822b4..260aeaa29c 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh @@ -21,6 +21,7 @@ namespace cuvs::neighbors::ivf_flat::detail { template void ivfflat_interleaved_scan(const index& index, + const search_params& params, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, @@ -44,6 +45,7 @@ void ivfflat_interleaved_scan(const index& index, typename cuvs::spatial::knn::detail::utils::config::value_t, \ IdxT, \ SampleFilterT>(const index& index, \ + const search_params& params, \ const T* queries, \ const uint32_t* coarse_query_results, \ const uint32_t n_queries, \ diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh index c4113a83ce..261d3e60a4 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh @@ -5,13 +5,14 @@ #pragma once +#include "../detail/jit_lto_kernels/filter_data.h" #include "../ivf_common.cuh" #include "detail/jit_lto_kernels/interleaved_scan_planner.hpp" #include "detail/jit_lto_kernels/kernel_def.hpp" #include #include #include -#include +#include #include #include #include @@ -28,7 +29,6 @@ #include namespace cuvs::neighbors::ivf_flat::detail { - static constexpr int kThreadsPerBlock = 128; using namespace cuvs::spatial::knn::detail; // NOLINT @@ -137,6 +137,7 @@ template void launch_kernel(const index& index, + const search_params& params, const T* queries, const uint32_t* coarse_index, const uint32_t num_queries, @@ -203,6 +204,9 @@ void launch_kernel(const index& index, return; } + // Pass individual filter parameters like CAGRA does + // The kernel will construct filter_data struct internally when needed + for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); dim3 grid_dim(grid_dim_x, grid_dim_y, 1); @@ -419,6 +423,7 @@ struct select_interleaved_scan_kernel { */ template void ivfflat_interleaved_scan(const index& index, + const search_params& params, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, @@ -455,6 +460,7 @@ void ivfflat_interleaved_scan(const index& index, select_min, metric, index, + params, queries, coarse_query_results, n_queries, diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index f42ffdc837..17ddeabe06 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -194,6 +194,7 @@ void search_impl(raft::resources const& handle, // query the gridDimX size to store probes topK output ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( index, + params, nullptr, nullptr, n_queries, @@ -250,6 +251,7 @@ void search_impl(raft::resources const& handle, ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( index, + params, queries, coarse_indices_dev.data(), n_queries, diff --git a/cpp/src/neighbors/ivf_flat_index.cpp b/cpp/src/neighbors/ivf_flat_index.cpp index 77b24d4690..6ab162a117 100644 --- a/cpp/src/neighbors/ivf_flat_index.cpp +++ b/cpp/src/neighbors/ivf_flat_index.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -60,6 +60,12 @@ cuvs::distance::DistanceType index::metric() const noexcept return metric_; } +template +void index::set_metric(cuvs::distance::DistanceType metric) +{ + metric_ = metric; +} + template bool index::adaptive_centers() const noexcept { diff --git a/cpp/src/neighbors/refine/refine_device.cuh b/cpp/src/neighbors/refine/refine_device.cuh index e027dca53b..75bb0465aa 100644 --- a/cpp/src/neighbors/refine/refine_device.cuh +++ b/cpp/src/neighbors/refine/refine_device.cuh @@ -104,6 +104,7 @@ void refine_device( cuvs::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( refinement_index, + cuvs::neighbors::ivf_flat::search_params(), queries.data_handle(), fake_coarse_idx.data(), static_cast(n_queries), diff --git a/dependencies.yaml b/dependencies.yaml index 2aae054862..cdfecd59c2 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -375,6 +375,7 @@ dependencies: - libcusolver-dev - libcusparse-dev - libnvjitlink-dev + - cuda-nvrtc-dev cuda_wheels: specific: # cuVS needs 'nvJitLink>={whatever-cuvs-was-built-against}' at runtime, and mixing