Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions csrc/include/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ typedef __hip_bfloat16 nv_bfloat16;
} \
} while (0)

namespace vllm
namespace aiter
{

constexpr int kMaxBlocks = 64;
Expand Down Expand Up @@ -1028,11 +1028,11 @@ namespace vllm
CUDACHECK(cudaIpcCloseMemHandle(ptr));
}
}
}; // namespace vllm
}; // namespace aiter
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
* template void aiter::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
*/
} // namespace vllm
} // namespace aiter
16 changes: 8 additions & 8 deletions csrc/kernels/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
{
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
}
return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
return (fptr_t) new aiter::CustomAllreduce(
reinterpret_cast<aiter::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}

Expand Down Expand Up @@ -79,7 +79,7 @@ bool _is_weak_contiguous(torch::Tensor &t)
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
cudaStream_t stream, bool open_fp8_quant)
{
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
auto fa = reinterpret_cast<aiter::CustomAllreduce *>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type())
{
Expand Down Expand Up @@ -150,24 +150,24 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,

void dispose(fptr_t _fa)
{
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
auto fa = reinterpret_cast<aiter::CustomAllreduce *>(_fa);
delete fa;
}

int64_t meta_size() { return sizeof(vllm::Signal); }
int64_t meta_size() { return sizeof(aiter::Signal); }

void register_buffer(fptr_t _fa, torch::Tensor &t,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets)
{
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
auto fa = reinterpret_cast<aiter::CustomAllreduce *>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
}

std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa)
{
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
auto fa = reinterpret_cast<aiter::CustomAllreduce *>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
Expand All @@ -180,7 +180,7 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets)
{
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
auto fa = reinterpret_cast<aiter::CustomAllreduce *>(_fa);
fa->register_graph_buffers(handles, offsets);
}

Expand Down
14 changes: 7 additions & 7 deletions csrc/py_itfs_cu/asm_communication.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ torch::Tensor all_reduce_asm(torch::Tensor &input,
inp_ptr = reg_buffer.data_ptr();
}

auto ca = reinterpret_cast<vllm::CustomAllreduce *>(_ca);
using RD = vllm::RankData;
auto ca = reinterpret_cast<aiter::CustomAllreduce *>(_ca);
using RD = aiter::RankData;

RD *input_rd = ca->get_buffer_RD(stream, inp_ptr);
RD *sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr());
Expand Down Expand Up @@ -138,8 +138,8 @@ std::tuple<torch::Tensor, torch::Tensor> all_reduce_rmsnorm(torch::Tensor &input
inp_ptr = reg_buffer.data_ptr();
}

auto ca = reinterpret_cast<vllm::CustomAllreduce *>(_ca);
using RD = vllm::RankData;
auto ca = reinterpret_cast<aiter::CustomAllreduce *>(_ca);
using RD = aiter::RankData;

RD *sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr());
RD *reg_rd = ca->get_buffer_RD(stream, reg_buffer.data_ptr());
Expand Down Expand Up @@ -280,8 +280,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> all_reduce_rmsnorm_quant
inp_ptr = reg_buffer.data_ptr();
}

auto ca = reinterpret_cast<vllm::CustomAllreduce *>(_ca);
using RD = vllm::RankData;
auto ca = reinterpret_cast<aiter::CustomAllreduce *>(_ca);
using RD = aiter::RankData;

RD *sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr());
RD *reg_rd = ca->get_buffer_RD(stream, reg_buffer.data_ptr());
Expand Down Expand Up @@ -404,4 +404,4 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> all_reduce_rmsnorm_quant
torch::from_blob(res_ptr, {input.sizes()}, opt_res),
torch::from_blob(ys_ptr, {M, 1}, opt_ys),
};
};
};