diff --git a/src/poetry/puzzle/solver.py b/src/poetry/puzzle/solver.py index 26810cc2aeb..539f7eeb6f1 100644 --- a/src/poetry/puzzle/solver.py +++ b/src/poetry/puzzle/solver.py @@ -40,7 +40,11 @@ from poetry.repositories import RepositoryPool from poetry.utils.env import Env - MarkerOriginDict = defaultdict[Package, defaultdict[Package, BaseMarker]] + # markers[child_package][parent_package][groups] -> BaseMarker + MarkerOriginDict = defaultdict[ + Package, + defaultdict[Package, defaultdict[frozenset[NormalizedName], BaseMarker]], + ] class Solver: @@ -238,7 +242,9 @@ def depth_first_search( source: PackageNode, ) -> tuple[list[list[PackageNode]], MarkerOriginDict]: back_edges: dict[DFSNodeID, list[PackageNode]] = defaultdict(list) - markers: MarkerOriginDict = defaultdict(lambda: defaultdict(EmptyMarker)) + markers: MarkerOriginDict = defaultdict( + lambda: defaultdict(lambda: defaultdict(EmptyMarker)) + ) visited: set[DFSNodeID] = set() topo_sorted_nodes: list[PackageNode] = [] @@ -272,12 +278,16 @@ def dfs_visit( for out_neighbor in node.reachable(): back_edges[out_neighbor.id].append(node) - marker = markers[out_neighbor.package][node.package] - markers[out_neighbor.package][node.package] = marker.union( + groups = out_neighbor.groups + prev_marker = markers[out_neighbor.package][node.package][groups] + new_marker = ( out_neighbor.marker if node.package.is_root() else out_neighbor.marker.without_extras() ) + markers[out_neighbor.package][node.package][groups] = prev_marker.union( + new_marker + ) dfs_visit(out_neighbor, back_edges, visited, sorted_nodes, markers) sorted_nodes.insert(0, node) @@ -398,20 +408,36 @@ def calculate_markers( transitive_marker: dict[NormalizedName, BaseMarker] = { group: EmptyMarker() for group in transitive_info.groups } - for parent, m in markers[package].items(): + for parent, group_markers in markers[package].items(): parent_info = packages[parent] if parent_info.groups: + # If parent has groups, we need to intersect its per-group + # markers with each edge marker and union into child's groups. if parent_info.groups != set(parent_info.markers): # there is a cycle -> we need one more iteration has_incomplete_markers = True continue for group in parent_info.groups: - transitive_marker[group] = transitive_marker[group].union( - parent_info.markers[group].intersect(m) - ) + for edge_marker in group_markers.values(): + transitive_marker[group] = transitive_marker[ + group + ].union( + parent_info.markers[group].intersect(edge_marker) + ) else: - for group in transitive_info.groups: - transitive_marker[group] = transitive_marker[group].union(m) + # Parent is the root (no groups). Edge markers specify which + # dependency groups the edge belongs to. We should only add + # the edge marker to the corresponding child groups. + for groups, edge_marker in group_markers.items(): + assert groups, ( + f"Package {package.name} at depth {depth} has no groups." + f" All dependencies except for the root package at depth -1 must have groups" + ) + for group in transitive_info.groups: + if group in groups: + transitive_marker[group] = transitive_marker[ + group + ].union(edge_marker) transitive_info.markers = transitive_marker diff --git a/tests/puzzle/test_solver_internals.py b/tests/puzzle/test_solver_internals.py index a7ed32ac4d8..79a8cb7f7c6 100644 --- a/tests/puzzle/test_solver_internals.py +++ b/tests/puzzle/test_solver_internals.py @@ -336,6 +336,50 @@ def test_propagate_markers_for_groups2(package: ProjectPackage, solver: Solver) } +def test_propagate_markers_for_groups_same_dep( + package: ProjectPackage, solver: Solver +) -> None: + a = Package("a", "1") + b = Package("b", "1") + package.add_dependency(dep("a", 'sys_platform == "win32"', groups=["main"])) + package.add_dependency(dep("a", 'sys_platform == "linux"', groups=["dev"])) + a.add_dependency(dep("b", 'python_version == "3.8"')) + + packages = [package, a, b] + result = solver._aggregate_solved_packages(packages) + + assert len(result) == len(packages) + assert result[package].groups == set() + assert result[a].groups == {"main", "dev"} + assert result[b].groups == {"main", "dev"} + assert tm(result[package]) == {} + assert tm(result[a]) == { + "main": 'sys_platform == "win32"', + "dev": 'sys_platform == "linux"', + } + assert tm(result[b]) == { + "main": 'sys_platform == "win32" and python_version == "3.8"', + "dev": 'sys_platform == "linux" and python_version == "3.8"', + } + + +def test_propagate_markers_for_groups_with_extra( + package: ProjectPackage, solver: Solver +) -> None: + a = Package("a", "1") + package.add_dependency(dep("a", groups=["main"], in_extras=["foo"])) + package.add_dependency(dep("a", groups=["dev"])) + + packages = [package, a] + result = solver._aggregate_solved_packages(packages) + + assert len(result) == len(packages) + assert result[package].groups == set() + assert result[a].groups == {"main", "dev"} + assert tm(result[package]) == {} + assert tm(result[a]) == {"main": 'extra == "foo"', "dev": ""} + + def test_propagate_markers_with_cycle(package: ProjectPackage, solver: Solver) -> None: a = Package("a", "1") b = Package("b", "1")