Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
225 commits
Select commit Hold shift + click to select a range
a024f61
jit lto interleaved scan
divyegala Oct 2, 2025
45da4aa
fix dependencies.yaml
divyegala Oct 2, 2025
a7c8621
generate files at build time, use tags to avoid compilation of types
divyegala Oct 4, 2025
eb2d74b
passing tests
divyegala Oct 5, 2025
d2318e8
update gitignore
divyegala Oct 6, 2025
5e6afcd
separate out distance function from main kernel
divyegala Oct 6, 2025
6eee4da
fix deps
divyegala Oct 6, 2025
1de8f28
add filters as jit device functions, rework caching logic
divyegala Oct 7, 2025
84c6020
lto post lambda, cleanup files, generate cmake in build dir
divyegala Oct 7, 2025
22680c8
don't read hardcoded kernels, use generator properly
divyegala Oct 8, 2025
37f1163
random cmake changes carried over from 25.10
divyegala Oct 8, 2025
0ae5383
cmake format
divyegala Oct 8, 2025
fe56aec
remove dep on kernel list
divyegala Oct 8, 2025
40c8fd6
attempt to solve overlinking problem
divyegala Oct 9, 2025
e87a8c7
reorder if-else in compiler check
divyegala Oct 9, 2025
179d733
Merge branch 'branch-25.12' into jit-lto-ivf-flat-interleaved
divyegala Oct 9, 2025
32a67bd
use cudart apis
divyegala Oct 9, 2025
c27612e
merge
divyegala Oct 9, 2025
a4b48b1
attempt to link cudart
divyegala Oct 9, 2025
d5d692e
revert cudart link, try all arch build of jit lto fatbin sources
divyegala Oct 9, 2025
1c6dd94
cmake format
divyegala Oct 9, 2025
30f5ab6
missing shared mem setting
divyegala Oct 10, 2025
9674969
separate cuda 12 and 13 compilation
divyegala Oct 22, 2025
24fc47d
merge upstream
divyegala Oct 22, 2025
db9a487
remove bench
divyegala Oct 22, 2025
aa9294f
c include directory
divyegala Oct 22, 2025
2eb77fe
style check
divyegala Oct 22, 2025
6c685fa
merge upstream
divyegala Oct 22, 2025
3e35b99
guard cuda calls and use shared_ptr
divyegala Oct 23, 2025
d0ff62c
add AlgorithmPlanner to main target
divyegala Oct 23, 2025
eb87577
merge upstream
divyegala Oct 23, 2025
445a6c4
remove nvjitlink as cuda 12 dep
divyegala Oct 23, 2025
92a27d4
address review
divyegala Oct 24, 2025
8549172
merge upstream
divyegala Oct 24, 2025
67579f4
add include guard
divyegala Oct 27, 2025
7ad8774
add and remove couple of comments
divyegala Oct 27, 2025
816a480
merge upstream
divyegala Oct 27, 2025
ab35ef3
delete readme
divyegala Oct 27, 2025
cdd4c85
increase warmup time
divyegala Oct 27, 2025
87334b2
merge upstream
divyegala Oct 27, 2025
c1eff9f
use new copyright
divyegala Oct 27, 2025
ece09b8
new copyright
divyegala Oct 27, 2025
4dacc6e
remove one more straggling comment
divyegala Oct 27, 2025
1fd95cd
use raft expects
divyegala Oct 27, 2025
64cde0d
Merge branch 'main' into jit-lto-ivf-flat-interleaved
divyegala Oct 27, 2025
5ac127b
merge upstream
divyegala Dec 12, 2025
78002c6
address review
divyegala Dec 12, 2025
9ad6a0b
pre-commit
divyegala Dec 12, 2025
bf4c4ad
address review
divyegala Dec 12, 2025
18b2af9
Generate kernel files in CMake instead of Python
KyleFromNVIDIA Dec 12, 2025
ece5cad
Merge remote-tracking branch 'refs/remotes/github/divyegala/jit-lto-i…
KyleFromNVIDIA Dec 12, 2025
8ce70c2
Style
KyleFromNVIDIA Dec 12, 2025
fdc4239
Style
KyleFromNVIDIA Dec 12, 2025
be3cf0d
Style
KyleFromNVIDIA Dec 12, 2025
7e644c3
Lint
KyleFromNVIDIA Dec 12, 2025
235938a
Style, lint
KyleFromNVIDIA Dec 12, 2025
e3b749d
Fix nvjitlink_checker
KyleFromNVIDIA Dec 15, 2025
f42ae3f
Style
KyleFromNVIDIA Dec 15, 2025
b606df9
Merge branch 'main' into jit-lto-ivf-flat-interleaved
KyleFromNVIDIA Dec 15, 2025
5ce7aab
Refactor JIT LTO kernel compilation
KyleFromNVIDIA Dec 15, 2025
eaad347
Style
KyleFromNVIDIA Dec 15, 2025
eb3b468
pic
KyleFromNVIDIA Dec 15, 2025
912279c
style
KyleFromNVIDIA Dec 15, 2025
19f1af3
Verbose build
KyleFromNVIDIA Dec 15, 2025
087b943
static
KyleFromNVIDIA Dec 15, 2025
c16e109
style
KyleFromNVIDIA Dec 15, 2025
323b79f
TARGET_OBJECTS
KyleFromNVIDIA Dec 15, 2025
9f13e73
Disable sccache
KyleFromNVIDIA Dec 16, 2025
eaf9d39
Recache
KyleFromNVIDIA Dec 16, 2025
ce40c51
Revert CI debugging
KyleFromNVIDIA Dec 16, 2025
0d0abb9
Install and link object library
KyleFromNVIDIA Dec 17, 2025
84bfa92
Style
KyleFromNVIDIA Dec 17, 2025
21241eb
Alias
KyleFromNVIDIA Dec 17, 2025
7c0ac13
Make cuvs_jit_lto_kernels a static library
KyleFromNVIDIA Dec 17, 2025
880dbf2
Style
KyleFromNVIDIA Dec 17, 2025
d04d7c1
rapids_cuda_init_architectures() for C tests
KyleFromNVIDIA Dec 17, 2025
19581f9
Be more specific about where we search for libclang
KyleFromNVIDIA Dec 17, 2025
a61f019
More libclang updates
KyleFromNVIDIA Dec 17, 2025
2eeb913
Revert "Fix libclang download for Rust, CUDA initialization for C tests"
KyleFromNVIDIA Dec 17, 2025
55ec26c
Merge branch 'main' into jit-lto-ivf-flat-interleaved
KyleFromNVIDIA Dec 18, 2025
10228c5
Merge branch 'main' into jit-lto-ivf-flat-interleaved
KyleFromNVIDIA Dec 18, 2025
031ce21
Merge branch 'main' into jit-lto-ivf-flat-interleaved
KyleFromNVIDIA Jan 14, 2026
088c21e
Copyright
KyleFromNVIDIA Jan 14, 2026
8ca1062
Apply suggestions from code review
divyegala Jan 22, 2026
d5ab5bf
merge upstream
divyegala Jan 22, 2026
b8c0d42
address some review comments
divyegala Jan 22, 2026
17d34ae
remove too many underscores
divyegala Jan 22, 2026
45a5146
FEA Add initial commit of prototype/pseudo-code for proposed UDF APIs…
dantegd Jan 26, 2026
447532e
stitch together
divyegala Jan 30, 2026
e1627d1
add udf to cmakelists
divyegala Jan 30, 2026
f7ea581
udfs working e2e
divyegala Jan 30, 2026
8b2775c
run benchmarks
divyegala Feb 3, 2026
e9c77d9
working through
divyegala Feb 3, 2026
adcfb8f
fixed overhead
divyegala Feb 4, 2026
282b376
Simplify
KyleFromNVIDIA Feb 4, 2026
609a4d6
Merge branch 'main' into jit-lto-ivf-flat-interleaved
KyleFromNVIDIA Feb 4, 2026
3115d07
address reviews
divyegala Feb 4, 2026
bb524ae
Merge remote-tracking branch 'origin/main' into jit-lto-ivf-flat-inte…
divyegala Feb 4, 2026
30a8a9f
Merge branch 'jit-lto-ivf-flat-interleaved' of github.com:divyegala/c…
divyegala Feb 4, 2026
72ddb36
Merge branch 'main' into jit-lto-ivf-flat-interleaved
divyegala Feb 5, 2026
4bd2102
add to docs and log about jit
divyegala Feb 10, 2026
fb722f0
Merge branch 'jit-lto-ivf-flat-interleaved' of github.com:divyegala/c…
divyegala Feb 10, 2026
3523b96
Merge remote-tracking branch 'origin/main' into jit-lto-ivf-flat-inte…
divyegala Feb 10, 2026
ba758a2
address review
divyegala Feb 10, 2026
42b78ae
rename inner_product to inner_prod
divyegala Feb 10, 2026
2e3a471
Merge remote-tracking branch 'origin/main' into jit-lto-ivf-flat-inte…
divyegala Feb 10, 2026
bfc6c09
fix merge conflict
divyegala Feb 10, 2026
f6377fa
include header and form better log
divyegala Feb 10, 2026
26abc7b
Merge branch 'jit-lto-ivf-flat-interleaved' into ivf-flat-search-udf
divyegala Feb 10, 2026
fb7f105
merge
divyegala Feb 10, 2026
432bb32
working through
divyegala Feb 11, 2026
533b770
address review and move
divyegala Feb 11, 2026
af23585
Merge remote-tracking branch 'origin/main' into jit-lto-ivf-flat-inte…
divyegala Feb 11, 2026
78c59d9
one more fix
divyegala Feb 11, 2026
9274868
Merge branch 'jit-lto-ivf-flat-interleaved' into cagra-search-jit-lto
divyegala Feb 11, 2026
7f8802b
correct path
divyegala Feb 11, 2026
f432aad
Merge branch 'jit-lto-ivf-flat-interleaved' into cagra-search-jit-lto
divyegala Feb 11, 2026
39ce9e3
in the middle of stuff
divyegala Feb 13, 2026
27acbb6
merge upstream
divyegala Feb 13, 2026
d11edfd
Merge branch 'jit-lto-ivf-flat-interleaved' into ivf-flat-search-udf
divyegala Feb 13, 2026
dd23671
multi-cta still failing
divyegala Feb 13, 2026
4f287c1
attempting to solve 2 kernel issue
divyegala Feb 15, 2026
64f6ad8
merge upstream
divyegala Feb 15, 2026
f1888a2
more cleaning
divyegala Feb 15, 2026
b596e79
merge cleanly
divyegala Feb 15, 2026
9c4980f
add nvrtc as a dependency
divyegala Feb 15, 2026
f27eeb2
fix build errors
divyegala Feb 15, 2026
bc5c90e
guard udf use
divyegala Feb 15, 2026
09dc56c
analyzing cubins
divyegala Feb 15, 2026
55c32f4
compiler definition on headers
divyegala Feb 15, 2026
1866475
guard udf test
divyegala Feb 15, 2026
c419173
remove
divyegala Feb 15, 2026
04cc166
missing include
divyegala Feb 15, 2026
1113afc
cleaning up
divyegala Feb 15, 2026
e372917
merge upstream
divyegala Feb 15, 2026
d8341ac
Merge remote-tracking branch 'divye/unneeded-cccl-includes' into cagr…
divyegala Feb 15, 2026
6feecce
most errors resolved
divyegala Feb 17, 2026
3e9f5f3
Merge branch 'main' into ivf-flat-search-udf
divyegala Feb 17, 2026
52e05c2
debug filter fragment
divyegala Feb 17, 2026
caf8d03
Merge branch 'main' into ivf-flat-search-udf
divyegala Feb 18, 2026
b65f599
occassional failure on dgx spark
divyegala Feb 18, 2026
5239a1a
fix compile
divyegala Feb 18, 2026
736dc75
Ignore cache-host run exports
bdice Feb 18, 2026
f83f595
Merge branch 'main' into ivf-flat-search-udf
divyegala Feb 18, 2026
a7a4ef7
pull out metric
divyegala Feb 19, 2026
5390c4c
Merge branch 'cagra-search-jit-lto' of github.com:divyegala/cuvs into…
divyegala Feb 19, 2026
07a158c
use void* for desc and create more fragments
divyegala Feb 19, 2026
0e201e8
attempt to fix cuda 12 builds
divyegala Feb 19, 2026
88a4b6e
respond to reviews
divyegala Feb 19, 2026
101c5ee
Merge remote-tracking branch 'origin/main' into ivf-flat-search-udf
divyegala Feb 19, 2026
5d3a9df
Merge branch 'ivf-flat-search-udf' of github.com:divyegala/cuvs into …
divyegala Feb 19, 2026
63c7300
pin cupy to <14.0 for cuda 12 wheels
divyegala Feb 19, 2026
0c0b6b5
fix cuda 12
divyegala Feb 19, 2026
faa9339
add includes
divyegala Feb 19, 2026
73e8fa0
fix logging
divyegala Feb 19, 2026
fef68d3
fix macro
divyegala Feb 19, 2026
05cc149
major refactor to reduce # of fragments
divyegala Feb 20, 2026
b6c9031
merge upstream udf pr
divyegala Feb 20, 2026
995f998
Merge branch 'main' into ivf-flat-search-udf
divyegala Feb 20, 2026
75e2616
Account for different QueryT
divyegala Feb 20, 2026
387d9ea
Merge remote-tracking branch 'origin/main' into cagra-search-jit-lto
divyegala Feb 20, 2026
1ccb01c
cleanup some stuff
divyegala Feb 20, 2026
3256a8e
attempt to fix devcontainer error
divyegala Feb 20, 2026
32a5d9f
Merge remote-tracking branch 'origin/main' into ivf-flat-search-udf
divyegala Feb 20, 2026
592af70
Merge branch 'ivf-flat-search-udf' of github.com:divyegala/cuvs into …
divyegala Feb 20, 2026
43501b7
address review comments
divyegala Feb 20, 2026
b5342d6
Merge branch 'ivf-flat-search-udf' into cagra-search-jit-lto
divyegala Feb 20, 2026
b85f16b
Add matrix JSON files
KyleFromNVIDIA Feb 24, 2026
e79de08
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Feb 24, 2026
de0a2b5
Fix
KyleFromNVIDIA Feb 24, 2026
c7909c3
more refactors and fix stream serialization bug
divyegala Feb 24, 2026
bbbfb25
launch correctly
divyegala Feb 25, 2026
22c40fd
Use new kernel matrix system
KyleFromNVIDIA Feb 25, 2026
d404869
remove debug prints
divyegala Feb 25, 2026
53ce0aa
Merge remote-tracking branch 'github/divyegala/cagra-search-jit-lto' …
KyleFromNVIDIA Feb 25, 2026
9fc9185
Merge remote-tracking branch 'github/divyegala/cagra-search-jit-lto' …
KyleFromNVIDIA Feb 25, 2026
1eef8c5
Remove preprocessor branch
KyleFromNVIDIA Feb 25, 2026
0af09e2
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Feb 25, 2026
b2e418b
reconcile pr 1807 and add nvjitlink/nvrtc to jit target
divyegala Feb 26, 2026
53195ef
Fix ivf flat
KyleFromNVIDIA Feb 26, 2026
f589b26
Fix kernel names and matrices
KyleFromNVIDIA Feb 26, 2026
6b8d175
Fix query
KyleFromNVIDIA Feb 26, 2026
426625e
Fix another query
KyleFromNVIDIA Feb 26, 2026
97dfa18
More
KyleFromNVIDIA Feb 26, 2026
29881c8
Make naming and matrices more consistent
KyleFromNVIDIA Feb 26, 2026
bb01ec6
add func specialization for smem launcher
divyegala Feb 26, 2026
6b32331
Merge branch 'cagra-search-jit-lto' of github.com:divyegala/cuvs into…
divyegala Feb 26, 2026
6516f78
fix ivf flat udf key
divyegala Feb 27, 2026
d737706
remove debug
divyegala Feb 27, 2026
a809041
Remove comments and debug statement, fix query, copyright
KyleFromNVIDIA Feb 27, 2026
0d48be2
Merge remote-tracking branch 'github/divyegala/cagra-search-jit-lto' …
KyleFromNVIDIA Feb 27, 2026
49f999f
missing query tag
divyegala Feb 27, 2026
d66edf0
Merge branch 'cagra-search-jit-lto' of github.com:divyegala/cuvs into…
divyegala Feb 27, 2026
b52f8c2
Refactor and make thread-safe
KyleFromNVIDIA Feb 27, 2026
0349746
remove prints
divyegala Feb 27, 2026
6e07abb
remove unnecessary includes
divyegala Feb 27, 2026
e9e2ff0
Don't build fatbins with debug symbols
KyleFromNVIDIA Mar 4, 2026
9bd6100
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 5, 2026
34ed3e2
Merge branch 'main' into cagra-search-jit-lto
divyegala Mar 5, 2026
e6f06fc
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 5, 2026
5552b2f
Merge remote-tracking branch 'github/divyegala/cagra-search-jit-lto' …
KyleFromNVIDIA Mar 5, 2026
582d6a0
unpin raft
divyegala Mar 6, 2026
fb13ea5
Merge remote-tracking branch 'origin/main' into cagra-search-jit-lto
divyegala Mar 6, 2026
98a1dce
Update cpp/cmake/thirdparty/get_raft.cmake
divyegala Mar 6, 2026
a39c150
Update cpp/cmake/thirdparty/get_raft.cmake
divyegala Mar 6, 2026
c3a8d73
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 9, 2026
8dfb354
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 9, 2026
33e1bc5
Add L1 dist op
KyleFromNVIDIA Mar 9, 2026
f050b77
Fix L1 distance
KyleFromNVIDIA Mar 10, 2026
d6eec0a
Explicitly install cudart
KyleFromNVIDIA Mar 11, 2026
1f3b75b
use function ptr indirection
divyegala Mar 12, 2026
832eaf2
Merge remote-tracking branch 'origin/main' into cagra-search-jit-lto
divyegala Mar 12, 2026
9243390
const
KyleFromNVIDIA Mar 12, 2026
dca579a
extern
KyleFromNVIDIA Mar 12, 2026
f11daf5
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 12, 2026
1c2da37
Re-run CI
KyleFromNVIDIA Mar 12, 2026
ff3527b
fix bug and simplify json
divyegala Mar 12, 2026
671e8a7
Merge branch 'cagra-search-jit-lto' of github.com:divyegala/cuvs into…
divyegala Mar 12, 2026
e14a119
simply function ptr usage
divyegala Mar 13, 2026
39e67f3
call functions directly
divyegala Mar 13, 2026
59f8911
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 16, 2026
be21da4
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 18, 2026
fb5cf1e
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Apr 15, 2026
8e6797c
merge upstream, make tests pass
divyegala Apr 17, 2026
ca67dea
delete extra files
divyegala Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions conda/recipes/libcuvs/recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ cache:
- ${{ stdlib("c") }}
host:
- libnvjitlink-dev
- cuda-nvrtc-dev
- librmm =${{ minor_version }}
- libraft-headers =${{ minor_version }}
- nccl ${{ nccl_version }}
Expand Down Expand Up @@ -122,6 +123,7 @@ outputs:
- libcusolver-dev
- libcusparse-dev
- libnvjitlink-dev
- cuda-nvrtc-dev
run:
- ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }}
- libraft-headers =${{ minor_version }}
Expand Down Expand Up @@ -184,6 +186,7 @@ outputs:
- libcusolver-dev
- libcusparse-dev
- libnvjitlink-dev
- cuda-nvrtc-dev
run:
- ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }}
- ${{ pin_subpackage("libcuvs-headers", exact=True) }}
Expand Down Expand Up @@ -245,6 +248,7 @@ outputs:
- libcusolver-dev
- libcusparse-dev
- libnvjitlink-dev
- cuda-nvrtc-dev
run:
- ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }}
- ${{ pin_subpackage("libcuvs-headers", exact=True) }}
Expand Down Expand Up @@ -327,6 +331,10 @@ outputs:
- librmm
- mkl
- nccl
- if: cuda_major == "13"
then:
- cuda-nvrtc
- libnvjitlink
about:
homepage: ${{ load_from_file("python/libcuvs/pyproject.toml").project.urls.Homepage }}
license: ${{ load_from_file("python/libcuvs/pyproject.toml").project.license }}
Expand Down
184 changes: 177 additions & 7 deletions cpp/CMakeLists.txt

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions cpp/cmake/modules/generate_jit_lto_kernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ function(process_jit_lto_matrix_entry source_list_var)
cmake_parse_arguments(_JIT_LTO "${options}" "${one_value}" "${multi_value}" ${ARGN})

