Skip to content

Safe resize vectorization#3906

Merged
naoyam merged 38 commits intomainfrom
traverse_all_paths
Mar 15, 2025
Merged

Safe resize vectorization#3906
naoyam merged 38 commits intomainfrom
traverse_all_paths

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Feb 16, 2025

Fixes #3640.

The current issue with resize vectorization is that because of the spanning tree based traversal, not all paths are taken into consideration when determining vectorization factors. This PR addresses the issue by finding all Resize ops that have dependencies with vectorized IDs. To do so, added a new graph traversal class, FindAllExprs, based on BFS. Note that BFS only finds shortest paths, so not suitable when, for example, Resize appears in a cycle, which is the case with RoPE.

I was also thinking about overhauling the vectorization analysis to be more ID traversal basis, but we would still need to look at each tensor's allocation domain, so I don't think we can completely move away from tensor op traversals.

@github-actions
Copy link

github-actions bot commented Feb 16, 2025

Review updated until commit 58dcb3a

Description

  • Added FindAllExprs class for comprehensive graph traversal.

  • Updated getResizeBasedOps to include more resize operations.

  • Introduced ResizeVectorizationFactors for vectorization analysis.

  • Enhanced getVectorizationFactor to consider all resize paths.


Changes walkthrough 📝

Relevant files
Enhancement
9 files
registry.cpp
Add ResizeVectorizationFactors to HeuristicDataCacheEntry
+2/-0     
resize.cpp
Update getResizeBasedOps usage                                                     
+1/-1     
resize_utils.cpp
Implement getResizeBasedOps                                                           
+4/-0     
vectorize_helper.cpp
Introduce getResizeVectorizationFactors and update
getVectorizationFactor
+113/-0 
val_graph_visitor.cpp
Implement getAllExprGroupsBetween                                               
+28/-4   
graph_traversal.h
Introduce FindAllExprs class                                                         
+596/-0 
compile_time_info.h
Add ResizeVectorizationFactors entry type                               
+10/-0   
resize_utils.h
Declare getResizeBasedOps                                                               
+2/-0     
val_graph_visitor.h
Declare getAllExprGroupsBetween                                                   
+11/-0   
Tests
2 files
test_bfs.cpp
Add tests for FindAllExprs                                                             
+357/-3 
test_resize.cpp
Add test for vectorization with multiple paths                     
+38/-5   

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Performance Impact

The new vectorization factor calculation may have performance implications. Ensure that the new approach does not introduce unnecessary overhead or regressions in performance.

    mappers.push_back(ContiguousInnerDimensionsMapper::map(ref, logical_dom)
                          .getTvToContigMergeOfInnerSizeMap());
    logical_dom.erase(logical_dom.begin());
  }
  return mappers;
}

// This is a WAR for vectorizing through resized iter domains. The
// spanning tree based analysis is not guaranteed to take all resize
// ops into considerations (issue
// https://github.com/NVIDIA/Fuser/issues/3640). To workaround the
// limitation, grab all factors that must be divisible by a
// vectorization factors.
std::unordered_set<Val*> getResizeVectorizationFactors(
    TensorView* reference_tv,
    int64_t break_point) {
  Fusion* fusion = reference_tv->fusion();
  std::unordered_set<Val*> factors;
  const auto resize_based_ops = scheduler_tools::getResizeBasedOps(fusion);

  if (resize_based_ops.empty()) {
    return factors;
  }

  IdModel id_model(reference_tv->fusion());
  const auto& graph = id_model.buildExactGraph();

  const auto ref_groups = graph.toGroups(reference_tv->getLogicalDomain());

  // For each of resize-based tensor ops, find all resize ops
  // that exist between the vectorized reference IDs and the output
  // tensor.
  for (auto resize_based_op : resize_based_ops) {
    auto resize_out = resize_based_op->output(0)->as<TensorView>();
    NVF_ERROR(
        resize_out->hasRoot(), "Unexpected op: ", resize_based_op->toString());
    // getAllExprGroupsBetween finds exprs between IDs. To make sure
    // the the resize op of this resize_based_op tensor op is found,
    // use both the root and logical domains as the traversal targets.
    ValGroups resize_inp_out;
    resize_inp_out.pushBack(graph.toGroups(resize_out->getRootDomain()));
    resize_inp_out.pushBack(graph.toGroups(resize_out->getLogicalDomain()));

    auto expr_path = getAllExprGroupsBetween(
                         graph,
                         ref_groups,
                         resize_inp_out,
                         /*require_all_to_visited=*/false)
                         .first;

    ValGroups vectorized_groups;
    for (auto it = reference_tv->getLogicalDomain().begin() + break_point;
         it != reference_tv->getLogicalDomain().end();
         ++it) {
      vectorized_groups.pushBack(graph.toGroup(*it));
    }

    // Find all resize exprs that appear in expr_path and depend on
    // vectorized_groups. Since expr_path is not guaranteed to be
    // topologically sorted, need to loop through the path until
    // converged.

    bool something_has_changed = true;
    while (something_has_changed) {
      something_has_changed = false;
      for (const auto& [expr_g, dir] : expr_path) {
        const auto inputs = getInputsOfExprGroup(graph, expr_g, dir);
        if (std::none_of(
                inputs.begin(), inputs.end(), [&](const ValGroup& inp) {
                  return vectorized_groups.has(inp);
                })) {
          continue;
        }

        if (vectorized_groups.pushBack(
                getOutputsOfExprGroup(graph, expr_g, dir))) {
          something_has_changed = true;
        }

        auto resize = dynamic_cast<Resize*>(expr_g->front());
        if (resize == nullptr) {
          continue;
        }

        // These three vals need to be divisible
        factors.emplace(resize->leftExpand());
        factors.emplace(resize->rightExpand());
        factors.emplace(
            dir == Direction::Forward ? resize->out()->extent()
                                      : resize->in()->extent());
      }
    }
  }

  return factors;
}
Complexity

The new FindAllExprs class is quite complex. Ensure that the algorithm is efficient and that the complexity is justified by the performance gains.

// clang-format off
/*
 * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 */
// clang-format on
#pragma once

#include <bfs.h>

namespace nvfuser {

// Find all exprs reachable from from_nodes when traversing to to_nodes. Edges
// are visitd only once, but nodes may be visited multiple times. Edges are
// always between ExprT and ValT and are directed, e.g., an edge from an
// ExprGroup to a ValGroup is differentiated from an edge from the ValGroup to
// the ExprGroup, and both of them may be visited.
//
// When there's a cycle, exprs in the cycle are also included. For
// example, given a graph like (each symbol represents an expr):
//
//   A -> B -> C -> D -> E
//        ^         |
//        +--- F ---+
//
// Exprs of {A_fwd, F_bwd, B_fwd, C_fwd, D_fwd, E_fwd} would be
// returened. Note that there's no guarantee of ordering, although it
// is at least partially sorted in a topological order.
//
// The overall traversal algorithm is to start from from_nodes and
// traverse edges in both directions or only in a specified
// direction. Unlike BFS, it keeps traversing even if all
// the to_nodes are reached and stops when no further progress is
// possible. At this point, we know all the reachable edges from
// from_nodes but we are only interested in that reach to_nodes. To
// find those edges, another traversal, this time from to_ndoes, is
// done to mark all visited edges that are reachable from
// to_nodes. That gives us all the edges between from_nodes and
// to_nodes. Finally, ExprPath is returned based on the exprs of the
// edges.
//
// NOTE 1: The algorithm and the implementation is based on the BFS
// class. There's likely more efficient algorithms.
//
// NOTE 2: The returned expr path is not guaranteed to be
// topologically sorted, which is not possible for cyclic graphs.
template <
    typename ExprT,
    typename ValT,
    typename DefinitionT,
    typename UsesT,
    typename InputsT,
    typename OutputsT>
class FindAllExprs {
 public:
  using ExprType = ExprT;
  using ValType = ValT;
  using NodeType = std::variant<ExprT, ValT>;
  using ExprPath = std::vector<std::pair<ExprT, Direction>>;

  // Edge represents an edge in the graph. By definition, it must be
  // between an expr and a val.
  struct Edge {
    NodeType from;
    NodeType to;
    Edge(const ValT& from, const ExprT& to) : from(from), to(to) {}
    Edge(const ExprT& from, const ValT& to) : from(from), to(to) {}
    bool operator==(const Edge& other) const {
      return from == other.from && to == other.to;
    }
    std::string toString() const {
      std::stringstream ss;
      ss << "{" << nvfuser::toString(from) << " -> " << nvfuser::toString(to)
         << "}";
      return ss.str();
    }
  };

  struct EdgeHash {
    std::size_t operator()(const Edge& edge) const {
      return std::hash<NodeType>()(edge.from) ^ std::hash<NodeType>()(edge.to);
    }
  };

