diff --git a/c/src/preprocessing/quantize/pq.cpp b/c/src/preprocessing/quantize/pq.cpp index 35a8f5ed72..1e3a48694a 100644 --- a/c/src/preprocessing/quantize/pq.cpp +++ b/c/src/preprocessing/quantize/pq.cpp @@ -62,18 +62,10 @@ void* _build(cuvsResources_t res, auto dataset = dataset_tensor->dl_tensor; auto res_ptr = reinterpret_cast(res); - - auto quantizer_params = cuvs::preprocessing::quantize::pq::params{ - .pq_bits = params->pq_bits, - .pq_dim = params->pq_dim, - .use_subspaces = params->use_subspaces, - .use_vq = params->use_vq, - .vq_n_centers = params->vq_n_centers, - .kmeans_n_iters = params->kmeans_n_iters, - .pq_kmeans_type = static_cast(params->pq_kmeans_type), - .max_train_points_per_pq_code = params->max_train_points_per_pq_code, - .max_train_points_per_vq_cluster = params->max_train_points_per_vq_cluster - }; + cuvs::preprocessing::quantize::pq::params quantizer_params( + params->pq_bits, params->pq_dim, params->use_subspaces, params->use_vq, params->vq_n_centers, + params->kmeans_n_iters, static_cast(params->pq_kmeans_type), params->max_train_points_per_pq_code, + params->max_train_points_per_vq_cluster); cuvs::preprocessing::quantize::pq::quantizer* ret = nullptr; if (cuvs::core::is_dlpack_device_compatible(dataset)) { diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 043bd92534..4690c313b2 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -5,11 +5,14 @@ #pragma once +#include #include #include #include #include +#include + namespace cuvs::preprocessing::quantize::pq { /** @@ -17,10 +20,62 @@ namespace cuvs::preprocessing::quantize::pq { * @{ */ +/** Alias for the variant holding either balanced or regular k-means parameters. */ +using kmeans_params_variant = + std::variant; + /** * @brief Product Quantizer parameters. */ struct params { + /** + * Simplified constructor that will build an appropriate kmeans params object. + */ + params(uint32_t pq_bits, + uint32_t pq_dim, + bool use_subspaces, + bool use_vq, + uint32_t vq_n_centers, + uint32_t kmeans_n_iters, + cuvs::cluster::kmeans::kmeans_type pq_kmeans_type, + uint32_t max_train_points_per_pq_code, + uint32_t max_train_points_per_vq_cluster) + { + this->pq_bits = pq_bits; + this->pq_dim = pq_dim; + this->use_subspaces = use_subspaces; + this->use_vq = use_vq; + this->vq_n_centers = vq_n_centers; + this->kmeans_params = + pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced + ? kmeans_params_variant{cuvs::cluster::kmeans::balanced_params{.n_iters = kmeans_n_iters}} + : kmeans_params_variant{cuvs::cluster::kmeans::params{.n_clusters = 1 << pq_bits, + .max_iter = (int)kmeans_n_iters}}; + this->max_train_points_per_pq_code = max_train_points_per_pq_code; + this->max_train_points_per_vq_cluster = max_train_points_per_vq_cluster; + } + + params(uint32_t pq_bits, + uint32_t pq_dim, + bool use_subspaces, + bool use_vq, + uint32_t vq_n_centers, + kmeans_params_variant kmeans_params, + uint32_t max_train_points_per_pq_code, + uint32_t max_train_points_per_vq_cluster) + { + this->pq_bits = pq_bits; + this->pq_dim = pq_dim; + this->use_subspaces = use_subspaces; + this->use_vq = use_vq; + this->vq_n_centers = vq_n_centers; + this->kmeans_params = kmeans_params; + this->max_train_points_per_pq_code = max_train_points_per_pq_code; + this->max_train_points_per_vq_cluster = max_train_points_per_vq_cluster; + } + + params() = default; + /** * The bit length of the vector element after compression by PQ. * @@ -53,16 +108,15 @@ struct params { * When zero, an optimal value is selected using a heuristic. */ uint32_t vq_n_centers = 0; - /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ - uint32_t kmeans_n_iters = 25; /** - * Type of k-means algorithm for PQ training. - * Balanced k-means tends to be faster than regular k-means for PQ training, for - * problem sets where the number of points per cluster are approximately equal. - * Regular k-means may be better for skewed cluster distributions. + * K-means parameters for PQ codebook training. + * + * Set to cuvs::cluster::kmeans::balanced_params for balanced k-means (default), + * or cuvs::cluster::kmeans::params for regular k-means. + * The active variant type selects the algorithm; balanced k-means tends to be faster + * for PQ training where cluster sizes are approximately equal. */ - cuvs::cluster::kmeans::kmeans_type pq_kmeans_type = - cuvs::cluster::kmeans::kmeans_type::KMeansBalanced; + kmeans_params_variant kmeans_params = cuvs::cluster::kmeans::balanced_params{}; /** * The max number of data points to use per PQ code during PQ codebook training. Using more data * points per PQ code may increase the quality of PQ codebook but may also increase the build diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index cbe06f5ca4..1b60e64e0a 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -5,6 +5,7 @@ #pragma once #include +#include #include "../../cluster/kmeans_balanced.cuh" #include "../../preprocessing/quantize/detail/pq_codepacking.cuh" // pq_bits-bitfield @@ -74,50 +75,51 @@ namespace cuvs::neighbors::detail { template void train_pq_centers( const raft::resources& res, - const cuvs::neighbors::vpq_params& params, + const cuvs::preprocessing::quantize::pq::kmeans_params_variant& kmeans_params, const raft::device_matrix_view pq_trainset_view, const raft::device_matrix_view pq_centers_view, raft::device_vector_view sub_labels_view, raft::device_vector_view pq_cluster_sizes_view) { - if (params.pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced) { - cuvs::cluster::kmeans::balanced_params kmeans_params; - kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded; - - cuvs::cluster::kmeans_balanced::helpers::build_clusters< - MathT, - MathT, - IdxT, - uint32_t, - uint32_t, - cuvs::spatial::knn::detail::utils::mapping>( - res, - kmeans_params, - pq_trainset_view, - pq_centers_view, - sub_labels_view, - pq_cluster_sizes_view, - cuvs::spatial::knn::detail::utils::mapping{}); - } else { - const auto pq_n_centers = pq_centers_view.extent(0); - cuvs::cluster::kmeans::params kmeans_params; - kmeans_params.n_clusters = pq_n_centers; - kmeans_params.max_iter = params.kmeans_n_iters; - kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded; - kmeans_params.init = cuvs::cluster::kmeans::params::InitMethod::Random; - - std::optional> sample_weight = std::nullopt; - MathT inertia; - IdxT n_iter; - cuvs::cluster::kmeans::fit(res, - kmeans_params, - pq_trainset_view, - sample_weight, - pq_centers_view, - raft::make_host_scalar_view(&inertia), - raft::make_host_scalar_view(&n_iter)); - } + std::visit( + [&](auto const& base_kmeans_params) { + using KP = std::decay_t; + if constexpr (std::is_same_v) { + auto bal_params = base_kmeans_params; + bal_params.metric = cuvs::distance::DistanceType::L2Expanded; + cuvs::cluster::kmeans_balanced::helpers::build_clusters< + MathT, + MathT, + IdxT, + uint32_t, + uint32_t, + cuvs::spatial::knn::detail::utils::mapping>( + res, + bal_params, + pq_trainset_view, + pq_centers_view, + sub_labels_view, + pq_cluster_sizes_view, + cuvs::spatial::knn::detail::utils::mapping{}); + } else { + auto classic_params = base_kmeans_params; + classic_params.n_clusters = pq_centers_view.extent(0); + classic_params.metric = cuvs::distance::DistanceType::L2Expanded; + RAFT_EXPECTS(classic_params.init != cuvs::cluster::kmeans::params::InitMethod::Array, + "Array initialization is not supported for PQ training"); + std::optional> sample_weight = std::nullopt; + MathT inertia; + IdxT n_iter; + cuvs::cluster::kmeans::fit(res, + classic_params, + pq_trainset_view, + sample_weight, + pq_centers_view, + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + } + }, + kmeans_params); } template @@ -219,7 +221,7 @@ auto predict_vq(const raft::resources& res, template auto train_pq(const raft::resources& res, - const vpq_params& params, + const cuvs::preprocessing::quantize::pq::params params, const DatasetT& dataset, const raft::device_matrix_view vq_centers) -> raft::device_matrix @@ -230,8 +232,8 @@ auto train_pq(const raft::resources& res, const ix_t pq_bits = params.pq_bits; const ix_t pq_n_centers = ix_t{1} << pq_bits; const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim); - const ix_t n_rows_train = std::min((ix_t)(n_rows * params.pq_kmeans_trainset_fraction), - params.max_train_points_per_pq_code * pq_n_centers); + const ix_t n_rows_train = + std::min(n_rows, params.max_train_points_per_pq_code * pq_n_centers); RAFT_EXPECTS( n_rows_train >= pq_n_centers, "The number of training samples must be greater than or equal to the number of PQ centers"); @@ -261,8 +263,12 @@ auto train_pq(const raft::resources& res, pq_trainset.data_handle(), n_rows_train * pq_dim, pq_len); auto sub_labels = raft::make_device_vector(res, pq_trainset_view.extent(0)); auto pq_cluster_sizes = raft::make_device_vector(res, pq_centers.extent(0)); - train_pq_centers( - res, params, pq_trainset_view, pq_centers.view(), sub_labels.view(), pq_cluster_sizes.view()); + train_pq_centers(res, + params.kmeans_params, + pq_trainset_view, + pq_centers.view(), + sub_labels.view(), + pq_cluster_sizes.view()); return pq_centers; } diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 37f74f29bb..1a09736d86 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -160,14 +160,16 @@ index build( int dim_per_subspace = params.pq_dim; int num_clusters = 1 << params.pq_bits; - cuvs::preprocessing::quantize::pq::params pq_build_params; - pq_build_params.pq_bits = params.pq_bits; - pq_build_params.pq_dim = num_subspaces; - pq_build_params.use_subspaces = true; - pq_build_params.use_vq = false; // We already computed residuals - pq_build_params.kmeans_n_iters = params.pq_train_iters; - pq_build_params.max_train_points_per_pq_code = pq_n_rows_train / num_clusters; - pq_build_params.pq_kmeans_type = cuvs::cluster::kmeans::kmeans_type::KMeansBalanced; + cuvs::preprocessing::quantize::pq::params pq_build_params( + params.pq_bits, + num_subspaces, + true, + false, + 0, + params.pq_train_iters, + cuvs::cluster::kmeans::kmeans_type::KMeansBalanced, + pq_n_rows_train / num_clusters, + 1024); auto pq_quantizer = cuvs::preprocessing::quantize::pq::build( res, pq_build_params, raft::make_const_mdspan(trainset_residuals.view())); diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 77fb0ac4f9..e9ac2257e2 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -29,17 +29,38 @@ inline void fill_missing_params_heuristics(cuvs::preprocessing::quantize::pq::pa } } +inline bool is_balanced_kmeans(const cuvs::preprocessing::quantize::pq::params& params) +{ + return std::holds_alternative(params.kmeans_params); +} + +inline uint32_t get_kmeans_n_iters(const cuvs::preprocessing::quantize::pq::params& params) +{ + return std::visit( + [](const auto& kp) -> uint32_t { + using kmeans_type = std::decay_t; + if constexpr (std::is_same_v) { + return kp.n_iters; + } else { + return static_cast(kp.max_iter); + } + }, + params.kmeans_params); +} + inline auto to_vpq_params(const cuvs::preprocessing::quantize::pq::params& params) -> cuvs::neighbors::vpq_params { + auto kmeans_type = is_balanced_kmeans(params) ? cuvs::cluster::kmeans::kmeans_type::KMeansBalanced + : cuvs::cluster::kmeans::kmeans_type::KMeans; return cuvs::neighbors::vpq_params{ .pq_bits = params.pq_bits, .pq_dim = params.pq_dim, .vq_n_centers = params.vq_n_centers, - .kmeans_n_iters = params.kmeans_n_iters, + .kmeans_n_iters = get_kmeans_n_iters(params), .vq_kmeans_trainset_fraction = 1.0, .pq_kmeans_trainset_fraction = 1.0, - .pq_kmeans_type = params.pq_kmeans_type, + .pq_kmeans_type = kmeans_type, .max_train_points_per_pq_code = params.max_train_points_per_pq_code, .max_train_points_per_vq_cluster = params.max_train_points_per_vq_cluster}; } @@ -92,7 +113,7 @@ auto train_pq_subspaces( auto sub_labels = raft::make_device_vector(res, 0); auto pq_cluster_sizes = raft::make_device_vector(res, 0); auto device_memory = raft::resource::get_workspace_resource(res); - if (params.pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced) { + if (is_balanced_kmeans(params)) { sub_labels = raft::make_device_mdarray( res, device_memory, raft::make_extents(n_rows_train)); pq_cluster_sizes = raft::make_device_mdarray( @@ -111,7 +132,7 @@ auto train_pq_subspaces( pq_centers.data_handle() + m * pq_n_centers * pq_len, pq_n_centers, pq_len); cuvs::neighbors::detail::train_pq_centers( res, - to_vpq_params(params), + params.kmeans_params, raft::make_const_mdspan(sub_dataset.view()), pq_centers_subspace_view, sub_labels.view(), @@ -152,7 +173,7 @@ quantizer build( res, filled_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } else { pq_code_book = cuvs::neighbors::detail::train_pq( - res, vpq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); + res, filled_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } return {filled_params, cuvs::neighbors::vpq_dataset{ @@ -344,6 +365,29 @@ void vpq_convert_math_type(const raft::resources& res, raft::make_const_mdspan(src.pq_code_book.view())); } +inline auto from_vpq_params(const cuvs::neighbors::vpq_params& in_params, const uint64_t n_rows) + -> cuvs::preprocessing::quantize::pq::params +{ + const uint32_t pq_n_centers = 1 << in_params.pq_bits; + uint32_t max_train_points_per_vq_cluster = in_params.max_train_points_per_vq_cluster; + if (in_params.vq_n_centers > 0) { + max_train_points_per_vq_cluster = + std::min(max_train_points_per_vq_cluster, + n_rows * in_params.vq_kmeans_trainset_fraction / in_params.vq_n_centers); + } + return cuvs::preprocessing::quantize::pq::params{ + in_params.pq_bits, + in_params.pq_dim, + true, + true, + in_params.vq_n_centers, + in_params.kmeans_n_iters, + in_params.pq_kmeans_type, + std::min(in_params.max_train_points_per_pq_code, + n_rows * in_params.pq_kmeans_trainset_fraction / pq_n_centers), + max_train_points_per_vq_cluster}; +} + template auto vpq_build(const raft::resources& res, const cuvs::neighbors::vpq_params& params, @@ -353,10 +397,11 @@ auto vpq_build(const raft::resources& res, // Use a heuristic to impute missing parameters. auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset); + auto pq_params = from_vpq_params(ps, dataset.extent(0)); // Train codes auto vq_code_book = cuvs::neighbors::detail::train_vq(res, ps, dataset); auto pq_code_book = cuvs::neighbors::detail::train_pq( - res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); + res, pq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); // Encode dataset const IdxT n_rows = dataset.extent(0); diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index f1b84854f1..e2dc37cc20 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -160,14 +160,21 @@ class ProductQuantizationTest : public ::testing::TestWithParam