Skip to content

fix#4769

Merged
zasdfgbnm merged 3 commits intofp4-cast-fecfrom
cast-fec-fix
Jul 11, 2025
Merged

fix#4769
zasdfgbnm merged 3 commits intofp4-cast-fecfrom
cast-fec-fix

Conversation

@zasdfgbnm
Copy link
Collaborator

No description provided.

@zasdfgbnm zasdfgbnm marked this pull request as ready for review July 11, 2025 00:55
@github-actions
Copy link

github-actions bot commented Jul 11, 2025

Review updated until commit f8c9361

Description

  • Renamed function to clarify bit-based buffer size calculation

  • Updated buffer size calculation to use bit-based data type size

  • Ensured consistent naming and functionality across related functions


Changes walkthrough 📝

Relevant files
Enhancement
normalization_inner_outer_utils.cpp
Rename and update buffer size calculation                               

csrc/scheduler/normalization_inner_outer_utils.cpp

  • Renamed partialOuterReductionBufferSize to
    partialOuterReductionBufferSizeBit
  • Updated variable names to reflect bit-based calculations
  • Changed data type size calculation to use dataTypeSizeBit
  • +14/-14 
    normalization_inner_outer_utils.h
    Rename buffer size function                                                           

    csrc/scheduler/normalization_inner_outer_utils.h

  • Renamed partialOuterReductionBufferSize to
    partialOuterReductionBufferSizeBit
  • +1/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Naming Consistency

    The function name partialOuterReductionBufferSizeBit suggests it returns a size in bits, but the original function partialOuterReductionBufferSize implies a size in bytes. Ensure the naming is consistent with the actual units returned.

    int64_t partialOuterReductionBufferSizeBit(
        const std::vector<TensorView*>& reduction_tvs,
        SchedulerRuntimeInfo& runtime_info) {
      int64_t partial_reduction_buffer_size_bit = 0;
      for (auto buffer : reduction_tvs) {
        if (scheduler_utils::isFastestDimReduction(buffer)) {
          continue;
        }
        int64_t buffer_size_bit = -1;
        for (auto id : buffer->getLogicalDomain()) {
          if (id->isReduction() || id->isBroadcast()) {
            continue;
          }
          auto id_size = runtime_info.expressionEvaluator().evaluate(id->extent());
          NVF_ERROR(id_size.hasValue(), "Could not infer persistent buffer size.");
          if (buffer_size_bit == -1) {
            buffer_size_bit = id_size.as<int64_t>();
          } else {
            buffer_size_bit *= id_size.as<int64_t>();
          }
        }
        buffer_size_bit = (buffer_size_bit == -1) ? 0
                                                  : buffer_size_bit *
                dataTypeSizeBit(buffer->getDataType().value(),
                                runtime_info.getIndexType());
        partial_reduction_buffer_size_bit += buffer_size_bit;
      }
      return partial_reduction_buffer_size_bit;
    }
    Error Handling

    The function partialOuterReductionBufferSizeBit initializes buffer_size_bit to -1 and checks if it remains -1 before using it. Ensure this logic is correct and consider adding more robust error handling.

    int64_t buffer_size_bit = -1;
    for (auto id : buffer->getLogicalDomain()) {
      if (id->isReduction() || id->isBroadcast()) {
        continue;
      }
      auto id_size = runtime_info.expressionEvaluator().evaluate(id->extent());
      NVF_ERROR(id_size.hasValue(), "Could not infer persistent buffer size.");
      if (buffer_size_bit == -1) {
        buffer_size_bit = id_size.as<int64_t>();
      } else {
        buffer_size_bit *= id_size.as<int64_t>();
      }
    }
    buffer_size_bit = (buffer_size_bit == -1) ? 0
                                              : buffer_size_bit *
    Data Type Conversion

    The function partialOuterReductionBufferSizeBit uses dataTypeSizeBit instead of dataTypeSizeByte. Verify that this change is intentional and that it does not introduce any unintended behavior.

    dataTypeSizeBit(buffer->getDataType().value(),
                    runtime_info.getIndexType());

    @zasdfgbnm zasdfgbnm merged commit 1275488 into fp4-cast-fec Jul 11, 2025
    8 of 9 checks passed
    @zasdfgbnm zasdfgbnm deleted the cast-fec-fix branch July 11, 2025 01:03
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant