Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiter/dist/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def tensor_model_parallel_all_reduce(

def tensor_model_parallel_fused_allreduce_rmsnorm(
input_: torch.Tensor, weight_: torch.Tensor, eps: float
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
return get_tp_group().fused_allreduce_rmsnorm(input_, weight_, eps)


Expand Down
17 changes: 11 additions & 6 deletions aiter/dist/device_communicators/communicator_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,14 @@ def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor:
torch.distributed.all_reduce(out, group=self.device_group)
return out

def fused_allreduce_rmsnorm(self, input_, weight_, eps) -> torch.Tensor:
def fused_allreduce_rmsnorm(
self, input_, weight_, eps
) -> tuple[torch.Tensor, torch.Tensor]:
n = input_.shape[-1]
can_use_fuse_ar_rms = (
n <= 16384 and input_.numel() * input_.element_size() < 8 * 1024 * 8192
n <= 16384
and input_.numel() * input_.element_size() < 8 * 1024 * 8192
and self.world_size != 6
)
ca_comm = self.ca_comm
if (
Expand All @@ -170,11 +174,12 @@ def fused_allreduce_rmsnorm(self, input_, weight_, eps) -> torch.Tensor:
and ca_comm.should_custom_ar(input_)
and can_use_fuse_ar_rms
):
out = ca_comm.custom_fused_ar_rms(input_, weight_, eps)
res_out, out = ca_comm.custom_fused_ar_rms(input_, weight_, eps)
assert out is not None
return out
assert res_out is not None
return res_out, out
# call split kernel
ar_out = all_reduce(input_)
ar_out = self.all_reduce(input_)
out = torch.empty_like(ar_out)
residual_out = torch.empty_like(ar_out)
from aiter import rmsnorm2d_fwd_with_add
Expand All @@ -188,7 +193,7 @@ def fused_allreduce_rmsnorm(self, input_, weight_, eps) -> torch.Tensor:
eps,
0,
)
return out
return residual_out, out

def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
Expand Down
8 changes: 6 additions & 2 deletions aiter/dist/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,22 +340,26 @@ def fused_ar_rms(
self,
inp: torch.Tensor,
*,
res_out: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
w: torch.Tensor,
eps: float,
registered: bool = False,
):
if out is None:
out = torch.empty_like(inp)
if res_out is None:
res_out = torch.empty_like(inp)
ops.fused_allreduce_rmsnorm(
self._ptr,
inp,
res_out,
out,
w,
eps,
None if registered else self.buffer,
)
return out
return res_out, out

def custom_fused_ar_rms(
self, input: torch.Tensor, weight: torch.Tensor, eps: float
Expand All @@ -367,7 +371,7 @@ def custom_fused_ar_rms(
if torch.cuda.is_current_stream_capturing():
return self.fused_ar_rms(input, w=weight, eps=eps, registered=True)
else:
return torch.empty_like(input)
return torch.empty_like(input), torch.empty_like(input)
else:
return self.fused_ar_rms(input, w=weight, eps=eps, registered=False)

Expand Down
7 changes: 4 additions & 3 deletions aiter/dist/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def fused_allreduce_rmsnorm_fake(
@torch_compile_guard(gen_fake=fused_allreduce_rmsnorm_fake)
def fused_allreduce_rmsnorm_(
inp: torch.Tensor, w: torch.Tensor, eps: float, group_name: str
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
Expand Down Expand Up @@ -348,14 +348,14 @@ def _all_reduce_out_place(

def fused_allreduce_rmsnorm(
self, input_: torch.Tensor, weight_: torch.Tensor, eps: float
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
return fused_allreduce_rmsnorm_(
input_, weight_, eps, group_name=self.unique_name
)

def _fused_allreduce_rmsnorm_out_place(
self, input_: torch.Tensor, weight_: torch.Tensor, eps: float
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.fused_allreduce_rmsnorm(input_, weight_, eps)
Expand Down Expand Up @@ -861,6 +861,7 @@ def get_pp_group() -> GroupCoordinator:


from typing import Optional

_DP: Optional[GroupCoordinator] = None


Expand Down
1 change: 1 addition & 0 deletions aiter/ops/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def all_gather_unreg(
def fused_allreduce_rmsnorm(
_fa: int,
inp: torch.Tensor,
res_out: torch.Tensor,
out: torch.Tensor,
w: torch.Tensor,
eps: float,
Expand Down
42 changes: 24 additions & 18 deletions csrc/include/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,7 @@ namespace aiter
template <typename T, int tnum, int n_loop>
__global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm_naive(
RankSignals sg,
T* __restrict__ residual_out,
T* __restrict__ results,
T* __restrict__ weight,
float eps,
Expand Down Expand Up @@ -1028,6 +1029,7 @@ namespace aiter
}
int write_idx = bid * n_loop * blockDim.x + n_iter * blockDim.x + threadIdx.x;
*(reinterpret_cast<P*>(results) + write_idx) = rmsnorm_rslt;
*(reinterpret_cast<P*>(residual_out) + write_idx) = rmsnorm_inp[n_iter];
}
}
}
Expand All @@ -1039,6 +1041,7 @@ namespace aiter
template <typename T, int tnum, int n_loop>
__global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm(
RankSignals sg,
T* __restrict__ residual_out,
T* __restrict__ results,
T* __restrict__ weight,
float eps,
Expand Down Expand Up @@ -1096,6 +1099,7 @@ namespace aiter
}
int write_idx = bid * (n / pack_size) + n_iter * tnum + threadIdx.x;
*(reinterpret_cast<P*>(results) + write_idx) = rmsnorm_rslt;
*(reinterpret_cast<P*>(residual_out) + write_idx) = rmsnorm_inp[n_iter];
}
}
}
Expand All @@ -1104,6 +1108,7 @@ namespace aiter
template <typename T, int n_loop>
__global__ void __launch_bounds__(256, 1) local_device_load_rmsnorm_512n(
RankSignals sg,
T* __restrict__ residual_out,
T* __restrict__ results,
T* __restrict__ weight,
float eps,
Expand Down Expand Up @@ -1156,6 +1161,7 @@ namespace aiter
}
int write_idx = bid * 64 * n_loop + n_iter * 64 + lane_id;
*(reinterpret_cast<P*>(results) + write_idx) = rmsnorm_rslt;
*(reinterpret_cast<P*>(residual_out) + write_idx) = rmsnorm_inp[n_iter];
}
}
}
Expand Down Expand Up @@ -1489,17 +1495,17 @@ namespace aiter
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);

