-
Notifications
You must be signed in to change notification settings - Fork 169
Sampling #727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces comprehensive sampling functionality for neural networks, adding support for top-k and top-p sampling strategies. The implementation includes GPU-accelerated kernels and comprehensive test coverage.
- Adds top-k probability renormalization, top-p sampling, and joint top-k/top-p sampling operations
- Implements CUDA/HIP kernels with vectorized data types for efficient GPU computation
- Provides comprehensive test coverage for various batch sizes, vocabulary sizes, and sampling parameters
Reviewed Changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_sampling.py | Comprehensive test suite for all sampling operations |
| csrc/cpp_itfs/utils.py | Updates import path for FileBaton utility |
| csrc/cpp_itfs/torch_utils.py | Adds log_args function and updates imports |
| csrc/cpp_itfs/sampling/vec_dtypes.cuh | Core vectorized data types for GPU kernels |
| csrc/cpp_itfs/sampling/top_*.py | Python wrappers for sampling operations |
| csrc/cpp_itfs/sampling/top_*.cpp.jinja | CUDA kernel templates |
| csrc/cpp_itfs/sampling/sampling.cuh | Main sampling kernel implementations |
| csrc/cpp_itfs/file_baton.py | File-based synchronization utility |
| aiter/ops/sampling.py | High-level Python API for sampling operations |
Comments suppressed due to low confidence (1)
csrc/cpp_itfs/sampling/vec_dtypes.cuh:1587
- The namespace comment refers to 'flashinfer' but the actual namespace is 'aiter' as defined on line 52.
} // namespace flashinfer
No description provided.