Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,52 @@

#include <cuda_fp16.h>
#include <torch/extension.h>

#include <vector>

namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {

torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);

torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);

torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");

return fwd_cuda(input, scale_factor);
}

torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {

torch::Tensor bwd(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results, float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");

AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");

return bwd_cuda(output_grads, softmax_results, scale_factor);
}

} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}