diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 97376fa88f..60cab78b25 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -306,7 +306,6 @@ if(BUILD_SHARED_LIBS) src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu src/neighbors/mg/omp_checks.cpp - src/neighbors/mg/nccl_comm.cpp ) endif() @@ -609,6 +608,7 @@ if(BUILD_SHARED_LIBS) if(BUILD_MG_ALGOS) target_compile_definitions(cuvs PUBLIC CUVS_BUILD_MG_ALGOS) target_compile_definitions(cuvs_objs PUBLIC CUVS_BUILD_MG_ALGOS) + target_compile_definitions(cuvs-cagra-search PUBLIC CUVS_BUILD_MG_ALGOS) endif() if(BUILD_CAGRA_HNSWLIB) diff --git a/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu b/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu index 8930972365..821b80cc72 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu @@ -30,38 +30,38 @@ namespace cuvs::bench { #ifdef CUVS_ANN_BENCH_USE_CUVS_MG -void add_distribution_mode(cuvs::neighbors::mg::distribution_mode* dist_mode, +void add_distribution_mode(cuvs::neighbors::distribution_mode* dist_mode, const nlohmann::json& conf) { if (conf.contains("distribution_mode")) { std::string distribution_mode = conf.at("distribution_mode"); if (distribution_mode == "replicated") { - *dist_mode = cuvs::neighbors::mg::distribution_mode::REPLICATED; + *dist_mode = cuvs::neighbors::distribution_mode::REPLICATED; } else if (distribution_mode == "sharded") { - *dist_mode = cuvs::neighbors::mg::distribution_mode::SHARDED; + *dist_mode = cuvs::neighbors::distribution_mode::SHARDED; } else { throw std::runtime_error("invalid value for distribution_mode"); } } else { // default - *dist_mode = cuvs::neighbors::mg::distribution_mode::SHARDED; + *dist_mode = cuvs::neighbors::distribution_mode::SHARDED; } }; -void add_merge_mode(cuvs::neighbors::mg::sharded_merge_mode* merge_mode, const nlohmann::json& conf) +void add_merge_mode(cuvs::neighbors::sharded_merge_mode* merge_mode, const nlohmann::json& conf) { if (conf.contains("merge_mode")) { std::string sharded_merge_mode = conf.at("merge_mode"); if (sharded_merge_mode == "tree_merge") { - *merge_mode = cuvs::neighbors::mg::sharded_merge_mode::TREE_MERGE; + *merge_mode = cuvs::neighbors::sharded_merge_mode::TREE_MERGE; } else if (sharded_merge_mode == "merge_on_root_rank") { - *merge_mode = cuvs::neighbors::mg::sharded_merge_mode::MERGE_ON_ROOT_RANK; + *merge_mode = cuvs::neighbors::sharded_merge_mode::MERGE_ON_ROOT_RANK; } else { throw std::runtime_error("invalid value for merge_mode"); } } else { // default - *merge_mode = cuvs::neighbors::mg::sharded_merge_mode::TREE_MERGE; + *merge_mode = cuvs::neighbors::sharded_merge_mode::TREE_MERGE; } }; #endif diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index cb2bbd9b5a..a655fdf674 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -17,8 +17,8 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_cagra_wrapper.h" -#include -#include +#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -33,21 +33,20 @@ class cuvs_mg_cagra : public algo, public algo_gpu { using algo::dim_; struct build_param : public cuvs::bench::cuvs_cagra::build_param { - cuvs::neighbors::mg::distribution_mode mode; + cuvs::neighbors::distribution_mode mode; }; struct search_param : public cuvs::bench::cuvs_cagra::search_param { - cuvs::neighbors::mg::sharded_merge_mode merge_mode; + cuvs::neighbors::sharded_merge_mode merge_mode; }; cuvs_mg_cagra(Metric metric, int dim, const build_param& param, int concurrent_searches = 1) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); - // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; @@ -69,7 +68,7 @@ class cuvs_mg_cagra : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -87,11 +86,11 @@ class cuvs_mg_cagra : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; + raft::device_resources_snmg clique_; float refine_ratio_; build_param index_params_; - cuvs::neighbors::mg::search_params search_params_; - std::shared_ptr, T, IdxT>> + cuvs::neighbors::mg_search_params search_params_; + std::shared_ptr, T, IdxT>> index_; }; @@ -100,14 +99,14 @@ void cuvs_mg_cagra::build(const T* dataset, size_t nrow) { auto dataset_extents = raft::make_extents(nrow, dim_); index_params_.prepare_build_params(dataset_extents); - cuvs::neighbors::mg::index_params build_params = index_params_.cagra_params; - build_params.mode = index_params_.mode; + cuvs::neighbors::mg_index_params build_params = index_params_.cagra_params; + build_params.mode = index_params_.mode; auto dataset_view = raft::make_host_matrix_view(dataset, nrow, dim_); - auto idx = cuvs::neighbors::mg::build(handle_, build_params, dataset_view); + auto idx = cuvs::neighbors::cagra::build(clique_, build_params, dataset_view); index_ = - std::make_shared, T, IdxT>>( + std::make_shared, T, IdxT>>( std::move(idx)); } @@ -118,8 +117,7 @@ void cuvs_mg_cagra::set_search_param(const search_param_base& param, const void* filter_bitset) { if (filter_bitset != nullptr) { throw std::runtime_error("Filtering is not supported yet."); } - auto sp = dynamic_cast(param); - // search_params_ = static_cast>(sp.p); + auto sp = dynamic_cast(param); cagra::search_params* search_params_ptr_ = static_cast(&search_params_); *search_params_ptr_ = sp.p; search_params_.merge_mode = sp.merge_mode; @@ -134,15 +132,15 @@ void cuvs_mg_cagra::set_search_dataset(const T* dataset, size_t nrow) template void cuvs_mg_cagra::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::cagra::serialize(clique_, *index_, file); } template void cuvs_mg_cagra::load(const std::string& file) { index_ = - std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_cagra(handle_, file))); + std::make_shared, T, IdxT>>( + std::move(cuvs::neighbors::cagra::deserialize(clique_, file))); } template @@ -165,8 +163,8 @@ void cuvs_mg_cagra::search_base( auto distances_view = raft::make_host_matrix_view(distances, batch_size, k); - cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + cuvs::neighbors::cagra::search( + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } template diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index b9f8b4b322..ca05ecdf0a 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -18,8 +18,8 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_ivf_flat_wrapper.h" -#include -#include +#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -30,18 +30,18 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { using search_param_base = typename algo::search_param; using algo::dim_; - using build_param = cuvs::neighbors::mg::index_params; + using build_param = cuvs::neighbors::mg_index_params; struct search_param : public cuvs::bench::cuvs_ivf_flat::search_param { - cuvs::neighbors::mg::sharded_merge_mode merge_mode; + cuvs::neighbors::sharded_merge_mode merge_mode; }; cuvs_mg_ivf_flat(Metric metric, int dim, const build_param& param) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); - // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; @@ -62,7 +62,7 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -73,10 +73,10 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; + raft::device_resources_snmg clique_; build_param index_params_; - cuvs::neighbors::mg::search_params search_params_; - std::shared_ptr, T, IdxT>> + cuvs::neighbors::mg_search_params search_params_; + std::shared_ptr, T, IdxT>> index_; }; @@ -85,9 +85,10 @@ void cuvs_mg_ivf_flat::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(handle_, index_params_, dataset_view); - index_ = std::make_shared< - cuvs::neighbors::mg::index, T, IdxT>>(std::move(idx)); + auto idx = cuvs::neighbors::ivf_flat::build(clique_, index_params_, dataset_view); + index_ = + std::make_shared, T, IdxT>>( + std::move(idx)); } template @@ -107,15 +108,15 @@ void cuvs_mg_ivf_flat::set_search_param(const search_param_base& param, template void cuvs_mg_ivf_flat::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::ivf_flat::serialize(clique_, *index_, file); } template void cuvs_mg_ivf_flat::load(const std::string& file) { - index_ = std::make_shared< - cuvs::neighbors::mg::index, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_flat(handle_, file))); + index_ = + std::make_shared, T, IdxT>>( + std::move(cuvs::neighbors::ivf_flat::deserialize(clique_, file))); } template @@ -135,8 +136,8 @@ void cuvs_mg_ivf_flat::search( auto distances_view = raft::make_host_matrix_view( distances, IdxT(batch_size), IdxT(k)); - cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + cuvs::neighbors::ivf_flat::search( + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } } // namespace cuvs::bench diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index 26781c522b..c37bbdf18b 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -18,8 +18,8 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_ivf_pq_wrapper.h" -#include -#include +#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -30,18 +30,18 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { using search_param_base = typename algo::search_param; using algo::dim_; - using build_param = cuvs::neighbors::mg::index_params; + using build_param = cuvs::neighbors::mg_index_params; struct search_param : public cuvs::bench::cuvs_ivf_pq::search_param { - cuvs::neighbors::mg::sharded_merge_mode merge_mode; + cuvs::neighbors::sharded_merge_mode merge_mode; }; cuvs_mg_ivf_pq(Metric metric, int dim, const build_param& param) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); - // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; @@ -62,7 +62,7 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -73,10 +73,10 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; + raft::device_resources_snmg clique_; build_param index_params_; - cuvs::neighbors::mg::search_params search_params_; - std::shared_ptr, T, IdxT>> index_; + cuvs::neighbors::mg_search_params search_params_; + std::shared_ptr, T, IdxT>> index_; }; template @@ -84,9 +84,9 @@ void cuvs_mg_ivf_pq::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(handle_, index_params_, dataset_view); + auto idx = cuvs::neighbors::ivf_pq::build(clique_, index_params_, dataset_view); index_ = - std::make_shared, T, IdxT>>( + std::make_shared, T, IdxT>>( std::move(idx)); } @@ -95,8 +95,7 @@ void cuvs_mg_ivf_pq::set_search_param(const search_param_base& param, const void* filter_bitset) { if (filter_bitset != nullptr) { throw std::runtime_error("Filtering is not supported yet."); } - auto sp = dynamic_cast(param); - // search_params_ = static_cast>(sp.pq_param); + auto sp = dynamic_cast(param); ivf_pq::search_params* search_params_ptr_ = static_cast(&search_params_); *search_params_ptr_ = sp.pq_param; search_params_.merge_mode = sp.merge_mode; @@ -106,15 +105,15 @@ void cuvs_mg_ivf_pq::set_search_param(const search_param_base& param, template void cuvs_mg_ivf_pq::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::ivf_pq::serialize(clique_, *index_, file); } template void cuvs_mg_ivf_pq::load(const std::string& file) { index_ = - std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_pq(handle_, file))); + std::make_shared, T, IdxT>>( + std::move(cuvs::neighbors::ivf_pq::deserialize(clique_, file))); } template @@ -134,8 +133,8 @@ void cuvs_mg_ivf_pq::search( auto distances_view = raft::make_host_matrix_view( distances, IdxT(batch_size), IdxT(k)); - cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + cuvs::neighbors::ivf_pq::search( + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } } // namespace cuvs::bench diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index f92835e282..07aa21e62c 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -2091,4 +2091,456 @@ auto merge(raft::resources const& res, * @} */ +/// \defgroup mg_cpp_index_build ANN MG index build + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, float, uint32_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, half, uint32_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, int8_t, uint32_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, uint8_t, uint32_t>; + +/// \defgroup mg_cpp_index_extend ANN MG index extend + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, float, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, half, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, int8_t, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, uint8_t, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \defgroup mg_cpp_index_search ANN MG index search + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search(const raft::resources& clique, + const cuvs::neighbors::mg_index, float, uint32_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search(const raft::resources& clique, + const cuvs::neighbors::mg_index, half, uint32_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search( + const raft::resources& clique, + const cuvs::neighbors::mg_index, int8_t, uint32_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search( + const raft::resources& clique, + const cuvs::neighbors::mg_index, uint8_t, uint32_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \defgroup mg_cpp_serialize ANN MG index serialization + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::resources& clique, + const cuvs::neighbors::mg_index, float, uint32_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::resources& clique, + const cuvs::neighbors::mg_index, half, uint32_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::resources& clique, + const cuvs::neighbors::mg_index, int8_t, uint32_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::resources& clique, + const cuvs::neighbors::mg_index, uint8_t, uint32_t>& index, + const std::string& filename); + +/// \defgroup mg_cpp_deserialize ANN MG index deserialization + +/// \ingroup mg_cpp_deserialize +/** + * @brief Deserializes a CAGRA multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::cagra::deserialize(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized + * + */ +template +auto deserialize(const raft::resources& clique, const std::string& filename) + -> cuvs::neighbors::mg_index, T, IdxT>; + +/// \defgroup mg_cpp_distribute ANN MG local index distribution + +/// \ingroup mg_cpp_distribute +/** + * @brief Replicates a locally built and serialized CAGRA index to all GPUs to form a distributed + * multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::cagra::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "local_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::cagra::distribute(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized : a local index + * + */ +template +auto distribute(const raft::resources& clique, const std::string& filename) + -> cuvs::neighbors::mg_index, T, IdxT>; + } // namespace cuvs::neighbors::cagra diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 038b6b1da5..4c31987989 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -740,9 +740,7 @@ enable_if_valid_list_t deserialize_list(const raft::resources& handle, const typename ListT::spec_type& store_spec, const typename ListT::spec_type& device_spec); } // namespace ivf -} // namespace cuvs::neighbors -namespace cuvs::neighbors { using namespace raft; template @@ -756,21 +754,21 @@ struct iface { }; template -void build(const raft::device_resources& handle, +void build(const raft::resources& handle, cuvs::neighbors::iface& interface, const cuvs::neighbors::index_params* index_params, raft::mdspan, row_major, Accessor> index_dataset); template void extend( - const raft::device_resources& handle, + const raft::resources& handle, cuvs::neighbors::iface& interface, raft::mdspan, row_major, Accessor1> new_vectors, std::optional, layout_c_contiguous, Accessor2>> new_indices); template -void search(const raft::device_resources& handle, +void search(const raft::resources& handle, const cuvs::neighbors::iface& interface, const cuvs::neighbors::search_params* search_params, raft::device_matrix_view h_queries, @@ -778,18 +776,98 @@ void search(const raft::device_resources& handle, raft::device_matrix_view d_distances); template -void serialize(const raft::device_resources& handle, +void serialize(const raft::resources& handle, const cuvs::neighbors::iface& interface, std::ostream& os); template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::resources& handle, cuvs::neighbors::iface& interface, std::istream& is); template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::resources& handle, cuvs::neighbors::iface& interface, const std::string& filename); -}; // namespace cuvs::neighbors +/// \defgroup mg_cpp_index_params ANN MG index build parameters + +/** Distribution mode */ +/// \ingroup mg_cpp_index_params +enum distribution_mode { + /** Index is replicated on each device, favors throughput */ + REPLICATED, + /** Index is split on several devices, favors scaling */ + SHARDED +}; + +/// \defgroup mg_cpp_search_params ANN MG search parameters + +/** Search mode when using a replicated index */ +/// \ingroup mg_cpp_search_params +enum replicated_search_mode { + /** Search queries are splited to maintain equal load on GPUs */ + LOAD_BALANCER, + /** Each search query is processed by a single GPU in a round-robin fashion */ + ROUND_ROBIN +}; + +/** Merge mode when using a sharded index */ +/// \ingroup mg_cpp_search_params +enum sharded_merge_mode { + /** Search batches are merged on the root rank */ + MERGE_ON_ROOT_RANK, + /** Search batches are merged in a tree reduction fashion */ + TREE_MERGE +}; + +/** Build parameters */ +/// \ingroup mg_cpp_index_params +template +struct mg_index_params : public Upstream { + mg_index_params() : mode(SHARDED) {} + + mg_index_params(const Upstream& sp) : Upstream(sp), mode(SHARDED) {} + + /** Distribution mode */ + cuvs::neighbors::distribution_mode mode = SHARDED; +}; + +/** Search parameters */ +/// \ingroup mg_cpp_search_params +template +struct mg_search_params : public Upstream { + mg_search_params() : search_mode(LOAD_BALANCER), merge_mode(TREE_MERGE) {} + + mg_search_params(const Upstream& sp) + : Upstream(sp), search_mode(LOAD_BALANCER), merge_mode(TREE_MERGE) + { + } + + /** Replicated search mode */ + cuvs::neighbors::replicated_search_mode search_mode = LOAD_BALANCER; + /** Sharded merge mode */ + cuvs::neighbors::sharded_merge_mode merge_mode = TREE_MERGE; + /** Number of rows per batch */ + int64_t n_rows_per_batch = 1 << 20; +}; + +template +struct mg_index { + mg_index(const raft::resources& clique, distribution_mode mode); + mg_index(const raft::resources& clique, const std::string& filename); + + mg_index(const mg_index&) = delete; + mg_index(mg_index&&) = default; + auto operator=(const mg_index&) -> mg_index& = delete; + auto operator=(mg_index&&) -> mg_index& = default; + + distribution_mode mode_; + int num_ranks_; + std::vector> ann_interfaces_; + + // for load balancing mechanism + std::shared_ptr> round_robin_counter_; +}; + +} // namespace cuvs::neighbors diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index e017946d9c..bba6436dbb 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -1598,6 +1598,359 @@ void deserialize(raft::resources const& handle, * @} */ +/// \defgroup mg_cpp_index_build ANN MG index build + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-Flat MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, float, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-Flat MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, int8_t, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-Flat MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, uint8_t, int64_t>; + +/// \defgroup mg_cpp_index_extend ANN MG index extend + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, float, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, int8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, uint8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \defgroup mg_cpp_index_search ANN MG index search + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search(const raft::resources& clique, + const cuvs::neighbors::mg_index, float, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search( + const raft::resources& clique, + const cuvs::neighbors::mg_index, int8_t, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search( + const raft::resources& clique, + const cuvs::neighbors::mg_index, uint8_t, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \defgroup mg_cpp_serialize ANN MG index serialization + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::resources& clique, + const cuvs::neighbors::mg_index, float, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::resources& clique, + const cuvs::neighbors::mg_index, int8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::resources& clique, + const cuvs::neighbors::mg_index, uint8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_deserialize +/** + * @brief Deserializes an IVF-Flat multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::ivf_flat::deserialize(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized + * + */ +template +auto deserialize(const raft::resources& clique, const std::string& filename) + -> cuvs::neighbors::mg_index, T, IdxT>; + +/// \defgroup mg_cpp_distribute ANN MG local index distribution + +/// \ingroup mg_cpp_distribute +/** + * @brief Replicates a locally built and serialized IVF-Flat index to all GPUs to form a distributed + * multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::ivf_flat::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "local_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::ivf_flat::distribute(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized : a local index + * + */ +template +auto distribute(const raft::resources& clique, const std::string& filename) + -> cuvs::neighbors::mg_index, T, IdxT>; + namespace helpers { /** diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index d85753b7f2..73f9a831bc 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -1728,6 +1728,449 @@ void deserialize(raft::resources const& handle, * @} */ +/// \defgroup mg_cpp_index_build ANN MG index build + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, float, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, half, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, int8_t, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::resources& clique, + const cuvs::neighbors::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg_index, uint8_t, int64_t>; + +/// \defgroup mg_cpp_index_extend ANN MG index extend + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, float, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, half, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, int8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::resources& clique, + cuvs::neighbors::mg_index, uint8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \defgroup mg_cpp_index_search ANN MG index search + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search(const raft::resources& clique, + const cuvs::neighbors::mg_index, float, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search(const raft::resources& clique, + const cuvs::neighbors::mg_index, half, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search(const raft::resources& clique, + const cuvs::neighbors::mg_index, int8_t, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg_search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * + */ +void search(const raft::resources& clique, + const cuvs::neighbors::mg_index, uint8_t, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); + +/// \defgroup mg_cpp_serialize ANN MG index serialization + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::resources& clique, + const cuvs::neighbors::mg_index, float, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::resources& clique, + const cuvs::neighbors::mg_index, half, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::resources& clique, + const cuvs::neighbors::mg_index, int8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::resources& clique, + const cuvs::neighbors::mg_index, uint8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_deserialize +/** + * @brief Deserializes an IVF-PQ multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg_index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::ivf_pq::deserialize(clique, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized + * + */ +template +auto deserialize(const raft::resources& clique, const std::string& filename) + -> cuvs::neighbors::mg_index, T, IdxT>; + +/// \defgroup mg_cpp_distribute ANN MG local index distribution + +/// \ingroup mg_cpp_distribute +/** + * @brief Replicates a locally built and serialized IVF-PQ index to all GPUs to form a distributed + * multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::ivf_pq::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "local_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::ivf_pq::distribute(clique, filename); + * @endcode + * + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized : a local index + * + */ +template +auto distribute(const raft::resources& clique, const std::string& filename) + -> cuvs::neighbors::mg_index, T, IdxT>; + namespace helpers { /** * @defgroup ivf_pq_cpp_helpers IVF-PQ helper methods diff --git a/cpp/include/cuvs/neighbors/mg.hpp b/cpp/include/cuvs/neighbors/mg.hpp deleted file mode 100644 index 4657fa8fb0..0000000000 --- a/cpp/include/cuvs/neighbors/mg.hpp +++ /dev/null @@ -1,1367 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef CUVS_BUILD_MG_ALGOS - -#include -#include - -#include -#include - -#include -#include -#include -#include - -#define DEFAULT_SEARCH_BATCH_SIZE 1 << 20 - -/// \defgroup mg_cpp_index_params ANN MG index build parameters - -namespace cuvs::neighbors::mg { -/** Distribution mode */ -/// \ingroup mg_cpp_index_params -enum distribution_mode { - /** Index is replicated on each device, favors throughput */ - REPLICATED, - /** Index is split on several devices, favors scaling */ - SHARDED -}; - -/// \defgroup mg_cpp_search_params ANN MG search parameters - -/** Search mode when using a replicated index */ -/// \ingroup mg_cpp_search_params -enum replicated_search_mode { - /** Search queries are splited to maintain equal load on GPUs */ - LOAD_BALANCER, - /** Each search query is processed by a single GPU in a round-robin fashion */ - ROUND_ROBIN -}; - -/** Merge mode when using a sharded index */ -/// \ingroup mg_cpp_search_params -enum sharded_merge_mode { - /** Search batches are merged on the root rank */ - MERGE_ON_ROOT_RANK, - /** Search batches are merged in a tree reduction fashion */ - TREE_MERGE -}; - -/** Build parameters */ -/// \ingroup mg_cpp_index_params -template -struct index_params : public Upstream { - index_params() : mode(SHARDED) {} - - index_params(const Upstream& sp) : Upstream(sp), mode(SHARDED) {} - - /** Distribution mode */ - cuvs::neighbors::mg::distribution_mode mode = SHARDED; -}; - -/** Search parameters */ -/// \ingroup mg_cpp_search_params -template -struct search_params : public Upstream { - search_params() : search_mode(LOAD_BALANCER), merge_mode(TREE_MERGE) {} - - search_params(const Upstream& sp) - : Upstream(sp), search_mode(LOAD_BALANCER), merge_mode(TREE_MERGE) - { - } - - /** Replicated search mode */ - cuvs::neighbors::mg::replicated_search_mode search_mode = LOAD_BALANCER; - /** Sharded merge mode */ - cuvs::neighbors::mg::sharded_merge_mode merge_mode = TREE_MERGE; -}; - -} // namespace cuvs::neighbors::mg - -namespace cuvs::neighbors::mg { - -using namespace raft; - -template -struct index { - index(distribution_mode mode, int num_ranks_); - index(const raft::device_resources& handle, const std::string& filename); - - index(const index&) = delete; - index(index&&) = default; - auto operator=(const index&) -> index& = delete; - auto operator=(index&&) -> index& = default; - - distribution_mode mode_; - int num_ranks_; - std::vector> ann_interfaces_; - - // for load balancing mechanism - std::shared_ptr> round_robin_counter_; -}; - -/// \defgroup mg_cpp_index_build ANN MG index build - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-Flat MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, float, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-Flat MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, int8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-Flat MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, uint8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, float, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, half, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, int8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, uint8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, float, uint32_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, half, uint32_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, int8_t, uint32_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * @endcode - * - * @param[in] handle - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources& handle, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, uint8_t, uint32_t>; - -/// \defgroup mg_cpp_index_extend ANN MG index extend - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, float, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, int8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, uint8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, float, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, half, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, int8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, uint8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, float, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, half, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, int8_t, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources& handle, - index, uint8_t, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \defgroup mg_cpp_index_search ANN MG index search - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, float, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, int8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, uint8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, float, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, half, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, int8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, uint8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, float, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, half, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, int8_t, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources& handle, - const index, uint8_t, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \defgroup mg_cpp_serialize ANN MG index serialization - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, float, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, int8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, uint8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, float, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, half, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, int8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, uint8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, float, uint32_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, half, uint32_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, int8_t, uint32_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * @endcode - * - * @param[in] handle - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources& handle, - const index, uint8_t, uint32_t>& index, - const std::string& filename); - -/// \defgroup mg_cpp_deserialize ANN MG index deserialization - -/// \ingroup mg_cpp_deserialize -/** - * @brief Deserializes an IVF-Flat multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_flat(handle, filename); - * - * @endcode - * - * @param[in] handle - * @param[in] filename path to the file to be deserialized - * - */ -template -auto deserialize_flat(const raft::device_resources& handle, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_deserialize -/** - * @brief Deserializes an IVF-PQ multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_pq(handle, filename); - * @endcode - * - * @param[in] handle - * @param[in] filename path to the file to be deserialized - * - */ -template -auto deserialize_pq(const raft::device_resources& handle, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_deserialize -/** - * @brief Deserializes a CAGRA multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_cagra(handle, filename); - * - * @endcode - * - * @param[in] handle - * @param[in] filename path to the file to be deserialized - * - */ -template -auto deserialize_cagra(const raft::device_resources& handle, const std::string& filename) - -> index, T, IdxT>; - -/// \defgroup mg_cpp_distribute ANN MG local index distribution - -/// \ingroup mg_cpp_distribute -/** - * @brief Replicates a locally built and serialized IVF-Flat index to all GPUs to form a distributed - * multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::ivf_flat::index_params index_params; - * auto index = cuvs::neighbors::ivf_flat::build(handle, index_params, index_dataset); - * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_flat::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_flat(handle, filename); - * - * @endcode - * - * @param[in] handle - * @param[in] filename path to the file to be deserialized : a local index - * - */ -template -auto distribute_flat(const raft::device_resources& handle, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_distribute -/** - * @brief Replicates a locally built and serialized IVF-PQ index to all GPUs to form a distributed - * multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::ivf_pq::index_params index_params; - * auto index = cuvs::neighbors::ivf_pq::build(handle, index_params, index_dataset); - * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_pq::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_pq(handle, filename); - * @endcode - * - * @param[in] handle - * @param[in] filename path to the file to be deserialized : a local index - * - */ -template -auto distribute_pq(const raft::device_resources& handle, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_distribute -/** - * @brief Replicates a locally built and serialized CAGRA index to all GPUs to form a distributed - * multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::handle_t handle; - * cuvs::neighbors::cagra::index_params index_params; - * auto index = cuvs::neighbors::cagra::build(handle, index_params, index_dataset); - * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::cagra::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_cagra(handle, filename); - * - * @endcode - * - * @param[in] handle - * @param[in] filename path to the file to be deserialized : a local index - * - */ -template -auto distribute_cagra(const raft::device_resources& handle, const std::string& filename) - -> index, T, IdxT>; - -} // namespace cuvs::neighbors::mg - -#else - -static_assert(false, - "FORBIDEN_MG_ALGORITHM_IMPORT\n\n" - "Please recompile the cuVS library with MG algorithms BUILD_MG_ALGOS=ON.\n"); - -#endif diff --git a/cpp/src/neighbors/iface/generate_iface.py b/cpp/src/neighbors/iface/generate_iface.py index 794219bbf3..698bee88d8 100644 --- a/cpp/src/neighbors/iface/generate_iface.py +++ b/cpp/src/neighbors/iface/generate_iface.py @@ -58,49 +58,49 @@ using IdxT_ha = raft::host_device_accessor, raft::memory_type::device>; \\ using IdxT_da = raft::host_device_accessor, raft::memory_type::host>; \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_ha> index_dataset); \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_da> index_dataset); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_ha> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_ha>> new_indices); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_da> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_da>> new_indices); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::device_matrix_view queries, \\ raft::device_matrix_view neighbors, \\ raft::device_matrix_view distances); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::host_matrix_view h_queries, \\ raft::device_matrix_view d_neighbors, \\ raft::device_matrix_view d_distances); \\ \\ - template void serialize(const raft::device_resources& handle, \\ + template void serialize(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ std::ostream& os); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ std::istream& is); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const std::string& filename); """ @@ -112,49 +112,49 @@ using IdxT_ha = raft::host_device_accessor, raft::memory_type::device>; \\ using IdxT_da = raft::host_device_accessor, raft::memory_type::host>; \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_ha> index_dataset); \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_da> index_dataset); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_ha> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_ha>> new_indices); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_da> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_da>> new_indices); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::device_matrix_view queries, \\ raft::device_matrix_view neighbors, \\ raft::device_matrix_view distances); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::host_matrix_view h_queries, \\ raft::device_matrix_view d_neighbors, \\ raft::device_matrix_view d_distances); \\ \\ - template void serialize(const raft::device_resources& handle, \\ + template void serialize(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ std::ostream& os); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ std::istream& is); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const std::string& filename); """ @@ -166,49 +166,49 @@ using IdxT_ha = raft::host_device_accessor, raft::memory_type::device>; \\ using IdxT_da = raft::host_device_accessor, raft::memory_type::host>; \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_ha> index_dataset); \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_da> index_dataset); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_ha> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_ha>> new_indices); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_da> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_da>> new_indices); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::device_matrix_view queries, \\ raft::device_matrix_view neighbors, \\ raft::device_matrix_view distances); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::host_matrix_view h_queries, \\ raft::device_matrix_view d_neighbors, \\ raft::device_matrix_view d_distances); \\ \\ - template void serialize(const raft::device_resources& handle, \\ + template void serialize(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ std::ostream& os); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ std::istream& is); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const std::string& filename); """ diff --git a/cpp/src/neighbors/iface/iface.hpp b/cpp/src/neighbors/iface/iface.hpp index 59d3e12d0e..4f7a6f5cc9 100644 --- a/cpp/src/neighbors/iface/iface.hpp +++ b/cpp/src/neighbors/iface/iface.hpp @@ -31,7 +31,7 @@ namespace cuvs::neighbors { using namespace raft; template -void build(const raft::device_resources& handle, +void build(const raft::resources& handle, cuvs::neighbors::iface& interface, const cuvs::neighbors::index_params* index_params, raft::mdspan, row_major, Accessor> index_dataset) @@ -56,7 +56,7 @@ void build(const raft::device_resources& handle, template void extend( - const raft::device_resources& handle, + const raft::resources& handle, cuvs::neighbors::iface& interface, raft::mdspan, row_major, Accessor1> new_vectors, std::optional, layout_c_contiguous, Accessor2>> @@ -79,7 +79,7 @@ void extend( } template -void search(const raft::device_resources& handle, +void search(const raft::resources& handle, const cuvs::neighbors::iface& interface, const cuvs::neighbors::search_params* search_params, raft::device_matrix_view queries, @@ -115,7 +115,7 @@ void search(const raft::device_resources& handle, // for MG ANN only template -void search(const raft::device_resources& handle, +void search(const raft::resources& handle, const cuvs::neighbors::iface& interface, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view h_queries, @@ -137,7 +137,7 @@ void search(const raft::device_resources& handle, } template -void serialize(const raft::device_resources& handle, +void serialize(const raft::resources& handle, const cuvs::neighbors::iface& interface, std::ostream& os) { @@ -150,10 +150,12 @@ void serialize(const raft::device_resources& handle, } else if constexpr (std::is_same>::value) { cagra::serialize(handle, os, interface.index_.value(), true); } + + resource::sync_stream(handle); } template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::resources& handle, cuvs::neighbors::iface& interface, std::istream& is) { @@ -162,20 +164,23 @@ void deserialize(const raft::device_resources& handle, if constexpr (std::is_same>::value) { ivf_flat::index idx(handle); ivf_flat::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { ivf_pq::index idx(handle); ivf_pq::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { cagra::index idx(handle); cagra::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } } template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::resources& handle, cuvs::neighbors::iface& interface, const std::string& filename) { @@ -187,14 +192,17 @@ void deserialize(const raft::device_resources& handle, if constexpr (std::is_same>::value) { ivf_flat::index idx(handle); ivf_flat::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { ivf_pq::index idx(handle); ivf_pq::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { cagra::index idx(handle); cagra::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } diff --git a/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu index b5e329dd82..b7ad428ad8 100644 --- a/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_CAGRA(float, uint32_t); diff --git a/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu index 23fcffc59a..86e0633bb1 100644 --- a/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_CAGRA(half, uint32_t); diff --git a/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu index 30377ab666..64f174184b 100644 --- a/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_CAGRA(int8_t, uint32_t); diff --git a/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu index 59a1640e89..9f6db32df8 100644 --- a/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_CAGRA(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu b/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu index a0a4553753..0afffe0ad3 100644 --- a/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu @@ -38,39 +38,39 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ @@ -78,15 +78,15 @@ namespace cuvs::neighbors { raft::device_matrix_view d_distances); \ \ template void serialize( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_FLAT(float, int64_t); diff --git a/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu index 9fdd6464ff..5afd77053e 100644 --- a/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu @@ -38,39 +38,39 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ @@ -78,15 +78,15 @@ namespace cuvs::neighbors { raft::device_matrix_view d_distances); \ \ template void serialize( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_FLAT(int8_t, int64_t); diff --git a/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu index daee59c4a8..4f2f85700c 100644 --- a/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu @@ -38,39 +38,39 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ @@ -78,15 +78,15 @@ namespace cuvs::neighbors { raft::device_matrix_view d_distances); \ \ template void serialize( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_FLAT(uint8_t, int64_t); diff --git a/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu index 7282d6bd07..90759d5f1a 100644 --- a/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_PQ(float, int64_t); diff --git a/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu index 4d67f9aed1..c92d6fd651 100644 --- a/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_PQ(half, int64_t); diff --git a/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu index 46537b3f9f..59269e9da1 100644 --- a/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_PQ(int8_t, int64_t); diff --git a/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu index 591ac881a8..c407e64cac 100644 --- a/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_PQ(uint8_t, int64_t); diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index af5e605456..49d4609fcd 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -40,190 +40,194 @@ """ include_macro = """ -#include "mg.cuh" -""" - -namespace_macro = """ -namespace cuvs::neighbors::mg { -""" - -footer = """ -} // namespace cuvs::neighbors::mg +#include "snmg.cuh" """ flat_macro = """ -#define CUVS_INST_MG_FLAT(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ - const mg::index_params& index_params, \\ - raft::host_matrix_view index_dataset) \\ - { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ - static_cast(&index_params), \\ - index_dataset); \\ - return index; \\ - } \\ - \\ - void extend(const raft::device_resources& handle, \\ - index, T, IdxT>& index, \\ - raft::host_matrix_view new_vectors, \\ - std::optional> new_indices) \\ - { \\ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \\ - } \\ - \\ - void search(const raft::device_resources& handle, \\ - const index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ - raft::host_matrix_view queries, \\ - raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ - { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ - static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ - } \\ - \\ - void serialize(const raft::device_resources& handle, \\ - const index, T, IdxT>& index, \\ - const std::string& filename) \\ - { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ - } \\ - \\ - template<> \\ - index, T, IdxT> deserialize_flat(const raft::device_resources& handle, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(handle, filename); \\ - return idx; \\ - } \\ - \\ - template<> \\ - index, T, IdxT> distribute_flat(const raft::device_resources& handle, \\ - const std::string& filename) \\ - { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ - return idx; \\ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \\ +namespace cuvs::neighbors::ivf_flat { \\ + using namespace cuvs::neighbors; \\ + \\ + cuvs::neighbors::mg_index, T, IdxT> build( \\ + const raft::resources& res, \\ + const mg_index_params& index_params, \\ + raft::host_matrix_view index_dataset) \\ + { \\ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \\ + cuvs::neighbors::snmg::detail::build(res, index, \\ + static_cast(&index_params), \\ + index_dataset); \\ + return index; \\ + } \\ + \\ + void extend(const raft::resources& res, \\ + cuvs::neighbors::mg_index, T, IdxT>& index, \\ + raft::host_matrix_view new_vectors, \\ + std::optional> new_indices) \\ + { \\ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \\ + } \\ + \\ + void search(const raft::resources& res, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ + const mg_search_params& search_params, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbors, \\ + raft::host_matrix_view distances) \\ + { \\ + cuvs::neighbors::snmg::detail::search(res, index, \\ + static_cast(&search_params), \\ + queries, neighbors, distances); \\ + } \\ + \\ + void serialize(const raft::resources& res, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ + const std::string& filename) \\ + { \\ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \\ + const raft::resources& res, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \\ + return idx; \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg_index, T, IdxT> distribute( \\ + const raft::resources& res, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \\ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \\ + return idx; \\ + } \\ +} // namespace cuvs::neighbors::ivf_flat """ pq_macro = """ -#define CUVS_INST_MG_PQ(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ - const mg::index_params& index_params, \\ - raft::host_matrix_view index_dataset) \\ - { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ - static_cast(&index_params), \\ - index_dataset); \\ - return index; \\ - } \\ - \\ - void extend(const raft::device_resources& handle, \\ - index, T, IdxT>& index, \\ - raft::host_matrix_view new_vectors, \\ - std::optional> new_indices) \\ - { \\ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \\ - } \\ - \\ - void search(const raft::device_resources& handle, \\ - const index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ - raft::host_matrix_view queries, \\ - raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ - { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ - static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ - } \\ - \\ - void serialize(const raft::device_resources& handle, \\ - const index, T, IdxT>& index, \\ - const std::string& filename) \\ - { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ - } \\ - \\ - template<> \\ - index, T, IdxT> deserialize_pq(const raft::device_resources& handle, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(handle, filename); \\ - return idx; \\ - } \\ - \\ - template<> \\ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \\ - const std::string& filename) \\ - { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ - return idx; \\ - } +#define CUVS_INST_MG_PQ(T, IdxT) \\ +namespace cuvs::neighbors::ivf_pq { \\ + using namespace cuvs::neighbors; \\ + \\ + cuvs::neighbors::mg_index, T, IdxT> build( \\ + const raft::resources& res, \\ + const mg_index_params& index_params, \\ + raft::host_matrix_view index_dataset) \\ + { \\ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \\ + cuvs::neighbors::snmg::detail::build(res, index, \\ + static_cast(&index_params), \\ + index_dataset); \\ + return index; \\ + } \\ + \\ + void extend(const raft::resources& res, \\ + cuvs::neighbors::mg_index, T, IdxT>& index, \\ + raft::host_matrix_view new_vectors, \\ + std::optional> new_indices) \\ + { \\ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \\ + } \\ + \\ + void search(const raft::resources& res, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ + const mg_search_params& search_params, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbors, \\ + raft::host_matrix_view distances) \\ + { \\ + cuvs::neighbors::snmg::detail::search(res, index, \\ + static_cast(&search_params), \\ + queries, neighbors, distances); \\ + } \\ + \\ + void serialize(const raft::resources& res, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ + const std::string& filename) \\ + { \\ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \\ + const raft::resources& res, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \\ + return idx; \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg_index, T, IdxT> distribute( \\ + const raft::resources& res, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \\ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \\ + return idx; \\ + } \\ +} // namespace cuvs::neighbors::ivf_pq """ cagra_macro = """ -#define CUVS_INST_MG_CAGRA(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ - const mg::index_params& index_params, \\ - raft::host_matrix_view index_dataset) \\ - { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ - static_cast(&index_params), \\ - index_dataset); \\ - return index; \\ - } \\ - \\ - void search(const raft::device_resources& handle, \\ - const index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ - raft::host_matrix_view queries, \\ - raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ - { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ - static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ - } \\ - \\ - void serialize(const raft::device_resources& handle, \\ - const index, T, IdxT>& index, \\ - const std::string& filename) \\ - { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ - } \\ - \\ - template<> \\ - index, T, IdxT> deserialize_cagra(const raft::device_resources& handle, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(handle, filename); \\ - return idx; \\ - } \\ - \\ - template<> \\ - index, T, IdxT> distribute_cagra(const raft::device_resources& handle, \\ - const std::string& filename) \\ - { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ - return idx; \\ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \\ +namespace cuvs::neighbors::cagra { \\ + using namespace cuvs::neighbors; \\ + \\ + cuvs::neighbors::mg_index, T, IdxT> build( \\ + const raft::resources& res, \\ + const mg_index_params& index_params, \\ + raft::host_matrix_view index_dataset) \\ + { \\ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \\ + cuvs::neighbors::snmg::detail::build(res, index, \\ + static_cast(&index_params), \\ + index_dataset); \\ + return index; \\ + } \\ + \\ + void search(const raft::resources& res, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ + const mg_search_params& search_params, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbors, \\ + raft::host_matrix_view distances) \\ + { \\ + cuvs::neighbors::snmg::detail::search(res, index, \\ + static_cast(&search_params), \\ + queries, neighbors, distances); \\ + } \\ + \\ + void serialize(const raft::resources& res, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ + const std::string& filename) \\ + { \\ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \\ + const raft::resources& res, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \\ + return idx; \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg_index, T, IdxT> distribute( \\ + const raft::resources& res, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \\ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \\ + return idx; \\ + } \\ +} // namespace cuvs::neighbors::cagra """ flat_macros = dict ( @@ -277,10 +281,8 @@ with open(path, "w") as f: f.write(header) f.write(macro['include']) - f.write(namespace_macro) f.write(macro["definition"]) f.write(f"{macro['name']}({T}, {IdxT});\n\n") f.write(f"#undef {macro['name']}\n") - f.write(footer) print(f"src/neighbors/mg/{path}") diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index b11610fb48..3260df7050 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -23,70 +23,66 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(float, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index 8f76c69a34..da8cbb1e49 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -23,70 +23,66 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(half, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index 67b88d7429..5aae599e22 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -23,70 +23,66 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(int8_t, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index f721749233..8384a6a972 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -23,70 +23,66 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(uint8_t, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index 4495e2527c..1948ab8c8a 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -23,78 +23,74 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(float, int64_t); #undef CUVS_INST_MG_FLAT - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 5494414a6a..1e0c9fddda 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -23,78 +23,74 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(int8_t, int64_t); #undef CUVS_INST_MG_FLAT - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index 35df2146bf..cc3945a05b 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -23,78 +23,74 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(uint8_t, int64_t); #undef CUVS_INST_MG_FLAT - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index c671740e61..d25cd3656d 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -23,78 +23,74 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(float, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index b167239c68..f9808df864 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -23,78 +23,74 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(half, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index 127baf8fd5..f2205626d8 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -23,78 +23,74 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(int8_t, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index 869e009a5b..1ec306a34a 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -23,78 +23,74 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::snmg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::snmg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(uint8_t, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/nccl_comm.cpp b/cpp/src/neighbors/mg/nccl_comm.cpp deleted file mode 100644 index c4556957ae..0000000000 --- a/cpp/src/neighbors/mg/nccl_comm.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include -#include - -namespace raft::comms { -void build_comms_nccl_only(raft::resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank) -{ -} -} // namespace raft::comms diff --git a/cpp/src/neighbors/mg/omp_checks.cpp b/cpp/src/neighbors/mg/omp_checks.cpp index c8cc274141..91a23c8f32 100644 --- a/cpp/src/neighbors/mg/omp_checks.cpp +++ b/cpp/src/neighbors/mg/omp_checks.cpp @@ -17,7 +17,7 @@ #include #include -namespace cuvs::neighbors::mg { +namespace cuvs::neighbors::snmg { void check_omp_threads(const int requirements) { @@ -30,4 +30,4 @@ void check_omp_threads(const int requirements) requirements); } -} // namespace cuvs::neighbors::mg +} // namespace cuvs::neighbors::snmg diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/snmg.cuh similarity index 72% rename from cpp/src/neighbors/mg/mg.cuh rename to cpp/src/neighbors/mg/snmg.cuh index e9cdc30f6b..3d9e795adc 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/snmg.cuh @@ -22,8 +22,10 @@ #include #include +#include #include -#include +#include +#include #include @@ -31,7 +33,7 @@ namespace cuvs::neighbors { using namespace raft; template -void search(const raft::device_resources& handle, +void search(const raft::resources& handle, const cuvs::neighbors::iface& interface, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view h_queries, @@ -39,55 +41,49 @@ void search(const raft::device_resources& handle, raft::device_matrix_view d_distances); } // namespace cuvs::neighbors -namespace cuvs::neighbors::mg { +namespace cuvs::neighbors::snmg { void check_omp_threads(const int requirements); -} // namespace cuvs::neighbors::mg +} // namespace cuvs::neighbors::snmg -namespace cuvs::neighbors::mg::detail { +namespace cuvs::neighbors::snmg::detail { using namespace cuvs::neighbors; using namespace raft; // local index deserialization and distribution template -void deserialize_and_distribute(const raft::device_resources& handle, - index& index, +void deserialize_and_distribute(const raft::resources& clique, + mg_index& index, const std::string& filename) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, filename); } } // MG index deserialization template -void deserialize(const raft::device_resources& handle, - index& index, +void deserialize(const raft::resources& clique, + mg_index& index, const std::string& filename) { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const auto& handle = raft::resource::set_current_device_to_root_rank(clique); + index.mode_ = (cuvs::neighbors::distribution_mode)deserialize_scalar(handle, is); + index.num_ranks_ = deserialize_scalar(handle, is); - index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); - index.num_ranks_ = deserialize_scalar(handle, is); - - if (index.num_ranks_ != clique.num_ranks_) { + if (index.num_ranks_ != raft::resource::get_nccl_num_ranks(clique)) { RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks", index.num_ranks_, - clique.num_ranks_); + raft::resource::get_nccl_num_ranks(clique)); } for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, is); } @@ -95,24 +91,20 @@ void deserialize(const raft::device_resources& handle, } template -void build(const raft::device_resources& handle, - index& index, +void build(const raft::resources& clique, + mg_index& index, const cuvs::neighbors::index_params* index_params, raft::host_matrix_view index_dataset) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - if (index.mode_ == REPLICATED) { int64_t n_rows = index_dataset.extent(0); - RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); + RAFT_LOG_DEBUG("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, index_dataset); resource::sync_stream(dev_res); } @@ -121,14 +113,12 @@ void build(const raft::device_resources& handle, int64_t n_cols = index_dataset.extent(1); int64_t n_rows_per_shard = raft::ceildiv(n_rows, (int64_t)index.num_ranks_); - RAFT_LOG_INFO("SHARDED BUILD: %d*%drows", index.num_ranks_, n_rows_per_shard); + RAFT_LOG_DEBUG("SHARDED BUILD: %d*%drows", index.num_ranks_, n_rows_per_shard); index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); int64_t offset = rank * n_rows_per_shard; int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); @@ -142,23 +132,19 @@ void build(const raft::device_resources& handle, } template -void extend(const raft::device_resources& handle, - index& index, +void extend(const raft::resources& clique, + mg_index& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - int64_t n_rows = new_vectors.extent(0); if (index.mode_ == REPLICATED) { - RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); + RAFT_LOG_DEBUG("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::extend(dev_res, ann_if, new_vectors, new_indices); resource::sync_stream(dev_res); } @@ -166,13 +152,11 @@ void extend(const raft::device_resources& handle, int64_t n_cols = new_vectors.extent(1); int64_t n_rows_per_shard = raft::ceildiv(n_rows, (int64_t)index.num_ranks_); - RAFT_LOG_INFO("SHARDED EXTEND: %d*%drows", index.num_ranks_, n_rows_per_shard); + RAFT_LOG_DEBUG("SHARDED EXTEND: %d*%drows", index.num_ranks_, n_rows_per_shard); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); int64_t offset = rank * n_rows_per_shard; int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); @@ -193,8 +177,8 @@ void extend(const raft::device_resources& handle, } template -void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, - const index& index, +void sharded_search_with_direct_merge(const raft::resources& clique, + const mg_index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, @@ -205,7 +189,7 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, int64_t n_neighbors, int64_t n_batches) { - const auto& root_handle = clique.set_current_device_to_root_rank(); + const auto& root_handle = raft::resource::set_current_device_to_root_rank(clique); auto in_neighbors = raft::make_device_matrix( root_handle, index.num_ranks_ * n_rows_per_batch, n_neighbors); auto in_distances = raft::make_device_matrix( @@ -228,13 +212,11 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - auto& ann_if = index.ann_interfaces_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; - if (rank == clique.root_rank_) { // root rank - uint64_t batch_offset = clique.root_rank_ * part_size; + if (rank == raft::resource::get_nccl_clique_root_rank(clique)) { // root rank + uint64_t batch_offset = raft::resource::get_nccl_clique_root_rank(clique) * part_size; auto d_neighbors = raft::make_device_matrix_view( in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); auto d_distances = raft::make_device_matrix_view( @@ -245,21 +227,21 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, // wait for other ranks ncclGroupStart(); for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) { - if (from_rank == clique.root_rank_) continue; + if (from_rank == raft::resource::get_nccl_clique_root_rank(clique)) continue; batch_offset = from_rank * part_size; ncclRecv(in_neighbors.data_handle() + batch_offset, part_size * sizeof(IdxT), ncclUint8, from_rank, - clique.nccl_comms_[rank], - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); ncclRecv(in_distances.data_handle() + batch_offset, part_size * sizeof(float), ncclUint8, from_rank, - clique.nccl_comms_[rank], - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); resource::sync_stream(dev_res); @@ -276,21 +258,21 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, ncclSend(d_neighbors.data_handle(), part_size * sizeof(IdxT), ncclUint8, - clique.root_rank_, - clique.nccl_comms_[rank], - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_clique_root_rank(clique), + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); ncclSend(d_distances.data_handle(), part_size * sizeof(float), ncclUint8, - clique.root_rank_, - clique.nccl_comms_[rank], - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_clique_root_rank(clique), + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); ncclGroupEnd(); resource::sync_stream(dev_res); } } - const auto& root_handle_ = clique.set_current_device_to_root_rank(); + const auto& root_handle_ = raft::resource::set_current_device_to_root_rank(clique); auto h_trans = std::vector(index.num_ranks_); int64_t translation_offset = 0; for (int rank = 0; rank < index.num_ranks_; rank++) { @@ -301,7 +283,7 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, raft::copy(d_trans.data_handle(), h_trans.data(), index.num_ranks_, - resource::get_cuda_stream(root_handle_)); + raft::resource::get_cuda_stream(root_handle_)); cuvs::neighbors::detail::knn_merge_parts(in_distances.data_handle(), in_neighbors.data_handle(), @@ -310,25 +292,25 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, n_rows_of_current_batch, index.num_ranks_, n_neighbors, - resource::get_cuda_stream(root_handle_), + raft::resource::get_cuda_stream(root_handle_), d_trans.data_handle()); raft::copy(neighbors.data_handle() + output_offset, out_neighbors.data_handle(), part_size, - resource::get_cuda_stream(root_handle_)); + raft::resource::get_cuda_stream(root_handle_)); raft::copy(distances.data_handle() + output_offset, out_distances.data_handle(), part_size, - resource::get_cuda_stream(root_handle_)); + raft::resource::get_cuda_stream(root_handle_)); resource::sync_stream(root_handle_); } } template -void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, - const index& index, +void sharded_search_with_tree_merge(const raft::resources& clique, + const mg_index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, @@ -351,10 +333,8 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - auto& ann_if = index.ann_interfaces_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; int64_t part_size = n_rows_of_current_batch * n_neighbors; @@ -377,11 +357,11 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, neighbors_view.data_handle(), (IdxT)translation_offset, part_size, - resource::get_cuda_stream(dev_res)); + raft::resource::get_cuda_stream(dev_res)); auto d_trans = raft::make_device_vector(dev_res, 2); cudaMemsetAsync( - d_trans.data_handle(), 0, 2 * sizeof(IdxT), resource::get_cuda_stream(dev_res)); + d_trans.data_handle(), 0, 2 * sizeof(IdxT), raft::resource::get_cuda_stream(dev_res)); int64_t remaining = index.num_ranks_; int64_t radix = 2; @@ -399,14 +379,14 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.nccl_comms_[rank], - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); ncclRecv(tmp_distances.data_handle() + part_size, part_size * sizeof(float), ncclUint8, other_id, - clique.nccl_comms_[rank], - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); received_something = true; } } else if (rank % radix == offset) // This is one of the senders @@ -416,14 +396,14 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.nccl_comms_[rank], - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); ncclSend(tmp_distances.data_handle(), part_size * sizeof(float), ncclUint8, other_id, - clique.nccl_comms_[rank], - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); @@ -439,7 +419,7 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, n_rows_of_current_batch, 2, n_neighbors, - resource::get_cuda_stream(dev_res), + raft::resource::get_cuda_stream(dev_res), d_trans.data_handle()); // If done, copy the final result @@ -447,11 +427,11 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, raft::copy(neighbors.data_handle() + output_offset, tmp_neighbors.data_handle(), part_size, - resource::get_cuda_stream(dev_res)); + raft::resource::get_cuda_stream(dev_res)); raft::copy(distances.data_handle() + output_offset, tmp_distances.data_handle(), part_size, - resource::get_cuda_stream(dev_res)); + raft::resource::get_cuda_stream(dev_res)); resource::sync_stream(dev_res); } @@ -462,8 +442,8 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, } template -void run_search_batch(const raft::comms::nccl_clique& clique, - const index& index, +void run_search_batch(const raft::resources& clique, + const mg_index& index, int rank, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view& queries, @@ -475,10 +455,8 @@ void run_search_batch(const raft::comms::nccl_clique& clique, int64_t n_cols, int64_t n_neighbors) { - int dev_id = clique.device_ids_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - const raft::device_resources& dev_res = clique.device_resources_[rank]; - auto& ann_if = index.ann_interfaces_[rank]; + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; auto query_partition = raft::make_host_matrix_view( queries.data_handle() + query_offset, n_rows_of_current_batch, n_cols); @@ -493,46 +471,46 @@ void run_search_batch(const raft::comms::nccl_clique& clique, raft::copy(neighbors.data_handle() + output_offset, d_neighbors.data_handle(), n_rows_of_current_batch * n_neighbors, - resource::get_cuda_stream(dev_res)); + raft::resource::get_cuda_stream(dev_res)); raft::copy(distances.data_handle() + output_offset, d_distances.data_handle(), n_rows_of_current_batch * n_neighbors, - resource::get_cuda_stream(dev_res)); + raft::resource::get_cuda_stream(dev_res)); resource::sync_stream(dev_res); } template -void search(const raft::device_resources& handle, - const index& index, +void search(const raft::resources& clique, + const mg_index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch) + raft::host_matrix_view distances) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - int64_t n_rows = queries.extent(0); int64_t n_cols = queries.extent(1); int64_t n_neighbors = neighbors.extent(1); + int64_t n_rows_per_batch = -1; if (index.mode_ == REPLICATED) { - cuvs::neighbors::mg::replicated_search_mode search_mode; + cuvs::neighbors::replicated_search_mode search_mode; if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>( + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>( search_params); - search_mode = mg_search_params->search_mode; + search_mode = mg_search_params->search_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>( - search_params); - search_mode = mg_search_params->search_mode; + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>(search_params); + search_mode = mg_search_params->search_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>(search_params); - search_mode = mg_search_params->search_mode; + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>(search_params); + search_mode = mg_search_params->search_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } if (search_mode == LOAD_BALANCER) { @@ -542,7 +520,7 @@ void search(const raft::device_resources& handle, int64_t n_batches = raft::ceildiv(n_rows, (int64_t)n_rows_per_batch); if (n_batches <= 1) n_rows_per_batch = n_rows; - RAFT_LOG_INFO( + RAFT_LOG_DEBUG( "REPLICATED SEARCH IN LOAD BALANCER MODE: %d*%drows", n_batches, n_rows_per_batch); #pragma omp parallel for @@ -567,7 +545,7 @@ void search(const raft::device_resources& handle, n_neighbors); } } else if (search_mode == ROUND_ROBIN) { - RAFT_LOG_INFO("REPLICATED SEARCH IN ROUND ROBIN MODE: %d*%drows", 1, n_rows); + RAFT_LOG_DEBUG("REPLICATED SEARCH IN ROUND ROBIN MODE: %d*%drows", 1, n_rows); ASSERT(n_rows <= n_rows_per_batch, "In round-robin mode, n_rows must lower or equal to n_rows_per_batch"); @@ -590,30 +568,32 @@ void search(const raft::device_resources& handle, n_neighbors); } } else if (index.mode_ == SHARDED) { - cuvs::neighbors::mg::sharded_merge_mode merge_mode; + cuvs::neighbors::sharded_merge_mode merge_mode; if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>( + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>( search_params); - merge_mode = mg_search_params->merge_mode; + merge_mode = mg_search_params->merge_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>( - search_params); - merge_mode = mg_search_params->merge_mode; + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>(search_params); + merge_mode = mg_search_params->merge_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>(search_params); - merge_mode = mg_search_params->merge_mode; + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>(search_params); + merge_mode = mg_search_params->merge_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } int64_t n_batches = raft::ceildiv(n_rows, (int64_t)n_rows_per_batch); if (n_batches <= 1) n_rows_per_batch = n_rows; if (merge_mode == MERGE_ON_ROOT_RANK) { - RAFT_LOG_INFO("SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows", - n_batches, - n_rows_per_batch); + RAFT_LOG_DEBUG("SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows", + n_batches, + n_rows_per_batch); sharded_search_with_direct_merge(clique, index, search_params, @@ -626,7 +606,7 @@ void search(const raft::device_resources& handle, n_neighbors, n_batches); } else if (merge_mode == TREE_MERGE) { - RAFT_LOG_INFO( + RAFT_LOG_DEBUG( "SHARDED SEARCH WITH TREE_MERGE MERGE MODE %d*%drows", n_batches, n_rows_per_batch); sharded_search_with_tree_merge(clique, index, @@ -644,23 +624,20 @@ void search(const raft::device_resources& handle, } template -void serialize(const raft::device_resources& handle, - const index& index, +void serialize(const raft::resources& clique, + const mg_index& index, const std::string& filename) { std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - + const auto& handle = raft::resource::set_current_device_to_root_rank(clique); serialize_scalar(handle, of, (int)index.mode_); serialize_scalar(handle, of, index.num_ranks_); for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::serialize(dev_res, ann_if, of); } @@ -668,25 +645,24 @@ void serialize(const raft::device_resources& handle, if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } } -} // namespace cuvs::neighbors::mg::detail +} // namespace cuvs::neighbors::snmg::detail -namespace cuvs::neighbors::mg { +namespace cuvs::neighbors { using namespace cuvs::neighbors; using namespace raft; template -index::index(distribution_mode mode, int num_ranks_) - : mode_(mode), - num_ranks_(num_ranks_), - round_robin_counter_(std::make_shared>(0)) +mg_index::mg_index(const raft::resources& clique, distribution_mode mode) + : mode_(mode), round_robin_counter_(std::make_shared>(0)) { + num_ranks_ = raft::resource::get_nccl_num_ranks(clique); } template -index::index(const raft::device_resources& handle, - const std::string& filename) +mg_index::mg_index(const raft::resources& clique, + const std::string& filename) : round_robin_counter_(std::make_shared>(0)) { - cuvs::neighbors::mg::detail::deserialize(handle, *this, filename); + cuvs::neighbors::snmg::detail::deserialize(clique, *this, filename); } -} // namespace cuvs::neighbors::mg +} // namespace cuvs::neighbors diff --git a/cpp/tests/neighbors/mg.cuh b/cpp/tests/neighbors/mg.cuh index be30ca6153..c2133393b6 100644 --- a/cpp/tests/neighbors/mg.cuh +++ b/cpp/tests/neighbors/mg.cuh @@ -19,8 +19,10 @@ #include "ann_utils.cuh" #include "naive_knn.cuh" -#include -#include +#include +#include +#include +#include namespace cuvs::neighbors::mg { @@ -46,14 +48,14 @@ template class AnnMGTest : public ::testing::TestWithParam { public: AnnMGTest() - : stream_(resource::get_cuda_stream(handle_)), - clique_(raft::resource::get_nccl_clique(handle_)), + : clique_(), ps(::testing::TestWithParam::GetParam()), - d_index_dataset(0, stream_), - d_queries(0, stream_), + d_index_dataset(0, resource::get_cuda_stream(clique_)), + d_queries(0, resource::get_cuda_stream(clique_)), h_index_dataset(0), h_queries(0) { + clique_.set_memory_pool(80); } void testAnnMG() @@ -67,9 +69,10 @@ class AnnMGTest : public ::testing::TestWithParam { std::vector neighbors_snmg_ann_32bits(queries_size); { - rmm::device_uvector distances_ref_dev(queries_size, stream_); - rmm::device_uvector neighbors_ref_dev(queries_size, stream_); - cuvs::neighbors::naive_knn(handle_, + rmm::device_uvector distances_ref_dev(queries_size, resource::get_cuda_stream(clique_)); + rmm::device_uvector neighbors_ref_dev(queries_size, + resource::get_cuda_stream(clique_)); + cuvs::neighbors::naive_knn(clique_, distances_ref_dev.data(), neighbors_ref_dev.data(), d_queries.data(), @@ -79,9 +82,15 @@ class AnnMGTest : public ::testing::TestWithParam { ps.dim, ps.k, ps.metric); - update_host(distances_ref.data(), distances_ref_dev.data(), queries_size, stream_); - update_host(neighbors_ref.data(), neighbors_ref_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); + update_host(distances_ref.data(), + distances_ref_dev.data(), + queries_size, + resource::get_cuda_stream(clique_)); + update_host(neighbors_ref.data(), + neighbors_ref_dev.data(), + queries_size, + resource::get_cuda_stream(clique_)); + resource::sync_stream(clique_); } int64_t n_rows_per_search_batch = 3000; // [3000, 3000, 1000] == 7000 rows @@ -95,7 +104,7 @@ class AnnMGTest : public ::testing::TestWithParam { else d_mode = distribution_mode::SHARDED; - mg::index_params index_params; + mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.adaptive_centers = ps.adaptive_centers; @@ -104,7 +113,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.metric_arg = 0; index_params.mode = d_mode; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = LOAD_BALANCER; @@ -118,20 +127,22 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(handle_, index, "mg_ivf_flat_index"); + auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_flat::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::ivf_flat::serialize(clique_, index, "mg_ivf_flat_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_flat(handle_, "mg_ivf_flat_index"); + cuvs::neighbors::ivf_flat::deserialize(clique_, "mg_ivf_flat_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::ivf_flat::search( + clique_, new_index, search_params, queries, neighbors, distances); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -155,7 +166,7 @@ class AnnMGTest : public ::testing::TestWithParam { else d_mode = distribution_mode::SHARDED; - mg::index_params index_params; + mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.add_data_on_build = false; @@ -163,7 +174,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.metric_arg = 0; index_params.mode = d_mode; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = LOAD_BALANCER; @@ -177,20 +188,22 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(handle_, index, "mg_ivf_pq_index"); + auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_pq::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::ivf_pq::serialize(clique_, index, "mg_ivf_pq_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_pq(handle_, "mg_ivf_pq_index"); + cuvs::neighbors::ivf_pq::deserialize(clique_, "mg_ivf_pq_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::ivf_pq::search( + clique_, new_index, search_params, queries, neighbors, distances); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -214,12 +227,12 @@ class AnnMGTest : public ::testing::TestWithParam { else d_mode = distribution_mode::SHARDED; - mg::index_params index_params; + mg_index_params index_params; index_params.graph_build_params = cagra::graph_build_params::ivf_pq_params( raft::matrix_extent(ps.num_db_vecs, ps.dim)); index_params.mode = d_mode; - mg::search_params search_params; + mg_search_params search_params; auto index_dataset = raft::make_host_matrix_view( h_index_dataset.data(), ps.num_db_vecs, ps.dim); @@ -231,19 +244,21 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::serialize(handle_, index, "mg_cagra_index"); + auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset); + cuvs::neighbors::cagra::serialize(clique_, index, "mg_cagra_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_cagra(handle_, "mg_cagra_index"); + cuvs::neighbors::cagra::deserialize(clique_, "mg_cagra_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::cagra::search( + clique_, new_index, search_params, queries, neighbors, distances); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref_32bits, @@ -267,15 +282,15 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = LOAD_BALANCER; { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::ivf_flat::build(handle_, index_params, index_dataset); - ivf_flat::serialize(handle_, "local_ivf_flat_index", index); + auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset); + ivf_flat::serialize(clique_, "local_ivf_flat_index", index); } auto queries = raft::make_host_matrix_view( @@ -286,17 +301,14 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_flat(handle_, "local_ivf_flat_index"); + cuvs::neighbors::ivf_flat::distribute(clique_, "local_ivf_flat_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); - resource::sync_stream(handle_); + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::ivf_flat::search( + clique_, distributed_index, search_params, queries, neighbors, distances); + + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -319,15 +331,15 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = LOAD_BALANCER; { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::ivf_pq::build(handle_, index_params, index_dataset); - ivf_pq::serialize(handle_, "local_ivf_pq_index", index); + auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset); + ivf_pq::serialize(clique_, "local_ivf_pq_index", index); } auto queries = raft::make_host_matrix_view( @@ -338,17 +350,14 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_pq(handle_, "local_ivf_pq_index"); + cuvs::neighbors::ivf_pq::distribute(clique_, "local_ivf_pq_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); - resource::sync_stream(handle_); + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::ivf_pq::search( + clique_, distributed_index, search_params, queries, neighbors, distances); + + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -368,13 +377,13 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.graph_build_params = cagra::graph_build_params::ivf_pq_params( raft::matrix_extent(ps.num_db_vecs, ps.dim)); - mg::search_params search_params; + mg_search_params search_params; { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::cagra::build(handle_, index_params, index_dataset); - cuvs::neighbors::cagra::serialize(handle_, "local_cagra_index", index); + auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset); + cuvs::neighbors::cagra::serialize(clique_, "local_cagra_index", index); } auto queries = raft::make_host_matrix_view( @@ -385,18 +394,15 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_cagra(handle_, "local_cagra_index"); + cuvs::neighbors::cagra::distribute(clique_, "local_cagra_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); - resource::sync_stream(handle_); + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::cagra::search( + clique_, distributed_index, search_params, queries, neighbors, distances); + + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref_32bits, @@ -414,7 +420,7 @@ class AnnMGTest : public ::testing::TestWithParam { if (ps.algo == algo_t::IVF_FLAT && ps.d_mode == d_mode_t::ROUND_ROBIN) { ASSERT_TRUE(ps.num_queries <= 4); - mg::index_params index_params; + mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.adaptive_centers = ps.adaptive_centers; @@ -423,7 +429,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.metric_arg = 0; index_params.mode = REPLICATED; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = ROUND_ROBIN; @@ -432,8 +438,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_flat::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -448,13 +454,14 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, - index, - search_params, - small_batch_query, - small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::ivf_flat::search(clique_, + index, + search_params, + small_batch_query, + small_batch_neighbors, + small_batch_distances); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), @@ -479,7 +486,7 @@ class AnnMGTest : public ::testing::TestWithParam { if (ps.algo == algo_t::IVF_PQ && ps.d_mode == d_mode_t::ROUND_ROBIN) { ASSERT_TRUE(ps.num_queries <= 4); - mg::index_params index_params; + mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.add_data_on_build = false; @@ -487,7 +494,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.metric_arg = 0; index_params.mode = REPLICATED; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = ROUND_ROBIN; @@ -496,8 +503,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_pq::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -512,13 +519,14 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, - index, - search_params, - small_batch_query, - small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::ivf_pq::search(clique_, + index, + search_params, + small_batch_query, + small_batch_neighbors, + small_batch_distances); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), @@ -543,12 +551,12 @@ class AnnMGTest : public ::testing::TestWithParam { if (ps.algo == algo_t::CAGRA && ps.d_mode == d_mode_t::ROUND_ROBIN) { ASSERT_TRUE(ps.num_queries <= 4); - mg::index_params index_params; + mg_index_params index_params; index_params.graph_build_params = cagra::graph_build_params::ivf_pq_params( raft::matrix_extent(ps.num_db_vecs, ps.dim)); index_params.mode = REPLICATED; - mg::search_params search_params; + mg_search_params search_params; search_params.search_mode = ROUND_ROBIN; auto index_dataset = raft::make_host_matrix_view( @@ -556,7 +564,7 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); + auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -571,13 +579,14 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, - index, - search_params, - small_batch_query, - small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::cagra::search(clique_, + index, + search_params, + small_batch_query, + small_batch_neighbors, + small_batch_distances); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), @@ -602,37 +611,35 @@ class AnnMGTest : public ::testing::TestWithParam { void SetUp() override { - d_index_dataset.resize(ps.num_db_vecs * ps.dim, stream_); - d_queries.resize(ps.num_queries * ps.dim, stream_); + d_index_dataset.resize(ps.num_db_vecs * ps.dim, resource::get_cuda_stream(clique_)); + d_queries.resize(ps.num_queries * ps.dim, resource::get_cuda_stream(clique_)); h_index_dataset.resize(ps.num_db_vecs * ps.dim); h_queries.resize(ps.num_queries * ps.dim); raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { raft::random::uniform( - handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(0.1), DataT(2.0)); - raft::random::uniform(handle_, r, d_queries.data(), d_queries.size(), DataT(0.1), DataT(2.0)); + clique_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(0.1), DataT(2.0)); + raft::random::uniform(clique_, r, d_queries.data(), d_queries.size(), DataT(0.1), DataT(2.0)); } else { raft::random::uniformInt( - handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(1), DataT(20)); - raft::random::uniformInt(handle_, r, d_queries.data(), d_queries.size(), DataT(1), DataT(20)); + clique_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(1), DataT(20)); + raft::random::uniformInt(clique_, r, d_queries.data(), d_queries.size(), DataT(1), DataT(20)); } raft::copy(h_index_dataset.data(), d_index_dataset.data(), d_index_dataset.size(), - resource::get_cuda_stream(handle_)); + resource::get_cuda_stream(clique_)); raft::copy( - h_queries.data(), d_queries.data(), d_queries.size(), resource::get_cuda_stream(handle_)); - resource::sync_stream(handle_); + h_queries.data(), d_queries.data(), d_queries.size(), resource::get_cuda_stream(clique_)); + resource::sync_stream(clique_); } void TearDown() override {} private: - raft::device_resources handle_; - rmm::cuda_stream_view stream_; - raft::comms::nccl_clique clique_; + raft::device_resources_snmg clique_; AnnMGInputs ps; std::vector h_index_dataset; std::vector h_queries; diff --git a/docs/source/cpp_api/neighbors_mg.rst b/docs/source/cpp_api/neighbors_mg.rst index b68defec9f..6192c4f9d6 100644 --- a/docs/source/cpp_api/neighbors_mg.rst +++ b/docs/source/cpp_api/neighbors_mg.rst @@ -7,9 +7,9 @@ The SNMG (single-node multi-GPUs) ANN API provides a set of functions to deploy :language: c++ :class: highlight -``#include `` +``#include `` -namespace *cuvs::neighbors::mg* +namespace *cuvs::neighbors* Index build parameters ---------------------- @@ -20,7 +20,7 @@ Index build parameters :content-only: Search parameters ----------------------- +----------------- .. doxygengroup:: mg_cpp_search_params :project: cuvs