  using EdgeSet = std::unordered_set<Edge, EdgeHash>;

  virtual ~FindAllExprs() = default;

 public:
  FindAllExprs(
      DefinitionT definition,
      UsesT uses,
      InputsT inputs,
      OutputsT outputs,
      std::vector<NodeType> from,
      std::vector<NodeType> to,
      bool require_all_to_visited = true,
      Direction allowed_direction = Direction::Undefined)
      : definition_(std::move(definition)),
        uses_(std::move(uses)),
        inputs_(std::move(inputs)),
        outputs_(std::move(outputs)),
        from_nodes_(std::move(from)),
        to_nodes_(std::move(to)),
        require_all_to_visited_(require_all_to_visited),
        allowed_direction_(allowed_direction) {}

  virtual void traverseAllEdges() {
    std::deque<Edge> edges_to_visit;

    for (const auto& from_node : from_nodes_) {
      if (const ValT* from_val = std::get_if<ValT>(&from_node)) {
        for (const auto& use_expr : uses_(*from_val)) {
          Edge edge(*from_val, use_expr);
          setVisited(edge);
          for (const auto& next_edge :
               getConsumerEdges(edge, allowed_direction_)) {
            edges_to_visit.push_back(next_edge);
          }
        }
        for (const auto& def_expr : definition_(*from_val)) {
          Edge edge(*from_val, def_expr);
          setVisited(edge);
          for (const auto& next_edge :
               getConsumerEdges(edge, allowed_direction_)) {
            edges_to_visit.push_back(next_edge);
          }
        }
      } else {
        NVF_THROW(
            "Traversal from nodes are assumed to be all Vals but found: ",
            toString(from_node));
      }
    }

    bool something_was_processed = true;
    while (something_was_processed) {
      std::deque<Edge> not_ready;
      something_was_processed = false;

      while (!edges_to_visit.empty()) {
        const auto edge_to_visit = edges_to_visit.front();
        edges_to_visit.pop_front();

        // Don't visit edges multiple times even when traversing all paths
        if (isVisited(edge_to_visit)) {
          continue;
        }

        auto prev_edges = isReady(edge_to_visit);
        if (prev_edges.empty()) {
          // To stop an infinite loop, the not-ready node is not moved
          // back to the to_visit_ queue but kept in the separate
          // queue. This way, if all nodes in to_visit_ are not ready,
          // the queue would eventually become empty, which would then
          // break the inner while loop. The something_was_processed
          // flag is used to remember if there's any progress.
          not_ready.emplace_back(edge_to_visit);
          continue;
        }

        setVisited(edge_to_visit);
        for (const auto& next_edge :
             getConsumerEdges(edge_to_visit, allowed_direction_)) {
          edges_to_visit.push_back(next_edge);
        }
        something_was_processed = true;
      }

      // Something was processed. Redo the traversal.
      edges_to_visit.insert(
          edges_to_visit.end(), not_ready.begin(), not_ready.end());
    }

    if (require_all_to_visited_ && !allToNodesVisited()) {
      auto visited_nodes = getVisitedNodes();
      std::stringstream ss;
      for (const auto& to : to_nodes_) {
        if (!visited_nodes.count(to)) {
          ss << " " << toString(to);
        }
      }
      ss << " (from: ";
      for (const auto& from : from_nodes_) {
        ss << " " << toString(from);
      }
      ss << ")";
      ss << ", visited: (";
      for (const auto& visited : visited_nodes) {
        if (const ValT* v = std::get_if<ValT>(&visited)) {
          ss << " " << toString(visited);
        }
      }
      ss << ")";
      NVF_THROW("BFS traversal could not visit some nodes: ", ss.str());
    }
  }

  // Check if a node is ready to visit. If yes, return the direction
  // and the prev nodes that should be visited before the given node
  // is visited.
  virtual std::vector<Edge> isReady(const Edge& edge) const {
    Direction dir = getDirection(edge);

    // If a direction is specified, only that direction of edges are
    // allowed.
    if ((dir == Direction::Forward &&
         allowed_direction_ == Direction::Backward) ||
        (dir == Direction::Backward &&
         allowed_direction_ == Direction::Forward)) {
      return {};
    }

    if (const ExprT* e = std::get_if<ExprT>(&(edge.from))) {
      return isReady(*e, std::get<ValT>(edge.to), dir);
    } else if (const ValT* v = std::get_if<ValT>(&(edge.from))) {
      return isReady(*v, std::get<ExprT>(edge.to), dir);
    } else {
      NVF_THROW();
    }
  }

