diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh index 5e6d2137c0..cc7587bfda 100644 --- a/csrc/include/custom_all_reduce.cuh +++ b/csrc/include/custom_all_reduce.cuh @@ -46,7 +46,7 @@ typedef __hip_bfloat16 nv_bfloat16; } \ } while (0) -namespace vllm +namespace aiter { constexpr int kMaxBlocks = 64; @@ -1028,11 +1028,11 @@ namespace vllm CUDACHECK(cudaIpcCloseMemHandle(ptr)); } } -}; // namespace vllm +}; // namespace aiter /** * To inspect PTX/SASS, copy paste this header file to compiler explorer and add a template instantiation: - * template void vllm::CustomAllreduce::allreduce(cudaStream_t, half *, + * template void aiter::CustomAllreduce::allreduce(cudaStream_t, half *, half *, int, int, int); */ -} // namespace vllm +} // namespace aiter diff --git a/csrc/kernels/custom_all_reduce.cu b/csrc/kernels/custom_all_reduce.cu index 2ff254c3d1..92edd73489 100644 --- a/csrc/kernels/custom_all_reduce.cu +++ b/csrc/kernels/custom_all_reduce.cu @@ -48,8 +48,8 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, { std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } - return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + return (fptr_t) new aiter::CustomAllreduce( + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } @@ -79,7 +79,7 @@ bool _is_weak_contiguous(torch::Tensor &t) void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, cudaStream_t stream, bool open_fp8_quant) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); switch (out.scalar_type()) { @@ -150,24 +150,24 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); delete fa; } -int64_t meta_size() { return sizeof(vllm::Signal); } +int64_t meta_size() { return sizeof(aiter::Signal); } void register_buffer(fptr_t _fa, torch::Tensor &t, const std::vector &handles, const std::vector &offsets) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); fa->register_buffer(handles, offsets, t.data_ptr()); } std::tuple> get_graph_buffer_ipc_meta( fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); @@ -180,7 +180,7 @@ std::tuple> get_graph_buffer_ipc_meta( void register_graph_buffers(fptr_t _fa, const std::vector &handles, const std::vector> &offsets) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); fa->register_graph_buffers(handles, offsets); } diff --git a/csrc/py_itfs_cu/asm_communication.cu b/csrc/py_itfs_cu/asm_communication.cu index e63c085d61..7f029a9b85 100644 --- a/csrc/py_itfs_cu/asm_communication.cu +++ b/csrc/py_itfs_cu/asm_communication.cu @@ -29,8 +29,8 @@ torch::Tensor all_reduce_asm(torch::Tensor &input, inp_ptr = reg_buffer.data_ptr(); } - auto ca = reinterpret_cast(_ca); - using RD = vllm::RankData; + auto ca = reinterpret_cast(_ca); + using RD = aiter::RankData; RD *input_rd = ca->get_buffer_RD(stream, inp_ptr); RD *sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr()); @@ -138,8 +138,8 @@ std::tuple all_reduce_rmsnorm(torch::Tensor &input inp_ptr = reg_buffer.data_ptr(); } - auto ca = reinterpret_cast(_ca); - using RD = vllm::RankData; + auto ca = reinterpret_cast(_ca); + using RD = aiter::RankData; RD *sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr()); RD *reg_rd = ca->get_buffer_RD(stream, reg_buffer.data_ptr()); @@ -280,8 +280,8 @@ std::tuple all_reduce_rmsnorm_quant inp_ptr = reg_buffer.data_ptr(); } - auto ca = reinterpret_cast(_ca); - using RD = vllm::RankData; + auto ca = reinterpret_cast(_ca); + using RD = aiter::RankData; RD *sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr()); RD *reg_rd = ca->get_buffer_RD(stream, reg_buffer.data_ptr()); @@ -404,4 +404,4 @@ std::tuple all_reduce_rmsnorm_quant torch::from_blob(res_ptr, {input.sizes()}, opt_res), torch::from_blob(ys_ptr, {M, 1}, opt_ys), }; -}; \ No newline at end of file +};