From 6a2311c037529679821585de6c94588238b99c9b Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 9 Apr 2026 09:49:07 -0700 Subject: [PATCH 1/4] Initial exposition of KMeans for PQ Signed-off-by: Mickael Ide --- c/src/preprocessing/quantize/pq.cpp | 31 +++--- .../cuvs/preprocessing/quantize/pq.hpp | 22 +++-- cpp/src/neighbors/detail/vpq_dataset.cuh | 94 ++++++++++--------- .../neighbors/scann/detail/scann_build.cuh | 5 +- cpp/src/preprocessing/quantize/detail/pq.cuh | 60 ++++++++++-- .../preprocessing/product_quantization.cu | 12 ++- 6 files changed, 149 insertions(+), 75 deletions(-) diff --git a/c/src/preprocessing/quantize/pq.cpp b/c/src/preprocessing/quantize/pq.cpp index 35a8f5ed72..c7381cdcbd 100644 --- a/c/src/preprocessing/quantize/pq.cpp +++ b/c/src/preprocessing/quantize/pq.cpp @@ -63,17 +63,26 @@ void* _build(cuvsResources_t res, auto res_ptr = reinterpret_cast(res); - auto quantizer_params = cuvs::preprocessing::quantize::pq::params{ - .pq_bits = params->pq_bits, - .pq_dim = params->pq_dim, - .use_subspaces = params->use_subspaces, - .use_vq = params->use_vq, - .vq_n_centers = params->vq_n_centers, - .kmeans_n_iters = params->kmeans_n_iters, - .pq_kmeans_type = static_cast(params->pq_kmeans_type), - .max_train_points_per_pq_code = params->max_train_points_per_pq_code, - .max_train_points_per_vq_cluster = params->max_train_points_per_vq_cluster - }; + cuvs::preprocessing::quantize::pq::kmeans_params_variant kmeans_params_var; + if (params->pq_kmeans_type == CUVS_KMEANS_TYPE_KMEANS_BALANCED) { + cuvs::cluster::kmeans::balanced_params bp; + bp.n_iters = params->kmeans_n_iters; + kmeans_params_var = bp; + } else { + cuvs::cluster::kmeans::params classic_params; + classic_params.max_iter = params->kmeans_n_iters; + kmeans_params_var = classic_params; + } + + cuvs::preprocessing::quantize::pq::params quantizer_params; + quantizer_params.pq_bits = params->pq_bits; + quantizer_params.pq_dim = params->pq_dim; + quantizer_params.use_subspaces = params->use_subspaces; + quantizer_params.use_vq = params->use_vq; + quantizer_params.vq_n_centers = params->vq_n_centers; + quantizer_params.kmeans_params = kmeans_params_var; + quantizer_params.max_train_points_per_pq_code = params->max_train_points_per_pq_code; + quantizer_params.max_train_points_per_vq_cluster = params->max_train_points_per_vq_cluster; cuvs::preprocessing::quantize::pq::quantizer* ret = nullptr; if (cuvs::core::is_dlpack_device_compatible(dataset)) { diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 043bd92534..bfd79cdcc2 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -5,11 +5,14 @@ #pragma once +#include #include #include #include #include +#include + namespace cuvs::preprocessing::quantize::pq { /** @@ -17,6 +20,10 @@ namespace cuvs::preprocessing::quantize::pq { * @{ */ +/** Alias for the variant holding either balanced or regular k-means parameters. */ +using kmeans_params_variant = + std::variant; + /** * @brief Product Quantizer parameters. */ @@ -53,16 +60,15 @@ struct params { * When zero, an optimal value is selected using a heuristic. */ uint32_t vq_n_centers = 0; - /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ - uint32_t kmeans_n_iters = 25; /** - * Type of k-means algorithm for PQ training. - * Balanced k-means tends to be faster than regular k-means for PQ training, for - * problem sets where the number of points per cluster are approximately equal. - * Regular k-means may be better for skewed cluster distributions. + * K-means parameters for PQ codebook training. + * + * Set to cuvs::cluster::kmeans::balanced_params for balanced k-means (default), + * or cuvs::cluster::kmeans::params for regular k-means. + * The active variant type selects the algorithm; balanced k-means tends to be faster + * for PQ training where cluster sizes are approximately equal. */ - cuvs::cluster::kmeans::kmeans_type pq_kmeans_type = - cuvs::cluster::kmeans::kmeans_type::KMeansBalanced; + kmeans_params_variant kmeans_params = cuvs::cluster::kmeans::balanced_params{}; /** * The max number of data points to use per PQ code during PQ codebook training. Using more data * points per PQ code may increase the quality of PQ codebook but may also increase the build diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index cbe06f5ca4..1b60e64e0a 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -5,6 +5,7 @@ #pragma once #include +#include #include "../../cluster/kmeans_balanced.cuh" #include "../../preprocessing/quantize/detail/pq_codepacking.cuh" // pq_bits-bitfield @@ -74,50 +75,51 @@ namespace cuvs::neighbors::detail { template void train_pq_centers( const raft::resources& res, - const cuvs::neighbors::vpq_params& params, + const cuvs::preprocessing::quantize::pq::kmeans_params_variant& kmeans_params, const raft::device_matrix_view pq_trainset_view, const raft::device_matrix_view pq_centers_view, raft::device_vector_view sub_labels_view, raft::device_vector_view pq_cluster_sizes_view) { - if (params.pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced) { - cuvs::cluster::kmeans::balanced_params kmeans_params; - kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded; - - cuvs::cluster::kmeans_balanced::helpers::build_clusters< - MathT, - MathT, - IdxT, - uint32_t, - uint32_t, - cuvs::spatial::knn::detail::utils::mapping>( - res, - kmeans_params, - pq_trainset_view, - pq_centers_view, - sub_labels_view, - pq_cluster_sizes_view, - cuvs::spatial::knn::detail::utils::mapping{}); - } else { - const auto pq_n_centers = pq_centers_view.extent(0); - cuvs::cluster::kmeans::params kmeans_params; - kmeans_params.n_clusters = pq_n_centers; - kmeans_params.max_iter = params.kmeans_n_iters; - kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded; - kmeans_params.init = cuvs::cluster::kmeans::params::InitMethod::Random; - - std::optional> sample_weight = std::nullopt; - MathT inertia; - IdxT n_iter; - cuvs::cluster::kmeans::fit(res, - kmeans_params, - pq_trainset_view, - sample_weight, - pq_centers_view, - raft::make_host_scalar_view(&inertia), - raft::make_host_scalar_view(&n_iter)); - } + std::visit( + [&](auto const& base_kmeans_params) { + using KP = std::decay_t; + if constexpr (std::is_same_v) { + auto bal_params = base_kmeans_params; + bal_params.metric = cuvs::distance::DistanceType::L2Expanded; + cuvs::cluster::kmeans_balanced::helpers::build_clusters< + MathT, + MathT, + IdxT, + uint32_t, + uint32_t, + cuvs::spatial::knn::detail::utils::mapping>( + res, + bal_params, + pq_trainset_view, + pq_centers_view, + sub_labels_view, + pq_cluster_sizes_view, + cuvs::spatial::knn::detail::utils::mapping{}); + } else { + auto classic_params = base_kmeans_params; + classic_params.n_clusters = pq_centers_view.extent(0); + classic_params.metric = cuvs::distance::DistanceType::L2Expanded; + RAFT_EXPECTS(classic_params.init != cuvs::cluster::kmeans::params::InitMethod::Array, + "Array initialization is not supported for PQ training"); + std::optional> sample_weight = std::nullopt; + MathT inertia; + IdxT n_iter; + cuvs::cluster::kmeans::fit(res, + classic_params, + pq_trainset_view, + sample_weight, + pq_centers_view, + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + } + }, + kmeans_params); } template @@ -219,7 +221,7 @@ auto predict_vq(const raft::resources& res, template auto train_pq(const raft::resources& res, - const vpq_params& params, + const cuvs::preprocessing::quantize::pq::params params, const DatasetT& dataset, const raft::device_matrix_view vq_centers) -> raft::device_matrix @@ -230,8 +232,8 @@ auto train_pq(const raft::resources& res, const ix_t pq_bits = params.pq_bits; const ix_t pq_n_centers = ix_t{1} << pq_bits; const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim); - const ix_t n_rows_train = std::min((ix_t)(n_rows * params.pq_kmeans_trainset_fraction), - params.max_train_points_per_pq_code * pq_n_centers); + const ix_t n_rows_train = + std::min(n_rows, params.max_train_points_per_pq_code * pq_n_centers); RAFT_EXPECTS( n_rows_train >= pq_n_centers, "The number of training samples must be greater than or equal to the number of PQ centers"); @@ -261,8 +263,12 @@ auto train_pq(const raft::resources& res, pq_trainset.data_handle(), n_rows_train * pq_dim, pq_len); auto sub_labels = raft::make_device_vector(res, pq_trainset_view.extent(0)); auto pq_cluster_sizes = raft::make_device_vector(res, pq_centers.extent(0)); - train_pq_centers( - res, params, pq_trainset_view, pq_centers.view(), sub_labels.view(), pq_cluster_sizes.view()); + train_pq_centers(res, + params.kmeans_params, + pq_trainset_view, + pq_centers.view(), + sub_labels.view(), + pq_cluster_sizes.view()); return pq_centers; } diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 37f74f29bb..c4eae0ac9c 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -161,13 +161,14 @@ index build( int num_clusters = 1 << params.pq_bits; cuvs::preprocessing::quantize::pq::params pq_build_params; + cuvs::cluster::kmeans::balanced_params pq_bp; + pq_bp.n_iters = params.pq_train_iters; pq_build_params.pq_bits = params.pq_bits; pq_build_params.pq_dim = num_subspaces; pq_build_params.use_subspaces = true; pq_build_params.use_vq = false; // We already computed residuals - pq_build_params.kmeans_n_iters = params.pq_train_iters; + pq_build_params.kmeans_params = pq_bp; pq_build_params.max_train_points_per_pq_code = pq_n_rows_train / num_clusters; - pq_build_params.pq_kmeans_type = cuvs::cluster::kmeans::kmeans_type::KMeansBalanced; auto pq_quantizer = cuvs::preprocessing::quantize::pq::build( res, pq_build_params, raft::make_const_mdspan(trainset_residuals.view())); diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 77fb0ac4f9..8ad8068246 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -29,17 +29,38 @@ inline void fill_missing_params_heuristics(cuvs::preprocessing::quantize::pq::pa } } +inline bool is_balanced_kmeans(const cuvs::preprocessing::quantize::pq::params& params) +{ + return std::holds_alternative(params.kmeans_params); +} + +inline uint32_t get_kmeans_n_iters(const cuvs::preprocessing::quantize::pq::params& params) +{ + return std::visit( + [](const auto& kp) -> uint32_t { + using kmeans_type = std::decay_t; + if constexpr (std::is_same_v) { + return kp.n_iters; + } else { + return static_cast(kp.max_iter); + } + }, + params.kmeans_params); +} + inline auto to_vpq_params(const cuvs::preprocessing::quantize::pq::params& params) -> cuvs::neighbors::vpq_params { + auto kmeans_type = is_balanced_kmeans(params) ? cuvs::cluster::kmeans::kmeans_type::KMeansBalanced + : cuvs::cluster::kmeans::kmeans_type::KMeans; return cuvs::neighbors::vpq_params{ .pq_bits = params.pq_bits, .pq_dim = params.pq_dim, .vq_n_centers = params.vq_n_centers, - .kmeans_n_iters = params.kmeans_n_iters, + .kmeans_n_iters = get_kmeans_n_iters(params), .vq_kmeans_trainset_fraction = 1.0, .pq_kmeans_trainset_fraction = 1.0, - .pq_kmeans_type = params.pq_kmeans_type, + .pq_kmeans_type = kmeans_type, .max_train_points_per_pq_code = params.max_train_points_per_pq_code, .max_train_points_per_vq_cluster = params.max_train_points_per_vq_cluster}; } @@ -92,7 +113,7 @@ auto train_pq_subspaces( auto sub_labels = raft::make_device_vector(res, 0); auto pq_cluster_sizes = raft::make_device_vector(res, 0); auto device_memory = raft::resource::get_workspace_resource(res); - if (params.pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced) { + if (is_balanced_kmeans(params)) { sub_labels = raft::make_device_mdarray( res, device_memory, raft::make_extents(n_rows_train)); pq_cluster_sizes = raft::make_device_mdarray( @@ -111,7 +132,7 @@ auto train_pq_subspaces( pq_centers.data_handle() + m * pq_n_centers * pq_len, pq_n_centers, pq_len); cuvs::neighbors::detail::train_pq_centers( res, - to_vpq_params(params), + params.kmeans_params, raft::make_const_mdspan(sub_dataset.view()), pq_centers_subspace_view, sub_labels.view(), @@ -152,7 +173,7 @@ quantizer build( res, filled_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } else { pq_code_book = cuvs::neighbors::detail::train_pq( - res, vpq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); + res, filled_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } return {filled_params, cuvs::neighbors::vpq_dataset{ @@ -344,6 +365,32 @@ void vpq_convert_math_type(const raft::resources& res, raft::make_const_mdspan(src.pq_code_book.view())); } +inline auto from_vpq_params(const cuvs::neighbors::vpq_params& in_params) + -> cuvs::preprocessing::quantize::pq::params +{ + std::variant kmeans_params; + if (in_params.pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced) { + kmeans_params = cuvs::cluster::kmeans::balanced_params{ + .n_iters = in_params.kmeans_n_iters, + }; + } else { + kmeans_params = cuvs::cluster::kmeans::params{ + .max_iter = (int)in_params.kmeans_n_iters, + }; + } + const uint32_t pq_n_centers = 1 << in_params.pq_bits; + return cuvs::preprocessing::quantize::pq::params{ + .pq_bits = in_params.pq_bits, + .pq_dim = in_params.pq_dim, + .vq_n_centers = in_params.vq_n_centers, + .kmeans_params = kmeans_params, + .max_train_points_per_pq_code = std::min( + in_params.max_train_points_per_pq_code, pq_n_centers * in_params.pq_kmeans_trainset_fraction), + .max_train_points_per_vq_cluster = + std::min(in_params.max_train_points_per_vq_cluster, + in_params.vq_n_centers * in_params.vq_kmeans_trainset_fraction)}; +} + template auto vpq_build(const raft::resources& res, const cuvs::neighbors::vpq_params& params, @@ -353,10 +400,11 @@ auto vpq_build(const raft::resources& res, // Use a heuristic to impute missing parameters. auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset); + auto pq_params = from_vpq_params(ps); // Train codes auto vq_code_book = cuvs::neighbors::detail::train_vq(res, ps, dataset); auto pq_code_book = cuvs::neighbors::detail::train_pq( - res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); + res, pq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); // Encode dataset const IdxT n_rows = dataset.extent(0); diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index f1b84854f1..aec590278b 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -160,14 +160,18 @@ class ProductQuantizationTest : public ::testing::TestWithParam Date: Fri, 10 Apr 2026 07:12:27 -0700 Subject: [PATCH 2/4] Fix trainset fraction Signed-off-by: Mickael Ide --- cpp/src/preprocessing/quantize/detail/pq.cuh | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 8ad8068246..cb3b3f7c09 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -365,7 +365,7 @@ void vpq_convert_math_type(const raft::resources& res, raft::make_const_mdspan(src.pq_code_book.view())); } -inline auto from_vpq_params(const cuvs::neighbors::vpq_params& in_params) +inline auto from_vpq_params(const cuvs::neighbors::vpq_params& in_params, const uint64_t n_rows) -> cuvs::preprocessing::quantize::pq::params { std::variant kmeans_params; @@ -380,15 +380,16 @@ inline auto from_vpq_params(const cuvs::neighbors::vpq_params& in_params) } const uint32_t pq_n_centers = 1 << in_params.pq_bits; return cuvs::preprocessing::quantize::pq::params{ - .pq_bits = in_params.pq_bits, - .pq_dim = in_params.pq_dim, - .vq_n_centers = in_params.vq_n_centers, - .kmeans_params = kmeans_params, - .max_train_points_per_pq_code = std::min( - in_params.max_train_points_per_pq_code, pq_n_centers * in_params.pq_kmeans_trainset_fraction), + .pq_bits = in_params.pq_bits, + .pq_dim = in_params.pq_dim, + .vq_n_centers = in_params.vq_n_centers, + .kmeans_params = kmeans_params, + .max_train_points_per_pq_code = + std::min(in_params.max_train_points_per_pq_code, + n_rows * in_params.pq_kmeans_trainset_fraction / pq_n_centers), .max_train_points_per_vq_cluster = std::min(in_params.max_train_points_per_vq_cluster, - in_params.vq_n_centers * in_params.vq_kmeans_trainset_fraction)}; + n_rows * in_params.vq_kmeans_trainset_fraction / in_params.vq_n_centers)}; } template @@ -400,7 +401,7 @@ auto vpq_build(const raft::resources& res, // Use a heuristic to impute missing parameters. auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset); - auto pq_params = from_vpq_params(ps); + auto pq_params = from_vpq_params(ps, dataset.extent(0)); // Train codes auto vq_code_book = cuvs::neighbors::detail::train_vq(res, ps, dataset); auto pq_code_book = cuvs::neighbors::detail::train_pq( From 3e66e9702e2298efb22d64c68e275633edf96136 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 14 Apr 2026 10:57:36 -0700 Subject: [PATCH 3/4] Add simplified constructor Signed-off-by: Mickael Ide --- c/src/preprocessing/quantize/pq.cpp | 25 ++-------- .../cuvs/preprocessing/quantize/pq.hpp | 48 +++++++++++++++++++ .../neighbors/scann/detail/scann_build.cuh | 19 ++++---- cpp/src/preprocessing/quantize/detail/pq.cuh | 31 +++++------- .../preprocessing/product_quantization.cu | 5 +- 5 files changed, 77 insertions(+), 51 deletions(-) diff --git a/c/src/preprocessing/quantize/pq.cpp b/c/src/preprocessing/quantize/pq.cpp index c7381cdcbd..1e3a48694a 100644 --- a/c/src/preprocessing/quantize/pq.cpp +++ b/c/src/preprocessing/quantize/pq.cpp @@ -62,27 +62,10 @@ void* _build(cuvsResources_t res, auto dataset = dataset_tensor->dl_tensor; auto res_ptr = reinterpret_cast(res); - - cuvs::preprocessing::quantize::pq::kmeans_params_variant kmeans_params_var; - if (params->pq_kmeans_type == CUVS_KMEANS_TYPE_KMEANS_BALANCED) { - cuvs::cluster::kmeans::balanced_params bp; - bp.n_iters = params->kmeans_n_iters; - kmeans_params_var = bp; - } else { - cuvs::cluster::kmeans::params classic_params; - classic_params.max_iter = params->kmeans_n_iters; - kmeans_params_var = classic_params; - } - - cuvs::preprocessing::quantize::pq::params quantizer_params; - quantizer_params.pq_bits = params->pq_bits; - quantizer_params.pq_dim = params->pq_dim; - quantizer_params.use_subspaces = params->use_subspaces; - quantizer_params.use_vq = params->use_vq; - quantizer_params.vq_n_centers = params->vq_n_centers; - quantizer_params.kmeans_params = kmeans_params_var; - quantizer_params.max_train_points_per_pq_code = params->max_train_points_per_pq_code; - quantizer_params.max_train_points_per_vq_cluster = params->max_train_points_per_vq_cluster; + cuvs::preprocessing::quantize::pq::params quantizer_params( + params->pq_bits, params->pq_dim, params->use_subspaces, params->use_vq, params->vq_n_centers, + params->kmeans_n_iters, static_cast(params->pq_kmeans_type), params->max_train_points_per_pq_code, + params->max_train_points_per_vq_cluster); cuvs::preprocessing::quantize::pq::quantizer* ret = nullptr; if (cuvs::core::is_dlpack_device_compatible(dataset)) { diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index bfd79cdcc2..4690c313b2 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -28,6 +28,54 @@ using kmeans_params_variant = * @brief Product Quantizer parameters. */ struct params { + /** + * Simplified constructor that will build an appropriate kmeans params object. + */ + params(uint32_t pq_bits, + uint32_t pq_dim, + bool use_subspaces, + bool use_vq, + uint32_t vq_n_centers, + uint32_t kmeans_n_iters, + cuvs::cluster::kmeans::kmeans_type pq_kmeans_type, + uint32_t max_train_points_per_pq_code, + uint32_t max_train_points_per_vq_cluster) + { + this->pq_bits = pq_bits; + this->pq_dim = pq_dim; + this->use_subspaces = use_subspaces; + this->use_vq = use_vq; + this->vq_n_centers = vq_n_centers; + this->kmeans_params = + pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced + ? kmeans_params_variant{cuvs::cluster::kmeans::balanced_params{.n_iters = kmeans_n_iters}} + : kmeans_params_variant{cuvs::cluster::kmeans::params{.n_clusters = 1 << pq_bits, + .max_iter = (int)kmeans_n_iters}}; + this->max_train_points_per_pq_code = max_train_points_per_pq_code; + this->max_train_points_per_vq_cluster = max_train_points_per_vq_cluster; + } + + params(uint32_t pq_bits, + uint32_t pq_dim, + bool use_subspaces, + bool use_vq, + uint32_t vq_n_centers, + kmeans_params_variant kmeans_params, + uint32_t max_train_points_per_pq_code, + uint32_t max_train_points_per_vq_cluster) + { + this->pq_bits = pq_bits; + this->pq_dim = pq_dim; + this->use_subspaces = use_subspaces; + this->use_vq = use_vq; + this->vq_n_centers = vq_n_centers; + this->kmeans_params = kmeans_params; + this->max_train_points_per_pq_code = max_train_points_per_pq_code; + this->max_train_points_per_vq_cluster = max_train_points_per_vq_cluster; + } + + params() = default; + /** * The bit length of the vector element after compression by PQ. * diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index c4eae0ac9c..1a09736d86 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -160,15 +160,16 @@ index build( int dim_per_subspace = params.pq_dim; int num_clusters = 1 << params.pq_bits; - cuvs::preprocessing::quantize::pq::params pq_build_params; - cuvs::cluster::kmeans::balanced_params pq_bp; - pq_bp.n_iters = params.pq_train_iters; - pq_build_params.pq_bits = params.pq_bits; - pq_build_params.pq_dim = num_subspaces; - pq_build_params.use_subspaces = true; - pq_build_params.use_vq = false; // We already computed residuals - pq_build_params.kmeans_params = pq_bp; - pq_build_params.max_train_points_per_pq_code = pq_n_rows_train / num_clusters; + cuvs::preprocessing::quantize::pq::params pq_build_params( + params.pq_bits, + num_subspaces, + true, + false, + 0, + params.pq_train_iters, + cuvs::cluster::kmeans::kmeans_type::KMeansBalanced, + pq_n_rows_train / num_clusters, + 1024); auto pq_quantizer = cuvs::preprocessing::quantize::pq::build( res, pq_build_params, raft::make_const_mdspan(trainset_residuals.view())); diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index cb3b3f7c09..a857f96875 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -368,28 +368,19 @@ void vpq_convert_math_type(const raft::resources& res, inline auto from_vpq_params(const cuvs::neighbors::vpq_params& in_params, const uint64_t n_rows) -> cuvs::preprocessing::quantize::pq::params { - std::variant kmeans_params; - if (in_params.pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced) { - kmeans_params = cuvs::cluster::kmeans::balanced_params{ - .n_iters = in_params.kmeans_n_iters, - }; - } else { - kmeans_params = cuvs::cluster::kmeans::params{ - .max_iter = (int)in_params.kmeans_n_iters, - }; - } const uint32_t pq_n_centers = 1 << in_params.pq_bits; return cuvs::preprocessing::quantize::pq::params{ - .pq_bits = in_params.pq_bits, - .pq_dim = in_params.pq_dim, - .vq_n_centers = in_params.vq_n_centers, - .kmeans_params = kmeans_params, - .max_train_points_per_pq_code = - std::min(in_params.max_train_points_per_pq_code, - n_rows * in_params.pq_kmeans_trainset_fraction / pq_n_centers), - .max_train_points_per_vq_cluster = - std::min(in_params.max_train_points_per_vq_cluster, - n_rows * in_params.vq_kmeans_trainset_fraction / in_params.vq_n_centers)}; + in_params.pq_bits, + in_params.pq_dim, + true, + true, + in_params.vq_n_centers, + in_params.kmeans_n_iters, + in_params.pq_kmeans_type, + std::min(in_params.max_train_points_per_pq_code, + n_rows * in_params.pq_kmeans_trainset_fraction / pq_n_centers), + std::min(in_params.max_train_points_per_vq_cluster, + n_rows * in_params.vq_kmeans_trainset_fraction / in_params.vq_n_centers)}; } template diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index aec590278b..e2dc37cc20 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -165,7 +165,10 @@ class ProductQuantizationTest : public ::testing::TestWithParam Date: Fri, 17 Apr 2026 08:00:21 -0700 Subject: [PATCH 4/4] Guard Div by zero Signed-off-by: Mickael Ide --- cpp/src/preprocessing/quantize/detail/pq.cuh | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index a857f96875..e9ac2257e2 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -368,7 +368,13 @@ void vpq_convert_math_type(const raft::resources& res, inline auto from_vpq_params(const cuvs::neighbors::vpq_params& in_params, const uint64_t n_rows) -> cuvs::preprocessing::quantize::pq::params { - const uint32_t pq_n_centers = 1 << in_params.pq_bits; + const uint32_t pq_n_centers = 1 << in_params.pq_bits; + uint32_t max_train_points_per_vq_cluster = in_params.max_train_points_per_vq_cluster; + if (in_params.vq_n_centers > 0) { + max_train_points_per_vq_cluster = + std::min(max_train_points_per_vq_cluster, + n_rows * in_params.vq_kmeans_trainset_fraction / in_params.vq_n_centers); + } return cuvs::preprocessing::quantize::pq::params{ in_params.pq_bits, in_params.pq_dim, @@ -379,8 +385,7 @@ inline auto from_vpq_params(const cuvs::neighbors::vpq_params& in_params, const in_params.pq_kmeans_type, std::min(in_params.max_train_points_per_pq_code, n_rows * in_params.pq_kmeans_trainset_fraction / pq_n_centers), - std::min(in_params.max_train_points_per_vq_cluster, - n_rows * in_params.vq_kmeans_trainset_fraction / in_params.vq_n_centers)}; + max_train_points_per_vq_cluster}; } template