Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
13725b3
add a new max_node_id parameter to the CAGRA search API, allowing use…
irina-resh-nvda Feb 6, 2026
0746446
Changed the max node id parameter name to graph_size for clarity; rem…
irina-resh-nvda Feb 6, 2026
70a69d9
wrote test
irina-resh-nvda Feb 6, 2026
f428e54
minor pre-commit changes
irina-resh-nvda Feb 6, 2026
c69c067
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Feb 9, 2026
7d3b52c
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Feb 16, 2026
13b1f77
addressed comments regarding type cast
irina-resh-nvda Feb 16, 2026
cb60c6b
Pre-commit style fix
irina-resh-nvda Feb 20, 2026
4be18ec
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Feb 20, 2026
c6bbece
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Feb 23, 2026
58d224a
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Feb 24, 2026
b27ad1b
style fix
irina-resh-nvda Feb 24, 2026
40e960f
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Mar 2, 2026
7ec1360
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Mar 3, 2026
3843a7b
smaller test
irina-resh-nvda Mar 3, 2026
86f8b86
comment fix
irina-resh-nvda Mar 3, 2026
d07140d
Merge branch 'main' into add-max-node-id-parameter
tfeher Mar 6, 2026
122c141
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Mar 16, 2026
a180187
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Mar 16, 2026
dcd7e6b
Merge branch 'main' into add-max-node-id-parameter
tfeher Mar 17, 2026
aa6c666
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Mar 23, 2026
28e215f
Merge branch 'main' into add-max-node-id-parameter
tfeher Mar 25, 2026
545af85
Merge branch 'main' into add-max-node-id-parameter
tfeher Mar 31, 2026
2c3e430
Merge branch 'main' into add-max-node-id-parameter
aamijar Apr 6, 2026
f63b805
Merge branch 'main' into add-max-node-id-parameter
irina-resh-nvda Apr 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,13 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
IndexT* __restrict__ traversed_hash_ptr,
const uint32_t traversed_hash_bitlen,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
const uint32_t num_blocks = 1,
const IndexT graph_size = 0)
{
const auto team_size_bits = dataset_desc.team_size_bitshift_from_smem();
const auto max_i = raft::round_up_safe<uint32_t>(num_pickup, warp_size >> team_size_bits);
const auto compute_distance = dataset_desc.compute_distance_impl;
const IndexT seed_index_limit = graph_size > 0 ? graph_size : dataset_desc.size;

for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += (blockDim.x >> team_size_bits)) {
const bool valid_i = (i < num_pickup);
Expand All @@ -121,7 +124,7 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
if (seed_ptr && (gid < num_seeds)) {
seed_index = seed_ptr[gid];
} else {
seed_index = device::xorshift64(gid ^ rand_xor_mask) % dataset_desc.size;
seed_index = device::xorshift64(gid ^ rand_xor_mask) % seed_index_limit;
}
}

Expand Down
50 changes: 25 additions & 25 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
const uint32_t min_iteration,
const uint32_t max_iteration,
uint32_t* const num_executed_iterations, /* stats */
SAMPLE_FILTER_T sample_filter)
SAMPLE_FILTER_T sample_filter,
const typename DATASET_DESCRIPTOR_T::INDEX_T graph_size = 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the default value used anywhere besides tests?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default is used everywhere but the tests and thefuture iterative cagra q implementation

{
using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T;
using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T;
Expand Down Expand Up @@ -282,7 +283,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
local_traversed_hashmap_ptr,
traversed_hash_bitlen,
block_id,
num_blocks);
num_blocks,
graph_size);
__syncthreads();
_CLK_REC(clk_compute_1st_distance);

Expand Down Expand Up @@ -607,29 +609,27 @@ void select_and_run(const dataset_descriptor_host<DataT, IndexT, DistanceT>& dat
num_queries,
smem_size);

auto const& kernel_launcher = [&](auto const& kernel) -> void {
kernel<<<grid_dims, block_dims, smem_size, stream>>>(topk_indices_ptr,
topk_distances_ptr,
dataset_desc.dev_ptr(stream),
queries_ptr,
graph.data_handle(),
max_elements,
graph.extent(1),
source_indices_ptr,
ps.num_random_samplings,
ps.rand_xor_mask,
dev_seed_ptr,
num_seeds,
visited_hash_bitlen,
traversed_hashmap_ptr,
traversed_hash_bitlen,
ps.itopk_size,
ps.min_iterations,
ps.max_iterations,
num_executed_iterations,
sample_filter);
};
cuvs::neighbors::detail::safely_launch_kernel_with_smem_size(kernel, smem_size, kernel_launcher);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(topk_indices_ptr,
topk_distances_ptr,
dataset_desc.dev_ptr(stream),
queries_ptr,
graph.data_handle(),
max_elements,
graph.extent(1),
source_indices_ptr,
ps.num_random_samplings,
ps.rand_xor_mask,
dev_seed_ptr,
num_seeds,
visited_hash_bitlen,
traversed_hashmap_ptr,
traversed_hash_bitlen,
ps.itopk_size,
ps.min_iterations,
ps.max_iterations,
num_executed_iterations,
sample_filter,
static_cast<IndexT>(graph.extent(0)));
}

} // namespace multi_cta_search
Expand Down
18 changes: 12 additions & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down Expand Up @@ -104,7 +104,8 @@ RAFT_KERNEL random_pickup_kernel(
typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldr]
const std::uint32_t ldr, // (*) ldr >= num_pickup
typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen]
const std::uint32_t hash_bitlen)
const std::uint32_t hash_bitlen,
const typename DATASET_DESCRIPTOR_T::INDEX_T graph_size = 0)
{
using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T;
using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T;
Expand All @@ -119,6 +120,8 @@ RAFT_KERNEL random_pickup_kernel(
dataset_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id);
__syncthreads();

const INDEX_T seed_index_limit = graph_size > 0 ? graph_size : dataset_desc->size;

INDEX_T best_index_team_local;
DISTANCE_T best_norm2_team_local = utils::get_max_value<DISTANCE_T>();
for (unsigned i = 0; i < num_distilation; i++) {
Expand All @@ -128,7 +131,7 @@ RAFT_KERNEL random_pickup_kernel(
} else {
// Chose a seed node randomly
seed_index =
device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_desc->size;
device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % seed_index_limit;
}

DISTANCE_T norm2 = dataset_desc->compute_distance(seed_index, true);
Expand Down Expand Up @@ -166,7 +169,8 @@ void random_pickup(const dataset_descriptor_host<DataT, IndexT, DistanceT>& data
std::size_t ldr, // (*) ldr >= num_pickup
IndexT* visited_hashmap_ptr, // [num_queries, 1 << bitlen]
std::uint32_t hash_bitlen,
cudaStream_t cuda_stream)
cudaStream_t cuda_stream,
IndexT graph_size = 0)
{
const auto block_size = 256u;
const auto num_teams_per_threadblock = block_size / dataset_desc.team_size;
Expand All @@ -185,7 +189,8 @@ void random_pickup(const dataset_descriptor_host<DataT, IndexT, DistanceT>& data
result_distances_ptr,
ldr,
visited_hashmap_ptr,
hash_bitlen);
hash_bitlen,
graph_size);
}