  // Check if an edge from an expr to a val is ready to visit. If this
  // is a forward edge, i.e., the val is an output of the expr, the
  // edge is ready to visit as long as all the inputs of the expr are
  // visited. If it's a backward edge, i.e., the val is an input of
  // the expr, it's ready if all of the outputs are visited. If ready,
  // the edges that this edge depends on are returned. For example, in
  // the case of a forward edge, all of the edges to from_expr are
  // returned.
  virtual std::vector<Edge> isReady(
      const ExprT& from_expr,
      const ValT& to_val,
      Direction dir) const {
    if (dir == Direction::Forward) {
      decltype(auto) inputs = inputs_(from_expr);
      if (std::all_of(
              inputs.begin(), inputs.end(), [&](const ValT& input) -> bool {
                return isVisited(Edge(input, from_expr));
              })) {
        std::vector<Edge> prev_edges;
        for (const ValT& input : inputs) {
          prev_edges.push_back(Edge(input, from_expr));
        }
        return prev_edges;
      }
    } else if (dir == Direction::Backward) {
      decltype(auto) outputs = outputs_(from_expr);
      if (std::all_of(
              outputs.begin(), outputs.end(), [&](const ValT& output) -> bool {
                return isVisited(Edge(output, from_expr));
              })) {
        std::vector<Edge> prev_edges;
        for (const ValT& output : outputs) {
          prev_edges.push_back(Edge(output, from_expr));
        }
        return prev_edges;
      }
    }

    return {};
  }

  // Check if an edge from a val to an expr is ready to visit. In the
  // case of a val, it is ready to visit as long as there's at least
  // one def or use expr that has been already visited. However, since
  // this is an edge to an expr, the edge from the same expr to this
  // val does not make this edge ready to visit. For example, even if
  // a merge producing i0 is visited, it should not automatically mean
  // the edge from i0 to the merge expr is ready to visit. Othewise,
  // the traversal would just move back and forth.
  virtual std::vector<Edge> isReady(
      const ValT& from_val,
      const ExprT& to_expr,
      Direction dir) const {
    std::vector<Edge> prev_edges;

    // Check if any def is visited
    decltype(auto) def = definition_(from_val);
    if (!def.empty()) {
      for (const ExprT& def_e : def) {
        if (def_e != to_expr && isVisited(Edge(def_e, from_val))) {
          prev_edges.emplace_back(Edge(def_e, from_val));
        }
      }
    }

    decltype(auto) uses = uses_(from_val);
    for (const ExprT& use_e : uses) {
      if (use_e != to_expr && isVisited(Edge(use_e, from_val))) {
        prev_edges.emplace_back(Edge(use_e, from_val));
      }
    }

    return prev_edges;
  }

  // Check if a given node is already visited
  virtual bool isVisited(const Edge& edge) const {
    return visited_edges_.find(edge) != visited_edges_.end();
  }

  virtual void setVisited(const Edge& edge) {
    if (visited_edges_.emplace(edge).second) {
      partially_ordered_visited_edges_.push_back(edge);
    }
  }