#define dispatch(ngpus, name) \
do \
{ \
if (bytes % 128 == 0) \
{ \
KL(ngpus, name) \
} \
else \
{ \
KL(ngpus, name##_naive) \
} \
#define dispatch(ngpus, name) \
do \
{ \
if (bytes % 128 == 0 && world_size_ != 6) \
{ \
KL(ngpus, name) \
} \
else \
{ \
KL(ngpus, name##_naive) \
} \
} while(0)

#define REDUCE_CASE(ngpus) \
Expand Down Expand Up @@ -1581,7 +1587,7 @@ namespace aiter
}

template <typename T>
void dispatchFusedAllReduceRMSNorm(hipStream_t stream, T* input, T* output, T* weight, float eps, int m, int n)
void dispatchFusedAllReduceRMSNorm(hipStream_t stream, T* input, T* residual_out, T* output, T* weight, float eps, int m, int n)
{
auto d = packed_t<T>::P::size;
int size = m * n;
Expand Down Expand Up @@ -1627,12 +1633,12 @@ namespace aiter
grid.x = naive_grid_size < num_cu * occupancy ? naive_grid_size : num_cu * occupancy;
};

#define launch_fused_allreduce_rmsnorm(template_kernel) \
do \
{ \
auto kernel_ptr = reinterpret_cast<const void*>(template_kernel); \
setGrid(naive_grid_size, kernel_ptr); \
template_kernel<<<grid, block, 0, stream>>>(sg_, output, weight, eps, rank_, m, n); \
#define launch_fused_allreduce_rmsnorm(template_kernel) \
do \
{ \
auto kernel_ptr = reinterpret_cast<const void*>(template_kernel); \
setGrid(naive_grid_size, kernel_ptr); \
template_kernel<<<grid, block, 0, stream>>>(sg_, residual_out, output, weight, eps, rank_, m, n); \
} while (0)

if (n_bytes % 1024 == 0)
Expand Down
1 change: 1 addition & 0 deletions csrc/include/custom_all_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void all_gather_unreg(fptr_t _fa,
torch::Tensor& out);
void fused_allreduce_rmsnorm(fptr_t _fa,
torch::Tensor& inp,
torch::Tensor& res_out,
torch::Tensor& out,
torch::Tensor& w,
float eps,
Expand Down
1 change: 1 addition & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ namespace py = pybind11;
&aiter::fused_allreduce_rmsnorm, \
py::arg("_fa"), \
py::arg("inp"), \
py::arg("res_out"), \
py::arg("out"), \
py::arg("w"), \
py::arg("eps"), \
Expand Down
10 changes: 7 additions & 3 deletions csrc/kernels/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ void all_gather_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
}

void _fused_allreduce_rmsnorm(
fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, torch::Tensor& w, int eps, int m, int n, hipStream_t stream)
fptr_t _fa, torch::Tensor& inp, torch::Tensor& residual_out, torch::Tensor& out, torch::Tensor& w, int eps, int m, int n, hipStream_t stream)
{
auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
Expand All @@ -226,6 +226,7 @@ void _fused_allreduce_rmsnorm(
case at::ScalarType::Float: {
fa->dispatchFusedAllReduceRMSNorm<float>(stream,
reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(residual_out.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
reinterpret_cast<float*>(w.data_ptr()),
eps, m, n);
Expand All @@ -234,6 +235,7 @@ void _fused_allreduce_rmsnorm(
case at::ScalarType::Half: {
fa->dispatchFusedAllReduceRMSNorm<half>(stream,
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(residual_out.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
reinterpret_cast<half*>(w.data_ptr()),
eps, m, n);
Expand All @@ -243,6 +245,7 @@ void _fused_allreduce_rmsnorm(
case at::ScalarType::BFloat16: {
fa->dispatchFusedAllReduceRMSNorm<__hip_bfloat16>(stream,
reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()),
reinterpret_cast<__hip_bfloat16*>(residual_out.data_ptr()),
reinterpret_cast<__hip_bfloat16*>(out.data_ptr()),
reinterpret_cast<__hip_bfloat16*>(w.data_ptr()),
eps, m, n);
Expand All @@ -256,6 +259,7 @@ void _fused_allreduce_rmsnorm(

void fused_allreduce_rmsnorm(fptr_t _fa,
torch::Tensor& inp,
torch::Tensor& res_out,
torch::Tensor& out,
torch::Tensor& w,
float eps,
Expand All @@ -278,11 +282,11 @@ void fused_allreduce_rmsnorm(fptr_t _fa,
input_size,
hipMemcpyDeviceToDevice,
stream));
_fused_allreduce_rmsnorm(_fa, reg_buffer.value(), out, w, eps, m, n, stream);
_fused_allreduce_rmsnorm(_fa, reg_buffer.value(), res_out, out, w, eps, m, n, stream);
}
else
{
_fused_allreduce_rmsnorm(_fa, inp, out, w, eps, m, n, stream);
_fused_allreduce_rmsnorm(_fa, inp, res_out, out, w, eps, m, n, stream);
}
}

Expand Down
27 changes: 8 additions & 19 deletions op_tests/multigpu_tests/test_fused_ar_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ def fused_ar_rmsnorm(tp_size, pp_size, rankID, x, weight, eps, withGraph=False):
graph = torch.cuda.CUDAGraph()
with graph_capture() as gc:
with torch.cuda.graph(graph, stream=gc.stream):
out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps)
res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm(
x, weight, eps
)
out.fill_(0)
res_out.fill_(0)

@perftest()
def run_ca():
Expand All @@ -75,7 +78,8 @@ def run_ca():

@perftest()
def run_ca(x):
return tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps)
res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps)
return out

out = run_ca(x)

Expand Down Expand Up @@ -113,7 +117,7 @@ def get_acc_value_with_cudagraph(tp_size, pp_size, rankID, x, weight, eps, loop_
with graph_capture() as gc:
with torch.cuda.graph(graph, stream=gc.stream):
# out = torch.empty_like(x)
out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps)
res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps)
out.fill_(0)

def run_ca():
Expand Down Expand Up @@ -154,7 +158,7 @@ def get_acc_value_only(tp_size, pp_size, rankID, x, weight, eps, loop_time=1):
torch.cuda.synchronize()

for i in range(loop_time):
out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps)
res, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps)

# destroy
if dist.is_initialized():
Expand Down Expand Up @@ -238,19 +242,6 @@ def run_ca(x):
return out


def run_cu(input, weight, eps, device_id):
device = f"cuda:{device_id}"
input = input.to(device)
weight = weight.to(device)

@perftest()
def compute():
output = torch.empty_like(input)
aiter.rms_norm_cu(output, input, weight, eps)

return compute()


@benchmark()
def test_split_ar_rmsnorm(tp_size, pp_size, shape, dtype, withGraph=False):
os.environ["MASTER_ADDR"] = "127.0.0.1"
Expand All @@ -275,7 +266,6 @@ def test_split_ar_rmsnorm(tp_size, pp_size, shape, dtype, withGraph=False):
pool.apply_async(
split_ar_rmsnorm, args=(tp_size, pp_size, i, x, weight, eps, withGraph)
)
# pool.apply_async(run_cu, args=(x, weight, eps, i))
)
pool.close()
pool.join()
Expand Down Expand Up @@ -320,7 +310,6 @@ def test_fused_ar_rmsnorm(tp_size, pp_size, shape, dtype, withGraph=False):
pool.apply_async(
fused_ar_rmsnorm, args=(tp_size, pp_size, i, x, weight, eps, withGraph)
)
# pool.apply_async(run_cu, args=(x, weight, eps, i))
)
pool.close()
pool.join()
Expand Down
Loading