diff --git a/docs/environment.yaml b/docs/environment.yaml index 35dc9b74c..32ec3f6b3 100644 --- a/docs/environment.yaml +++ b/docs/environment.yaml @@ -6,6 +6,7 @@ dependencies: - autodoc-pydantic <2.0 - gitpython - kartograf >=1.0.0 +- konnektor >=0.2.0 - libsass - lomap2 >=3.0.0 - myst-parser diff --git a/environment.yml b/environment.yml index 142c9337a..1e3e9948d 100644 --- a/environment.yml +++ b/environment.yml @@ -7,6 +7,7 @@ dependencies: - coverage - duecredit<0.10 - kartograf>=1.0.0 + - konnektor~=0.2.0 - lomap2>=3.2.1 - networkx - numpy<2.0.0 diff --git a/news/konnektor_changes.rst b/news/konnektor_changes.rst new file mode 100644 index 000000000..4ff5dcd05 --- /dev/null +++ b/news/konnektor_changes.rst @@ -0,0 +1,25 @@ +**Added:** + +* Added optional ``n_processes`` (number of parallel processes to use when generating the network) arguments for network planners. +* Added optional ``progress`` (whether to show progress bar) for ``openfe.setup.ligand_network_planning.generate_radial_network`` (default=``False``, such that there is no default behavior change). + +**Changed:** + +* `konnektor _` is now used as the backend for all network generation. +* ``openfe.setup.ligand_network_planning.generate_maximal_network`` now returns the *best* mapping for each edge, rather than *all possible* mappings for each edge. If multiple mappers are passed but no scorer, the first mapper passed will be used, and a warning will be raised. + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/openfe/setup/ligand_network_planning.py b/openfe/setup/ligand_network_planning.py index 784c5501d..97bdeed09 100644 --- a/openfe/setup/ligand_network_planning.py +++ b/openfe/setup/ligand_network_planning.py @@ -3,13 +3,8 @@ from pathlib import Path from typing import Iterable, Callable, Optional, Union -import itertools -from collections import Counter -import functools -import warnings import networkx as nx -from tqdm.auto import tqdm from gufe import SmallMoleculeComponent, AtomMapper from openfe.setup import LigandNetwork @@ -19,6 +14,14 @@ from lomap import generate_lomap_network as generate_lomap_network from lomap import LomapAtomMapper from lomap.dbmol import _find_common_core +from konnektor.network_planners import ( + StarNetworkGenerator, + MaximalNetworkGenerator, + RedundantMinimalSpanningTreeNetworkGenerator, + MinimalSpanningTreeNetworkGenerator, + ExplicitNetworkGenerator, +) +from konnektor import network_analysis, network_planners, network_tools def _hasten_lomap(mapper, ligands): @@ -33,9 +36,12 @@ def _hasten_lomap(mapper, ligands): core = "" return LomapAtomMapper( - time=mapper.time, threed=mapper.threed, max3d=mapper.max3d, - element_change=mapper.element_change, seed=core, - shift=mapper.shift + time=mapper.time, + threed=mapper.threed, + max3d=mapper.max3d, + element_change=mapper.element_change, + seed=core, + shift=mapper.shift, ) @@ -44,6 +50,8 @@ def generate_radial_network( central_ligand: Union[SmallMoleculeComponent, str, int], mappers: Union[AtomMapper, Iterable[AtomMapper]], scorer: Optional[Callable[[LigandAtomMapping], float]] = None, + progress: bool = False, + n_processes: int = 1, ) -> LigandNetwork: """ Plan a radial network with all ligands connected to a central node. @@ -66,6 +74,10 @@ def generate_radial_network( a callable which returns a float for any LigandAtomMapping. Used to assign scores to potential mappings; higher scores indicate better mappings. + progress : bool + If True, show a tqdm progress bar. (default=True) + n_processes: int + number of cpu processes to use if parallelizing network generation. Raises ------ @@ -87,64 +99,61 @@ def generate_radial_network( mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) else m for m in mappers] + ligands = list(ligands) + # handle central_ligand arg possibilities # after this, central_ligand is resolved to a SmallMoleculeComponent if isinstance(central_ligand, int): - ligands = list(ligands) try: central_ligand = ligands[central_ligand] + ligands.remove(central_ligand) except IndexError: raise ValueError(f"index '{central_ligand}' out of bounds, there are " f"{len(ligands)} ligands") elif isinstance(central_ligand, str): - ligands = list(ligands) - possibles = [l for l in ligands if l.name == central_ligand] + possibles = [lig for lig in ligands if lig.name == central_ligand] if not possibles: raise ValueError(f"No ligand called '{central_ligand}' " - f"available: {', '.join(l.name for l in ligands)}") + f"available: {', '.join(lig.name for lig in ligands)}") if len(possibles) > 1: raise ValueError(f"Multiple ligands called '{central_ligand}'") central_ligand = possibles[0] + ligands.remove(central_ligand) + + # Construct network + network_planner = StarNetworkGenerator( + mappers=mappers, + scorer=scorer, + progress=progress, + n_processes=n_processes, + ) - edges = [] - - for ligand in ligands: - if ligand == central_ligand: - wmsg = (f"The central_ligand {ligand.name} was also found in " - "the list of ligands to arrange around the " - "central_ligand this will be ignored.") - warnings.warn(wmsg) - continue - best_score = 0.0 - best_mapping = None - - for mapping in itertools.chain.from_iterable( - mapper.suggest_mappings(central_ligand, ligand) - for mapper in mappers - ): - if not scorer: - best_mapping = mapping - break - - score = scorer(mapping) - mapping = mapping.with_annotations({"score": score}) + network = network_planner.generate_ligand_network( + components=ligands, central_component=central_ligand + ) - if score > best_score: - best_mapping = mapping - best_score = score + if network.is_connected(): + connected_nodes = network.nodes + else: + connected_nodes = max(nx.weakly_connected_components(network.graph), key=len) - if best_mapping is None: - raise ValueError(f"No mapping found for {ligand}") - edges.append(best_mapping) + # check for disconnected nodes + missing_nodes = set(ligands + [central_ligand]) - set(connected_nodes) + missing_node_names = [node.name for node in missing_nodes] + if missing_nodes: + raise RuntimeError( + f"ERROR: No mapping found between the central ligand ('{central_ligand.name}') and the following node(s): {missing_node_names}" + ) - return LigandNetwork(edges) + return network def generate_maximal_network( ligands: Iterable[SmallMoleculeComponent], mappers: Union[AtomMapper, Iterable[AtomMapper]], scorer: Optional[Callable[[LigandAtomMapping], float]] = None, - progress: Union[bool, Callable[[Iterable], Iterable]] = True, + progress: bool = True, + n_processes: int = 1, ) -> LigandNetwork: """ Plan a network with all possible proposed mappings. @@ -167,38 +176,27 @@ def generate_maximal_network( but many can be given. scorer : Scoring function any callable which takes a LigandAtomMapping and returns a float - progress : Union[bool, Callable[Iterable], Iterable] - progress bar: if False, no progress bar will be shown. If True, use a - tqdm progress bar that only appears after 1.5 seconds. You can also - provide a custom progress bar wrapper as a callable. + progress : bool + If True, show a tqdm progress bar. (default=True) + n_processes: int + number of cpu processes to use if parallelizing network generation. """ if isinstance(mappers, AtomMapper): mappers = [mappers] mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) else m for m in mappers] - nodes = list(ligands) - if progress is True: - # default is a tqdm progress bar - total = len(nodes) * (len(nodes) - 1) // 2 - progress = functools.partial(tqdm, total=total, delay=1.5) - elif progress is False: - def progress(x): return x - # otherwise, it should be a user-defined callable - - mapping_generator = itertools.chain.from_iterable( - mapper.suggest_mappings(molA, molB) - for molA, molB in progress(itertools.combinations(nodes, 2)) - for mapper in mappers + # Construct network + network_planner = MaximalNetworkGenerator( + mappers=mappers, + scorer=scorer, + progress=progress, + n_processes=n_processes, ) - if scorer: - mappings = [mapping.with_annotations({'score': scorer(mapping)}) - for mapping in mapping_generator] - else: - mappings = list(mapping_generator) - network = LigandNetwork(mappings, nodes=nodes) + network = network_planner.generate_ligand_network(nodes) + return network @@ -207,7 +205,8 @@ def generate_minimal_spanning_network( mappers: Union[AtomMapper, Iterable[AtomMapper]], # TODO: scorer is currently required, but not actually necessary. scorer: Callable[[LigandAtomMapping], float], - progress: Union[bool, Callable[[Iterable], Iterable]] = True, + progress: bool = True, + n_processes: int = 1, ) -> LigandNetwork: """ Plan a network with as few edges as possible with maximum total score @@ -222,44 +221,37 @@ def generate_minimal_spanning_network( highest score edges scorer : Scoring function any callable which takes a LigandAtomMapping and returns a float - progress : Union[bool, Callable[Iterable], Iterable] - progress bar: if False, no progress bar will be shown. If True, use a - tqdm progress bar that only appears after 1.5 seconds. You can also - provide a custom progress bar wrapper as a callable. + progress : bool + If True, show a tqdm progress bar. (default=True) + n_processes: int + number of cpu processes to use if parallelizing network generation. """ if isinstance(mappers, AtomMapper): mappers = [mappers] mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) else m for m in mappers] + nodes = list(ligands) - # First create a network with all the proposed mappings (scored) - network = generate_maximal_network(ligands, mappers, scorer, progress) - - # Flip network scores so we can use minimal algorithm - g2 = nx.MultiGraph() - for e1, e2, d in network.graph.edges(data=True): - g2.add_edge(e1, e2, weight=-d['score'], object=d['object']) - - # Next analyze that network to create minimal spanning network. Because - # we carry the original (directed) LigandAtomMapping, we don't lose - # direction information when converting to an undirected graph. - min_edges = nx.minimum_spanning_edges(g2) - min_mappings = [edge_data['object'] for _, _, _, edge_data in min_edges] - min_network = LigandNetwork(min_mappings) - missing_nodes = set(network.nodes) - set(min_network.nodes) - if missing_nodes: - raise RuntimeError("Unable to create edges to some nodes: " - f"{list(missing_nodes)}") + # Construct network + network_planner = MinimalSpanningTreeNetworkGenerator( + mappers=mappers, + scorer=scorer, + progress=progress, + n_processes=n_processes, + ) - return min_network + network = network_planner.generate_ligand_network(nodes) + + return network def generate_minimal_redundant_network( ligands: Iterable[SmallMoleculeComponent], mappers: Union[AtomMapper, Iterable[AtomMapper]], scorer: Callable[[LigandAtomMapping], float], - progress: Union[bool, Callable[[Iterable], Iterable]] = True, + progress: bool = True, mst_num: int = 2, + n_processes: int = 1, ) -> LigandNetwork: """ Plan a network with a specified amount of redundancy for each node @@ -278,54 +270,40 @@ def generate_minimal_redundant_network( highest score edges scorer : Scoring function any callable which takes a LigandAtomMapping and returns a float - progress : Union[bool, Callable[Iterable], Iterable] - progress bar: if False, no progress bar will be shown. If True, use a - tqdm progress bar that only appears after 1.5 seconds. You can also - provide a custom progress bar wrapper as a callable. - mst_num: int + progress : bool + If True, show a tqdm progress bar. (default=True) + mst_num : int Minimum Spanning Tree number: the number of minimum spanning trees to generate. If two, the second-best edges are included in the returned network. If three, the third-best edges are also included, etc. + n_processes: int + number of threads to use if parallelizing network generation + """ if isinstance(mappers, AtomMapper): mappers = [mappers] mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) else m for m in mappers] + nodes = list(ligands) - # First create a network with all the proposed mappings (scored) - network = generate_maximal_network(ligands, mappers, scorer, progress) - - # Flip network scores so we can use minimal algorithm - g2 = nx.MultiGraph() - for e1, e2, d in network.graph.edges(data=True): - g2.add_edge(e1, e2, weight=-d['score'], object=d['object']) - - # As in .generate_minimal_spanning_network(), use nx to get the minimal - # network. But now also remove those edges from the fully-connected - # network, then get the minimal network again. Add mappings from all - # minimal networks together. - mappings = [] - for _ in range(mst_num): # can increase range here for more redundancy - # get list from generator so that we don't adjust network by calling it: - current_best_edges = list(nx.minimum_spanning_edges(g2)) - - g2.remove_edges_from(current_best_edges) - for _, _, _, edge_data in current_best_edges: - mappings.append(edge_data['object']) - - redund_network = LigandNetwork(mappings) - missing_nodes = set(network.nodes) - set(redund_network.nodes) - if missing_nodes: - raise RuntimeError("Unable to create edges to some nodes: " - f"{list(missing_nodes)}") + # Construct network + network_planner = RedundantMinimalSpanningTreeNetworkGenerator( + mappers=mappers, + scorer=scorer, + progress=progress, + n_redundancy=mst_num, + n_processes=n_processes, + ) + + network = network_planner.generate_ligand_network(nodes) - return redund_network + return network def generate_network_from_names( - ligands: list[SmallMoleculeComponent], - mapper: AtomMapper, - names: list[tuple[str, str]], + ligands: list[SmallMoleculeComponent], + mapper: AtomMapper, + names: list[tuple[str, str]], ) -> LigandNetwork: """ Generate a :class:`.LigandNetwork` by specifying edges as tuples of names. @@ -353,29 +331,21 @@ def generate_network_from_names( if multiple molecules have the same name (this would otherwise be problematic) """ - nm2idx = {l.name: i for i, l in enumerate(ligands)} + nodes = list(ligands) - if len(nm2idx) < len(ligands): - dupes = Counter((l.name for l in ligands)) - dupe_names = [k for k, v in dupes.items() if v > 1] - raise ValueError(f"Duplicate names: {dupe_names}") + network_planner = ExplicitNetworkGenerator(mappers=mapper, scorer=None) - try: - ids = [(nm2idx[nm1], nm2idx[nm2]) for nm1, nm2 in names] - except KeyError: - badnames = [nm for nm in itertools.chain.from_iterable(names) - if nm not in nm2idx] - available = [ligand.name for ligand in ligands] - raise KeyError(f"Invalid name(s) requested {badnames}. " - f"Available: {available}") + network = network_planner.generate_network_from_names( + components=nodes, names=names + ) - return generate_network_from_indices(ligands, mapper, ids) + return network def generate_network_from_indices( - ligands: list[SmallMoleculeComponent], - mapper: AtomMapper, - indices: list[tuple[int, int]], + ligands: list[SmallMoleculeComponent], + mapper: AtomMapper, + indices: list[tuple[int, int]], ) -> LigandNetwork: """ Generate a :class:`.LigandNetwork` by specifying edges as tuples of indices. @@ -400,31 +370,19 @@ def generate_network_from_indices( IndexError if an invalid ligand index is requested """ - edges = [] - - for i, j in indices: - try: - m1, m2 = ligands[i], ligands[j] - except IndexError: - raise IndexError(f"Invalid ligand id, requested {i} {j} " - f"with {len(ligands)} available") - - mapping = next(mapper.suggest_mappings(m1, m2)) - - edges.append(mapping) - - network = LigandNetwork(edges=edges, nodes=ligands) - - if not network.is_connected(): - warnings.warn("Generated network is not fully connected") + nodes = list(ligands) + network_planner = ExplicitNetworkGenerator(mappers=mapper, scorer=None) + network = network_planner.generate_network_from_indices( + components=nodes, indices=indices + ) return network def load_orion_network( - ligands: list[SmallMoleculeComponent], - mapper: AtomMapper, - network_file: Union[str, Path], + ligands: list[SmallMoleculeComponent], + mapper: AtomMapper, + network_file: Union[str, Path], ) -> LigandNetwork: """Load a :class:`.LigandNetwork` from an Orion NES network file. @@ -460,13 +418,16 @@ def load_orion_network( names.append((entry[0], entry[2])) - return generate_network_from_names(ligands, mapper, names) + network_planner = ExplicitNetworkGenerator(mappers=mapper, scorer=None) + network = network_planner.generate_network_from_names(components=ligands, names=names) + + return network def load_fepplus_network( - ligands: list[SmallMoleculeComponent], - mapper: AtomMapper, - network_file: Union[str, Path], + ligands: list[SmallMoleculeComponent], + mapper: AtomMapper, + network_file: Union[str, Path], ) -> LigandNetwork: """Load a :class:`.LigandNetwork` from an FEP+ edges network file. @@ -496,10 +457,13 @@ def load_fepplus_network( for entry in network_lines: if len(entry) != 5 or entry[1] != '#' or entry[3] != '->': errmsg = ("line does not match expected format " - f"hash:hash # name -> name\n" - "line format: {entry}") + "hash:hash # name -> name\n" + f"line format: {entry}") raise KeyError(errmsg) names.append((entry[2], entry[4])) - return generate_network_from_names(ligands, mapper, names) + network_planner = ExplicitNetworkGenerator(mappers=mapper, scorer=None) + network = network_planner.generate_network_from_names(components=ligands, names=names) + + return network diff --git a/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py b/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py index a425c1fa1..20eb89c62 100644 --- a/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py +++ b/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py @@ -33,9 +33,7 @@ def test_rbfe_alchemical_network_planner_call(atom_mapping_basic_test_files, T4_ solvent=SolventComponent(), protein=T4_protein_component, ) - assert isinstance(alchem_network, AlchemicalNetwork) - edges = alchem_network.edges assert len(edges) == 14 # we build 2envs * (8 ligands - 1) = 14 relative edges. diff --git a/openfe/tests/setup/test_network_planning.py b/openfe/tests/setup/test_network_planning.py index 061de4725..747d08ec6 100644 --- a/openfe/tests/setup/test_network_planning.py +++ b/openfe/tests/setup/test_network_planning.py @@ -200,7 +200,7 @@ def test_radial_network_index_error(self, toluene_vs_others, lomap_old_mapper): ligands = [toluene] + others with pytest.raises(ValueError, match="index '2077' out of bounds, there are 8 ligands"): - openfe.setup.ligand_network_planning.generate_radial_network( + _ = openfe.setup.ligand_network_planning.generate_radial_network( ligands=ligands, central_ligand=2077, mappers=lomap_old_mapper, @@ -214,7 +214,7 @@ def test_radial_network_self_central(self, toluene_vs_others, lomap_old_mapper): toluene, others = toluene_vs_others ligands = [toluene] + others - with pytest.warns(UserWarning, match="The central_ligand toluene was also found in the list of ligands"): + with pytest.warns(UserWarning, match="The central component 'toluene' is present in the list of components"): network = openfe.setup.ligand_network_planning.generate_radial_network( ligands=ligands, central_ligand=toluene, @@ -233,6 +233,7 @@ def test_radial_network_self_central(self, toluene_vs_others, lomap_old_mapper): def test_radial_network_with_scorer(self, toluene_vs_others, lomap_old_mapper, simple_scorer): """Test that the scorer chooses the mapper with the best score (in this case, the LOMAP mapper).""" toluene, others = toluene_vs_others + mappers = [BadMapper(), lomap_old_mapper] scorer = simple_scorer @@ -290,7 +291,8 @@ def test_radial_network_no_mapping_failure(self, toluene_vs_others, lomap_old_ma # lomap cannot make a mapping to nimrod, and will return nothing for the (toluene, nimrod) pair nimrod = openfe.SmallMoleculeComponent(mol_from_smiles('N'), name='nimrod') - with pytest.raises(ValueError, match=r'No mapping found for SmallMoleculeComponent\(name=nimrod\)'): + err_str = r"No mapping found between the central ligand \('toluene'\) and the following node\(s\): \['nimrod'\]" + with pytest.raises(RuntimeError, match=err_str): _ = openfe.setup.ligand_network_planning.generate_radial_network( ligands=others + [nimrod], central_ligand=toluene, @@ -334,12 +336,7 @@ def test_generate_maximal_network( ligands_in_network = {mol.name for mol in network.nodes} assert ligands_in_network == expected_names - if extra_mapper: - # two edges per pair of nodes, one for each mapper - edge_count = len(expected_names) * (len(expected_names) - 1) - else: - # one edge per pair of nodes - edge_count = (len(expected_names) * (len(expected_names) - 1)) / 2 + edge_count = len(others) * (len(others) + 1) / 2 assert len(network.edges) == edge_count @@ -350,7 +347,7 @@ def test_generate_maximal_network( else: for edge in network.edges: assert "score" not in edge.annotations - + assert 'score' not in edge.annotations class TestMinimalSpanningNetworkGenerator: @pytest.mark.parametrize("multi_mappers", [False, True]) @@ -433,7 +430,7 @@ def test_minimal_spanning_network_unreachable(self, toluene_vs_others, lomap_old scorer = simple_scorer - with pytest.raises(RuntimeError, match=r"Unable to create edges to some nodes: \[SmallMoleculeComponent\(name=nimrod\)\]"): + with pytest.raises(RuntimeError, match=r"Unable to create edges for the following nodes: \[SmallMoleculeComponent\(name=nimrod\)\]"): _ = openfe.setup.ligand_network_planning.generate_minimal_spanning_network( ligands=others + [toluene, nimrod], mappers=[lomap_old_mapper], @@ -529,8 +526,9 @@ def test_minimal_redundant_network_unreachable(self, toluene_vs_others, lomap_ol scorer = simple_scorer - with pytest.raises(RuntimeError, match=r"Unable to create edges to some nodes: \[SmallMoleculeComponent\(name=nimrod\)\]"): - _ = openfe.setup.ligand_network_planning.generate_minimal_redundant_network( + err_str = r"ERROR: Unable to create edges for the following nodes: \[SmallMoleculeComponent\(name=nimrod\)\]" + with pytest.raises(RuntimeError, match=err_str): + _ = openfe.setup.ligand_network_planning.generate_minimal_spanning_network( ligands=others + [toluene, nimrod], mappers=[lomap_old_mapper], scorer=scorer @@ -540,26 +538,27 @@ class TestGenerateNetworkFromNames: def test_generate_network_from_names(self, atom_mapping_basic_test_files, lomap_old_mapper): ligands = list(atom_mapping_basic_test_files.values()) - requested = [ + requested_names = [ ('toluene', '2-naftanol'), ('2-methylnaphthalene', '2-naftanol'), ] network = openfe.setup.ligand_network_planning.generate_network_from_names( ligands=ligands, - names=requested, + names=requested_names, mapper=lomap_old_mapper, ) + assert len(network.nodes) == len(ligands) + assert len(network.edges) == 2 + expected_node_names = {c.name for c in ligands} actual_node_names = {n.name for n in network.nodes} - assert len(network.nodes) == len(ligands) assert actual_node_names == expected_node_names - assert len(network.edges) == 2 actual_edges = {(e.componentA.name, e.componentB.name) for e in network.edges} - assert set(requested) == actual_edges + assert set(requested_names) == actual_edges def test_generate_network_from_names_bad_name_error(self, atom_mapping_basic_test_files, lomap_old_mapper): ligands = list(atom_mapping_basic_test_files.values()) @@ -622,7 +621,7 @@ def test_network_from_indices_indexerror(self, atom_mapping_basic_test_files, lo requested = [(20, 1), (2, 3)] - with pytest.raises(IndexError, match="Invalid ligand id"): + with pytest.raises(IndexError, match="Invalid ligand index"): _ = openfe.setup.ligand_network_planning.generate_network_from_indices( ligands=ligands, indices=requested,