  // Get edges that are consumers or producers of a given edge. A
  // consumer edge of edge A->B is an edge that has node B as its from
  // node. A producer edge is an edge that has node A as its to node.
  virtual std::vector<Edge> getConsumerOrProducerEdges(
      const Edge& edge,
      bool is_consumer,
      Direction allowed_direction = Direction::Undefined) const {
    std::vector<Edge> neighbor_edges;

    auto add_to_neighbor_list = [&](const auto& from, const auto& to) -> void {
      Edge neighbor_edge(from, to);

      if (edge == neighbor_edge ||
          // Don't traverse back
          (edge.from == neighbor_edge.to && edge.to == neighbor_edge.from)) {
        return;
      }

      if (excludeFromTraversal(neighbor_edge)) {
        return;
      }

      auto neighbor_edge_dir = getDirection(neighbor_edge);
      if ((allowed_direction == Direction::Forward &&
           neighbor_edge_dir == Direction::Backward) ||
          (allowed_direction == Direction::Backward &&
           neighbor_edge_dir == Direction::Forward)) {
        return;
      }

      neighbor_edges.push_back(neighbor_edge);
    };

    Direction edge_dir = getDirection(edge);
    NVF_ERROR(
        edge_dir == Direction::Forward || edge_dir == Direction::Backward);

    const auto& node = is_consumer ? edge.to : edge.from;

    // Since the direction is forward, this edge is
    // Consumer edges are those that start from the e expr. Since
    // the direction is Forward, When grabbing consumer edges, If the node is
    // the to of the edge, the edge is from an input Val to its use Expr, so
    // traverse from the use Expr to its outputs. If the node is the from of the
    // edge, the edge is from a defining expr to one of its outputs, in that
    // case grab edges of the inputs of the expr.

    if (const ExprT* e = std::get_if<ExprT>(&node)) {
      // The from node must be a Val.

      // In the case of Expr, only consider edges of the same
      // direction
      if (edge_dir == Direction::Forward) {
        if (is_consumer) {
          // Grab consumer edges of the forward edge to the expr. The
          // edge represents a use expr of the from val. Consumers are
          // forward edges from the expr to its outputs.
          for (const auto& v : outputs_(*e)) {
            add_to_neighbor_list(*e, v);
          }
        } else {
          // Grab producer edges of the forward edge from the expr. The
          // edge represents a defining expr of the to val. Producers
          // are forward edges to the defining expr from its inputs.
          for (const auto& v : inputs_(*e)) {
            add_to_neighbor_list(v, *e);
          }
        }
      } else if (edge_dir == Direction::Backward) {
        if (is_consumer) {
          // Grab consumer edges of the backward edge to the expr. The
          // edge represents a defining expr of the from val. Consumers
          // are backward edges from the defining expr to its inputs.
          for (const auto& v : inputs_(*e)) {
            add_to_neighbor_list(*e, v);
          }
        } else {
          // Grab producer edges of the backward edge from the expr. The
          // edge represents a use expr of the from val. Produces
          // are backward edges to the use expr expr from its outputs.
          for (const auto& v : outputs_(*e)) {
            add_to_neighbor_list(v, *e);
          }
        }
      }
    } else if (const ValT* v = std::get_if<ValT>(&node)) {
      // The from node must be an Expr.

      // In the case of Val, no matter what direction this node is, it
      // should be valid to traverse both directions. Just don't
      // traverse back to the same node.

      for (const auto& e : uses_(*v)) {
        if (is_consumer) {
          // Uses of v are forward consumer edges of the edge to val v
          add_to_neighbor_list(*v, e);
        } else {
          // Uses of v are backward producer edges of the edge from val v
          add_to_neighbor_list(e, *v);
        }
      }

      for (const auto& e : definition_(*v)) {
        if (is_consumer) {
          // Defs of v are backward consumer edges of the edge to val v
          add_to_neighbor_list(*v, e);
        } else {
          // Defs of v are forward producer edges of the edge from val v
          add_to_neighbor_list(e, *v);
        }
      }
    }

    return neighbor_edges;
  }

  // Get edges that should be traversed from the to node of a given edge
  virtual std::vector<Edge> getConsumerEdges(
      const Edge& edge,
      Direction allowed_direction = Direction::Undefined) const {
    return getConsumerOrProducerEdges(
        edge, /*is_consumer=*/true, allowed_direction);
  }

  // Get edges that should be traversed before the from node of a given edge
  virtual std::vector<Edge> getProducerEdges(
      const Edge& edge,
      Direction allowed_direction = Direction::Undefined) const {
    return getConsumerOrProducerEdges(
        edge, /*is_consumer=*/false, allowed_direction);
  }

  // Check if all to_ are visited
  virtual bool allToNodesVisited() const {
    auto visited_nodes = getVisitedNodes();
    return std::all_of(
        to_nodes_.begin(), to_nodes_.end(), [&](const NodeType& node) -> bool {
          return visited_nodes.count(node);
        });
  };

  // Hook to exclude certain graph edges.
  virtual bool excludeFromTraversal(const Edge& edge) const {
    return false;
  }

  // If an edge is from a val to its use expr, it's a forward
  // edge. Similarly, it's also a forward edge if it's an expr to one
  // of its outputs. Otherwise, it's a backward edge.
  Direction getDirection(const Edge& edge) const {
    if (const ExprT* from_expr = std::get_if<ExprT>(&edge.from)) {
      const ValT& to_val = std::get<ValT>(edge.to);
      decltype(auto) inputs = inputs_(*from_expr);
      if (std::find(inputs.begin(), inputs.end(), to_val) != inputs.end()) {
        return Direction::Backward;
      }
      decltype(auto) outputs = outputs_(*from_expr);
      if (std::find(outputs.begin(), outputs.end(), to_val) != outputs.end()) {
        return Direction::Forward;
      }
      NVF_THROW();
    } else if (const ValT* from_val = std::get_if<ValT>(&edge.from)) {
      const ExprT& to_expr = std::get<ExprT>(edge.to);
      decltype(auto) inputs = inputs_(to_expr);
      if (std::find(inputs.begin(), inputs.end(), *from_val) != inputs.end()) {
        return Direction::Forward;
      }
      decltype(auto) outputs = outputs_(to_expr);
      if (std::find(outputs.begin(), outputs.end(), *from_val) !=
          outputs.end()) {
        return Direction::Backward;
      }
      NVF_THROW();
    } else {
      NVF_THROW();
    }
  }

  virtual std::unordered_set<NodeType> getVisitedNodes() const {
    std::unordered_set<NodeType> visited_nodes;
    visited_nodes.insert(from_nodes_.begin(), from_nodes_.end());
    for (const auto& visited_edge : visited_edges_) {
      visited_nodes.emplace(visited_edge.from);
      visited_nodes.emplace(visited_edge.to);
    }
    return visited_nodes;
  }

  // Grab all visited edges that are reachable from from_nodes and
  // to_nodes. traverseAllEdges must have been completed.
  virtual EdgeSet getUsedEdges() const {
    NVF_ERROR(
        !require_all_to_visited_ || allToNodesVisited(),
        "Traveral is either not done or failed");

    // Traverse back from to_ nodes to from_ nodes by traversing
    // through visted edges
    std::deque<Edge> to_visit;

    // Gather all visited edges to the to_ nodes. These edges are used
    // as initial edges for the traversal below
    for (const NodeType& to_node : to_nodes_) {
      if (const ValT* to_val = std::get_if<ValT>(&to_node)) {
        for (const ExprT& use_expr : uses_(*to_val)) {
          Edge e{use_expr, *to_val};
          if (isVisited(e)) {
            to_visit.emplace_back(e);
          }
        }
        for (const ExprT& def_expr : definition_(*to_val)) {
          Edge e{def_expr, *to_val};
          if (isVisited(e)) {
            to_visit.emplace_back(e);
          }
        }
      } else {
        NVF_THROW(
            "Traversal to nodes are assumed to be all Vals but found: ",
            toString(to_node));
      }
    }

    EdgeSet used_edges;

    while (!to_visit.empty()) {
      const auto edge_to_visit = to_visit.front();
      to_visit.pop_front();

      if (used_edges.count(edge_to_visit)) {
        continue;
      }

      auto producer_edges = getProducerEdges(edge_to_visit);
      for (const Edge& producer_edge : producer_edges) {
        if (isVisited(producer_edge)) {
          to_visit.emplace_back(producer_edge);
        }
      }

      used_edges.insert(edge_to_visit);
    }

    return used_edges;
  }

  // Return ExprPath consisting of all exprs appearing between
  // from_nodes and to_ndoes. The exprs are partially topologically
  // sorted, but not completely. The ordering should be deterministic,
  // but do not assume any particular ordering.
  virtual std::pair<ExprPath, bool> getPartiallyOrderedExprs() const {
    const auto used_edges = getUsedEdges();

    VectorOfUniqueEntries<std::pair<ExprT, Direction>> expr_path;

    for (const Edge& ordered_visited_edge : partially_ordered_visited_edges_) {
      if (!used_edges.count(ordered_visited_edge)) {
        continue;
      }

      Direction edge_dir = getDirection(ordered_visited_edge);

      // Append the expr of this edge
      const ExprT& expr =
          std::get_if<ExprT>(&(ordered_visited_edge.from)) != nullptr
          ? std::get<ExprT>(ordered_visited_edge.from)
          : std::get<ExprT>(ordered_visited_edge.to);
      expr_path.pushBack(std::make_pair(expr, edge_dir));
    }

    return std::make_pair(expr_path.vector(), allToNodesVisited());
  }

 protected:
  const DefinitionT definition_;
  const UsesT uses_;
  const InputsT inputs_;
  const OutputsT outputs_;
  const std::vector<NodeType> from_nodes_;
  const std::vector<NodeType> to_nodes_;
  bool require_all_to_visited_ = true;
  Direction allowed_direction_ = Direction::Undefined;

