diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index b188a14b47c..0c9ee7b061a 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -27,44 +27,6 @@ namespace nvfuser { -bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { - if (!isResharding(expr)) { - return true; - } - if (!ir_utils::isTvOp(expr)) { - return false; - } - if (auto* reduction = dynamic_cast(expr)) { - if (!ignore_inner_resharding && !isCommunicationLayoutCompliant(expr)) { - return false; - } - auto in = reduction->in()->as(); - auto out = reduction->out()->as(); - // get the reduced axis - std::vector reduction_axis; - std::copy_if( - out->getLogicalDomain().begin(), - out->getLogicalDomain().end(), - std::back_inserter(reduction_axis), - [](IterDomain* id) { return id->isReduction(); }); - // check whether the reduction involves only one axis - if (reduction_axis.size() != 1) { - return false; - } - // We check whether the reduced axis is sharded on the input - const auto c2p_map = - PairwiseLogicalDomainMap(in, out).mapConsumerToProducer(); - auto c2p_map_it = c2p_map.find(reduction_axis.at(0)); - return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim(); - } else if (auto* ldst = dynamic_cast(expr)) { - if (!ignore_inner_resharding && !isCommunicationLayoutCompliant(expr)) { - return false; - } - return ldst->as()->opType() == LoadStoreOpType::Set; - } - return false; -} - bool HostIrLower::isLowerableAsStandaloneHostOp(Expr* expr) { if (expr->isOneOf< MatmulOp, diff --git a/csrc/host_ir/lower.h b/csrc/host_ir/lower.h index 8df156d4512..8bc4f37eca2 100644 --- a/csrc/host_ir/lower.h +++ b/csrc/host_ir/lower.h @@ -24,11 +24,6 @@ class HostIrLower { explicit HostIrLower(const HostIrLowerParams& params = HostIrLowerParams()) : params_(params) {} - // The flag `ignore_inner_resharding` is useful because the preseg passes - // `InsertReshardingsPass` and `ReorderShardedAxisPass` want different - // behaviors - static bool canLower(Expr* expr, bool ignore_inner_resharding = false); - // Lower a sharded Expr into a series of Communication. std::vector lower(Expr* c, DeviceIdxType my_device_index);