Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions src/poetry/puzzle/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
from poetry.repositories import RepositoryPool
from poetry.utils.env import Env

MarkerOriginDict = defaultdict[Package, defaultdict[Package, BaseMarker]]
# markers[child_package][parent_package][groups] -> BaseMarker
MarkerOriginDict = defaultdict[
Package,
defaultdict[Package, defaultdict[frozenset[NormalizedName], BaseMarker]],
]


class Solver:
Expand Down Expand Up @@ -238,7 +242,9 @@ def depth_first_search(
source: PackageNode,
) -> tuple[list[list[PackageNode]], MarkerOriginDict]:
back_edges: dict[DFSNodeID, list[PackageNode]] = defaultdict(list)
markers: MarkerOriginDict = defaultdict(lambda: defaultdict(EmptyMarker))
markers: MarkerOriginDict = defaultdict(
lambda: defaultdict(lambda: defaultdict(EmptyMarker))
)
visited: set[DFSNodeID] = set()
topo_sorted_nodes: list[PackageNode] = []

Expand Down Expand Up @@ -272,12 +278,16 @@ def dfs_visit(

for out_neighbor in node.reachable():
back_edges[out_neighbor.id].append(node)
marker = markers[out_neighbor.package][node.package]
markers[out_neighbor.package][node.package] = marker.union(
groups = out_neighbor.groups
prev_marker = markers[out_neighbor.package][node.package][groups]
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
new_marker = (
out_neighbor.marker
if node.package.is_root()
else out_neighbor.marker.without_extras()
)
markers[out_neighbor.package][node.package][groups] = prev_marker.union(
new_marker
)
dfs_visit(out_neighbor, back_edges, visited, sorted_nodes, markers)
sorted_nodes.insert(0, node)

Expand Down Expand Up @@ -398,20 +408,36 @@ def calculate_markers(
transitive_marker: dict[NormalizedName, BaseMarker] = {
group: EmptyMarker() for group in transitive_info.groups
}
for parent, m in markers[package].items():
for parent, group_markers in markers[package].items():
parent_info = packages[parent]
if parent_info.groups:
# If parent has groups, we need to intersect its per-group
# markers with each edge marker and union into child's groups.
if parent_info.groups != set(parent_info.markers):
# there is a cycle -> we need one more iteration
has_incomplete_markers = True
continue
for group in parent_info.groups:
transitive_marker[group] = transitive_marker[group].union(
parent_info.markers[group].intersect(m)
)
for edge_marker in group_markers.values():
transitive_marker[group] = transitive_marker[
group
].union(
parent_info.markers[group].intersect(edge_marker)
)
else:
for group in transitive_info.groups:
transitive_marker[group] = transitive_marker[group].union(m)
# Parent is the root (no groups). Edge markers specify which
# dependency groups the edge belongs to. We should only add
# the edge marker to the corresponding child groups.
for groups, edge_marker in group_markers.items():
assert groups, (
f"Package {package.name} at depth {depth} has no groups."
f" All dependencies except for the root package at depth -1 must have groups"
)
for group in transitive_info.groups:
if group in groups:
transitive_marker[group] = transitive_marker[
group
].union(edge_marker)
transitive_info.markers = transitive_marker


Expand Down
44 changes: 44 additions & 0 deletions tests/puzzle/test_solver_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,50 @@ def test_propagate_markers_for_groups2(package: ProjectPackage, solver: Solver)
}


def test_propagate_markers_for_groups_same_dep(
package: ProjectPackage, solver: Solver
) -> None:
a = Package("a", "1")
b = Package("b", "1")
package.add_dependency(dep("a", 'sys_platform == "win32"', groups=["main"]))
package.add_dependency(dep("a", 'sys_platform == "linux"', groups=["dev"]))
a.add_dependency(dep("b", 'python_version == "3.8"'))

packages = [package, a, b]
result = solver._aggregate_solved_packages(packages)

assert len(result) == len(packages)
assert result[package].groups == set()
assert result[a].groups == {"main", "dev"}
assert result[b].groups == {"main", "dev"}
assert tm(result[package]) == {}
assert tm(result[a]) == {
"main": 'sys_platform == "win32"',
"dev": 'sys_platform == "linux"',
}
assert tm(result[b]) == {
"main": 'sys_platform == "win32" and python_version == "3.8"',
"dev": 'sys_platform == "linux" and python_version == "3.8"',
}


def test_propagate_markers_for_groups_with_extra(
package: ProjectPackage, solver: Solver
) -> None:
a = Package("a", "1")
package.add_dependency(dep("a", groups=["main"], in_extras=["foo"]))
package.add_dependency(dep("a", groups=["dev"]))

packages = [package, a]
result = solver._aggregate_solved_packages(packages)

assert len(result) == len(packages)
assert result[package].groups == set()
assert result[a].groups == {"main", "dev"}
assert tm(result[package]) == {}
assert tm(result[a]) == {"main": 'extra == "foo"', "dev": ""}


def test_propagate_markers_with_cycle(package: ProjectPackage, solver: Solver) -> None:
a = Package("a", "1")
b = Package("b", "1")
Expand Down