populate_matrix_variables("${_JIT_LTO_MATRIX_JSON_ENTRY}")

# Registration fragment tags (FRAGMENT_TAG_FORMAT) can track GPU template params in .cu.in; map
# those scalars to tag_* names here so matrix JSON does not duplicate tag strings.
Comment on lines +52 to +53
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the existing conventions that we use in other matrix files. We use the *_abbrev substitutions for the kernel name and the tag name. Please don't pollute core CMake code with these implementation details.

if(DEFINED codebook_type)
if(codebook_type STREQUAL "void")
set(jit_fragment_codebook_tag "tag_codebook_none")
elseif(codebook_type STREQUAL "half")
set(jit_fragment_codebook_tag "tag_codebook_half")
else()
message(FATAL_ERROR "Unknown codebook_type for JIT fragment: ${codebook_type}")
endif()
endif()

string(CONFIGURE "${_JIT_LTO_NAME_FORMAT}" kernel_name @ONLY)
string(CONFIGURE "${_JIT_LTO_FRAGMENT_TAG_FORMAT}" fragment_tag @ONLY)

Expand Down
1 change: 1 addition & 0 deletions cpp/cmake/modules/register_fatbin.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "@fatbin_header_file@"
#include <cuvs/detail/jit_lto/FragmentEntry.hpp>
#include <cuvs/detail/jit_lto/registration_tags.hpp>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After restoring interleaved_scan_fragments.hpp, please remove this #include.


