From ee98593fa758555820ad35c25c4787565dea0060 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 8 Nov 2024 18:23:28 +0100 Subject: [PATCH 01/17] Account for RAFT update --- cpp/CMakeLists.txt | 1 - cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 2 +- .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 2 +- .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 2 +- cpp/src/neighbors/mg/generate_mg.py | 12 ++++++------ cpp/src/neighbors/mg/mg.cuh | 18 +++++++++--------- .../neighbors/mg/mg_cagra_float_uint32_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu | 4 ++-- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 4 ++-- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu | 4 ++-- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/nccl_comm.cpp | 8 -------- cpp/test/neighbors/mg.cuh | 2 +- 19 files changed, 41 insertions(+), 50 deletions(-) delete mode 100644 cpp/src/neighbors/mg/nccl_comm.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c493af488e..ce925713e0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -298,7 +298,6 @@ if(BUILD_SHARED_LIBS) src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu src/neighbors/mg/omp_checks.cpp - src/neighbors/mg/nccl_comm.cpp ) endif() diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index 50c1ff4dbf..cf150bf985 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -47,7 +47,7 @@ class cuvs_mg_cagra : public algo, public algo_gpu { index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index 54a0d2facb..d313bbcbc0 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -41,7 +41,7 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { { index_params_.metric = parse_metric_type(metric); // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index 84aea7d4a3..588e4798a5 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -41,7 +41,7 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { { index_params_.metric = parse_metric_type(metric); // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index af5e605456..b9fdd0aa9c 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -57,7 +57,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::build(handle, index, \\ static_cast(&index_params), \\ @@ -105,7 +105,7 @@ index, T, IdxT> distribute_flat(const raft::device_resources& handle, \\ const std::string& filename) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ return idx; \\ @@ -118,7 +118,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::build(handle, index, \\ static_cast(&index_params), \\ @@ -166,7 +166,7 @@ index, T, IdxT> distribute_pq(const raft::device_resources& handle, \\ const std::string& filename) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ return idx; \\ @@ -179,7 +179,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::build(handle, index, \\ static_cast(&index_params), \\ @@ -219,7 +219,7 @@ index, T, IdxT> distribute_cagra(const raft::device_resources& handle, \\ const std::string& filename) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ return idx; \\ diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index d3f635bc40..3679e89da1 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -51,7 +51,7 @@ void deserialize_and_distribute(const raft::device_resources& handle, index& index, const std::string& filename) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); for (int rank = 0; rank < index.num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; @@ -70,7 +70,7 @@ void deserialize(const raft::device_resources& handle, std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); index.num_ranks_ = deserialize_scalar(handle, is); @@ -98,7 +98,7 @@ void build(const raft::device_resources& handle, const cuvs::neighbors::index_params* index_params, raft::host_matrix_view index_dataset) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); if (index.mode_ == REPLICATED) { int64_t n_rows = index_dataset.extent(0); @@ -145,7 +145,7 @@ void extend(const raft::device_resources& handle, raft::host_matrix_view new_vectors, std::optional> new_indices) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); int64_t n_rows = new_vectors.extent(0); if (index.mode_ == REPLICATED) { @@ -191,7 +191,7 @@ void extend(const raft::device_resources& handle, } template -void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, +void sharded_search_with_direct_merge(const raft::core::nccl_clique& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -325,7 +325,7 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, } template -void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, +void sharded_search_with_tree_merge(const raft::core::nccl_clique& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -460,7 +460,7 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, } template -void run_search_batch(const raft::comms::nccl_clique& clique, +void run_search_batch(const raft::core::nccl_clique& clique, const index& index, int rank, const cuvs::neighbors::search_params* search_params, @@ -509,7 +509,7 @@ void search(const raft::device_resources& handle, raft::host_matrix_view distances, int64_t n_rows_per_batch) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); int64_t n_rows = queries.extent(0); int64_t n_cols = queries.extent(1); @@ -649,7 +649,7 @@ void serialize(const raft::device_resources& handle, std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); serialize_scalar(handle, of, (int)index.mode_); serialize_scalar(handle, of, index.num_ranks_); diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index b11610fb48..0f6154395e 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -80,7 +80,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_cagra( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index 8f76c69a34..ad041155b9 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -80,7 +80,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_cagra( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index 67b88d7429..5099a04175 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -80,7 +80,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_cagra( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index f721749233..df33190f99 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -80,7 +80,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_cagra( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index 4495e2527c..5fcf27301e 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_flat( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 5494414a6a..75128ea0c4 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_flat( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index 35df2146bf..cafdbdcbb0 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_flat( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index c671740e61..8e7a710143 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index b167239c68..7fe19a8994 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index 127baf8fd5..ca56c90ed1 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index 869e009a5b..ed40fa2305 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/nccl_comm.cpp b/cpp/src/neighbors/mg/nccl_comm.cpp deleted file mode 100644 index c4556957ae..0000000000 --- a/cpp/src/neighbors/mg/nccl_comm.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include -#include - -namespace raft::comms { -void build_comms_nccl_only(raft::resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank) -{ -} -} // namespace raft::comms diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index be30ca6153..6b98e975a0 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -632,7 +632,7 @@ class AnnMGTest : public ::testing::TestWithParam { private: raft::device_resources handle_; rmm::cuda_stream_view stream_; - raft::comms::nccl_clique clique_; + raft::core::nccl_clique clique_; AnnMGInputs ps; std::vector h_index_dataset; std::vector h_queries; From 3a99b406dbf71d58c02148b850524ff75248df52 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 15 Nov 2024 14:25:38 +0100 Subject: [PATCH 02/17] use new device_resources_snmg --- .../ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 16 +- .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 15 +- .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 15 +- cpp/include/cuvs/neighbors/mg.hpp | 494 +++++++++--------- cpp/src/neighbors/mg/generate_mg.py | 74 ++- cpp/src/neighbors/mg/mg.cuh | 40 +- .../neighbors/mg/mg_cagra_float_uint32_t.cu | 22 +- .../neighbors/mg/mg_cagra_half_uint32_t.cu | 22 +- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 22 +- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 22 +- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 26 +- .../neighbors/mg/mg_flat_int8_t_int64_t.cu | 26 +- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 26 +- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 128 +++-- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 128 +++-- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 128 +++-- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 128 +++-- cpp/test/neighbors/mg.cuh | 84 +-- 18 files changed, 687 insertions(+), 729 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index cf150bf985..f5a3944829 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -18,7 +18,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_cagra_wrapper.h" #include -#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -41,13 +41,10 @@ class cuvs_mg_cagra : public algo, public algo_gpu { }; cuvs_mg_cagra(Metric metric, int dim, const build_param& param, int concurrent_searches = 1) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); - - // init nccl clique outside as to not affect benchmark - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; @@ -88,6 +85,7 @@ class cuvs_mg_cagra : public algo, public algo_gpu { private: raft::device_resources handle_; + raft::device_resources_snmg clique_; float refine_ratio_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; @@ -105,7 +103,7 @@ void cuvs_mg_cagra::build(const T* dataset, size_t nrow) auto dataset_view = raft::make_host_matrix_view(dataset, nrow, dim_); - auto idx = cuvs::neighbors::mg::build(handle_, build_params, dataset_view); + auto idx = cuvs::neighbors::mg::build(clique_, build_params, dataset_view); index_ = std::make_shared, T, IdxT>>( std::move(idx)); @@ -132,7 +130,7 @@ void cuvs_mg_cagra::set_search_dataset(const T* dataset, size_t nrow) template void cuvs_mg_cagra::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::mg::serialize(clique_, *index_, file); } template @@ -140,7 +138,7 @@ void cuvs_mg_cagra::load(const std::string& file) { index_ = std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_cagra(handle_, file))); + std::move(cuvs::neighbors::mg::deserialize_cagra(clique_, file))); } template @@ -164,7 +162,7 @@ void cuvs_mg_cagra::search_base( raft::make_host_matrix_view(distances, batch_size, k); cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } template diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index d313bbcbc0..05e68b26ba 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -19,7 +19,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_ivf_flat_wrapper.h" #include -#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -37,11 +37,9 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { }; cuvs_mg_ivf_flat(Metric metric, int dim, const build_param& param) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); - // init nccl clique outside as to not affect benchmark - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; @@ -74,6 +72,7 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { private: raft::device_resources handle_; + raft::device_resources_snmg clique_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; std::shared_ptr, T, IdxT>> @@ -85,7 +84,7 @@ void cuvs_mg_ivf_flat::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(handle_, index_params_, dataset_view); + auto idx = cuvs::neighbors::mg::build(clique_, index_params_, dataset_view); index_ = std::make_shared< cuvs::neighbors::mg::index, T, IdxT>>(std::move(idx)); } @@ -105,7 +104,7 @@ void cuvs_mg_ivf_flat::set_search_param(const search_param_base& param) template void cuvs_mg_ivf_flat::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::mg::serialize(clique_, *index_, file); } template @@ -113,7 +112,7 @@ void cuvs_mg_ivf_flat::load(const std::string& file) { index_ = std::make_shared< cuvs::neighbors::mg::index, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_flat(handle_, file))); + std::move(cuvs::neighbors::mg::deserialize_flat(clique_, file))); } template @@ -134,7 +133,7 @@ void cuvs_mg_ivf_flat::search( distances, IdxT(batch_size), IdxT(k)); cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } } // namespace cuvs::bench \ No newline at end of file diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index 588e4798a5..d430d27bf0 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -19,7 +19,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_ivf_pq_wrapper.h" #include -#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -37,11 +37,9 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { }; cuvs_mg_ivf_pq(Metric metric, int dim, const build_param& param) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); - // init nccl clique outside as to not affect benchmark - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; @@ -74,6 +72,7 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { private: raft::device_resources handle_; + raft::device_resources_snmg clique_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; std::shared_ptr, T, IdxT>> index_; @@ -84,7 +83,7 @@ void cuvs_mg_ivf_pq::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(handle_, index_params_, dataset_view); + auto idx = cuvs::neighbors::mg::build(clique_, index_params_, dataset_view); index_ = std::make_shared, T, IdxT>>( std::move(idx)); @@ -104,7 +103,7 @@ void cuvs_mg_ivf_pq::set_search_param(const search_param_base& param) template void cuvs_mg_ivf_pq::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::mg::serialize(clique_, *index_, file); } template @@ -112,7 +111,7 @@ void cuvs_mg_ivf_pq::load(const std::string& file) { index_ = std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_pq(handle_, file))); + std::move(cuvs::neighbors::mg::deserialize_pq(clique_, file))); } template @@ -133,7 +132,7 @@ void cuvs_mg_ivf_pq::search( distances, IdxT(batch_size), IdxT(k)); cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } } // namespace cuvs::bench \ No newline at end of file diff --git a/cpp/include/cuvs/neighbors/mg.hpp b/cpp/include/cuvs/neighbors/mg.hpp index 4657fa8fb0..86572adebd 100644 --- a/cpp/include/cuvs/neighbors/mg.hpp +++ b/cpp/include/cuvs/neighbors/mg.hpp @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include @@ -101,7 +101,7 @@ using namespace raft; template struct index { index(distribution_mode mode, int num_ranks_); - index(const raft::device_resources& handle, const std::string& filename); + index(const raft::device_resources_snmg& clique, const std::string& filename); index(const index&) = delete; index(index&&) = default; @@ -124,18 +124,18 @@ struct index { * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, float, int64_t>; @@ -146,18 +146,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, int8_t, int64_t>; @@ -168,18 +168,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, uint8_t, int64_t>; @@ -190,18 +190,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, float, int64_t>; @@ -212,18 +212,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, half, int64_t>; @@ -234,18 +234,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, int8_t, int64_t>; @@ -256,18 +256,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, uint8_t, int64_t>; @@ -278,18 +278,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, float, uint32_t>; @@ -300,18 +300,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, half, uint32_t>; @@ -322,18 +322,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, int8_t, uint32_t>; @@ -344,18 +344,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, uint8_t, uint32_t>; @@ -368,20 +368,20 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, float, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -392,20 +392,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, int8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -416,20 +416,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, uint8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -440,20 +440,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, float, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -464,20 +464,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, half, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -488,20 +488,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, int8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -512,20 +512,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, uint8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -536,20 +536,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, float, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -560,20 +560,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, half, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -584,20 +584,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, int8_t, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -608,20 +608,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, uint8_t, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -634,15 +634,15 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -651,7 +651,7 @@ void extend(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -665,15 +665,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -682,7 +682,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -696,15 +696,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -713,7 +713,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -727,15 +727,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -744,7 +744,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -758,15 +758,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -775,7 +775,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, half, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -789,15 +789,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -806,7 +806,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -820,15 +820,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -837,7 +837,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -851,15 +851,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -868,7 +868,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, float, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -882,15 +882,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -899,7 +899,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, half, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -913,15 +913,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -930,7 +930,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, int8_t, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -944,15 +944,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -961,7 +961,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, uint8_t, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -977,19 +977,19 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const std::string& filename); @@ -999,19 +999,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const std::string& filename); @@ -1021,19 +1021,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const std::string& filename); @@ -1043,19 +1043,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const std::string& filename); @@ -1065,19 +1065,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, half, int64_t>& index, const std::string& filename); @@ -1087,19 +1087,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const std::string& filename); @@ -1109,19 +1109,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const std::string& filename); @@ -1131,19 +1131,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, float, uint32_t>& index, const std::string& filename); @@ -1153,19 +1153,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, half, uint32_t>& index, const std::string& filename); @@ -1175,19 +1175,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, int8_t, uint32_t>& index, const std::string& filename); @@ -1197,19 +1197,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, uint8_t, uint32_t>& index, const std::string& filename); @@ -1221,21 +1221,21 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_flat(handle, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::mg::deserialize_flat(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized * */ template -auto deserialize_flat(const raft::device_resources& handle, const std::string& filename) +auto deserialize_flat(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_deserialize @@ -1244,20 +1244,20 @@ auto deserialize_flat(const raft::device_resources& handle, const std::string& f * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_pq(handle, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::mg::deserialize_pq(clique, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized * */ template -auto deserialize_pq(const raft::device_resources& handle, const std::string& filename) +auto deserialize_pq(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_deserialize @@ -1266,21 +1266,21 @@ auto deserialize_pq(const raft::device_resources& handle, const std::string& fil * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_cagra(handle, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::mg::deserialize_cagra(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized * */ template -auto deserialize_cagra(const raft::device_resources& handle, const std::string& filename) +auto deserialize_cagra(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \defgroup mg_cpp_distribute ANN MG local index distribution @@ -1292,21 +1292,21 @@ auto deserialize_cagra(const raft::device_resources& handle, const std::string& * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::ivf_flat::index_params index_params; - * auto index = cuvs::neighbors::ivf_flat::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_flat::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_flat(handle, filename); + * cuvs::neighbors::ivf_flat::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::mg::distribute_flat(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized : a local index * */ template -auto distribute_flat(const raft::device_resources& handle, const std::string& filename) +auto distribute_flat(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_distribute @@ -1316,20 +1316,20 @@ auto distribute_flat(const raft::device_resources& handle, const std::string& fi * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::ivf_pq::index_params index_params; - * auto index = cuvs::neighbors::ivf_pq::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_pq::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_pq(handle, filename); + * cuvs::neighbors::ivf_pq::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::mg::distribute_pq(clique, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized : a local index * */ template -auto distribute_pq(const raft::device_resources& handle, const std::string& filename) +auto distribute_pq(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_distribute @@ -1339,21 +1339,21 @@ auto distribute_pq(const raft::device_resources& handle, const std::string& file * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::cagra::index_params index_params; - * auto index = cuvs::neighbors::cagra::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::cagra::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_cagra(handle, filename); + * cuvs::neighbors::cagra::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::mg::distribute_cagra(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized : a local index * */ template -auto distribute_cagra(const raft::device_resources& handle, const std::string& filename) +auto distribute_cagra(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; } // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index b9fdd0aa9c..023f5baf36 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -53,27 +53,26 @@ flat_macro = """ #define CUVS_INST_MG_FLAT(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ + index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void extend(const raft::device_resources& handle, \\ + void extend(const raft::device_resources_snmg& clique, \\ index, T, IdxT>& index, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \\ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ } \\ \\ - void search(const raft::device_resources& handle, \\ + void search(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ @@ -81,60 +80,58 @@ raft::host_matrix_view distances, \\ int64_t n_rows_per_batch) \\ { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ static_cast(&search_params), \\ queries, neighbors, distances, n_rows_per_batch); \\ } \\ \\ - void serialize(const raft::device_resources& handle, \\ + void serialize(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ } \\ \\ template<> \\ - index, T, IdxT> deserialize_flat(const raft::device_resources& handle, \\ + index, T, IdxT> deserialize_flat(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(handle, filename); \\ + auto idx = index, T, IdxT>(clique, filename); \\ return idx; \\ } \\ \\ template<> \\ - index, T, IdxT> distribute_flat(const raft::device_resources& handle, \\ + index, T, IdxT> distribute_flat(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } """ pq_macro = """ #define CUVS_INST_MG_PQ(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ + index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void extend(const raft::device_resources& handle, \\ + void extend(const raft::device_resources_snmg& clique, \\ index, T, IdxT>& index, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \\ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ } \\ \\ - void search(const raft::device_resources& handle, \\ + void search(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ @@ -142,52 +139,50 @@ raft::host_matrix_view distances, \\ int64_t n_rows_per_batch) \\ { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ static_cast(&search_params), \\ queries, neighbors, distances, n_rows_per_batch); \\ } \\ \\ - void serialize(const raft::device_resources& handle, \\ + void serialize(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ } \\ \\ template<> \\ - index, T, IdxT> deserialize_pq(const raft::device_resources& handle, \\ + index, T, IdxT> deserialize_pq(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(handle, filename); \\ + auto idx = index, T, IdxT>(clique, filename); \\ return idx; \\ } \\ \\ template<> \\ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \\ + index, T, IdxT> distribute_pq(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } """ cagra_macro = """ #define CUVS_INST_MG_CAGRA(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ + index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void search(const raft::device_resources& handle, \\ + void search(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ @@ -195,33 +190,32 @@ raft::host_matrix_view distances, \\ int64_t n_rows_per_batch) \\ { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ static_cast(&search_params), \\ queries, neighbors, distances, n_rows_per_batch); \\ } \\ \\ - void serialize(const raft::device_resources& handle, \\ + void serialize(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ } \\ \\ template<> \\ - index, T, IdxT> deserialize_cagra(const raft::device_resources& handle, \\ + index, T, IdxT> deserialize_cagra(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(handle, filename); \\ + auto idx = index, T, IdxT>(clique, filename); \\ return idx; \\ } \\ \\ template<> \\ - index, T, IdxT> distribute_cagra(const raft::device_resources& handle, \\ + index, T, IdxT> distribute_cagra(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } """ diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 3679e89da1..0e113ef72e 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -17,7 +17,6 @@ #pragma once #include "../detail/knn_merge_parts.cuh" -#include #include #include #include @@ -47,11 +46,10 @@ using namespace raft; // local index deserialization and distribution template -void deserialize_and_distribute(const raft::device_resources& handle, +void deserialize_and_distribute(const raft::device_resources_snmg& clique, index& index, const std::string& filename) { - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); for (int rank = 0; rank < index.num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; @@ -63,17 +61,16 @@ void deserialize_and_distribute(const raft::device_resources& handle, // MG index deserialization template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::device_resources_snmg& clique, index& index, const std::string& filename) { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - - index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); - index.num_ranks_ = deserialize_scalar(handle, is); + const auto& handle = clique.set_current_device_to_root_rank(); + index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); + index.num_ranks_ = deserialize_scalar(handle, is); if (index.num_ranks_ != clique.num_ranks_) { RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks", @@ -93,13 +90,11 @@ void deserialize(const raft::device_resources& handle, } template -void build(const raft::device_resources& handle, +void build(const raft::device_resources_snmg& clique, index& index, const cuvs::neighbors::index_params* index_params, raft::host_matrix_view index_dataset) { - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - if (index.mode_ == REPLICATED) { int64_t n_rows = index_dataset.extent(0); RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); @@ -140,13 +135,11 @@ void build(const raft::device_resources& handle, } template -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - int64_t n_rows = new_vectors.extent(0); if (index.mode_ == REPLICATED) { RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); @@ -191,7 +184,7 @@ void extend(const raft::device_resources& handle, } template -void sharded_search_with_direct_merge(const raft::core::nccl_clique& clique, +void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -325,7 +318,7 @@ void sharded_search_with_direct_merge(const raft::core::nccl_clique& clique, } template -void sharded_search_with_tree_merge(const raft::core::nccl_clique& clique, +void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -460,7 +453,7 @@ void sharded_search_with_tree_merge(const raft::core::nccl_clique& clique, } template -void run_search_batch(const raft::core::nccl_clique& clique, +void run_search_batch(const raft::device_resources_snmg& clique, const index& index, int rank, const cuvs::neighbors::search_params* search_params, @@ -501,7 +494,7 @@ void run_search_batch(const raft::core::nccl_clique& clique, } template -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -509,8 +502,6 @@ void search(const raft::device_resources& handle, raft::host_matrix_view distances, int64_t n_rows_per_batch) { - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - int64_t n_rows = queries.extent(0); int64_t n_cols = queries.extent(1); int64_t n_neighbors = neighbors.extent(1); @@ -642,15 +633,14 @@ void search(const raft::device_resources& handle, } template -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index& index, const std::string& filename) { std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - + const auto& handle = clique.set_current_device_to_root_rank(); serialize_scalar(handle, of, (int)index.mode_); serialize_scalar(handle, of, index.num_ranks_); @@ -681,10 +671,10 @@ index::index(distribution_mode mode, int num_ranks_) } template -index::index(const raft::device_resources& handle, +index::index(const raft::device_resources_snmg& clique, const std::string& filename) : round_robin_counter_(std::make_shared>(0)) { - cuvs::neighbors::mg::detail::deserialize(handle, *this, filename); + cuvs::neighbors::mg::detail::deserialize(clique, *this, filename); } } // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index 0f6154395e..c3ef3705e6 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -29,21 +29,20 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_CAGRA(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -52,7 +51,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -61,28 +60,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_CAGRA(float, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index ad041155b9..ea9ec672ba 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -29,21 +29,20 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_CAGRA(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -52,7 +51,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -61,28 +60,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_CAGRA(half, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index 5099a04175..aeae0f2ccf 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -29,21 +29,20 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_CAGRA(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -52,7 +51,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -61,28 +60,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_CAGRA(int8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index df33190f99..22421d6f08 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -29,21 +29,20 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_CAGRA(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -52,7 +51,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -61,28 +60,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_CAGRA(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index 5fcf27301e..423aa02845 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -29,29 +29,28 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_FLAT(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void extend(const raft::device_resources& handle, \ + void extend(const raft::device_resources_snmg& clique, \ index, T, IdxT>& index, \ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -60,7 +59,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -69,28 +68,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_FLAT(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 75128ea0c4..06bb7af267 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -29,29 +29,28 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_FLAT(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void extend(const raft::device_resources& handle, \ + void extend(const raft::device_resources_snmg& clique, \ index, T, IdxT>& index, \ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -60,7 +59,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -69,28 +68,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_FLAT(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index cafdbdcbb0..bbf7d96f86 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -29,29 +29,28 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_FLAT(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void extend(const raft::device_resources& handle, \ + void extend(const raft::device_resources_snmg& clique, \ index, T, IdxT>& index, \ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -60,7 +59,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -69,28 +68,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_FLAT(uint8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index 8e7a710143..441a09e2fe 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index 7fe19a8994..bf6126feeb 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(half, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index ca56c90ed1..3921f810c3 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index ed40fa2305..8f4683fd79 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(uint8_t, int64_t); diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index 6b98e975a0..eb97b583cb 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -20,7 +20,7 @@ #include "naive_knn.cuh" #include -#include +#include namespace cuvs::neighbors::mg { @@ -47,7 +47,7 @@ class AnnMGTest : public ::testing::TestWithParam { public: AnnMGTest() : stream_(resource::get_cuda_stream(handle_)), - clique_(raft::resource::get_nccl_clique(handle_)), + clique_(), ps(::testing::TestWithParam::GetParam()), d_index_dataset(0, stream_), d_queries(0, stream_), @@ -69,7 +69,7 @@ class AnnMGTest : public ::testing::TestWithParam { { rmm::device_uvector distances_ref_dev(queries_size, stream_); rmm::device_uvector neighbors_ref_dev(queries_size, stream_); - cuvs::neighbors::naive_knn(handle_, + cuvs::neighbors::naive_knn(clique_, distances_ref_dev.data(), neighbors_ref_dev.data(), d_queries.data(), @@ -118,19 +118,19 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(handle_, index, "mg_ivf_flat_index"); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::mg::serialize(clique_, index, "mg_ivf_flat_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_flat(handle_, "mg_ivf_flat_index"); + cuvs::neighbors::mg::deserialize_flat(clique_, "mg_ivf_flat_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -177,19 +177,19 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(handle_, index, "mg_ivf_pq_index"); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::mg::serialize(clique_, index, "mg_ivf_pq_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_pq(handle_, "mg_ivf_pq_index"); + cuvs::neighbors::mg::deserialize_pq(clique_, "mg_ivf_pq_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -231,18 +231,18 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::serialize(handle_, index, "mg_cagra_index"); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::serialize(clique_, index, "mg_cagra_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_cagra(handle_, "mg_cagra_index"); + cuvs::neighbors::mg::deserialize_cagra(clique_, "mg_cagra_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -274,8 +274,8 @@ class AnnMGTest : public ::testing::TestWithParam { { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::ivf_flat::build(handle_, index_params, index_dataset); - ivf_flat::serialize(handle_, "local_ivf_flat_index", index); + auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset); + ivf_flat::serialize(clique_, "local_ivf_flat_index", index); } auto queries = raft::make_host_matrix_view( @@ -286,9 +286,9 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_flat(handle_, "local_ivf_flat_index"); + cuvs::neighbors::mg::distribute_flat(clique_, "local_ivf_flat_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, distributed_index, search_params, queries, @@ -326,8 +326,8 @@ class AnnMGTest : public ::testing::TestWithParam { { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::ivf_pq::build(handle_, index_params, index_dataset); - ivf_pq::serialize(handle_, "local_ivf_pq_index", index); + auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset); + ivf_pq::serialize(clique_, "local_ivf_pq_index", index); } auto queries = raft::make_host_matrix_view( @@ -338,9 +338,9 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_pq(handle_, "local_ivf_pq_index"); + cuvs::neighbors::mg::distribute_pq(clique_, "local_ivf_pq_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, distributed_index, search_params, queries, @@ -373,8 +373,8 @@ class AnnMGTest : public ::testing::TestWithParam { { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::cagra::build(handle_, index_params, index_dataset); - cuvs::neighbors::cagra::serialize(handle_, "local_cagra_index", index); + auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset); + cuvs::neighbors::cagra::serialize(clique_, "local_cagra_index", index); } auto queries = raft::make_host_matrix_view( @@ -385,10 +385,10 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_cagra(handle_, "local_cagra_index"); + cuvs::neighbors::mg::distribute_cagra(clique_, "local_cagra_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, distributed_index, search_params, queries, @@ -432,8 +432,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -448,7 +448,7 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, index, search_params, small_batch_query, @@ -496,8 +496,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -512,7 +512,7 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, index, search_params, small_batch_query, @@ -556,7 +556,7 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -571,7 +571,7 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, index, search_params, small_batch_query, @@ -610,12 +610,12 @@ class AnnMGTest : public ::testing::TestWithParam { raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { raft::random::uniform( - handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(0.1), DataT(2.0)); - raft::random::uniform(handle_, r, d_queries.data(), d_queries.size(), DataT(0.1), DataT(2.0)); + clique_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(0.1), DataT(2.0)); + raft::random::uniform(clique_, r, d_queries.data(), d_queries.size(), DataT(0.1), DataT(2.0)); } else { raft::random::uniformInt( - handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(1), DataT(20)); - raft::random::uniformInt(handle_, r, d_queries.data(), d_queries.size(), DataT(1), DataT(20)); + clique_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(1), DataT(20)); + raft::random::uniformInt(clique_, r, d_queries.data(), d_queries.size(), DataT(1), DataT(20)); } raft::copy(h_index_dataset.data(), @@ -632,7 +632,7 @@ class AnnMGTest : public ::testing::TestWithParam { private: raft::device_resources handle_; rmm::cuda_stream_view stream_; - raft::core::nccl_clique clique_; + raft::device_resources_snmg clique_; AnnMGInputs ps; std::vector h_index_dataset; std::vector h_queries; From e16b68e7bddf1738d4189c3b5586520deb341ecb Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 15 Nov 2024 15:25:27 +0100 Subject: [PATCH 03/17] improved device_resources_snmg --- .../ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 2 + .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 2 + .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 2 + cpp/src/neighbors/mg/generate_mg.py | 12 +- cpp/src/neighbors/mg/mg.cuh | 96 ++++++------- .../neighbors/mg/mg_cagra_float_uint32_t.cu | 110 +++++++-------- .../neighbors/mg/mg_cagra_half_uint32_t.cu | 110 +++++++-------- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 110 +++++++-------- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 110 +++++++-------- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 126 +++++++++--------- .../neighbors/mg/mg_flat_int8_t_int64_t.cu | 126 +++++++++--------- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 126 +++++++++--------- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 126 +++++++++--------- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 126 +++++++++--------- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 126 +++++++++--------- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 126 +++++++++--------- cpp/test/neighbors/mg.cuh | 1 + 17 files changed, 712 insertions(+), 725 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index f5a3944829..27a0fd7acc 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -45,6 +45,8 @@ class cuvs_mg_cagra : public algo, public algo_gpu { { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index 05e68b26ba..5e811da335 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -40,6 +40,8 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index d430d27bf0..c4a820cad9 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -40,6 +40,8 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index 023f5baf36..26e81da169 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -57,7 +57,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -104,7 +104,7 @@ index, T, IdxT> distribute_flat(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } @@ -116,7 +116,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -163,7 +163,7 @@ index, T, IdxT> distribute_pq(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } @@ -175,7 +175,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -214,7 +214,7 @@ index, T, IdxT> distribute_cagra(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 0e113ef72e..c6812b1e10 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -51,10 +51,8 @@ void deserialize_and_distribute(const raft::device_resources_snmg& clique, const std::string& filename) { for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, filename); } } @@ -72,17 +70,15 @@ void deserialize(const raft::device_resources_snmg& clique, index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); index.num_ranks_ = deserialize_scalar(handle, is); - if (index.num_ranks_ != clique.num_ranks_) { + if (index.num_ranks_ != clique.get_num_ranks()) { RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks", index.num_ranks_, - clique.num_ranks_); + clique.get_num_ranks()); } for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, is); } @@ -102,10 +98,8 @@ void build(const raft::device_resources_snmg& clique, index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, index_dataset); resource::sync_stream(dev_res); } @@ -119,13 +113,11 @@ void build(const raft::device_resources_snmg& clique, index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - int64_t offset = rank * n_rows_per_shard; - int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); - auto partition = raft::make_host_matrix_view( + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + int64_t offset = rank * n_rows_per_shard; + int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); + auto partition = raft::make_host_matrix_view( partition_ptr, n_rows_of_current_shard, n_cols); auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, partition); @@ -146,10 +138,8 @@ void extend(const raft::device_resources_snmg& clique, #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::extend(dev_res, ann_if, new_vectors, new_indices); resource::sync_stream(dev_res); } @@ -161,13 +151,11 @@ void extend(const raft::device_resources_snmg& clique, #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - int64_t offset = rank * n_rows_per_shard; - int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); - auto new_vectors_part = raft::make_host_matrix_view( + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + int64_t offset = rank * n_rows_per_shard; + int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); + auto new_vectors_part = raft::make_host_matrix_view( new_vectors_ptr, n_rows_of_current_shard, n_cols); std::optional> new_indices_part = std::nullopt; @@ -219,13 +207,11 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - if (rank == clique.root_rank_) { // root rank - uint64_t batch_offset = clique.root_rank_ * part_size; + if (rank == clique.get_root_rank()) { // root rank + uint64_t batch_offset = clique.get_root_rank() * part_size; auto d_neighbors = raft::make_device_matrix_view( in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); auto d_distances = raft::make_device_matrix_view( @@ -236,20 +222,20 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, // wait for other ranks ncclGroupStart(); for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) { - if (from_rank == clique.root_rank_) continue; + if (from_rank == clique.get_root_rank()) continue; batch_offset = from_rank * part_size; ncclRecv(in_neighbors.data_handle() + batch_offset, part_size * sizeof(IdxT), ncclUint8, from_rank, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclRecv(in_distances.data_handle() + batch_offset, part_size * sizeof(float), ncclUint8, from_rank, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); @@ -267,14 +253,14 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, ncclSend(d_neighbors.data_handle(), part_size * sizeof(IdxT), ncclUint8, - clique.root_rank_, - clique.nccl_comms_[rank], + clique.get_root_rank(), + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclSend(d_distances.data_handle(), part_size * sizeof(float), ncclUint8, - clique.root_rank_, - clique.nccl_comms_[rank], + clique.get_root_rank(), + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclGroupEnd(); resource::sync_stream(dev_res); @@ -342,10 +328,8 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); int64_t part_size = n_rows_of_current_batch * n_neighbors; @@ -390,13 +374,13 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclRecv(tmp_distances.data_handle() + part_size, part_size * sizeof(float), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); received_something = true; } @@ -407,13 +391,13 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclSend(tmp_distances.data_handle(), part_size * sizeof(float), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); @@ -466,9 +450,7 @@ void run_search_batch(const raft::device_resources_snmg& clique, int64_t n_cols, int64_t n_neighbors) { - int dev_id = clique.device_ids_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; auto query_partition = raft::make_host_matrix_view( @@ -645,10 +627,8 @@ void serialize(const raft::device_resources_snmg& clique, serialize_scalar(handle, of, index.num_ranks_); for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::serialize(dev_res, ann_if, of); } diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index c3ef3705e6..e179a56e38 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(float, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index ea9ec672ba..3e369d9ac6 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(half, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index aeae0f2ccf..5ebf223d12 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(int8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index 22421d6f08..923031b1c3 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index 423aa02845..f90f6fcfbc 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 06bb7af267..2eefad5d57 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index bbf7d96f86..9684f19d8a 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(uint8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index 441a09e2fe..c71133ac45 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index bf6126feeb..df148620fc 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(half, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index 3921f810c3..afe5faa41d 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index 8f4683fd79..c725d21398 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(uint8_t, int64_t); diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index eb97b583cb..f634765c95 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -54,6 +54,7 @@ class AnnMGTest : public ::testing::TestWithParam { h_index_dataset(0), h_queries(0) { + clique_.set_memory_pool(80); } void testAnnMG() From 96e69fc60db6c3848d77857eda9ff3c3a671ec00 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 15 Nov 2024 17:59:00 +0100 Subject: [PATCH 04/17] switch from RAFT_LOG_INFO to RAFT_LOG_DEBUG for mg logs --- cpp/src/neighbors/mg/mg.cuh | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 54ac32cc3b..14ffbce93c 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -95,7 +95,7 @@ void build(const raft::device_resources_snmg& clique, { if (index.mode_ == REPLICATED) { int64_t n_rows = index_dataset.extent(0); - RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); + RAFT_LOG_DEBUG("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for @@ -110,7 +110,7 @@ void build(const raft::device_resources_snmg& clique, int64_t n_cols = index_dataset.extent(1); int64_t n_rows_per_shard = raft::ceildiv(n_rows, (int64_t)index.num_ranks_); - RAFT_LOG_INFO("SHARDED BUILD: %d*%drows", index.num_ranks_, n_rows_per_shard); + RAFT_LOG_DEBUG("SHARDED BUILD: %d*%drows", index.num_ranks_, n_rows_per_shard); index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for @@ -136,7 +136,7 @@ void extend(const raft::device_resources_snmg& clique, { int64_t n_rows = new_vectors.extent(0); if (index.mode_ == REPLICATED) { - RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); + RAFT_LOG_DEBUG("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { @@ -149,7 +149,7 @@ void extend(const raft::device_resources_snmg& clique, int64_t n_cols = new_vectors.extent(1); int64_t n_rows_per_shard = raft::ceildiv(n_rows, (int64_t)index.num_ranks_); - RAFT_LOG_INFO("SHARDED EXTEND: %d*%drows", index.num_ranks_, n_rows_per_shard); + RAFT_LOG_DEBUG("SHARDED EXTEND: %d*%drows", index.num_ranks_, n_rows_per_shard); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { @@ -515,7 +515,7 @@ void search(const raft::device_resources_snmg& clique, int64_t n_batches = raft::ceildiv(n_rows, (int64_t)n_rows_per_batch); if (n_batches <= 1) n_rows_per_batch = n_rows; - RAFT_LOG_INFO( + RAFT_LOG_DEBUG( "REPLICATED SEARCH IN LOAD BALANCER MODE: %d*%drows", n_batches, n_rows_per_batch); #pragma omp parallel for @@ -540,7 +540,7 @@ void search(const raft::device_resources_snmg& clique, n_neighbors); } } else if (search_mode == ROUND_ROBIN) { - RAFT_LOG_INFO("REPLICATED SEARCH IN ROUND ROBIN MODE: %d*%drows", 1, n_rows); + RAFT_LOG_DEBUG("REPLICATED SEARCH IN ROUND ROBIN MODE: %d*%drows", 1, n_rows); ASSERT(n_rows <= n_rows_per_batch, "In round-robin mode, n_rows must lower or equal to n_rows_per_batch"); @@ -584,9 +584,9 @@ void search(const raft::device_resources_snmg& clique, if (n_batches <= 1) n_rows_per_batch = n_rows; if (merge_mode == MERGE_ON_ROOT_RANK) { - RAFT_LOG_INFO("SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows", - n_batches, - n_rows_per_batch); + RAFT_LOG_DEBUG("SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows", + n_batches, + n_rows_per_batch); sharded_search_with_direct_merge(clique, index, search_params, @@ -599,7 +599,7 @@ void search(const raft::device_resources_snmg& clique, n_neighbors, n_batches); } else if (merge_mode == TREE_MERGE) { - RAFT_LOG_INFO( + RAFT_LOG_DEBUG( "SHARDED SEARCH WITH TREE_MERGE MERGE MODE %d*%drows", n_batches, n_rows_per_batch); sharded_search_with_tree_merge(clique, index, From 657bf9e36c96bb578e514a605b2510a5a3b95f09 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 25 Nov 2024 15:05:06 +0100 Subject: [PATCH 05/17] clique as device_resource --- .../ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 3 +-- .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 3 +-- .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 3 +-- cpp/test/neighbors/mg.cuh | 24 +++++++++---------- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index 27a0fd7acc..6a6580f4f9 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -68,7 +68,7 @@ class cuvs_mg_cagra : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -86,7 +86,6 @@ class cuvs_mg_cagra : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; raft::device_resources_snmg clique_; float refine_ratio_; build_param index_params_; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index 5e811da335..a2b91bc0ad 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -62,7 +62,7 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -73,7 +73,6 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; raft::device_resources_snmg clique_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index c4a820cad9..c2ce61cd86 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -62,7 +62,7 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -73,7 +73,6 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; raft::device_resources_snmg clique_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index f634765c95..853dc8c0ed 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -46,8 +46,7 @@ template class AnnMGTest : public ::testing::TestWithParam { public: AnnMGTest() - : stream_(resource::get_cuda_stream(handle_)), - clique_(), + : stream_(resource::get_cuda_stream(clique_)), ps(::testing::TestWithParam::GetParam()), d_index_dataset(0, stream_), d_queries(0, stream_), @@ -82,7 +81,7 @@ class AnnMGTest : public ::testing::TestWithParam { ps.metric); update_host(distances_ref.data(), distances_ref_dev.data(), queries_size, stream_); update_host(neighbors_ref.data(), neighbors_ref_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); + resource::sync_stream(clique_); } int64_t n_rows_per_search_batch = 3000; // [3000, 3000, 1000] == 7000 rows @@ -132,7 +131,7 @@ class AnnMGTest : public ::testing::TestWithParam { search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -191,7 +190,7 @@ class AnnMGTest : public ::testing::TestWithParam { search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -244,7 +243,7 @@ class AnnMGTest : public ::testing::TestWithParam { search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref_32bits, @@ -297,7 +296,7 @@ class AnnMGTest : public ::testing::TestWithParam { distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -349,7 +348,7 @@ class AnnMGTest : public ::testing::TestWithParam { distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -397,7 +396,7 @@ class AnnMGTest : public ::testing::TestWithParam { distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref_32bits, @@ -622,16 +621,15 @@ class AnnMGTest : public ::testing::TestWithParam { raft::copy(h_index_dataset.data(), d_index_dataset.data(), d_index_dataset.size(), - resource::get_cuda_stream(handle_)); + resource::get_cuda_stream(clique_)); raft::copy( - h_queries.data(), d_queries.data(), d_queries.size(), resource::get_cuda_stream(handle_)); - resource::sync_stream(handle_); + h_queries.data(), d_queries.data(), d_queries.size(), resource::get_cuda_stream(clique_)); + resource::sync_stream(clique_); } void TearDown() override {} private: - raft::device_resources handle_; rmm::cuda_stream_view stream_; raft::device_resources_snmg clique_; AnnMGInputs ps; From 1fdccd4e5c46625e8f1f467052d7b2ab4b5556e7 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 26 Nov 2024 14:31:46 +0100 Subject: [PATCH 06/17] updating MG tests --- cpp/test/neighbors/mg.cuh | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index 853dc8c0ed..b4131acdb9 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -46,10 +46,10 @@ template class AnnMGTest : public ::testing::TestWithParam { public: AnnMGTest() - : stream_(resource::get_cuda_stream(clique_)), + : clique_(), ps(::testing::TestWithParam::GetParam()), - d_index_dataset(0, stream_), - d_queries(0, stream_), + d_index_dataset(0, resource::get_cuda_stream(clique_)), + d_queries(0, resource::get_cuda_stream(clique_)), h_index_dataset(0), h_queries(0) { @@ -67,8 +67,9 @@ class AnnMGTest : public ::testing::TestWithParam { std::vector neighbors_snmg_ann_32bits(queries_size); { - rmm::device_uvector distances_ref_dev(queries_size, stream_); - rmm::device_uvector neighbors_ref_dev(queries_size, stream_); + rmm::device_uvector distances_ref_dev(queries_size, resource::get_cuda_stream(clique_)); + rmm::device_uvector neighbors_ref_dev(queries_size, + resource::get_cuda_stream(clique_)); cuvs::neighbors::naive_knn(clique_, distances_ref_dev.data(), neighbors_ref_dev.data(), @@ -79,8 +80,14 @@ class AnnMGTest : public ::testing::TestWithParam { ps.dim, ps.k, ps.metric); - update_host(distances_ref.data(), distances_ref_dev.data(), queries_size, stream_); - update_host(neighbors_ref.data(), neighbors_ref_dev.data(), queries_size, stream_); + update_host(distances_ref.data(), + distances_ref_dev.data(), + queries_size, + resource::get_cuda_stream(clique_)); + update_host(neighbors_ref.data(), + neighbors_ref_dev.data(), + queries_size, + resource::get_cuda_stream(clique_)); resource::sync_stream(clique_); } @@ -602,8 +609,8 @@ class AnnMGTest : public ::testing::TestWithParam { void SetUp() override { - d_index_dataset.resize(ps.num_db_vecs * ps.dim, stream_); - d_queries.resize(ps.num_queries * ps.dim, stream_); + d_index_dataset.resize(ps.num_db_vecs * ps.dim, resource::get_cuda_stream(clique_)); + d_queries.resize(ps.num_queries * ps.dim, resource::get_cuda_stream(clique_)); h_index_dataset.resize(ps.num_db_vecs * ps.dim); h_queries.resize(ps.num_queries * ps.dim); @@ -630,7 +637,6 @@ class AnnMGTest : public ::testing::TestWithParam { void TearDown() override {} private: - rmm::cuda_stream_view stream_; raft::device_resources_snmg clique_; AnnMGInputs ps; std::vector h_index_dataset; From 45a41faf2f1551ec2eccd5bfecf1b87c403a5ae8 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 20 Jan 2025 17:53:17 +0000 Subject: [PATCH 07/17] API unification (removal of mg namespace) --- cpp/CMakeLists.txt | 1 + .../ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 8 +- .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 8 +- .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 8 +- cpp/include/cuvs/neighbors/cagra.hpp | 462 ++++++ cpp/include/cuvs/neighbors/ivf_flat.hpp | 361 +++++ cpp/include/cuvs/neighbors/ivf_pq.hpp | 451 ++++++ cpp/include/cuvs/neighbors/mg.hpp | 1256 +---------------- cpp/src/neighbors/mg/generate_mg.py | 351 ++--- cpp/src/neighbors/mg/mg.cuh | 3 + .../neighbors/mg/mg_cagra_float_uint32_t.cu | 122 +- .../neighbors/mg/mg_cagra_half_uint32_t.cu | 122 +- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 122 +- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 122 +- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 138 +- .../neighbors/mg/mg_flat_int8_t_int64_t.cu | 138 +- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 138 +- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 138 +- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 138 +- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 138 +- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 138 +- cpp/test/neighbors/mg.cuh | 132 +- 22 files changed, 2277 insertions(+), 2218 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 78862cb334..59eab62efa 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -580,6 +580,7 @@ if(BUILD_SHARED_LIBS) if(BUILD_MG_ALGOS) target_compile_definitions(cuvs PUBLIC CUVS_BUILD_MG_ALGOS) target_compile_definitions(cuvs_objs PUBLIC CUVS_BUILD_MG_ALGOS) + target_compile_definitions(cuvs-cagra-search PUBLIC CUVS_BUILD_MG_ALGOS) endif() if(BUILD_CAGRA_HNSWLIB) diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index 6a6580f4f9..6762287e1f 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -104,7 +104,7 @@ void cuvs_mg_cagra::build(const T* dataset, size_t nrow) auto dataset_view = raft::make_host_matrix_view(dataset, nrow, dim_); - auto idx = cuvs::neighbors::mg::build(clique_, build_params, dataset_view); + auto idx = cuvs::neighbors::cagra::build(clique_, build_params, dataset_view); index_ = std::make_shared, T, IdxT>>( std::move(idx)); @@ -131,7 +131,7 @@ void cuvs_mg_cagra::set_search_dataset(const T* dataset, size_t nrow) template void cuvs_mg_cagra::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(clique_, *index_, file); + cuvs::neighbors::cagra::serialize(clique_, *index_, file); } template @@ -139,7 +139,7 @@ void cuvs_mg_cagra::load(const std::string& file) { index_ = std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_cagra(clique_, file))); + std::move(cuvs::neighbors::cagra::deserialize(clique_, file))); } template @@ -162,7 +162,7 @@ void cuvs_mg_cagra::search_base( auto distances_view = raft::make_host_matrix_view(distances, batch_size, k); - cuvs::neighbors::mg::search( + cuvs::neighbors::cagra::search( clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index a2b91bc0ad..de854b5084 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -85,7 +85,7 @@ void cuvs_mg_ivf_flat::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(clique_, index_params_, dataset_view); + auto idx = cuvs::neighbors::ivf_flat::build(clique_, index_params_, dataset_view); index_ = std::make_shared< cuvs::neighbors::mg::index, T, IdxT>>(std::move(idx)); } @@ -105,7 +105,7 @@ void cuvs_mg_ivf_flat::set_search_param(const search_param_base& param) template void cuvs_mg_ivf_flat::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(clique_, *index_, file); + cuvs::neighbors::ivf_flat::serialize(clique_, *index_, file); } template @@ -113,7 +113,7 @@ void cuvs_mg_ivf_flat::load(const std::string& file) { index_ = std::make_shared< cuvs::neighbors::mg::index, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_flat(clique_, file))); + std::move(cuvs::neighbors::ivf_flat::deserialize(clique_, file))); } template @@ -133,7 +133,7 @@ void cuvs_mg_ivf_flat::search( auto distances_view = raft::make_host_matrix_view( distances, IdxT(batch_size), IdxT(k)); - cuvs::neighbors::mg::search( + cuvs::neighbors::ivf_flat::search( clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index c2ce61cd86..9e6e20c988 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -84,7 +84,7 @@ void cuvs_mg_ivf_pq::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(clique_, index_params_, dataset_view); + auto idx = cuvs::neighbors::ivf_pq::build(clique_, index_params_, dataset_view); index_ = std::make_shared, T, IdxT>>( std::move(idx)); @@ -104,7 +104,7 @@ void cuvs_mg_ivf_pq::set_search_param(const search_param_base& param) template void cuvs_mg_ivf_pq::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(clique_, *index_, file); + cuvs::neighbors::ivf_pq::serialize(clique_, *index_, file); } template @@ -112,7 +112,7 @@ void cuvs_mg_ivf_pq::load(const std::string& file) { index_ = std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_pq(clique_, file))); + std::move(cuvs::neighbors::ivf_pq::deserialize(clique_, file))); } template @@ -132,7 +132,7 @@ void cuvs_mg_ivf_pq::search( auto distances_view = raft::make_host_matrix_view( distances, IdxT(batch_size), IdxT(k)); - cuvs::neighbors::mg::search( + cuvs::neighbors::ivf_pq::search( clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index a4684ce267..f62fe7aea2 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -1752,4 +1753,465 @@ void serialize_to_hnswlib(raft::resources const& handle, * @} */ +/// \defgroup mg_cpp_index_build ANN MG index build + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, float, uint32_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, half, uint32_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, int8_t, uint32_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, uint8_t, uint32_t>; + +/// \defgroup mg_cpp_index_extend ANN MG index extend + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, float, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, half, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, int8_t, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \defgroup mg_cpp_index_search ANN MG index search + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, uint32_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, half, uint32_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, uint32_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \defgroup mg_cpp_serialize ANN MG index serialization + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, uint32_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, half, uint32_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, uint32_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, + const std::string& filename); + +/// \defgroup mg_cpp_deserialize ANN MG index deserialization + +/// \ingroup mg_cpp_deserialize +/** + * @brief Deserializes a CAGRA multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::cagra::deserialize(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized + * + */ +template +auto deserialize(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; + +/// \defgroup mg_cpp_distribute ANN MG local index distribution + +/// \ingroup mg_cpp_distribute +/** + * @brief Replicates a locally built and serialized CAGRA index to all GPUs to form a distributed + * multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::cagra::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "local_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::cagra::distribute(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized : a local index + * + */ +template +auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; + } // namespace cuvs::neighbors::cagra diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index e017946d9c..a9cb02de61 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -19,6 +19,7 @@ #include "common.hpp" #include #include +#include #include #include @@ -1598,6 +1599,366 @@ void deserialize(raft::resources const& handle, * @} */ +/// \defgroup mg_cpp_index_build ANN MG index build + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-Flat MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, float, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-Flat MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, int8_t, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-Flat MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, uint8_t, int64_t>; + +/// \defgroup mg_cpp_index_extend ANN MG index extend + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, float, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, int8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \defgroup mg_cpp_index_search ANN MG index search + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \defgroup mg_cpp_serialize ANN MG index serialization + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_deserialize +/** + * @brief Deserializes an IVF-Flat multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::ivf_flat::deserialize(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized + * + */ +template +auto deserialize(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; + +/// \defgroup mg_cpp_distribute ANN MG local index distribution + +/// \ingroup mg_cpp_distribute +/** + * @brief Replicates a locally built and serialized IVF-Flat index to all GPUs to form a distributed + * multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::ivf_flat::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "local_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::ivf_flat::distribute(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized : a local index + * + */ +template +auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; + namespace helpers { /** diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index d85753b7f2..1be9a19240 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -1728,6 +1729,456 @@ void deserialize(raft::resources const& handle, * @} */ +/// \defgroup mg_cpp_index_build ANN MG index build + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, float, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, half, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, int8_t, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, uint8_t, int64_t>; + +/// \defgroup mg_cpp_index_extend ANN MG index extend + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, float, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, half, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, int8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \defgroup mg_cpp_index_search ANN MG index search + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, half, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \defgroup mg_cpp_serialize ANN MG index serialization + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, half, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_deserialize +/** + * @brief Deserializes an IVF-PQ multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::ivf_pq::deserialize(clique, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized + * + */ +template +auto deserialize(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; + +/// \defgroup mg_cpp_distribute ANN MG local index distribution + +/// \ingroup mg_cpp_distribute +/** + * @brief Replicates a locally built and serialized IVF-PQ index to all GPUs to form a distributed + * multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::ivf_pq::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "local_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::ivf_pq::distribute(clique, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized : a local index + * + */ +template +auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; namespace helpers { /** * @defgroup ivf_pq_cpp_helpers IVF-PQ helper methods diff --git a/cpp/include/cuvs/neighbors/mg.hpp b/cpp/include/cuvs/neighbors/mg.hpp index 86572adebd..d2be229513 100644 --- a/cpp/include/cuvs/neighbors/mg.hpp +++ b/cpp/include/cuvs/neighbors/mg.hpp @@ -18,16 +18,8 @@ #ifdef CUVS_BUILD_MG_ALGOS -#include -#include - -#include -#include - -#include #include -#include -#include +#include #define DEFAULT_SEARCH_BATCH_SIZE 1 << 20 @@ -92,12 +84,6 @@ struct search_params : public Upstream { cuvs::neighbors::mg::sharded_merge_mode merge_mode = TREE_MERGE; }; -} // namespace cuvs::neighbors::mg - -namespace cuvs::neighbors::mg { - -using namespace raft; - template struct index { index(distribution_mode mode, int num_ranks_); @@ -116,1246 +102,6 @@ struct index { std::shared_ptr> round_robin_counter_; }; -/// \defgroup mg_cpp_index_build ANN MG index build - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-Flat MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, float, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-Flat MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, int8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-Flat MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, uint8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, float, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, half, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, int8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, uint8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, float, uint32_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, half, uint32_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, int8_t, uint32_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, uint8_t, uint32_t>; - -/// \defgroup mg_cpp_index_extend ANN MG index extend - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, float, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, int8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, uint8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, float, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, half, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, int8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, uint8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, float, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, half, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, int8_t, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, uint8_t, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \defgroup mg_cpp_index_search ANN MG index search - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, float, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, int8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, uint8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, float, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, half, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, int8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, uint8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, float, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, half, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, int8_t, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, uint8_t, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \defgroup mg_cpp_serialize ANN MG index serialization - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, float, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, int8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, uint8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, float, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, half, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, int8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, uint8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, float, uint32_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, half, uint32_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, int8_t, uint32_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, uint8_t, uint32_t>& index, - const std::string& filename); - -/// \defgroup mg_cpp_deserialize ANN MG index deserialization - -/// \ingroup mg_cpp_deserialize -/** - * @brief Deserializes an IVF-Flat multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_flat(clique, filename); - * - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized - * - */ -template -auto deserialize_flat(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_deserialize -/** - * @brief Deserializes an IVF-PQ multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_pq(clique, filename); - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized - * - */ -template -auto deserialize_pq(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_deserialize -/** - * @brief Deserializes a CAGRA multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_cagra(clique, filename); - * - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized - * - */ -template -auto deserialize_cagra(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - -/// \defgroup mg_cpp_distribute ANN MG local index distribution - -/// \ingroup mg_cpp_distribute -/** - * @brief Replicates a locally built and serialized IVF-Flat index to all GPUs to form a distributed - * multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::ivf_flat::index_params index_params; - * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); - * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_flat::serialize(clique, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_flat(clique, filename); - * - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized : a local index - * - */ -template -auto distribute_flat(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_distribute -/** - * @brief Replicates a locally built and serialized IVF-PQ index to all GPUs to form a distributed - * multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::ivf_pq::index_params index_params; - * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); - * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_pq::serialize(clique, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_pq(clique, filename); - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized : a local index - * - */ -template -auto distribute_pq(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_distribute -/** - * @brief Replicates a locally built and serialized CAGRA index to all GPUs to form a distributed - * multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::cagra::index_params index_params; - * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); - * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::cagra::serialize(clique, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_cagra(clique, filename); - * - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized : a local index - * - */ -template -auto distribute_cagra(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - } // namespace cuvs::neighbors::mg #else diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index 26e81da169..d7089e56ce 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -43,181 +43,194 @@ #include "mg.cuh" """ -namespace_macro = """ -namespace cuvs::neighbors::mg { -""" - -footer = """ -} // namespace cuvs::neighbors::mg -""" - flat_macro = """ -#define CUVS_INST_MG_FLAT(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ - const mg::index_params& index_params, \\ - raft::host_matrix_view index_dataset) \\ - { \\ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::build(clique, index, \\ - static_cast(&index_params), \\ - index_dataset); \\ - return index; \\ - } \\ - \\ - void extend(const raft::device_resources_snmg& clique, \\ - index, T, IdxT>& index, \\ - raft::host_matrix_view new_vectors, \\ - std::optional> new_indices) \\ - { \\ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ - } \\ - \\ - void search(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ - raft::host_matrix_view queries, \\ - raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ - { \\ - cuvs::neighbors::mg::detail::search(clique, index, \\ - static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ - } \\ - \\ - void serialize(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const std::string& filename) \\ - { \\ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ - } \\ - \\ - template<> \\ - index, T, IdxT> deserialize_flat(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(clique, filename); \\ - return idx; \\ - } \\ - \\ - template<> \\ - index, T, IdxT> distribute_flat(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ - return idx; \\ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \\ +namespace cuvs::neighbors::ivf_flat { \\ + using namespace cuvs::neighbors::mg; \\ + \\ + cuvs::neighbors::mg::index, T, IdxT> build( \\ + const raft::device_resources_snmg& clique, \\ + const mg::index_params& index_params, \\ + raft::host_matrix_view index_dataset) \\ + { \\ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ + static_cast(&index_params), \\ + index_dataset); \\ + return index; \\ + } \\ + \\ + void extend(const raft::device_resources_snmg& clique, \\ + cuvs::neighbors::mg::index, T, IdxT>& index, \\ + raft::host_matrix_view new_vectors, \\ + std::optional> new_indices) \\ + { \\ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ + } \\ + \\ + void search(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const mg::search_params& search_params, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbors, \\ + raft::host_matrix_view distances, \\ + int64_t n_rows_per_batch) \\ + { \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ + static_cast(&search_params), \\ + queries, neighbors, distances, n_rows_per_batch); \\ + } \\ + \\ + void serialize(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const std::string& filename) \\ + { \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \\ + return idx; \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> distribute( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ + return idx; \\ + } \\ +} // namespace cuvs::neighbors::ivf_flat """ pq_macro = """ -#define CUVS_INST_MG_PQ(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ - const mg::index_params& index_params, \\ - raft::host_matrix_view index_dataset) \\ - { \\ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::build(clique, index, \\ - static_cast(&index_params), \\ - index_dataset); \\ - return index; \\ - } \\ - \\ - void extend(const raft::device_resources_snmg& clique, \\ - index, T, IdxT>& index, \\ - raft::host_matrix_view new_vectors, \\ - std::optional> new_indices) \\ - { \\ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ - } \\ - \\ - void search(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ - raft::host_matrix_view queries, \\ - raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ - { \\ - cuvs::neighbors::mg::detail::search(clique, index, \\ - static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ - } \\ - \\ - void serialize(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const std::string& filename) \\ - { \\ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ - } \\ - \\ - template<> \\ - index, T, IdxT> deserialize_pq(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(clique, filename); \\ - return idx; \\ - } \\ - \\ - template<> \\ - index, T, IdxT> distribute_pq(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ - return idx; \\ - } +#define CUVS_INST_MG_PQ(T, IdxT) \\ +namespace cuvs::neighbors::ivf_pq { \\ + using namespace cuvs::neighbors::mg; \\ + \\ + cuvs::neighbors::mg::index, T, IdxT> build( \\ + const raft::device_resources_snmg& clique, \\ + const mg::index_params& index_params, \\ + raft::host_matrix_view index_dataset) \\ + { \\ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ + static_cast(&index_params), \\ + index_dataset); \\ + return index; \\ + } \\ + \\ + void extend(const raft::device_resources_snmg& clique, \\ + cuvs::neighbors::mg::index, T, IdxT>& index, \\ + raft::host_matrix_view new_vectors, \\ + std::optional> new_indices) \\ + { \\ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ + } \\ + \\ + void search(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const mg::search_params& search_params, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbors, \\ + raft::host_matrix_view distances, \\ + int64_t n_rows_per_batch) \\ + { \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ + static_cast(&search_params), \\ + queries, neighbors, distances, n_rows_per_batch); \\ + } \\ + \\ + void serialize(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const std::string& filename) \\ + { \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \\ + return idx; \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> distribute( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ + return idx; \\ + } \\ +} // namespace cuvs::neighbors::ivf_pq """ cagra_macro = """ -#define CUVS_INST_MG_CAGRA(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ - const mg::index_params& index_params, \\ - raft::host_matrix_view index_dataset) \\ - { \\ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::build(clique, index, \\ - static_cast(&index_params), \\ - index_dataset); \\ - return index; \\ - } \\ - \\ - void search(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ - raft::host_matrix_view queries, \\ - raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ - { \\ - cuvs::neighbors::mg::detail::search(clique, index, \\ - static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ - } \\ - \\ - void serialize(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const std::string& filename) \\ - { \\ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ - } \\ - \\ - template<> \\ - index, T, IdxT> deserialize_cagra(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(clique, filename); \\ - return idx; \\ - } \\ - \\ - template<> \\ - index, T, IdxT> distribute_cagra(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ - return idx; \\ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \\ +namespace cuvs::neighbors::cagra { \\ + using namespace cuvs::neighbors::mg; \\ + \\ + cuvs::neighbors::mg::index, T, IdxT> build( \\ + const raft::device_resources_snmg& clique, \\ + const mg::index_params& index_params, \\ + raft::host_matrix_view index_dataset) \\ + { \\ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ + static_cast(&index_params), \\ + index_dataset); \\ + return index; \\ + } \\ + \\ + void search(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const mg::search_params& search_params, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbors, \\ + raft::host_matrix_view distances, \\ + int64_t n_rows_per_batch) \\ + { \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ + static_cast(&search_params), \\ + queries, neighbors, distances, n_rows_per_batch); \\ + } \\ + \\ + void serialize(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const std::string& filename) \\ + { \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \\ + return idx; \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> distribute( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ + return idx; \\ + } \\ +} // namespace cuvs::neighbors::cagra """ flat_macros = dict ( @@ -271,10 +284,8 @@ with open(path, "w") as f: f.write(header) f.write(macro['include']) - f.write(namespace_macro) f.write(macro["definition"]) f.write(f"{macro['name']}({T}, {IdxT});\n\n") f.write(f"#undef {macro['name']}\n") - f.write(footer) print(f"src/neighbors/mg/{path}") diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 14ffbce93c..f36ef8fa31 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -21,7 +21,10 @@ #include #include +#include #include +#include +#include #include #include diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index e179a56e38..fbba29b25c 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -25,66 +25,68 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(float, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index 3e369d9ac6..4633cc77cf 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -25,66 +25,68 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(half, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index 5ebf223d12..4c15e09f5b 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -25,66 +25,68 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(int8_t, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index 923031b1c3..8c585c2993 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -25,66 +25,68 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(uint8_t, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index f90f6fcfbc..ef84cff6c3 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>( \ + REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(float, int64_t); #undef CUVS_INST_MG_FLAT - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 2eefad5d57..6e6daace7c 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>( \ + REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(int8_t, int64_t); #undef CUVS_INST_MG_FLAT - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index 9684f19d8a..ab0f4fb10c 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>( \ + REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(uint8_t, int64_t); #undef CUVS_INST_MG_FLAT - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index c71133ac45..3b5268c699 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(float, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index df148620fc..e6d18bd304 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(half, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index afe5faa41d..ead2186778 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(int8_t, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index c725d21398..27d36c34be 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(uint8_t, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index b4131acdb9..32532df3f0 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -19,7 +19,9 @@ #include "ann_utils.cuh" #include "naive_knn.cuh" -#include +#include +#include +#include #include namespace cuvs::neighbors::mg { @@ -125,18 +127,18 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); - cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(clique_, index, "mg_ivf_flat_index"); + auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_flat::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::ivf_flat::serialize(clique_, index, "mg_ivf_flat_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_flat(clique_, "mg_ivf_flat_index"); + cuvs::neighbors::ivf_flat::deserialize(clique_, "mg_ivf_flat_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search( + cuvs::neighbors::ivf_flat::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(clique_); @@ -184,18 +186,18 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); - cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(clique_, index, "mg_ivf_pq_index"); + auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_pq::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::ivf_pq::serialize(clique_, index, "mg_ivf_pq_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_pq(clique_, "mg_ivf_pq_index"); + cuvs::neighbors::ivf_pq::deserialize(clique_, "mg_ivf_pq_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search( + cuvs::neighbors::ivf_pq::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(clique_); @@ -238,17 +240,17 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); - cuvs::neighbors::mg::serialize(clique_, index, "mg_cagra_index"); + auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset); + cuvs::neighbors::cagra::serialize(clique_, index, "mg_cagra_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_cagra(clique_, "mg_cagra_index"); + cuvs::neighbors::cagra::deserialize(clique_, "mg_cagra_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search( + cuvs::neighbors::cagra::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(clique_); @@ -293,15 +295,15 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_flat(clique_, "local_ivf_flat_index"); + cuvs::neighbors::ivf_flat::distribute(clique_, "local_ivf_flat_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(clique_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); + cuvs::neighbors::ivf_flat::search(clique_, + distributed_index, + search_params, + queries, + neighbors, + distances, + n_rows_per_search_batch); resource::sync_stream(clique_); @@ -345,15 +347,15 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_pq(clique_, "local_ivf_pq_index"); + cuvs::neighbors::ivf_pq::distribute(clique_, "local_ivf_pq_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(clique_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); + cuvs::neighbors::ivf_pq::search(clique_, + distributed_index, + search_params, + queries, + neighbors, + distances, + n_rows_per_search_batch); resource::sync_stream(clique_); @@ -392,16 +394,16 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_cagra(clique_, "local_cagra_index"); + cuvs::neighbors::cagra::distribute(clique_, "local_cagra_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(clique_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); + cuvs::neighbors::cagra::search(clique_, + distributed_index, + search_params, + queries, + neighbors, + distances, + n_rows_per_search_batch); resource::sync_stream(clique_); @@ -439,8 +441,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); - cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_flat::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -455,13 +457,13 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(clique_, - index, - search_params, - small_batch_query, - small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + cuvs::neighbors::ivf_flat::search(clique_, + index, + search_params, + small_batch_query, + small_batch_neighbors, + small_batch_distances, + n_rows_per_search_batch); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), @@ -503,8 +505,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); - cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_pq::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -519,13 +521,13 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(clique_, - index, - search_params, - small_batch_query, - small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + cuvs::neighbors::ivf_pq::search(clique_, + index, + search_params, + small_batch_query, + small_batch_neighbors, + small_batch_distances, + n_rows_per_search_batch); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), @@ -563,7 +565,7 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -578,13 +580,13 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(clique_, - index, - search_params, - small_batch_query, - small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + cuvs::neighbors::cagra::search(clique_, + index, + search_params, + small_batch_query, + small_batch_neighbors, + small_batch_distances, + n_rows_per_search_batch); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), From db12ab92d7d2c538cee8b15a8e142f0d443fdff8 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 20 Jan 2025 18:43:05 +0000 Subject: [PATCH 08/17] Adding CUVS_BUILD_MG_ALGOS macro back --- cpp/include/cuvs/neighbors/cagra.hpp | 9 ++++++++- cpp/include/cuvs/neighbors/ivf_flat.hpp | 9 ++++++++- cpp/include/cuvs/neighbors/ivf_pq.hpp | 10 +++++++++- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index f62fe7aea2..7b311a1b55 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -32,6 +31,10 @@ #include #include +#ifdef CUVS_BUILD_MG_ALGOS +#include +#endif + #include #include @@ -1753,6 +1756,8 @@ void serialize_to_hnswlib(raft::resources const& handle, * @} */ +#ifdef CUVS_BUILD_MG_ALGOS + /// \defgroup mg_cpp_index_build ANN MG index build /// \ingroup mg_cpp_index_build @@ -2214,4 +2219,6 @@ template auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; +#endif + } // namespace cuvs::neighbors::cagra diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index a9cb02de61..df77542c6e 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -19,10 +19,13 @@ #include "common.hpp" #include #include -#include #include #include +#ifdef CUVS_BUILD_MG_ALGOS +#include +#endif + namespace cuvs::neighbors::ivf_flat { /** * @defgroup ivf_flat_cpp_index_params IVF-Flat index build parameters @@ -1599,6 +1602,8 @@ void deserialize(raft::resources const& handle, * @} */ +#ifdef CUVS_BUILD_MG_ALGOS + /// \defgroup mg_cpp_index_build ANN MG index build /// \ingroup mg_cpp_index_build @@ -1959,6 +1964,8 @@ template auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; +#endif + namespace helpers { /** diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 1be9a19240..1b4fc87fe1 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -19,7 +19,6 @@ #include #include -#include #include #include @@ -28,6 +27,10 @@ #include #include +#ifdef CUVS_BUILD_MG_ALGOS +#include +#endif + namespace cuvs::neighbors::ivf_pq { /** @@ -1729,6 +1732,8 @@ void deserialize(raft::resources const& handle, * @} */ +#ifdef CUVS_BUILD_MG_ALGOS + /// \defgroup mg_cpp_index_build ANN MG index build /// \ingroup mg_cpp_index_build @@ -2179,6 +2184,9 @@ auto deserialize(const raft::device_resources_snmg& clique, const std::string& f template auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; + +#endif + namespace helpers { /** * @defgroup ivf_pq_cpp_helpers IVF-PQ helper methods From db4cf11b85fbc43d4128e3f93fa6118d4ab86d74 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 29 Jan 2025 16:37:20 +0000 Subject: [PATCH 09/17] Main requested changes --- cpp/include/cuvs/neighbors/cagra.hpp | 52 +++---- cpp/include/cuvs/neighbors/ivf_flat.hpp | 40 +++--- cpp/include/cuvs/neighbors/ivf_pq.hpp | 52 +++---- cpp/include/cuvs/neighbors/mg.hpp | 9 +- cpp/src/neighbors/mg/generate_mg.py | 95 ++++++------- cpp/src/neighbors/mg/mg.cuh | 77 ++++++---- .../neighbors/mg/mg_cagra_float_uint32_t.cu | 118 ++++++++------- .../neighbors/mg/mg_cagra_half_uint32_t.cu | 118 ++++++++------- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 118 ++++++++------- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 118 ++++++++------- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 36 +++-- .../neighbors/mg/mg_flat_int8_t_int64_t.cu | 36 +++-- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 36 +++-- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 134 +++++++++--------- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 134 +++++++++--------- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 134 +++++++++--------- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 134 +++++++++--------- cpp/test/neighbors/mg.cuh | 60 ++++---- 18 files changed, 729 insertions(+), 772 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 7b311a1b55..30f36c8bcc 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -1777,7 +1777,7 @@ void serialize_to_hnswlib(raft::resources const& handle, * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, float, uint32_t>; @@ -1799,7 +1799,7 @@ auto build(const raft::device_resources_snmg& clique, * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, half, uint32_t>; @@ -1821,7 +1821,7 @@ auto build(const raft::device_resources_snmg& clique, * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, int8_t, uint32_t>; @@ -1843,7 +1843,7 @@ auto build(const raft::device_resources_snmg& clique, * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, uint8_t, uint32_t>; @@ -1869,7 +1869,7 @@ auto build(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, float, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1893,7 +1893,7 @@ void extend(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, half, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1917,7 +1917,7 @@ void extend(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, int8_t, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1941,7 +1941,7 @@ void extend(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1968,16 +1968,14 @@ void extend(const raft::device_resources_snmg& clique, * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources_snmg& clique, +void search(const raft::resources& clique, const cuvs::neighbors::mg::index, float, uint32_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \ingroup mg_cpp_index_search /** @@ -1999,16 +1997,14 @@ void search(const raft::device_resources_snmg& clique, * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources_snmg& clique, +void search(const raft::resources& clique, const cuvs::neighbors::mg::index, half, uint32_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \ingroup mg_cpp_index_search /** @@ -2030,17 +2026,15 @@ void search(const raft::device_resources_snmg& clique, * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ void search( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, int8_t, uint32_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \ingroup mg_cpp_index_search /** @@ -2062,17 +2056,15 @@ void search( * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ void search( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \defgroup mg_cpp_serialize ANN MG index serialization @@ -2095,7 +2087,7 @@ void search( * */ void serialize( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, float, uint32_t>& index, const std::string& filename); @@ -2118,7 +2110,7 @@ void serialize( * */ void serialize( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, half, uint32_t>& index, const std::string& filename); @@ -2141,7 +2133,7 @@ void serialize( * */ void serialize( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, int8_t, uint32_t>& index, const std::string& filename); @@ -2164,7 +2156,7 @@ void serialize( * */ void serialize( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, const std::string& filename); @@ -2190,7 +2182,7 @@ void serialize( * */ template -auto deserialize(const raft::device_resources_snmg& clique, const std::string& filename) +auto deserialize(const raft::resources& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; /// \defgroup mg_cpp_distribute ANN MG local index distribution @@ -2216,7 +2208,7 @@ auto deserialize(const raft::device_resources_snmg& clique, const std::string& f * */ template -auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) +auto distribute(const raft::resources& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; #endif diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index df77542c6e..e0f8e572cd 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -1623,7 +1623,7 @@ void deserialize(raft::resources const& handle, * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, float, int64_t>; @@ -1645,7 +1645,7 @@ auto build(const raft::device_resources_snmg& clique, * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, int8_t, int64_t>; @@ -1667,7 +1667,7 @@ auto build(const raft::device_resources_snmg& clique, * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, uint8_t, int64_t>; @@ -1693,7 +1693,7 @@ auto build(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, float, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1717,7 +1717,7 @@ void extend(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, int8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1741,7 +1741,7 @@ void extend(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, uint8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1768,17 +1768,15 @@ void extend(const raft::device_resources_snmg& clique, * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ void search( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, float, int64_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \ingroup mg_cpp_index_search /** @@ -1800,17 +1798,15 @@ void search( * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ void search( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, int8_t, int64_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \ingroup mg_cpp_index_search /** @@ -1832,17 +1828,15 @@ void search( * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ void search( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \defgroup mg_cpp_serialize ANN MG index serialization @@ -1865,7 +1859,7 @@ void search( * */ void serialize( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, float, int64_t>& index, const std::string& filename); @@ -1888,7 +1882,7 @@ void serialize( * */ void serialize( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, int8_t, int64_t>& index, const std::string& filename); @@ -1911,7 +1905,7 @@ void serialize( * */ void serialize( - const raft::device_resources_snmg& clique, + const raft::resources& clique, const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, const std::string& filename); @@ -1935,7 +1929,7 @@ void serialize( * */ template -auto deserialize(const raft::device_resources_snmg& clique, const std::string& filename) +auto deserialize(const raft::resources& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; /// \defgroup mg_cpp_distribute ANN MG local index distribution @@ -1961,7 +1955,7 @@ auto deserialize(const raft::device_resources_snmg& clique, const std::string& f * */ template -auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) +auto distribute(const raft::resources& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; #endif diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 1b4fc87fe1..8396baeb73 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -1753,7 +1753,7 @@ void deserialize(raft::resources const& handle, * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, float, int64_t>; @@ -1775,7 +1775,7 @@ auto build(const raft::device_resources_snmg& clique, * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, half, int64_t>; @@ -1797,7 +1797,7 @@ auto build(const raft::device_resources_snmg& clique, * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, int8_t, int64_t>; @@ -1819,7 +1819,7 @@ auto build(const raft::device_resources_snmg& clique, * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources_snmg& clique, +auto build(const raft::resources& clique, const cuvs::neighbors::mg::index_params& index_params, raft::host_matrix_view index_dataset) -> cuvs::neighbors::mg::index, uint8_t, int64_t>; @@ -1845,7 +1845,7 @@ auto build(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, float, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1869,7 +1869,7 @@ void extend(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, half, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1893,7 +1893,7 @@ void extend(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, int8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1917,7 +1917,7 @@ void extend(const raft::device_resources_snmg& clique, * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& clique, cuvs::neighbors::mg::index, uint8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1944,16 +1944,14 @@ void extend(const raft::device_resources_snmg& clique, * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources_snmg& clique, +void search(const raft::resources& clique, const cuvs::neighbors::mg::index, float, int64_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \ingroup mg_cpp_index_search /** @@ -1975,16 +1973,14 @@ void search(const raft::device_resources_snmg& clique, * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources_snmg& clique, +void search(const raft::resources& clique, const cuvs::neighbors::mg::index, half, int64_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \ingroup mg_cpp_index_search /** @@ -2006,16 +2002,14 @@ void search(const raft::device_resources_snmg& clique, * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources_snmg& clique, +void search(const raft::resources& clique, const cuvs::neighbors::mg::index, int8_t, int64_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \ingroup mg_cpp_index_search /** @@ -2037,16 +2031,14 @@ void search(const raft::device_resources_snmg& clique, * @param[in] queries a row-major matrix on host [n_rows, dim] * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources_snmg& clique, +void search(const raft::resources& clique, const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, const cuvs::neighbors::mg::search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + raft::host_matrix_view distances); /// \defgroup mg_cpp_serialize ANN MG index serialization @@ -2068,7 +2060,7 @@ void search(const raft::device_resources_snmg& clique, * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources_snmg& clique, +void serialize(const raft::resources& clique, const cuvs::neighbors::mg::index, float, int64_t>& index, const std::string& filename); @@ -2090,7 +2082,7 @@ void serialize(const raft::device_resources_snmg& clique, * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources_snmg& clique, +void serialize(const raft::resources& clique, const cuvs::neighbors::mg::index, half, int64_t>& index, const std::string& filename); @@ -2112,7 +2104,7 @@ void serialize(const raft::device_resources_snmg& clique, * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources_snmg& clique, +void serialize(const raft::resources& clique, const cuvs::neighbors::mg::index, int8_t, int64_t>& index, const std::string& filename); @@ -2134,7 +2126,7 @@ void serialize(const raft::device_resources_snmg& clique, * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources_snmg& clique, +void serialize(const raft::resources& clique, const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, const std::string& filename); @@ -2157,7 +2149,7 @@ void serialize(const raft::device_resources_snmg& clique, * */ template -auto deserialize(const raft::device_resources_snmg& clique, const std::string& filename) +auto deserialize(const raft::resources& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; /// \defgroup mg_cpp_distribute ANN MG local index distribution @@ -2182,7 +2174,7 @@ auto deserialize(const raft::device_resources_snmg& clique, const std::string& f * */ template -auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) +auto distribute(const raft::resources& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; #endif diff --git a/cpp/include/cuvs/neighbors/mg.hpp b/cpp/include/cuvs/neighbors/mg.hpp index d2be229513..685185a644 100644 --- a/cpp/include/cuvs/neighbors/mg.hpp +++ b/cpp/include/cuvs/neighbors/mg.hpp @@ -19,9 +19,6 @@ #ifdef CUVS_BUILD_MG_ALGOS #include -#include - -#define DEFAULT_SEARCH_BATCH_SIZE 1 << 20 /// \defgroup mg_cpp_index_params ANN MG index build parameters @@ -82,12 +79,14 @@ struct search_params : public Upstream { cuvs::neighbors::mg::replicated_search_mode search_mode = LOAD_BALANCER; /** Sharded merge mode */ cuvs::neighbors::mg::sharded_merge_mode merge_mode = TREE_MERGE; + /** Number of rows per batch */ + int64_t n_rows_per_batch = 1 << 20; }; template struct index { - index(distribution_mode mode, int num_ranks_); - index(const raft::device_resources_snmg& clique, const std::string& filename); + index(const raft::resources& clique, distribution_mode mode); + index(const raft::resources& clique, const std::string& filename); index(const index&) = delete; index(index&&) = default; diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index d7089e56ce..20b570a7de 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -49,61 +49,60 @@ using namespace cuvs::neighbors::mg; \\ \\ cuvs::neighbors::mg::index, T, IdxT> build( \\ - const raft::device_resources_snmg& clique, \\ + const raft::resources& res, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::build(clique, index, \\ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \\ + cuvs::neighbors::mg::detail::build(res, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void extend(const raft::device_resources_snmg& clique, \\ + void extend(const raft::resources& res, \\ cuvs::neighbors::mg::index, T, IdxT>& index, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \\ } \\ \\ - void search(const raft::device_resources_snmg& clique, \\ + void search(const raft::resources& res, \\ const cuvs::neighbors::mg::index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ + raft::host_matrix_view distances) \\ { \\ - cuvs::neighbors::mg::detail::search(clique, index, \\ + cuvs::neighbors::mg::detail::search(res, index, \\ static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ + queries, neighbors, distances); \\ } \\ \\ - void serialize(const raft::device_resources_snmg& clique, \\ + void serialize(const raft::resources& res, \\ const cuvs::neighbors::mg::index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \\ } \\ \\ template<> \\ cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ - const raft::device_resources_snmg& clique, \\ + const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \\ return idx; \\ } \\ \\ template<> \\ cuvs::neighbors::mg::index, T, IdxT> distribute( \\ - const raft::device_resources_snmg& clique, \\ + const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \\ return idx; \\ } \\ } // namespace cuvs::neighbors::ivf_flat @@ -115,61 +114,60 @@ using namespace cuvs::neighbors::mg; \\ \\ cuvs::neighbors::mg::index, T, IdxT> build( \\ - const raft::device_resources_snmg& clique, \\ + const raft::resources& res, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::build(clique, index, \\ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \\ + cuvs::neighbors::mg::detail::build(res, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void extend(const raft::device_resources_snmg& clique, \\ + void extend(const raft::resources& res, \\ cuvs::neighbors::mg::index, T, IdxT>& index, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \\ } \\ \\ - void search(const raft::device_resources_snmg& clique, \\ + void search(const raft::resources& res, \\ const cuvs::neighbors::mg::index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ + raft::host_matrix_view distances) \\ { \\ - cuvs::neighbors::mg::detail::search(clique, index, \\ + cuvs::neighbors::mg::detail::search(res, index, \\ static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ + queries, neighbors, distances); \\ } \\ \\ - void serialize(const raft::device_resources_snmg& clique, \\ + void serialize(const raft::resources& res, \\ const cuvs::neighbors::mg::index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \\ } \\ \\ template<> \\ cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ - const raft::device_resources_snmg& clique, \\ + const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \\ return idx; \\ } \\ \\ template<> \\ cuvs::neighbors::mg::index, T, IdxT> distribute( \\ - const raft::device_resources_snmg& clique, \\ + const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \\ return idx; \\ } \\ } // namespace cuvs::neighbors::ivf_pq @@ -181,53 +179,52 @@ using namespace cuvs::neighbors::mg; \\ \\ cuvs::neighbors::mg::index, T, IdxT> build( \\ - const raft::device_resources_snmg& clique, \\ + const raft::resources& res, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::build(clique, index, \\ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \\ + cuvs::neighbors::mg::detail::build(res, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void search(const raft::device_resources_snmg& clique, \\ + void search(const raft::resources& res, \\ const cuvs::neighbors::mg::index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ + raft::host_matrix_view distances) \\ { \\ - cuvs::neighbors::mg::detail::search(clique, index, \\ + cuvs::neighbors::mg::detail::search(res, index, \\ static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ + queries, neighbors, distances); \\ } \\ \\ - void serialize(const raft::device_resources_snmg& clique, \\ + void serialize(const raft::resources& res, \\ const cuvs::neighbors::mg::index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \\ } \\ \\ template<> \\ cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ - const raft::device_resources_snmg& clique, \\ + const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \\ return idx; \\ } \\ \\ template<> \\ cuvs::neighbors::mg::index, T, IdxT> distribute( \\ - const raft::device_resources_snmg& clique, \\ + const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \\ return idx; \\ } \\ } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index f36ef8fa31..616e779f14 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -17,6 +17,7 @@ #pragma once #include "../detail/knn_merge_parts.cuh" +#include #include #include #include @@ -32,6 +33,18 @@ namespace cuvs::neighbors { using namespace raft; +inline const raft::device_resources_snmg& get_snmng_clique(const raft::resources& res) +{ + try { + const raft::device_resources_snmg& clique = + dynamic_cast(res); + return clique; + } catch (const std::bad_cast& e) { + throw std::runtime_error( + "MG function was not used with an appropriate raft::device_resources_snmg object"); + } +} + template void search(const raft::device_resources& handle, const cuvs::neighbors::iface& interface, @@ -51,10 +64,11 @@ using namespace raft; // local index deserialization and distribution template -void deserialize_and_distribute(const raft::device_resources_snmg& clique, +void deserialize_and_distribute(const raft::resources& res, index& index, const std::string& filename) { + const auto& clique = get_snmng_clique(res); for (int rank = 0; rank < index.num_ranks_; rank++) { const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_.emplace_back(); @@ -64,10 +78,11 @@ void deserialize_and_distribute(const raft::device_resources_snmg& clique, // MG index deserialization template -void deserialize(const raft::device_resources_snmg& clique, +void deserialize(const raft::resources& res, index& index, const std::string& filename) { + const auto& clique = get_snmng_clique(res); std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } @@ -91,11 +106,12 @@ void deserialize(const raft::device_resources_snmg& clique, } template -void build(const raft::device_resources_snmg& clique, +void build(const raft::resources& res, index& index, const cuvs::neighbors::index_params* index_params, raft::host_matrix_view index_dataset) { + const auto& clique = get_snmng_clique(res); if (index.mode_ == REPLICATED) { int64_t n_rows = index_dataset.extent(0); RAFT_LOG_DEBUG("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); @@ -132,12 +148,13 @@ void build(const raft::device_resources_snmg& clique, } template -void extend(const raft::device_resources_snmg& clique, +void extend(const raft::resources& res, index& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { - int64_t n_rows = new_vectors.extent(0); + const auto& clique = get_snmng_clique(res); + int64_t n_rows = new_vectors.extent(0); if (index.mode_ == REPLICATED) { RAFT_LOG_DEBUG("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); @@ -177,7 +194,7 @@ void extend(const raft::device_resources_snmg& clique, } template -void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, +void sharded_search_with_direct_merge(const raft::resources& res, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -189,6 +206,7 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, int64_t n_neighbors, int64_t n_batches) { + const auto& clique = get_snmng_clique(res); const auto& root_handle = clique.set_current_device_to_root_rank(); auto in_neighbors = raft::make_device_matrix( root_handle, index.num_ranks_ * n_rows_per_batch, n_neighbors); @@ -309,7 +327,7 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, } template -void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, +void sharded_search_with_tree_merge(const raft::resources& res, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -321,6 +339,7 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, int64_t n_neighbors, int64_t n_batches) { + const auto& clique = get_snmng_clique(res); for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { int64_t offset = batch_idx * n_rows_per_batch; int64_t query_offset = offset * n_cols; @@ -442,7 +461,7 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, } template -void run_search_batch(const raft::device_resources_snmg& clique, +void run_search_batch(const raft::resources& res, const index& index, int rank, const cuvs::neighbors::search_params* search_params, @@ -455,6 +474,7 @@ void run_search_batch(const raft::device_resources_snmg& clique, int64_t n_cols, int64_t n_neighbors) { + const auto& clique = get_snmng_clique(res); const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; @@ -481,34 +501,38 @@ void run_search_batch(const raft::device_resources_snmg& clique, } template -void search(const raft::device_resources_snmg& clique, +void search(const raft::resources& res, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch) + raft::host_matrix_view distances) { + const auto& clique = get_snmng_clique(res); int64_t n_rows = queries.extent(0); int64_t n_cols = queries.extent(1); int64_t n_neighbors = neighbors.extent(1); + int64_t n_rows_per_batch = -1; if (index.mode_ == REPLICATED) { cuvs::neighbors::mg::replicated_search_mode search_mode; if constexpr (std::is_same>::value) { const cuvs::neighbors::mg::search_params* mg_search_params = static_cast*>( search_params); - search_mode = mg_search_params->search_mode; + search_mode = mg_search_params->search_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { const cuvs::neighbors::mg::search_params* mg_search_params = static_cast*>( search_params); - search_mode = mg_search_params->search_mode; + search_mode = mg_search_params->search_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { const cuvs::neighbors::mg::search_params* mg_search_params = static_cast*>(search_params); - search_mode = mg_search_params->search_mode; + search_mode = mg_search_params->search_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } if (search_mode == LOAD_BALANCER) { @@ -571,16 +595,19 @@ void search(const raft::device_resources_snmg& clique, const cuvs::neighbors::mg::search_params* mg_search_params = static_cast*>( search_params); - merge_mode = mg_search_params->merge_mode; + merge_mode = mg_search_params->merge_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { const cuvs::neighbors::mg::search_params* mg_search_params = static_cast*>( search_params); - merge_mode = mg_search_params->merge_mode; + merge_mode = mg_search_params->merge_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { const cuvs::neighbors::mg::search_params* mg_search_params = static_cast*>(search_params); - merge_mode = mg_search_params->merge_mode; + merge_mode = mg_search_params->merge_mode; + n_rows_per_batch = mg_search_params->n_rows_per_batch; } int64_t n_batches = raft::ceildiv(n_rows, (int64_t)n_rows_per_batch); @@ -620,10 +647,11 @@ void search(const raft::device_resources_snmg& clique, } template -void serialize(const raft::device_resources_snmg& clique, +void serialize(const raft::resources& res, const index& index, const std::string& filename) { + const auto& clique = get_snmng_clique(res); std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } @@ -648,18 +676,17 @@ using namespace cuvs::neighbors; using namespace raft; template -index::index(distribution_mode mode, int num_ranks_) - : mode_(mode), - num_ranks_(num_ranks_), - round_robin_counter_(std::make_shared>(0)) +index::index(const raft::resources& res, distribution_mode mode) + : mode_(mode), round_robin_counter_(std::make_shared>(0)) { + const auto& clique = get_snmng_clique(res); + num_ranks_ = clique.get_num_ranks(); } template -index::index(const raft::device_resources_snmg& clique, - const std::string& filename) +index::index(const raft::resources& res, const std::string& filename) : round_robin_counter_(std::make_shared>(0)) { - cuvs::neighbors::mg::detail::deserialize(clique, *this, filename); + cuvs::neighbors::mg::detail::deserialize(res, *this, filename); } } // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index fbba29b25c..ef2ae66d51 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -25,67 +25,63 @@ #include "mg.cuh" -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - namespace cuvs::neighbors::cagra { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::resources& res, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(float, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index 4633cc77cf..16c8c436ba 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -25,67 +25,63 @@ #include "mg.cuh" -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - namespace cuvs::neighbors::cagra { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::resources& res, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(half, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index 4c15e09f5b..dae594c30a 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -25,67 +25,63 @@ #include "mg.cuh" -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - namespace cuvs::neighbors::cagra { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::resources& res, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(int8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index 8c585c2993..2f486f5d24 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -25,67 +25,63 @@ #include "mg.cuh" -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - namespace cuvs::neighbors::cagra { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::resources& res, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index ef84cff6c3..3f96e68677 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -30,68 +30,64 @@ using namespace cuvs::neighbors::mg; \ \ cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ + const raft::resources& res, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ cuvs::neighbors::mg::detail::build( \ - clique, \ + res, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void extend(const raft::device_resources_snmg& clique, \ + void extend(const raft::resources& res, \ cuvs::neighbors::mg::index, T, IdxT>& index, \ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ } \ \ - void search(const raft::device_resources_snmg& clique, \ + void search(const raft::resources& res, \ const cuvs::neighbors::mg::index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ + raft::host_matrix_view distances) \ { \ cuvs::neighbors::mg::detail::search( \ - clique, \ + res, \ index, \ static_cast(&search_params), \ queries, \ neighbors, \ - distances, \ - n_rows_per_batch); \ + distances); \ } \ \ - void serialize(const raft::device_resources_snmg& clique, \ + void serialize(const raft::resources& res, \ const cuvs::neighbors::mg::index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ } \ \ template <> \ cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ + const raft::resources& res, const std::string& filename) \ { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ return idx; \ } \ \ template <> \ cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ + const raft::resources& res, const std::string& filename) \ { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>( \ - REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 6e6daace7c..108c1ce478 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -30,68 +30,64 @@ using namespace cuvs::neighbors::mg; \ \ cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ + const raft::resources& res, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ cuvs::neighbors::mg::detail::build( \ - clique, \ + res, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void extend(const raft::device_resources_snmg& clique, \ + void extend(const raft::resources& res, \ cuvs::neighbors::mg::index, T, IdxT>& index, \ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ } \ \ - void search(const raft::device_resources_snmg& clique, \ + void search(const raft::resources& res, \ const cuvs::neighbors::mg::index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ + raft::host_matrix_view distances) \ { \ cuvs::neighbors::mg::detail::search( \ - clique, \ + res, \ index, \ static_cast(&search_params), \ queries, \ neighbors, \ - distances, \ - n_rows_per_batch); \ + distances); \ } \ \ - void serialize(const raft::device_resources_snmg& clique, \ + void serialize(const raft::resources& res, \ const cuvs::neighbors::mg::index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ } \ \ template <> \ cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ + const raft::resources& res, const std::string& filename) \ { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ return idx; \ } \ \ template <> \ cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ + const raft::resources& res, const std::string& filename) \ { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>( \ - REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index ab0f4fb10c..d180bbf95c 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -30,68 +30,64 @@ using namespace cuvs::neighbors::mg; \ \ cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ + const raft::resources& res, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ cuvs::neighbors::mg::detail::build( \ - clique, \ + res, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void extend(const raft::device_resources_snmg& clique, \ + void extend(const raft::resources& res, \ cuvs::neighbors::mg::index, T, IdxT>& index, \ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ } \ \ - void search(const raft::device_resources_snmg& clique, \ + void search(const raft::resources& res, \ const cuvs::neighbors::mg::index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ + raft::host_matrix_view distances) \ { \ cuvs::neighbors::mg::detail::search( \ - clique, \ + res, \ index, \ static_cast(&search_params), \ queries, \ neighbors, \ - distances, \ - n_rows_per_batch); \ + distances); \ } \ \ - void serialize(const raft::device_resources_snmg& clique, \ + void serialize(const raft::resources& res, \ const cuvs::neighbors::mg::index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ } \ \ template <> \ cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ + const raft::resources& res, const std::string& filename) \ { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ return idx; \ } \ \ template <> \ cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ + const raft::resources& res, const std::string& filename) \ { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>( \ - REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index 3b5268c699..9dee9c7762 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -25,75 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_PQ(T, IdxT) \ - namespace cuvs::neighbors::ivf_pq { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::resources& res, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index e6d18bd304..eb9a2b834e 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -25,75 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_PQ(T, IdxT) \ - namespace cuvs::neighbors::ivf_pq { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::resources& res, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(half, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index ead2186778..3bc08bd984 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -25,75 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_PQ(T, IdxT) \ - namespace cuvs::neighbors::ivf_pq { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::resources& res, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index 27d36c34be..5ce7e61f58 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -25,75 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_PQ(T, IdxT) \ - namespace cuvs::neighbors::ivf_pq { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ - clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::resources& res, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(uint8_t, int64_t); diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index 32532df3f0..f632f57986 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -138,8 +138,10 @@ class AnnMGTest : public ::testing::TestWithParam { search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; + + search_params.n_rows_per_batch = n_rows_per_search_batch; cuvs::neighbors::ivf_flat::search( - clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + clique_, new_index, search_params, queries, neighbors, distances); resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -197,8 +199,10 @@ class AnnMGTest : public ::testing::TestWithParam { search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; + + search_params.n_rows_per_batch = n_rows_per_search_batch; cuvs::neighbors::ivf_pq::search( - clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + clique_, new_index, search_params, queries, neighbors, distances); resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -250,8 +254,10 @@ class AnnMGTest : public ::testing::TestWithParam { search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; + + search_params.n_rows_per_batch = n_rows_per_search_batch; cuvs::neighbors::cagra::search( - clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + clique_, new_index, search_params, queries, neighbors, distances); resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -297,13 +303,10 @@ class AnnMGTest : public ::testing::TestWithParam { auto distributed_index = cuvs::neighbors::ivf_flat::distribute(clique_, "local_ivf_flat_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::ivf_flat::search(clique_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); + + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::ivf_flat::search( + clique_, distributed_index, search_params, queries, neighbors, distances); resource::sync_stream(clique_); @@ -349,13 +352,10 @@ class AnnMGTest : public ::testing::TestWithParam { auto distributed_index = cuvs::neighbors::ivf_pq::distribute(clique_, "local_ivf_pq_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::ivf_pq::search(clique_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); + + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::ivf_pq::search( + clique_, distributed_index, search_params, queries, neighbors, distances); resource::sync_stream(clique_); @@ -397,13 +397,10 @@ class AnnMGTest : public ::testing::TestWithParam { cuvs::neighbors::cagra::distribute(clique_, "local_cagra_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::cagra::search(clique_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); + + search_params.n_rows_per_batch = n_rows_per_search_batch; + cuvs::neighbors::cagra::search( + clique_, distributed_index, search_params, queries, neighbors, distances); resource::sync_stream(clique_); @@ -457,13 +454,14 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); + + search_params.n_rows_per_batch = n_rows_per_search_batch; cuvs::neighbors::ivf_flat::search(clique_, index, search_params, small_batch_query, small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + small_batch_distances); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), @@ -521,13 +519,14 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); + + search_params.n_rows_per_batch = n_rows_per_search_batch; cuvs::neighbors::ivf_pq::search(clique_, index, search_params, small_batch_query, small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + small_batch_distances); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), @@ -580,13 +579,14 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); + + search_params.n_rows_per_batch = n_rows_per_search_batch; cuvs::neighbors::cagra::search(clique_, index, search_params, small_batch_query, small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + small_batch_distances); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), From ddaeee9908315460cdc18ae3dbad4dac1d0d282e Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 28 Feb 2025 15:40:13 +0100 Subject: [PATCH 10/17] Updating MG API --- cpp/include/cuvs/neighbors/cagra.hpp | 72 +++---- cpp/include/cuvs/neighbors/common.hpp | 12 +- cpp/include/cuvs/neighbors/ivf_flat.hpp | 56 +++--- cpp/include/cuvs/neighbors/ivf_pq.hpp | 72 +++---- cpp/src/neighbors/iface/generate_iface.py | 54 +++--- cpp/src/neighbors/iface/iface.hpp | 14 +- .../iface/iface_cagra_float_uint32_t.cu | 18 +- .../iface/iface_cagra_half_uint32_t.cu | 18 +- .../iface/iface_cagra_int8_t_uint32_t.cu | 18 +- .../iface/iface_cagra_uint8_t_uint32_t.cu | 18 +- .../iface/iface_flat_float_int64_t.cu | 18 +- .../iface/iface_flat_int8_t_int64_t.cu | 18 +- .../iface/iface_flat_uint8_t_int64_t.cu | 18 +- .../neighbors/iface/iface_pq_float_int64_t.cu | 18 +- .../neighbors/iface/iface_pq_half_int64_t.cu | 18 +- .../iface/iface_pq_int8_t_int64_t.cu | 18 +- .../iface/iface_pq_uint8_t_int64_t.cu | 18 +- cpp/src/neighbors/mg/mg.cuh | 182 ++++++++---------- 18 files changed, 319 insertions(+), 341 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 6c32a29037..b6ea310bce 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -2003,12 +2003,12 @@ auto merge(raft::resources const& res, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -2025,12 +2025,12 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -2047,12 +2047,12 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -2069,12 +2069,12 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -2093,13 +2093,13 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -2117,13 +2117,13 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -2141,13 +2141,13 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -2165,13 +2165,13 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -2191,7 +2191,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -2199,7 +2199,7 @@ void extend(const raft::resources& clique, * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -2220,7 +2220,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -2228,7 +2228,7 @@ void search(const raft::resources& clique, * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -2249,7 +2249,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -2257,7 +2257,7 @@ void search(const raft::resources& clique, * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -2279,7 +2279,7 @@ void search( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -2287,7 +2287,7 @@ void search( * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -2311,14 +2311,14 @@ void search( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::cagra::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -2334,14 +2334,14 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::cagra::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -2357,14 +2357,14 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::cagra::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -2380,14 +2380,14 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::cagra::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -2405,7 +2405,7 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2414,7 +2414,7 @@ void serialize( * * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] filename path to the file to be deserialized * */ @@ -2431,7 +2431,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::cagra::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; @@ -2440,7 +2440,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) * * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] filename path to the file to be deserialized : a local index * */ diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 038b6b1da5..2e275242b6 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -756,21 +756,21 @@ struct iface { }; template -void build(const raft::device_resources& handle, +void build(const raft::resources& handle, cuvs::neighbors::iface& interface, const cuvs::neighbors::index_params* index_params, raft::mdspan, row_major, Accessor> index_dataset); template void extend( - const raft::device_resources& handle, + const raft::resources& handle, cuvs::neighbors::iface& interface, raft::mdspan, row_major, Accessor1> new_vectors, std::optional, layout_c_contiguous, Accessor2>> new_indices); template -void search(const raft::device_resources& handle, +void search(const raft::resources& handle, const cuvs::neighbors::iface& interface, const cuvs::neighbors::search_params* search_params, raft::device_matrix_view h_queries, @@ -778,17 +778,17 @@ void search(const raft::device_resources& handle, raft::device_matrix_view d_distances); template -void serialize(const raft::device_resources& handle, +void serialize(const raft::resources& handle, const cuvs::neighbors::iface& interface, std::ostream& os); template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::resources& handle, cuvs::neighbors::iface& interface, std::istream& is); template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::resources& handle, cuvs::neighbors::iface& interface, const std::string& filename); diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index e0f8e572cd..db132af0e8 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -1612,12 +1612,12 @@ void deserialize(raft::resources const& handle, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -1634,12 +1634,12 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -1656,12 +1656,12 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -1680,13 +1680,13 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -1704,13 +1704,13 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -1728,13 +1728,13 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -1754,7 +1754,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -1762,7 +1762,7 @@ void extend(const raft::resources& clique, * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -1784,7 +1784,7 @@ void search( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -1792,7 +1792,7 @@ void search( * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -1814,7 +1814,7 @@ void search( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -1822,7 +1822,7 @@ void search( * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -1846,14 +1846,14 @@ void search( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -1869,14 +1869,14 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -1892,14 +1892,14 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -1915,7 +1915,7 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -1924,7 +1924,7 @@ void serialize( * * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] filename path to the file to be deserialized * */ @@ -1941,7 +1941,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::ivf_flat::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; @@ -1950,7 +1950,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) * * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] filename path to the file to be deserialized : a local index * */ diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 8396baeb73..6bf1315be4 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -1742,12 +1742,12 @@ void deserialize(raft::resources const& handle, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -1764,12 +1764,12 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -1786,12 +1786,12 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -1808,12 +1808,12 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * @@ -1832,13 +1832,13 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -1856,13 +1856,13 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -1880,13 +1880,13 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -1904,13 +1904,13 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], @@ -1930,7 +1930,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -1938,7 +1938,7 @@ void extend(const raft::resources& clique, * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -1959,7 +1959,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -1967,7 +1967,7 @@ void search(const raft::resources& clique, * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -1988,7 +1988,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -1996,7 +1996,7 @@ void search(const raft::resources& clique, * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -2017,7 +2017,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; @@ -2025,7 +2025,7 @@ void search(const raft::resources& clique, * distances); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -2048,14 +2048,14 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -2070,14 +2070,14 @@ void serialize(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -2092,14 +2092,14 @@ void serialize(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -2114,14 +2114,14 @@ void serialize(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * @@ -2136,7 +2136,7 @@ void serialize(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::mg::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2144,7 +2144,7 @@ void serialize(const raft::resources& clique, * auto new_index = cuvs::neighbors::ivf_pq::deserialize(clique, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] filename path to the file to be deserialized * */ @@ -2161,7 +2161,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) * * Usage example: * @code{.cpp} - * raft::device_resources_snmg clique; + * raft::resources clique; * cuvs::neighbors::ivf_pq::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; @@ -2169,7 +2169,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) * auto new_index = cuvs::neighbors::ivf_pq::distribute(clique, filename); * @endcode * - * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] clique a `raft::resources` object specifying the NCCL clique configuration * @param[in] filename path to the file to be deserialized : a local index * */ diff --git a/cpp/src/neighbors/iface/generate_iface.py b/cpp/src/neighbors/iface/generate_iface.py index 794219bbf3..698bee88d8 100644 --- a/cpp/src/neighbors/iface/generate_iface.py +++ b/cpp/src/neighbors/iface/generate_iface.py @@ -58,49 +58,49 @@ using IdxT_ha = raft::host_device_accessor, raft::memory_type::device>; \\ using IdxT_da = raft::host_device_accessor, raft::memory_type::host>; \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_ha> index_dataset); \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_da> index_dataset); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_ha> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_ha>> new_indices); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_da> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_da>> new_indices); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::device_matrix_view queries, \\ raft::device_matrix_view neighbors, \\ raft::device_matrix_view distances); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::host_matrix_view h_queries, \\ raft::device_matrix_view d_neighbors, \\ raft::device_matrix_view d_distances); \\ \\ - template void serialize(const raft::device_resources& handle, \\ + template void serialize(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ std::ostream& os); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ std::istream& is); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const std::string& filename); """ @@ -112,49 +112,49 @@ using IdxT_ha = raft::host_device_accessor, raft::memory_type::device>; \\ using IdxT_da = raft::host_device_accessor, raft::memory_type::host>; \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_ha> index_dataset); \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_da> index_dataset); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_ha> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_ha>> new_indices); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_da> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_da>> new_indices); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::device_matrix_view queries, \\ raft::device_matrix_view neighbors, \\ raft::device_matrix_view distances); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::host_matrix_view h_queries, \\ raft::device_matrix_view d_neighbors, \\ raft::device_matrix_view d_distances); \\ \\ - template void serialize(const raft::device_resources& handle, \\ + template void serialize(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ std::ostream& os); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ std::istream& is); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const std::string& filename); """ @@ -166,49 +166,49 @@ using IdxT_ha = raft::host_device_accessor, raft::memory_type::device>; \\ using IdxT_da = raft::host_device_accessor, raft::memory_type::host>; \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_ha> index_dataset); \\ \\ - template void build(const raft::device_resources& handle, \\ + template void build(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::index_params* index_params, \\ raft::mdspan, row_major, T_da> index_dataset); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_ha> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_ha>> new_indices); \\ \\ - template void extend(const raft::device_resources& handle, \\ + template void extend(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ raft::mdspan, row_major, T_da> new_vectors, \\ std::optional, layout_c_contiguous, IdxT_da>> new_indices); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::device_matrix_view queries, \\ raft::device_matrix_view neighbors, \\ raft::device_matrix_view distances); \\ \\ - template void search(const raft::device_resources& handle, \\ + template void search(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ const cuvs::neighbors::search_params* search_params, \\ raft::host_matrix_view h_queries, \\ raft::device_matrix_view d_neighbors, \\ raft::device_matrix_view d_distances); \\ \\ - template void serialize(const raft::device_resources& handle, \\ + template void serialize(const raft::resources& handle, \\ const cuvs::neighbors::iface, T, IdxT>& interface, \\ std::ostream& os); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ std::istream& is); \\ \\ - template void deserialize(const raft::device_resources& handle, \\ + template void deserialize(const raft::resources& handle, \\ cuvs::neighbors::iface, T, IdxT>& interface, \\ const std::string& filename); """ diff --git a/cpp/src/neighbors/iface/iface.hpp b/cpp/src/neighbors/iface/iface.hpp index 59d3e12d0e..8ccc0d54b9 100644 --- a/cpp/src/neighbors/iface/iface.hpp +++ b/cpp/src/neighbors/iface/iface.hpp @@ -31,7 +31,7 @@ namespace cuvs::neighbors { using namespace raft; template -void build(const raft::device_resources& handle, +void build(const raft::resources& handle, cuvs::neighbors::iface& interface, const cuvs::neighbors::index_params* index_params, raft::mdspan, row_major, Accessor> index_dataset) @@ -56,7 +56,7 @@ void build(const raft::device_resources& handle, template void extend( - const raft::device_resources& handle, + const raft::resources& handle, cuvs::neighbors::iface& interface, raft::mdspan, row_major, Accessor1> new_vectors, std::optional, layout_c_contiguous, Accessor2>> @@ -79,7 +79,7 @@ void extend( } template -void search(const raft::device_resources& handle, +void search(const raft::resources& handle, const cuvs::neighbors::iface& interface, const cuvs::neighbors::search_params* search_params, raft::device_matrix_view queries, @@ -115,7 +115,7 @@ void search(const raft::device_resources& handle, // for MG ANN only template -void search(const raft::device_resources& handle, +void search(const raft::resources& handle, const cuvs::neighbors::iface& interface, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view h_queries, @@ -137,7 +137,7 @@ void search(const raft::device_resources& handle, } template -void serialize(const raft::device_resources& handle, +void serialize(const raft::resources& handle, const cuvs::neighbors::iface& interface, std::ostream& os) { @@ -153,7 +153,7 @@ void serialize(const raft::device_resources& handle, } template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::resources& handle, cuvs::neighbors::iface& interface, std::istream& is) { @@ -175,7 +175,7 @@ void deserialize(const raft::device_resources& handle, } template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::resources& handle, cuvs::neighbors::iface& interface, const std::string& filename) { diff --git a/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu index b5e329dd82..b7ad428ad8 100644 --- a/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_CAGRA(float, uint32_t); diff --git a/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu index 23fcffc59a..86e0633bb1 100644 --- a/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_CAGRA(half, uint32_t); diff --git a/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu index 30377ab666..64f174184b 100644 --- a/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_CAGRA(int8_t, uint32_t); diff --git a/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu index 59a1640e89..9f6db32df8 100644 --- a/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_CAGRA(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu b/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu index a0a4553753..0afffe0ad3 100644 --- a/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu @@ -38,39 +38,39 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ @@ -78,15 +78,15 @@ namespace cuvs::neighbors { raft::device_matrix_view d_distances); \ \ template void serialize( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_FLAT(float, int64_t); diff --git a/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu index 9fdd6464ff..5afd77053e 100644 --- a/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu @@ -38,39 +38,39 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ @@ -78,15 +78,15 @@ namespace cuvs::neighbors { raft::device_matrix_view d_distances); \ \ template void serialize( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_FLAT(int8_t, int64_t); diff --git a/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu index daee59c4a8..4f2f85700c 100644 --- a/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu @@ -38,39 +38,39 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ @@ -78,15 +78,15 @@ namespace cuvs::neighbors { raft::device_matrix_view d_distances); \ \ template void serialize( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_FLAT(uint8_t, int64_t); diff --git a/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu index 7282d6bd07..90759d5f1a 100644 --- a/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_PQ(float, int64_t); diff --git a/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu index 4d67f9aed1..c92d6fd651 100644 --- a/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_PQ(half, int64_t); diff --git a/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu index 46537b3f9f..59269e9da1 100644 --- a/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_PQ(int8_t, int64_t); diff --git a/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu index 591ac881a8..c407e64cac 100644 --- a/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu @@ -38,54 +38,54 @@ namespace cuvs::neighbors { raft::memory_type::host>; \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_ha> index_dataset); \ \ template void build( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::index_params* index_params, \ raft::mdspan, row_major, T_da> index_dataset); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_ha> new_vectors, \ std::optional, layout_c_contiguous, IdxT_ha>> \ new_indices); \ \ template void extend( \ - const raft::device_resources& handle, \ + const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ raft::mdspan, row_major, T_da> new_vectors, \ std::optional, layout_c_contiguous, IdxT_da>> \ new_indices); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); \ \ - template void search(const raft::device_resources& handle, \ + template void search(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ const cuvs::neighbors::search_params* search_params, \ raft::host_matrix_view h_queries, \ raft::device_matrix_view d_neighbors, \ raft::device_matrix_view d_distances); \ \ - template void serialize(const raft::device_resources& handle, \ + template void serialize(const raft::resources& handle, \ const cuvs::neighbors::iface, T, IdxT>& interface, \ std::ostream& os); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ std::istream& is); \ \ - template void deserialize(const raft::device_resources& handle, \ + template void deserialize(const raft::resources& handle, \ cuvs::neighbors::iface, T, IdxT>& interface, \ const std::string& filename); CUVS_INST_MG_PQ(uint8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 616e779f14..65da31bbac 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -17,7 +17,7 @@ #pragma once #include "../detail/knn_merge_parts.cuh" -#include +#include #include #include #include @@ -33,20 +33,8 @@ namespace cuvs::neighbors { using namespace raft; -inline const raft::device_resources_snmg& get_snmng_clique(const raft::resources& res) -{ - try { - const raft::device_resources_snmg& clique = - dynamic_cast(res); - return clique; - } catch (const std::bad_cast& e) { - throw std::runtime_error( - "MG function was not used with an appropriate raft::device_resources_snmg object"); - } -} - template -void search(const raft::device_resources& handle, +void search(const raft::resources& handle, const cuvs::neighbors::iface& interface, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view h_queries, @@ -64,41 +52,39 @@ using namespace raft; // local index deserialization and distribution template -void deserialize_and_distribute(const raft::resources& res, +void deserialize_and_distribute(const raft::resources& clique, index& index, const std::string& filename) { - const auto& clique = get_snmng_clique(res); for (int rank = 0; rank < index.num_ranks_; rank++) { - const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, filename); } } // MG index deserialization template -void deserialize(const raft::resources& res, +void deserialize(const raft::resources& clique, index& index, const std::string& filename) { - const auto& clique = get_snmng_clique(res); std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const auto& handle = clique.set_current_device_to_root_rank(); + const auto& handle = raft::resource::set_current_device_to_root_rank(clique); index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); index.num_ranks_ = deserialize_scalar(handle, is); - if (index.num_ranks_ != clique.get_num_ranks()) { + if (index.num_ranks_ != raft::resource::get_num_ranks(clique)) { RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks", index.num_ranks_, - clique.get_num_ranks()); + raft::resource::get_num_ranks(clique)); } for (int rank = 0; rank < index.num_ranks_; rank++) { - const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, is); } @@ -106,12 +92,11 @@ void deserialize(const raft::resources& res, } template -void build(const raft::resources& res, +void build(const raft::resources& clique, index& index, const cuvs::neighbors::index_params* index_params, raft::host_matrix_view index_dataset) { - const auto& clique = get_snmng_clique(res); if (index.mode_ == REPLICATED) { int64_t n_rows = index_dataset.extent(0); RAFT_LOG_DEBUG("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); @@ -119,8 +104,8 @@ void build(const raft::resources& res, index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, index_dataset); resource::sync_stream(dev_res); } @@ -134,11 +119,11 @@ void build(const raft::resources& res, index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); - int64_t offset = rank * n_rows_per_shard; - int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); - auto partition = raft::make_host_matrix_view( + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + int64_t offset = rank * n_rows_per_shard; + int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); + auto partition = raft::make_host_matrix_view( partition_ptr, n_rows_of_current_shard, n_cols); auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, partition); @@ -148,20 +133,19 @@ void build(const raft::resources& res, } template -void extend(const raft::resources& res, +void extend(const raft::resources& clique, index& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { - const auto& clique = get_snmng_clique(res); - int64_t n_rows = new_vectors.extent(0); + int64_t n_rows = new_vectors.extent(0); if (index.mode_ == REPLICATED) { RAFT_LOG_DEBUG("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::extend(dev_res, ann_if, new_vectors, new_indices); resource::sync_stream(dev_res); } @@ -173,11 +157,11 @@ void extend(const raft::resources& res, #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); - int64_t offset = rank * n_rows_per_shard; - int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); - auto new_vectors_part = raft::make_host_matrix_view( + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + int64_t offset = rank * n_rows_per_shard; + int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); + auto new_vectors_part = raft::make_host_matrix_view( new_vectors_ptr, n_rows_of_current_shard, n_cols); std::optional> new_indices_part = std::nullopt; @@ -194,7 +178,7 @@ void extend(const raft::resources& res, } template -void sharded_search_with_direct_merge(const raft::resources& res, +void sharded_search_with_direct_merge(const raft::resources& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -206,8 +190,7 @@ void sharded_search_with_direct_merge(const raft::resources& res, int64_t n_neighbors, int64_t n_batches) { - const auto& clique = get_snmng_clique(res); - const auto& root_handle = clique.set_current_device_to_root_rank(); + const auto& root_handle = raft::resource::set_current_device_to_root_rank(clique); auto in_neighbors = raft::make_device_matrix( root_handle, index.num_ranks_ * n_rows_per_batch, n_neighbors); auto in_distances = raft::make_device_matrix( @@ -230,11 +213,11 @@ void sharded_search_with_direct_merge(const raft::resources& res, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; - if (rank == clique.get_root_rank()) { // root rank - uint64_t batch_offset = clique.get_root_rank() * part_size; + if (rank == raft::resource::get_clique_root_rank(clique)) { // root rank + uint64_t batch_offset = raft::resource::get_clique_root_rank(clique) * part_size; auto d_neighbors = raft::make_device_matrix_view( in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); auto d_distances = raft::make_device_matrix_view( @@ -245,21 +228,21 @@ void sharded_search_with_direct_merge(const raft::resources& res, // wait for other ranks ncclGroupStart(); for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) { - if (from_rank == clique.get_root_rank()) continue; + if (from_rank == raft::resource::get_clique_root_rank(clique)) continue; batch_offset = from_rank * part_size; ncclRecv(in_neighbors.data_handle() + batch_offset, part_size * sizeof(IdxT), ncclUint8, from_rank, - clique.get_nccl_comm(rank), - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); ncclRecv(in_distances.data_handle() + batch_offset, part_size * sizeof(float), ncclUint8, from_rank, - clique.get_nccl_comm(rank), - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); resource::sync_stream(dev_res); @@ -276,21 +259,21 @@ void sharded_search_with_direct_merge(const raft::resources& res, ncclSend(d_neighbors.data_handle(), part_size * sizeof(IdxT), ncclUint8, - clique.get_root_rank(), - clique.get_nccl_comm(rank), - resource::get_cuda_stream(dev_res)); + raft::resource::get_clique_root_rank(clique), + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); ncclSend(d_distances.data_handle(), part_size * sizeof(float), ncclUint8, - clique.get_root_rank(), - clique.get_nccl_comm(rank), - resource::get_cuda_stream(dev_res)); + raft::resource::get_clique_root_rank(clique), + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); ncclGroupEnd(); resource::sync_stream(dev_res); } } - const auto& root_handle_ = clique.set_current_device_to_root_rank(); + const auto& root_handle_ = raft::resource::set_current_device_to_root_rank(clique); auto h_trans = std::vector(index.num_ranks_); int64_t translation_offset = 0; for (int rank = 0; rank < index.num_ranks_; rank++) { @@ -301,7 +284,7 @@ void sharded_search_with_direct_merge(const raft::resources& res, raft::copy(d_trans.data_handle(), h_trans.data(), index.num_ranks_, - resource::get_cuda_stream(root_handle_)); + raft::resource::get_cuda_stream(root_handle_)); cuvs::neighbors::detail::knn_merge_parts(in_distances.data_handle(), in_neighbors.data_handle(), @@ -310,24 +293,24 @@ void sharded_search_with_direct_merge(const raft::resources& res, n_rows_of_current_batch, index.num_ranks_, n_neighbors, - resource::get_cuda_stream(root_handle_), + raft::resource::get_cuda_stream(root_handle_), d_trans.data_handle()); raft::copy(neighbors.data_handle() + output_offset, out_neighbors.data_handle(), part_size, - resource::get_cuda_stream(root_handle_)); + raft::resource::get_cuda_stream(root_handle_)); raft::copy(distances.data_handle() + output_offset, out_distances.data_handle(), part_size, - resource::get_cuda_stream(root_handle_)); + raft::resource::get_cuda_stream(root_handle_)); resource::sync_stream(root_handle_); } } template -void sharded_search_with_tree_merge(const raft::resources& res, +void sharded_search_with_tree_merge(const raft::resources& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -339,7 +322,6 @@ void sharded_search_with_tree_merge(const raft::resources& res, int64_t n_neighbors, int64_t n_batches) { - const auto& clique = get_snmng_clique(res); for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { int64_t offset = batch_idx * n_rows_per_batch; int64_t query_offset = offset * n_cols; @@ -352,8 +334,8 @@ void sharded_search_with_tree_merge(const raft::resources& res, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; int64_t part_size = n_rows_of_current_batch * n_neighbors; @@ -376,11 +358,11 @@ void sharded_search_with_tree_merge(const raft::resources& res, neighbors_view.data_handle(), (IdxT)translation_offset, part_size, - resource::get_cuda_stream(dev_res)); + raft::resource::get_cuda_stream(dev_res)); auto d_trans = raft::make_device_vector(dev_res, 2); cudaMemsetAsync( - d_trans.data_handle(), 0, 2 * sizeof(IdxT), resource::get_cuda_stream(dev_res)); + d_trans.data_handle(), 0, 2 * sizeof(IdxT), raft::resource::get_cuda_stream(dev_res)); int64_t remaining = index.num_ranks_; int64_t radix = 2; @@ -398,14 +380,14 @@ void sharded_search_with_tree_merge(const raft::resources& res, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.get_nccl_comm(rank), - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); ncclRecv(tmp_distances.data_handle() + part_size, part_size * sizeof(float), ncclUint8, other_id, - clique.get_nccl_comm(rank), - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); received_something = true; } } else if (rank % radix == offset) // This is one of the senders @@ -415,14 +397,14 @@ void sharded_search_with_tree_merge(const raft::resources& res, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.get_nccl_comm(rank), - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); ncclSend(tmp_distances.data_handle(), part_size * sizeof(float), ncclUint8, other_id, - clique.get_nccl_comm(rank), - resource::get_cuda_stream(dev_res)); + raft::resource::get_nccl_comm(dev_res), + raft::resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); @@ -438,7 +420,7 @@ void sharded_search_with_tree_merge(const raft::resources& res, n_rows_of_current_batch, 2, n_neighbors, - resource::get_cuda_stream(dev_res), + raft::resource::get_cuda_stream(dev_res), d_trans.data_handle()); // If done, copy the final result @@ -446,11 +428,11 @@ void sharded_search_with_tree_merge(const raft::resources& res, raft::copy(neighbors.data_handle() + output_offset, tmp_neighbors.data_handle(), part_size, - resource::get_cuda_stream(dev_res)); + raft::resource::get_cuda_stream(dev_res)); raft::copy(distances.data_handle() + output_offset, tmp_distances.data_handle(), part_size, - resource::get_cuda_stream(dev_res)); + raft::resource::get_cuda_stream(dev_res)); resource::sync_stream(dev_res); } @@ -461,7 +443,7 @@ void sharded_search_with_tree_merge(const raft::resources& res, } template -void run_search_batch(const raft::resources& res, +void run_search_batch(const raft::resources& clique, const index& index, int rank, const cuvs::neighbors::search_params* search_params, @@ -474,9 +456,8 @@ void run_search_batch(const raft::resources& res, int64_t n_cols, int64_t n_neighbors) { - const auto& clique = get_snmng_clique(res); - const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; auto query_partition = raft::make_host_matrix_view( queries.data_handle() + query_offset, n_rows_of_current_batch, n_cols); @@ -491,24 +472,23 @@ void run_search_batch(const raft::resources& res, raft::copy(neighbors.data_handle() + output_offset, d_neighbors.data_handle(), n_rows_of_current_batch * n_neighbors, - resource::get_cuda_stream(dev_res)); + raft::resource::get_cuda_stream(dev_res)); raft::copy(distances.data_handle() + output_offset, d_distances.data_handle(), n_rows_of_current_batch * n_neighbors, - resource::get_cuda_stream(dev_res)); + raft::resource::get_cuda_stream(dev_res)); resource::sync_stream(dev_res); } template -void search(const raft::resources& res, +void search(const raft::resources& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { - const auto& clique = get_snmng_clique(res); int64_t n_rows = queries.extent(0); int64_t n_cols = queries.extent(1); int64_t n_neighbors = neighbors.extent(1); @@ -647,21 +627,20 @@ void search(const raft::resources& res, } template -void serialize(const raft::resources& res, +void serialize(const raft::resources& clique, const index& index, const std::string& filename) { - const auto& clique = get_snmng_clique(res); std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const auto& handle = clique.set_current_device_to_root_rank(); + const auto& handle = raft::resource::set_current_device_to_root_rank(clique); serialize_scalar(handle, of, (int)index.mode_); serialize_scalar(handle, of, index.num_ranks_); for (int rank = 0; rank < index.num_ranks_; rank++) { - const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::serialize(dev_res, ann_if, of); } @@ -676,17 +655,16 @@ using namespace cuvs::neighbors; using namespace raft; template -index::index(const raft::resources& res, distribution_mode mode) +index::index(const raft::resources& clique, distribution_mode mode) : mode_(mode), round_robin_counter_(std::make_shared>(0)) { - const auto& clique = get_snmng_clique(res); - num_ranks_ = clique.get_num_ranks(); + num_ranks_ = raft::resource::get_num_ranks(clique); } template -index::index(const raft::resources& res, const std::string& filename) +index::index(const raft::resources& clique, const std::string& filename) : round_robin_counter_(std::make_shared>(0)) { - cuvs::neighbors::mg::detail::deserialize(res, *this, filename); + cuvs::neighbors::mg::detail::deserialize(clique, *this, filename); } } // namespace cuvs::neighbors::mg From 3075869d951b10e8938fb925fb2cc25d3b1a92cd Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 4 Mar 2025 15:12:43 +0100 Subject: [PATCH 11/17] Updating CUVS_BUILD_MG_ALGOS ifdef in header files --- cpp/include/cuvs/neighbors/cagra.hpp | 15 +++++++++++---- cpp/include/cuvs/neighbors/ivf_flat.hpp | 15 +++++++++++---- cpp/include/cuvs/neighbors/ivf_pq.hpp | 15 +++++++++++---- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index b6ea310bce..c9e23bbf65 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -33,6 +33,17 @@ #ifdef CUVS_BUILD_MG_ALGOS #include +#else +namespace cuvs::neighbors::mg { +template +struct index_params; + +template +struct search_params; + +template +struct index; +} // namespace cuvs::neighbors::mg #endif #include @@ -1993,8 +2004,6 @@ auto merge(raft::resources const& res, * @} */ -#ifdef CUVS_BUILD_MG_ALGOS - /// \defgroup mg_cpp_index_build ANN MG index build /// \ingroup mg_cpp_index_build @@ -2448,6 +2457,4 @@ template auto distribute(const raft::resources& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; -#endif - } // namespace cuvs::neighbors::cagra diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index db132af0e8..d3c519362d 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -24,6 +24,17 @@ #ifdef CUVS_BUILD_MG_ALGOS #include +#else +namespace cuvs::neighbors::mg { +template +struct index_params; + +template +struct search_params; + +template +struct index; +} // namespace cuvs::neighbors::mg #endif namespace cuvs::neighbors::ivf_flat { @@ -1602,8 +1613,6 @@ void deserialize(raft::resources const& handle, * @} */ -#ifdef CUVS_BUILD_MG_ALGOS - /// \defgroup mg_cpp_index_build ANN MG index build /// \ingroup mg_cpp_index_build @@ -1958,8 +1967,6 @@ template auto distribute(const raft::resources& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; -#endif - namespace helpers { /** diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 6bf1315be4..fdfa0807e9 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -29,6 +29,17 @@ #ifdef CUVS_BUILD_MG_ALGOS #include +#else +namespace cuvs::neighbors::mg { +template +struct index_params; + +template +struct search_params; + +template +struct index; +} // namespace cuvs::neighbors::mg #endif namespace cuvs::neighbors::ivf_pq { @@ -1732,8 +1743,6 @@ void deserialize(raft::resources const& handle, * @} */ -#ifdef CUVS_BUILD_MG_ALGOS - /// \defgroup mg_cpp_index_build ANN MG index build /// \ingroup mg_cpp_index_build @@ -2177,8 +2186,6 @@ template auto distribute(const raft::resources& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; -#endif - namespace helpers { /** * @defgroup ivf_pq_cpp_helpers IVF-PQ helper methods From 4a422dfbddd0fca799e6576cbf34a7c939de1b0a Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 5 Mar 2025 11:04:05 +0100 Subject: [PATCH 12/17] moving mg structs away from mg namespace --- cpp/bench/ann/src/cuvs/cuvs_benchmark.cu | 16 +-- .../ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 21 ++- .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 21 +-- .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 17 ++- cpp/include/cuvs/neighbors/cagra.hpp | 114 +++++++-------- cpp/include/cuvs/neighbors/common.hpp | 83 +++++++++++ cpp/include/cuvs/neighbors/ivf_flat.hpp | 96 ++++++------- cpp/include/cuvs/neighbors/ivf_pq.hpp | 109 +++++++-------- cpp/include/cuvs/neighbors/mg.hpp | 112 --------------- cpp/src/neighbors/mg/generate_mg.py | 70 +++++----- cpp/src/neighbors/mg/mg.cuh | 60 ++++---- .../neighbors/mg/mg_cagra_float_uint32_t.cu | 114 +++++++-------- .../neighbors/mg/mg_cagra_half_uint32_t.cu | 114 +++++++-------- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 114 +++++++-------- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 114 +++++++-------- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 130 +++++++++--------- .../neighbors/mg/mg_flat_int8_t_int64_t.cu | 130 +++++++++--------- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 130 +++++++++--------- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 130 +++++++++--------- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 130 +++++++++--------- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 130 +++++++++--------- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 130 +++++++++--------- cpp/tests/neighbors/mg.cuh | 30 ++-- docs/source/cpp_api/neighbors_mg.rst | 2 +- 24 files changed, 1019 insertions(+), 1098 deletions(-) delete mode 100644 cpp/include/cuvs/neighbors/mg.hpp diff --git a/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu b/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu index 8930972365..821b80cc72 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu @@ -30,38 +30,38 @@ namespace cuvs::bench { #ifdef CUVS_ANN_BENCH_USE_CUVS_MG -void add_distribution_mode(cuvs::neighbors::mg::distribution_mode* dist_mode, +void add_distribution_mode(cuvs::neighbors::distribution_mode* dist_mode, const nlohmann::json& conf) { if (conf.contains("distribution_mode")) { std::string distribution_mode = conf.at("distribution_mode"); if (distribution_mode == "replicated") { - *dist_mode = cuvs::neighbors::mg::distribution_mode::REPLICATED; + *dist_mode = cuvs::neighbors::distribution_mode::REPLICATED; } else if (distribution_mode == "sharded") { - *dist_mode = cuvs::neighbors::mg::distribution_mode::SHARDED; + *dist_mode = cuvs::neighbors::distribution_mode::SHARDED; } else { throw std::runtime_error("invalid value for distribution_mode"); } } else { // default - *dist_mode = cuvs::neighbors::mg::distribution_mode::SHARDED; + *dist_mode = cuvs::neighbors::distribution_mode::SHARDED; } }; -void add_merge_mode(cuvs::neighbors::mg::sharded_merge_mode* merge_mode, const nlohmann::json& conf) +void add_merge_mode(cuvs::neighbors::sharded_merge_mode* merge_mode, const nlohmann::json& conf) { if (conf.contains("merge_mode")) { std::string sharded_merge_mode = conf.at("merge_mode"); if (sharded_merge_mode == "tree_merge") { - *merge_mode = cuvs::neighbors::mg::sharded_merge_mode::TREE_MERGE; + *merge_mode = cuvs::neighbors::sharded_merge_mode::TREE_MERGE; } else if (sharded_merge_mode == "merge_on_root_rank") { - *merge_mode = cuvs::neighbors::mg::sharded_merge_mode::MERGE_ON_ROOT_RANK; + *merge_mode = cuvs::neighbors::sharded_merge_mode::MERGE_ON_ROOT_RANK; } else { throw std::runtime_error("invalid value for merge_mode"); } } else { // default - *merge_mode = cuvs::neighbors::mg::sharded_merge_mode::TREE_MERGE; + *merge_mode = cuvs::neighbors::sharded_merge_mode::TREE_MERGE; } }; #endif diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index 823b80d4c3..a655fdf674 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -17,7 +17,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_cagra_wrapper.h" -#include +#include #include namespace cuvs::bench { @@ -33,11 +33,11 @@ class cuvs_mg_cagra : public algo, public algo_gpu { using algo::dim_; struct build_param : public cuvs::bench::cuvs_cagra::build_param { - cuvs::neighbors::mg::distribution_mode mode; + cuvs::neighbors::distribution_mode mode; }; struct search_param : public cuvs::bench::cuvs_cagra::search_param { - cuvs::neighbors::mg::sharded_merge_mode merge_mode; + cuvs::neighbors::sharded_merge_mode merge_mode; }; cuvs_mg_cagra(Metric metric, int dim, const build_param& param, int concurrent_searches = 1) @@ -89,8 +89,8 @@ class cuvs_mg_cagra : public algo, public algo_gpu { raft::device_resources_snmg clique_; float refine_ratio_; build_param index_params_; - cuvs::neighbors::mg::search_params search_params_; - std::shared_ptr, T, IdxT>> + cuvs::neighbors::mg_search_params search_params_; + std::shared_ptr, T, IdxT>> index_; }; @@ -99,14 +99,14 @@ void cuvs_mg_cagra::build(const T* dataset, size_t nrow) { auto dataset_extents = raft::make_extents(nrow, dim_); index_params_.prepare_build_params(dataset_extents); - cuvs::neighbors::mg::index_params build_params = index_params_.cagra_params; - build_params.mode = index_params_.mode; + cuvs::neighbors::mg_index_params build_params = index_params_.cagra_params; + build_params.mode = index_params_.mode; auto dataset_view = raft::make_host_matrix_view(dataset, nrow, dim_); auto idx = cuvs::neighbors::cagra::build(clique_, build_params, dataset_view); index_ = - std::make_shared, T, IdxT>>( + std::make_shared, T, IdxT>>( std::move(idx)); } @@ -117,8 +117,7 @@ void cuvs_mg_cagra::set_search_param(const search_param_base& param, const void* filter_bitset) { if (filter_bitset != nullptr) { throw std::runtime_error("Filtering is not supported yet."); } - auto sp = dynamic_cast(param); - // search_params_ = static_cast>(sp.p); + auto sp = dynamic_cast(param); cagra::search_params* search_params_ptr_ = static_cast(&search_params_); *search_params_ptr_ = sp.p; search_params_.merge_mode = sp.merge_mode; @@ -140,7 +139,7 @@ template void cuvs_mg_cagra::load(const std::string& file) { index_ = - std::make_shared, T, IdxT>>( + std::make_shared, T, IdxT>>( std::move(cuvs::neighbors::cagra::deserialize(clique_, file))); } diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index 384d0c48f4..ca05ecdf0a 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -18,7 +18,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_ivf_flat_wrapper.h" -#include +#include #include namespace cuvs::bench { @@ -30,10 +30,10 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { using search_param_base = typename algo::search_param; using algo::dim_; - using build_param = cuvs::neighbors::mg::index_params; + using build_param = cuvs::neighbors::mg_index_params; struct search_param : public cuvs::bench::cuvs_ivf_flat::search_param { - cuvs::neighbors::mg::sharded_merge_mode merge_mode; + cuvs::neighbors::sharded_merge_mode merge_mode; }; cuvs_mg_ivf_flat(Metric metric, int dim, const build_param& param) @@ -75,8 +75,8 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { private: raft::device_resources_snmg clique_; build_param index_params_; - cuvs::neighbors::mg::search_params search_params_; - std::shared_ptr, T, IdxT>> + cuvs::neighbors::mg_search_params search_params_; + std::shared_ptr, T, IdxT>> index_; }; @@ -86,8 +86,9 @@ void cuvs_mg_ivf_flat::build(const T* dataset, size_t nrow) auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); auto idx = cuvs::neighbors::ivf_flat::build(clique_, index_params_, dataset_view); - index_ = std::make_shared< - cuvs::neighbors::mg::index, T, IdxT>>(std::move(idx)); + index_ = + std::make_shared, T, IdxT>>( + std::move(idx)); } template @@ -113,9 +114,9 @@ void cuvs_mg_ivf_flat::save(const std::string& file) const template void cuvs_mg_ivf_flat::load(const std::string& file) { - index_ = std::make_shared< - cuvs::neighbors::mg::index, T, IdxT>>( - std::move(cuvs::neighbors::ivf_flat::deserialize(clique_, file))); + index_ = + std::make_shared, T, IdxT>>( + std::move(cuvs::neighbors::ivf_flat::deserialize(clique_, file))); } template diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index 0bb17b3192..c37bbdf18b 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -18,7 +18,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_ivf_pq_wrapper.h" -#include +#include #include namespace cuvs::bench { @@ -30,10 +30,10 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { using search_param_base = typename algo::search_param; using algo::dim_; - using build_param = cuvs::neighbors::mg::index_params; + using build_param = cuvs::neighbors::mg_index_params; struct search_param : public cuvs::bench::cuvs_ivf_pq::search_param { - cuvs::neighbors::mg::sharded_merge_mode merge_mode; + cuvs::neighbors::sharded_merge_mode merge_mode; }; cuvs_mg_ivf_pq(Metric metric, int dim, const build_param& param) @@ -75,8 +75,8 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { private: raft::device_resources_snmg clique_; build_param index_params_; - cuvs::neighbors::mg::search_params search_params_; - std::shared_ptr, T, IdxT>> index_; + cuvs::neighbors::mg_search_params search_params_; + std::shared_ptr, T, IdxT>> index_; }; template @@ -86,7 +86,7 @@ void cuvs_mg_ivf_pq::build(const T* dataset, size_t nrow) raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); auto idx = cuvs::neighbors::ivf_pq::build(clique_, index_params_, dataset_view); index_ = - std::make_shared, T, IdxT>>( + std::make_shared, T, IdxT>>( std::move(idx)); } @@ -95,8 +95,7 @@ void cuvs_mg_ivf_pq::set_search_param(const search_param_base& param, const void* filter_bitset) { if (filter_bitset != nullptr) { throw std::runtime_error("Filtering is not supported yet."); } - auto sp = dynamic_cast(param); - // search_params_ = static_cast>(sp.pq_param); + auto sp = dynamic_cast(param); ivf_pq::search_params* search_params_ptr_ = static_cast(&search_params_); *search_params_ptr_ = sp.pq_param; search_params_.merge_mode = sp.merge_mode; @@ -113,7 +112,7 @@ template void cuvs_mg_ivf_pq::load(const std::string& file) { index_ = - std::make_shared, T, IdxT>>( + std::make_shared, T, IdxT>>( std::move(cuvs::neighbors::ivf_pq::deserialize(clique_, file))); } diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index c9e23bbf65..0a47136264 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -31,21 +31,6 @@ #include #include -#ifdef CUVS_BUILD_MG_ALGOS -#include -#else -namespace cuvs::neighbors::mg { -template -struct index_params; - -template -struct search_params; - -template -struct index; -} // namespace cuvs::neighbors::mg -#endif - #include #include @@ -2013,7 +1998,7 @@ auto merge(raft::resources const& res, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode * @@ -2024,9 +2009,9 @@ auto merge(raft::resources const& res, * @return the constructed CAGRA MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, float, uint32_t>; + -> cuvs::neighbors::mg_index, float, uint32_t>; /// \ingroup mg_cpp_index_build /** @@ -2035,7 +2020,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode * @@ -2046,9 +2031,9 @@ auto build(const raft::resources& clique, * @return the constructed CAGRA MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, half, uint32_t>; + -> cuvs::neighbors::mg_index, half, uint32_t>; /// \ingroup mg_cpp_index_build /** @@ -2057,7 +2042,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode * @@ -2068,9 +2053,9 @@ auto build(const raft::resources& clique, * @return the constructed CAGRA MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, int8_t, uint32_t>; + -> cuvs::neighbors::mg_index, int8_t, uint32_t>; /// \ingroup mg_cpp_index_build /** @@ -2079,7 +2064,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode * @@ -2090,9 +2075,9 @@ auto build(const raft::resources& clique, * @return the constructed CAGRA MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, uint8_t, uint32_t>; + -> cuvs::neighbors::mg_index, uint8_t, uint32_t>; /// \defgroup mg_cpp_index_extend ANN MG index extend @@ -2103,7 +2088,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -2116,7 +2101,7 @@ auto build(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, float, uint32_t>& index, + cuvs::neighbors::mg_index, float, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -2127,7 +2112,7 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -2140,7 +2125,7 @@ void extend(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, half, uint32_t>& index, + cuvs::neighbors::mg_index, half, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -2151,7 +2136,7 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -2164,7 +2149,7 @@ void extend(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, int8_t, uint32_t>& index, + cuvs::neighbors::mg_index, int8_t, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -2175,7 +2160,7 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -2188,7 +2173,7 @@ void extend(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, + cuvs::neighbors::mg_index, uint8_t, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -2201,9 +2186,9 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -2217,8 +2202,8 @@ void extend(const raft::resources& clique, * */ void search(const raft::resources& clique, - const cuvs::neighbors::mg::index, float, uint32_t>& index, - const cuvs::neighbors::mg::search_params& search_params, + const cuvs::neighbors::mg_index, float, uint32_t>& index, + const cuvs::neighbors::mg_search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -2230,9 +2215,9 @@ void search(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -2246,8 +2231,8 @@ void search(const raft::resources& clique, * */ void search(const raft::resources& clique, - const cuvs::neighbors::mg::index, half, uint32_t>& index, - const cuvs::neighbors::mg::search_params& search_params, + const cuvs::neighbors::mg_index, half, uint32_t>& index, + const cuvs::neighbors::mg_search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -2259,9 +2244,9 @@ void search(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -2276,8 +2261,8 @@ void search(const raft::resources& clique, */ void search( const raft::resources& clique, - const cuvs::neighbors::mg::index, int8_t, uint32_t>& index, - const cuvs::neighbors::mg::search_params& search_params, + const cuvs::neighbors::mg_index, int8_t, uint32_t>& index, + const cuvs::neighbors::mg_search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -2289,9 +2274,9 @@ void search( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -2306,8 +2291,8 @@ void search( */ void search( const raft::resources& clique, - const cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, - const cuvs::neighbors::mg::search_params& search_params, + const cuvs::neighbors::mg_index, uint8_t, uint32_t>& index, + const cuvs::neighbors::mg_search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -2321,7 +2306,7 @@ void search( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::cagra::serialize(clique, index, filename); @@ -2334,7 +2319,7 @@ void search( */ void serialize( const raft::resources& clique, - const cuvs::neighbors::mg::index, float, uint32_t>& index, + const cuvs::neighbors::mg_index, float, uint32_t>& index, const std::string& filename); /// \ingroup mg_cpp_serialize @@ -2344,7 +2329,7 @@ void serialize( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::cagra::serialize(clique, index, filename); @@ -2355,10 +2340,9 @@ void serialize( * @param[in] filename path to the file to be serialized * */ -void serialize( - const raft::resources& clique, - const cuvs::neighbors::mg::index, half, uint32_t>& index, - const std::string& filename); +void serialize(const raft::resources& clique, + const cuvs::neighbors::mg_index, half, uint32_t>& index, + const std::string& filename); /// \ingroup mg_cpp_serialize /** @@ -2367,7 +2351,7 @@ void serialize( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::cagra::serialize(clique, index, filename); @@ -2380,7 +2364,7 @@ void serialize( */ void serialize( const raft::resources& clique, - const cuvs::neighbors::mg::index, int8_t, uint32_t>& index, + const cuvs::neighbors::mg_index, int8_t, uint32_t>& index, const std::string& filename); /// \ingroup mg_cpp_serialize @@ -2390,7 +2374,7 @@ void serialize( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::cagra::serialize(clique, index, filename); @@ -2403,7 +2387,7 @@ void serialize( */ void serialize( const raft::resources& clique, - const cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, + const cuvs::neighbors::mg_index, uint8_t, uint32_t>& index, const std::string& filename); /// \defgroup mg_cpp_deserialize ANN MG index deserialization @@ -2415,7 +2399,7 @@ void serialize( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::cagra::serialize(clique, index, filename); @@ -2429,7 +2413,7 @@ void serialize( */ template auto deserialize(const raft::resources& clique, const std::string& filename) - -> cuvs::neighbors::mg::index, T, IdxT>; + -> cuvs::neighbors::mg_index, T, IdxT>; /// \defgroup mg_cpp_distribute ANN MG local index distribution @@ -2455,6 +2439,6 @@ auto deserialize(const raft::resources& clique, const std::string& filename) */ template auto distribute(const raft::resources& clique, const std::string& filename) - -> cuvs::neighbors::mg::index, T, IdxT>; + -> cuvs::neighbors::mg_index, T, IdxT>; } // namespace cuvs::neighbors::cagra diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 2e275242b6..6ddcebe961 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -793,3 +793,86 @@ void deserialize(const raft::resources& handle, const std::string& filename); }; // namespace cuvs::neighbors + +/// \defgroup mg_cpp_index_params ANN MG index build parameters + +namespace cuvs::neighbors { +/** Distribution mode */ +/// \ingroup mg_cpp_index_params +enum distribution_mode { + /** Index is replicated on each device, favors throughput */ + REPLICATED, + /** Index is split on several devices, favors scaling */ + SHARDED +}; + +/// \defgroup mg_cpp_search_params ANN MG search parameters + +/** Search mode when using a replicated index */ +/// \ingroup mg_cpp_search_params +enum replicated_search_mode { + /** Search queries are splited to maintain equal load on GPUs */ + LOAD_BALANCER, + /** Each search query is processed by a single GPU in a round-robin fashion */ + ROUND_ROBIN +}; + +/** Merge mode when using a sharded index */ +/// \ingroup mg_cpp_search_params +enum sharded_merge_mode { + /** Search batches are merged on the root rank */ + MERGE_ON_ROOT_RANK, + /** Search batches are merged in a tree reduction fashion */ + TREE_MERGE +}; + +/** Build parameters */ +/// \ingroup mg_cpp_index_params +template +struct mg_index_params : public Upstream { + mg_index_params() : mode(SHARDED) {} + + mg_index_params(const Upstream& sp) : Upstream(sp), mode(SHARDED) {} + + /** Distribution mode */ + cuvs::neighbors::distribution_mode mode = SHARDED; +}; + +/** Search parameters */ +/// \ingroup mg_cpp_search_params +template +struct mg_search_params : public Upstream { + mg_search_params() : search_mode(LOAD_BALANCER), merge_mode(TREE_MERGE) {} + + mg_search_params(const Upstream& sp) + : Upstream(sp), search_mode(LOAD_BALANCER), merge_mode(TREE_MERGE) + { + } + + /** Replicated search mode */ + cuvs::neighbors::replicated_search_mode search_mode = LOAD_BALANCER; + /** Sharded merge mode */ + cuvs::neighbors::sharded_merge_mode merge_mode = TREE_MERGE; + /** Number of rows per batch */ + int64_t n_rows_per_batch = 1 << 20; +}; + +template +struct mg_index { + mg_index(const raft::resources& clique, distribution_mode mode); + mg_index(const raft::resources& clique, const std::string& filename); + + mg_index(const mg_index&) = delete; + mg_index(mg_index&&) = default; + auto operator=(const mg_index&) -> mg_index& = delete; + auto operator=(mg_index&&) -> mg_index& = default; + + distribution_mode mode_; + int num_ranks_; + std::vector> ann_interfaces_; + + // for load balancing mechanism + std::shared_ptr> round_robin_counter_; +}; + +} // namespace cuvs::neighbors diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index d3c519362d..fd96729755 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -22,21 +22,6 @@ #include #include -#ifdef CUVS_BUILD_MG_ALGOS -#include -#else -namespace cuvs::neighbors::mg { -template -struct index_params; - -template -struct search_params; - -template -struct index; -} // namespace cuvs::neighbors::mg -#endif - namespace cuvs::neighbors::ivf_flat { /** * @defgroup ivf_flat_cpp_index_params IVF-Flat index build parameters @@ -1622,7 +1607,7 @@ void deserialize(raft::resources const& handle, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * @endcode * @@ -1633,9 +1618,9 @@ void deserialize(raft::resources const& handle, * @return the constructed IVF-Flat MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, float, int64_t>; + -> cuvs::neighbors::mg_index, float, int64_t>; /// \ingroup mg_cpp_index_build /** @@ -1644,7 +1629,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * @endcode * @@ -1655,9 +1640,9 @@ auto build(const raft::resources& clique, * @return the constructed IVF-Flat MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, int8_t, int64_t>; + -> cuvs::neighbors::mg_index, int8_t, int64_t>; /// \ingroup mg_cpp_index_build /** @@ -1666,7 +1651,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * @endcode * @@ -1677,9 +1662,9 @@ auto build(const raft::resources& clique, * @return the constructed IVF-Flat MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, uint8_t, int64_t>; + -> cuvs::neighbors::mg_index, uint8_t, int64_t>; /// \defgroup mg_cpp_index_extend ANN MG index extend @@ -1690,7 +1675,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -1703,7 +1688,7 @@ auto build(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, float, int64_t>& index, + cuvs::neighbors::mg_index, float, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1714,7 +1699,7 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -1727,7 +1712,7 @@ void extend(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, int8_t, int64_t>& index, + cuvs::neighbors::mg_index, int8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1738,7 +1723,7 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -1751,7 +1736,7 @@ void extend(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + cuvs::neighbors::mg_index, uint8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1764,9 +1749,9 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -1779,13 +1764,12 @@ void extend(const raft::resources& clique, * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] * */ -void search( - const raft::resources& clique, - const cuvs::neighbors::mg::index, float, int64_t>& index, - const cuvs::neighbors::mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances); +void search(const raft::resources& clique, + const cuvs::neighbors::mg_index, float, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances); /// \ingroup mg_cpp_index_search /** @@ -1794,9 +1778,9 @@ void search( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -1811,8 +1795,8 @@ void search( */ void search( const raft::resources& clique, - const cuvs::neighbors::mg::index, int8_t, int64_t>& index, - const cuvs::neighbors::mg::search_params& search_params, + const cuvs::neighbors::mg_index, int8_t, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -1824,9 +1808,9 @@ void search( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -1841,8 +1825,8 @@ void search( */ void search( const raft::resources& clique, - const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, - const cuvs::neighbors::mg::search_params& search_params, + const cuvs::neighbors::mg_index, uint8_t, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -1856,7 +1840,7 @@ void search( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); @@ -1869,7 +1853,7 @@ void search( */ void serialize( const raft::resources& clique, - const cuvs::neighbors::mg::index, float, int64_t>& index, + const cuvs::neighbors::mg_index, float, int64_t>& index, const std::string& filename); /// \ingroup mg_cpp_serialize @@ -1879,7 +1863,7 @@ void serialize( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); @@ -1892,7 +1876,7 @@ void serialize( */ void serialize( const raft::resources& clique, - const cuvs::neighbors::mg::index, int8_t, int64_t>& index, + const cuvs::neighbors::mg_index, int8_t, int64_t>& index, const std::string& filename); /// \ingroup mg_cpp_serialize @@ -1902,7 +1886,7 @@ void serialize( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); @@ -1915,7 +1899,7 @@ void serialize( */ void serialize( const raft::resources& clique, - const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + const cuvs::neighbors::mg_index, uint8_t, int64_t>& index, const std::string& filename); /// \ingroup mg_cpp_deserialize @@ -1925,7 +1909,7 @@ void serialize( * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); @@ -1939,7 +1923,7 @@ void serialize( */ template auto deserialize(const raft::resources& clique, const std::string& filename) - -> cuvs::neighbors::mg::index, T, IdxT>; + -> cuvs::neighbors::mg_index, T, IdxT>; /// \defgroup mg_cpp_distribute ANN MG local index distribution @@ -1965,7 +1949,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) */ template auto distribute(const raft::resources& clique, const std::string& filename) - -> cuvs::neighbors::mg::index, T, IdxT>; + -> cuvs::neighbors::mg_index, T, IdxT>; namespace helpers { diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index fdfa0807e9..4347a64a0b 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -27,21 +27,6 @@ #include #include -#ifdef CUVS_BUILD_MG_ALGOS -#include -#else -namespace cuvs::neighbors::mg { -template -struct index_params; - -template -struct search_params; - -template -struct index; -} // namespace cuvs::neighbors::mg -#endif - namespace cuvs::neighbors::ivf_pq { /** @@ -1752,7 +1737,7 @@ void deserialize(raft::resources const& handle, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode * @@ -1763,9 +1748,9 @@ void deserialize(raft::resources const& handle, * @return the constructed IVF-PQ MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, float, int64_t>; + -> cuvs::neighbors::mg_index, float, int64_t>; /// \ingroup mg_cpp_index_build /** @@ -1774,7 +1759,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode * @@ -1785,9 +1770,9 @@ auto build(const raft::resources& clique, * @return the constructed IVF-PQ MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, half, int64_t>; + -> cuvs::neighbors::mg_index, half, int64_t>; /// \ingroup mg_cpp_index_build /** @@ -1796,7 +1781,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode * @@ -1807,9 +1792,9 @@ auto build(const raft::resources& clique, * @return the constructed IVF-PQ MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, int8_t, int64_t>; + -> cuvs::neighbors::mg_index, int8_t, int64_t>; /// \ingroup mg_cpp_index_build /** @@ -1818,7 +1803,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode * @@ -1829,9 +1814,9 @@ auto build(const raft::resources& clique, * @return the constructed IVF-PQ MG index */ auto build(const raft::resources& clique, - const cuvs::neighbors::mg::index_params& index_params, + const cuvs::neighbors::mg_index_params& index_params, raft::host_matrix_view index_dataset) - -> cuvs::neighbors::mg::index, uint8_t, int64_t>; + -> cuvs::neighbors::mg_index, uint8_t, int64_t>; /// \defgroup mg_cpp_index_extend ANN MG index extend @@ -1842,7 +1827,7 @@ auto build(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -1855,7 +1840,7 @@ auto build(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, float, int64_t>& index, + cuvs::neighbors::mg_index, float, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1866,7 +1851,7 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -1879,7 +1864,7 @@ void extend(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, half, int64_t>& index, + cuvs::neighbors::mg_index, half, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1890,7 +1875,7 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -1903,7 +1888,7 @@ void extend(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, int8_t, int64_t>& index, + cuvs::neighbors::mg_index, int8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1914,7 +1899,7 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); * @endcode @@ -1927,7 +1912,7 @@ void extend(const raft::resources& clique, * */ void extend(const raft::resources& clique, - cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + cuvs::neighbors::mg_index, uint8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -1940,9 +1925,9 @@ void extend(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -1956,8 +1941,8 @@ void extend(const raft::resources& clique, * */ void search(const raft::resources& clique, - const cuvs::neighbors::mg::index, float, int64_t>& index, - const cuvs::neighbors::mg::search_params& search_params, + const cuvs::neighbors::mg_index, float, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -1969,9 +1954,9 @@ void search(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -1985,8 +1970,8 @@ void search(const raft::resources& clique, * */ void search(const raft::resources& clique, - const cuvs::neighbors::mg::index, half, int64_t>& index, - const cuvs::neighbors::mg::search_params& search_params, + const cuvs::neighbors::mg_index, half, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -1998,9 +1983,9 @@ void search(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -2014,8 +1999,8 @@ void search(const raft::resources& clique, * */ void search(const raft::resources& clique, - const cuvs::neighbors::mg::index, int8_t, int64_t>& index, - const cuvs::neighbors::mg::search_params& search_params, + const cuvs::neighbors::mg_index, int8_t, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -2027,9 +2012,9 @@ void search(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::mg_search_params search_params; * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, * distances); * @endcode @@ -2043,8 +2028,8 @@ void search(const raft::resources& clique, * */ void search(const raft::resources& clique, - const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, - const cuvs::neighbors::mg::search_params& search_params, + const cuvs::neighbors::mg_index, uint8_t, int64_t>& index, + const cuvs::neighbors::mg_search_params& search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -2058,7 +2043,7 @@ void search(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); @@ -2070,7 +2055,7 @@ void search(const raft::resources& clique, * */ void serialize(const raft::resources& clique, - const cuvs::neighbors::mg::index, float, int64_t>& index, + const cuvs::neighbors::mg_index, float, int64_t>& index, const std::string& filename); /// \ingroup mg_cpp_serialize @@ -2080,7 +2065,7 @@ void serialize(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); @@ -2092,7 +2077,7 @@ void serialize(const raft::resources& clique, * */ void serialize(const raft::resources& clique, - const cuvs::neighbors::mg::index, half, int64_t>& index, + const cuvs::neighbors::mg_index, half, int64_t>& index, const std::string& filename); /// \ingroup mg_cpp_serialize @@ -2102,7 +2087,7 @@ void serialize(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); @@ -2114,7 +2099,7 @@ void serialize(const raft::resources& clique, * */ void serialize(const raft::resources& clique, - const cuvs::neighbors::mg::index, int8_t, int64_t>& index, + const cuvs::neighbors::mg_index, int8_t, int64_t>& index, const std::string& filename); /// \ingroup mg_cpp_serialize @@ -2124,7 +2109,7 @@ void serialize(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); @@ -2136,7 +2121,7 @@ void serialize(const raft::resources& clique, * */ void serialize(const raft::resources& clique, - const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + const cuvs::neighbors::mg_index, uint8_t, int64_t>& index, const std::string& filename); /// \ingroup mg_cpp_deserialize @@ -2146,7 +2131,7 @@ void serialize(const raft::resources& clique, * Usage example: * @code{.cpp} * raft::resources clique; - * cuvs::neighbors::mg::index_params index_params; + * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); @@ -2159,7 +2144,7 @@ void serialize(const raft::resources& clique, */ template auto deserialize(const raft::resources& clique, const std::string& filename) - -> cuvs::neighbors::mg::index, T, IdxT>; + -> cuvs::neighbors::mg_index, T, IdxT>; /// \defgroup mg_cpp_distribute ANN MG local index distribution @@ -2184,7 +2169,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) */ template auto distribute(const raft::resources& clique, const std::string& filename) - -> cuvs::neighbors::mg::index, T, IdxT>; + -> cuvs::neighbors::mg_index, T, IdxT>; namespace helpers { /** diff --git a/cpp/include/cuvs/neighbors/mg.hpp b/cpp/include/cuvs/neighbors/mg.hpp deleted file mode 100644 index 685185a644..0000000000 --- a/cpp/include/cuvs/neighbors/mg.hpp +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef CUVS_BUILD_MG_ALGOS - -#include - -/// \defgroup mg_cpp_index_params ANN MG index build parameters - -namespace cuvs::neighbors::mg { -/** Distribution mode */ -/// \ingroup mg_cpp_index_params -enum distribution_mode { - /** Index is replicated on each device, favors throughput */ - REPLICATED, - /** Index is split on several devices, favors scaling */ - SHARDED -}; - -/// \defgroup mg_cpp_search_params ANN MG search parameters - -/** Search mode when using a replicated index */ -/// \ingroup mg_cpp_search_params -enum replicated_search_mode { - /** Search queries are splited to maintain equal load on GPUs */ - LOAD_BALANCER, - /** Each search query is processed by a single GPU in a round-robin fashion */ - ROUND_ROBIN -}; - -/** Merge mode when using a sharded index */ -/// \ingroup mg_cpp_search_params -enum sharded_merge_mode { - /** Search batches are merged on the root rank */ - MERGE_ON_ROOT_RANK, - /** Search batches are merged in a tree reduction fashion */ - TREE_MERGE -}; - -/** Build parameters */ -/// \ingroup mg_cpp_index_params -template -struct index_params : public Upstream { - index_params() : mode(SHARDED) {} - - index_params(const Upstream& sp) : Upstream(sp), mode(SHARDED) {} - - /** Distribution mode */ - cuvs::neighbors::mg::distribution_mode mode = SHARDED; -}; - -/** Search parameters */ -/// \ingroup mg_cpp_search_params -template -struct search_params : public Upstream { - search_params() : search_mode(LOAD_BALANCER), merge_mode(TREE_MERGE) {} - - search_params(const Upstream& sp) - : Upstream(sp), search_mode(LOAD_BALANCER), merge_mode(TREE_MERGE) - { - } - - /** Replicated search mode */ - cuvs::neighbors::mg::replicated_search_mode search_mode = LOAD_BALANCER; - /** Sharded merge mode */ - cuvs::neighbors::mg::sharded_merge_mode merge_mode = TREE_MERGE; - /** Number of rows per batch */ - int64_t n_rows_per_batch = 1 << 20; -}; - -template -struct index { - index(const raft::resources& clique, distribution_mode mode); - index(const raft::resources& clique, const std::string& filename); - - index(const index&) = delete; - index(index&&) = default; - auto operator=(const index&) -> index& = delete; - auto operator=(index&&) -> index& = default; - - distribution_mode mode_; - int num_ranks_; - std::vector> ann_interfaces_; - - // for load balancing mechanism - std::shared_ptr> round_robin_counter_; -}; - -} // namespace cuvs::neighbors::mg - -#else - -static_assert(false, - "FORBIDEN_MG_ALGORITHM_IMPORT\n\n" - "Please recompile the cuVS library with MG algorithms BUILD_MG_ALGOS=ON.\n"); - -#endif diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index 20b570a7de..a3f3725c5f 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -46,14 +46,14 @@ flat_macro = """ #define CUVS_INST_MG_FLAT(T, IdxT) \\ namespace cuvs::neighbors::ivf_flat { \\ - using namespace cuvs::neighbors::mg; \\ + using namespace cuvs::neighbors; \\ \\ - cuvs::neighbors::mg::index, T, IdxT> build( \\ + cuvs::neighbors::mg_index, T, IdxT> build( \\ const raft::resources& res, \\ - const mg::index_params& index_params, \\ + const mg_index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \\ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \\ cuvs::neighbors::mg::detail::build(res, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -61,7 +61,7 @@ } \\ \\ void extend(const raft::resources& res, \\ - cuvs::neighbors::mg::index, T, IdxT>& index, \\ + cuvs::neighbors::mg_index, T, IdxT>& index, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ @@ -69,8 +69,8 @@ } \\ \\ void search(const raft::resources& res, \\ - const cuvs::neighbors::mg::index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ + const mg_search_params& search_params, \\ raft::host_matrix_view queries, \\ raft::host_matrix_view neighbors, \\ raft::host_matrix_view distances) \\ @@ -81,27 +81,27 @@ } \\ \\ void serialize(const raft::resources& res, \\ - const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ cuvs::neighbors::mg::detail::serialize(res, index, filename); \\ } \\ \\ template<> \\ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \\ const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \\ return idx; \\ } \\ \\ template<> \\ - cuvs::neighbors::mg::index, T, IdxT> distribute( \\ + cuvs::neighbors::mg_index, T, IdxT> distribute( \\ const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \\ return idx; \\ } \\ @@ -111,14 +111,14 @@ pq_macro = """ #define CUVS_INST_MG_PQ(T, IdxT) \\ namespace cuvs::neighbors::ivf_pq { \\ - using namespace cuvs::neighbors::mg; \\ + using namespace cuvs::neighbors; \\ \\ - cuvs::neighbors::mg::index, T, IdxT> build( \\ + cuvs::neighbors::mg_index, T, IdxT> build( \\ const raft::resources& res, \\ - const mg::index_params& index_params, \\ + const mg_index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \\ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \\ cuvs::neighbors::mg::detail::build(res, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -126,7 +126,7 @@ } \\ \\ void extend(const raft::resources& res, \\ - cuvs::neighbors::mg::index, T, IdxT>& index, \\ + cuvs::neighbors::mg_index, T, IdxT>& index, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ @@ -134,8 +134,8 @@ } \\ \\ void search(const raft::resources& res, \\ - const cuvs::neighbors::mg::index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ + const mg_search_params& search_params, \\ raft::host_matrix_view queries, \\ raft::host_matrix_view neighbors, \\ raft::host_matrix_view distances) \\ @@ -146,27 +146,27 @@ } \\ \\ void serialize(const raft::resources& res, \\ - const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ cuvs::neighbors::mg::detail::serialize(res, index, filename); \\ } \\ \\ template<> \\ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \\ const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \\ return idx; \\ } \\ \\ template<> \\ - cuvs::neighbors::mg::index, T, IdxT> distribute( \\ + cuvs::neighbors::mg_index, T, IdxT> distribute( \\ const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \\ return idx; \\ } \\ @@ -176,14 +176,14 @@ cagra_macro = """ #define CUVS_INST_MG_CAGRA(T, IdxT) \\ namespace cuvs::neighbors::cagra { \\ - using namespace cuvs::neighbors::mg; \\ + using namespace cuvs::neighbors; \\ \\ - cuvs::neighbors::mg::index, T, IdxT> build( \\ + cuvs::neighbors::mg_index, T, IdxT> build( \\ const raft::resources& res, \\ - const mg::index_params& index_params, \\ + const mg_index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \\ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \\ cuvs::neighbors::mg::detail::build(res, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -191,8 +191,8 @@ } \\ \\ void search(const raft::resources& res, \\ - const cuvs::neighbors::mg::index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ + const mg_search_params& search_params, \\ raft::host_matrix_view queries, \\ raft::host_matrix_view neighbors, \\ raft::host_matrix_view distances) \\ @@ -203,27 +203,27 @@ } \\ \\ void serialize(const raft::resources& res, \\ - const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const cuvs::neighbors::mg_index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ cuvs::neighbors::mg::detail::serialize(res, index, filename); \\ } \\ \\ template<> \\ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \\ const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \\ return idx; \\ } \\ \\ template<> \\ - cuvs::neighbors::mg::index, T, IdxT> distribute( \\ + cuvs::neighbors::mg_index, T, IdxT> distribute( \\ const raft::resources& res, \\ const std::string& filename) \\ { \\ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \\ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \\ return idx; \\ } \\ diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 65da31bbac..758bc028eb 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -26,7 +26,6 @@ #include #include #include -#include #include @@ -53,7 +52,7 @@ using namespace raft; // local index deserialization and distribution template void deserialize_and_distribute(const raft::resources& clique, - index& index, + mg_index& index, const std::string& filename) { for (int rank = 0; rank < index.num_ranks_; rank++) { @@ -66,14 +65,14 @@ void deserialize_and_distribute(const raft::resources& clique, // MG index deserialization template void deserialize(const raft::resources& clique, - index& index, + mg_index& index, const std::string& filename) { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } const auto& handle = raft::resource::set_current_device_to_root_rank(clique); - index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); + index.mode_ = (cuvs::neighbors::distribution_mode)deserialize_scalar(handle, is); index.num_ranks_ = deserialize_scalar(handle, is); if (index.num_ranks_ != raft::resource::get_num_ranks(clique)) { @@ -93,7 +92,7 @@ void deserialize(const raft::resources& clique, template void build(const raft::resources& clique, - index& index, + mg_index& index, const cuvs::neighbors::index_params* index_params, raft::host_matrix_view index_dataset) { @@ -134,7 +133,7 @@ void build(const raft::resources& clique, template void extend(const raft::resources& clique, - index& index, + mg_index& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { @@ -179,7 +178,7 @@ void extend(const raft::resources& clique, template void sharded_search_with_direct_merge(const raft::resources& clique, - const index& index, + const mg_index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, @@ -311,7 +310,7 @@ void sharded_search_with_direct_merge(const raft::resources& clique, template void sharded_search_with_tree_merge(const raft::resources& clique, - const index& index, + const mg_index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, @@ -444,7 +443,7 @@ void sharded_search_with_tree_merge(const raft::resources& clique, template void run_search_batch(const raft::resources& clique, - const index& index, + const mg_index& index, int rank, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view& queries, @@ -483,7 +482,7 @@ void run_search_batch(const raft::resources& clique, template void search(const raft::resources& clique, - const index& index, + const mg_index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, raft::host_matrix_view neighbors, @@ -495,22 +494,21 @@ void search(const raft::resources& clique, int64_t n_rows_per_batch = -1; if (index.mode_ == REPLICATED) { - cuvs::neighbors::mg::replicated_search_mode search_mode; + cuvs::neighbors::replicated_search_mode search_mode; if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>( + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>( search_params); search_mode = mg_search_params->search_mode; n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>( - search_params); + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>(search_params); search_mode = mg_search_params->search_mode; n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>(search_params); + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>(search_params); search_mode = mg_search_params->search_mode; n_rows_per_batch = mg_search_params->n_rows_per_batch; } @@ -570,22 +568,21 @@ void search(const raft::resources& clique, n_neighbors); } } else if (index.mode_ == SHARDED) { - cuvs::neighbors::mg::sharded_merge_mode merge_mode; + cuvs::neighbors::sharded_merge_mode merge_mode; if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>( + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>( search_params); merge_mode = mg_search_params->merge_mode; n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>( - search_params); + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>(search_params); merge_mode = mg_search_params->merge_mode; n_rows_per_batch = mg_search_params->n_rows_per_batch; } else if constexpr (std::is_same>::value) { - const cuvs::neighbors::mg::search_params* mg_search_params = - static_cast*>(search_params); + const cuvs::neighbors::mg_search_params* mg_search_params = + static_cast*>(search_params); merge_mode = mg_search_params->merge_mode; n_rows_per_batch = mg_search_params->n_rows_per_batch; } @@ -628,7 +625,7 @@ void search(const raft::resources& clique, template void serialize(const raft::resources& clique, - const index& index, + const mg_index& index, const std::string& filename) { std::ofstream of(filename, std::ios::out | std::ios::binary); @@ -650,21 +647,22 @@ void serialize(const raft::resources& clique, } // namespace cuvs::neighbors::mg::detail -namespace cuvs::neighbors::mg { +namespace cuvs::neighbors { using namespace cuvs::neighbors; using namespace raft; template -index::index(const raft::resources& clique, distribution_mode mode) +mg_index::mg_index(const raft::resources& clique, distribution_mode mode) : mode_(mode), round_robin_counter_(std::make_shared>(0)) { num_ranks_ = raft::resource::get_num_ranks(clique); } template -index::index(const raft::resources& clique, const std::string& filename) +mg_index::mg_index(const raft::resources& clique, + const std::string& filename) : round_robin_counter_(std::make_shared>(0)) { cuvs::neighbors::mg::detail::deserialize(clique, *this, filename); } -} // namespace cuvs::neighbors::mg +} // namespace cuvs::neighbors diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index ef2ae66d51..f196b9c596 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -25,63 +25,63 @@ #include "mg.cuh" -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - namespace cuvs::neighbors::cagra { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(float, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index 16c8c436ba..8cb67a96f7 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -25,63 +25,63 @@ #include "mg.cuh" -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - namespace cuvs::neighbors::cagra { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(half, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index dae594c30a..25ed90c67f 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -25,63 +25,63 @@ #include "mg.cuh" -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - namespace cuvs::neighbors::cagra { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(int8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index 2f486f5d24..43b82426be 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -25,63 +25,63 @@ #include "mg.cuh" -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - namespace cuvs::neighbors::cagra { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index 3f96e68677..d6b77faf2a 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -25,71 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_FLAT(T, IdxT) \ - namespace cuvs::neighbors::ivf_flat { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::resources& res, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 108c1ce478..fb2f85b817 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -25,71 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_FLAT(T, IdxT) \ - namespace cuvs::neighbors::ivf_flat { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::resources& res, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index d180bbf95c..33fe220974 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -25,71 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_FLAT(T, IdxT) \ - namespace cuvs::neighbors::ivf_flat { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::resources& res, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(uint8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index 9dee9c7762..d6c3e271f8 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -25,71 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_PQ(T, IdxT) \ - namespace cuvs::neighbors::ivf_pq { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::resources& res, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index eb9a2b834e..bd8ccd0645 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -25,71 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_PQ(T, IdxT) \ - namespace cuvs::neighbors::ivf_pq { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::resources& res, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(half, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index 3bc08bd984..220b5bd01c 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -25,71 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_PQ(T, IdxT) \ - namespace cuvs::neighbors::ivf_pq { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::resources& res, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index 5ce7e61f58..a63428ae39 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -25,71 +25,71 @@ #include "mg.cuh" -#define CUVS_INST_MG_PQ(T, IdxT) \ - namespace cuvs::neighbors::ivf_pq { \ - using namespace cuvs::neighbors::mg; \ - \ - cuvs::neighbors::mg::index, T, IdxT> build( \ - const raft::resources& res, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - cuvs::neighbors::mg::index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ - res, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::resources& res, \ - cuvs::neighbors::mg::index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - cuvs::neighbors::mg::detail::search( \ - res, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances); \ - } \ - \ - void serialize(const raft::resources& res, \ - const cuvs::neighbors::mg::index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> deserialize( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, filename); \ - return idx; \ - } \ - \ - template <> \ - cuvs::neighbors::mg::index, T, IdxT> distribute( \ - const raft::resources& res, const std::string& filename) \ - { \ - auto idx = cuvs::neighbors::mg::index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ - return idx; \ - } \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors; \ + \ + cuvs::neighbors::mg_index, T, IdxT> build( \ + const raft::resources& res, \ + const mg_index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ + cuvs::neighbors::mg::detail::build( \ + res, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::resources& res, \ + cuvs::neighbors::mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const mg_search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + cuvs::neighbors::mg::detail::search( \ + res, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances); \ + } \ + \ + void serialize(const raft::resources& res, \ + const cuvs::neighbors::mg_index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> deserialize( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg_index, T, IdxT> distribute( \ + const raft::resources& res, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + return idx; \ + } \ } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(uint8_t, int64_t); diff --git a/cpp/tests/neighbors/mg.cuh b/cpp/tests/neighbors/mg.cuh index f632f57986..c2133393b6 100644 --- a/cpp/tests/neighbors/mg.cuh +++ b/cpp/tests/neighbors/mg.cuh @@ -104,7 +104,7 @@ class AnnMGTest : public ::testing::TestWithParam { else d_mode = distribution_mode::SHARDED; - mg::index_params index_params; + mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.adaptive_centers = ps.adaptive_centers; @@ -113,7 +113,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.metric_arg = 0; index_params.mode = d_mode; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = LOAD_BALANCER; @@ -166,7 +166,7 @@ class AnnMGTest : public ::testing::TestWithParam { else d_mode = distribution_mode::SHARDED; - mg::index_params index_params; + mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.add_data_on_build = false; @@ -174,7 +174,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.metric_arg = 0; index_params.mode = d_mode; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = LOAD_BALANCER; @@ -227,12 +227,12 @@ class AnnMGTest : public ::testing::TestWithParam { else d_mode = distribution_mode::SHARDED; - mg::index_params index_params; + mg_index_params index_params; index_params.graph_build_params = cagra::graph_build_params::ivf_pq_params( raft::matrix_extent(ps.num_db_vecs, ps.dim)); index_params.mode = d_mode; - mg::search_params search_params; + mg_search_params search_params; auto index_dataset = raft::make_host_matrix_view( h_index_dataset.data(), ps.num_db_vecs, ps.dim); @@ -282,7 +282,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = LOAD_BALANCER; @@ -331,7 +331,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = LOAD_BALANCER; @@ -377,7 +377,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.graph_build_params = cagra::graph_build_params::ivf_pq_params( raft::matrix_extent(ps.num_db_vecs, ps.dim)); - mg::search_params search_params; + mg_search_params search_params; { auto index_dataset = raft::make_device_matrix_view( @@ -420,7 +420,7 @@ class AnnMGTest : public ::testing::TestWithParam { if (ps.algo == algo_t::IVF_FLAT && ps.d_mode == d_mode_t::ROUND_ROBIN) { ASSERT_TRUE(ps.num_queries <= 4); - mg::index_params index_params; + mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.adaptive_centers = ps.adaptive_centers; @@ -429,7 +429,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.metric_arg = 0; index_params.mode = REPLICATED; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = ROUND_ROBIN; @@ -486,7 +486,7 @@ class AnnMGTest : public ::testing::TestWithParam { if (ps.algo == algo_t::IVF_PQ && ps.d_mode == d_mode_t::ROUND_ROBIN) { ASSERT_TRUE(ps.num_queries <= 4); - mg::index_params index_params; + mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.add_data_on_build = false; @@ -494,7 +494,7 @@ class AnnMGTest : public ::testing::TestWithParam { index_params.metric_arg = 0; index_params.mode = REPLICATED; - mg::search_params search_params; + mg_search_params search_params; search_params.n_probes = ps.nprobe; search_params.search_mode = ROUND_ROBIN; @@ -551,12 +551,12 @@ class AnnMGTest : public ::testing::TestWithParam { if (ps.algo == algo_t::CAGRA && ps.d_mode == d_mode_t::ROUND_ROBIN) { ASSERT_TRUE(ps.num_queries <= 4); - mg::index_params index_params; + mg_index_params index_params; index_params.graph_build_params = cagra::graph_build_params::ivf_pq_params( raft::matrix_extent(ps.num_db_vecs, ps.dim)); index_params.mode = REPLICATED; - mg::search_params search_params; + mg_search_params search_params; search_params.search_mode = ROUND_ROBIN; auto index_dataset = raft::make_host_matrix_view( diff --git a/docs/source/cpp_api/neighbors_mg.rst b/docs/source/cpp_api/neighbors_mg.rst index b68defec9f..d30e4b7136 100644 --- a/docs/source/cpp_api/neighbors_mg.rst +++ b/docs/source/cpp_api/neighbors_mg.rst @@ -7,7 +7,7 @@ The SNMG (single-node multi-GPUs) ANN API provides a set of functions to deploy :language: c++ :class: highlight -``#include `` +``#include `` namespace *cuvs::neighbors::mg* From 39bda6b5bdae1458c284a43f926f674529582694 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 6 Mar 2025 13:51:41 +0100 Subject: [PATCH 13/17] Update doc --- cpp/include/cuvs/neighbors/cagra.hpp | 36 ++++++++++++------------- cpp/include/cuvs/neighbors/ivf_flat.hpp | 28 +++++++++---------- cpp/include/cuvs/neighbors/ivf_pq.hpp | 36 ++++++++++++------------- docs/source/cpp_api/neighbors_mg.rst | 4 +-- 4 files changed, 52 insertions(+), 52 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 0a47136264..bb4d90aa35 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -1997,7 +1997,7 @@ auto merge(raft::resources const& res, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode @@ -2019,7 +2019,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode @@ -2041,7 +2041,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode @@ -2063,7 +2063,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * @endcode @@ -2087,7 +2087,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); @@ -2111,7 +2111,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); @@ -2135,7 +2135,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); @@ -2159,7 +2159,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); @@ -2185,7 +2185,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -2214,7 +2214,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -2243,7 +2243,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -2273,7 +2273,7 @@ void search( * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -2305,7 +2305,7 @@ void search( * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2328,7 +2328,7 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2350,7 +2350,7 @@ void serialize(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2373,7 +2373,7 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2398,7 +2398,7 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2424,7 +2424,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::cagra::index_params index_params; * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index fd96729755..bba6436dbb 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -1606,7 +1606,7 @@ void deserialize(raft::resources const& handle, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * @endcode @@ -1628,7 +1628,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * @endcode @@ -1650,7 +1650,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * @endcode @@ -1674,7 +1674,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); @@ -1698,7 +1698,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); @@ -1722,7 +1722,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); @@ -1748,7 +1748,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -1777,7 +1777,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -1807,7 +1807,7 @@ void search( * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -1839,7 +1839,7 @@ void search( * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -1862,7 +1862,7 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -1885,7 +1885,7 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -1908,7 +1908,7 @@ void serialize( * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -1934,7 +1934,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::ivf_flat::index_params index_params; * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 4347a64a0b..73f9a831bc 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -1736,7 +1736,7 @@ void deserialize(raft::resources const& handle, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode @@ -1758,7 +1758,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode @@ -1780,7 +1780,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode @@ -1802,7 +1802,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * @endcode @@ -1826,7 +1826,7 @@ auto build(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); @@ -1850,7 +1850,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); @@ -1874,7 +1874,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); @@ -1898,7 +1898,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); @@ -1924,7 +1924,7 @@ void extend(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -1953,7 +1953,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -1982,7 +1982,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -2011,7 +2011,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * cuvs::neighbors::mg_search_params search_params; @@ -2042,7 +2042,7 @@ void search(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2064,7 +2064,7 @@ void serialize(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2086,7 +2086,7 @@ void serialize(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2108,7 +2108,7 @@ void serialize(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2130,7 +2130,7 @@ void serialize(const raft::resources& clique, * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg_index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; @@ -2155,7 +2155,7 @@ auto deserialize(const raft::resources& clique, const std::string& filename) * * Usage example: * @code{.cpp} - * raft::resources clique; + * raft::device_resources_snmg clique; * cuvs::neighbors::ivf_pq::index_params index_params; * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; diff --git a/docs/source/cpp_api/neighbors_mg.rst b/docs/source/cpp_api/neighbors_mg.rst index d30e4b7136..6192c4f9d6 100644 --- a/docs/source/cpp_api/neighbors_mg.rst +++ b/docs/source/cpp_api/neighbors_mg.rst @@ -9,7 +9,7 @@ The SNMG (single-node multi-GPUs) ANN API provides a set of functions to deploy ``#include `` -namespace *cuvs::neighbors::mg* +namespace *cuvs::neighbors* Index build parameters ---------------------- @@ -20,7 +20,7 @@ Index build parameters :content-only: Search parameters ----------------------- +----------------- .. doxygengroup:: mg_cpp_search_params :project: cuvs From 028408b9b0358491e8d7961996278103a290a048 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 2 Apr 2025 17:23:56 +0200 Subject: [PATCH 14/17] Answering review --- cpp/include/cuvs/neighbors/common.hpp | 5 ---- cpp/src/neighbors/mg/generate_mg.py | 30 +++++++++---------- .../neighbors/mg/mg_cagra_float_uint32_t.cu | 10 +++---- .../neighbors/mg/mg_cagra_half_uint32_t.cu | 10 +++---- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 10 +++---- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 10 +++---- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 12 ++++---- .../neighbors/mg/mg_flat_int8_t_int64_t.cu | 12 ++++---- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 12 ++++---- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 12 ++++---- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 12 ++++---- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 12 ++++---- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 12 ++++---- cpp/src/neighbors/mg/omp_checks.cpp | 4 +-- cpp/src/neighbors/mg/{mg.cuh => snmg.cuh} | 26 ++++++++-------- 15 files changed, 92 insertions(+), 97 deletions(-) rename cpp/src/neighbors/mg/{mg.cuh => snmg.cuh} (97%) diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 6ddcebe961..4c31987989 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -740,9 +740,7 @@ enable_if_valid_list_t deserialize_list(const raft::resources& handle, const typename ListT::spec_type& store_spec, const typename ListT::spec_type& device_spec); } // namespace ivf -} // namespace cuvs::neighbors -namespace cuvs::neighbors { using namespace raft; template @@ -792,11 +790,8 @@ void deserialize(const raft::resources& handle, cuvs::neighbors::iface& interface, const std::string& filename); -}; // namespace cuvs::neighbors - /// \defgroup mg_cpp_index_params ANN MG index build parameters -namespace cuvs::neighbors { /** Distribution mode */ /// \ingroup mg_cpp_index_params enum distribution_mode { diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index a3f3725c5f..49d4609fcd 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -40,7 +40,7 @@ """ include_macro = """ -#include "mg.cuh" +#include "snmg.cuh" """ flat_macro = """ @@ -54,7 +54,7 @@ raft::host_matrix_view index_dataset) \\ { \\ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \\ - cuvs::neighbors::mg::detail::build(res, index, \\ + cuvs::neighbors::snmg::detail::build(res, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ @@ -65,7 +65,7 @@ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \\ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \\ } \\ \\ void search(const raft::resources& res, \\ @@ -75,7 +75,7 @@ raft::host_matrix_view neighbors, \\ raft::host_matrix_view distances) \\ { \\ - cuvs::neighbors::mg::detail::search(res, index, \\ + cuvs::neighbors::snmg::detail::search(res, index, \\ static_cast(&search_params), \\ queries, neighbors, distances); \\ } \\ @@ -84,7 +84,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \\ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \\ } \\ \\ template<> \\ @@ -102,7 +102,7 @@ const std::string& filename) \\ { \\ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \\ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \\ return idx; \\ } \\ } // namespace cuvs::neighbors::ivf_flat @@ -119,7 +119,7 @@ raft::host_matrix_view index_dataset) \\ { \\ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \\ - cuvs::neighbors::mg::detail::build(res, index, \\ + cuvs::neighbors::snmg::detail::build(res, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ @@ -130,7 +130,7 @@ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \\ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \\ } \\ \\ void search(const raft::resources& res, \\ @@ -140,7 +140,7 @@ raft::host_matrix_view neighbors, \\ raft::host_matrix_view distances) \\ { \\ - cuvs::neighbors::mg::detail::search(res, index, \\ + cuvs::neighbors::snmg::detail::search(res, index, \\ static_cast(&search_params), \\ queries, neighbors, distances); \\ } \\ @@ -149,7 +149,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \\ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \\ } \\ \\ template<> \\ @@ -167,7 +167,7 @@ const std::string& filename) \\ { \\ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \\ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \\ return idx; \\ } \\ } // namespace cuvs::neighbors::ivf_pq @@ -184,7 +184,7 @@ raft::host_matrix_view index_dataset) \\ { \\ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \\ - cuvs::neighbors::mg::detail::build(res, index, \\ + cuvs::neighbors::snmg::detail::build(res, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ @@ -197,7 +197,7 @@ raft::host_matrix_view neighbors, \\ raft::host_matrix_view distances) \\ { \\ - cuvs::neighbors::mg::detail::search(res, index, \\ + cuvs::neighbors::snmg::detail::search(res, index, \\ static_cast(&search_params), \\ queries, neighbors, distances); \\ } \\ @@ -206,7 +206,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \\ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \\ } \\ \\ template<> \\ @@ -224,7 +224,7 @@ const std::string& filename) \\ { \\ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \\ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \\ return idx; \\ } \\ } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index f196b9c596..3260df7050 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_CAGRA(T, IdxT) \ namespace cuvs::neighbors::cagra { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -50,7 +50,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -63,7 +63,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -79,7 +79,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index 8cb67a96f7..da8cbb1e49 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_CAGRA(T, IdxT) \ namespace cuvs::neighbors::cagra { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -50,7 +50,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -63,7 +63,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -79,7 +79,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index 25ed90c67f..5aae599e22 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_CAGRA(T, IdxT) \ namespace cuvs::neighbors::cagra { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -50,7 +50,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -63,7 +63,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -79,7 +79,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index 43b82426be..8384a6a972 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_CAGRA(T, IdxT) \ namespace cuvs::neighbors::cagra { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -50,7 +50,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -63,7 +63,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -79,7 +79,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index d6b77faf2a..1948ab8c8a 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_FLAT(T, IdxT) \ namespace cuvs::neighbors::ivf_flat { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -48,7 +48,7 @@ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ } \ \ void search(const raft::resources& res, \ @@ -58,7 +58,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -71,7 +71,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -87,7 +87,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index fb2f85b817..1e0c9fddda 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_FLAT(T, IdxT) \ namespace cuvs::neighbors::ivf_flat { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -48,7 +48,7 @@ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ } \ \ void search(const raft::resources& res, \ @@ -58,7 +58,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -71,7 +71,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -87,7 +87,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index 33fe220974..cc3945a05b 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_FLAT(T, IdxT) \ namespace cuvs::neighbors::ivf_flat { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -48,7 +48,7 @@ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ } \ \ void search(const raft::resources& res, \ @@ -58,7 +58,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -71,7 +71,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -87,7 +87,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index d6c3e271f8..d25cd3656d 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_PQ(T, IdxT) \ namespace cuvs::neighbors::ivf_pq { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -48,7 +48,7 @@ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ } \ \ void search(const raft::resources& res, \ @@ -58,7 +58,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -71,7 +71,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -87,7 +87,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index bd8ccd0645..f9808df864 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_PQ(T, IdxT) \ namespace cuvs::neighbors::ivf_pq { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -48,7 +48,7 @@ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ } \ \ void search(const raft::resources& res, \ @@ -58,7 +58,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -71,7 +71,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -87,7 +87,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index 220b5bd01c..f2205626d8 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_PQ(T, IdxT) \ namespace cuvs::neighbors::ivf_pq { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -48,7 +48,7 @@ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ } \ \ void search(const raft::resources& res, \ @@ -58,7 +58,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -71,7 +71,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -87,7 +87,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index a63428ae39..1ec306a34a 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -23,7 +23,7 @@ * */ -#include "mg.cuh" +#include "snmg.cuh" #define CUVS_INST_MG_PQ(T, IdxT) \ namespace cuvs::neighbors::ivf_pq { \ @@ -35,7 +35,7 @@ raft::host_matrix_view index_dataset) \ { \ cuvs::neighbors::mg_index, T, IdxT> index(res, index_params.mode); \ - cuvs::neighbors::mg::detail::build( \ + cuvs::neighbors::snmg::detail::build( \ res, \ index, \ static_cast(&index_params), \ @@ -48,7 +48,7 @@ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(res, index, new_vectors, new_indices); \ + cuvs::neighbors::snmg::detail::extend(res, index, new_vectors, new_indices); \ } \ \ void search(const raft::resources& res, \ @@ -58,7 +58,7 @@ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ - cuvs::neighbors::mg::detail::search( \ + cuvs::neighbors::snmg::detail::search( \ res, \ index, \ static_cast(&search_params), \ @@ -71,7 +71,7 @@ const cuvs::neighbors::mg_index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(res, index, filename); \ + cuvs::neighbors::snmg::detail::serialize(res, index, filename); \ } \ \ template <> \ @@ -87,7 +87,7 @@ const raft::resources& res, const std::string& filename) \ { \ auto idx = cuvs::neighbors::mg_index, T, IdxT>(res, REPLICATED); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(res, idx, filename); \ + cuvs::neighbors::snmg::detail::deserialize_and_distribute(res, idx, filename); \ return idx; \ } \ } // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/mg/omp_checks.cpp b/cpp/src/neighbors/mg/omp_checks.cpp index c8cc274141..91a23c8f32 100644 --- a/cpp/src/neighbors/mg/omp_checks.cpp +++ b/cpp/src/neighbors/mg/omp_checks.cpp @@ -17,7 +17,7 @@ #include #include -namespace cuvs::neighbors::mg { +namespace cuvs::neighbors::snmg { void check_omp_threads(const int requirements) { @@ -30,4 +30,4 @@ void check_omp_threads(const int requirements) requirements); } -} // namespace cuvs::neighbors::mg +} // namespace cuvs::neighbors::snmg diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/snmg.cuh similarity index 97% rename from cpp/src/neighbors/mg/mg.cuh rename to cpp/src/neighbors/mg/snmg.cuh index 758bc028eb..3d9e795adc 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/snmg.cuh @@ -41,11 +41,11 @@ void search(const raft::resources& handle, raft::device_matrix_view d_distances); } // namespace cuvs::neighbors -namespace cuvs::neighbors::mg { +namespace cuvs::neighbors::snmg { void check_omp_threads(const int requirements); -} // namespace cuvs::neighbors::mg +} // namespace cuvs::neighbors::snmg -namespace cuvs::neighbors::mg::detail { +namespace cuvs::neighbors::snmg::detail { using namespace cuvs::neighbors; using namespace raft; @@ -75,10 +75,10 @@ void deserialize(const raft::resources& clique, index.mode_ = (cuvs::neighbors::distribution_mode)deserialize_scalar(handle, is); index.num_ranks_ = deserialize_scalar(handle, is); - if (index.num_ranks_ != raft::resource::get_num_ranks(clique)) { + if (index.num_ranks_ != raft::resource::get_nccl_num_ranks(clique)) { RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks", index.num_ranks_, - raft::resource::get_num_ranks(clique)); + raft::resource::get_nccl_num_ranks(clique)); } for (int rank = 0; rank < index.num_ranks_; rank++) { @@ -215,8 +215,8 @@ void sharded_search_with_direct_merge(const raft::resources& clique, const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank); auto& ann_if = index.ann_interfaces_[rank]; - if (rank == raft::resource::get_clique_root_rank(clique)) { // root rank - uint64_t batch_offset = raft::resource::get_clique_root_rank(clique) * part_size; + if (rank == raft::resource::get_nccl_clique_root_rank(clique)) { // root rank + uint64_t batch_offset = raft::resource::get_nccl_clique_root_rank(clique) * part_size; auto d_neighbors = raft::make_device_matrix_view( in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); auto d_distances = raft::make_device_matrix_view( @@ -227,7 +227,7 @@ void sharded_search_with_direct_merge(const raft::resources& clique, // wait for other ranks ncclGroupStart(); for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) { - if (from_rank == raft::resource::get_clique_root_rank(clique)) continue; + if (from_rank == raft::resource::get_nccl_clique_root_rank(clique)) continue; batch_offset = from_rank * part_size; ncclRecv(in_neighbors.data_handle() + batch_offset, @@ -258,13 +258,13 @@ void sharded_search_with_direct_merge(const raft::resources& clique, ncclSend(d_neighbors.data_handle(), part_size * sizeof(IdxT), ncclUint8, - raft::resource::get_clique_root_rank(clique), + raft::resource::get_nccl_clique_root_rank(clique), raft::resource::get_nccl_comm(dev_res), raft::resource::get_cuda_stream(dev_res)); ncclSend(d_distances.data_handle(), part_size * sizeof(float), ncclUint8, - raft::resource::get_clique_root_rank(clique), + raft::resource::get_nccl_clique_root_rank(clique), raft::resource::get_nccl_comm(dev_res), raft::resource::get_cuda_stream(dev_res)); ncclGroupEnd(); @@ -645,7 +645,7 @@ void serialize(const raft::resources& clique, if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } } -} // namespace cuvs::neighbors::mg::detail +} // namespace cuvs::neighbors::snmg::detail namespace cuvs::neighbors { using namespace cuvs::neighbors; @@ -655,7 +655,7 @@ template mg_index::mg_index(const raft::resources& clique, distribution_mode mode) : mode_(mode), round_robin_counter_(std::make_shared>(0)) { - num_ranks_ = raft::resource::get_num_ranks(clique); + num_ranks_ = raft::resource::get_nccl_num_ranks(clique); } template @@ -663,6 +663,6 @@ mg_index::mg_index(const raft::resources& clique, const std::string& filename) : round_robin_counter_(std::make_shared>(0)) { - cuvs::neighbors::mg::detail::deserialize(clique, *this, filename); + cuvs::neighbors::snmg::detail::deserialize(clique, *this, filename); } } // namespace cuvs::neighbors From 4e0f512a31b5abef3267f35b7c60d5291924ac38 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 2 Apr 2025 18:41:12 +0200 Subject: [PATCH 15/17] Fix sync issue in iface --- cpp/src/neighbors/iface/iface.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cpp/src/neighbors/iface/iface.hpp b/cpp/src/neighbors/iface/iface.hpp index 8ccc0d54b9..4f7a6f5cc9 100644 --- a/cpp/src/neighbors/iface/iface.hpp +++ b/cpp/src/neighbors/iface/iface.hpp @@ -150,6 +150,8 @@ void serialize(const raft::resources& handle, } else if constexpr (std::is_same>::value) { cagra::serialize(handle, os, interface.index_.value(), true); } + + resource::sync_stream(handle); } template @@ -162,14 +164,17 @@ void deserialize(const raft::resources& handle, if constexpr (std::is_same>::value) { ivf_flat::index idx(handle); ivf_flat::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { ivf_pq::index idx(handle); ivf_pq::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { cagra::index idx(handle); cagra::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } } @@ -187,14 +192,17 @@ void deserialize(const raft::resources& handle, if constexpr (std::is_same>::value) { ivf_flat::index idx(handle); ivf_flat::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { ivf_pq::index idx(handle); ivf_pq::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { cagra::index idx(handle); cagra::deserialize(handle, is, &idx); + resource::sync_stream(handle); interface.index_.emplace(std::move(idx)); } From 1e1d97e83e1595a77a726021b4b1265778720080 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 23 Apr 2025 13:27:09 +0200 Subject: [PATCH 16/17] temporary pin to RAFT compilation fix --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index a71dc4cb43..ca519215b5 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -14,7 +14,7 @@ # Use RAPIDS_VERSION_MAJOR_MINOR from rapids_config.cmake set(RAFT_VERSION "${RAPIDS_VERSION_MAJOR_MINOR}") set(RAFT_FORK "rapidsai") -set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION_MAJOR_MINOR}") +set(RAFT_PINNED_TAG "fix-type-qualifier-specified-more-than-once") function(find_and_configure_raft) set(oneValueArgs VERSION FORK PINNED_TAG USE_RAFT_STATIC ENABLE_NVTX ENABLE_MNMG_DEPENDENCIES CLONE_ON_PIN) From 34d67ee598b07cf60bdc7b0be988cb3844dbdb5b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 23 Apr 2025 11:34:01 -0400 Subject: [PATCH 17/17] Update cpp/cmake/thirdparty/get_raft.cmake Co-authored-by: Artem M. Chirkin <9253178+achirkin@users.noreply.github.com> --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index ca519215b5..a71dc4cb43 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -14,7 +14,7 @@ # Use RAPIDS_VERSION_MAJOR_MINOR from rapids_config.cmake set(RAFT_VERSION "${RAPIDS_VERSION_MAJOR_MINOR}") set(RAFT_FORK "rapidsai") -set(RAFT_PINNED_TAG "fix-type-qualifier-specified-more-than-once") +set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION_MAJOR_MINOR}") function(find_and_configure_raft) set(oneValueArgs VERSION FORK PINNED_TAG USE_RAFT_STATIC ENABLE_NVTX ENABLE_MNMG_DEPENDENCIES CLONE_ON_PIN)