Skip to content

Fix detection of ldmatrix/stmatrix#5645

Merged
naoyam merged 5 commits intomainfrom
fix_ld_st_matrix_indexing
Dec 9, 2025
Merged

Fix detection of ldmatrix/stmatrix#5645
naoyam merged 5 commits intomainfrom
fix_ld_st_matrix_indexing

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Dec 9, 2025

This was found while I was experimenting TensorIndexer with the matmul tests (#5574). Ldmatrix and stmatrix use a special domain as an alternative loop domain for indexing. IIUC, we should not use the alternate domains when initializing tensors. This happens, for example, a tensor is defined by an stmatrix op but is also initialized to zero for predicate elimination. Looks like the initialization should not be done at all, but I think that's a separate issue.

Please see https://github.com/NVIDIA/Fuser/pull/5645/files#r2600720284. The other changes are just due to this change.

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 9, 2025

!test --diff

@github-actions
Copy link

github-actions bot commented Dec 9, 2025

Review updated until commit cb755ef

Description

  • Add explicit ld_st_matrix parameter to indexing functions for ldmatrix/stmatrix detection

  • Fix alternate loop domain usage for shared memory tensors in ldmatrix/stmatrix operations

  • Prevent incorrect indexing when initialization ops exist on same tensor as ldmatrix/stmatrix

  • Update function signatures across IndexLowering, TensorIndexer, and Index classes

Changes walkthrough

Relevant files
Bug_fix
index.cpp
Add ld_st_matrix parameter and detection logic                     

csrc/device_lower/pass/index.cpp

  • Added ld_st_matrix parameter to lowerSrcIndex and lowerDstIndex
    functions
  • Added detection of ldmatrix/stmatrix operations: const bool
    ld_st_matrix = ir_utils::isLdMatrixOp(ldst) ||
    ir_utils::isStMatrixOp(ldst);
  • Pass ld_st_matrix flag through indexing calls for proper alternate
    domain handling
  • Updated calls to getLinearIndex and index lowering functions to
    include the new parameter
  • +31/-10 
    indexing.cpp
    Refactor alternate domain detection with explicit parameter

    csrc/id_model/indexing.cpp

  • Renamed isSharedMemoryTvForLdStMatrix to shouldUseAlternateLoopDomain
    with explicit ld_st_matrix parameter
  • Updated getLinearIndex and getContigIndexFor to accept and use
    ld_st_matrix parameter
  • Fixed alternate loop domain detection to use explicit flag rather than
    checking expr directly
  • Added error checking to ensure expr is actually ldmatrix/stmatrix when
    flag is true
  • +30/-7   
    index_compute.cpp
    Add ld_st_matrix parameter to Index functions                       

    csrc/index_compute.cpp

  • Added ld_st_matrix parameter to Index::getProducerIndex and
    Index::getConsumerIndex functions
  • Pass ld_st_matrix parameter through to tensorIndexer().getLinearIndex
    calls
  • Updated function signatures to match changes in indexing.cpp
  • +6/-4     
    index.h
    Update IndexLowering header declarations                                 

    csrc/device_lower/pass/index.h

  • Updated lowerSrcIndex and lowerDstIndex declarations to include
    ld_st_matrix parameter with default value false
  • Updated function documentation to reflect new parameter
  • +4/-2     
    indexing.h
    Update TensorIndexer header declarations                                 

    csrc/id_model/indexing.h

  • Updated getLinearIndex and getContigIndexFor declarations to include
    ld_st_matrix parameter with default value false
  • Updated TensorIndexer class method signatures
  • +4/-2     
    index_compute.h
    Update Index header declarations                                                 

    csrc/index_compute.h

  • Updated getProducerIndex and getConsumerIndex declarations to include
    ld_st_matrix parameter with default value false
  • Updated Index class method signatures
  • +4/-2     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review
    Parameter Usage Consistency

    The new ld_st_matrix parameter is added to lowerSrcIndex and lowerDstIndex functions but is only used in specific code paths (lines 2127, 2189, 2227, 2240). Verify that all call sites properly handle this parameter and that there are no missing usages where ldmatrix/stmatrix operations should be detected.

    Val* IndexLowering::lowerSrcIndex(
        Val* src,
        Val* dst,
        const std::unordered_map<IterDomain*, Val*>& override_index,
        bool generate_pointer,
        DataType as_type,
        bool ld_st_matrix) const {
      if (auto tv = dynamic_cast<TensorView*>(src)) {
        NVF_ERROR(dst->isA<TensorView>());
        kir::TensorIndex* tind = Index::getProducerIndex(
            tv,
            dst->as<TensorView>(),
            for_loops_,
            getRotatedLoop(),
            override_index,
            generate_pointer,
            as_type,
            ld_st_matrix);
        if (TensorView* aliased_producer =
                GpuLower::current()->getTensorProducerAlias(tv)) {
          return IrBuilder::create<kir::TensorIndex>(
              aliased_producer, tind->index());
        } else {
          return tind;
        }
      } else {
        return src;
      }
    }
    
    Val* IndexLowering::lowerDstIndex(
        Val* dst,
        const std::unordered_map<IterDomain*, Val*>& override_index,
        bool generate_pointer,
        DataType as_type,
        bool ld_st_matrix) const {
      if (auto tv = dynamic_cast<TensorView*>(dst)) {
        return Index::getConsumerIndex(
            tv,
            for_loops_,
            getRotatedLoop(),
            override_index,
            generate_pointer,
            as_type,
            ld_st_matrix);
      } else {
        return dst;
      }
    }
    Logic Validation

    The shouldUseAlternateLoopDomain function now requires both ld_st_matrix flag to be true AND the expression to be a valid ldmatrix/stmatrix op. This double-check approach is sound, but verify that the ld_st_matrix flag is always correctly set based on the actual operation type to avoid missing legitimate ldmatrix/stmatrix cases.

    bool shouldUseAlternateLoopDomain(
        TensorView* tv,
        const Expr* expr,
        bool ld_st_matrix) {
      // short-circuit: not (ldmatrix or stmatrix)
      if (!ld_st_matrix) {
        return false;
      }
    
      NVF_ERROR(
          ir_utils::isLdMatrixOp(expr) || ir_utils::isStMatrixOp(expr),
          "Unexpected expr: ",
          expr->toString());
    
      // short-circuit: only the shared memory TensorView uses alternate loop
      // domain. For ldmatrix, it is the input TensorView. For stmatrix, it is the
      // output TensorView.
      if (tv->getMemoryType() != MemoryType::Shared) {
        return false;
    Flag Detection Logic

    The ld_st_matrix flag is computed once at the beginning of handle(const LoadStoreOp* ldst) using ir_utils::isLdMatrixOp(ldst) || ir_utils::isStMatrixOp(ldst). This is then passed to all indexing functions. Ensure this detection logic covers all relevant cases and that the flag propagation is consistent throughout the function.

    const bool ld_st_matrix =
        ir_utils::isLdMatrixOp(ldst) || ir_utils::isStMatrixOp(ldst);
    

    Comment on lines +932 to +936
    // stmatrix. Note that the explicit bool indicator of the expr is
    // required to correctly determine it is a ldmatrix/stmatrix op since
    // there can be an initialization op using the same output tensor
    // after the allocation lowering pass.
    bool shouldUseAlternateLoopDomain(
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is a bit ugly fix, but at the time of index lowering, the expr parameter doesn't necessarily mean it's the actual expression for the lowered operation. The state of the fusion program is not well defined here as we are still building the Kernel IR program. After the allocation lowering, a TensorView can have multiple defining expressions due to, e.g., initializations of buffers, and thus it's no longer SSA. What tv->definition() returns is the original expr, but we may be using it even when lowering the initialization.

    That could cause a problem here since even though expr is a ldmatrix or stmatrix, it may not correspond to the actual op such as initializations. To find if the actual op is indeed ldmatrix or stmatrix, that information needs to be passed down from the indexing pass itself.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 9, 2025

    !test --diff

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 9, 2025

    !test --diff

    @naoyam naoyam marked this pull request as ready for review December 9, 2025 17:46
    @naoyam naoyam requested a review from rdspring1 December 9, 2025 17:46
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 9, 2025

    Greptile Overview

    Greptile Summary

    This PR fixes a bug in the detection of ldmatrix/stmatrix operations for indexing. The issue was that the previous code would infer whether to use alternate loop domains based on the expression type, but this approach failed when a tensor defined by stmatrix was also initialized (e.g., for predicate elimination). The initialization operation would incorrectly use the alternate loop domain intended only for the actual ld/stmatrix operation.

    Key Changes:

    • Added explicit ld_st_matrix boolean parameter that is propagated through the indexing call chain (lowerSrcIndexgetProducerIndexgetLinearIndexgetContigIndexForshouldUseAlternateLoopDomain)
    • Renamed isSharedMemoryTvForLdStMatrix to shouldUseAlternateLoopDomain with clearer semantics
    • Added NVF_ERROR validation to ensure the explicit flag is consistent with the actual expression type

    This ensures that initialization ops on tensors that are also used by ld/stmatrix operations don't incorrectly use the alternate loop domain for indexing.

    Confidence Score: 5/5

    • This PR is safe to merge - it's a targeted bug fix with clear logic and consistent implementation across all affected files.
    • The change adds an explicit parameter to correctly identify ldmatrix/stmatrix operations instead of inferring from context. The logic is straightforward, all parameter names are consistent, and there's validation via NVF_ERROR to catch misuse. Previous review comments have been addressed.
    • No files require special attention.

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/device_lower/pass/index.cpp 5/5 Adds ld_st_matrix parameter to lowerSrcIndex and lowerDstIndex functions and propagates it through to indexing calls for ldmatrix/stmatrix operations. The parameter correctly distinguishes actual ld/stmatrix operations from initialization ops that may use the same output tensor.
    csrc/device_lower/pass/index.h 5/5 Header file updated with new ld_st_matrix parameter for lowerSrcIndex and lowerDstIndex function signatures with default value of false.
    csrc/id_model/indexing.cpp 5/5 Renames isSharedMemoryTvForLdStMatrix to shouldUseAlternateLoopDomain with explicit ld_st_matrix parameter instead of inferring from expression type. Adds NVF_ERROR validation to ensure the flag is consistent with the expression type.
    csrc/id_model/indexing.h 5/5 Adds ld_st_matrix parameter to getLinearIndex and getContigIndexFor function signatures with default value of false.
    csrc/index_compute.cpp 5/5 Propagates ld_st_matrix parameter through getProducerIndex and getConsumerIndex to TensorIndexer's getLinearIndex method.
    csrc/index_compute.h 5/5 Header updated with new ld_st_matrix parameter for getProducerIndex and getConsumerIndex function signatures with default value of false.

    Sequence Diagram

    sequenceDiagram
        participant IL as IndexLowering::handle(LoadStoreOp)
        participant Lower as lowerSrcIndex/lowerDstIndex
        participant Index as Index::getProducerIndex/getConsumerIndex
        participant TI as TensorIndexer::getLinearIndex
        participant Contig as getContigIndexFor
        participant Check as shouldUseAlternateLoopDomain
    
        IL->>IL: Determine ld_st_matrix = isLdMatrixOp || isStMatrixOp
        IL->>Lower: Pass ld_st_matrix flag
        Lower->>Index: Forward ld_st_matrix flag
        Index->>TI: Forward ld_st_matrix flag
        TI->>Contig: Forward ld_st_matrix flag
        Contig->>Check: shouldUseAlternateLoopDomain(tv, expr, ld_st_matrix)
        Check-->>Check: if !ld_st_matrix return false
        Check-->>Check: NVF_ERROR validates expr matches flag
        Check-->>Contig: Return whether to use alternate domain
        Contig-->>TI: Return computed indices
        TI-->>Index: Return linear index
        Index-->>Lower: Return TensorIndex
        Lower-->>IL: Return indexed value
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Copy link
    Collaborator

    @rdspring1 rdspring1 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM.

    Greptile flagged changing variable name from is_st_matrix to ls_st_matrix in csrc/index_compute.cpp

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 9, 2025

    !build

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @naoyam naoyam merged commit b8b33e9 into main Dec 9, 2025
    18 checks passed
    @naoyam naoyam deleted the fix_ld_st_matrix_indexing branch December 9, 2025 22:09
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants