Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions c/src/preprocessing/quantize/pq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,10 @@ void* _build(cuvsResources_t res,
auto dataset = dataset_tensor->dl_tensor;

auto res_ptr = reinterpret_cast<raft::resources*>(res);

auto quantizer_params = cuvs::preprocessing::quantize::pq::params{
.pq_bits = params->pq_bits,
.pq_dim = params->pq_dim,
.use_subspaces = params->use_subspaces,
.use_vq = params->use_vq,
.vq_n_centers = params->vq_n_centers,
.kmeans_n_iters = params->kmeans_n_iters,
.pq_kmeans_type = static_cast<cuvs::cluster::kmeans::kmeans_type>(params->pq_kmeans_type),
.max_train_points_per_pq_code = params->max_train_points_per_pq_code,
.max_train_points_per_vq_cluster = params->max_train_points_per_vq_cluster
};
cuvs::preprocessing::quantize::pq::params quantizer_params(
params->pq_bits, params->pq_dim, params->use_subspaces, params->use_vq, params->vq_n_centers,
params->kmeans_n_iters, static_cast<cuvs::cluster::kmeans::kmeans_type>(params->pq_kmeans_type), params->max_train_points_per_pq_code,
params->max_train_points_per_vq_cluster);
cuvs::preprocessing::quantize::pq::quantizer<T>* ret = nullptr;

if (cuvs::core::is_dlpack_device_compatible(dataset)) {
Expand Down
70 changes: 62 additions & 8 deletions cpp/include/cuvs/preprocessing/quantize/pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,77 @@

#pragma once

#include <cuvs/cluster/kmeans.hpp>
#include <cuvs/neighbors/common.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdspan.hpp>

#include <variant>

namespace cuvs::preprocessing::quantize::pq {

/**
* @defgroup pq Product Quantizer utilities
* @{
*/

/** Alias for the variant holding either balanced or regular k-means parameters. */
using kmeans_params_variant =
std::variant<cuvs::cluster::kmeans::balanced_params, cuvs::cluster::kmeans::params>;

/**
* @brief Product Quantizer parameters.
*/
struct params {
/**
* Simplified constructor that will build an appropriate kmeans params object.
*/
params(uint32_t pq_bits,
uint32_t pq_dim,
bool use_subspaces,
bool use_vq,
uint32_t vq_n_centers,
uint32_t kmeans_n_iters,
cuvs::cluster::kmeans::kmeans_type pq_kmeans_type,
uint32_t max_train_points_per_pq_code,
uint32_t max_train_points_per_vq_cluster)
{
this->pq_bits = pq_bits;
this->pq_dim = pq_dim;
this->use_subspaces = use_subspaces;
this->use_vq = use_vq;
this->vq_n_centers = vq_n_centers;
this->kmeans_params =
pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced
? kmeans_params_variant{cuvs::cluster::kmeans::balanced_params{.n_iters = kmeans_n_iters}}
: kmeans_params_variant{cuvs::cluster::kmeans::params{.n_clusters = 1 << pq_bits,
.max_iter = (int)kmeans_n_iters}};
this->max_train_points_per_pq_code = max_train_points_per_pq_code;
this->max_train_points_per_vq_cluster = max_train_points_per_vq_cluster;
}

params(uint32_t pq_bits,
uint32_t pq_dim,
bool use_subspaces,
bool use_vq,
uint32_t vq_n_centers,
kmeans_params_variant kmeans_params,
uint32_t max_train_points_per_pq_code,
uint32_t max_train_points_per_vq_cluster)
{
this->pq_bits = pq_bits;
this->pq_dim = pq_dim;
this->use_subspaces = use_subspaces;
this->use_vq = use_vq;
this->vq_n_centers = vq_n_centers;
this->kmeans_params = kmeans_params;
this->max_train_points_per_pq_code = max_train_points_per_pq_code;
this->max_train_points_per_vq_cluster = max_train_points_per_vq_cluster;
}

params() = default;

/**
* The bit length of the vector element after compression by PQ.
*
Expand Down Expand Up @@ -53,16 +108,15 @@ struct params {
* When zero, an optimal value is selected using a heuristic.
*/
uint32_t vq_n_centers = 0;
/** The number of iterations searching for kmeans centers (both VQ & PQ phases). */
uint32_t kmeans_n_iters = 25;
/**
* Type of k-means algorithm for PQ training.
* Balanced k-means tends to be faster than regular k-means for PQ training, for
* problem sets where the number of points per cluster are approximately equal.
* Regular k-means may be better for skewed cluster distributions.
* K-means parameters for PQ codebook training.
*
* Set to cuvs::cluster::kmeans::balanced_params for balanced k-means (default),
* or cuvs::cluster::kmeans::params for regular k-means.
* The active variant type selects the algorithm; balanced k-means tends to be faster
* for PQ training where cluster sizes are approximately equal.
*/
cuvs::cluster::kmeans::kmeans_type pq_kmeans_type =
cuvs::cluster::kmeans::kmeans_type::KMeansBalanced;
kmeans_params_variant kmeans_params = cuvs::cluster::kmeans::balanced_params{};
/**
* The max number of data points to use per PQ code during PQ codebook training. Using more data
* points per PQ code may increase the quality of PQ codebook but may also increase the build
Expand Down
94 changes: 50 additions & 44 deletions cpp/src/neighbors/detail/vpq_dataset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <cuvs/neighbors/common.hpp>
#include <cuvs/preprocessing/quantize/pq.hpp>

#include "../../cluster/kmeans_balanced.cuh"
#include "../../preprocessing/quantize/detail/pq_codepacking.cuh" // pq_bits-bitfield
Expand Down Expand Up @@ -74,50 +75,51 @@ namespace cuvs::neighbors::detail {
template <typename MathT, typename IdxT>
void train_pq_centers(
const raft::resources& res,
const cuvs::neighbors::vpq_params& params,
const cuvs::preprocessing::quantize::pq::kmeans_params_variant& kmeans_params,
const raft::device_matrix_view<const MathT, IdxT, raft::row_major> pq_trainset_view,
const raft::device_matrix_view<MathT, uint32_t, raft::row_major> pq_centers_view,
raft::device_vector_view<uint32_t, IdxT> sub_labels_view,
raft::device_vector_view<uint32_t, IdxT> pq_cluster_sizes_view)
{
if (params.pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced) {
cuvs::cluster::kmeans::balanced_params kmeans_params;
kmeans_params.n_iters = params.kmeans_n_iters;
kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded;

cuvs::cluster::kmeans_balanced::helpers::build_clusters<
MathT,
MathT,
IdxT,
uint32_t,
uint32_t,
cuvs::spatial::knn::detail::utils::mapping<MathT>>(
res,
kmeans_params,
pq_trainset_view,
pq_centers_view,
sub_labels_view,
pq_cluster_sizes_view,
cuvs::spatial::knn::detail::utils::mapping<MathT>{});
} else {
const auto pq_n_centers = pq_centers_view.extent(0);
cuvs::cluster::kmeans::params kmeans_params;
kmeans_params.n_clusters = pq_n_centers;
kmeans_params.max_iter = params.kmeans_n_iters;
kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded;
kmeans_params.init = cuvs::cluster::kmeans::params::InitMethod::Random;

std::optional<raft::device_vector_view<const MathT, IdxT>> sample_weight = std::nullopt;
MathT inertia;
IdxT n_iter;
cuvs::cluster::kmeans::fit(res,
kmeans_params,
pq_trainset_view,
sample_weight,
pq_centers_view,
raft::make_host_scalar_view<MathT>(&inertia),
raft::make_host_scalar_view<IdxT>(&n_iter));
}
std::visit(
[&](auto const& base_kmeans_params) {
using KP = std::decay_t<decltype(base_kmeans_params)>;
if constexpr (std::is_same_v<KP, cuvs::cluster::kmeans::balanced_params>) {
auto bal_params = base_kmeans_params;
bal_params.metric = cuvs::distance::DistanceType::L2Expanded;
cuvs::cluster::kmeans_balanced::helpers::build_clusters<
MathT,
MathT,
IdxT,
uint32_t,
uint32_t,
cuvs::spatial::knn::detail::utils::mapping<MathT>>(
res,
bal_params,
pq_trainset_view,
pq_centers_view,
sub_labels_view,
pq_cluster_sizes_view,
cuvs::spatial::knn::detail::utils::mapping<MathT>{});
} else {
auto classic_params = base_kmeans_params;
classic_params.n_clusters = pq_centers_view.extent(0);
classic_params.metric = cuvs::distance::DistanceType::L2Expanded;
RAFT_EXPECTS(classic_params.init != cuvs::cluster::kmeans::params::InitMethod::Array,
"Array initialization is not supported for PQ training");
std::optional<raft::device_vector_view<const MathT, IdxT>> sample_weight = std::nullopt;
MathT inertia;
IdxT n_iter;
cuvs::cluster::kmeans::fit(res,
classic_params,
pq_trainset_view,
sample_weight,
pq_centers_view,
raft::make_host_scalar_view<MathT>(&inertia),
raft::make_host_scalar_view<IdxT>(&n_iter));
}
},
kmeans_params);
}

template <typename DatasetT>
Expand Down Expand Up @@ -219,7 +221,7 @@ auto predict_vq(const raft::resources& res,

template <typename MathT, typename DatasetT>
auto train_pq(const raft::resources& res,
const vpq_params& params,
const cuvs::preprocessing::quantize::pq::params params,
const DatasetT& dataset,
const raft::device_matrix_view<const MathT, uint32_t, raft::row_major> vq_centers)
-> raft::device_matrix<MathT, uint32_t, raft::row_major>
Expand All @@ -230,8 +232,8 @@ auto train_pq(const raft::resources& res,
const ix_t pq_bits = params.pq_bits;
const ix_t pq_n_centers = ix_t{1} << pq_bits;
const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim);
const ix_t n_rows_train = std::min((ix_t)(n_rows * params.pq_kmeans_trainset_fraction),
params.max_train_points_per_pq_code * pq_n_centers);
const ix_t n_rows_train =
std::min<ix_t>(n_rows, params.max_train_points_per_pq_code * pq_n_centers);
RAFT_EXPECTS(
n_rows_train >= pq_n_centers,
"The number of training samples must be greater than or equal to the number of PQ centers");
Expand Down Expand Up @@ -261,8 +263,12 @@ auto train_pq(const raft::resources& res,
pq_trainset.data_handle(), n_rows_train * pq_dim, pq_len);
auto sub_labels = raft::make_device_vector<uint32_t, ix_t>(res, pq_trainset_view.extent(0));
auto pq_cluster_sizes = raft::make_device_vector<uint32_t, ix_t>(res, pq_centers.extent(0));
train_pq_centers<MathT, ix_t>(
res, params, pq_trainset_view, pq_centers.view(), sub_labels.view(), pq_cluster_sizes.view());
train_pq_centers<MathT, ix_t>(res,
params.kmeans_params,
pq_trainset_view,
pq_centers.view(),
sub_labels.view(),
pq_cluster_sizes.view());

return pq_centers;
}
Expand Down
18 changes: 10 additions & 8 deletions cpp/src/neighbors/scann/detail/scann_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,16 @@ index<T, IdxT> build(
int dim_per_subspace = params.pq_dim;
int num_clusters = 1 << params.pq_bits;

cuvs::preprocessing::quantize::pq::params pq_build_params;
pq_build_params.pq_bits = params.pq_bits;
pq_build_params.pq_dim = num_subspaces;
pq_build_params.use_subspaces = true;
pq_build_params.use_vq = false; // We already computed residuals
pq_build_params.kmeans_n_iters = params.pq_train_iters;
pq_build_params.max_train_points_per_pq_code = pq_n_rows_train / num_clusters;
pq_build_params.pq_kmeans_type = cuvs::cluster::kmeans::kmeans_type::KMeansBalanced;
cuvs::preprocessing::quantize::pq::params pq_build_params(
params.pq_bits,
num_subspaces,
true,
false,
0,
params.pq_train_iters,
cuvs::cluster::kmeans::kmeans_type::KMeansBalanced,
pq_n_rows_train / num_clusters,
1024);

auto pq_quantizer = cuvs::preprocessing::quantize::pq::build(
res, pq_build_params, raft::make_const_mdspan(trainset_residuals.view()));
Expand Down
52 changes: 46 additions & 6 deletions cpp/src/preprocessing/quantize/detail/pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,38 @@ inline void fill_missing_params_heuristics(cuvs::preprocessing::quantize::pq::pa
}
}

inline bool is_balanced_kmeans(const cuvs::preprocessing::quantize::pq::params& params)
{
return std::holds_alternative<cuvs::cluster::kmeans::balanced_params>(params.kmeans_params);
}

inline uint32_t get_kmeans_n_iters(const cuvs::preprocessing::quantize::pq::params& params)
{
return std::visit(
[](const auto& kp) -> uint32_t {
using kmeans_type = std::decay_t<decltype(kp)>;
if constexpr (std::is_same_v<kmeans_type, cuvs::cluster::kmeans::balanced_params>) {
return kp.n_iters;
} else {
return static_cast<uint32_t>(kp.max_iter);
}
},
params.kmeans_params);
}

inline auto to_vpq_params(const cuvs::preprocessing::quantize::pq::params& params)
-> cuvs::neighbors::vpq_params
{
auto kmeans_type = is_balanced_kmeans(params) ? cuvs::cluster::kmeans::kmeans_type::KMeansBalanced
: cuvs::cluster::kmeans::kmeans_type::KMeans;
return cuvs::neighbors::vpq_params{
.pq_bits = params.pq_bits,
.pq_dim = params.pq_dim,
.vq_n_centers = params.vq_n_centers,
.kmeans_n_iters = params.kmeans_n_iters,
.kmeans_n_iters = get_kmeans_n_iters(params),
.vq_kmeans_trainset_fraction = 1.0,
.pq_kmeans_trainset_fraction = 1.0,
.pq_kmeans_type = params.pq_kmeans_type,
.pq_kmeans_type = kmeans_type,
.max_train_points_per_pq_code = params.max_train_points_per_pq_code,
.max_train_points_per_vq_cluster = params.max_train_points_per_vq_cluster};
}
Expand Down Expand Up @@ -92,7 +113,7 @@ auto train_pq_subspaces(
auto sub_labels = raft::make_device_vector<uint32_t, ix_t>(res, 0);
auto pq_cluster_sizes = raft::make_device_vector<uint32_t, ix_t>(res, 0);
auto device_memory = raft::resource::get_workspace_resource(res);
if (params.pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced) {
if (is_balanced_kmeans(params)) {
sub_labels = raft::make_device_mdarray<uint32_t>(
res, device_memory, raft::make_extents<ix_t>(n_rows_train));
pq_cluster_sizes = raft::make_device_mdarray<uint32_t>(
Expand All @@ -111,7 +132,7 @@ auto train_pq_subspaces(
pq_centers.data_handle() + m * pq_n_centers * pq_len, pq_n_centers, pq_len);
cuvs::neighbors::detail::train_pq_centers<MathT, ix_t>(
res,
to_vpq_params(params),
params.kmeans_params,
raft::make_const_mdspan(sub_dataset.view()),
pq_centers_subspace_view,
sub_labels.view(),
Expand Down Expand Up @@ -152,7 +173,7 @@ quantizer<MathT> build(
res, filled_params, dataset, raft::make_const_mdspan(vq_code_book.view()));
} else {
pq_code_book = cuvs::neighbors::detail::train_pq<MathT>(
res, vpq_params, dataset, raft::make_const_mdspan(vq_code_book.view()));
res, filled_params, dataset, raft::make_const_mdspan(vq_code_book.view()));
}
return {filled_params,
cuvs::neighbors::vpq_dataset<MathT, int64_t>{
Expand Down Expand Up @@ -344,6 +365,24 @@ void vpq_convert_math_type(const raft::resources& res,
raft::make_const_mdspan(src.pq_code_book.view()));
}

inline auto from_vpq_params(const cuvs::neighbors::vpq_params& in_params, const uint64_t n_rows)
-> cuvs::preprocessing::quantize::pq::params
{
const uint32_t pq_n_centers = 1 << in_params.pq_bits;
return cuvs::preprocessing::quantize::pq::params{
in_params.pq_bits,
in_params.pq_dim,
true,
true,
in_params.vq_n_centers,
in_params.kmeans_n_iters,
in_params.pq_kmeans_type,
std::min<uint32_t>(in_params.max_train_points_per_pq_code,
n_rows * in_params.pq_kmeans_trainset_fraction / pq_n_centers),
std::min<uint32_t>(in_params.max_train_points_per_vq_cluster,
n_rows * in_params.vq_kmeans_trainset_fraction / in_params.vq_n_centers)};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could vq_n_centers be set to zero. Perhaps we can just add a guard to avoid zero division.

}

template <typename DatasetT, typename MathT, typename IdxT>
auto vpq_build(const raft::resources& res,
const cuvs::neighbors::vpq_params& params,
Expand All @@ -353,10 +392,11 @@ auto vpq_build(const raft::resources& res,
// Use a heuristic to impute missing parameters.
auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset);

auto pq_params = from_vpq_params(ps, dataset.extent(0));
// Train codes
auto vq_code_book = cuvs::neighbors::detail::train_vq<MathT>(res, ps, dataset);
auto pq_code_book = cuvs::neighbors::detail::train_pq<MathT>(
res, ps, dataset, raft::make_const_mdspan(vq_code_book.view()));
res, pq_params, dataset, raft::make_const_mdspan(vq_code_book.view()));

// Encode dataset
const IdxT n_rows = dataset.extent(0);
Expand Down
Loading
Loading