  EdgeSet visited_edges_;
  std::vector<Edge> partially_ordered_visited_edges_;
};

} // namespace nvfuser
Naming Consistency

The function getResizeBasedOps currently returns both SliceOp and PadOp. Consider renaming the function to better reflect its purpose or adding a separate function for ResizeOp if needed.

std::vector<Expr*> getResizeBasedOps(Fusion* fusion) {
  return ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);
}

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 17, 2025

CC: @jjsjann123

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 26, 2025

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 26, 2025

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 26, 2025

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 26, 2025

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 28, 2025

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 28, 2025

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 28, 2025

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Mar 1, 2025

!test

@naoyam
Copy link
Collaborator Author

naoyam commented Mar 1, 2025

!test

@naoyam naoyam marked this pull request as ready for review March 2, 2025 01:32
@naoyam naoyam changed the title [WIP] Traverse all paths (including cycles) Safe resize vectorization Mar 2, 2025
@naoyam naoyam requested a review from jjsjann123 March 2, 2025 02:00
}

TEST_F(LinearHostIrTest, HostIrLinearOut) {
GTEST_SKIP();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporarily disabled as it prevents the binary tests running on H100.

Related: #3996

@naoyam
Copy link
Collaborator Author

naoyam commented Mar 3, 2025

!test

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some quick comments.

I still need to go over graph_traversal.h and tests.

}

auto resize_factors_entry =
HeuristicDataCacheEntry<HeuristicCompileTime::ResizeVectorizationFactors>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏

std::vector<NodeType> from,
std::vector<NodeType> to,
bool require_all_to_visited = true,
Direction allowed_direction = Direction::Undefined)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: I know this ship has sailed. But Direction::Undefined used for Direction::Forward_and_Backward is confusing to read..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was also planning to rename it to something like Direction::Unspecified.


Direction edge_dir = getDirection(edge);
NVF_ERROR(
edge_dir == Direction::Forward || edge_dir == Direction::Backward);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: we already have NVF_THROW() in getDirection, the check here seems redundant.

virtual std::vector<Edge> getConsumerOrProducerEdges(
const Edge& edge,
bool is_consumer,
Direction allowed_direction = Direction::Undefined) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naive question: it felt like is_consumer and allowed_direction needs to agree here in order for this to traverse anywhere.

Is the allowed_direction placed here expecting some future expansion? similar to that excludeFromTraversal as a placeholder?

NVM, I realized we are just trying to have a single function handling both direction of traversal....

outputs.begin(), outputs.end(), [&](const ValT& output) -> bool {
return isVisited(Edge(output, from_expr));
})) {
std::vector<Edge> prev_edges;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: why is it named prev_edges?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because those are the edges that should have been visited previously.

}
}

return std::nullopt;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a NVF_THROW() as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think so since not all of the dependencies may be satisfied yet.

continue;
}

auto prev_edges = isReady(edge_to_visit);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually we didn't use prev_edges.value(), so isReady could just return a boolean instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was actually using that information in some of prior versions. I ended up not using it, but I'm not sure if that's always the case, so I'll keep it as is for now.

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my nitpick and questions.

LGTM

auto tv1 = reshape(tv0, {10}, {2, 5});
auto tv2 = reshape(tv1, {2, 5}, {10});
// Another cycle from tv0 to tv3 and then tv4
auto tv3 = reshape(tv0, {10}, {5, 2});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, since we do not have tv1 and tv3 merged together as an add(tv1, tv3). Even though the reshape is a static split, id_model is not mapping tv1 and tv3 into the same group.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tv1's logical shape is {2, 5}, while tv5 is {5, 2}, so they are not mapped.

{graph.toGroup(tv4->getLogicalDomain().at(0)->definition()),
Direction::Forward},
{graph.toGroup(tv2->getLogicalDomain().at(0)->definition()),
Direction::Forward}};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming here, if we were to map tv1 and tv3 together, we wouldn't have 4 groups here, because in that scenario, they would be merged into only 2 distinct groups.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not mappable.

//
// A ----> B ----> D
// ^ |
// +-- C <-+
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: I struggle with mapping {tvX} into (A, B, C, D)

{{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()},
{fusion.zeroVal(), IrBuilder::create<Val>(shape[1] / 2)}});

