Skip to content

Reapply "Pipe quantize kernel through FusionExecutorCache (#4760)" (#4854)#4874

Merged
zasdfgbnm merged 5 commits intomainfrom
pipe-fp4
Jul 29, 2025
Merged

Reapply "Pipe quantize kernel through FusionExecutorCache (#4760)" (#4854)#4874
zasdfgbnm merged 5 commits intomainfrom
pipe-fp4

Conversation

@zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Jul 28, 2025

With #4852 fixed

@zasdfgbnm
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Jul 28, 2025

Review updated until commit 7b7d722

Description

  • Rename getPrecisionOfProducerConsumerTensors to getPrecisionOfProducerConsumerTensorsBit

  • Update function to return precision in bits instead of bytes

  • Add support for sub-byte data types in schedulers

  • Fix ViewOp creation in reshape function


Changes walkthrough 📝

Relevant files
Enhancement
12 files
fusion_segmenter.cpp
Rename precision function and update usage                             
+5/-5     
utils.cpp
Rename precision function and update implementation           
+3/-4     
matmul.cpp
Add sub-byte data type support in MatmulScheduler               
+9/-0     
normalization_inner_outer.cpp
Add sub-byte data type support in InnerOuterPersistentKernelScheduler
+10/-0   
normalization_utils.cpp
Add sub-byte data type support in compileTimeCheck             
+9/-0     
reduction.cpp
Add sub-byte data type support in ReductionScheduler         
+10/-0   
resize.cpp
Add sub-byte data type support in ResizeScheduler               
+9/-0     
transpose.cpp
Add sub-byte data type support in TransposeScheduler         
+9/-0     
utils.cpp
Update precision function usage                                                   
+1/-1     
test_low_precision_recipe.cpp
Update fusion creation and add FusionExecutorCache usage 
+30/-14 
helpers.cu
Add fmax and abs overloads for __half and __bfloat             
+25/-0   
utils.h
Rename precision function and update documentation             
+3/-3     
Bug fix
1 files
alias.cpp
Fix ViewOp creation in reshape function                                   
+1/-1     
Tests
1 files
test_gpu3.cpp
Update precision checks in test cases                                       
+7/-7     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Precision Calculation

The new function getPrecisionOfProducerConsumerTensorsBit calculates precision in bits, but the test cases in tests/cpp/test_gpu3.cpp and tests/cpp/test_low_precision_recipe.cpp still use byte-based expectations. Ensure that the test cases are updated to reflect the new bit-based precision calculations.

getPrecisionOfProducerConsumerTensorsBit(UnaryOp* uop) {
  NVF_CHECK(uop != nullptr);
  NVF_CHECK(
      uop->getUnaryOpType() == UnaryOpType::Cast,
      "Invalid expr: ",
      uop->toString());

  auto inp_tv = ir_utils::getTvInput(uop);
  auto out_tv = ir_utils::getTvOutput(uop);
  if (inp_tv == nullptr || out_tv == nullptr) {
    return std::nullopt;
  }

  auto inp_dtype = inp_tv->dtype().type;
  auto out_dtype = out_tv->dtype().type;
  auto inp_prim_type = std::get_if<PrimDataType>(&inp_dtype);
  auto out_prim_type = std::get_if<PrimDataType>(&out_dtype);

  if (inp_prim_type == nullptr || out_prim_type == nullptr ||
      *inp_prim_type == PrimDataType::Index ||
      *out_prim_type == PrimDataType::Index) {
    return std::nullopt;
  }

  return std::make_pair(
      primDataTypeSizeBit(*inp_prim_type), primDataTypeSizeBit(*out_prim_type));
}
Function Naming

The new function getProducerConsumerPrecisionBit has a misleading name as it returns a pair of precision values in bits, not a single bit value. Consider renaming the function to better reflect its purpose.

std::optional<std::pair<int64_t, int64_t>> getProducerConsumerPrecisionBit(
    SegmentedGroup* group) const {
  if (group->exprs().size() != 1) {
    return std::nullopt;
  }

  auto uop = dynamic_cast<UnaryOp*>(group->exprs().front());
  if (uop == nullptr || uop->getUnaryOpType() != UnaryOpType::Cast) {
    return std::nullopt;
  }

  return ir_utils::getPrecisionOfProducerConsumerTensorsBit(uop);
}
Template Function

The abs function template is defined for all types, but then specialized for __half and __bfloat. This can lead to ambiguity or unexpected behavior. Consider using SFINAE or overloading instead of templating.

template <typename T>
__device__ T abs(T a) {
  return a > 0 ? a : -a;
}

__device__ __half abs(__half a) {
  return __float2half(fabs(__half2float(a)));
}

__device__ __bfloat abs(__bfloat a) {
  return __float2bfloat(fabs(__bfloat2float(a)));
}

@zasdfgbnm
Copy link
Collaborator Author

!test

@zasdfgbnm zasdfgbnm marked this pull request as ready for review July 29, 2025 16:04
@zasdfgbnm zasdfgbnm requested a review from naoyam July 29, 2025 16:05
@zasdfgbnm
Copy link
Collaborator Author

!test

@zasdfgbnm zasdfgbnm requested a review from naoyam July 29, 2025 16:54
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zasdfgbnm zasdfgbnm merged commit 2996760 into main Jul 29, 2025
45 of 48 checks passed
@zasdfgbnm zasdfgbnm deleted the pipe-fp4 branch July 29, 2025 19:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants