Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3684,7 +3684,7 @@ class MergeUpAndDownCast {
}

bool isUpCast(SegmentedGroup* group) const {
if (auto precision_bits = getProducerConsumerPrecision(group);
if (auto precision_bits = getProducerConsumerPrecisionBit(group);
precision_bits.has_value()) {
return precision_bits->first < precision_bits->second;
} else {
Expand All @@ -3693,15 +3693,15 @@ class MergeUpAndDownCast {
}

bool isDownCast(SegmentedGroup* group) const {
if (auto precision_bits = getProducerConsumerPrecision(group);
if (auto precision_bits = getProducerConsumerPrecisionBit(group);
precision_bits.has_value()) {
return precision_bits->first > precision_bits->second;
} else {
return false;
}
}

std::optional<std::pair<int64_t, int64_t>> getProducerConsumerPrecision(
std::optional<std::pair<int64_t, int64_t>> getProducerConsumerPrecisionBit(
SegmentedGroup* group) const {
if (group->exprs().size() != 1) {
return std::nullopt;
Expand All @@ -3712,7 +3712,7 @@ class MergeUpAndDownCast {
return std::nullopt;
}

return ir_utils::getPrecisionOfProducerConsumerTensors(uop);
return ir_utils::getPrecisionOfProducerConsumerTensorsBit(uop);
}

private:
Expand Down Expand Up @@ -4372,7 +4372,7 @@ void SegmentCandidateFinder::privatizeUpcast() {
}

auto precisions =
ir_utils::getPrecisionOfProducerConsumerTensors(maybe_upcast_op);
ir_utils::getPrecisionOfProducerConsumerTensorsBit(maybe_upcast_op);
if (!precisions.has_value() || precisions->first >= precisions->second) {
continue;
}
Expand Down
7 changes: 3 additions & 4 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1551,8 +1551,8 @@ std::vector<IterDomain*> strideOrderToAllocation(
return allocation_domain;
}

std::optional<std::pair<int64_t, int64_t>> getPrecisionOfProducerConsumerTensors(
UnaryOp* uop) {
std::optional<std::pair<int64_t, int64_t>>
getPrecisionOfProducerConsumerTensorsBit(UnaryOp* uop) {
NVF_CHECK(uop != nullptr);
NVF_CHECK(
uop->getUnaryOpType() == UnaryOpType::Cast,
Expand All @@ -1577,8 +1577,7 @@ std::optional<std::pair<int64_t, int64_t>> getPrecisionOfProducerConsumerTensors
}

return std::make_pair(
primDataTypeSizeByte(*inp_prim_type),
primDataTypeSizeByte(*out_prim_type));
primDataTypeSizeBit(*inp_prim_type), primDataTypeSizeBit(*out_prim_type));
}

int64_t getTMemLdStVectorizeSize(TensorView* consumer_tv) {
Expand Down
6 changes: 3 additions & 3 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -794,10 +794,10 @@ std::vector<IterDomain*> strideOrderToAllocation(
const std::vector<IterDomain*>& logical_domain,
const std::vector<int64_t>& stride_order);

// Returns the number of bytes of data types of the producer and
// Returns the number of bits of data types of the producer and
// consumer tensors of a cast unary op
std::optional<std::pair<int64_t, int64_t>> getPrecisionOfProducerConsumerTensors(
UnaryOp* cast_op);
std::optional<std::pair<int64_t, int64_t>>
getPrecisionOfProducerConsumerTensorsBit(UnaryOp* cast_op);

// Get the <size> in the PTX instruction of TMem load/store:
// tcgen05.st.sync.aligned.32x32b.x<size>.b32
Expand Down
2 changes: 1 addition & 1 deletion csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ NVF_API TensorView* reshape(
logical_domain,
TensorDomain::getContiguityFilledWith(logical_domain, true)),
x->getDataType().value());
IrBuilder::create<ViewOp>(x, out_tv);
IrBuilder::create<ViewOp>(out_tv, x);
return out_tv;
}

Expand Down
9 changes: 9 additions & 0 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
namespace nvfuser {

bool MatmulScheduler::canScheduleCompileTime(Fusion* fusion) {
for (auto tv : fusion->allTvs()) {
if (tv->dtype() != DataType::Index &&
dataTypeSizeBit(tv->dtype()) % 8 != 0) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Does not support sub-byte data types.");
return false;
}
}

const auto msg = matmul_utils::getMatmulCompileTimeRejectReason(fusion);
if (!msg.empty()) {
scheduler_debug_utils::canScheduleRejectReason(schedulerType(), msg);
Expand Down
10 changes: 10 additions & 0 deletions csrc/scheduler/normalization_inner_outer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,16 @@ bool InnerOuterPersistentKernelScheduler::canScheduleCompileTime(
Fusion* fusion) {
FUSER_PERF_SCOPE(
"InnerOuterPersistentKernelScheduler::canScheduleCompileTime");

for (auto tv : fusion->allTvs()) {
if (tv->dtype() != DataType::Index &&
dataTypeSizeBit(tv->dtype()) % 8 != 0) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Does not support sub-byte data types.");
return false;
}
}

// common checks for all persistent heuristics
if (!normalization_scheduler_utils::checkOpsAndInputs(
fusion, schedulerType())) {
Expand Down
9 changes: 9 additions & 0 deletions csrc/scheduler/normalization_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,15 @@ bool checkReductionPattern(
// The identical compile time check of InnerPersistentKernelScheduler and
// OuterPersistentKernelScheduler.
bool compileTimeCheck(Fusion* fusion, SchedulerType scheduler_type) {
for (auto tv : fusion->allTvs()) {
if (tv->dtype() != DataType::Index &&
dataTypeSizeBit(tv->dtype()) % 8 != 0) {
scheduler_debug_utils::canScheduleRejectReason(
scheduler_type, "Does not support sub-byte data types.");
return false;
}
}

// common checks for all persistent heuristics
if (!normalization_scheduler_utils::checkOpsAndInputs(
fusion, scheduler_type)) {
Expand Down
10 changes: 10 additions & 0 deletions csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1650,6 +1650,16 @@ void scheduleReduction(Fusion* fusion, const ReductionParams* rparams) {
//! Check if the reduction heuristics apply in given fusion
bool ReductionScheduler::canScheduleCompileTime(Fusion* fusion) {
FUSER_PERF_SCOPE("ReductionScheduler::canScheduleCompileTime");

for (auto tv : fusion->allTvs()) {
if (tv->dtype() != DataType::Index &&
dataTypeSizeBit(tv->dtype()) % 8 != 0) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Does not support sub-byte data types.");
return false;
}
}

if (scheduler_utils::isResharding(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Fusion is resharding.");
Expand Down
9 changes: 9 additions & 0 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
return false;
}

for (auto tv : fusion->allTvs()) {
if (tv->dtype() != DataType::Index &&
dataTypeSizeBit(tv->dtype()) % 8 != 0) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Does not support sub-byte data types.");
return false;
}
}

if (!scheduler_tools::hasResizeBasedOps(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "No resize op to schedule");
Expand Down
9 changes: 9 additions & 0 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ namespace nvfuser {

bool TransposeScheduler::canScheduleCompileTime(Fusion* fusion) {
FUSER_PERF_SCOPE("TransposeScheduler::canScheduleCompileTime");
for (auto tv : fusion->allTvs()) {
if (tv->dtype() != DataType::Index &&
dataTypeSizeBit(tv->dtype()) % 8 != 0) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Does not support sub-byte data types.");
return false;
}
}

if (scheduler_utils::isResharding(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Fusion is resharding.");
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ TensorView* getUpCastInputOf(const TensorView* tv) {
return nullptr;
}
// skip if the cast is not upcast
auto precisions = ir_utils::getPrecisionOfProducerConsumerTensors(uop);
auto precisions = ir_utils::getPrecisionOfProducerConsumerTensorsBit(uop);
if (!precisions.has_value() || precisions->first >= precisions->second) {
return nullptr;
}
Expand Down
25 changes: 25 additions & 0 deletions runtime/helpers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,31 @@ __device__ float fmax(float a, float b) {
}
}

__device__ __half fmax(__half a, __half b) {
auto a_float = __half2float(a);
auto b_float = __half2float(b);
return __float2half(fmax(a_float, b_float));
}

__device__ __bfloat fmax(__bfloat a, __bfloat b) {
auto a_float = __bfloat2float(a);
auto b_float = __bfloat2float(b);
return __float2bfloat(fmax(a_float, b_float));
}

template <typename T>
__device__ T abs(T a) {
return a > 0 ? a : -a;
}

__device__ __half abs(__half a) {
return __float2half(fabs(__half2float(a)));
}

__device__ __bfloat abs(__bfloat a) {
return __float2bfloat(fabs(__bfloat2float(a)));
}

__device__ constexpr int min(int a, int b) {
return a > b ? b : a;
}
Expand Down
14 changes: 7 additions & 7 deletions tests/cpp/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8972,20 +8972,20 @@ TEST_F(NVFuserTest, CastPrecision) {
auto tv4 = castOp(DataType::Int, tv3);
fusion.addOutput(tv4);

auto tv1_precision = ir_utils::getPrecisionOfProducerConsumerTensors(
auto tv1_precision = ir_utils::getPrecisionOfProducerConsumerTensorsBit(
tv1->definition()->as<UnaryOp>());
ASSERT_TRUE(tv1_precision.has_value());
EXPECT_EQ(tv1_precision->first, 2);
EXPECT_EQ(tv1_precision->second, 4);
EXPECT_EQ(tv1_precision->first, 16);
EXPECT_EQ(tv1_precision->second, 32);

auto tv2_precision = ir_utils::getPrecisionOfProducerConsumerTensors(
auto tv2_precision = ir_utils::getPrecisionOfProducerConsumerTensorsBit(
tv2->definition()->as<UnaryOp>());
ASSERT_TRUE(tv2_precision.has_value());
EXPECT_EQ(tv2_precision->first, 4);
EXPECT_EQ(tv2_precision->second, 2);
EXPECT_EQ(tv2_precision->first, 32);
EXPECT_EQ(tv2_precision->second, 16);

// Precision of type Index is not possible to determine until lowering
auto tv4_precision = ir_utils::getPrecisionOfProducerConsumerTensors(
auto tv4_precision = ir_utils::getPrecisionOfProducerConsumerTensorsBit(
tv4->definition()->as<UnaryOp>());
ASSERT_FALSE(tv4_precision.has_value());
}
Expand Down
44 changes: 30 additions & 14 deletions tests/cpp/test_low_precision_recipe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

namespace nvfuser {

using FP4RecipeTest = NVFuserTest;

// Testing the following function:
// https://github.com/pytorch/ao/blob/b1163dc63dfa22d403586672fd3648cd661c5003/torchao/prototype/mx_formats/nvfp4_tensor.py#L545-L617
//
Expand Down Expand Up @@ -105,17 +103,17 @@ constexpr double F4_E2M1_MAX = 6.0;
constexpr double E4M3_EPS = 0.015625;
constexpr double F8E4M3_MAX = 448.0;

class NVFP4QuantizeTest : public FP4RecipeTest,
class NVFP4QuantizeTest : public BlackwellBase,
public ::testing::WithParamInterface<DataType> {};

TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) {
auto data_hp_dtype = GetParam();

Fusion fusion;
FusionGuard fg(&fusion);
std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

auto tv_data_hp = makeContigTensor(2, data_hp_dtype);
fusion.addInput(tv_data_hp);
fusion->addInput(tv_data_hp);

auto tv_data_hp_reshaped =
reshape(tv_data_hp, [](auto& x) { x.split(-1, block_size); });
Expand All @@ -142,23 +140,32 @@ TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) {
tv_data_scaled,
IrBuilder::create<Val>(-F4_E2M1_MAX, DataType::Float),
IrBuilder::create<Val>(F4_E2M1_MAX, DataType::Float));

auto tv_data_lp_fp4 = castOp(DataType::Float4_e2m1fn, tv_data_scaled_clamp);
auto tv_data_lp = reshape(tv_data_lp_fp4, [](auto& x) { x.merge(-2); });

fusion.addOutput(tv_block_scale_fp8);
fusion.addOutput(tv_data_lp);
fusion->addOutput(tv_block_scale_fp8);
fusion->addOutput(tv_data_lp);

FusionExecutorCache fec(std::move(fusion));

std::vector<at::Tensor> inputs;
inputs.push_back(
at::randn({1024, 1024}, at::device(at::kCUDA).dtype(at::kFloat))
.to(data_type_to_aten(data_hp_dtype)));
auto outputs = fec.runFusionWithInputs(inputs);
}

TEST_P(NVFP4QuantizeTest, WithPerTensorAmax) {
auto data_hp_dtype = GetParam();

Fusion fusion;
FusionGuard fg(&fusion);
std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

auto tv_data_hp = makeContigTensor(2, data_hp_dtype);
auto tv_per_tensor_scale = makeContigTensor(0, DataType::Float);
fusion.addInput(tv_data_hp);
fusion.addInput(tv_per_tensor_scale);
fusion->addInput(tv_data_hp);
fusion->addInput(tv_per_tensor_scale);

auto tv_data_hp_reshaped =
reshape(tv_data_hp, [](auto& x) { x.split(-1, block_size); });
Expand Down Expand Up @@ -199,8 +206,17 @@ TEST_P(NVFP4QuantizeTest, WithPerTensorAmax) {
auto tv_data_lp_fp4 = castOp(DataType::Float4_e2m1fn, tv_data_scaled_clamp);
auto tv_data_lp = reshape(tv_data_lp_fp4, [](auto& x) { x.merge(-2); });

fusion.addOutput(tv_scaled_block_scales_fp8);
fusion.addOutput(tv_data_lp);
fusion->addOutput(tv_scaled_block_scales_fp8);
fusion->addOutput(tv_data_lp);

FusionExecutorCache fec(std::move(fusion));

std::vector<at::Tensor> inputs;
inputs.push_back(
at::randn({1024, 1024}, at::device(at::kCUDA).dtype(at::kFloat))
.to(data_type_to_aten(data_hp_dtype)));
inputs.push_back(at::randn({}, at::device(at::kCUDA).dtype(at::kFloat)));
auto outputs = fec.runFusionWithInputs(inputs);
}

INSTANTIATE_TEST_SUITE_P(
Expand Down