template <class INDEX_T>
Expand Down Expand Up @@ -826,7 +831,8 @@ struct search
result_buffer_allocation_size,
hashmap.data(),
hash_bitlen,
stream);
stream,
static_cast<IndexT>(this->dataset_size));

unsigned iter = 0;
while (1) {
Expand Down
67 changes: 35 additions & 32 deletions cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,8 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core(
const std::uint32_t small_hash_bitlen,
const std::uint32_t small_hash_reset_interval,
const std::uint32_t query_id,
SAMPLE_FILTER_T sample_filter)
SAMPLE_FILTER_T sample_filter,
const typename DATASET_DESCRIPTOR_T::INDEX_T graph_size = 0)
{
using LOAD_T = device::LOAD_128BIT_T;

Expand Down Expand Up @@ -792,7 +793,10 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core(
local_visited_hashmap_ptr,
hash_bitlen,
(INDEX_T*)nullptr,
0);
0,
0,
1,
graph_size);
__syncthreads();
_CLK_REC(clk_compute_1st_distance);

Expand Down Expand Up @@ -1125,7 +1129,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
const std::uint32_t hash_bitlen,
const std::uint32_t small_hash_bitlen,
const std::uint32_t small_hash_reset_interval,
SAMPLE_FILTER_T sample_filter)
SAMPLE_FILTER_T sample_filter,
const typename DATASET_DESCRIPTOR_T::INDEX_T graph_size = 0)
{
const auto query_id = blockIdx.y;
search_core<TOPK_BY_BITONIC_SORT,
Expand Down Expand Up @@ -1156,7 +1161,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
small_hash_bitlen,
small_hash_reset_interval,
query_id,
sample_filter);
sample_filter,
graph_size);
}

// To make sure we avoid false sharing on both CPU and GPU, we enforce cache line size to the
Expand Down Expand Up @@ -2317,34 +2323,31 @@ control is returned in this thread (in persistent_runner_t constructor), so we'r
dim3 block_dims(1, num_queries, 1);
RAFT_LOG_DEBUG(
"Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size);
auto const& kernel_launcher = [&](auto const& kernel) -> void {
kernel<<<block_dims, thread_dims, smem_size, stream>>>(topk_indices_ptr,
topk_distances_ptr,
topk,
dataset_desc.dev_ptr(stream),
queries_ptr,
graph.data_handle(),
graph.extent(1),
source_indices_ptr,
ps.num_random_samplings,
ps.rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap_ptr,
max_candidates,
max_itopk,
ps.itopk_size,
ps.search_width,
ps.min_iterations,
ps.max_iterations,
num_executed_iterations,
hash_bitlen,
small_hash_bitlen,
small_hash_reset_interval,
sample_filter);
};
cuvs::neighbors::detail::safely_launch_kernel_with_smem_size(
kernel, smem_size, kernel_launcher);
kernel<<<block_dims, thread_dims, smem_size, stream>>>(topk_indices_ptr,
topk_distances_ptr,
topk,
dataset_desc.dev_ptr(stream),
queries_ptr,
graph.data_handle(),
graph.extent(1),
source_indices_ptr,
ps.num_random_samplings,
ps.rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap_ptr,
max_candidates,
max_itopk,
ps.itopk_size,
ps.search_width,
ps.min_iterations,
ps.max_iterations,
num_executed_iterations,
hash_bitlen,
small_hash_bitlen,
small_hash_reset_interval,
sample_filter,
static_cast<IndexT>(graph.extent(0)));
RAFT_CUDA_TRY(cudaPeekAtLastError());
}
}
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ ConfigureTest(
NAME NEIGHBORS_ANN_CAGRA_TEST_BUGS
PATH neighbors/ann_cagra/bug_extreme_inputs_oob.cu
neighbors/ann_cagra/bug_multi_cta_crash.cu
neighbors/ann_cagra/bug_graph_smaller_than_dataset.cu
neighbors/ann_cagra/bug_iterative_cagra_build.cu
neighbors/ann_cagra/bug_issue_93_reproducer.cu
GPUS 1
Expand Down
Loading
Loading