From 64bceed1c670fe9abd742469253d2065e0c33575 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 17:27:42 +0000 Subject: [PATCH] Do not assume mul-sum pattern inputs are TensorView This was a change I made to handle casts that wound up breaking some tests and benchmarks in #2272, leading to dynamic cast errors or segfaults. The solution is to test the type of the left and right hand sides before processing the pattern matching. --- csrc/scheduler/mma_utils.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 1b05d80bccf..986a9169df9 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1420,8 +1420,14 @@ class MatmulPatternMatcher : IterVisitor { // the Fusion was segmented and casts to half precision were inserted at // the segmentation edge (see castInputOutputToLowerPrecision in // fusion_segmenter.cpp). - TensorView* ltv = getTensorviewPriorToCast(bop->lhs()->as()); - TensorView* rtv = getTensorviewPriorToCast(bop->rhs()->as()); + TensorView* ltv = dynamic_cast(bop->lhs()); + TensorView* rtv = dynamic_cast(bop->rhs()); + if (ltv == nullptr || rtv == nullptr) { + // Found a scalar input + return; + } + ltv = getTensorviewPriorToCast(ltv); + rtv = getTensorviewPriorToCast(rtv); std::vector lrf = TensorDomain::noReductions(ltv->getMaybeRFactorDomain());