diff --git a/poetry/puzzle/solver.py b/poetry/puzzle/solver.py index 386d8f95a7f..8dd4fafaaaa 100644 --- a/poetry/puzzle/solver.py +++ b/poetry/puzzle/solver.py @@ -1,8 +1,8 @@ +import enum import time +from collections import defaultdict from contextlib import contextmanager -from typing import Any -from typing import Dict from typing import List from typing import Optional @@ -221,128 +221,197 @@ def _solve(self, use_latest=None): except SolveFailure as e: raise SolverProblemError(e) - graph = self._build_graph(self._package, packages) + results = dict( + depth_first_search( + PackageNode(self._package, packages), aggregate_package_nodes + ) + ) + # Return the packages in their original order with associated depths + final_packages = packages + depths = [results[package] for package in packages] - depths = [] - final_packages = [] - for package in packages: - category, optional, depth = self._get_tags_for_package(package, graph) + return final_packages, depths - package.category = category - package.optional = optional - depths.append(depth) - final_packages.append(package) +class DFSNode(object): + def __init__(self, id, name): + self.id = id + self.name = name - return final_packages, depths + def reachable(self): + return [] + + def visit(self, parents): + pass + + def __str__(self): + return str(self.id) + + +class VisitedState(enum.Enum): + Unvisited = 0 + PartiallyVisited = 1 + Visited = 2 + + +def depth_first_search(source, aggregator): + back_edges = defaultdict(list) + visited = {} + topo_sorted_nodes = [] + + dfs_visit(source, back_edges, visited, topo_sorted_nodes) + + # Combine the nodes by name + combined_nodes = defaultdict(list) + name_children = defaultdict(list) + for node in topo_sorted_nodes: + node.visit(back_edges[node.id]) + name_children[node.name].extend(node.reachable()) + combined_nodes[node.name].append(node) + + combined_topo_sorted_nodes = [] + for node in topo_sorted_nodes: + if node.name in combined_nodes: + combined_topo_sorted_nodes.append(combined_nodes.pop(node.name)) + + results = [ + aggregator(nodes, name_children[nodes[0].name]) + for nodes in combined_topo_sorted_nodes + ] + return results + + +def dfs_visit(node, back_edges, visited, sorted_nodes): + if visited.get(node.id, VisitedState.Unvisited) == VisitedState.Visited: + return True + if visited.get(node.id, VisitedState.Unvisited) == VisitedState.PartiallyVisited: + # We have a circular dependency. + # Since the dependencies are resolved we can + # simply skip it because we already have it + return True + + visited[node.id] = VisitedState.PartiallyVisited + for neighbor in node.reachable(): + back_edges[neighbor.id].append(node) + if not dfs_visit(neighbor, back_edges, visited, sorted_nodes): + return False + visited[node.id] = VisitedState.Visited + sorted_nodes.insert(0, node) + return True + + +class PackageNode(DFSNode): + def __init__( + self, + package, + packages, + previous=None, + previous_dep=None, + dep=None, + is_activated=True, + ): + self.package = package + self.packages = packages + + self.previous = previous + self.previous_dep = previous_dep + self.dep = dep + self.depth = -1 - def _build_graph( - self, package, packages, previous=None, previous_dep=None, dep=None - ): # type: (...) -> Dict[str, Any] if not previous: - category = "dev" - optional = True + self.category = "dev" + self.optional = True else: - category = dep.category - optional = dep.is_optional() and not dep.is_activated() + self.category = dep.category + self.optional = dep.is_optional() and not dep.is_activated() + if not is_activated: + self.optional = True + super(PackageNode, self).__init__( + (package.name, self.category, self.optional), package.name + ) - childrens = [] # type: List[Dict[str, Any]] - graph = { - "name": package.name, - "category": category, - "optional": optional, - "children": childrens, - } + def reachable(self): + children = [] # type: List[PackageNode] - if previous_dep and previous_dep is not dep and previous_dep.name == dep.name: - return graph + if ( + self.previous_dep + and self.previous_dep is not self.dep + and self.previous_dep.name == self.dep.name + ): + return [] - for dependency in package.all_requires: + for dependency in self.package.all_requires: is_activated = True if dependency.is_optional(): - if not package.is_root() and ( - not previous_dep or not previous_dep.extras + if not self.package.is_root() and ( + not self.previous_dep or not self.previous_dep.extras ): continue is_activated = False - for group, extra_deps in package.extras.items(): - if dep: - extras = previous_dep.extras - elif package.is_root(): - extras = package.extras + for group, extra_deps in self.package.extras.items(): + if self.dep: + extras = self.previous_dep.extras + elif self.package.is_root(): + extras = self.package.extras else: extras = [] if group in extras and dependency.name in ( - d.name for d in package.extras[group] + d.name for d in self.package.extras[group] ): is_activated = True break - if previous and previous["name"] == dependency.name: + if self.previous and self.previous.package.name == dependency.name: # We have a circular dependency. # Since the dependencies are resolved we can # simply skip it because we already have it + # N.B. this only catches cycles of length 2; + # dependency cycles in general are handled by the DFS traversal continue - for pkg in packages: + for pkg in self.packages: if pkg.name == dependency.name and dependency.constraint.allows( pkg.version ): # If there is already a child with this name # we merge the requirements - existing = None - for child in childrens: - if ( - child["name"] == pkg.name - and child["category"] == dependency.category - ): - existing = child - continue - - child_graph = self._build_graph( - pkg, packages, graph, dependency, dep or dependency - ) - - if not is_activated: - child_graph["optional"] = True - - if existing: + if any( + child.package.name == pkg.name + and child.category == dependency.category + for child in children + ): continue - - childrens.append(child_graph) - - return graph - - def _get_tags_for_package(self, package, graph, depth=0): - categories = ["dev"] - optionals = [True] - _depths = [0] - - children = graph["children"] - for child in children: - if child["name"] == package.name: - category = child["category"] - optional = child["optional"] - _depths.append(depth) - else: - (category, optional, _depth) = self._get_tags_for_package( - package, child, depth=depth + 1 - ) - - _depths.append(_depth) - - categories.append(category) - optionals.append(optional) - - if "main" in categories: - category = "main" - else: - category = "dev" - - optional = all(optionals) - - depth = max(*(_depths + [0])) - - return category, optional, depth + children.append( + PackageNode( + pkg, + self.packages, + self, + dependency, + self.dep or dependency, + is_activated=is_activated, + ) + ) + return children + + def visit(self, parents): + # The root package, which has no parents, is defined as having depth -1 + # So that the root package's top-level dependencies have depth 0. + self.depth = 1 + max([parent.depth for parent in parents] + [-2]) + + +def aggregate_package_nodes(nodes, children): + package = nodes[0].package + depth = max(node.depth for node in nodes) + category = ( + "main" if any(node.category == "main" for node in children + nodes) else "dev" + ) + optional = all(node.optional for node in children + nodes) + for node in nodes: + node.depth = depth + node.category = category + node.optional = optional + package.category = category + package.optional = optional + return package, depth diff --git a/tests/puzzle/test_solver.py b/tests/puzzle/test_solver.py index c3540a36591..58f7ffe3ac1 100644 --- a/tests/puzzle/test_solver.py +++ b/tests/puzzle/test_solver.py @@ -746,6 +746,62 @@ def test_solver_circular_dependency(solver, repo, package): assert "main" == ops[0].package.category +def test_solver_circular_dependency_chain(solver, repo, package): + package.add_dependency("A") + + package_a = get_package("A", "1.0") + package_a.add_dependency("B", "^1.0") + + package_b = get_package("B", "1.0") + package_b.add_dependency("C", "^1.0") + + package_c = get_package("C", "1.0") + package_c.add_dependency("D", "^1.0") + + package_d = get_package("D", "1.0") + package_d.add_dependency("B", "^1.0") + + repo.add_package(package_a) + repo.add_package(package_b) + repo.add_package(package_c) + repo.add_package(package_d) + + ops = solver.solve() + + check_solver_result( + ops, + [ + {"job": "install", "package": package_d}, + {"job": "install", "package": package_c}, + {"job": "install", "package": package_b}, + {"job": "install", "package": package_a}, + ], + ) + + assert "main" == ops[0].package.category + + +def test_solver_dense_dependencies(solver, repo, package): + # The root package depends on packages A0...An-1, + # And package Ai depends on packages A0...Ai-1 + # This graph is a transitive tournament + packages = [] + n = 22 + for i in range(n): + package_ai = get_package("a" + str(i), "1.0") + repo.add_package(package_ai) + packages.append(package_ai) + package.add_dependency("a" + str(i), "^1.0") + for j in range(i): + package_ai.add_dependency("a" + str(j), "^1.0") + + ops = solver.solve() + + check_solver_result( + ops, [{"job": "install", "package": packages[i]} for i in range(n)] + ) + + def test_solver_duplicate_dependencies_same_constraint(solver, repo, package): package.add_dependency("A")