diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 1b05d80bccf..986a9169df9 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1420,8 +1420,14 @@ class MatmulPatternMatcher : IterVisitor { // the Fusion was segmented and casts to half precision were inserted at // the segmentation edge (see castInputOutputToLowerPrecision in // fusion_segmenter.cpp). - TensorView* ltv = getTensorviewPriorToCast(bop->lhs()->as()); - TensorView* rtv = getTensorviewPriorToCast(bop->rhs()->as()); + TensorView* ltv = dynamic_cast(bop->lhs()); + TensorView* rtv = dynamic_cast(bop->rhs()); + if (ltv == nullptr || rtv == nullptr) { + // Found a scalar input + return; + } + ltv = getTensorviewPriorToCast(ltv); + rtv = getTensorviewPriorToCast(rtv); std::vector lrf = TensorDomain::noReductions(ltv->getMaybeRFactorDomain());