@fragment_tag_header_files@

Expand Down
15 changes: 15 additions & 0 deletions cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,25 @@ struct AlgorithmLauncher {
this->call(stream, grid, block, shared_mem, kernel_args);
}

template <typename FuncT, typename... Args>
void dispatch_cooperative(
cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, Args&&... args)
{
static_assert(
std::is_same_v<FuncT, void(std::remove_reference_t<Args>...)>,
"dispatch_cooperative() argument types do not match the kernel function signature FuncT");

void* kernel_args[] = {const_cast<void*>(static_cast<void const*>(&args))...};
this->call_cooperative(stream, grid, block, shared_mem, kernel_args);
}

cudaLibrary_t get_library() { return this->library; }
cudaKernel_t get_kernel() { return this->kernel; }

private:
void call(cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args);
void call_cooperative(
cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args);
cudaKernel_t kernel;
cudaLibrary_t library;
};
74 changes: 74 additions & 0 deletions cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cuvs/detail/jit_lto/common_fragments.hpp>

namespace cuvs::neighbors::cagra::detail {

template <typename DataTag,
typename IndexTag,
typename DistanceTag,
typename QueryTag,
typename CodebookTag,
typename TeamTag,
typename BlockDimTag,
typename PqBitsTag,
typename PqLenTag>
struct fragment_tag_setup_workspace {};

template <typename DataTag,
typename IndexTag,
typename DistanceTag,
typename QueryTag,
typename CodebookTag,
typename TeamTag,
typename BlockDimTag,
typename PqBitsTag,
typename PqLenTag>
struct fragment_tag_compute_distance {};

template <typename QueryTag, typename DistanceTag, typename MetricTag>
struct fragment_tag_dist_op {};

template <typename DataTag,
typename IndexTag,
typename DistanceTag,
typename QueryTag,
typename TeamTag,
typename BlockDimTag,
typename NormTag>
struct fragment_tag_apply_normalization_standard {};

template <typename DataTag,
typename SourceIndexTag,
typename IndexTag,
typename DistanceTag,
bool TopkByBitonicSort,
bool BitonicSortAndMergeMultiWarps>
struct fragment_tag_search_single_cta {};

template <typename DataTag,
typename SourceIndexTag,
typename IndexTag,
typename DistanceTag,
bool TopkByBitonicSort,
bool BitonicSortAndMergeMultiWarps>
struct fragment_tag_search_single_cta_p {};

template <typename DataTag, typename SourceIndexTag, typename IndexTag, typename DistanceTag>
struct fragment_tag_search_multi_cta {};

template <typename DataTag, typename IndexTag, typename DistanceTag>
struct fragment_tag_random_pickup {};

template <typename DataTag, typename IndexTag, typename DistanceTag, typename SourceIndexTag>
struct fragment_tag_compute_distance_to_child_nodes {};

template <typename IndexTag, typename DistanceTag, typename SourceIndexTag>
struct fragment_tag_apply_filter_kernel {};

} // namespace cuvs::neighbors::cagra::detail

This file was deleted.

98 changes: 98 additions & 0 deletions cpp/include/cuvs/detail/jit_lto/registration_tags.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please restore interleaved_scan_fragments.hpp and remove this file. I'd like to keep the fragment/tag headers consistent.

* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cstdint>

namespace cuvs::detail::jit_lto {

struct tag_f {};
struct tag_h {};
struct tag_sc {};
struct tag_uc {};
struct tag_idx_l {};
struct tag_filter_none {};
struct tag_filter_bitset {};

} // namespace cuvs::detail::jit_lto

namespace cuvs::neighbors::cagra::detail {

using cuvs::detail::jit_lto::tag_f;
using cuvs::detail::jit_lto::tag_filter_bitset;
using cuvs::detail::jit_lto::tag_filter_none;
using cuvs::detail::jit_lto::tag_h;
using cuvs::detail::jit_lto::tag_idx_l;
using cuvs::detail::jit_lto::tag_sc;
using cuvs::detail::jit_lto::tag_uc;

struct tag_idx_ui {};
struct tag_dist_f {};
struct tag_metric_l2 {};
struct tag_metric_inner_product {};
struct tag_metric_cosine {};
struct tag_metric_hamming {};
struct tag_team_8 {};
struct tag_team_16 {};
struct tag_team_32 {};
struct tag_dim_128 {};
struct tag_dim_256 {};
struct tag_dim_512 {};
struct tag_pq_bits_0 {};
struct tag_pq_bits_8 {};
struct tag_pq_len_0 {};
struct tag_pq_len_2 {};
struct tag_pq_len_4 {};
struct tag_codebook_none {};
struct tag_codebook_half {};
struct tag_metric_l1 {};
struct tag_norm_noop {};
struct tag_norm_cosine {};

} // namespace cuvs::neighbors::cagra::detail