auto tv3 = sin(tv0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick, this isn't really rope. But that doesn't matter.

another nitpick/question. Does the two sin function add anything to the test? I think we can directly slice on tv0, which would test the same pattern.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not the entire rope but just the rotation part, which is what tends to matter most.

sin is just to prevent the preseg passes to take over the slice ops.

@naoyam
Copy link
Collaborator Author

naoyam commented Mar 14, 2025

!build

@naoyam naoyam merged commit 1ce3651 into main Mar 15, 2025
15 of 16 checks passed
@naoyam naoyam deleted the traverse_all_paths branch March 15, 2025 00:08
naoyam added a commit that referenced this pull request Apr 25, 2025
This is a follow-up to #3906, which added a WAR to #3640. While it's
safe, it turned out it's just too conservative. For example, here's a
concat pattern appearing in the backward of Litgpt Llama RoPE:

```
Inputs:
  T0_g___bfloat[bS0{1}, iS1{8}, iS2{4}, iS3{8192}, iS4{128}]
  T1_g___bfloat[bS5{1}, iS6{8}, bS7{1}, iS8{8192}, iS9{128}]
  T2_g___bfloat[bS10{1}, iS11{8}, bS12{1}, iS13{8192}, iS14{128}]
Outputs:
  T8_g___bfloat[bS43{1}, iS44{8192}, iS52{6144}rf]

%kernel_math {
T3_l___bfloat[bS15{1}, iS16{8}, iS18{6}rf, iS19{8192}, iS20{128}]
   = pad( T0_g___bfloat[bS0{1}, iS1{8}, iS2{4}, iS3{8192}, iS4{128}], {0, 0, 0, 0, 0, 2, 0, 0, 0, 0} )
i31 = 0 + 4;
T4_l___bfloat[bS21{1}, iS22{8}, iS24{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS25{8192}, iS26{128}]
   = pad( T1_g___bfloat[bS5{1}, iS6{8}, bS7{1}, iS8{8192}, iS9{128}], {0, 0, 0, 0, i31, 1, 0, 0, 0, 0} )
i47 = i31 + 1;
T5_l___bfloat[bS27{1}, iS28{8}, iS30{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS31{8192}, iS32{128}]
   = pad( T2_g___bfloat[bS10{1}, iS11{8}, bS12{1}, iS13{8192}, iS14{128}], {0, 0, 0, 0, i47, 0, 0, 0, 0, 0} )
T6_l___bfloat[bS33{1}, iS34{8}, iS35{6}, iS36{8192}, iS37{128}]
   = cat( T3_l___bfloat[bS15{1}, iS16{8}, iS18{6}rf, iS19{8192}, iS20{128}], T4_l___bfloat[bS21{1}, iS22{8}, iS24{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS25{8192}, iS26{128}], T5_l___bfloat[bS27{1}, iS28{8}, iS30{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS31{8192}, iS32{128}], 2 )
T7_l___bfloat[bS38{1}, iS41{8192}, iS39{8}, iS40{6}, iS42{128}]
   = Set.Permute( T6_l___bfloat[bS33{1}, iS34{8}, iS35{6}, iS36{8192}, iS37{128}], cache_op=Streaming )
T8_g___bfloat[bS43{1}, iS44{8192}, iS52{6144}rf] = view( T7_l___bfloat[bS38{1}, iS41{8192}, iS39{8}, iS40{6}, iS42{128}] )
} // %kernel_math
```

This is currently taken by the pointwise scheduler, which attempts to
vectorize the innermost ID of the output (i.e., `iS52{6144}`). Since the
resize ops of the three pad ops are reachable from `iS52`, the WAR of
#3640 simply takes them into consideration by calculating gcd with the
left and right expand factors. In this case, since there's an expand
factor of 1, the resulting vectorization factor is also just 1, which is
clearly not what we want. Here, while the resized ID itself is not
vectorizable due to the expand factor of 1, all of the resized tensors
have large enough inner IDs that should allow the maximum vectorization.

To make the WAR a little less conservative, this PR also checks if the
constraint by a Resize expr may be missed by the vectorization analysis.
In the above case, that should not happen as there's only one path
through each of the resize-based tensor ops.

This change is still not able to eliminate false positives completely.
See one of the new tests that is currently disabled.

The codediff results all seem to make sense. http://nv/eFb. Previously
some of the tests did not have vectorization due to the WAR, which is
relaxed in this PR and allows some vectorization.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Vectorization analysis returns wrong Vectorization Factor

2 participants