Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2021,7 +2021,9 @@ auto iterative_build_graph(

// Allocate memory for search results.
constexpr uint64_t max_chunk_size = 8192;
auto topk = intermediate_degree;
// +1 because the search may return the query node itself as a neighbor;
// this is consistent with the per-iteration curr_topk = next_graph_degree + 1
auto topk = intermediate_degree + 1;
auto dev_neighbors = raft::make_device_matrix<IdxT, int64_t>(res, max_chunk_size, topk);
auto dev_distances = raft::make_device_matrix<float, int64_t>(res, max_chunk_size, topk);

Expand Down
4 changes: 3 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ ConfigureTest(

ConfigureTest(
NAME NEIGHBORS_ANN_CAGRA_TEST_BUGS
PATH neighbors/ann_cagra/bug_extreme_inputs_oob.cu neighbors/ann_cagra/bug_multi_cta_crash.cu
PATH neighbors/ann_cagra/bug_extreme_inputs_oob.cu
neighbors/ann_cagra/bug_multi_cta_crash.cu
neighbors/ann_cagra/bug_iterative_cagra_build.cu
neighbors/ann_cagra/bug_issue_93_reproducer.cu
GPUS 1
PERCENT 100
Expand Down
85 changes: 85 additions & 0 deletions cpp/tests/neighbors/ann_cagra/bug_iterative_cagra_build.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include <gtest/gtest.h>

#include <cuvs/neighbors/cagra.hpp>

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/random/rng.cuh>

#include <cstdint>
#include <type_traits>

namespace cuvs::neighbors::cagra {

template <typename DataT>
class CagraIterativeBuildBugTest : public ::testing::Test {
public:
using data_type = DataT;

protected:
void run()
{
// Set up iterative CAGRA graph building
cagra::index_params index_params;
// The bug manifests when graph_degree is equal to intermediate_graph_degree
// see issue https://github.com/rapidsai/cuvs/issues/1818
index_params.graph_degree = 16;
index_params.intermediate_graph_degree = 16;
Comment thread
irina-resh-nvda marked this conversation as resolved.

// Use iterative CAGRA search for graph building
index_params.graph_build_params = graph_build_params::iterative_search_params();

// Build the index
auto cagra_index = cagra::build(res, index_params, raft::make_const_mdspan(dataset->view()));
raft::resource::sync_stream(res);

// Verify the index was built successfully
ASSERT_GT(cagra_index.size(), 0);
ASSERT_EQ(cagra_index.dim(), n_dim);
}

void SetUp() override
{
dataset.emplace(raft::make_device_matrix<data_type, int64_t>(res, n_samples, n_dim));
raft::random::RngState r(1234ULL);

// Generate random data based on type
if constexpr (std::is_same_v<data_type, float>) {
raft::random::normal(
res, r, dataset->data_handle(), n_samples * n_dim, data_type(0), data_type(1));
} else if constexpr (std::is_same_v<data_type, int8_t>) {
raft::random::uniformInt(
res, r, dataset->data_handle(), n_samples * n_dim, int8_t(-128), int8_t(127));
} else if constexpr (std::is_same_v<data_type, uint8_t>) {
raft::random::uniformInt(
res, r, dataset->data_handle(), n_samples * n_dim, uint8_t(0), uint8_t(255));
}
raft::resource::sync_stream(res);
}

void TearDown() override
{
dataset.reset();
raft::resource::sync_stream(res);
}

private:
raft::resources res;
std::optional<raft::device_matrix<data_type, int64_t>> dataset = std::nullopt;

constexpr static int64_t n_samples = 10000;
constexpr static int64_t n_dim = 1024;
};

// Instantiate test for different data types
using TestTypes = ::testing::Types<float, int8_t, uint8_t>;
TYPED_TEST_SUITE(CagraIterativeBuildBugTest, TestTypes);

TYPED_TEST(CagraIterativeBuildBugTest, IterativeBuildTest) { this->run(); }

} // namespace cuvs::neighbors::cagra