namespace cuvs::neighbors::ivf_flat::detail {

using cuvs::detail::jit_lto::tag_f;
using cuvs::detail::jit_lto::tag_filter_bitset;
using cuvs::detail::jit_lto::tag_filter_none;
using cuvs::detail::jit_lto::tag_h;
using cuvs::detail::jit_lto::tag_idx_l;
using cuvs::detail::jit_lto::tag_sc;
using cuvs::detail::jit_lto::tag_uc;

struct tag_i8 {};
struct tag_u8 {};

struct tag_acc_f {};
struct tag_acc_h {};
struct tag_acc_i32 {};
struct tag_acc_u32 {};

struct tag_metric_euclidean {};
struct tag_metric_inner_product {};
struct tag_metric_custom_udf {};

struct tag_post_process_identity {};
struct tag_post_process_sqrt {};
struct tag_post_process_compose {};

template <typename DataTag, typename AccTag, typename IdxTag, int Capacity, bool Ascending>
struct fragment_tag_interleaved_scan {};

template <typename DataTag, typename AccTag, bool ComputeNorm, int Veclen>
struct fragment_tag_load_and_compute_dist {};

template <typename DataTag, typename AccTag, typename MetricTag, int Veclen>
struct fragment_tag_metric {};

template <typename IndexTag, typename FilterTag>
struct fragment_tag_filter {};

template <typename PostLambdaTag>
struct fragment_tag_post_lambda {};

} // namespace cuvs::neighbors::ivf_flat::detail
2 changes: 2 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ struct index : cuvs::neighbors::index {
/** Distance metric used for clustering. */
cuvs::distance::DistanceType metric() const noexcept;

void set_metric(cuvs::distance::DistanceType metric);

/** Whether `centers()` change upon extending the index (ivf_flat::extend). */
bool adaptive_centers() const noexcept;

Expand Down
25 changes: 24 additions & 1 deletion cpp/src/detail/jit_lto/AlgorithmLauncher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ AlgorithmLauncher::AlgorithmLauncher(AlgorithmLauncher&& other) noexcept
AlgorithmLauncher& AlgorithmLauncher::operator=(AlgorithmLauncher&& other) noexcept
{
if (this != &other) {
// Unload current library if it exists
if (library != nullptr) { cudaLibraryUnload(library); }
kernel = other.kernel;
library = other.library;
Expand All @@ -47,3 +46,27 @@ void AlgorithmLauncher::call(

RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args));
}

void AlgorithmLauncher::call_cooperative(
cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** kernel_args)
{
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeCooperative;
attribute[0].val.cooperative = 1;

cudaLaunchConfig_t config{};
config.gridDim = grid;
config.blockDim = block;
config.stream = stream;
config.dynamicSmemBytes = shared_mem;
config.numAttrs = 1;
config.attrs = attribute;

RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args));
}

std::unordered_map<std::string, std::shared_ptr<AlgorithmLauncher>>& get_cached_launchers()
{
static std::unordered_map<std::string, std::shared_ptr<AlgorithmLauncher>> launchers;
return launchers;
}
4 changes: 2 additions & 2 deletions cpp/src/detail/jit_lto/AlgorithmPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ std::shared_ptr<AlgorithmLauncher> AlgorithmPlanner::build()

// Load the generated LTO IR and link them together
nvJitLinkHandle handle;
const char* lopts[] = {"-lto", archs.c_str()};
auto result = nvJitLinkCreate(&handle, 2, lopts);
const char* lopts[] = {"-lto", archs.c_str(), "-maxrregcount=64"};
auto result = nvJitLinkCreate(&handle, 3, lopts);
check_nvjitlink_result(handle, result);

for (const auto& frag : this->fragments) {
Expand Down
1 change: 1 addition & 0 deletions cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <raft/linalg/norm.cuh>
#include <raft/linalg/reduce.cuh>

// All includes are done before opening namespace to avoid nested namespace issues
namespace cuvs::neighbors::cagra::detail {

template <typename DataT,
Expand Down
Loading
Loading