From fa14170e95db5c0387b280ca297bd021ac48d57e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 10 Jan 2023 17:52:38 +0800 Subject: [PATCH 01/14] init --- .../test_autochunk_openfold_codegen.py | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 tests/test_autochunk/test_autochunk_openfold_codegen.py diff --git a/tests/test_autochunk/test_autochunk_openfold_codegen.py b/tests/test_autochunk/test_autochunk_openfold_codegen.py new file mode 100644 index 000000000000..02fa07e2ca00 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_openfold_codegen.py @@ -0,0 +1,113 @@ +from functools import partial + +import pytest +import torch +import torch.fx +import torch.multiprocessing as mp + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port +from tests.test_autochunk.evoformer.evoformer import evoformer_base + +if CODEGEN_AVAILABLE and is_compatible_with_meta(): + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + + +def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): + # for memory test + # torch.cuda.reset_peak_memory_stats() + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # with torch.no_grad(): + # node1 = node.clone() + # pair1 = pair.clone() + # gm(node1, pair1) + # new_now_mem = torch.cuda.memory_allocated() / 1024**2 + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print( + # "autochunk now mem:%.2f max mem:%.2f" + # % (new_now_mem - now_mem, new_max_mem - now_mem) + # ) + + # test forward + with torch.no_grad(): + non_fx_out = model(node, pair) + fx_out = gm(node, pair) + + assert torch.allclose(non_fx_out[0], fx_out[0], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[0] - fx_out[0])) + assert torch.allclose(non_fx_out[1], fx_out[1], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[1] - fx_out[1])) + + +def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + model = evoformer_base().cuda() + node = torch.randn(1, msa_len, pair_len, 256).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + + # trace the module and replace codegen + graph = ColoTracer().trace( + model, + meta_args={ + "node": node.to(torch.device("meta")), + "pair": pair.to(torch.device("meta")), + }, + ) + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + interp = MetaInfoProp(gm_prop) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) + + # now run it twice to get meta info in graph module, not necessary + gm = torch.fx.GraphModule(model, graph) + interp = MetaInfoProp(gm) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) + + codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph) + gm.recompile() + + # assert we have inserted chunk + code = graph.python_code("self").src + assert "chunk_size" in code + # print(code) + + _test_fwd(model, gm, node, pair) + gpc.destroy() + + +@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0') +@pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) +@pytest.mark.parametrize("msa_len", [32]) +@pytest.mark.parametrize("pair_len", [64]) +def test_autochunk_codegen(msa_len, pair_len, max_memory): + run_func = partial( + _test_autochunk_codegen, + msa_len=msa_len, + pair_len=pair_len, + max_memory=max_memory, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + _test_autochunk_codegen(0, 32, 64, 25) From d6b69da1549fb6107e0d39395bdafbacd3b33c13 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 13 Jan 2023 20:16:24 +0800 Subject: [PATCH 02/14] init origin openfold --- .../origin_openfold/__init__.py | 0 .../test_autochunk/origin_openfold/dropout.py | 78 + .../origin_openfold/embedders.py | 420 +++++ .../origin_openfold/embedders_multimer.py | 352 +++++ .../origin_openfold/evoformer.py | 626 ++++++++ tests/test_autochunk/origin_openfold/heads.py | 231 +++ tests/test_autochunk/origin_openfold/msa.py | 384 +++++ .../origin_openfold/outer_product_mean.py | 124 ++ .../origin_openfold/pair_transition.py | 103 ++ .../origin_openfold/primitives.py | 544 +++++++ .../origin_openfold/structure_module.py | 914 +++++++++++ .../origin_openfold/template.py | 308 ++++ .../origin_openfold/triangular_attention.py | 130 ++ .../triangular_multiplicative_update.py | 129 ++ .../utils/all_atom_multimer.py | 415 +++++ .../origin_openfold/utils/checkpointing.py | 86 + .../origin_openfold/utils/feats.py | 302 ++++ .../utils/geometry/__init__.py | 26 + .../utils/geometry/quat_rigid.py | 42 + .../utils/geometry/rigid_matrix_vector.py | 156 ++ .../utils/geometry/rotation_matrix.py | 162 ++ .../utils/geometry/test_utils.py | 86 + .../origin_openfold/utils/geometry/utils.py | 22 + .../origin_openfold/utils/geometry/vector.py | 253 +++ .../origin_openfold/utils/loss.py | 1403 +++++++++++++++++ .../origin_openfold/utils/tensor_utils.py | 384 +++++ 26 files changed, 7680 insertions(+) create mode 100644 tests/test_autochunk/origin_openfold/__init__.py create mode 100644 tests/test_autochunk/origin_openfold/dropout.py create mode 100644 tests/test_autochunk/origin_openfold/embedders.py create mode 100644 tests/test_autochunk/origin_openfold/embedders_multimer.py create mode 100644 tests/test_autochunk/origin_openfold/evoformer.py create mode 100644 tests/test_autochunk/origin_openfold/heads.py create mode 100644 tests/test_autochunk/origin_openfold/msa.py create mode 100644 tests/test_autochunk/origin_openfold/outer_product_mean.py create mode 100644 tests/test_autochunk/origin_openfold/pair_transition.py create mode 100644 tests/test_autochunk/origin_openfold/primitives.py create mode 100644 tests/test_autochunk/origin_openfold/structure_module.py create mode 100644 tests/test_autochunk/origin_openfold/template.py create mode 100644 tests/test_autochunk/origin_openfold/triangular_attention.py create mode 100644 tests/test_autochunk/origin_openfold/triangular_multiplicative_update.py create mode 100644 tests/test_autochunk/origin_openfold/utils/all_atom_multimer.py create mode 100644 tests/test_autochunk/origin_openfold/utils/checkpointing.py create mode 100644 tests/test_autochunk/origin_openfold/utils/feats.py create mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/__init__.py create mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/quat_rigid.py create mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/rigid_matrix_vector.py create mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py create mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/test_utils.py create mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/utils.py create mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/vector.py create mode 100644 tests/test_autochunk/origin_openfold/utils/loss.py create mode 100644 tests/test_autochunk/origin_openfold/utils/tensor_utils.py diff --git a/tests/test_autochunk/origin_openfold/__init__.py b/tests/test_autochunk/origin_openfold/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_autochunk/origin_openfold/dropout.py b/tests/test_autochunk/origin_openfold/dropout.py new file mode 100644 index 000000000000..5e3f8d620498 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/dropout.py @@ -0,0 +1,78 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partialmethod +from typing import List, Union + +import torch +import torch.nn as nn + + +class Dropout(nn.Module): + """ + Implementation of dropout with the ability to share the dropout mask + along a particular dimension. + + If not in training mode, this module computes the identity function. + """ + + def __init__(self, r: float, batch_dim: Union[int, List[int]]): + """ + Args: + r: + Dropout rate + batch_dim: + Dimension(s) along which the dropout mask is shared + """ + super(Dropout, self).__init__() + + self.r = r + if type(batch_dim) == int: + batch_dim = [batch_dim] + self.batch_dim = batch_dim + self.dropout = nn.Dropout(self.r) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + Tensor to which dropout is applied. Can have any shape + compatible with self.batch_dim + """ + shape = list(x.shape) + if self.batch_dim is not None: + for bd in self.batch_dim: + shape[bd] = 1 + mask = x.new_ones(shape) + mask = self.dropout(mask) + x *= mask + return x + + +class DropoutRowwise(Dropout): + """ + Convenience class for rowwise dropout as described in subsection + 1.11.6. + """ + + __init__ = partialmethod(Dropout.__init__, batch_dim=-3) + + +class DropoutColumnwise(Dropout): + """ + Convenience class for columnwise dropout as described in subsection + 1.11.6. + """ + + __init__ = partialmethod(Dropout.__init__, batch_dim=-2) diff --git a/tests/test_autochunk/origin_openfold/embedders.py b/tests/test_autochunk/origin_openfold/embedders.py new file mode 100644 index 000000000000..daeddb084a90 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/embedders.py @@ -0,0 +1,420 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Dict, Tuple + +import torch +import torch.nn as nn + +from .primitives import LayerNorm, Linear +from .template import TemplatePairStack, TemplatePointwiseAttention +from .utils import all_atom_multimer, geometry +from .utils.feats import build_template_angle_feat, build_template_pair_feat +from .utils.tensor_utils import dict_multimap, one_hot, tensor_tree_map + + +class InputEmbedder(nn.Module): + """ + Embeds a subset of the input features. + + Implements Algorithms 3 (InputEmbedder) and 4 (relpos). + """ + + def __init__( + self, + tf_dim: int, + msa_dim: int, + c_z: int, + c_m: int, + relpos_k: int, + **kwargs, + ): + """ + Args: + tf_dim: + Final dimension of the target features + msa_dim: + Final dimension of the MSA features + c_z: + Pair embedding dimension + c_m: + MSA embedding dimension + relpos_k: + Window size used in relative positional encoding + """ + super(InputEmbedder, self).__init__() + + self.tf_dim = tf_dim + self.msa_dim = msa_dim + + self.c_z = c_z + self.c_m = c_m + + self.linear_tf_z_i = Linear(tf_dim, c_z) + self.linear_tf_z_j = Linear(tf_dim, c_z) + self.linear_tf_m = Linear(tf_dim, c_m) + self.linear_msa_m = Linear(msa_dim, c_m) + + # RPE stuff + self.relpos_k = relpos_k + self.no_bins = 2 * relpos_k + 1 + self.linear_relpos = Linear(self.no_bins, c_z) + + def relpos(self, ri: torch.Tensor): + """ + Computes relative positional encodings + + Implements Algorithm 4. + + Args: + ri: + "residue_index" features of shape [*, N] + """ + d = ri[..., None] - ri[..., None, :] + boundaries = torch.arange(start=-self.relpos_k, end=self.relpos_k + 1, device=d.device) + oh = one_hot(d, boundaries).type(ri.dtype) + return self.linear_relpos(oh) + + def forward( + self, + tf: torch.Tensor, + ri: torch.Tensor, + msa: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + tf: + "target_feat" features of shape [*, N_res, tf_dim] + ri: + "residue_index" features of shape [*, N_res] + msa: + "msa_feat" features of shape [*, N_clust, N_res, msa_dim] + Returns: + msa_emb: + [*, N_clust, N_res, C_m] MSA embedding + pair_emb: + [*, N_res, N_res, C_z] pair embedding + + """ + # [*, N_res, c_z] + tf_emb_i = self.linear_tf_z_i(tf) + tf_emb_j = self.linear_tf_z_j(tf) + + # [*, N_res, N_res, c_z] + pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] + pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype)) + + # [*, N_clust, N_res, c_m] + n_clust = msa.shape[-3] + tf_m = (self.linear_tf_m(tf).unsqueeze(-3).expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))) + msa_emb = self.linear_msa_m(msa) + tf_m + + return msa_emb, pair_emb + + +class RecyclingEmbedder(nn.Module): + """ + Embeds the output of an iteration of the model for recycling. + + Implements Algorithm 32. + """ + + def __init__( + self, + c_m: int, + c_z: int, + min_bin: float, + max_bin: float, + no_bins: int, + inf: float = 1e8, + **kwargs, + ): + """ + Args: + c_m: + MSA channel dimension + c_z: + Pair embedding channel dimension + min_bin: + Smallest distogram bin (Angstroms) + max_bin: + Largest distogram bin (Angstroms) + no_bins: + Number of distogram bins + """ + super(RecyclingEmbedder, self).__init__() + + self.c_m = c_m + self.c_z = c_z + self.min_bin = min_bin + self.max_bin = max_bin + self.no_bins = no_bins + self.inf = inf + + self.linear = Linear(self.no_bins, self.c_z) + self.layer_norm_m = LayerNorm(self.c_m) + self.layer_norm_z = LayerNorm(self.c_z) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + m: + First row of the MSA embedding. [*, N_res, C_m] + z: + [*, N_res, N_res, C_z] pair embedding + x: + [*, N_res, 3] predicted C_beta coordinates + Returns: + m: + [*, N_res, C_m] MSA embedding update + z: + [*, N_res, N_res, C_z] pair embedding update + """ + bins = torch.linspace( + self.min_bin, + self.max_bin, + self.no_bins, + dtype=x.dtype, + device=x.device, + requires_grad=False, + ) + + # [*, N, C_m] + m_update = self.layer_norm_m(m) + + # This squared method might become problematic in FP16 mode. + # I'm using it because my homegrown method had a stubborn discrepancy I + # couldn't find in time. + squared_bins = bins**2 + upper = torch.cat([squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1) + d = torch.sum((x[..., None, :] - x[..., None, :, :])**2, dim=-1, keepdims=True) + + # [*, N, N, no_bins] + d = ((d > squared_bins) * (d < upper)).type(x.dtype) + + # [*, N, N, C_z] + d = self.linear(d) + z_update = d + self.layer_norm_z(z) + + return m_update, z_update + + +class TemplateEmbedder(nn.Module): + + def __init__(self, config): + super(TemplateEmbedder, self).__init__() + + self.config = config + self.template_angle_embedder = TemplateAngleEmbedder(**config["template_angle_embedder"],) + self.template_pair_embedder = TemplatePairEmbedder(**config["template_pair_embedder"],) + self.template_pair_stack = TemplatePairStack(**config["template_pair_stack"],) + self.template_pointwise_att = TemplatePointwiseAttention(**config["template_pointwise_attention"],) + + def forward(self, batch, z, pair_mask, templ_dim, chunk_size, _mask_trans=True): + # Embed the templates one at a time (with a poor man's vmap) + template_embeds = [] + n_templ = batch["template_aatype"].shape[templ_dim] + for i in range(n_templ): + idx = batch["template_aatype"].new_tensor(i) + single_template_feats = tensor_tree_map( + lambda t: torch.index_select(t, templ_dim, idx), + batch, + ) + + single_template_embeds = {} + if self.config.embed_angles: + template_angle_feat = build_template_angle_feat(single_template_feats,) + + # [*, S_t, N, C_m] + a = self.template_angle_embedder(template_angle_feat) + + single_template_embeds["angle"] = a + + # [*, S_t, N, N, C_t] + t = build_template_pair_feat( + single_template_feats, + use_unit_vector=self.config.use_unit_vector, + inf=self.config.inf, + eps=self.config.eps, + **self.config.distogram, + ).to(z.dtype) + t = self.template_pair_embedder(t) + + single_template_embeds.update({"pair": t}) + + template_embeds.append(single_template_embeds) + + template_embeds = dict_multimap( + partial(torch.cat, dim=templ_dim), + template_embeds, + ) + + # [*, S_t, N, N, C_z] + t = self.template_pair_stack( + template_embeds["pair"], + pair_mask.unsqueeze(-3).to(dtype=z.dtype), + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) + + # [*, N, N, C_z] + t = self.template_pointwise_att( + t, + z, + template_mask=batch["template_mask"].to(dtype=z.dtype), + chunk_size=chunk_size, + ) + t = t * (torch.sum(batch["template_mask"]) > 0) + + ret = {} + if self.config.embed_angles: + ret["template_single_embedding"] = template_embeds["angle"] + + ret.update({"template_pair_embedding": t}) + + return ret + + +class TemplateAngleEmbedder(nn.Module): + """ + Embeds the "template_angle_feat" feature. + + Implements Algorithm 2, line 7. + """ + + def __init__( + self, + c_in: int, + c_out: int, + **kwargs, + ): + """ + Args: + c_in: + Final dimension of "template_angle_feat" + c_out: + Output channel dimension + """ + super(TemplateAngleEmbedder, self).__init__() + + self.c_out = c_out + self.c_in = c_in + + self.linear_1 = Linear(self.c_in, self.c_out, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.c_out, self.c_out, init="relu") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: [*, N_templ, N_res, c_in] "template_angle_feat" features + Returns: + x: [*, N_templ, N_res, C_out] embedding + """ + x = self.linear_1(x) + x = self.relu(x) + x = self.linear_2(x) + + return x + + +class TemplatePairEmbedder(nn.Module): + """ + Embeds "template_pair_feat" features. + + Implements Algorithm 2, line 9. + """ + + def __init__( + self, + c_in: int, + c_out: int, + **kwargs, + ): + """ + Args: + c_in: + + c_out: + Output channel dimension + """ + super(TemplatePairEmbedder, self).__init__() + + self.c_in = c_in + self.c_out = c_out + + # Despite there being no relu nearby, the source uses that initializer + self.linear = Linear(self.c_in, self.c_out, init="relu") + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + [*, C_in] input tensor + Returns: + [*, C_out] output tensor + """ + x = self.linear(x) + + return x + + +class ExtraMSAEmbedder(nn.Module): + """ + Embeds unclustered MSA sequences. + + Implements Algorithm 2, line 15 + """ + + def __init__( + self, + c_in: int, + c_out: int, + **kwargs, + ): + """ + Args: + c_in: + Input channel dimension + c_out: + Output channel dimension + """ + super(ExtraMSAEmbedder, self).__init__() + + self.c_in = c_in + self.c_out = c_out + + self.linear = Linear(self.c_in, self.c_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + [*, N_extra_seq, N_res, C_in] "extra_msa_feat" features + Returns: + [*, N_extra_seq, N_res, C_out] embedding + """ + x = self.linear(x) + + return x diff --git a/tests/test_autochunk/origin_openfold/embedders_multimer.py b/tests/test_autochunk/origin_openfold/embedders_multimer.py new file mode 100644 index 000000000000..6bee17227457 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/embedders_multimer.py @@ -0,0 +1,352 @@ +from functools import partial +from typing import Dict, Tuple + +import torch +import torch.nn as nn + +from .primitives import LayerNorm, Linear +from .template import TemplatePairStack, TemplatePointwiseAttention +from .utils import all_atom_multimer, dgram_from_positions, geometry +from .utils.tensor_utils import dict_multimap, one_hot, tensor_tree_map + + +class InputEmbedderMultimer(nn.Module): + """ + Embeds a subset of the input features. + + Implements Algorithms 3 (InputEmbedder) and 4 (relpos). + """ + + def __init__( + self, + tf_dim: int, + msa_dim: int, + c_z: int, + c_m: int, + max_relative_idx: int, + use_chain_relative: bool, + max_relative_chain: int, + **kwargs, + ): + """ + Args: + tf_dim: + Final dimension of the target features + msa_dim: + Final dimension of the MSA features + c_z: + Pair embedding dimension + c_m: + MSA embedding dimension + relpos_k: + Window size used in relative positional encoding + """ + super(InputEmbedderMultimer, self).__init__() + + self.tf_dim = tf_dim + self.msa_dim = msa_dim + + self.c_z = c_z + self.c_m = c_m + + self.linear_tf_z_i = Linear(tf_dim, c_z) + self.linear_tf_z_j = Linear(tf_dim, c_z) + self.linear_tf_m = Linear(tf_dim, c_m) + self.linear_msa_m = Linear(msa_dim, c_m) + + # RPE stuff + self.max_relative_idx = max_relative_idx + self.use_chain_relative = use_chain_relative + self.max_relative_chain = max_relative_chain + if self.use_chain_relative: + self.no_bins = 2 * max_relative_idx + 2 + 1 + 2 * max_relative_chain + 2 + else: + self.no_bins = 2 * max_relative_idx + 1 + self.linear_relpos = Linear(self.no_bins, c_z) + + def relpos(self, batch: Dict[str, torch.Tensor]): + pos = batch["residue_index"] + asym_id = batch["asym_id"] + asym_id_same = asym_id[..., None] == asym_id[..., None, :] + offset = pos[..., None] - pos[..., None, :] + + clipped_offset = torch.clamp(offset + self.max_relative_idx, 0, 2 * self.max_relative_idx) + + rel_feats = [] + if self.use_chain_relative: + final_offset = torch.where( + asym_id_same, + clipped_offset, + (2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset), + ) + + rel_pos = torch.nn.functional.one_hot( + final_offset, + 2 * self.max_relative_idx + 2, + ) + + rel_feats.append(rel_pos) + + entity_id = batch["entity_id"] + entity_id_same = entity_id[..., None] == entity_id[..., None, :] + rel_feats.append(entity_id_same[..., None]) + + sym_id = batch["sym_id"] + rel_sym_id = sym_id[..., None] - sym_id[..., None, :] + + max_rel_chain = self.max_relative_chain + clipped_rel_chain = torch.clamp( + rel_sym_id + max_rel_chain, + 0, + 2 * max_rel_chain, + ) + + final_rel_chain = torch.where( + entity_id_same, + clipped_rel_chain, + (2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain), + ) + + rel_chain = torch.nn.functional.one_hot( + final_rel_chain.long(), + 2 * max_rel_chain + 2, + ) + + rel_feats.append(rel_chain) + else: + rel_pos = torch.nn.functional.one_hot( + clipped_offset, + 2 * self.max_relative_idx + 1, + ) + rel_feats.append(rel_pos) + + rel_feat = torch.cat(rel_feats, dim=-1).to(self.linear_relpos.weight.dtype) + + return self.linear_relpos(rel_feat) + + def forward(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + tf = batch["target_feat"] + msa = batch["msa_feat"] + + # [*, N_res, c_z] + tf_emb_i = self.linear_tf_z_i(tf) + tf_emb_j = self.linear_tf_z_j(tf) + + # [*, N_res, N_res, c_z] + pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] + pair_emb = pair_emb + self.relpos(batch) + + # [*, N_clust, N_res, c_m] + n_clust = msa.shape[-3] + tf_m = (self.linear_tf_m(tf).unsqueeze(-3).expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))) + msa_emb = self.linear_msa_m(msa) + tf_m + + return msa_emb, pair_emb + + +class TemplatePairEmbedderMultimer(nn.Module): + + def __init__( + self, + c_z: int, + c_out: int, + c_dgram: int, + c_aatype: int, + ): + super().__init__() + + self.dgram_linear = Linear(c_dgram, c_out) + self.aatype_linear_1 = Linear(c_aatype, c_out) + self.aatype_linear_2 = Linear(c_aatype, c_out) + self.query_embedding_layer_norm = LayerNorm(c_z) + self.query_embedding_linear = Linear(c_z, c_out) + + self.pseudo_beta_mask_linear = Linear(1, c_out) + self.x_linear = Linear(1, c_out) + self.y_linear = Linear(1, c_out) + self.z_linear = Linear(1, c_out) + self.backbone_mask_linear = Linear(1, c_out) + + def forward( + self, + template_dgram: torch.Tensor, + aatype_one_hot: torch.Tensor, + query_embedding: torch.Tensor, + pseudo_beta_mask: torch.Tensor, + backbone_mask: torch.Tensor, + multichain_mask_2d: torch.Tensor, + unit_vector: geometry.Vec3Array, + ) -> torch.Tensor: + act = 0. + + pseudo_beta_mask_2d = (pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]) + pseudo_beta_mask_2d *= multichain_mask_2d + template_dgram *= pseudo_beta_mask_2d[..., None] + act += self.dgram_linear(template_dgram) + act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None]) + + aatype_one_hot = aatype_one_hot.to(template_dgram.dtype) + act += self.aatype_linear_1(aatype_one_hot[..., None, :, :]) + act += self.aatype_linear_2(aatype_one_hot[..., None, :]) + + backbone_mask_2d = (backbone_mask[..., None] * backbone_mask[..., None, :]) + backbone_mask_2d *= multichain_mask_2d + x, y, z = [coord * backbone_mask_2d for coord in unit_vector] + act += self.x_linear(x[..., None]) + act += self.y_linear(y[..., None]) + act += self.z_linear(z[..., None]) + + act += self.backbone_mask_linear(backbone_mask_2d[..., None]) + + query_embedding = self.query_embedding_layer_norm(query_embedding) + act += self.query_embedding_linear(query_embedding) + + return act + + +class TemplateSingleEmbedderMultimer(nn.Module): + + def __init__( + self, + c_in: int, + c_m: int, + ): + super().__init__() + self.template_single_embedder = Linear(c_in, c_m) + self.template_projector = Linear(c_m, c_m) + + def forward( + self, + batch, + atom_pos, + aatype_one_hot, + ): + out = {} + + template_chi_angles, template_chi_mask = (all_atom_multimer.compute_chi_angles( + atom_pos, + batch["template_all_atom_mask"], + batch["template_aatype"], + )) + + template_features = torch.cat( + [ + aatype_one_hot, + torch.sin(template_chi_angles) * template_chi_mask, + torch.cos(template_chi_angles) * template_chi_mask, + template_chi_mask, + ], + dim=-1, + ) + + template_mask = template_chi_mask[..., 0] + + template_activations = self.template_single_embedder(template_features) + template_activations = torch.nn.functional.relu(template_activations) + template_activations = self.template_projector(template_activations,) + + out["template_single_embedding"] = (template_activations) + out["template_mask"] = template_mask + + return out + + +class TemplateEmbedderMultimer(nn.Module): + + def __init__(self, config): + super(TemplateEmbedderMultimer, self).__init__() + + self.config = config + self.template_pair_embedder = TemplatePairEmbedderMultimer(**config["template_pair_embedder"],) + self.template_single_embedder = TemplateSingleEmbedderMultimer(**config["template_single_embedder"],) + self.template_pair_stack = TemplatePairStack(**config["template_pair_stack"],) + + self.linear_t = Linear(config.c_t, config.c_z) + + def forward( + self, + batch, + z, + padding_mask_2d, + templ_dim, + chunk_size, + multichain_mask_2d, + ): + template_embeds = [] + n_templ = batch["template_aatype"].shape[templ_dim] + for i in range(n_templ): + idx = batch["template_aatype"].new_tensor(i) + single_template_feats = tensor_tree_map( + lambda t: torch.index_select(t, templ_dim, idx), + batch, + ) + + single_template_embeds = {} + act = 0. + + template_positions, pseudo_beta_mask = ( + single_template_feats["template_pseudo_beta"], + single_template_feats["template_pseudo_beta_mask"], + ) + + template_dgram = dgram_from_positions( + template_positions, + inf=self.config.inf, + **self.config.distogram, + ) + + aatype_one_hot = torch.nn.functional.one_hot( + single_template_feats["template_aatype"], + 22, + ) + + raw_atom_pos = single_template_feats["template_all_atom_positions"] + + atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) + rigid, backbone_mask = all_atom_multimer.make_backbone_affine( + atom_pos, + single_template_feats["template_all_atom_mask"], + single_template_feats["template_aatype"], + ) + points = rigid.translation + rigid_vec = rigid[..., None].inverse().apply_to_point(points) + unit_vector = rigid_vec.normalized() + + pair_act = self.template_pair_embedder( + template_dgram, + aatype_one_hot, + z, + pseudo_beta_mask, + backbone_mask, + multichain_mask_2d, + unit_vector, + ) + + single_template_embeds["template_pair_embedding"] = pair_act + single_template_embeds.update( + self.template_single_embedder( + single_template_feats, + atom_pos, + aatype_one_hot, + )) + template_embeds.append(single_template_embeds) + + template_embeds = dict_multimap( + partial(torch.cat, dim=templ_dim), + template_embeds, + ) + + # [*, S_t, N, N, C_z] + t = self.template_pair_stack( + template_embeds["template_pair_embedding"], + padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype), + chunk_size=chunk_size, + _mask_trans=False, + ) + # [*, N, N, C_z] + t = torch.sum(t, dim=-4) / n_templ + t = torch.nn.functional.relu(t) + t = self.linear_t(t) + template_embeds["template_pair_embedding"] = t + + return template_embeds diff --git a/tests/test_autochunk/origin_openfold/evoformer.py b/tests/test_autochunk/origin_openfold/evoformer.py new file mode 100644 index 000000000000..ab1f03a950f7 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/evoformer.py @@ -0,0 +1,626 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from .dropout import DropoutColumnwise, DropoutRowwise +from .msa import MSAColumnAttention, MSAColumnGlobalAttention, MSARowAttentionWithPairBias +from .outer_product_mean import OuterProductMean +from .pair_transition import PairTransition +from .primitives import LayerNorm, Linear +from .triangular_attention import TriangleAttentionEndingNode, TriangleAttentionStartingNode +from .triangular_multiplicative_update import TriangleMultiplicationIncoming, TriangleMultiplicationOutgoing +from .utils.checkpointing import checkpoint_blocks, get_checkpoint_fn +from .utils.tensor_utils import chunk_layer + + +class MSATransition(nn.Module): + """ + Feed-forward network applied to MSA activations after attention. + + Implements Algorithm 9 + """ + + def __init__(self, c_m, n): + """ + Args: + c_m: + MSA channel dimension + n: + Factor multiplied to c_m to obtain the hidden channel + dimension + """ + super(MSATransition, self).__init__() + + self.c_m = c_m + self.n = n + + self.layer_norm = LayerNorm(self.c_m) + self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") + + def _transition(self, m, mask): + m = self.linear_1(m) + m = self.relu(m) + m = self.linear_2(m) * mask + return m + + @torch.jit.ignore + def _chunk( + self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + { + "m": m, + "mask": mask + }, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA activation + mask: + [*, N_seq, N_res, C_m] MSA mask + Returns: + m: + [*, N_seq, N_res, C_m] MSA activation update + """ + + # DISCREPANCY: DeepMind forgets to apply the MSA mask here. + if mask is None: + mask = m.new_ones(m.shape[:-1]) + + # [*, N_seq, N_res, 1] + mask = mask.unsqueeze(-1) + + m = self.layer_norm(m) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self._transition(m, mask) + + return m + + +class EvoformerBlockCore(nn.Module): + + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + pair_dropout: float, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + is_multimer: bool = False, + ): + super(EvoformerBlockCore, self).__init__() + self.is_multimer = is_multimer + self.msa_transition = MSATransition( + c_m=c_m, + n=transition_n, + ) + + self.outer_product_mean = OuterProductMean( + c_m, + c_z, + c_hidden_opm, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + c_z, + c_hidden_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + c_z, + c_hidden_mul, + ) + + self.tri_att_start = TriangleAttentionStartingNode( + c_z, + c_hidden_pair_att, + no_heads_pair, + inf=inf, + ) + self.tri_att_end = TriangleAttentionEndingNode( + c_z, + c_hidden_pair_att, + no_heads_pair, + inf=inf, + ) + + self.pair_transition = PairTransition( + c_z, + transition_n, + ) + + self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) + self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # DeepMind doesn't mask these transitions in the source, so _mask_trans + # should be disabled to better approximate the exact activations of + # the original. + msa_trans_mask = msa_mask if _mask_trans else None + pair_trans_mask = pair_mask if _mask_trans else None + + m = m + self.msa_transition(m, mask=msa_trans_mask, chunk_size=chunk_size) + z = z + self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size) + z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) + z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) + z = z + self.ps_dropout_row_layer(self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)) + z = z + self.ps_dropout_col_layer(self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)) + z = z + self.pair_transition(z, mask=pair_trans_mask, chunk_size=chunk_size) + + return m, z + + +class EvoformerBlock(nn.Module): + + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + is_multimer: bool, + ): + super(EvoformerBlock, self).__init__() + + self.msa_att_row = MSARowAttentionWithPairBias( + c_m=c_m, + c_z=c_z, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + ) + + self.msa_att_col = MSAColumnAttention( + c_m, + c_hidden_msa_att, + no_heads_msa, + inf=inf, + ) + + self.msa_dropout_layer = DropoutRowwise(msa_dropout) + + self.core = EvoformerBlockCore( + c_m=c_m, + c_z=c_z, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ) + + self.outer_product_mean = OuterProductMean( + c_m, + c_z, + c_hidden_opm, + ) + self.is_multimer = is_multimer + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m = m + self.msa_dropout_layer(self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)) + m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) + m, z = self.core( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) + + return m, z + + +class ExtraMSABlock(nn.Module): + """ + Almost identical to the standard EvoformerBlock, except in that the + ExtraMSABlock uses GlobalAttention for MSA column attention and + requires more fine-grained control over checkpointing. Separated from + its twin to preserve the TorchScript-ability of the latter. + """ + + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + ckpt: bool, + is_multimer: bool, + ): + super(ExtraMSABlock, self).__init__() + + self.ckpt = ckpt + + self.msa_att_row = MSARowAttentionWithPairBias( + c_m=c_m, + c_z=c_z, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + ) + + self.msa_att_col = MSAColumnGlobalAttention( + c_in=c_m, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + eps=eps, + ) + + self.msa_dropout_layer = DropoutRowwise(msa_dropout) + + self.core = EvoformerBlockCore( + c_m=c_m, + c_z=c_z, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ) + self.is_multimer = is_multimer + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _chunk_logits: Optional[int] = 1024, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m = m + self.msa_dropout_layer( + self.msa_att_row( + m.clone(), + z=z.clone(), + mask=msa_mask, + chunk_size=chunk_size, + _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None, + _checkpoint_chunks=self.ckpt if torch.is_grad_enabled() else False, + )) + + def fn(m, z): + m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) + m, z = self.core(m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size) + + return m, z + + if (torch.is_grad_enabled() and self.ckpt): + checkpoint_fn = get_checkpoint_fn() + m, z = checkpoint_fn(fn, m, z) + else: + m, z = fn(m, z) + + return m, z + + +class EvoformerStack(nn.Module): + """ + Main Evoformer trunk. + + Implements Algorithm 6. + """ + + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + c_s: int, + no_heads_msa: int, + no_heads_pair: int, + no_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + blocks_per_ckpt: int, + inf: float, + eps: float, + clear_cache_between_blocks: bool = False, + is_multimer: bool = False, + **kwargs, + ): + """ + Args: + c_m: + MSA channel dimension + c_z: + Pair channel dimension + c_hidden_msa_att: + Hidden dimension in MSA attention + c_hidden_opm: + Hidden dimension in outer product mean module + c_hidden_mul: + Hidden dimension in multiplicative updates + c_hidden_pair_att: + Hidden dimension in triangular attention + c_s: + Channel dimension of the output "single" embedding + no_heads_msa: + Number of heads used for MSA attention + no_heads_pair: + Number of heads used for pair attention + no_blocks: + Number of Evoformer blocks in the stack + transition_n: + Factor by which to multiply c_m to obtain the MSATransition + hidden dimension + msa_dropout: + Dropout rate for MSA activations + pair_dropout: + Dropout used for pair activations + blocks_per_ckpt: + Number of Evoformer blocks in each activation checkpoint + clear_cache_between_blocks: + Whether to clear CUDA's GPU memory cache between blocks of the + stack. Slows down each block but can reduce fragmentation + """ + super(EvoformerStack, self).__init__() + + self.blocks_per_ckpt = blocks_per_ckpt + self.clear_cache_between_blocks = clear_cache_between_blocks + + self.blocks = nn.ModuleList() + + for _ in range(no_blocks): + block = EvoformerBlock( + c_m=c_m, + c_z=c_z, + c_hidden_msa_att=c_hidden_msa_att, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + is_multimer=is_multimer, + ) + self.blocks.append(block) + + self.linear = Linear(c_m, c_s) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: int, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + msa_mask: + [*, N_seq, N_res] MSA mask + pair_mask: + [*, N_res, N_res] pair mask + Returns: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + s: + [*, N_res, C_s] single embedding (or None if extra MSA stack) + """ + blocks = [ + partial( + b, + msa_mask=msa_mask, + pair_mask=pair_mask, + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) for b in self.blocks + ] + + if (self.clear_cache_between_blocks): + + def block_with_cache_clear(block, *args): + torch.cuda.empty_cache() + return block(*args) + + blocks = [partial(block_with_cache_clear, b) for b in blocks] + + m, z = checkpoint_blocks( + blocks, + args=(m, z), + blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, + ) + + s = self.linear(m[..., 0, :, :]) + + return m, z, s + + +class ExtraMSAStack(nn.Module): + """ + Implements Algorithm 18. + """ + + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + no_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + ckpt: bool, + clear_cache_between_blocks: bool = False, + is_multimer: bool = False, + **kwargs, + ): + super(ExtraMSAStack, self).__init__() + + self.clear_cache_between_blocks = clear_cache_between_blocks + self.blocks = nn.ModuleList() + for _ in range(no_blocks): + block = ExtraMSABlock( + c_m=c_m, + c_z=c_z, + c_hidden_msa_att=c_hidden_msa_att, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ckpt=ckpt, + is_multimer=is_multimer, + ) + self.blocks.append(block) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + chunk_size: int, + msa_mask: Optional[torch.Tensor] = None, + pair_mask: Optional[torch.Tensor] = None, + _mask_trans: bool = True, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_extra, N_res, C_m] extra MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + msa_mask: + Optional [*, N_extra, N_res] MSA mask + pair_mask: + Optional [*, N_res, N_res] pair mask + Returns: + [*, N_res, N_res, C_z] pair update + """ + #checkpoint_fn = get_checkpoint_fn() + #blocks = [ + # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks + #] + + #def dodo(b, *args): + # torch.cuda.empty_cache() + # return b(*args) + + #blocks = [partial(dodo, b) for b in blocks] + + #for b in blocks: + # if(torch.is_grad_enabled()): + # m, z = checkpoint_fn(b, *(m, z)) + # else: + # m, z = b(m, z) + + for b in self.blocks: + m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) + + if (self.clear_cache_between_blocks): + torch.cuda.empty_cache() + + return z diff --git a/tests/test_autochunk/origin_openfold/heads.py b/tests/test_autochunk/origin_openfold/heads.py new file mode 100644 index 000000000000..57718d5d4a1f --- /dev/null +++ b/tests/test_autochunk/origin_openfold/heads.py @@ -0,0 +1,231 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from .primitives import LayerNorm, Linear +from .utils.loss import compute_plddt, compute_predicted_aligned_error, compute_tm + + +class AuxiliaryHeads(nn.Module): + + def __init__(self, config): + super(AuxiliaryHeads, self).__init__() + + self.plddt = PerResidueLDDTCaPredictor(**config["lddt"],) + + self.distogram = DistogramHead(**config["distogram"],) + + self.masked_msa = MaskedMSAHead(**config["masked_msa"],) + + self.experimentally_resolved = ExperimentallyResolvedHead(**config["experimentally_resolved"],) + + if config.tm.enabled: + self.tm = TMScoreHead(**config.tm,) + + self.config = config + + def forward(self, outputs): + aux_out = {} + lddt_logits = self.plddt(outputs["sm"]["single"]) + aux_out["lddt_logits"] = lddt_logits + + # Required for relaxation later on + aux_out["plddt"] = compute_plddt(lddt_logits) + + distogram_logits = self.distogram(outputs["pair"]) + aux_out["distogram_logits"] = distogram_logits + + masked_msa_logits = self.masked_msa(outputs["msa"]) + aux_out["masked_msa_logits"] = masked_msa_logits + + experimentally_resolved_logits = self.experimentally_resolved(outputs["single"]) + aux_out["experimentally_resolved_logits"] = experimentally_resolved_logits + + if self.config.tm.enabled: + tm_logits = self.tm(outputs["pair"]) + aux_out["tm_logits"] = tm_logits + aux_out["predicted_tm_score"] = compute_tm(tm_logits, **self.config.tm) + aux_out.update(compute_predicted_aligned_error( + tm_logits, + **self.config.tm, + )) + + return aux_out + + +class PerResidueLDDTCaPredictor(nn.Module): + + def __init__(self, no_bins, c_in, c_hidden): + super(PerResidueLDDTCaPredictor, self).__init__() + + self.no_bins = no_bins + self.c_in = c_in + self.c_hidden = c_hidden + + self.layer_norm = LayerNorm(self.c_in) + + self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu") + self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu") + self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final") + + self.relu = nn.ReLU() + + def forward(self, s): + s = self.layer_norm(s) + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + return s + + +class DistogramHead(nn.Module): + """ + Computes a distogram probability distribution. + + For use in computation of distogram loss, subsection 1.9.8 + """ + + def __init__(self, c_z, no_bins, **kwargs): + """ + Args: + c_z: + Input channel dimension + no_bins: + Number of distogram bins + """ + super(DistogramHead, self).__init__() + + self.c_z = c_z + self.no_bins = no_bins + + self.linear = Linear(self.c_z, self.no_bins, init="final") + + def forward(self, z): # [*, N, N, C_z] + """ + Args: + z: + [*, N_res, N_res, C_z] pair embedding + Returns: + [*, N, N, no_bins] distogram probability distribution + """ + # [*, N, N, no_bins] + logits = self.linear(z) + logits = logits + logits.transpose(-2, -3) + return logits + + +class TMScoreHead(nn.Module): + """ + For use in computation of TM-score, subsection 1.9.7 + """ + + def __init__(self, c_z, no_bins, **kwargs): + """ + Args: + c_z: + Input channel dimension + no_bins: + Number of bins + """ + super(TMScoreHead, self).__init__() + + self.c_z = c_z + self.no_bins = no_bins + + self.linear = Linear(self.c_z, self.no_bins, init="final") + + def forward(self, z): + """ + Args: + z: + [*, N_res, N_res, C_z] pairwise embedding + Returns: + [*, N_res, N_res, no_bins] prediction + """ + # [*, N, N, no_bins] + logits = self.linear(z) + return logits + + +class MaskedMSAHead(nn.Module): + """ + For use in computation of masked MSA loss, subsection 1.9.9 + """ + + def __init__(self, c_m, c_out, **kwargs): + """ + Args: + c_m: + MSA channel dimension + c_out: + Output channel dimension + """ + super(MaskedMSAHead, self).__init__() + + self.c_m = c_m + self.c_out = c_out + + self.linear = Linear(self.c_m, self.c_out, init="final") + + def forward(self, m): + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + Returns: + [*, N_seq, N_res, C_out] reconstruction + """ + # [*, N_seq, N_res, C_out] + logits = self.linear(m) + return logits + + +class ExperimentallyResolvedHead(nn.Module): + """ + For use in computation of "experimentally resolved" loss, subsection + 1.9.10 + """ + + def __init__(self, c_s, c_out, **kwargs): + """ + Args: + c_s: + Input channel dimension + c_out: + Number of distogram bins + """ + super(ExperimentallyResolvedHead, self).__init__() + + self.c_s = c_s + self.c_out = c_out + + self.linear = Linear(self.c_s, self.c_out, init="final") + + def forward(self, s): + """ + Args: + s: + [*, N_res, C_s] single embedding + Returns: + [*, N, C_out] logits + """ + # [*, N, C_out] + logits = self.linear(s) + return logits diff --git a/tests/test_autochunk/origin_openfold/msa.py b/tests/test_autochunk/origin_openfold/msa.py new file mode 100644 index 000000000000..4c7714ab73f1 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/msa.py @@ -0,0 +1,384 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn + +from .primitives import Attention, GlobalAttention, LayerNorm, Linear, _attention_chunked_trainable +from .utils.checkpointing import get_checkpoint_fn +from .utils.tensor_utils import chunk_layer, flatten_final_dims, permute_final_dims + + +class MSAAttention(nn.Module): + + def __init__( + self, + c_in, + c_hidden, + no_heads, + pair_bias=False, + c_z=None, + inf=1e9, + ): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + pair_bias: + Whether to use pair embedding bias + c_z: + Pair embedding channel dimension. Ignored unless pair_bias + is true + inf: + A large number to be used in computing the attention mask + """ + super(MSAAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.pair_bias = pair_bias + self.c_z = c_z + self.inf = inf + + self.layer_norm_m = LayerNorm(self.c_in) + + self.layer_norm_z = None + self.linear_z = None + if self.pair_bias: + self.layer_norm_z = LayerNorm(self.c_z) + self.linear_z = Linear(self.c_z, self.no_heads, bias=False, init="normal") + + self.mha = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) + + @torch.jit.ignore + def _chunk( + self, + m: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self.mha, + { + "q_x": m, + "kv_x": m, + "biases": biases + }, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def _prep_inputs(self, m: torch.Tensor, z: Optional[torch.Tensor], + mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, N_seq, N_res, C_m] + m = self.layer_norm_m(m) + + n_seq, n_res = m.shape[-3:-1] + if mask is None: + # [*, N_seq, N_res] + mask = m.new_ones(m.shape[:-3] + (n_seq, n_res),) + + # [*, N_seq, 1, 1, N_res] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # This step simply returns a larger view of the bias, and does not + # consume additional memory. + # [*, N_seq, no_heads, N_res, N_res] + #bias = bias.expand( + # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) + #) + + if (self.pair_bias and z is not None and # For the + self.layer_norm_z is not None and # benefit of + self.linear_z is not None # TorchScript + ): + # [*, N_res, N_res, C_z] + z = self.layer_norm_z(z) + + # [*, N_res, N_res, no_heads] + z = self.linear_z(z) + + # [*, 1, no_heads, N_res, N_res] + z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) + + return m, mask_bias, z + + @torch.jit.ignore + def _chunked_msa_attn( + self, + m: torch.Tensor, + z: Optional[torch.Tensor], + mask: Optional[torch.Tensor], + chunk_logits: int, + checkpoint: bool, + ) -> torch.Tensor: + MSA_DIM = -4 + + def _get_qkv(m, z): + m, mask_bias, z = self._prep_inputs(m, z, mask) + q, k, v = self.mha._prep_qkv(m, m) + return m, q, k, v, mask_bias, z + + checkpoint_fn = get_checkpoint_fn() + + if (torch.is_grad_enabled() and checkpoint): + m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z) + else: + m, q, k, v, mask_bias, z = _get_qkv(m, z) + + o = _attention_chunked_trainable( + query=q, + key=k, + value=v, + biases=[mask_bias, z], + chunk_size=chunk_logits, + chunk_dim=MSA_DIM, + checkpoint=checkpoint, + ) + + if (torch.is_grad_enabled() and checkpoint): + # Storing an additional m here is far from ideal + m = checkpoint_fn(self.mha._wrap_up, o, m) + else: + m = self.mha._wrap_up(o, m) + + return m + + def forward( + self, + m: torch.Tensor, + z: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + _chunk_logits: Optional[int] = None, + _checkpoint_chunks: Optional[bool] = None, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding. Required only if + pair_bias is True + mask: + [*, N_seq, N_res] MSA mask + chunk_size: + Size of chunks into which the inputs are split along their + batch dimensions. A low value decreases memory overhead at the + cost of slower execution. Chunking is not performed by default. + + """ + if (_chunk_logits is not None): + return self._chunked_msa_attn(m=m, + z=z, + mask=mask, + chunk_logits=_chunk_logits, + checkpoint=_checkpoint_chunks) + + m, mask_bias, z = self._prep_inputs(m, z, mask) + + biases = [mask_bias] + if (z is not None): + biases.append(z) + + if chunk_size is not None: + m = self._chunk(m, biases, chunk_size) + else: + m = self.mha(q_x=m, kv_x=m, biases=biases) + + return m + + +class MSARowAttentionWithPairBias(MSAAttention): + """ + Implements Algorithm 7. + """ + + def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9): + """ + Args: + c_m: + Input channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + inf: + Large number used to construct attention masks + """ + super(MSARowAttentionWithPairBias, self).__init__( + c_m, + c_hidden, + no_heads, + pair_bias=True, + c_z=c_z, + inf=inf, + ) + + +class MSAColumnAttention(nn.Module): + """ + Implements Algorithm 8. + + By rights, this should also be a subclass of MSAAttention. Alas, + most inheritance isn't supported by TorchScript. + """ + + def __init__(self, c_m, c_hidden, no_heads, inf=1e9): + """ + Args: + c_m: + MSA channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + inf: + Large number used to construct attention masks + """ + super(MSAColumnAttention, self).__init__() + + self.c_m = c_m + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + + self._msa_att = MSAAttention( + c_in=c_m, + c_hidden=c_hidden, + no_heads=no_heads, + pair_bias=False, + c_z=None, + inf=inf, + ) + + def forward(self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + mask: + [*, N_seq, N_res] MSA mask + chunk_size: + Size of chunks into which the inputs are split along their + batch dimensions. A low value decreases memory overhead at the + cost of slower execution. Chunking is not performed by default. + """ + # [*, N_res, N_seq, C_in] + m = m.transpose(-2, -3) + if mask is not None: + mask = mask.transpose(-1, -2) + + m = self._msa_att(m, mask=mask, chunk_size=chunk_size) + + # [*, N_seq, N_res, C_in] + m = m.transpose(-2, -3) + if mask is not None: + mask = mask.transpose(-1, -2) + + return m + + +class MSAColumnGlobalAttention(nn.Module): + + def __init__( + self, + c_in, + c_hidden, + no_heads, + inf=1e9, + eps=1e-10, + ): + super(MSAColumnGlobalAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + self.eps = eps + + self.layer_norm_m = nn.LayerNorm(c_in) + + self.global_attention = GlobalAttention( + c_in=c_in, + c_hidden=c_hidden, + no_heads=no_heads, + inf=inf, + eps=eps, + ) + + @torch.jit.ignore + def _chunk( + self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + mha_input = { + "m": m, + "mask": mask, + } + return chunk_layer( + self.global_attention, + mha_input, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + n_seq, n_res, c_in = m.shape[-3:] + + if mask is None: + # [*, N_seq, N_res] + mask = torch.ones( + m.shape[:-1], + dtype=m.dtype, + device=m.device, + ).detach() + + # [*, N_res, N_seq, C_in] + m = m.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, N_res, N_seq, C_in] + m = self.layer_norm_m(m) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self.global_attention(m=m, mask=mask) + + # [*, N_seq, N_res, C_in] + m = m.transpose(-2, -3) + + return m diff --git a/tests/test_autochunk/origin_openfold/outer_product_mean.py b/tests/test_autochunk/origin_openfold/outer_product_mean.py new file mode 100644 index 000000000000..074555ad8a9a --- /dev/null +++ b/tests/test_autochunk/origin_openfold/outer_product_mean.py @@ -0,0 +1,124 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn + +from .primitives import Linear +from .utils.tensor_utils import chunk_layer + + +class OuterProductMean(nn.Module): + """ + Implements Algorithm 10. + """ + + def __init__(self, c_m, c_z, c_hidden, eps=1e-3): + """ + Args: + c_m: + MSA embedding channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Hidden channel dimension + """ + super(OuterProductMean, self).__init__() + + self.c_m = c_m + self.c_z = c_z + self.c_hidden = c_hidden + self.eps = eps + + self.layer_norm = nn.LayerNorm(c_m) + self.linear_1 = Linear(c_m, c_hidden) + self.linear_2 = Linear(c_m, c_hidden) + self.linear_out = Linear(c_hidden**2, c_z, init="final") + + def _opm(self, a, b): + # [*, N_res, N_res, C, C] + outer = torch.einsum("...bac,...dae->...bdce", a, b) + + # [*, N_res, N_res, C * C] + outer = outer.reshape(outer.shape[:-2] + (-1,)) + + # [*, N_res, N_res, C_z] + outer = self.linear_out(outer) + + return outer + + @torch.jit.ignore + def _chunk(self, a: torch.Tensor, b: torch.Tensor, chunk_size: int) -> torch.Tensor: + # Since the "batch dim" in this case is not a true batch dimension + # (in that the shape of the output depends on it), we need to + # iterate over it ourselves + a_reshape = a.reshape((-1,) + a.shape[-3:]) + b_reshape = b.reshape((-1,) + b.shape[-3:]) + out = [] + for a_prime, b_prime in zip(a_reshape, b_reshape): + outer = chunk_layer( + partial(self._opm, b=b_prime), + {"a": a_prime}, + chunk_size=chunk_size, + no_batch_dims=1, + ) + out.append(outer) + outer = torch.stack(out, dim=0) + outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) + + return outer + + def forward(self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + mask: + [*, N_seq, N_res] MSA mask + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + if mask is None: + mask = m.new_ones(m.shape[:-1]) + + # [*, N_seq, N_res, C_m] + m = self.layer_norm(m) + + # [*, N_seq, N_res, C] + mask = mask.unsqueeze(-1) + a = self.linear_1(m) * mask + b = self.linear_2(m) * mask + + a = a.transpose(-2, -3) + b = b.transpose(-2, -3) + + if chunk_size is not None: + outer = self._chunk(a, b, chunk_size) + else: + outer = self._opm(a, b) + + # [*, N_res, N_res, 1] + norm = torch.einsum("...abc,...adc->...bdc", mask, mask) + + # [*, N_res, N_res, C_z] + outer = outer / (self.eps + norm) + + return outer diff --git a/tests/test_autochunk/origin_openfold/pair_transition.py b/tests/test_autochunk/origin_openfold/pair_transition.py new file mode 100644 index 000000000000..9d32adb89b63 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/pair_transition.py @@ -0,0 +1,103 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch.nn as nn + +from .primitives import LayerNorm, Linear +from .utils.tensor_utils import chunk_layer + + +class PairTransition(nn.Module): + """ + Implements Algorithm 15. + """ + + def __init__(self, c_z, n): + """ + Args: + c_z: + Pair transition channel dimension + n: + Factor by which c_z is multiplied to obtain hidden channel + dimension + """ + super(PairTransition, self).__init__() + + self.c_z = c_z + self.n = n + + self.layer_norm = LayerNorm(self.c_z) + self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") + + def _transition(self, z, mask): + # [*, N_res, N_res, C_hidden] + z = self.linear_1(z) + z = self.relu(z) + + # [*, N_res, N_res, C_z] + z = self.linear_2(z) * mask + + return z + + @torch.jit.ignore + def _chunk( + self, + z: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + { + "z": z, + "mask": mask + }, + chunk_size=chunk_size, + no_batch_dims=len(z.shape[:-2]), + ) + + def forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + z: + [*, N_res, N_res, C_z] pair embedding + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + # DISCREPANCY: DeepMind forgets to apply the mask in this module. + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + # [*, N_res, N_res, 1] + mask = mask.unsqueeze(-1) + + # [*, N_res, N_res, C_z] + z = self.layer_norm(z) + + if chunk_size is not None: + z = self._chunk(z, mask, chunk_size) + else: + z = self._transition(z=z, mask=mask) + + return z diff --git a/tests/test_autochunk/origin_openfold/primitives.py b/tests/test_autochunk/origin_openfold/primitives.py new file mode 100644 index 000000000000..5b7556ee3d60 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/primitives.py @@ -0,0 +1,544 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import partial +from typing import Callable, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn as nn +from scipy.stats import truncnorm + +from .utils.checkpointing import get_checkpoint_fn +from .utils.tensor_utils import _chunk_slice, flatten_final_dims, permute_final_dims + + +def _prod(nums): + out = 1 + for n in nums: + out = out * n + return out + + +def _calculate_fan(linear_weight_shape, fan="fan_in"): + fan_out, fan_in = linear_weight_shape + + if fan == "fan_in": + f = fan_in + elif fan == "fan_out": + f = fan_out + elif fan == "fan_avg": + f = (fan_in + fan_out) / 2 + else: + raise ValueError("Invalid fan option") + + return f + + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + f = _calculate_fan(shape, fan) + scale = scale / max(1, f) + a = -2 + b = 2 + std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) + size = _prod(shape) + samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) + samples = np.reshape(samples, shape) + with torch.no_grad(): + weights.copy_(torch.tensor(samples, device=weights.device)) + + +def lecun_normal_init_(weights): + trunc_normal_init_(weights, scale=1.0) + + +def he_normal_init_(weights): + trunc_normal_init_(weights, scale=2.0) + + +def glorot_uniform_init_(weights): + nn.init.xavier_uniform_(weights, gain=1) + + +def final_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def gating_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def normal_init_(weights): + torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization + "relu": He initialization w/ truncated normal distribution + "glorot": Fan-average Glorot uniform initialization + "gating": Weights=0, Bias=1 + "normal": Normal initialization with std=1/sqrt(fan_in) + "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. + Overrides init if not None. + """ + super(Linear, self).__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init_fn is not None: + init_fn(self.weight, self.bias) + else: + if init == "default": + lecun_normal_init_(self.weight) + elif init == "relu": + he_normal_init_(self.weight) + elif init == "glorot": + glorot_uniform_init_(self.weight) + elif init == "gating": + gating_init_(self.weight) + if bias: + with torch.no_grad(): + self.bias.fill_(1.0) + elif init == "normal": + normal_init_(self.weight) + elif init == "final": + final_init_(self.weight) + else: + raise ValueError("Invalid init string.") + + +class LayerNorm(nn.Module): + + def __init__(self, c_in, eps=1e-5): + super(LayerNorm, self).__init__() + + self.c_in = (c_in,) + self.eps = eps + + self.weight = nn.Parameter(torch.ones(c_in)) + self.bias = nn.Parameter(torch.zeros(c_in)) + + def forward(self, x): + out = nn.functional.layer_norm( + x, + self.c_in, + self.weight, + self.bias, + self.eps, + ) + + return out + + +@torch.jit.ignore +def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + s = torch.nn.functional.softmax(t, dim=dim) + + return s + + +#@torch.jit.script +def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor: + # [*, H, Q, C_hidden] + query = permute_final_dims(query, (1, 0, 2)) + + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 2, 0)) + + # [*, H, V, C_hidden] + value = permute_final_dims(value, (1, 0, 2)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + + for b in biases: + a += b + + a = softmax(a, -1) + + # [*, H, Q, C_hidden] + a = torch.matmul(a, value) + + # [*, Q, H, C_hidden] + a = a.transpose(-2, -3) + + return a + + +@torch.jit.ignore +def _attention_chunked_trainable( + query, + key, + value, + biases, + chunk_size, + chunk_dim, + checkpoint, +): + if (checkpoint and len(biases) > 2): + raise ValueError("Checkpointed version permits only permits two bias terms") + + def _checkpointable_attention(q, k, v, b1, b2): + bs = [b for b in [b1, b2] if b is not None] + return _attention(q, k, v, bs) + + o_chunks = [] + checkpoint_fn = get_checkpoint_fn() + count = query.shape[chunk_dim] + for start in range(0, count, chunk_size): + end = start + chunk_size + idx = [slice(None)] * len(query.shape) + idx[chunk_dim] = slice(start, end) + idx_tup = tuple(idx) + q_chunk = query[idx_tup] + k_chunk = key[idx_tup] + v_chunk = value[idx_tup] + + def _slice_bias(b): + idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)) + return b[tuple(idx)] + + if (checkpoint): + bias_1_chunk, bias_2_chunk = [ + _slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2] + ] + + o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk) + else: + bias_chunks = [_slice_bias(b) for b in biases] + + o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) + + o_chunks.append(o_chunk) + + o = torch.cat(o_chunks, dim=chunk_dim) + return o + + +class Attention(nn.Module): + """ + Standard multi-head attention using AlphaFold's default layer + initialization. Allows multiple bias vectors. + """ + + def __init__( + self, + c_q: int, + c_k: int, + c_v: int, + c_hidden: int, + no_heads: int, + gating: bool = True, + ): + """ + Args: + c_q: + Input dimension of query data + c_k: + Input dimension of key data + c_v: + Input dimension of value data + c_hidden: + Per-head hidden dimension + no_heads: + Number of attention heads + gating: + Whether the output should be gated using query data + """ + super(Attention, self).__init__() + + self.c_q = c_q + self.c_k = c_k + self.c_v = c_v + self.c_hidden = c_hidden + self.no_heads = no_heads + self.gating = gating + + # DISCREPANCY: c_hidden is not the per-head channel dimension, as + # stated in the supplement, but the overall channel dimension. + + self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final") + + self.linear_g = None + if self.gating: + self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating") + + self.sigmoid = nn.Sigmoid() + + def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, Q/K/V, H * C_hidden] + q = self.linear_q(q_x) + k = self.linear_k(kv_x) + v = self.linear_v(kv_x) + + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + q /= math.sqrt(self.c_hidden) + + return q, k, v + + def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: + if (self.linear_g is not None): + g = self.sigmoid(self.linear_g(q_x)) + + # [*, Q, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + o = o * g + + # [*, Q, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + def forward( + self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + biases: Optional[List[torch.Tensor]] = None, + use_lma: bool = False, + q_chunk_size: Optional[int] = None, + kv_chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List of biases that broadcast to [*, H, Q, K] + use_lma: + Whether to use low-memory attention + q_chunk_size: + Query chunk size (for LMA) + kv_chunk_size: + Key/Value chunk size (for LMA) + Returns + [*, Q, C_q] attention update + """ + if (biases is None): + biases = [] + if (use_lma and (q_chunk_size is None or kv_chunk_size is None)): + raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must " + "be provided") + + q, k, v = self._prep_qkv(q_x, kv_x) + + if (use_lma): + biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases] + + o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) + else: + o = _attention(q, k, v, biases) + + o = self._wrap_up(o, q_x) + + return o + + +class GlobalAttention(nn.Module): + + def __init__(self, c_in, c_hidden, no_heads, inf, eps): + super(GlobalAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + self.eps = eps + + self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot") + + self.linear_k = Linear( + c_in, + c_hidden, + bias=False, + init="glorot", + ) + self.linear_v = Linear( + c_in, + c_hidden, + bias=False, + init="glorot", + ) + self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") + self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") + + self.sigmoid = nn.Sigmoid() + + def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # [*, N_res, C_in] + q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps) + + # [*, N_res, H * C_hidden] + q = self.linear_q(q) + q *= (self.c_hidden**(-0.5)) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, N_seq, C_hidden] + k = self.linear_k(m) + v = self.linear_v(m) + + # [*, N_res, H, N_seq] + a = torch.matmul( + q, + k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] + ) + bias = (self.inf * (mask - 1))[..., :, None, :] + a += bias + a = softmax(a) + + # [*, N_res, H, C_hidden] + o = torch.matmul( + a, + v, + ) + + # [*, N_res, N_seq, C_hidden] + g = self.sigmoid(self.linear_g(m)) + + # [*, N_res, N_seq, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, N_seq, H, C_hidden] + o = o.unsqueeze(-3) * g + + # [*, N_res, N_seq, H * C_hidden] + o = o.reshape(o.shape[:-2] + (-1,)) + + # [*, N_res, N_seq, C_in] + m = self.linear_o(o) + + return m + + +def _lma( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + biases: List[torch.Tensor], + q_chunk_size: int, + kv_chunk_size: int, +): + no_q, no_kv = q.shape[-3], k.shape[-3] + + # [*, Q, H, C_hidden] + o = q.new_zeros(q.shape) + for q_s in range(0, no_q, q_chunk_size): + q_chunk = q[..., q_s:q_s + q_chunk_size, :, :] + large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases] + + maxes = [] + weights = [] + values = [] + for kv_s in range(0, no_kv, kv_chunk_size): + k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :] + v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :] + small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks] + + a = torch.einsum( + "...qhd,...khd->...hqk", + q_chunk, + k_chunk, + ) + + for b in small_bias_chunks: + a += b + + a = a.transpose(-2, -3) + + max_a = torch.max(a, dim=-1, keepdim=True)[0] + exp_a = torch.exp(a - max_a) + exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) + + maxes.append(max_a.detach().squeeze(-1)) + weights.append(torch.sum(exp_a, dim=-1)) + values.append(exp_v) + + chunk_max = torch.stack(maxes, dim=-3) + chunk_weights = torch.stack(weights, dim=-3) + chunk_values = torch.stack(values, dim=-4) + + global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= max_diffs.unsqueeze(-1) + chunk_weights *= max_diffs + + all_values = torch.sum(chunk_values, dim=-4) + all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) + + q_chunk_out = all_values / all_weights + + o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out + + return o diff --git a/tests/test_autochunk/origin_openfold/structure_module.py b/tests/test_autochunk/origin_openfold/structure_module.py new file mode 100644 index 000000000000..da3b98202a26 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/structure_module.py @@ -0,0 +1,914 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +from fastfold.common.residue_constants import ( + restype_atom14_mask, + restype_atom14_rigid_group_positions, + restype_atom14_to_rigid_group, + restype_rigid_group_default_frame, +) +from fastfold.utils.feats import frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames +from fastfold.utils.geometry.quat_rigid import QuatRigid +from fastfold.utils.geometry.rigid_matrix_vector import Rigid3Array +from fastfold.utils.geometry.vector import Vec3Array +from fastfold.utils.rigid_utils import Rigid, Rotation +from fastfold.utils.tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims + +from .primitives import LayerNorm, Linear, ipa_point_weights_init_ + + +class AngleResnetBlock(nn.Module): + + def __init__(self, c_hidden): + """ + Args: + c_hidden: + Hidden channel dimension + """ + super(AngleResnetBlock, self).__init__() + + self.c_hidden = c_hidden + + self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu") + self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final") + + self.relu = nn.ReLU() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + + s_initial = a + + a = self.relu(a) + a = self.linear_1(a) + a = self.relu(a) + a = self.linear_2(a) + + return a + s_initial + + +class AngleResnet(nn.Module): + """ + Implements Algorithm 20, lines 11-14 + """ + + def __init__(self, c_in: int, c_hidden: int, no_blocks: int, no_angles: int, epsilon: float): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Hidden channel dimension + no_blocks: + Number of resnet blocks + no_angles: + Number of torsion angles to generate + epsilon: + Small constant for normalization + """ + super(AngleResnet, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_blocks = no_blocks + self.no_angles = no_angles + self.eps = epsilon + + self.linear_in = Linear(self.c_in, self.c_hidden) + self.linear_initial = Linear(self.c_in, self.c_hidden) + + self.layers = nn.ModuleList() + for _ in range(self.no_blocks): + layer = AngleResnetBlock(c_hidden=self.c_hidden) + self.layers.append(layer) + + self.linear_out = Linear(self.c_hidden, self.no_angles * 2) + + self.relu = nn.ReLU() + + def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + s: + [*, C_hidden] single embedding + s_initial: + [*, C_hidden] single embedding as of the start of the + StructureModule + Returns: + [*, no_angles, 2] predicted angles + """ + # NOTE: The ReLU's applied to the inputs are absent from the supplement + # pseudocode but present in the source. For maximal compatibility with + # the pretrained weights, I'm going with the source. + + # [*, C_hidden] + s_initial = self.relu(s_initial) + s_initial = self.linear_initial(s_initial) + s = self.relu(s) + s = self.linear_in(s) + s = s + s_initial + + for l in self.layers: + s = l(s) + + s = self.relu(s) + + # [*, no_angles * 2] + s = self.linear_out(s) + + # [*, no_angles, 2] + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt(torch.clamp( + torch.sum(s**2, dim=-1, keepdim=True), + min=self.eps, + )) + s = s / norm_denom + + return unnormalized_s, s + + +class PointProjection(nn.Module): + + def __init__( + self, + c_hidden: int, + num_points: int, + no_heads: int, + return_local_points: bool = False, + ): + super().__init__() + self.return_local_points = return_local_points + self.no_heads = no_heads + + self.linear = Linear(c_hidden, no_heads * 3 * num_points) + + def forward( + self, + activations: torch.Tensor, + rigids: Rigid3Array, + ) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array]]: + # TODO: Needs to run in high precision during training + points_local = self.linear(activations) + points_local = points_local.reshape( + *points_local.shape[:-1], + self.no_heads, + -1, + ) + points_local = torch.split(points_local, points_local.shape[-1] // 3, dim=-1) + points_local = Vec3Array(*points_local) + points_global = rigids[..., None, None].apply_to_point(points_local) + + if self.return_local_points: + return points_global, points_local + + return points_global + + +class InvariantPointAttention(nn.Module): + """ + Implements Algorithm 22. + """ + + def __init__( + self, + c_s: int, + c_z: int, + c_hidden: int, + no_heads: int, + no_qk_points: int, + no_v_points: int, + inf: float = 1e5, + eps: float = 1e-8, + is_multimer: bool = False, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_hidden: + Hidden channel dimension + no_heads: + Number of attention heads + no_qk_points: + Number of query/key points to generate + no_v_points: + Number of value points to generate + """ + super(InvariantPointAttention, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_hidden = c_hidden + self.no_heads = no_heads + self.no_qk_points = no_qk_points + self.no_v_points = no_v_points + self.inf = inf + self.eps = eps + self.is_multimer = is_multimer + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + if not self.is_multimer: + hc = self.c_hidden * self.no_heads + self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer)) + self.linear_kv = Linear(self.c_s, 2 * hc) + + hpq = self.no_heads * self.no_qk_points * 3 + self.linear_q_points = Linear(self.c_s, hpq) + + hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 + self.linear_kv_points = Linear(self.c_s, hpkv) + + # hpv = self.no_heads * self.no_v_points * 3 + + else: + hc = self.c_hidden * self.no_heads + self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer)) + self.linear_q_points = PointProjection(self.c_s, self.no_qk_points, self.no_heads) + + self.linear_k = Linear(self.c_s, hc, bias=False) + self.linear_v = Linear(self.c_s, hc, bias=False) + self.linear_k_points = PointProjection( + self.c_s, + self.no_qk_points, + self.no_heads, + ) + + self.linear_v_points = PointProjection( + self.c_s, + self.no_v_points, + self.no_heads, + ) + self.linear_b = Linear(self.c_z, self.no_heads) + + self.head_weights = nn.Parameter(torch.zeros((no_heads))) + ipa_point_weights_init_(self.head_weights) + + concat_out_dim = self.no_heads * (self.c_z + self.c_hidden + self.no_v_points * 4) + self.linear_out = Linear(concat_out_dim, self.c_s, init="final") + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: torch.Tensor, + r: Union[Rigid, Rigid3Array], + mask: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + ####################################### + # Generate scalar and point activations + ####################################### + + # The following two blocks are equivalent + # They're separated only to preserve compatibility with old AF weights + if self.is_multimer: + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, P_qk] + q_pts = self.linear_q_points(s, r) + # [*, N_res, H * C_hidden] + k = self.linear_k(s) + v = self.linear_v(s) + + # [*, N_res, H, C_hidden] + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, P_qk, 3] + k_pts = self.linear_k_points(s, r) + + # [*, N_res, H, P_v, 3] + v_pts = self.linear_v_points(s, r) + else: + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + kv = self.linear_kv(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.c_hidden, dim=-1) + + # [*, N_res, H * P_q * 3] + q_pts = self.linear_q_points(s) + + # This is kind of clunky, but it's how the original does it + # [*, N_res, H * P_q, 3] + q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) + q_pts = torch.stack(q_pts, dim=-1) + q_pts = r[..., None].apply(q_pts) + + # [*, N_res, H, P_q, 3] + q_pts = q_pts.view(q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)) + + # [*, N_res, H * (P_q + P_v) * 3] + kv_pts = self.linear_kv_points(s) + + # [*, N_res, H * (P_q + P_v), 3] + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = r[..., None].apply(kv_pts) + + # [*, N_res, H, (P_q + P_v), 3] + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split(kv_pts, [self.no_qk_points, self.no_v_points], dim=-2) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z) + + # [*, H, N_res, N_res] + a = torch.matmul( + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + a *= math.sqrt(1.0 / (3 * self.c_hidden)) + a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)) + + if self.is_multimer: + # [*, N_res, N_res, H, P_q, 3] + pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :] + # [*, N_res, N_res, H, P_q] + pt_att = sum([c**2 for c in pt_att]) + else: + # [*, N_res, N_res, H, P_q, 3] + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_att**2 + # [*, N_res, N_res, H, P_q] + pt_att = sum(torch.unbind(pt_att, dim=-1)) + + head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1))) + head_weights = head_weights * math.sqrt(1.0 / (3 * (self.no_qk_points * 9.0 / 2))) + pt_att = pt_att * head_weights + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # As DeepMind explains, this manual matmul ensures that the operation + # happens in float32. + if self.is_multimer: + # [*, N_res, H, P_v] + o_pt = v_pts * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1) + o_pt = o_pt.sum(dim=-3) + + # [*, N_res, H, P_v] + o_pt = r[..., None, None].apply_inverse_to_point(o_pt) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,)) + + # [*, N_res, H * P_v] + o_pt_norm = o_pt.norm(self.eps) + else: + # [*, H, 3, N_res, P_v] + o_pt = torch.sum( + (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), + dim=-2, + ) + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.eps), 2) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + + # [*, N_res, H, C_z] + o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype)) + + # [*, N_res, H * C_z] + o_pair = flatten_final_dims(o_pair, 2) + + # [*, N_res, C_s] + if self.is_multimer: + s = self.linear_out(torch.cat((o, *o_pt, o_pt_norm, o_pair), dim=-1).to(dtype=z.dtype)) + else: + s = self.linear_out( + torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z.dtype)) + + return s + + +class BackboneUpdate(nn.Module): + """ + Implements part of Algorithm 23. + """ + + def __init__(self, c_s: int): + """ + Args: + c_s: + Single representation channel dimension + """ + super(BackboneUpdate, self).__init__() + + self.c_s = c_s + + self.linear = Linear(self.c_s, 6, init="final") + + def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + [*, N_res, C_s] single representation + Returns: + [*, N_res, 6] update vector + """ + # [*, 6] + update = self.linear(s) + + return update + + +class StructureModuleTransitionLayer(nn.Module): + + def __init__(self, c: int): + super(StructureModuleTransitionLayer, self).__init__() + + self.c = c + + self.linear_1 = Linear(self.c, self.c, init="relu") + self.linear_2 = Linear(self.c, self.c, init="relu") + self.linear_3 = Linear(self.c, self.c, init="final") + + self.relu = nn.ReLU() + + def forward(self, s: torch.Tensor): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + s = s + s_initial + + return s + + +class StructureModuleTransition(nn.Module): + + def __init__(self, c: int, num_layers: int, dropout_rate: float): + super(StructureModuleTransition, self).__init__() + + self.c = c + self.num_layers = num_layers + self.dropout_rate = dropout_rate + + self.layers = nn.ModuleList() + for _ in range(self.num_layers): + l = StructureModuleTransitionLayer(self.c) + self.layers.append(l) + + self.dropout = nn.Dropout(self.dropout_rate) + self.layer_norm = LayerNorm(self.c) + + def forward(self, s: torch.Tensor) -> torch.Tensor: + for l in self.layers: + s = l(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class StructureModule(nn.Module): + + def __init__( + self, + c_s: int, + c_z: int, + c_ipa: int, + c_resnet: int, + no_heads_ipa: int, + no_qk_points: int, + no_v_points: int, + dropout_rate: float, + no_blocks: int, + no_transition_layers: int, + no_resnet_blocks: int, + no_angles: int, + trans_scale_factor: float, + epsilon: float, + inf: float, + is_multimer: bool = False, + **kwargs, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_ipa: + IPA hidden channel dimension + c_resnet: + Angle resnet (Alg. 23 lines 11-14) hidden channel dimension + no_heads_ipa: + Number of IPA heads + no_qk_points: + Number of query/key points to generate during IPA + no_v_points: + Number of value points to generate during IPA + dropout_rate: + Dropout rate used throughout the layer + no_blocks: + Number of structure module blocks + no_transition_layers: + Number of layers in the single representation transition + (Alg. 23 lines 8-9) + no_resnet_blocks: + Number of blocks in the angle resnet + no_angles: + Number of angles to generate in the angle resnet + trans_scale_factor: + Scale of single representation transition hidden dimension + epsilon: + Small number used in angle resnet normalization + inf: + Large number used for attention masking + is_multimer: + whether running under multimer mode + """ + super(StructureModule, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_ipa = c_ipa + self.c_resnet = c_resnet + self.no_heads_ipa = no_heads_ipa + self.no_qk_points = no_qk_points + self.no_v_points = no_v_points + self.dropout_rate = dropout_rate + self.no_blocks = no_blocks + self.no_transition_layers = no_transition_layers + self.no_resnet_blocks = no_resnet_blocks + self.no_angles = no_angles + self.trans_scale_factor = trans_scale_factor + self.epsilon = epsilon + self.inf = inf + self.is_multimer = is_multimer + + # To be lazily initialized later + self.default_frames = None + self.group_idx = None + self.atom_mask = None + self.lit_positions = None + + self.layer_norm_s = LayerNorm(self.c_s) + self.layer_norm_z = LayerNorm(self.c_z) + + self.linear_in = Linear(self.c_s, self.c_s) + + self.ipa = InvariantPointAttention( + self.c_s, + self.c_z, + self.c_ipa, + self.no_heads_ipa, + self.no_qk_points, + self.no_v_points, + inf=self.inf, + eps=self.epsilon, + is_multimer=self.is_multimer, + ) + + self.ipa_dropout = nn.Dropout(self.dropout_rate) + self.layer_norm_ipa = LayerNorm(self.c_s) + + self.transition = StructureModuleTransition( + self.c_s, + self.no_transition_layers, + self.dropout_rate, + ) + + if is_multimer: + self.bb_update = QuatRigid(self.c_s, full_quat=False) + else: + self.bb_update = BackboneUpdate(self.c_s) + + self.angle_resnet = AngleResnet( + self.c_s, + self.c_resnet, + self.no_resnet_blocks, + self.no_angles, + self.epsilon, + ) + + def _forward_monomer( + self, + s: torch.Tensor, + z: torch.Tensor, + aatype: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Dict[str, Any]: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + aatype: + [*, N_res] amino acid indices + mask: + Optional [*, N_res] sequence mask + Returns: + A dictionary of outputs + """ + if mask is None: + # [*, N] + mask = s.new_ones(s.shape[:-1]) + + # [*, N, C_s] + s = self.layer_norm_s(s) + + # [*, N, N, C_z] + z = self.layer_norm_z(z) + + # [*, N, C_s] + s_initial = s + s = self.linear_in(s) + + # [*, N] + rigids = Rigid.identity( + s.shape[:-1], + s.dtype, + s.device, + self.training, + fmt="quat", + ) + outputs = [] + for i in range(self.no_blocks): + # [*, N, C_s] + s = s + self.ipa(s, z, rigids, mask) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # [*, N] + rigids = rigids.compose_q_update_vec(self.bb_update(s)) + + # To hew as closely as possible to AlphaFold, we convert our + # quaternion-based transformations to rotation-matrix ones + # here + backb_to_global = Rigid( + Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None), + rigids.get_trans(), + ) + + backb_to_global = backb_to_global.scale_translation(self.trans_scale_factor) + + # [*, N, 7, 2] + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + all_frames_to_global = self.torsion_angles_to_frames( + backb_to_global, + angles, + aatype, + ) + + pred_xyz = self.frames_and_literature_positions_to_atom14_pos( + all_frames_to_global, + aatype, + ) + + scaled_rigids = rigids.scale_translation(self.trans_scale_factor) + + preds = { + "frames": scaled_rigids.to_tensor_7(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "unnormalized_angles": unnormalized_angles, + "angles": angles, + "positions": pred_xyz, + } + + outputs.append(preds) + + if i < (self.no_blocks - 1): + rigids = rigids.stop_rot_gradient() + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + + return outputs + + def _forward_multimer( + self, + s: torch.Tensor, + z: torch.Tensor, + aatype: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Dict[str, Any]: + if mask is None: + # [*, N] + mask = s.new_ones(s.shape[:-1]) + + # [*, N, C_s] + s = self.layer_norm_s(s) + + # [*, N, N, C_z] + z = self.layer_norm_z(z) + + # [*, N, C_s] + s_initial = s + s = self.linear_in(s) + + # [*, N] + rigids = Rigid3Array.identity( + s.shape[:-1], + s.device, + ) + outputs = [] + for i in range(self.no_blocks): + # [*, N, C_s] + s = s + self.ipa(s, z, rigids, mask) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # [*, N] + rigids = rigids @ self.bb_update(s) + + # [*, N, 7, 2] + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + all_frames_to_global = self.torsion_angles_to_frames( + rigids.scale_translation(self.trans_scale_factor), + angles, + aatype, + ) + + pred_xyz = self.frames_and_literature_positions_to_atom14_pos( + all_frames_to_global, + aatype, + ) + + preds = { + "frames": rigids.scale_translation(self.trans_scale_factor).to_tensor(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "unnormalized_angles": unnormalized_angles, + "angles": angles, + "positions": pred_xyz.to_tensor(), + } + + outputs.append(preds) + + if i < (self.no_blocks - 1): + rigids = rigids.stop_rot_gradient() + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + + return outputs + + def forward( + self, + s: torch.Tensor, + z: torch.Tensor, + aatype: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ): + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + aatype: + [*, N_res] amino acid indices + mask: + Optional [*, N_res] sequence mask + Returns: + A dictionary of outputs + """ + if self.is_multimer: + outputs = self._forward_multimer(s, z, aatype, mask) + else: + outputs = self._forward_monomer(s, z, aatype, mask) + + return outputs + + def _init_residue_constants(self, float_dtype: torch.dtype, device: torch.device): + if self.default_frames is None: + self.default_frames = torch.tensor( + restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + if self.group_idx is None: + self.group_idx = torch.tensor( + restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ) + if self.atom_mask is None: + self.atom_mask = torch.tensor( + restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + if self.lit_positions is None: + self.lit_positions = torch.tensor( + restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + + def torsion_angles_to_frames(self, r: Union[Rigid, Rigid3Array], alpha: torch.Tensor, f): + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(alpha.dtype, alpha.device) + # Separated purely to make testing less annoying + return torsion_angles_to_frames(r, alpha, f, self.default_frames) + + def frames_and_literature_positions_to_atom14_pos( + self, + r: Union[Rigid, Rigid3Array], + f # [*, N, 8] # [*, N] + ): + # Lazily initialize the residue constants on the correct device + if type(r) == Rigid: + self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) + elif type(r) == Rigid3Array: + self._init_residue_constants(r.dtype, r.device) + else: + raise ValueError("Unknown rigid type") + return frames_and_literature_positions_to_atom14_pos( + r, + f, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) diff --git a/tests/test_autochunk/origin_openfold/template.py b/tests/test_autochunk/origin_openfold/template.py new file mode 100644 index 000000000000..6c4c0f42875e --- /dev/null +++ b/tests/test_autochunk/origin_openfold/template.py @@ -0,0 +1,308 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from functools import partial +from typing import List, Optional + +import torch +import torch.nn as nn +from fastfold.model.nn.dropout import DropoutColumnwise, DropoutRowwise +from fastfold.model.nn.pair_transition import PairTransition +from fastfold.model.nn.primitives import Attention, LayerNorm, Linear +from fastfold.model.nn.triangular_attention import TriangleAttentionEndingNode, TriangleAttentionStartingNode +from fastfold.model.nn.triangular_multiplicative_update import ( + TriangleMultiplicationIncoming, + TriangleMultiplicationOutgoing, +) +from fastfold.utils.checkpointing import checkpoint_blocks +from fastfold.utils.tensor_utils import chunk_layer, flatten_final_dims, permute_final_dims + + +class TemplatePointwiseAttention(nn.Module): + """ + Implements Algorithm 17. + """ + + def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs): + """ + Args: + c_t: + Template embedding channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Hidden channel dimension + """ + super(TemplatePointwiseAttention, self).__init__() + + self.c_t = c_t + self.c_z = c_z + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + + self.mha = Attention( + self.c_z, + self.c_t, + self.c_t, + self.c_hidden, + self.no_heads, + gating=False, + ) + + def _chunk( + self, + z: torch.Tensor, + t: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + ) -> torch.Tensor: + mha_inputs = { + "q_x": z, + "kv_x": t, + "biases": biases, + } + return chunk_layer( + self.mha, + mha_inputs, + chunk_size=chunk_size, + no_batch_dims=len(z.shape[:-2]), + ) + + def forward(self, + t: torch.Tensor, + z: torch.Tensor, + template_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None) -> torch.Tensor: + """ + Args: + t: + [*, N_templ, N_res, N_res, C_t] template embedding + z: + [*, N_res, N_res, C_t] pair embedding + template_mask: + [*, N_templ] template mask + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + if template_mask is None: + template_mask = t.new_ones(t.shape[:-3]) + + bias = self.inf * (template_mask[..., None, None, None, None, :] - 1) + + # [*, N_res, N_res, 1, C_z] + z = z.unsqueeze(-2) + + # [*, N_res, N_res, N_temp, C_t] + t = permute_final_dims(t, (1, 2, 0, 3)) + + # [*, N_res, N_res, 1, C_z] + biases = [bias] + if chunk_size is not None: + z = self._chunk(z, t, biases, chunk_size) + else: + z = self.mha(q_x=z, kv_x=t, biases=biases) + + # [*, N_res, N_res, C_z] + z = z.squeeze(-2) + + return z + + +class TemplatePairStackBlock(nn.Module): + + def __init__( + self, + c_t: int, + c_hidden_tri_att: int, + c_hidden_tri_mul: int, + no_heads: int, + pair_transition_n: int, + dropout_rate: float, + inf: float, + is_multimer: bool = False, + **kwargs, + ): + super(TemplatePairStackBlock, self).__init__() + + self.c_t = c_t + self.c_hidden_tri_att = c_hidden_tri_att + self.c_hidden_tri_mul = c_hidden_tri_mul + self.no_heads = no_heads + self.pair_transition_n = pair_transition_n + self.dropout_rate = dropout_rate + self.inf = inf + self.is_multimer = is_multimer + + self.dropout_row = DropoutRowwise(self.dropout_rate) + self.dropout_col = DropoutColumnwise(self.dropout_rate) + + self.tri_att_start = TriangleAttentionStartingNode( + self.c_t, + self.c_hidden_tri_att, + self.no_heads, + inf=inf, + ) + self.tri_att_end = TriangleAttentionEndingNode( + self.c_t, + self.c_hidden_tri_att, + self.no_heads, + inf=inf, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + self.c_t, + self.c_hidden_tri_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + self.c_t, + self.c_hidden_tri_mul, + ) + + self.pair_transition = PairTransition( + self.c_t, + self.pair_transition_n, + ) + + def forward(self, z: torch.Tensor, mask: torch.Tensor, chunk_size: Optional[int] = None, _mask_trans: bool = True): + single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)] + single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)] + if not self.is_multimer: + for i in range(len(single_templates)): + single = single_templates[i] + single_mask = single_templates_masks[i] + + single = single + self.dropout_row(self.tri_att_start(single, chunk_size=chunk_size, mask=single_mask)) + single = single + self.dropout_col(self.tri_att_end(single, chunk_size=chunk_size, mask=single_mask)) + single = single + self.dropout_row(self.tri_mul_out(single, mask=single_mask)) + single = single + self.dropout_row(self.tri_mul_in(single, mask=single_mask)) + single = single + self.pair_transition( + single, + mask=single_mask if _mask_trans else None, + chunk_size=chunk_size, + ) + + single_templates[i] = single + else: + for i in range(len(single_templates)): + single = single_templates[i] + single_mask = single_templates_masks[i] + + single = single + self.dropout_row(self.tri_att_start(single, chunk_size=chunk_size, mask=single_mask)) + single = single + self.dropout_col(self.tri_att_end(single, chunk_size=chunk_size, mask=single_mask)) + single = single + self.dropout_row(self.tri_mul_out(single, mask=single_mask)) + single = single + self.dropout_row(self.tri_mul_in(single, mask=single_mask)) + single = single + self.pair_transition( + single, + mask=single_mask if _mask_trans else None, + chunk_size=chunk_size, + ) + single_templates[i] = single + + z = torch.cat(single_templates, dim=-4) + + return z + + +class TemplatePairStack(nn.Module): + """ + Implements Algorithm 16. + """ + + def __init__( + self, + c_t, + c_hidden_tri_att, + c_hidden_tri_mul, + no_blocks, + no_heads, + pair_transition_n, + dropout_rate, + blocks_per_ckpt, + inf=1e9, + **kwargs, + ): + """ + Args: + c_t: + Template embedding channel dimension + c_hidden_tri_att: + Per-head hidden dimension for triangular attention + c_hidden_tri_att: + Hidden dimension for triangular multiplication + no_blocks: + Number of blocks in the stack + pair_transition_n: + Scale of pair transition (Alg. 15) hidden dimension + dropout_rate: + Dropout rate used throughout the stack + blocks_per_ckpt: + Number of blocks per activation checkpoint. None disables + activation checkpointing + """ + super(TemplatePairStack, self).__init__() + + self.blocks_per_ckpt = blocks_per_ckpt + + self.blocks = nn.ModuleList() + for _ in range(no_blocks): + block = TemplatePairStackBlock( + c_t=c_t, + c_hidden_tri_att=c_hidden_tri_att, + c_hidden_tri_mul=c_hidden_tri_mul, + no_heads=no_heads, + pair_transition_n=pair_transition_n, + dropout_rate=dropout_rate, + inf=inf, + ) + self.blocks.append(block) + + self.layer_norm = LayerNorm(c_t) + + def forward( + self, + t: torch.tensor, + mask: torch.tensor, + chunk_size: int, + _mask_trans: bool = True, + ): + """ + Args: + t: + [*, N_templ, N_res, N_res, C_t] template embedding + mask: + [*, N_templ, N_res, N_res] mask + Returns: + [*, N_templ, N_res, N_res, C_t] template embedding update + """ + if (mask.shape[-3] == 1): + expand_idx = list(mask.shape) + expand_idx[-3] = t.shape[-4] + mask = mask.expand(*expand_idx) + + t, = checkpoint_blocks( + blocks=[partial( + b, + mask=mask, + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) for b in self.blocks], + args=(t,), + blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, + ) + + t = self.layer_norm(t) + + return t diff --git a/tests/test_autochunk/origin_openfold/triangular_attention.py b/tests/test_autochunk/origin_openfold/triangular_attention.py new file mode 100644 index 000000000000..4d2e8efc419d --- /dev/null +++ b/tests/test_autochunk/origin_openfold/triangular_attention.py @@ -0,0 +1,130 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import partial, partialmethod +from typing import List, Optional + +import torch +import torch.nn as nn + +from .primitives import Attention, LayerNorm, Linear +from .utils.tensor_utils import chunk_layer, flatten_final_dims, permute_final_dims + + +class TriangleAttention(nn.Module): + + def __init__(self, c_in, c_hidden, no_heads, starting, inf=1e9): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Overall hidden channel dimension (not per-head) + no_heads: + Number of attention heads + """ + super(TriangleAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.starting = starting + self.inf = inf + + self.layer_norm = LayerNorm(self.c_in) + + self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") + + self.mha = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + ) -> torch.Tensor: + mha_inputs = { + "q_x": x, + "kv_x": x, + "biases": biases, + } + return chunk_layer( + partial(self.mha), + mha_inputs, + chunk_size=chunk_size, + no_batch_dims=len(x.shape[:-2]), + ) + + def forward(self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None) -> torch.Tensor: + """ + Args: + x: + [*, I, J, C_in] input tensor (e.g. the pair representation) + Returns: + [*, I, J, C_in] output tensor + """ + if mask is None: + # [*, I, J] + mask = x.new_ones(x.shape[:-1],) + + # Shape annotations assume self.starting. Else, I and J are flipped + if not self.starting: + x = x.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, I, J, C_in] + x = self.layer_norm(x) + + # [*, I, 1, 1, J] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # [*, H, I, J] + triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) + + # [*, 1, H, I, J] + triangle_bias = triangle_bias.unsqueeze(-4) + + biases = [mask_bias, triangle_bias] + + if chunk_size is not None: + x = self._chunk(x, biases, chunk_size) + else: + x = self.mha(q_x=x, kv_x=x, biases=biases) + + if not self.starting: + x = x.transpose(-2, -3) + + return x + + +class TriangleAttentionStartingNode(TriangleAttention): + """ + Implements Algorithm 13. + """ + + __init__ = partialmethod(TriangleAttention.__init__, starting=True) + + +class TriangleAttentionEndingNode(TriangleAttention): + """ + Implements Algorithm 14. + """ + + __init__ = partialmethod(TriangleAttention.__init__, starting=False) diff --git a/tests/test_autochunk/origin_openfold/triangular_multiplicative_update.py b/tests/test_autochunk/origin_openfold/triangular_multiplicative_update.py new file mode 100644 index 000000000000..f02e9033ae15 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/triangular_multiplicative_update.py @@ -0,0 +1,129 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partialmethod +from typing import Optional + +import torch +import torch.nn as nn + +from .primitives import LayerNorm, Linear +from .utils.tensor_utils import permute_final_dims + + +class TriangleMultiplicativeUpdate(nn.Module): + """ + Implements Algorithms 11 and 12. + """ + + def __init__(self, c_z, c_hidden, _outgoing=True): + """ + Args: + c_z: + Input channel dimension + c: + Hidden channel dimension + """ + super(TriangleMultiplicativeUpdate, self).__init__() + self.c_z = c_z + self.c_hidden = c_hidden + self._outgoing = _outgoing + + self.linear_a_p = Linear(self.c_z, self.c_hidden) + self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating") + self.linear_b_p = Linear(self.c_z, self.c_hidden) + self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating") + self.linear_g = Linear(self.c_z, self.c_z, init="gating") + self.linear_z = Linear(self.c_hidden, self.c_z, init="final") + + self.layer_norm_in = LayerNorm(self.c_z) + self.layer_norm_out = LayerNorm(self.c_hidden) + + self.sigmoid = nn.Sigmoid() + + def _combine_projections( + self, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError("This method needs to be overridden") + + def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + x: + [*, N_res, N_res, C_z] input tensor + mask: + [*, N_res, N_res] input mask + Returns: + [*, N_res, N_res, C_z] output tensor + """ + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + z = self.layer_norm_in(z) + a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z)) + a = a * mask + b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) + b = b * mask + x = self._combine_projections(a, b) + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.sigmoid(self.linear_g(z)) + z = x * g + + return z + + +class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): + """ + Implements Algorithm 11. + """ + + def _combine_projections( + self, + a: torch.Tensor, # [*, N_i, N_k, C] + b: torch.Tensor, # [*, N_j, N_k, C] + ): + # [*, C, N_i, N_j] + p = torch.matmul( + permute_final_dims(a, (2, 0, 1)), + permute_final_dims(b, (2, 1, 0)), + ) + + # [*, N_i, N_j, C] + return permute_final_dims(p, (1, 2, 0)) + + +class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): + """ + Implements Algorithm 12. + """ + + def _combine_projections( + self, + a: torch.Tensor, # [*, N_k, N_i, C] + b: torch.Tensor, # [*, N_k, N_j, C] + ): + # [*, C, N_i, N_j] + p = torch.matmul( + permute_final_dims(a, (2, 1, 0)), + permute_final_dims(b, (2, 0, 1)), + ) + + # [*, N_i, N_j, C] + return permute_final_dims(p, (1, 2, 0)) diff --git a/tests/test_autochunk/origin_openfold/utils/all_atom_multimer.py b/tests/test_autochunk/origin_openfold/utils/all_atom_multimer.py new file mode 100644 index 000000000000..0b3199698bbd --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/all_atom_multimer.py @@ -0,0 +1,415 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ops for all atom representations.""" + +from functools import partial +from typing import Dict, Text, Tuple + +import numpy as np +import torch +from fastfold.common import residue_constants as rc +from fastfold.utils import geometry, tensor_utils + + +def squared_difference(x, y): + return np.square(x - y) + + +def get_rc_tensor(rc_np, aatype): + return torch.tensor(rc_np, device=aatype.device)[aatype] + + +def atom14_to_atom37( + atom14_data: torch.Tensor, # (*, N, 14, ...) + aatype: torch.Tensor # (*, N) +) -> torch.Tensor: # (*, N, 37, ...) + """Convert atom14 to atom37 representation.""" + idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype) + no_batch_dims = len(aatype.shape) - 1 + atom37_data = tensor_utils.batched_gather(atom14_data, + idx_atom37_to_atom14, + dim=no_batch_dims + 1, + no_batch_dims=no_batch_dims + 1) + atom37_mask = get_rc_tensor(rc.RESTYPE_ATOM37_MASK, aatype) + if len(atom14_data.shape) == no_batch_dims + 2: + atom37_data *= atom37_mask + elif len(atom14_data.shape) == no_batch_dims + 3: + atom37_data *= atom37_mask[..., None].astype(atom37_data.dtype) + else: + raise ValueError("Incorrectly shaped data") + return atom37_data + + +def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask): + """Convert Atom37 positions to Atom14 positions.""" + residx_atom14_to_atom37 = get_rc_tensor(rc.RESTYPE_ATOM14_TO_ATOM37, aatype) + no_batch_dims = len(aatype.shape) + atom14_mask = tensor_utils.batched_gather( + all_atom_mask, + residx_atom14_to_atom37, + dim=no_batch_dims + 1, + no_batch_dims=no_batch_dims + 1, + ).to(torch.float32) + # create a mask for known groundtruth positions + atom14_mask *= get_rc_tensor(rc.RESTYPE_ATOM14_MASK, aatype) + # gather the groundtruth positions + atom14_positions = tensor_utils.batched_gather( + all_atom_pos, + residx_atom14_to_atom37, + dim=no_batch_dims + 1, + no_batch_dims=no_batch_dims + 1, + ), + atom14_positions = atom14_mask * atom14_positions + return atom14_positions, atom14_mask + + +def get_alt_atom14(aatype, positions: torch.Tensor, mask): + """Get alternative atom14 positions.""" + # pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14) + renaming_transform = get_rc_tensor(rc.RENAMING_MATRICES, aatype) + alternative_positions = torch.sum(positions[..., None, :] * renaming_transform[..., None], dim=-2) + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position) + alternative_mask = torch.sum(mask[..., None] * renaming_transform, dim=-2) + + return alternative_positions, alternative_mask + + +def atom37_to_frames( + aatype: torch.Tensor, # (...) + all_atom_positions: torch.Tensor, # (..., 37) + all_atom_mask: torch.Tensor, # (..., 37) +) -> Dict[Text, torch.Tensor]: + """Computes the frames for the up to 8 rigid groups for each residue.""" + # 0: 'backbone group', + # 1: 'pre-omega-group', (empty) + # 2: 'phi-group', (currently empty, because it defines only hydrogens) + # 3: 'psi-group', + # 4,5,6,7: 'chi1,2,3,4-group' + + no_batch_dims = len(aatype.shape) - 1 + + # Compute the gather indices for all residues in the chain. + # shape (N, 8, 3) + residx_rigidgroup_base_atom37_idx = get_rc_tensor(rc.RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype) + + # Gather the base atom positions for each rigid group. + base_atom_pos = tensor_utils.batched_gather( + all_atom_positions, + residx_rigidgroup_base_atom37_idx, + dim=no_batch_dims + 1, + batch_dims=no_batch_dims + 1, + ) + + # Compute the Rigids. + point_on_neg_x_axis = base_atom_pos[..., :, :, 0] + origin = base_atom_pos[..., :, :, 1] + point_on_xy_plane = base_atom_pos[..., :, :, 2] + gt_rotation = geometry.Rot3Array.from_two_vectors(origin - point_on_neg_x_axis, point_on_xy_plane - origin) + + gt_frames = geometry.Rigid3Array(gt_rotation, origin) + + # Compute a mask whether the group exists. + # (N, 8) + group_exists = get_rc_tensor(rc.RESTYPE_RIGIDGROUP_MASK, aatype) + + # Compute a mask whether ground truth exists for the group + gt_atoms_exist = tensor_utils.batched_gather( # shape (N, 8, 3) + all_atom_mask.to(dtype=torch.float32), + residx_rigidgroup_base_atom37_idx, + batch_dims=no_batch_dims + 1, + ) + gt_exists = torch.min(gt_atoms_exist, dim=-1) * group_exists # (N, 8) + + # Adapt backbone frame to old convention (mirror x-axis and z-axis). + rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) + rots[0, 0, 0] = -1 + rots[0, 2, 2] = -1 + gt_frames = gt_frames.compose_rotation(geometry.Rot3Array.from_array(torch.tensor(rots, device=aatype.device))) + + # The frames for ambiguous rigid groups are just rotated by 180 degree around + # the x-axis. The ambiguous group is always the last chi-group. + restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) + restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) + + for resname, _ in rc.residue_atom_renaming_swaps.items(): + restype = rc.restype_order[rc.restype_3to1[resname]] + chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 + + # Gather the ambiguity information for each residue. + residx_rigidgroup_is_ambiguous = torch.tensor( + restype_rigidgroup_is_ambiguous, + device=aatype.device, + )[aatype] + ambiguity_rot = torch.tensor( + restype_rigidgroup_rots, + device=aatype.device, + )[aatype] + ambiguity_rot = geometry.Rot3Array.from_array(torch.Tensor(ambiguity_rot, device=aatype.device)) + + # Create the alternative ground truth frames. + alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot) + + fix_shape = lambda x: x.reshape(x.shape[:-2] + (8,)) + + # reshape back to original residue layout + gt_frames = fix_shape(gt_frames) + gt_exists = fix_shape(gt_exists) + group_exists = fix_shape(group_exists) + residx_rigidgroup_is_ambiguous = fix_shape(residx_rigidgroup_is_ambiguous) + alt_gt_frames = fix_shape(alt_gt_frames) + + return { + 'rigidgroups_gt_frames': gt_frames, # Rigid (..., 8) + 'rigidgroups_gt_exists': gt_exists, # (..., 8) + 'rigidgroups_group_exists': group_exists, # (..., 8) + 'rigidgroups_group_is_ambiguous': residx_rigidgroup_is_ambiguous, # (..., 8) + 'rigidgroups_alt_gt_frames': alt_gt_frames, # Rigid (..., 8) + } + + +def torsion_angles_to_frames( + aatype: torch.Tensor, # (N) + backb_to_global: geometry.Rigid3Array, # (N) + torsion_angles_sin_cos: torch.Tensor # (N, 7, 2) +) -> geometry.Rigid3Array: # (N, 8) + """Compute rigid group frames from torsion angles.""" + # Gather the default frames for all rigid groups. + # geometry.Rigid3Array with shape (N, 8) + m = get_rc_tensor(rc.restype_rigid_group_default_frame, aatype) + default_frames = geometry.Rigid3Array.from_array4x4(m) + + # Create the rotation matrices according to the given angles (each frame is + # defined such that its rotation is around the x-axis). + sin_angles = torsion_angles_sin_cos[..., 0] + cos_angles = torsion_angles_sin_cos[..., 1] + + # insert zero rotation for backbone group. + num_residues = aatype.shape[-1] + sin_angles = torch.cat([ + torch.zeros_like(aatype).unsqueeze(), + sin_angles, + ], dim=-1) + cos_angles = torch.cat([torch.ones_like(aatype).unsqueeze(), cos_angles], dim=-1) + zeros = torch.zeros_like(sin_angles) + ones = torch.ones_like(sin_angles) + + # all_rots are geometry.Rot3Array with shape (..., N, 8) + all_rots = geometry.Rot3Array(ones, zeros, zeros, zeros, cos_angles, -sin_angles, zeros, sin_angles, cos_angles) + + # Apply rotations to the frames. + all_frames = default_frames.compose_rotation(all_rots) + + # chi2, chi3, and chi4 frames do not transform to the backbone frame but to + # the previous frame. So chain them up accordingly. + + chi1_frame_to_backb = all_frames[..., 4] + chi2_frame_to_backb = chi1_frame_to_backb @ all_frames[..., 5] + chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[..., 6] + chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[..., 7] + + all_frames_to_backb = Rigid3Array.cat([ + all_frames[..., 0:5], chi2_frame_to_backb[..., None], chi3_frame_to_backb[..., None], chi4_frame_to_backb[..., + None] + ], + dim=-1) + + # Create the global frames. + # shape (N, 8) + all_frames_to_global = backb_to_global[..., None] @ all_frames_to_backb + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + aatype: torch.Tensor, # (*, N) + all_frames_to_global: geometry.Rigid3Array # (N, 8) +) -> geometry.Vec3Array: # (*, N, 14) + """Put atom literature positions (atom14 encoding) in each rigid group.""" + # Pick the appropriate transform for every atom. + residx_to_group_idx = get_rc_tensor(rc.restype_atom14_to_rigid_group, aatype) + group_mask = torch.nn.functional.one_hot(residx_to_group_idx, num_classes=8) # shape (*, N, 14, 8) + + # geometry.Rigid3Array with shape (N, 14) + map_atoms_to_global = all_frames_to_global[..., None, :] * group_mask + map_atoms_to_global = map_atoms_to_global.map_tensor_fn(partial(torch.sum, dim=-1)) + + # Gather the literature atom positions for each residue. + # geometry.Vec3Array with shape (N, 14) + lit_positions = geometry.Vec3Array.from_array(get_rc_tensor(rc.restype_atom14_rigid_group_positions, aatype)) + + # Transform each atom from its local frame to the global frame. + # geometry.Vec3Array with shape (N, 14) + pred_positions = map_atoms_to_global.apply_to_point(lit_positions) + + # Mask out non-existing atoms. + mask = get_rc_tensor(rc.restype_atom14_mask, aatype) + pred_positions = pred_positions * mask + + return pred_positions + + +def extreme_ca_ca_distance_violations( + positions: geometry.Vec3Array, # (N, 37(14)) + mask: torch.Tensor, # (N, 37(14)) + residue_index: torch.Tensor, # (N) + max_angstrom_tolerance=1.5, + eps: float = 1e-6) -> torch.Tensor: + """Counts residues whose Ca is a large distance from its neighbor.""" + this_ca_pos = positions[..., :-1, 1] # (N - 1,) + this_ca_mask = mask[..., :-1, 1] # (N - 1) + next_ca_pos = positions[..., 1:, 1] # (N - 1,) + next_ca_mask = mask[..., 1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[..., 1:] - residue_index[..., :-1]) == 1.0).astype(torch.float32) + ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, eps) + violations = (ca_ca_distance - rc.ca_ca) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + return tensor_utils.masked_mean(mask=mask, value=violations, dim=-1) + + +def get_chi_atom_indices(device: torch.device): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in rc.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in rc.restypes: + residue_name = rc.restype_1to3[residue_name] + residue_chi_angles = rc.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append([rc.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + return torch.tensor(chi_atom_indices, device=device) + + +def compute_chi_angles(positions: geometry.Vec3Array, mask: torch.Tensor, aatype: torch.Tensor): + """Computes the chi angles given all atom positions and the amino acid type. + + Args: + positions: A Vec3Array of shape + [num_res, rc.atom_type_num], with positions of + atoms needed to calculate chi angles. Supports up to 1 batch dimension. + mask: An optional tensor of shape + [num_res, rc.atom_type_num] that masks which atom + positions are set for each residue. If given, then the chi mask will be + set to 1 for a chi angle only if the amino acid has that chi angle and all + the chi atoms needed to calculate that chi angle are set. If not given + (set to None), the chi mask will be set to 1 for a chi angle if the amino + acid has that chi angle and whether the actual atoms needed to calculate + it were set will be ignored. + aatype: A tensor of shape [num_res] with amino acid type integer + code (0 to 21). Supports up to 1 batch dimension. + + Returns: + A tuple of tensors (chi_angles, mask), where both have shape + [num_res, 4]. The mask masks out unused chi angles for amino acid + types that have less than 4 chi angles. If atom_positions_mask is set, the + chi mask will also mask out uncomputable chi angles. + """ + + # Don't assert on the num_res and batch dimensions as they might be unknown. + assert positions.shape[-1] == rc.atom_type_num + assert mask.shape[-1] == rc.atom_type_num + no_batch_dims = len(aatype.shape) - 1 + + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + chi_atom_indices = get_chi_atom_indices(aatype.device) + + # DISCREPANCY: DeepMind doesn't remove the gaps here. I don't know why + # theirs works. + aatype_gapless = torch.clamp(aatype, max=20) + + # Select atoms to compute chis. Shape: [*, num_res, chis=4, atoms=4]. + atom_indices = chi_atom_indices[aatype_gapless] + # Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3]. + chi_angle_atoms = positions.map_tensor_fn( + partial(tensor_utils.batched_gather, inds=atom_indices, dim=-1, no_batch_dims=no_batch_dims + 1)) + + a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)] + + chi_angles = geometry.dihedral_angle(a, b, c, d) + + # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4]. + chi_angles_mask = list(rc.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = torch.tensor(chi_angles_mask, device=aatype.device) + # Compute the chi angle mask. Shape [num_res, chis=4]. + chi_mask = chi_angles_mask[aatype_gapless] + + # The chi_mask is set to 1 only when all necessary chi angle atoms were set. + # Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4]. + chi_angle_atoms_mask = tensor_utils.batched_gather(mask, atom_indices, dim=-1, no_batch_dims=no_batch_dims + 1) + # Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4]. + chi_angle_atoms_mask = torch.prod(chi_angle_atoms_mask, dim=-1) + chi_mask = chi_mask * chi_angle_atoms_mask.to(torch.float32) + + return chi_angles, chi_mask + + +def make_transform_from_reference(a_xyz: geometry.Vec3Array, b_xyz: geometry.Vec3Array, + c_xyz: geometry.Vec3Array) -> geometry.Rigid3Array: + """Returns rotation and translation matrices to convert from reference. + + Note that this method does not take care of symmetries. If you provide the + coordinates in the non-standard way, the A atom will end up in the negative + y-axis rather than in the positive y-axis. You need to take care of such + cases in your code. + + Args: + a_xyz: A Vec3Array. + b_xyz: A Vec3Array. + c_xyz: A Vec3Array. + + Returns: + A Rigid3Array which, when applied to coordinates in a canonicalized + reference frame, will give coordinates approximately equal + the original coordinates (in the global frame). + """ + rotation = geometry.Rot3Array.from_two_vectors(c_xyz - b_xyz, a_xyz - b_xyz) + return geometry.Rigid3Array(rotation, b_xyz) + + +def make_backbone_affine( + positions: geometry.Vec3Array, + mask: torch.Tensor, + aatype: torch.Tensor, +) -> Tuple[geometry.Rigid3Array, torch.Tensor]: + a = rc.atom_order['N'] + b = rc.atom_order['CA'] + c = rc.atom_order['C'] + + rigid_mask = (mask[..., a] * mask[..., b] * mask[..., c]) + + rigid = make_transform_from_reference( + a_xyz=positions[..., a], + b_xyz=positions[..., b], + c_xyz=positions[..., c], + ) + + return rigid, rigid_mask diff --git a/tests/test_autochunk/origin_openfold/utils/checkpointing.py b/tests/test_autochunk/origin_openfold/utils/checkpointing.py new file mode 100644 index 000000000000..bd8def5f63c7 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/checkpointing.py @@ -0,0 +1,86 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.utils.checkpoint + +BLOCK_ARG = Any +BLOCK_ARGS = List[BLOCK_ARG] + + +def get_checkpoint_fn(): + checkpoint = torch.utils.checkpoint.checkpoint + + return checkpoint + + +@torch.jit.ignore +def checkpoint_blocks( + blocks: List[Callable], + args: BLOCK_ARGS, + blocks_per_ckpt: Optional[int], +) -> BLOCK_ARGS: + """ + Chunk a list of blocks and run each chunk with activation + checkpointing. We define a "block" as a callable whose only inputs are + the outputs of the previous block. + + Implements Subsection 1.11.8 + + Args: + blocks: + List of blocks + args: + Tuple of arguments for the first block. + blocks_per_ckpt: + Size of each chunk. A higher value corresponds to fewer + checkpoints, and trades memory for speed. If None, no checkpointing + is performed. + Returns: + The output of the final block + """ + + def wrap(a): + return (a,) if type(a) is not tuple else a + + def exec(b, a): + for block in b: + a = wrap(block(*a)) + return a + + def chunker(s, e): + + def exec_sliced(*a): + return exec(blocks[s:e], a) + + return exec_sliced + + # Avoids mishaps when the blocks take just one argument + args = wrap(args) + + if blocks_per_ckpt is None: + return exec(blocks, args) + elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): + raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") + + checkpoint = get_checkpoint_fn() + + for s in range(0, len(blocks), blocks_per_ckpt): + e = s + blocks_per_ckpt + args = checkpoint(chunker(s, e), *args) + args = wrap(args) + + return args diff --git a/tests/test_autochunk/origin_openfold/utils/feats.py b/tests/test_autochunk/origin_openfold/utils/feats.py new file mode 100644 index 000000000000..04ae54d413b3 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/feats.py @@ -0,0 +1,302 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union + +import fastfold.common.residue_constants as rc +import numpy as np +import torch +import torch.nn as nn +from fastfold.common import protein +from fastfold.utils.geometry.rigid_matrix_vector import Rigid3Array +from fastfold.utils.geometry.rotation_matrix import Rot3Array +from fastfold.utils.rigid_utils import Rigid, Rotation +from fastfold.utils.tensor_utils import batched_gather, one_hot, tensor_tree_map, tree_map + + +def dgram_from_positions( + pos: torch.Tensor, + min_bin: float = 3.25, + max_bin: float = 50.75, + no_bins: float = 39, + inf: float = 1e8, +) -> torch.Tensor: + dgram = torch.sum((pos[..., None, :] - pos[..., None, :, :])**2, dim=-1, keepdim=True) + lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device)**2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + + return dgram + + +def pseudo_beta_fn(aatype, all_atom_positions: torch.Tensor, + all_atom_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + is_gly = aatype == rc.restype_order["G"] + ca_idx = rc.atom_order["CA"] + cb_idx = rc.atom_order["CB"] + pseudo_beta = torch.where( + is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_masks is not None: + pseudo_beta_mask = torch.where( + is_gly, + all_atom_masks[..., ca_idx], + all_atom_masks[..., cb_idx], + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta, None + + +def atom14_to_atom37(atom14, batch: Dict[str, Any]): + atom37_data = batched_gather( + atom14, + batch["residx_atom37_to_atom14"], + dim=-2, + no_batch_dims=len(atom14.shape[:-2]), + ) + + atom37_data = atom37_data * batch["atom37_atom_exists"][..., None] + + return atom37_data + + +def build_template_angle_feat(template_feats: Dict[str, Any]) -> torch.Tensor: + template_aatype = template_feats["template_aatype"] + torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"] + alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"] + torsion_angles_mask = template_feats["template_torsion_angles_mask"] + template_angle_feat = torch.cat( + [ + nn.functional.one_hot(template_aatype, 22), + torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14), + alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14), + torsion_angles_mask, + ], + dim=-1, + ) + + return template_angle_feat + + +def build_template_pair_feat(batch: Dict[str, Any], + min_bin: float, + max_bin: float, + no_bins: int, + use_unit_vector: bool = False, + eps: float = 1e-20, + inf: float = 1e8, + chunk=None): + if chunk and 1 <= chunk <= 4: + for k, v in batch.items(): + batch[k] = v.cpu() + + template_mask = batch["template_pseudo_beta_mask"] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + # Compute distogram (this seems to differ slightly from Alg. 5) + tpb = batch["template_pseudo_beta"] + dgram = dgram_from_positions(tpb, min_bin, max_bin, no_bins, inf) + + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot = nn.functional.one_hot( + batch["template_aatype"], + rc.restype_num + 2, + ) + + n_res = batch["template_aatype"].shape[-1] + to_concat.append(aatype_one_hot[..., None, :, :].expand(*aatype_one_hot.shape[:-2], n_res, -1, -1)) + to_concat.append(aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1)) + + n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] + rigids = Rigid.make_transform_from_reference( + n_xyz=batch["template_all_atom_positions"][..., n, :], + ca_xyz=batch["template_all_atom_positions"][..., ca, :], + c_xyz=batch["template_all_atom_positions"][..., c, :], + eps=eps, + ) + points = rigids.get_trans()[..., None, :, :] + rigid_vec = rigids[..., None].invert_apply(points) + del rigids, points + + inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) + + t_aa_masks = batch["template_all_atom_mask"] + template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] + del t_aa_masks, n, ca, c + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + inv_distance_scalar = inv_distance_scalar * template_mask_2d + unit_vector = rigid_vec * inv_distance_scalar[..., None] + + if not use_unit_vector: + unit_vector = unit_vector * 0.0 + + to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) + to_concat.append(template_mask_2d[..., None]) + del unit_vector, rigid_vec, inv_distance_scalar + + act = torch.cat(to_concat, dim=-1) + act = act * template_mask_2d[..., None] + + return act + + +def build_extra_msa_feat(batch: Dict[str, Any]) -> torch.Tensor: + msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23) + msa_feat = [ + msa_1hot, + batch["extra_has_deletion"].unsqueeze(-1), + batch["extra_deletion_value"].unsqueeze(-1), + ] + return torch.cat(msa_feat, dim=-1) + + +def torsion_angles_to_frames( + r: Union[Rigid3Array, Rigid], + alpha: torch.Tensor, + aatype: torch.Tensor, + rrgdf: torch.Tensor, +) -> Union[Rigid, Rigid3Array]: + # [*, N, 8, 4, 4] + default_4x4 = rrgdf[aatype, ...] + + # [*, N, 8] transformations, i.e. + # One [*, N, 8, 3, 3] rotation matrix and + # One [*, N, 8, 3] translation matrix + default_r = r.from_tensor_4x4(default_4x4) + + bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + # [*, N, 8, 2] + alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) + + # [*, N, 8, 3, 3] + # Produces rotation matrices of the form: + # [ + # [1, 0 , 0 ], + # [0, a_2,-a_1], + # [0, a_1, a_2] + # ] + # This follows the original code rather than the supplement, which uses + # different indices. + if type(r) == Rigid3Array: + all_rots = alpha.new_zeros(default_r.shape + (3, 3)) + elif type(r) == Rigid: + all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) + else: + raise TypeError(f"Wrong type of Rigid: {type(r)}") + + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + if type(r) == Rigid3Array: + all_rots = Rot3Array.from_array(all_rots) + all_frames = default_r.compose_rotation(all_rots) + elif type(r) == Rigid: + all_rots = Rigid(Rotation(rot_mats=all_rots), None) + all_frames = default_r.compose(all_rots) + else: + raise TypeError(f"Wrong type of Rigid: {type(r)}") + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + if type(all_frames) == Rigid3Array: + all_frames_to_bb = Rigid3Array.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + elif type(all_frames) == Rigid: + all_frames_to_bb = Rigid.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = r[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + r: Union[Rigid3Array, Rigid], + aatype: torch.Tensor, + default_frames: torch.Tensor, + group_idx: torch.Tensor, + atom_mask: torch.Tensor, + lit_positions: torch.Tensor, +) -> torch.Tensor: + # [*, N, 14, 4, 4] + default_4x4 = default_frames[aatype, ...] + + # [*, N, 14] + group_mask = group_idx[aatype, ...] + + # [*, N, 14, 8] + if type(r) == Rigid3Array: + group_mask = nn.functional.one_hot( + group_mask.long(), + num_classes=default_frames.shape[-3], + ) + elif type(r) == Rigid: + group_mask = nn.functional.one_hot( + group_mask, + num_classes=default_frames.shape[-3], + ) + else: + raise TypeError(f"Wrong type of Rigid: {type(r)}") + + # [*, N, 14, 8] + t_atoms_to_global = r[..., None, :] * group_mask + + # [*, N, 14] + t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) + + # [*, N, 14, 1] + if type(r) == Rigid: + atom_mask = atom_mask[aatype, ...].unsqueeze(-1) + elif type(r) == Rigid3Array: + atom_mask = atom_mask[aatype, ...] + + # [*, N, 14, 3] + lit_positions = lit_positions[aatype, ...] + pred_positions = t_atoms_to_global.apply(lit_positions) + pred_positions = pred_positions * atom_mask + + return pred_positions diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/__init__.py b/tests/test_autochunk/origin_openfold/utils/geometry/__init__.py new file mode 100644 index 000000000000..2abd731c6a31 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/geometry/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Geometry Module.""" + +from fastfold.utils.geometry import rigid_matrix_vector, rotation_matrix, vector + +Rot3Array = rotation_matrix.Rot3Array +Rigid3Array = rigid_matrix_vector.Rigid3Array + +Vec3Array = vector.Vec3Array +square_euclidean_distance = vector.square_euclidean_distance +euclidean_distance = vector.euclidean_distance +dihedral_angle = vector.dihedral_angle +dot = vector.dot +cross = vector.cross diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/quat_rigid.py b/tests/test_autochunk/origin_openfold/utils/geometry/quat_rigid.py new file mode 100644 index 000000000000..fa72b4a7437f --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/geometry/quat_rigid.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +from fastfold.model.nn.primitives import Linear +from fastfold.utils.geometry.rigid_matrix_vector import Rigid3Array +from fastfold.utils.geometry.rotation_matrix import Rot3Array +from fastfold.utils.geometry.vector import Vec3Array + + +class QuatRigid(nn.Module): + + def __init__(self, c_hidden, full_quat): + super().__init__() + self.full_quat = full_quat + if self.full_quat: + rigid_dim = 7 + else: + rigid_dim = 6 + + self.linear = Linear(c_hidden, rigid_dim) + + def forward(self, activations: torch.Tensor) -> Rigid3Array: + # NOTE: During training, this needs to be run in higher precision + rigid_flat = self.linear(activations.to(torch.float32)) + + rigid_flat = torch.unbind(rigid_flat, dim=-1) + if (self.full_quat): + qw, qx, qy, qz = rigid_flat[:4] + translation = rigid_flat[4:] + else: + qx, qy, qz = rigid_flat[:3] + qw = torch.ones_like(qx) + translation = rigid_flat[3:] + + rotation = Rot3Array.from_quaternion( + qw, + qx, + qy, + qz, + normalize=True, + ) + translation = Vec3Array(*translation) + return Rigid3Array(rotation, translation) diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/rigid_matrix_vector.py b/tests/test_autochunk/origin_openfold/utils/geometry/rigid_matrix_vector.py new file mode 100644 index 000000000000..7b97e1827c16 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/geometry/rigid_matrix_vector.py @@ -0,0 +1,156 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rigid3Array Transformations represented by a Matrix and a Vector.""" + +from __future__ import annotations + +import dataclasses +from typing import List, Union + +import torch +from fastfold.utils.geometry import rotation_matrix, vector + +Float = Union[float, torch.Tensor] + + +@dataclasses.dataclass(frozen=True) +class Rigid3Array: + """Rigid Transformation, i.e. element of special euclidean group.""" + + rotation: rotation_matrix.Rot3Array + translation: vector.Vec3Array + + def __matmul__(self, other: Rigid3Array) -> Rigid3Array: + new_rotation = self.rotation @ other.rotation # __matmul__ + new_translation = self.apply_to_point(other.translation) + return Rigid3Array(new_rotation, new_translation) + + def __getitem__(self, index) -> Rigid3Array: + return Rigid3Array( + self.rotation[index], + self.translation[index], + ) + + def __mul__(self, other: torch.Tensor) -> Rigid3Array: + return Rigid3Array( + self.rotation * other, + self.translation * other, + ) + + def map_tensor_fn(self, fn) -> Rigid3Array: + return Rigid3Array( + self.rotation.map_tensor_fn(fn), + self.translation.map_tensor_fn(fn), + ) + + def inverse(self) -> Rigid3Array: + """Return Rigid3Array corresponding to inverse transform.""" + inv_rotation = self.rotation.inverse() + inv_translation = inv_rotation.apply_to_point(-self.translation) + return Rigid3Array(inv_rotation, inv_translation) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply Rigid3Array transform to point.""" + return self.rotation.apply_to_point(point) + self.translation + + def apply(self, point: torch.Tensor) -> vector.Vec3Array: + return self.apply_to_point(vector.Vec3Array.from_array(point)) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply inverse Rigid3Array transform to point.""" + new_point = point - self.translation + return self.rotation.apply_inverse_to_point(new_point) + + def compose_rotation(self, other_rotation): + rot = self.rotation @ other_rotation + return Rigid3Array(rot, self.translation.clone()) + + def compose(self, other_rigid): + return self @ other_rigid + + def unsqueeze(self, dim: int): + return Rigid3Array( + self.rotation.unsqueeze(dim), + self.translation.unsqueeze(dim), + ) + + @property + def shape(self) -> torch.Size: + return self.rotation.xx.shape + + @property + def dtype(self) -> torch.dtype: + return self.rotation.xx.dtype + + @property + def device(self) -> torch.device: + return self.rotation.xx.device + + @classmethod + def identity(cls, shape, device) -> Rigid3Array: + """Return identity Rigid3Array of given shape.""" + return cls(rotation_matrix.Rot3Array.identity(shape, device), vector.Vec3Array.zeros(shape, device)) + + @classmethod + def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array: + return cls( + rotation_matrix.Rot3Array.cat([r.rotation for r in rigids], dim=dim), + vector.Vec3Array.cat([r.translation for r in rigids], dim=dim), + ) + + def scale_translation(self, factor: Float) -> Rigid3Array: + """Scale translation in Rigid3Array by 'factor'.""" + return Rigid3Array(self.rotation, self.translation * factor) + + def to_tensor(self) -> torch.Tensor: + rot_array = self.rotation.to_tensor() + vec_array = self.translation.to_tensor() + array = torch.zeros(rot_array.shape[:-2] + (4, 4), device=rot_array.device, dtype=rot_array.dtype) + array[..., :3, :3] = rot_array + array[..., :3, 3] = vec_array + array[..., 3, 3] = 1. + return array + + def to_tensor_4x4(self) -> torch.Tensor: + return self.to_tensor() + + def reshape(self, new_shape) -> Rigid3Array: + rots = self.rotation.reshape(new_shape) + trans = self.translation.reshape(new_shape) + return Rigid3Aray(rots, trans) + + def stop_rot_gradient(self) -> Rigid3Array: + return Rigid3Array( + self.rotation.stop_gradient(), + self.translation, + ) + + @classmethod + def from_array(cls, array): + rot = rotation_matrix.Rot3Array.from_array(array[..., :3, :3],) + vec = vector.Vec3Array.from_array(array[..., :3, 3]) + return cls(rot, vec) + + @classmethod + def from_tensor_4x4(cls, array): + return cls.from_array(array) + + @classmethod + def from_array4x4(cls, array: torch.tensor) -> Rigid3Array: + """Construct Rigid3Array from homogeneous 4x4 array.""" + rotation = rotation_matrix.Rot3Array(array[..., 0, 0], array[..., 0, 1], array[..., 0, 2], array[..., 1, 0], + array[..., 1, 1], array[..., 1, 2], array[..., 2, 0], array[..., 2, 1], + array[..., 2, 2]) + translation = vector.Vec3Array(array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]) + return cls(rotation, translation) diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py b/tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py new file mode 100644 index 000000000000..552128ac0c0c --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py @@ -0,0 +1,162 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rot3Array Matrix Class.""" + +from __future__ import annotations + +import dataclasses + +import numpy as np +import torch +from fastfold.utils.geometry import utils, vector +from fastfold.utils.tensor_utils import tensor_tree_map + +COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] + + +@dataclasses.dataclass(frozen=True) +class Rot3Array: + """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" + xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) + xy: torch.Tensor + xz: torch.Tensor + yx: torch.Tensor + yy: torch.Tensor + yz: torch.Tensor + zx: torch.Tensor + zy: torch.Tensor + zz: torch.Tensor + + __array_ufunc__ = None + + def __getitem__(self, index): + field_names = utils.get_field_names(Rot3Array) + return Rot3Array(**{name: getattr(self, name)[index] for name in field_names}) + + def __mul__(self, other: torch.Tensor): + field_names = utils.get_field_names(Rot3Array) + return Rot3Array(**{name: getattr(self, name) * other for name in field_names}) + + def __matmul__(self, other: Rot3Array) -> Rot3Array: + """Composes two Rot3Arrays.""" + c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) + c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) + c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) + return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + def map_tensor_fn(self, fn) -> Rot3Array: + field_names = utils.get_field_names(Rot3Array) + return Rot3Array(**{name: fn(getattr(self, name)) for name in field_names}) + + def inverse(self) -> Rot3Array: + """Returns inverse of Rot3Array.""" + return Rot3Array(self.xx, self.yx, self.zx, self.xy, self.yy, self.zy, self.xz, self.yz, self.zz) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies Rot3Array to point.""" + return vector.Vec3Array(self.xx * point.x + self.xy * point.y + self.xz * point.z, + self.yx * point.x + self.yy * point.y + self.yz * point.z, + self.zx * point.x + self.zy * point.y + self.zz * point.z) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies inverse Rot3Array to point.""" + return self.inverse().apply_to_point(point) + + def unsqueeze(self, dim: int): + return Rot3Array(*tensor_tree_map(lambda t: t.unsqueeze(dim), [getattr(self, c) for c in COMPONENTS])) + + def stop_gradient(self) -> Rot3Array: + return Rot3Array(*[getattr(self, c).detach() for c in COMPONENTS]) + + @classmethod + def identity(cls, shape, device) -> Rot3Array: + """Returns identity of given shape.""" + ones = torch.ones(shape, dtype=torch.float32, device=device) + zeros = torch.zeros(shape, dtype=torch.float32, device=device) + return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) + + @classmethod + def from_two_vectors(cls, e0: vector.Vec3Array, e1: vector.Vec3Array) -> Rot3Array: + """Construct Rot3Array from two Vectors. + + Rot3Array is constructed such that in the corresponding frame 'e0' lies on + the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. + + Args: + e0: Vector + e1: Vector + Returns: + Rot3Array + """ + # Normalize the unit vector for the x-axis, e0. + e0 = e0.normalized() + # make e1 perpendicular to e0. + c = e1.dot(e0) + e1 = (e1 - c * e0).normalized() + # Compute e2 as cross product of e0 and e1. + e2 = e0.cross(e1) + return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + + @classmethod + def from_array(cls, array: torch.Tensor) -> Rot3Array: + """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" + rows = torch.unbind(array, dim=-2) + rc = [torch.unbind(e, dim=-1) for e in rows] + return cls(*[e for row in rc for e in row]) + + def to_tensor(self) -> torch.Tensor: + """Convert Rot3Array to array of shape [..., 3, 3].""" + return torch.stack([ + torch.stack([self.xx, self.xy, self.xz], dim=-1), + torch.stack([self.yx, self.yy, self.yz], dim=-1), + torch.stack([self.zx, self.zy, self.zz], dim=-1) + ], + dim=-2) + + @classmethod + def from_quaternion(cls, + w: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + normalize: bool = True, + eps: float = 1e-6) -> Rot3Array: + """Construct Rot3Array from components of quaternion.""" + if normalize: + inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2) + w *= inv_norm + x *= inv_norm + y *= inv_norm + z *= inv_norm + xx = 1 - 2 * (y**2 + z**2) + xy = 2 * (x * y - w * z) + xz = 2 * (x * z + w * y) + yx = 2 * (x * y + w * z) + yy = 1 - 2 * (x**2 + z**2) + yz = 2 * (y * z - w * x) + zx = 2 * (x * z - w * y) + zy = 2 * (y * z + w * x) + zz = 1 - 2 * (x**2 + y**2) + return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) + + def reshape(self, new_shape): + field_names = utils.get_field_names(Rot3Array) + reshape_fn = lambda t: t.reshape(new_shape) + return Rot3Array(**{name: reshape_fn(getattr(self, name)) for name in field_names}) + + @classmethod + def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array: + field_names = utils.get_field_names(Rot3Array) + cat_fn = lambda l: torch.cat(l, dim=dim) + return cls(**{name: cat_fn([getattr(r, name) for r in rots]) for name in field_names}) diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/test_utils.py b/tests/test_autochunk/origin_openfold/utils/geometry/test_utils.py new file mode 100644 index 000000000000..a86cb6a864e6 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/geometry/test_utils.py @@ -0,0 +1,86 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utils for tests.""" + +import dataclasses + +import numpy as np +from fastfold.utils.geometry import rigid_matrix_vector, rotation_matrix, vector + + +def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array, matrix2: rotation_matrix.Rot3Array): + for field in dataclasses.fields(rotation_matrix.Rot3Array): + field = field.name + np.testing.assert_array_equal(getattr(matrix1, field), getattr(matrix2, field)) + + +def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array, mat2: rotation_matrix.Rot3Array): + np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6) + + +def assert_array_equal_to_rotation_matrix(array: np.ndarray, matrix: rotation_matrix.Rot3Array): + """Check that array and Matrix match.""" + np.testing.assert_array_equal(matrix.xx, array[..., 0, 0]) + np.testing.assert_array_equal(matrix.xy, array[..., 0, 1]) + np.testing.assert_array_equal(matrix.xz, array[..., 0, 2]) + np.testing.assert_array_equal(matrix.yx, array[..., 1, 0]) + np.testing.assert_array_equal(matrix.yy, array[..., 1, 1]) + np.testing.assert_array_equal(matrix.yz, array[..., 1, 2]) + np.testing.assert_array_equal(matrix.zx, array[..., 2, 0]) + np.testing.assert_array_equal(matrix.zy, array[..., 2, 1]) + np.testing.assert_array_equal(matrix.zz, array[..., 2, 2]) + + +def assert_array_close_to_rotation_matrix(array: np.ndarray, matrix: rotation_matrix.Rot3Array): + np.testing.assert_array_almost_equal(matrix.to_array(), array, 6) + + +def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + np.testing.assert_array_equal(vec1.x, vec2.x) + np.testing.assert_array_equal(vec1.y, vec2.y) + np.testing.assert_array_equal(vec1.z, vec2.z) + + +def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.) + np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.) + np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.) + + +def assert_array_close_to_vector(array: np.ndarray, vec: vector.Vec3Array): + np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.) + + +def assert_array_equal_to_vector(array: np.ndarray, vec: vector.Vec3Array): + np.testing.assert_array_equal(vec.to_array(), array) + + +def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array, trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_equal(rot, rigid.rotation) + assert_vectors_equal(trans, rigid.translation) + + +def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array, trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_close(rot, rigid.rotation) + assert_vectors_close(trans, rigid.translation) diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/utils.py b/tests/test_autochunk/origin_openfold/utils/geometry/utils.py new file mode 100644 index 000000000000..6c4d52ba9969 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/geometry/utils.py @@ -0,0 +1,22 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for geometry library.""" + +import dataclasses + + +def get_field_names(cls): + fields = dataclasses.fields(cls) + field_names = [f.name for f in fields] + return field_names diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/vector.py b/tests/test_autochunk/origin_openfold/utils/geometry/vector.py new file mode 100644 index 000000000000..4204e736c328 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/geometry/vector.py @@ -0,0 +1,253 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Vec3Array Class.""" + +from __future__ import annotations + +import dataclasses +from typing import List, Union + +import torch +from fastfold.utils.geometry import utils + +Float = Union[float, torch.Tensor] + + +@dataclasses.dataclass(frozen=True) +class Vec3Array: + x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) + y: torch.Tensor + z: torch.Tensor + + def __post_init__(self): + if hasattr(self.x, 'dtype'): + assert self.x.dtype == self.y.dtype + assert self.x.dtype == self.z.dtype + assert all([x == y for x, y in zip(self.x.shape, self.y.shape)]) + assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) + + def __add__(self, other: Vec3Array) -> Vec3Array: + return Vec3Array( + self.x + other.x, + self.y + other.y, + self.z + other.z, + ) + + def __sub__(self, other: Vec3Array) -> Vec3Array: + return Vec3Array( + self.x - other.x, + self.y - other.y, + self.z - other.z, + ) + + def __mul__(self, other: Float) -> Vec3Array: + return Vec3Array( + self.x * other, + self.y * other, + self.z * other, + ) + + def __rmul__(self, other: Float) -> Vec3Array: + return self * other + + def __truediv__(self, other: Float) -> Vec3Array: + return Vec3Array( + self.x / other, + self.y / other, + self.z / other, + ) + + def __neg__(self) -> Vec3Array: + return self * -1 + + def __pos__(self) -> Vec3Array: + return self * 1 + + def __getitem__(self, index) -> Vec3Array: + return Vec3Array( + self.x[index], + self.y[index], + self.z[index], + ) + + def __iter__(self): + return iter((self.x, self.y, self.z)) + + @property + def shape(self): + return self.x.shape + + def map_tensor_fn(self, fn) -> Vec3Array: + return Vec3Array( + fn(self.x), + fn(self.y), + fn(self.z), + ) + + def cross(self, other: Vec3Array) -> Vec3Array: + """Compute cross product between 'self' and 'other'.""" + new_x = self.y * other.z - self.z * other.y + new_y = self.z * other.x - self.x * other.z + new_z = self.x * other.y - self.y * other.x + return Vec3Array(new_x, new_y, new_z) + + def dot(self, other: Vec3Array) -> Float: + """Compute dot product between 'self' and 'other'.""" + return self.x * other.x + self.y * other.y + self.z * other.z + + def norm(self, epsilon: float = 1e-6) -> Float: + """Compute Norm of Vec3Array, clipped to epsilon.""" + # To avoid NaN on the backward pass, we must use maximum before the sqrt + norm2 = self.dot(self) + if epsilon: + norm2 = torch.clamp(norm2, min=epsilon**2) + return torch.sqrt(norm2) + + def norm2(self): + return self.dot(self) + + def normalized(self, epsilon: float = 1e-6) -> Vec3Array: + """Return unit vector with optional clipping.""" + return self / self.norm(epsilon) + + def clone(self) -> Vec3Array: + return Vec3Array( + self.x.clone(), + self.y.clone(), + self.z.clone(), + ) + + def reshape(self, new_shape) -> Vec3Array: + x = self.x.reshape(new_shape) + y = self.y.reshape(new_shape) + z = self.z.reshape(new_shape) + + return Vec3Array(x, y, z) + + def sum(self, dim: int) -> Vec3Array: + return Vec3Array( + torch.sum(self.x, dim=dim), + torch.sum(self.y, dim=dim), + torch.sum(self.z, dim=dim), + ) + + def unsqueeze(self, dim: int): + return Vec3Array( + self.x.unsqueeze(dim), + self.y.unsqueeze(dim), + self.z.unsqueeze(dim), + ) + + @classmethod + def zeros(cls, shape, device="cpu"): + """Return Vec3Array corresponding to zeros of given shape.""" + return cls(torch.zeros(shape, dtype=torch.float32, device=device), + torch.zeros(shape, dtype=torch.float32, device=device), + torch.zeros(shape, dtype=torch.float32, device=device)) + + def to_tensor(self) -> torch.Tensor: + return torch.stack([self.x, self.y, self.z], dim=-1) + + @classmethod + def from_array(cls, tensor): + return cls(*torch.unbind(tensor, dim=-1)) + + @classmethod + def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array: + return cls( + torch.cat([v.x for v in vecs], dim=dim), + torch.cat([v.y for v in vecs], dim=dim), + torch.cat([v.z for v in vecs], dim=dim), + ) + + +def square_euclidean_distance(vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6) -> Float: + """Computes square of euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute distance to + vec2: Vec3Array to compute distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of square euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + difference = vec1 - vec2 + distance = difference.dot(difference) + if epsilon: + distance = torch.maximum(distance, epsilon) + return distance + + +def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.dot(vector2) + + +def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.cross(vector2) + + +def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: + return vector.norm(epsilon) + + +def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: + return vector.normalized(epsilon) + + +def euclidean_distance(vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6) -> Float: + """Computes euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute euclidean distance to + vec2: Vec3Array to compute euclidean distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) + distance = torch.sqrt(distance_sq) + return distance + + +def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, d: Vec3Array) -> Float: + """Computes torsion angle for a quadruple of points. + + For points (a, b, c, d), this is the angle between the planes defined by + points (a, b, c) and (b, c, d). It is also known as the dihedral angle. + + Arguments: + a: A Vec3Array of coordinates. + b: A Vec3Array of coordinates. + c: A Vec3Array of coordinates. + d: A Vec3Array of coordinates. + + Returns: + A tensor of angles in radians: [-pi, pi]. + """ + v1 = a - b + v2 = b - c + v3 = d - c + + c1 = v1.cross(v2) + c2 = v3.cross(v2) + c3 = c2.cross(c1) + + v2_mag = v2.norm() + return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2)) diff --git a/tests/test_autochunk/origin_openfold/utils/loss.py b/tests/test_autochunk/origin_openfold/utils/loss.py new file mode 100644 index 000000000000..d39705c901f3 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/loss.py @@ -0,0 +1,1403 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import partial +from typing import Dict, Optional, Tuple + +import ml_collections +import numpy as np +import torch +import torch.nn as nn +from fastfold.common import residue_constants +from fastfold.utils import feats +from fastfold.utils.rigid_utils import Rigid, Rotation +from fastfold.utils.tensor_utils import batched_gather, masked_mean, permute_final_dims, tensor_tree_map, tree_map +from torch.distributions.bernoulli import Bernoulli + + +def softmax_cross_entropy(logits, labels): + loss = -1 * torch.sum( + labels * torch.nn.functional.log_softmax(logits, dim=-1), + dim=-1, + ) + return loss + + +def sigmoid_cross_entropy(logits, labels): + log_p = torch.log(torch.sigmoid(logits)) + log_not_p = torch.log(torch.sigmoid(-logits)) + loss = -labels * log_p - (1 - labels) * log_not_p + return loss + + +def torsion_angle_loss( + a, # [*, N, 7, 2] + a_gt, # [*, N, 7, 2] + a_alt_gt, # [*, N, 7, 2] +): + # [*, N, 7] + norm = torch.norm(a, dim=-1) + + # [*, N, 7, 2] + a = a / norm.unsqueeze(-1) + + # [*, N, 7] + diff_norm_gt = torch.norm(a - a_gt, dim=-1) + diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1) + min_diff = torch.minimum(diff_norm_gt**2, diff_norm_alt_gt**2) + + # [*] + l_torsion = torch.mean(min_diff, dim=(-1, -2)) + l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2)) + + an_weight = 0.02 + return l_torsion + an_weight * l_angle_norm + + +def compute_fape( + pred_frames: Rigid, + target_frames: Rigid, + frames_mask: torch.Tensor, + pred_positions: torch.Tensor, + target_positions: torch.Tensor, + positions_mask: torch.Tensor, + length_scale: float, + l1_clamp_distance: Optional[float] = None, + eps=1e-8, +) -> torch.Tensor: + """ + Computes FAPE loss. + + Args: + pred_frames: + [*, N_frames] Rigid object of predicted frames + target_frames: + [*, N_frames] Rigid object of ground truth frames + frames_mask: + [*, N_frames] binary mask for the frames + pred_positions: + [*, N_pts, 3] predicted atom positions + target_positions: + [*, N_pts, 3] ground truth positions + positions_mask: + [*, N_pts] positions mask + length_scale: + Length scale by which the loss is divided + l1_clamp_distance: + Cutoff above which distance errors are disregarded + eps: + Small value used to regularize denominators + Returns: + [*] loss tensor + """ + # [*, N_frames, N_pts, 3] + local_pred_pos = pred_frames.invert()[..., None].apply(pred_positions[..., None, :, :],) + local_target_pos = target_frames.invert()[..., None].apply(target_positions[..., None, :, :],) + + error_dist = torch.sqrt(torch.sum((local_pred_pos - local_target_pos)**2, dim=-1) + eps) + + if l1_clamp_distance is not None: + error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance) + + normed_error = error_dist / length_scale + normed_error = normed_error * frames_mask[..., None] + normed_error = normed_error * positions_mask[..., None, :] + + # FP16-friendly averaging. Roughly equivalent to: + # + # norm_factor = ( + # torch.sum(frames_mask, dim=-1) * + # torch.sum(positions_mask, dim=-1) + # ) + # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor) + # + # ("roughly" because eps is necessarily duplicated in the latter) + normed_error = torch.sum(normed_error, dim=-1) + normed_error = (normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]) + normed_error = torch.sum(normed_error, dim=-1) + normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1)) + + return normed_error + + +def backbone_loss( + backbone_rigid_tensor: torch.Tensor, + backbone_rigid_mask: torch.Tensor, + traj: torch.Tensor, + use_clamped_fape: Optional[torch.Tensor] = None, + clamp_distance: float = 10.0, + loss_unit_distance: float = 10.0, + eps: float = 1e-4, + **kwargs, +) -> torch.Tensor: + pred_aff = Rigid.from_tensor_7(traj) + pred_aff = Rigid( + Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None), + pred_aff.get_trans(), + ) + + # DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of + # backbone tensor, normalizes it, and then turns it back to a rotation + # matrix. To avoid a potentially numerically unstable rotation matrix + # to quaternion conversion, we just use the original rotation matrix + # outright. This one hasn't been composed a bunch of times, though, so + # it might be fine. + gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor) + + fape_loss = compute_fape( + pred_aff, + gt_aff[None], + backbone_rigid_mask[None], + pred_aff.get_trans(), + gt_aff[None].get_trans(), + backbone_rigid_mask[None], + l1_clamp_distance=clamp_distance, + length_scale=loss_unit_distance, + eps=eps, + ) + if use_clamped_fape is not None: + unclamped_fape_loss = compute_fape( + pred_aff, + gt_aff[None], + backbone_rigid_mask[None], + pred_aff.get_trans(), + gt_aff[None].get_trans(), + backbone_rigid_mask[None], + l1_clamp_distance=None, + length_scale=loss_unit_distance, + eps=eps, + ) + + fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (1 - use_clamped_fape) + + # Average over the batch dimension + fape_loss = torch.mean(fape_loss) + + return fape_loss + + +def sidechain_loss( + sidechain_frames: torch.Tensor, + sidechain_atom_pos: torch.Tensor, + rigidgroups_gt_frames: torch.Tensor, + rigidgroups_alt_gt_frames: torch.Tensor, + rigidgroups_gt_exists: torch.Tensor, + renamed_atom14_gt_positions: torch.Tensor, + renamed_atom14_gt_exists: torch.Tensor, + alt_naming_is_better: torch.Tensor, + clamp_distance: float = 10.0, + length_scale: float = 10.0, + eps: float = 1e-4, + **kwargs, +) -> torch.Tensor: + renamed_gt_frames = (1.0 - alt_naming_is_better[..., None, None, None] + ) * rigidgroups_gt_frames + alt_naming_is_better[..., None, None, + None] * rigidgroups_alt_gt_frames + + # Steamroll the inputs + sidechain_frames = sidechain_frames[-1] + batch_dims = sidechain_frames.shape[:-4] + sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4) + sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames) + renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4) + renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames) + rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1) + sidechain_atom_pos = sidechain_atom_pos[-1] + sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3) + renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(*batch_dims, -1, 3) + renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1) + + fape = compute_fape( + sidechain_frames, + renamed_gt_frames, + rigidgroups_gt_exists, + sidechain_atom_pos, + renamed_atom14_gt_positions, + renamed_atom14_gt_exists, + l1_clamp_distance=clamp_distance, + length_scale=length_scale, + eps=eps, + ) + + return fape + + +def fape_loss( + out: Dict[str, torch.Tensor], + batch: Dict[str, torch.Tensor], + config: ml_collections.ConfigDict, +) -> torch.Tensor: + bb_loss = backbone_loss( + traj=out["sm"]["frames"], + **{ + **batch, + **config.backbone + }, + ) + + sc_loss = sidechain_loss( + out["sm"]["sidechain_frames"], + out["sm"]["positions"], + **{ + **batch, + **config.sidechain + }, + ) + + loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def supervised_chi_loss( + angles_sin_cos: torch.Tensor, + unnormalized_angles_sin_cos: torch.Tensor, + aatype: torch.Tensor, + seq_mask: torch.Tensor, + chi_mask: torch.Tensor, + chi_angles_sin_cos: torch.Tensor, + chi_weight: float, + angle_norm_weight: float, + eps=1e-6, + **kwargs, +) -> torch.Tensor: + """ + Implements Algorithm 27 (torsionAngleLoss) + + Args: + angles_sin_cos: + [*, N, 7, 2] predicted angles + unnormalized_angles_sin_cos: + The same angles, but unnormalized + aatype: + [*, N] residue indices + seq_mask: + [*, N] sequence mask + chi_mask: + [*, N, 7] angle mask + chi_angles_sin_cos: + [*, N, 7, 2] ground truth angles + chi_weight: + Weight for the angle component of the loss + angle_norm_weight: + Weight for the normalization component of the loss + Returns: + [*] loss tensor + """ + pred_angles = angles_sin_cos[..., 3:, :] + residue_type_one_hot = torch.nn.functional.one_hot( + aatype, + residue_constants.restype_num + 1, + ) + chi_pi_periodic = torch.einsum( + "...ij,jk->ik", + residue_type_one_hot.type(angles_sin_cos.dtype), + angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), + ) + + true_chi = chi_angles_sin_cos[None] + + shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1) + true_chi_shifted = shifted_mask * true_chi + sq_chi_error = torch.sum((true_chi - pred_angles)**2, dim=-1) + sq_chi_error_shifted = torch.sum((true_chi_shifted - pred_angles)**2, dim=-1) + sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted) + # The ol' switcheroo + sq_chi_error = sq_chi_error.permute(*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1) + sq_chi_loss = masked_mean(chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)) + + loss = chi_weight * sq_chi_loss + + angle_norm = torch.sqrt(torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps) + norm_error = torch.abs(angle_norm - 1.0) + norm_error = norm_error.permute(*range(len(norm_error.shape))[1:-2], 0, -2, -1) + angle_norm_loss = masked_mean(seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)) + + loss = loss + angle_norm_weight * angle_norm_loss + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def compute_plddt(logits: torch.Tensor) -> torch.Tensor: + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bounds = torch.arange(start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device) + probs = torch.nn.functional.softmax(logits, dim=-1) + pred_lddt_ca = torch.sum( + probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape), + dim=-1, + ) + return pred_lddt_ca * 100 + + +def lddt( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + n = all_atom_mask.shape[-2] + dmat_true = torch.sqrt(eps + torch.sum( + (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])**2, + dim=-1, + )) + + dmat_pred = torch.sqrt(eps + torch.sum( + (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :])**2, + dim=-1, + )) + dists_to_score = ((dmat_true < cutoff) * all_atom_mask * permute_final_dims(all_atom_mask, (1, 0)) * + (1.0 - torch.eye(n, device=all_atom_mask.device))) + + dist_l1 = torch.abs(dmat_true - dmat_pred) + + score = ((dist_l1 < 0.5).type(dist_l1.dtype) + (dist_l1 < 1.0).type(dist_l1.dtype) + + (dist_l1 < 2.0).type(dist_l1.dtype) + (dist_l1 < 4.0).type(dist_l1.dtype)) + score = score * 0.25 + + dims = (-1,) if per_residue else (-2, -1) + norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) + score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) + + return score + + +def lddt_ca( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos:(ca_pos + 1)] # keep dim + + return lddt( + all_atom_pred_pos, + all_atom_positions, + all_atom_mask, + cutoff=cutoff, + eps=eps, + per_residue=per_residue, + ) + + +def lddt_loss( + logits: torch.Tensor, + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + resolution: torch.Tensor, + cutoff: float = 15.0, + no_bins: int = 50, + min_resolution: float = 0.1, + max_resolution: float = 3.0, + eps: float = 1e-10, + **kwargs, +) -> torch.Tensor: + n = all_atom_mask.shape[-2] + + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos:(ca_pos + 1)] # keep dim + + score = lddt(all_atom_pred_pos, all_atom_positions, all_atom_mask, cutoff=cutoff, eps=eps) + + score = score.detach() + + bin_index = torch.floor(score * no_bins).long() + bin_index = torch.clamp(bin_index, max=(no_bins - 1)) + lddt_ca_one_hot = torch.nn.functional.one_hot(bin_index, num_classes=no_bins) + + errors = softmax_cross_entropy(logits, lddt_ca_one_hot) + all_atom_mask = all_atom_mask.squeeze(-1) + loss = torch.sum(errors * all_atom_mask, dim=-1) / (eps + torch.sum(all_atom_mask, dim=-1)) + + loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution)) + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def distogram_loss( + logits, + pseudo_beta, + pseudo_beta_mask, + min_bin=2.3125, + max_bin=21.6875, + no_bins=64, + eps=1e-6, + **kwargs, +): + boundaries = torch.linspace( + min_bin, + max_bin, + no_bins - 1, + device=logits.device, + ) + boundaries = boundaries**2 + + dists = torch.sum( + (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :])**2, + dim=-1, + keepdims=True, + ) + + true_bins = torch.sum(dists > boundaries, dim=-1) + + errors = softmax_cross_entropy( + logits, + torch.nn.functional.one_hot(true_bins, no_bins), + ) + + square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :] + + # FP16-friendly sum. Equivalent to: + # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) / + # (eps + torch.sum(square_mask, dim=(-1, -2)))) + denom = eps + torch.sum(square_mask, dim=(-1, -2)) + mean = errors * square_mask + mean = torch.sum(mean, dim=-1) + mean = mean / denom[..., None] + mean = torch.sum(mean, dim=-1) + + # Average over the batch dimensions + mean = torch.mean(mean) + + return mean + + +def _calculate_bin_centers(boundaries: torch.Tensor): + step = boundaries[1] - boundaries[0] + bin_centers = boundaries + step / 2 + bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: torch.Tensor, + aligned_distance_error_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + return ( + torch.sum(aligned_distance_error_probs * bin_centers, dim=-1), + bin_centers[-1], + ) + + +def compute_predicted_aligned_error( + logits: torch.Tensor, + max_bin: int = 31, + no_bins: int = 64, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [*, num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + max_bin: Maximum bin value + no_bins: Number of bins + Returns: + aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [*, num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: [*] the maximum predicted error possible. + """ + boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) + + aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1) + ( + predicted_aligned_error, + max_predicted_aligned_error, + ) = _calculate_expected_aligned_error( + alignment_confidence_breaks=boundaries, + aligned_distance_error_probs=aligned_confidence_probs, + ) + + return { + "aligned_confidence_probs": aligned_confidence_probs, + "predicted_aligned_error": predicted_aligned_error, + "max_predicted_aligned_error": max_predicted_aligned_error, + } + + +def compute_tm( + logits: torch.Tensor, + residue_weights: Optional[torch.Tensor] = None, + max_bin: int = 31, + no_bins: int = 64, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + if residue_weights is None: + residue_weights = logits.new_ones(logits.shape[-2]) + + boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) + + bin_centers = _calculate_bin_centers(boundaries) + torch.sum(residue_weights) + n = logits.shape[-2] + clipped_n = max(n, 19) + + d0 = 1.24 * (clipped_n - 15)**(1.0 / 3) - 1.8 + + probs = torch.nn.functional.softmax(logits, dim=-1) + + tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2)) + predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) + + normed_residue_mask = residue_weights / (eps + residue_weights.sum()) + per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) + weighted = per_alignment * residue_weights + argmax = (weighted == torch.max(weighted)).nonzero()[0] + return per_alignment[tuple(argmax)] + + +def tm_loss( + logits, + final_affine_tensor, + backbone_rigid_tensor, + backbone_rigid_mask, + resolution, + max_bin=31, + no_bins=64, + min_resolution: float = 0.1, + max_resolution: float = 3.0, + eps=1e-8, + **kwargs, +): + pred_affine = Rigid.from_tensor_7(final_affine_tensor) + backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor) + + def _points(affine): + pts = affine.get_trans()[..., None, :, :] + return affine.invert()[..., None].apply(pts) + + sq_diff = torch.sum((_points(pred_affine) - _points(backbone_rigid))**2, dim=-1) + + sq_diff = sq_diff.detach() + + boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) + boundaries = boundaries**2 + true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1) + + errors = softmax_cross_entropy(logits, torch.nn.functional.one_hot(true_bins, no_bins)) + + square_mask = (backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]) + + loss = torch.sum(errors * square_mask, dim=-1) + scale = 0.5 # hack to help FP16 training along + denom = eps + torch.sum(scale * square_mask, dim=(-1, -2)) + loss = loss / denom[..., None] + loss = torch.sum(loss, dim=-1) + loss = loss * scale + + loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution)) + + # Average over the loss dimension + loss = torch.mean(loss) + + return loss + + +def between_residue_bond_loss( + pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3) + pred_atom_mask: torch.Tensor, # (*, N, 37/14) + residue_index: torch.Tensor, # (*, N) + aatype: torch.Tensor, # (*, N) + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0, + eps=1e-6, +) -> Dict[str, torch.Tensor]: + """Flat-bottom loss to penalize structural violations between residues. + + This is a loss penalizing any violation of the geometry around the peptide + bond between consecutive amino acids. This loss corresponds to + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + aatype: Amino acid type of given residue + tolerance_factor_soft: soft tolerance factor measured in standard deviations + of pdb distributions + tolerance_factor_hard: hard tolerance factor measured in standard deviations + of pdb distributions + + Returns: + Dict containing: + * 'c_n_loss_mean': Loss for peptide bond length violations + * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned + by CA, C, N + * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned + by C, N, CA + * 'per_residue_loss_sum': sum of all losses for each residue + * 'per_residue_violation_mask': mask denoting all residues with violation + present. + """ + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[..., :-1, 1, :] + this_ca_mask = pred_atom_mask[..., :-1, 1] + this_c_pos = pred_atom_positions[..., :-1, 2, :] + this_c_mask = pred_atom_mask[..., :-1, 2] + next_n_pos = pred_atom_positions[..., 1:, 0, :] + next_n_mask = pred_atom_mask[..., 1:, 0] + next_ca_pos = pred_atom_positions[..., 1:, 1, :] + next_ca_mask = pred_atom_mask[..., 1:, 1] + has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 + + # Compute loss for the C--N bond. + c_n_bond_length = torch.sqrt(eps + torch.sum((this_c_pos - next_n_pos)**2, dim=-1)) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"] + gt_length = (~next_is_proline) * residue_constants.between_res_bond_length_c_n[ + 0] + next_is_proline * residue_constants.between_res_bond_length_c_n[1] + gt_stddev = (~next_is_proline) * residue_constants.between_res_bond_length_stddev_c_n[ + 0] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1] + c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length)**2) + c_n_loss_per_residue = torch.nn.functional.relu(c_n_bond_length_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (torch.sum(mask, dim=-1) + eps) + c_n_violation_mask = mask * (c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) + + # Compute loss for the angles. + ca_c_bond_length = torch.sqrt(eps + torch.sum((this_ca_pos - this_c_pos)**2, dim=-1)) + n_ca_bond_length = torch.sqrt(eps + torch.sum((next_n_pos - next_ca_pos)**2, dim=-1)) + + c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None] + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None] + n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None] + + ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + ca_c_n_cos_angle_error = torch.sqrt(eps + (ca_c_n_cos_angle - gt_angle)**2) + ca_c_n_loss_per_residue = torch.nn.functional.relu(ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (torch.sum(mask, dim=-1) + eps) + ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = torch.sqrt(eps + torch.square(c_n_ca_cos_angle - gt_angle)) + c_n_ca_loss_per_residue = torch.nn.functional.relu(c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (torch.sum(mask, dim=-1) + eps) + c_n_ca_violation_mask = mask * (c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + # Compute a per residue loss (equally distribute the loss to both + # neighbouring residues). + per_residue_loss_sum = (c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue) + per_residue_loss_sum = 0.5 * (torch.nn.functional.pad(per_residue_loss_sum, (0, 1)) + + torch.nn.functional.pad(per_residue_loss_sum, (1, 0))) + + # Compute hard violations. + violation_mask = torch.max( + torch.stack( + [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask], + dim=-2, + ), + dim=-2, + )[0] + violation_mask = torch.maximum( + torch.nn.functional.pad(violation_mask, (0, 1)), + torch.nn.functional.pad(violation_mask, (1, 0)), + ) + + return { + "c_n_loss_mean": c_n_loss, + "ca_c_n_loss_mean": ca_c_n_loss, + "c_n_ca_loss_mean": c_n_ca_loss, + "per_residue_loss_sum": per_residue_loss_sum, + "per_residue_violation_mask": violation_mask, + } + + +def between_residue_clash_loss( + atom14_pred_positions: torch.Tensor, + atom14_atom_exists: torch.Tensor, + atom14_atom_radius: torch.Tensor, + residue_index: torch.Tensor, + overlap_tolerance_soft=1.5, + overlap_tolerance_hard=1.5, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """Loss to penalize steric clashes between residues. + + This is a loss penalizing any steric clashes due to non bonded atoms in + different peptides coming too close. This loss corresponds to the part with + different residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + atom14_atom_radius: Van der Waals radius for each atom. + residue_index: Residue index for given amino acid. + overlap_tolerance_soft: Soft tolerance factor. + overlap_tolerance_hard: Hard tolerance factor. + + Returns: + Dict containing: + * 'mean_loss': average clash loss + * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14) + * 'per_atom_clash_mask': mask whether atom clashes with any other atom + shape (N, 14) + """ + fp_type = atom14_pred_positions.dtype + + # Create the distance matrix. + # (N, N, 14, 14) + dists = torch.sqrt(eps + torch.sum( + (atom14_pred_positions[..., :, None, :, None, :] - atom14_pred_positions[..., None, :, None, :, :])**2, + dim=-1, + )) + + # Create the mask for valid distances. + # shape (N, N, 14, 14) + dists_mask = (atom14_atom_exists[..., :, None, :, None] * atom14_atom_exists[..., None, :, None, :]).type(fp_type) + + # Mask out all the duplicate entries in the lower triangular matrix. + # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms + # are handled separately. + dists_mask = dists_mask * (residue_index[..., :, None, None, None] < residue_index[..., None, :, None, None]) + + # Backbone C--N bond between subsequent residues is no clash. + c_one_hot = torch.nn.functional.one_hot(residue_index.new_tensor(2), num_classes=14) + c_one_hot = c_one_hot.reshape(*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape) + c_one_hot = c_one_hot.type(fp_type) + n_one_hot = torch.nn.functional.one_hot(residue_index.new_tensor(0), num_classes=14) + n_one_hot = n_one_hot.reshape(*((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape) + n_one_hot = n_one_hot.type(fp_type) + + neighbour_mask = (residue_index[..., :, None, None, None] + 1) == residue_index[..., None, :, None, None] + c_n_bonds = (neighbour_mask * c_one_hot[..., None, None, :, None] * n_one_hot[..., None, None, None, :]) + dists_mask = dists_mask * (1.0 - c_n_bonds) + + # Disulfide bridge between two cysteines is no clash. + cys = residue_constants.restype_name_to_atom14_names["CYS"] + cys_sg_idx = cys.index("SG") + cys_sg_idx = residue_index.new_tensor(cys_sg_idx) + cys_sg_idx = cys_sg_idx.reshape(*((1,) * len(residue_index.shape[:-1])), 1).squeeze(-1) + cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14) + disulfide_bonds = (cys_sg_one_hot[..., None, None, :, None] * cys_sg_one_hot[..., None, None, None, :]) + dists_mask = dists_mask * (1.0 - disulfide_bonds) + + # Compute the lower bound for the allowed distances. + # shape (N, N, 14, 14) + dists_lower_bound = dists_mask * (atom14_atom_radius[..., :, None, :, None] + + atom14_atom_radius[..., None, :, None, :]) + + # Compute the error. + # shape (N, N, 14, 14) + dists_to_low_error = dists_mask * torch.nn.functional.relu(dists_lower_bound - overlap_tolerance_soft - dists) + + # Compute the mean loss. + # shape () + mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask)) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(dists_to_low_error, axis=(-3, -1)) + + # Compute the hard clash mask. + # shape (N, N, 14, 14) + clash_mask = dists_mask * (dists < (dists_lower_bound - overlap_tolerance_hard)) + + # Compute the per atom clash. + # shape (N, 14) + per_atom_clash_mask = torch.maximum( + torch.amax(clash_mask, axis=(-4, -2)), + torch.amax(clash_mask, axis=(-3, -1)), + ) + + return { + "mean_loss": mean_loss, # shape () + "per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14) + "per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14) + } + + +def within_residue_violations( + atom14_pred_positions: torch.Tensor, + atom14_atom_exists: torch.Tensor, + atom14_dists_lower_bound: torch.Tensor, + atom14_dists_upper_bound: torch.Tensor, + tighten_bounds_for_loss=0.0, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """Loss to penalize steric clashes within residues. + + This is a loss penalizing any steric violations or clashes of non-bonded atoms + in a given peptide. This loss corresponds to the part with + the same residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions ([*, N, 14, 3]): + Predicted positions of atoms in global prediction frame. + atom14_atom_exists ([*, N, 14]): + Mask denoting whether atom at positions exists for given + amino acid type + atom14_dists_lower_bound ([*, N, 14]): + Lower bound on allowed distances. + atom14_dists_upper_bound ([*, N, 14]): + Upper bound on allowed distances + tighten_bounds_for_loss ([*, N]): + Extra factor to tighten loss + + Returns: + Dict containing: + * 'per_atom_loss_sum' ([*, N, 14]): + sum of all clash losses per atom, shape + * 'per_atom_clash_mask' ([*, N, 14]): + mask whether atom clashes with any other atom shape + """ + # Compute the mask for each residue. + dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None] + dists_masks = dists_masks.reshape(*((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape) + dists_masks = (atom14_atom_exists[..., :, :, None] * atom14_atom_exists[..., :, None, :] * dists_masks) + + # Distance matrix + dists = torch.sqrt(eps + torch.sum( + (atom14_pred_positions[..., :, :, None, :] - atom14_pred_positions[..., :, None, :, :])**2, + dim=-1, + )) + + # Compute the loss. + dists_to_low_error = torch.nn.functional.relu(atom14_dists_lower_bound + tighten_bounds_for_loss - dists) + dists_to_high_error = torch.nn.functional.relu(dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + + # Compute the per atom loss sum. + per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1) + + # Compute the violations mask. + violations = dists_masks * ((dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)) + + # Compute the per atom violations. + per_atom_violations = torch.maximum(torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]) + + return { + "per_atom_loss_sum": per_atom_loss_sum, + "per_atom_violations": per_atom_violations, + } + + +def find_structural_violations( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, + violation_tolerance_factor: float, + clash_overlap_tolerance: float, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes several checks for structural violations.""" + + # Compute between residue backbone violations of bonds and angles. + connection_violations = between_residue_bond_loss( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch["atom14_atom_exists"], + residue_index=batch["residue_index"], + aatype=batch["aatype"], + tolerance_factor_soft=violation_tolerance_factor, + tolerance_factor_hard=violation_tolerance_factor, + ) + + # Compute the Van der Waals radius for every atom + # (the first letter of the atom name is the element type). + # Shape: (N, 14). + atomtype_radius = [residue_constants.van_der_waals_radius[name[0]] for name in residue_constants.atom_types] + atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius) + atom14_atom_radius = (batch["atom14_atom_exists"] * atomtype_radius[batch["residx_atom14_to_atom37"]]) + + # Compute the between residue clash loss. + between_residue_clashes = between_residue_clash_loss( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch["atom14_atom_exists"], + atom14_atom_radius=atom14_atom_radius, + residue_index=batch["residue_index"], + overlap_tolerance_soft=clash_overlap_tolerance, + overlap_tolerance_hard=clash_overlap_tolerance, + ) + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( + overlap_tolerance=clash_overlap_tolerance, + bond_length_tolerance_factor=violation_tolerance_factor, + ) + atom14_atom_exists = batch["atom14_atom_exists"] + atom14_dists_lower_bound = atom14_pred_positions.new_tensor(restype_atom14_bounds["lower_bound"])[batch["aatype"]] + atom14_dists_upper_bound = atom14_pred_positions.new_tensor(restype_atom14_bounds["upper_bound"])[batch["aatype"]] + residue_violations = within_residue_violations( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch["atom14_atom_exists"], + atom14_dists_lower_bound=atom14_dists_lower_bound, + atom14_dists_upper_bound=atom14_dists_upper_bound, + tighten_bounds_for_loss=0.0, + ) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = torch.max( + torch.stack( + [ + connection_violations["per_residue_violation_mask"], + torch.max(between_residue_clashes["per_atom_clash_mask"], dim=-1)[0], + torch.max(residue_violations["per_atom_violations"], dim=-1)[0], + ], + dim=-1, + ), + dim=-1, + )[0] + + return { + "between_residues": { + "bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # () + "angles_ca_c_n_loss_mean": connection_violations["ca_c_n_loss_mean"], # () + "angles_c_n_ca_loss_mean": connection_violations["c_n_ca_loss_mean"], # () + "connections_per_residue_loss_sum": connection_violations["per_residue_loss_sum"], # (N) + "connections_per_residue_violation_mask": connection_violations["per_residue_violation_mask"], # (N) + "clashes_mean_loss": between_residue_clashes["mean_loss"], # () + "clashes_per_atom_loss_sum": between_residue_clashes["per_atom_loss_sum"], # (N, 14) + "clashes_per_atom_clash_mask": between_residue_clashes["per_atom_clash_mask"], # (N, 14) + }, + "within_residues": { + "per_atom_loss_sum": residue_violations["per_atom_loss_sum"], # (N, 14) + "per_atom_violations": residue_violations["per_atom_violations"], # (N, 14), + }, + "total_per_residue_violations_mask": per_residue_violations_mask, # (N) + } + + +def find_structural_violations_np( + batch: Dict[str, np.ndarray], + atom14_pred_positions: np.ndarray, + config: ml_collections.ConfigDict, +) -> Dict[str, np.ndarray]: + to_tensor = lambda x: torch.tensor(x) + batch = tree_map(to_tensor, batch, np.ndarray) + atom14_pred_positions = to_tensor(atom14_pred_positions) + + out = find_structural_violations(batch, atom14_pred_positions, **config) + + to_np = lambda x: np.array(x) + np_out = tensor_tree_map(to_np, out) + + return np_out + + +def extreme_ca_ca_distance_violations( + pred_atom_positions: torch.Tensor, # (N, 37(14), 3) + pred_atom_mask: torch.Tensor, # (N, 37(14)) + residue_index: torch.Tensor, # (N) + max_angstrom_tolerance=1.5, + eps=1e-6, +) -> torch.Tensor: + """Counts residues whose Ca is a large distance from its neighbour. + + Measures the fraction of CA-CA pairs between consecutive amino acids that are + more than 'max_angstrom_tolerance' apart. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + max_angstrom_tolerance: Maximum distance allowed to not count as violation. + Returns: + Fraction of consecutive CA-CA pairs with violation. + """ + this_ca_pos = pred_atom_positions[..., :-1, 1, :] + this_ca_mask = pred_atom_mask[..., :-1, 1] + next_ca_pos = pred_atom_positions[..., 1:, 1, :] + next_ca_mask = pred_atom_mask[..., 1:, 1] + has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 + ca_ca_distance = torch.sqrt(eps + torch.sum((this_ca_pos - next_ca_pos)**2, dim=-1)) + violations = (ca_ca_distance - residue_constants.ca_ca) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + mean = masked_mean(mask, violations, -1) + return mean + + +def compute_violation_metrics( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, # (N, 14, 3) + violations: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + """Compute several metrics to assess the structural violations.""" + ret = {} + extreme_ca_ca_violations = extreme_ca_ca_distance_violations( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch["atom14_atom_exists"], + residue_index=batch["residue_index"], + ) + ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations + ret["violations_between_residue_bond"] = masked_mean( + batch["seq_mask"], + violations["between_residues"]["connections_per_residue_violation_mask"], + dim=-1, + ) + ret["violations_between_residue_clash"] = masked_mean( + mask=batch["seq_mask"], + value=torch.max( + violations["between_residues"]["clashes_per_atom_clash_mask"], + dim=-1, + )[0], + dim=-1, + ) + ret["violations_within_residue"] = masked_mean( + mask=batch["seq_mask"], + value=torch.max(violations["within_residues"]["per_atom_violations"], dim=-1)[0], + dim=-1, + ) + ret["violations_per_residue"] = masked_mean( + mask=batch["seq_mask"], + value=violations["total_per_residue_violations_mask"], + dim=-1, + ) + return ret + + +def compute_violation_metrics_np( + batch: Dict[str, np.ndarray], + atom14_pred_positions: np.ndarray, + violations: Dict[str, np.ndarray], +) -> Dict[str, np.ndarray]: + to_tensor = lambda x: torch.tensor(x) + batch = tree_map(to_tensor, batch, np.ndarray) + atom14_pred_positions = to_tensor(atom14_pred_positions) + violations = tree_map(to_tensor, violations, np.ndarray) + + out = compute_violation_metrics(batch, atom14_pred_positions, violations) + + to_np = lambda x: np.array(x) + return tree_map(to_np, out, torch.Tensor) + + +def violation_loss( + violations: Dict[str, torch.Tensor], + atom14_atom_exists: torch.Tensor, + eps=1e-6, + **kwargs, +) -> torch.Tensor: + num_atoms = torch.sum(atom14_atom_exists) + l_clash = torch.sum(violations["between_residues"]["clashes_per_atom_loss_sum"] + + violations["within_residues"]["per_atom_loss_sum"]) + l_clash = l_clash / (eps + num_atoms) + loss = (violations["between_residues"]["bonds_c_n_loss_mean"] + + violations["between_residues"]["angles_ca_c_n_loss_mean"] + + violations["between_residues"]["angles_c_n_ca_loss_mean"] + l_clash) + + return loss + + +def compute_renamed_ground_truth( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """ + Find optimal renaming of ground truth based on the predicted positions. + + Alg. 26 "renameSymmetricGroundTruthAtoms" + + This renamed ground truth is then used for all losses, + such that each loss moves the atoms in the same direction. + + Args: + batch: Dictionary containing: + * atom14_gt_positions: Ground truth positions. + * atom14_alt_gt_positions: Ground truth positions with renaming swaps. + * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by + renaming swaps. + * atom14_gt_exists: Mask for which atoms exist in ground truth. + * atom14_alt_gt_exists: Mask for which atoms exist in ground truth + after renaming. + * atom14_atom_exists: Mask for whether each atom is part of the given + amino acid type. + atom14_pred_positions: Array of atom positions in global frame with shape + Returns: + Dictionary containing: + alt_naming_is_better: Array with 1.0 where alternative swap is better. + renamed_atom14_gt_positions: Array of optimal ground truth positions + after renaming swaps are performed. + renamed_atom14_gt_exists: Mask after renaming swap is performed. + """ + + pred_dists = torch.sqrt(eps + torch.sum( + (atom14_pred_positions[..., None, :, None, :] - atom14_pred_positions[..., None, :, None, :, :])**2, + dim=-1, + )) + + atom14_gt_positions = batch["atom14_gt_positions"] + gt_dists = torch.sqrt(eps + torch.sum( + (atom14_gt_positions[..., None, :, None, :] - atom14_gt_positions[..., None, :, None, :, :])**2, + dim=-1, + )) + + atom14_alt_gt_positions = batch["atom14_alt_gt_positions"] + alt_gt_dists = torch.sqrt(eps + torch.sum( + (atom14_alt_gt_positions[..., None, :, None, :] - atom14_alt_gt_positions[..., None, :, None, :, :])**2, + dim=-1, + )) + + lddt = torch.sqrt(eps + (pred_dists - gt_dists)**2) + alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists)**2) + + atom14_gt_exists = batch["atom14_gt_exists"] + atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"] + mask = (atom14_gt_exists[..., None, :, None] * atom14_atom_is_ambiguous[..., None, :, None] * + atom14_gt_exists[..., None, :, None, :] * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])) + + per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3)) + alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3)) + + fp_type = atom14_pred_positions.dtype + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type) + + renamed_atom14_gt_positions = (1.0 - alt_naming_is_better[ + ..., None, None]) * atom14_gt_positions + alt_naming_is_better[..., None, None] * atom14_alt_gt_positions + + renamed_atom14_gt_mask = (1.0 - alt_naming_is_better[..., None] + ) * atom14_gt_exists + alt_naming_is_better[..., None] * batch["atom14_alt_gt_exists"] + + return { + "alt_naming_is_better": alt_naming_is_better, + "renamed_atom14_gt_positions": renamed_atom14_gt_positions, + "renamed_atom14_gt_exists": renamed_atom14_gt_mask, + } + + +def experimentally_resolved_loss( + logits: torch.Tensor, + atom37_atom_exists: torch.Tensor, + all_atom_mask: torch.Tensor, + resolution: torch.Tensor, + min_resolution: float, + max_resolution: float, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + errors = sigmoid_cross_entropy(logits, all_atom_mask) + loss = torch.sum(errors * atom37_atom_exists, dim=-1) + loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2))) + loss = torch.sum(loss, dim=-1) + + loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution)) + + loss = torch.mean(loss) + + return loss + + +def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): + """ + Computes BERT-style masked MSA loss. Implements subsection 1.9.9. + + Args: + logits: [*, N_seq, N_res, 23] predicted residue distribution + true_msa: [*, N_seq, N_res] true MSA + bert_mask: [*, N_seq, N_res] MSA mask + Returns: + Masked MSA loss + """ + errors = softmax_cross_entropy(logits, torch.nn.functional.one_hot(true_msa, num_classes=23)) + + # FP16-friendly averaging. Equivalent to: + # loss = ( + # torch.sum(errors * bert_mask, dim=(-1, -2)) / + # (eps + torch.sum(bert_mask, dim=(-1, -2))) + # ) + loss = errors * bert_mask + loss = torch.sum(loss, dim=-1) + scale = 0.5 + denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2)) + loss = loss / denom[..., None] + loss = torch.sum(loss, dim=-1) + loss = loss * scale + + loss = torch.mean(loss) + + return loss + + +def compute_drmsd(structure_1, structure_2, mask=None): + if (mask is not None): + structure_1 = structure_1 * mask[..., None] + structure_2 = structure_2 * mask[..., None] + + d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :] + d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :] + + d1 = d1**2 + d2 = d2**2 + + d1 = torch.sqrt(torch.sum(d1, dim=-1)) + d2 = torch.sqrt(torch.sum(d2, dim=-1)) + + drmsd = d1 - d2 + drmsd = drmsd**2 + drmsd = torch.sum(drmsd, dim=(-1, -2)) + n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) + drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.) + drmsd = torch.sqrt(drmsd) + + return drmsd + + +def compute_drmsd_np(structure_1, structure_2, mask=None): + structure_1 = torch.tensor(structure_1) + structure_2 = torch.tensor(structure_2) + if (mask is not None): + mask = torch.tensor(mask) + + return compute_drmsd(structure_1, structure_2, mask) + + +class AlphaFoldLoss(nn.Module): + """Aggregation of the various losses described in the supplement""" + + def __init__(self, config): + super(AlphaFoldLoss, self).__init__() + self.config = config + + def forward(self, out, batch, _return_breakdown=False): + if "violation" not in out.keys(): + out["violation"] = find_structural_violations( + batch, + out["sm"]["positions"][-1], + **self.config.violation, + ) + + if "renamed_atom14_gt_positions" not in out.keys(): + batch.update(compute_renamed_ground_truth( + batch, + out["sm"]["positions"][-1], + )) + + loss_fns = { + "distogram": + lambda: distogram_loss( + logits=out["distogram_logits"], + **{ + **batch, + **self.config.distogram + }, + ), + "experimentally_resolved": + lambda: experimentally_resolved_loss( + logits=out["experimentally_resolved_logits"], + **{ + **batch, + **self.config.experimentally_resolved + }, + ), + "fape": + lambda: fape_loss( + out, + batch, + self.config.fape, + ), + "lddt": + lambda: lddt_loss( + logits=out["lddt_logits"], + all_atom_pred_pos=out["final_atom_positions"], + **{ + **batch, + **self.config.lddt + }, + ), + "masked_msa": + lambda: masked_msa_loss( + logits=out["masked_msa_logits"], + **{ + **batch, + **self.config.masked_msa + }, + ), + "supervised_chi": + lambda: supervised_chi_loss( + out["sm"]["angles"], + out["sm"]["unnormalized_angles"], + **{ + **batch, + **self.config.supervised_chi + }, + ), + "violation": + lambda: violation_loss( + out["violation"], + **batch, + ), + } + + if (self.config.tm.enabled): + loss_fns["tm"] = lambda: tm_loss( + logits=out["tm_logits"], + **{ + **batch, + **out, + **self.config.tm + }, + ) + + cum_loss = 0. + losses = {} + for loss_name, loss_fn in loss_fns.items(): + weight = self.config[loss_name].weight + loss = loss_fn() + if (torch.isnan(loss) or torch.isinf(loss)): + logging.warning(f"{loss_name} loss is NaN. Skipping...") + loss = loss.new_tensor(0., requires_grad=True) + cum_loss = cum_loss + weight * loss + losses[loss_name] = loss.detach().clone() + + losses["unscaled_loss"] = cum_loss.detach().clone() + + # Scale the loss by the square root of the minimum of the crop size and + # the (average) sequence length. See subsection 1.9. + seq_len = torch.mean(batch["seq_length"].float()) + crop_len = batch["aatype"].shape[-1] + cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len)) + + losses["loss"] = cum_loss.detach().clone() + + if (not _return_breakdown): + return cum_loss + + return cum_loss, losses diff --git a/tests/test_autochunk/origin_openfold/utils/tensor_utils.py b/tests/test_autochunk/origin_openfold/utils/tensor_utils.py new file mode 100644 index 000000000000..6ef3674c7226 --- /dev/null +++ b/tests/test_autochunk/origin_openfold/utils/tensor_utils.py @@ -0,0 +1,384 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def masked_mean(mask, value, dim, eps=1e-4): + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): + boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device) + dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3))**2, dim=-1)) + return torch.bucketize(dists, boundaries) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if type(v) is dict: + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def one_hot(x, v_bins): + reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) + diffs = x[..., None] - reshaped_bins + am = torch.argmin(torch.abs(diffs), dim=-1) + return nn.functional.one_hot(am, num_classes=len(v_bins)).float() + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +# With tree_map, a poor man's JAX tree_map +def dict_map(fn, dic, leaf_type): + new_dict = {} + for k, v in dic.items(): + if type(v) is dict: + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + + return new_dict + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple([tree_map(fn, x, leaf_type) for x in tree]) + elif isinstance(tree, leaf_type): + return fn(tree) + else: + print(type(tree)) + raise ValueError("Not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) + + +def _fetch_dims(tree): + shapes = [] + tree_type = type(tree) + if tree_type is dict: + for v in tree.values(): + shapes.extend(_fetch_dims(v)) + elif tree_type is list or tree_type is tuple: + for t in tree: + shapes.extend(_fetch_dims(t)) + elif tree_type is torch.Tensor: + shapes.append(tree.shape) + else: + raise ValueError("Not supported") + + return shapes + + +@torch.jit.ignore +def _flat_idx_to_idx( + flat_idx: int, + dims: Tuple[int], +) -> Tuple[int]: + idx = [] + for d in reversed(dims): + idx.append(flat_idx % d) + flat_idx = flat_idx // d + + return tuple(reversed(idx)) + + +@torch.jit.ignore +def _get_minimal_slice_set( + start: Sequence[int], + end: Sequence[int], + dims: int, + start_edges: Optional[Sequence[bool]] = None, + end_edges: Optional[Sequence[bool]] = None, +) -> Sequence[Tuple[int]]: + """ + Produces an ordered sequence of tensor slices that, when used in + sequence on a tensor with shape dims, yields tensors that contain every + leaf in the contiguous range [start, end]. Care is taken to yield a + short sequence of slices, and perhaps even the shortest possible (I'm + pretty sure it's the latter). + + end is INCLUSIVE. + """ + + # start_edges and end_edges both indicate whether, starting from any given + # dimension, the start/end index is at the top/bottom edge of the + # corresponding tensor, modeled as a tree + def reduce_edge_list(l): + tally = 1 + for i in range(len(l)): + reversed_idx = -1 * (i + 1) + l[reversed_idx] *= tally + tally = l[reversed_idx] + + if (start_edges is None): + start_edges = [s == 0 for s in start] + reduce_edge_list(start_edges) + if (end_edges is None): + end_edges = [e == (d - 1) for e, d in zip(end, dims)] + reduce_edge_list(end_edges) + + # Base cases. Either start/end are empty and we're done, or the final, + # one-dimensional tensor can be simply sliced + if (len(start) == 0): + return [tuple()] + elif (len(start) == 1): + return [(slice(start[0], end[0] + 1),)] + + slices = [] + path = [] + + # Dimensions common to start and end can be selected directly + for s, e in zip(start, end): + if (s == e): + path.append(slice(s, s + 1)) + else: + break + + path = tuple(path) + divergence_idx = len(path) + + # start == end, and we're done + if (divergence_idx == len(dims)): + return [tuple(path)] + + def upper(): + sdi = start[divergence_idx] + return [ + path + (slice(sdi, sdi + 1),) + s + for s in _get_minimal_slice_set(start[divergence_idx + 1:], [d - 1 for d in dims[divergence_idx + 1:]], + dims[divergence_idx + 1:], + start_edges=start_edges[divergence_idx + 1:], + end_edges=[1 for _ in end_edges[divergence_idx + 1:]]) + ] + + def lower(): + edi = end[divergence_idx] + return [ + path + (slice(edi, edi + 1),) + s for s in _get_minimal_slice_set( + [0 for _ in start[divergence_idx + 1:]], + end[divergence_idx + 1:], + dims[divergence_idx + 1:], + start_edges=[1 for _ in start_edges[divergence_idx + 1:]], + end_edges=end_edges[divergence_idx + 1:], + ) + ] + + # If both start and end are at the edges of the subtree rooted at + # divergence_idx, we can just select the whole subtree at once + if (start_edges[divergence_idx] and end_edges[divergence_idx]): + slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),)) + # If just start is at the edge, we can grab almost all of the subtree, + # treating only the ragged bottom edge as an edge case + elif (start_edges[divergence_idx]): + slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),)) + slices.extend(lower()) + # Analogous to the previous case, but the top is ragged this time + elif (end_edges[divergence_idx]): + slices.extend(upper()) + slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)) + # If both sides of the range are ragged, we need to handle both sides + # separately. If there's contiguous meat in between them, we can index it + # in one big chunk + else: + slices.extend(upper()) + middle_ground = end[divergence_idx] - start[divergence_idx] + if (middle_ground > 1): + slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)) + slices.extend(lower()) + + return [tuple(s) for s in slices] + + +@torch.jit.ignore +def _chunk_slice( + t: torch.Tensor, + flat_start: int, + flat_end: int, + no_batch_dims: int, +) -> torch.Tensor: + """ + Equivalent to + + t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] + + but without the need for the initial reshape call, which can be + memory-intensive in certain situations. The only reshape operations + in this function are performed on sub-tensors that scale with + (flat_end - flat_start), the chunk size. + """ + + batch_dims = t.shape[:no_batch_dims] + start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) + # _get_minimal_slice_set is inclusive + end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) + + # Get an ordered list of slices to perform + slices = _get_minimal_slice_set( + start_idx, + end_idx, + batch_dims, + ) + + sliced_tensors = [t[s] for s in slices] + + return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]) + + +def chunk_layer( + layer: Callable, + inputs: Dict[str, Any], + chunk_size: int, + no_batch_dims: int, + low_mem: bool = False, +) -> Any: + """ + Implements the "chunking" procedure described in section 1.11.8. + + Layer outputs and inputs are assumed to be simple "pytrees," + consisting only of (arbitrarily nested) lists, tuples, and dicts with + torch.Tensor leaves. + + Args: + layer: + The layer to be applied chunk-wise + inputs: + A (non-nested) dictionary of keyworded inputs. All leaves must + be tensors and must share the same batch dimensions. + chunk_size: + The number of sub-batches per chunk. If multiple batch + dimensions are specified, a "sub-batch" is defined as a single + indexing of all batch dimensions simultaneously (s.t. the + number of sub-batches is the product of the batch dimensions). + no_batch_dims: + How many of the initial dimensions of each input tensor can + be considered batch dimensions. + low_mem: + Avoids flattening potentially large input tensors. Unnecessary + in most cases, and is ever so slightly slower than the default + setting. + Returns: + The reassembled output of the layer on the inputs. + """ + if not (len(inputs) > 0): + raise ValueError("Must provide at least one input") + + initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] + orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) + + def _prep_inputs(t): + # TODO: make this more memory efficient. This sucks + if (not low_mem): + if not sum(t.shape[:no_batch_dims]) == no_batch_dims: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + t = t.reshape(-1, *t.shape[no_batch_dims:]) + else: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + return t + + prepped_inputs = tensor_tree_map(_prep_inputs, inputs) + + flat_batch_dim = 1 + for d in orig_batch_dims: + flat_batch_dim *= d + + no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0) + + i = 0 + out = None + for _ in range(no_chunks): + # Chunk the input + if (not low_mem): + select_chunk = (lambda t: t[i:i + chunk_size] if t.shape[0] != 1 else t) + else: + select_chunk = (partial(_chunk_slice, + flat_start=i, + flat_end=min(flat_batch_dim, i + chunk_size), + no_batch_dims=len(orig_batch_dims))) + + chunks = tensor_tree_map(select_chunk, prepped_inputs) + + # Run the layer on the chunk + output_chunk = layer(**chunks) + + # Allocate space for the output + if out is None: + allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]) + out = tensor_tree_map(allocate, output_chunk) + + # Put the chunk in its pre-allocated space + out_type = type(output_chunk) + if out_type is dict: + + def assign(d1, d2): + for k, v in d1.items(): + if type(v) is dict: + assign(v, d2[k]) + else: + v[i:i + chunk_size] = d2[k] + + assign(out, output_chunk) + elif out_type is tuple: + for x1, x2 in zip(out, output_chunk): + x1[i:i + chunk_size] = x2 + elif out_type is torch.Tensor: + out[i:i + chunk_size] = output_chunk + else: + raise ValueError("Not supported") + + i += chunk_size + + reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) + out = tensor_tree_map(reshape, out) + + return out From 75713c7c4c48acf0babbe6105428df7936a682b2 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 16:07:24 +0800 Subject: [PATCH 03/14] openfold can run now --- colossalai/autochunk/trace_flow.py | 92 +++++-------------- colossalai/autochunk/trace_indice.py | 62 +++++++------ colossalai/autochunk/utils.py | 48 ++++++---- colossalai/fx/profiler/opcount.py | 39 ++------ .../test_autochunk_openfold_codegen.py | 89 +++++++++++++----- 5 files changed, 165 insertions(+), 165 deletions(-) diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index 1e2e6dc1258b..ec1e012beb17 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -10,6 +10,7 @@ class TraceFlow(object): + def __init__(self, trace_indice: TraceIndice) -> None: self.trace_indice = trace_indice @@ -28,9 +29,7 @@ def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list) end_node_trace = self.trace_indice._find_trace_from_node(end_node) end_node_trace_source = end_node_trace["source"][end_dim] - sorted_source = sorted( - end_node_trace_source.items(), key=lambda d: d[0], reverse=True - ) + sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True) for node_idx, node_dim in sorted_source: if node_idx == start_node_idx and start_dim in node_dim: return True @@ -70,10 +69,8 @@ def _find_inherit_dim(self, input_node, input_dim, node): input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) node_trace_source = self.trace_indice._find_source_trace_from_node(node) for node_dim in range(len(get_node_shape(node))): - if ( - input_node_idx in node_trace_source[node_dim] - and input_dim[0] in node_trace_source[node_dim][input_node_idx] - ): + if (input_node_idx in node_trace_source[node_dim] + and input_dim[0] in node_trace_source[node_dim][input_node_idx]): return node_dim return None @@ -81,15 +78,11 @@ def check_index_duplicate(self, chunk_infos, return_dim=False): input_dim_after_node = {} for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim( - input_node, v, self.trace_indice.node_list[k] - ) + inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k]) if inherit_dim: input_dim_after_node[k] = inherit_dim - for node in self.trace_indice.node_list[ - chunk_infos["region"][0] : chunk_infos["region"][1] + 1 - ]: + for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]: if is_non_compute_node_except_placeholder(node): continue count = 0 @@ -159,9 +152,7 @@ def _assgin_single_node_flow( if arg_node in all_node_info: if all_node_info[arg_node]["chunk_dim"] != arg_dim: return False - all_node_info[arg_node]["fix_dim"] = list( - set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) - ) + all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)) # else add it to list else: all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} @@ -170,9 +161,7 @@ def _assgin_single_node_flow( return True def _get_all_node_info(self, end_dim, start_idx, end_idx): - cur_node_list = [ - self.trace_indice.node_list[end_idx] - ] # start from the last node + cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} while len(cur_node_list) > 0: @@ -183,12 +172,8 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx): cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] if cur_node_chunk_dim: - cur_node_compute = self.trace_indice._find_compute_trace_from_node( - cur_node - ) - cur_node_source = self.trace_indice._find_source_trace_from_node( - cur_node - ) + cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node) + cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node) else: cur_node_compute = cur_node_source = None @@ -215,15 +200,9 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx): return None if len(arg_list) == 2: - if any(i in cur_node.name for i in ["add", "mul"]): + if any(i in cur_node.name for i in ["add", "mul", "truediv"]): for arg in arg_list: - if not ( - start_idx - <= find_idx_by_name( - arg.name, self.trace_indice.node_list - ) - < end_idx - ): + if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx): continue arg_chunk_dim = all_node_info[arg]["chunk_dim"] arg_fix_dim = all_node_info[arg]["fix_dim"] @@ -249,9 +228,7 @@ def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): remove_inputs = [] for input_node in inputs: input_dict = {} - input_node_idx = find_idx_by_name( - input_node.name, self.trace_indice.node_list - ) + input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) for user in input_node.users.keys(): if is_non_compute_node(user): continue @@ -259,9 +236,7 @@ def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): if start_idx <= user_idx <= end_idx: chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: - user_source = self.trace_indice._find_source_trace_from_node( - user - )[chunk_dim] + user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim] if input_node_idx in user_source: input_dict[user_idx] = user_source[input_node_idx] else: @@ -284,7 +259,7 @@ def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): maybe_prepose_nodes.sort( key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list), reverse=True, - ) # from last node to first node + ) # from last node to first node prepose_nodes = [] # set every node as root, search its args, if all legal, turn root and args as prepose nodes while len(maybe_prepose_nodes) > 0: @@ -305,13 +280,8 @@ def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): if type(cur_prepose_node_arg) != type(cur_prepose_node): continue # out of loop - if not ( - start_idx - <= find_idx_by_name( - cur_prepose_node_arg.name, self.trace_indice.node_list - ) - < end_idx - ): + if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) < + end_idx): continue # compute op in loop elif cur_prepose_node_arg in all_node_info: @@ -335,15 +305,13 @@ def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): if n in maybe_prepose_nodes: maybe_prepose_nodes.remove(n) # sort by index - prepose_nodes.sort( - key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list) - ) + prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)) return prepose_nodes def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): # we need to log input nodes to avoid deleteing them in the loop - chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1] + chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1] # also need to get some prepose node's arg out of non_chunk_inputs for n in chunk_info["args"]["prepose_nodes"]: chunk_node_list.remove(n) @@ -354,9 +322,7 @@ def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): return chunk_info def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.trace_indice.node_list[start_idx : end_idx + 1] - ) + inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1]) # only single ouput if len(outputs) > 1: return None @@ -367,9 +333,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): return None # get input nodes' chunk dim - inputs, inputs_dim = self._get_input_nodes_dim( - inputs, start_idx, end_idx, all_node_info - ) + inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) if inputs is None: return None @@ -385,9 +349,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): } # move useless nodes ahead of loop - chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes( - all_node_info, start_idx, end_idx - ) + chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx) # find non chunk inputs chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) @@ -400,10 +362,8 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): def _reassgin_reshape_size(self, chunk_info): chunk_region = chunk_info["region"] reshape_size = {} - chunk_shape = get_node_shape(chunk_info["outputs"][0])[ - chunk_info["outputs_dim"] - ] - for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]: + chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] + for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]: if any(i in node.name for i in ["reshape", "view"]): reshape_args = node.args[1:] reshape_log = self.trace_indice.indice_view_list[node] @@ -413,8 +373,6 @@ def _reassgin_reshape_size(self, chunk_info): if reshape_arg_dim in reshape_log["dim_to"]: continue if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = ( - "min(chunk_size, %d - chunk_idx)" % chunk_shape - ) + reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape) chunk_info["reshape_size"] = reshape_size return chunk_info diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 1e16ab9bdf35..5a5d15e0a1f4 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -3,7 +3,7 @@ from torch.fx.node import Node -from .utils import find_idx_by_name, get_node_shape +from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list class TraceIndice(object): @@ -79,9 +79,7 @@ def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim): node_from_trace = self._find_trace_from_node(node_from) node_to_trace = self._find_trace_from_node(node_to) node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim] - node_to_trace["compute"][node_to_dim] = copy.deepcopy( - node_from_trace["compute"][node_from_dim] - ) + node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim]) self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) def _inherit_all_computation(self, node_from, node_to): @@ -209,7 +207,7 @@ def _assign_indice_as_input(self, node, node_idx, input_node=None): node_idx (int) """ if input_node == None: - input_node = node.args[0] + input_node = find_first_tensor_arg(node) input_node_idx = find_idx_by_name(input_node.name, self.node_list) input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"] @@ -227,6 +225,8 @@ def _assign_all_indice(self, node, node_idx): node_idx (int) """ shape = node.meta["tensor_meta"].shape + if shape is None: + return new_trace = [] for _ in shape: new_trace.append(self._add_indice()) @@ -259,7 +259,7 @@ def _assign_permute_indice(self, node, node_idx): node (node) node_idx (int) """ - permute_dim = node.args[1:] + permute_dim = unflat_list(node.args[1:]) input_node = node.args[0] self._assign_indice_as_input(node, node_idx, input_node) @@ -359,6 +359,15 @@ def _assign_einsum_indice(self, node, idx): left, right = patterns.split("->") left = left.split(",") + if '...' in right: + replace_list = "!@#$%^&*" + target_len = len(get_node_shape(node)) + add_len = target_len - len(right) + 3 + replace_str = replace_list[:add_len] + right = right.replace("...", replace_str) + for ll in range(len(left)): + left[ll] = left[ll].replace("...", replace_str) + all_index = [] for i in left: for c in i: @@ -369,9 +378,7 @@ def _assign_einsum_indice(self, node, idx): for left_idx, left_str in enumerate(left): if right_indice in left_str: source_idx = left_str.index(right_indice) - self._inherit_indice( - input_nodes[left_idx], source_idx, node, right_idx - ) + self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx) def _assign_softmax_indice(self, node, idx): """ @@ -440,11 +447,12 @@ def _assign_view_reshape_indice(self, node, node_idx): origin_node = node.args[0] origin_shape = origin_node.meta["tensor_meta"].shape target_shape = [] - for i in range(1, len(node.args)): - if isinstance(node.args[i], int): - target_shape.append(node.args[i]) + unflated_args = unflat_list(node.args) + for i in range(1, len(unflated_args)): + if isinstance(unflated_args[i], int): + target_shape.append(unflated_args[i]) else: - target_shape.append(node.args[i].meta["fwd_out"][0]) + target_shape.append(unflated_args[i].meta["fwd_out"][0]) # compute the value of -1 if -1 in target_shape: @@ -472,13 +480,7 @@ def _assign_view_reshape_indice(self, node, node_idx): dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] self._del_dim(node_idx, -1) else: - raise NotImplementedError( - "shape" - + str(origin_shape) - + "and" - + str(target_shape) - + "view not implemented" - ) + raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented") # get new indice origin_trace = self._find_indice_trace_from_node(origin_node) @@ -521,6 +523,8 @@ def trace_indice(self): self._assign_unsqueeze_indice(node, idx) elif any(i in node.name for i in ["to", "contiguous"]): self._assgin_no_change_indice(node, idx) + elif "new_ones" in node.name: + self._assign_ones_like_indice(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") elif node.op == "call_function": @@ -530,7 +534,7 @@ def trace_indice(self): self._assign_matmul_indice(node, idx) elif "softmax" in node.name: self._assign_softmax_indice(node, idx) - elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]): + elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]): self._assign_elementwise_indice(node, idx) elif "ones_like" in node.name: self._assign_ones_like_indice(node, idx) @@ -538,21 +542,21 @@ def trace_indice(self): self._assign_dropout_indice(node, idx) elif "einsum" in node.name: self._assign_einsum_indice(node, idx) - elif "getattr" in node.name: - continue # get attr like shape - elif "getitem" in node.name: - continue # get item in list + elif "layer_norm" in node.name: + self._assign_layernorm_indice(node, idx) + elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]): + continue else: - raise NotImplementedError( - node.name, "function not implemented yet!" - ) + raise NotImplementedError(node.name, "function not implemented yet!") elif node.op == "call_module": if any(n in node.name for n in ["layernorm", "norm"]): self._assign_layernorm_indice(node, idx) + elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]): + self._assign_elementwise_indice(node, idx) else: raise NotImplementedError(node.name, "module not implemented yet!") elif node.op == "get_attr": - self._assign_all_indice(node, idx) # get param + self._assign_all_indice(node, idx) # get param elif node.op == "output": continue else: diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index b62a6600adc8..5f3ea3bf482d 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -3,10 +3,32 @@ from torch.fx.node import Node +def unflat_list(inputs): + """ + unflat a list by recursion + """ + res = [] + for i in inputs: + if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple): + res.extend(unflat_list(i)) + else: + res.append(i) + return res + + +def find_first_tensor_arg(node): + """ + Find the first input tensor arg for a node + """ + for arg in node.args: + if type(arg) == type(node): + return arg + raise RuntimeError() + + def is_non_compute_node(node): if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): + i in node.name for i in ["getitem", "getattr"]): return True return False @@ -18,17 +40,13 @@ def get_node_shape(node): def is_non_compute_node_except_placeholder(node): - if any(i in node.op for i in ["get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): + if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]): return True return False def is_non_compute_node_except_placeholder_output(node): - if any(i in node.op for i in ["get_attr"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): + if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]): return True return False @@ -74,22 +92,16 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]): # we treat that input node as the input of the checkpoint function for node in nodes: for input_node in node._input_nodes.keys(): - if ( - input_node not in nodes - and input_node not in input_nodes - and not is_non_compute_node_except_placeholder(input_node) - ): + if (input_node not in nodes and input_node not in input_nodes + and not is_non_compute_node_except_placeholder(input_node)): input_nodes.append(input_node) # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output for node in nodes: for output_node in node.users.keys(): - if ( - output_node not in nodes - and node not in output_nodes - and not is_non_compute_node_except_placeholder_output(output_node) - ): + if (output_node not in nodes and node not in output_nodes + and not is_non_compute_node_except_placeholder_output(output_node)): output_nodes.append(node) return input_nodes, output_nodes diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 1c39dc247750..57280d2e260f 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -249,6 +249,8 @@ def zero_flop_jit(*args): aten.sum.default, aten.sum.dim_IntList, aten.mean.dim, + aten.sub.Tensor, + aten.sub_.Tensor, # activation op aten.hardswish.default, @@ -283,36 +285,13 @@ def zero_flop_jit(*args): # TODO: this will be removed in future zero_flop_aten = [ - aten.as_strided.default, - aten.as_strided_.default, - aten.bernoulli_.float, - aten.cat.default, - aten.clone.default, - aten.copy_.default, - aten.detach.default, - aten.expand.default, - aten.empty_like.default, - aten.new_empty.default, - aten.new_empty_strided.default, - aten.ones_like.default, - aten._reshape_alias.default, - aten.select.int, - aten.select_backward.default, - aten.squeeze.dim, - aten.slice.Tensor, - aten.slice_backward.default, - aten.split.Tensor, - aten.permute.default, - aten.t.default, - aten.transpose.int, - aten._to_copy.default, - aten.unsqueeze.default, - aten.unbind.int, - aten._unsafe_view.default, - aten.view.default, - aten.where.self, - aten.zero_.default, - aten.zeros_like.default, + aten.as_strided.default, aten.as_strided_.default, aten.bernoulli_.float, aten.cat.default, aten.clone.default, + aten.copy_.default, aten.detach.default, aten.expand.default, aten.empty_like.default, aten.new_empty.default, + aten.new_empty_strided.default, aten.ones_like.default, aten._reshape_alias.default, aten.select.int, + aten.select_backward.default, aten.squeeze.dim, aten.slice.Tensor, aten.slice_backward.default, + aten.split.Tensor, aten.permute.default, aten.t.default, aten.transpose.int, aten._to_copy.default, + aten.unsqueeze.default, aten.unbind.int, aten._unsafe_view.default, aten.view.default, aten.where.self, + aten.zero_.default, aten.zeros_like.default, aten.fill_.Scalar ] for op in zero_flop_aten: diff --git a/tests/test_autochunk/test_autochunk_openfold_codegen.py b/tests/test_autochunk/test_autochunk_openfold_codegen.py index 02fa07e2ca00..86aa447a9d2b 100644 --- a/tests/test_autochunk/test_autochunk_openfold_codegen.py +++ b/tests/test_autochunk/test_autochunk_openfold_codegen.py @@ -7,20 +7,21 @@ import colossalai from colossalai.core import global_context as gpc -from colossalai.fx import ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace from colossalai.utils import free_port from tests.test_autochunk.evoformer.evoformer import evoformer_base +from tests.test_autochunk.origin_openfold.evoformer import EvoformerBlock if CODEGEN_AVAILABLE and is_compatible_with_meta(): from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.fx.profiler import MetaTensor -def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): +def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask): # for memory test # torch.cuda.reset_peak_memory_stats() # now_mem = torch.cuda.memory_allocated() / 1024**2 @@ -36,9 +37,10 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): # ) # test forward + model = model.cuda() with torch.no_grad(): - non_fx_out = model(node, pair) - fx_out = gm(node, pair) + non_fx_out = model(node, pair, node_mask, pair_mask) + fx_out = gm(node, pair, node_mask, pair_mask) assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( @@ -48,6 +50,26 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): torch.abs(non_fx_out[1] - fx_out[1])) +def _build_openfold(): + model = EvoformerBlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + is_multimer=False, + ).eval().cuda() + return model + + def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): # launch colossalai colossalai.launch( @@ -60,29 +82,51 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): ) # build model and input - model = evoformer_base().cuda() + model = _build_openfold() node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() - # trace the module and replace codegen - graph = ColoTracer().trace( + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( model, meta_args={ - "node": node.to(torch.device("meta")), - "pair": pair.to(torch.device("meta")), + "m": node.to(torch.device("meta")), + "z": pair.to(torch.device("meta")), + "msa_mask": node_mask.to(torch.device("meta")), + "pair_mask": pair_mask.to(torch.device("meta")), + }, + concrete_args={ + "chunk_size": None, + "_mask_trans": True, }, ) - gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace - interp = MetaInfoProp(gm_prop) - interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) - - # now run it twice to get meta info in graph module, not necessary - gm = torch.fx.GraphModule(model, graph) - interp = MetaInfoProp(gm) - interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) + interp = MetaInfoProp(meta_graph) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), + MetaTensor(pair, fake_device="cuda:0"), + MetaTensor(node_mask, fake_device="cuda:0"), + MetaTensor(pair_mask, fake_device="cuda:0"), + ) + codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory) - codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory) - graph.set_codegen(codegen) + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model, + meta_args={ + "m": node.to(torch.device("meta")), + "z": pair.to(torch.device("meta")), + "msa_mask": node_mask.to(torch.device("meta")), + "pair_mask": pair_mask.to(torch.device("meta")), + }, + concrete_args={ + "chunk_size": None, + "_mask_trans": True, + }, + ) + # graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile() @@ -91,11 +135,14 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): assert "chunk_size" in code # print(code) - _test_fwd(model, gm, node, pair) + _test_fwd(model, gm, node, pair, node_mask, pair_mask) gpc.destroy() -@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif( + not (CODEGEN_AVAILABLE and is_compatible_with_meta()), + reason="torch version is lower than 1.12.0", +) @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) From eab2775b63227d11c41936ac0c0b6a09a94db493 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 16:13:55 +0800 Subject: [PATCH 04/14] remove useless attr --- colossalai/autochunk/autochunk_codegen.py | 168 ++++++---------------- 1 file changed, 46 insertions(+), 122 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index e8af9bde86d8..ceccb9a9fde2 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -48,9 +48,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> return new_shape -def _gen_loop_start( - chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2 -) -> str: +def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str: """ Generate chunk loop start @@ -72,9 +70,8 @@ def _gen_loop_start( out_shape = get_node_shape(chunk_output) out_str = str(list(out_shape)) context = ( - "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" - % (out_str, input_node.name, input_node.name, chunk_size) - ) + "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" % + (out_str, input_node.name, input_node.name, chunk_size)) context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) return context @@ -105,26 +102,17 @@ def _gen_loop_end( chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list) chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape - chunk_slice = _gen_chunk_slice_dim( - chunk_outputs_dim, "chunk_idx", chunk_output_shape - ) + chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape) context = " chunk_result%s = %s; %s = None\n" % ( chunk_slice, chunk_outputs_name, chunk_outputs_name, ) - context += ( - chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" - ) + context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None") # determine if its the last use for chunk input for chunk_input in chunk_inputs + chunk_non_compute_inputs: - if all( - [ - find_idx_by_name(user.name, node_list) <= chunk_outputs_idx - for user in chunk_input.users.keys() - ] - ): + if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]): context += "; %s = None" % chunk_input.name context += "\n" @@ -171,17 +159,10 @@ def _replace_ones_like( chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] if get_node_shape(meta_node)[chunk_dim] != 1: source_node = meta_node.args[0].args[0] - if ( - source_node not in chunk_infos[region_idx]["node_chunk_dim"] - or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] - is None - ): - chunk_slice = _gen_chunk_slice_dim( - chunk_dim, "chunk_idx", get_node_shape(node) - ) - body[-1] = _replace_name( - body[-1], node.args[0].name, node.args[0].name + chunk_slice - ) + if (source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None): + chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node)) + body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice) return body @@ -198,12 +179,8 @@ def _replace_input_node( for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim( - dim[0], "chunk_idx", get_node_shape(input_node) - ) - body[-1] = _replace_name( - body[-1], input_node.name, input_node.name + chunk_slice - ) + chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node)) + body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice) return body @@ -236,14 +213,10 @@ def emit_code_with_chunk( chunk_ends = [i["region"][1] for i in chunk_infos] # chunk inputs - chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk - chunk_inputs_non_chunk = [ - i["inputs_non_chunk"] for i in chunk_infos - ] # input without chunk - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim - chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ - j.name for i in chunk_inputs_non_chunk for j in i - ] + chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] # chunk outputs chunk_outputs = [i["outputs"][0] for i in chunk_infos] @@ -267,23 +240,16 @@ def emit_code_with_chunk( chunk_outputs[region_idx], chunk_outputs_dim[region_idx], chunk_infos[region_idx]["chunk_size"], - ) - ) + )) if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - body = _replace_input_node( - chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body - ) + body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body) # ones like - body = _replace_ones_like( - search_chunk, chunk_infos, region_idx, node_idx, node, body - ) + body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body) # reassgin reshape size - body[-1] = _replace_reshape_size( - body[-1], node.name, chunk_infos[region_idx]["reshape_size"] - ) + body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"]) body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) else: @@ -300,8 +266,7 @@ def emit_code_with_chunk( chunk_outputs[region_idx], chunk_outputs_dim[region_idx], node_list, - ) - ) + )) within_chunk_region = False node_idx += 1 @@ -310,18 +275,14 @@ def emit_code_with_chunk( if CODEGEN_AVAILABLE: class AutoChunkCodeGen(CodeGen): + def __init__(self, meta_graph, max_memory=None, print_mem=False): super().__init__() - self.meta_graph = meta_graph - self.max_memory = max_memory - self.meta_node = list(meta_graph.graph.nodes) # find the chunk regions self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) self.chunk_infos = self.search_chunk.search_region() - def _gen_python_code( - self, nodes, root_module: str, namespace: _Namespace - ) -> PythonCode: + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] globals_: Dict[str, Any] = {} @@ -338,9 +299,7 @@ def add_global(name_hint: str, obj: Any): Returns: the global name that should be used to reference 'obj' in generated source. """ - if ( - _is_from_torch(obj) and obj != torch.device - ): # to support registering torch.device + if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -356,9 +315,7 @@ def add_global(name_hint: str, obj: Any): return global_name # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin( - "import colossalai", colossalai - ) + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) # Pre-fill the globals table with registered builtins. for name, (_, obj) in _custom_builtins.items(): @@ -394,9 +351,8 @@ def type_repr(o: Any): # Common case: this is a regular module name like 'foo.bar.baz' return add_global(typename, o) - def _format_args( - args: Tuple[Argument, ...], kwargs: Dict[str, Argument] - ) -> str: + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. if isinstance(arg, tuple) and hasattr(arg, "_fields"): @@ -444,26 +400,18 @@ def delete_unused_values(user: Node, body, to_keep=[]): nodes_to_delete = user_to_last_uses.get(user, []) nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] if len(nodes_to_delete): - to_delete_str = " = ".join( - [repr(n) for n in nodes_to_delete] + ["None"] - ) + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) body.append(f"; {to_delete_str}\n") else: body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = ( - "" if node.type is None else f" : {type_repr(node.type)}" - ) + maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}") if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = ( - "" if not node.args else f" = {repr(node.args[0])}" - ) - free_vars.append( - f"{node.target}{maybe_type_annotation}{maybe_default_arg}" - ) + maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}") + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") raw_name = node.target.replace("*", "") if raw_name != repr(node): body.append(f"{repr(node)} = {raw_name}\n") @@ -472,68 +420,46 @@ def emit_node(node: Node, body): assert isinstance(node.target, str) body.append( f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" - f"({_format_args(node.args[1:], node.kwargs)})" - ) + f"({_format_args(node.args[1:], node.kwargs)})") return elif node.op == "call_function": assert callable(node.target) # pretty print operators - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in magic_methods - ): + if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods): assert isinstance(node.args, tuple) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" - ) + body.append(f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}") return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in inplace_methods - ): - body.append( - f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " - f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" - ) + if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods): + body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}") return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if ( - global_name == "getattr" - and isinstance(node.args, tuple) - and isinstance(node.args[1], str) - and node.args[1].isidentifier() - and len(node.args) == 2 - ): + if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str) + and node.args[1].isidentifier() and len(node.args) == 2): body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" - ) + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}") return body.append( - f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" - ) + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})") if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return elif node.op == "call_module": assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" - ) + body.append(f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})") return elif node.op == "get_attr": assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" - ) + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return elif node.op == "output": if node.type is not None: @@ -564,9 +490,7 @@ def emit_node(node: Node, body): if len(wrapped_fns) > 0: wrap_name = add_global("wrap", torch.fx.wrap) - wrap_stmts = "\n".join( - [f'{wrap_name}("{name}")' for name in wrapped_fns] - ) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: wrap_stmts = "" From 4cf57bcf54b69605c80ac603d7ba7568e93471f3 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 16:20:31 +0800 Subject: [PATCH 05/14] reformat --- colossalai/fx/profiler/opcount.py | 40 ++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 57280d2e260f..6bd612ad2fd1 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -285,14 +285,38 @@ def zero_flop_jit(*args): # TODO: this will be removed in future zero_flop_aten = [ - aten.as_strided.default, aten.as_strided_.default, aten.bernoulli_.float, aten.cat.default, aten.clone.default, - aten.copy_.default, aten.detach.default, aten.expand.default, aten.empty_like.default, aten.new_empty.default, - aten.new_empty_strided.default, aten.ones_like.default, aten._reshape_alias.default, aten.select.int, - aten.select_backward.default, aten.squeeze.dim, aten.slice.Tensor, aten.slice_backward.default, - aten.split.Tensor, aten.permute.default, aten.t.default, aten.transpose.int, aten._to_copy.default, - aten.unsqueeze.default, aten.unbind.int, aten._unsafe_view.default, aten.view.default, aten.where.self, - aten.zero_.default, aten.zeros_like.default, aten.fill_.Scalar - ] + aten.as_strided.default, + aten.as_strided_.default, + aten.bernoulli_.float, + aten.cat.default, + aten.clone.default, + aten.copy_.default, + aten.detach.default, + aten.expand.default, + aten.empty_like.default, + aten.new_empty.default, + aten.new_empty_strided.default, + aten.ones_like.default, + aten._reshape_alias.default, + aten.select.int, + aten.select_backward.default, + aten.squeeze.dim, + aten.slice.Tensor, + aten.slice_backward.default, + aten.split.Tensor, + aten.permute.default, + aten.t.default, + aten.transpose.int, + aten._to_copy.default, + aten.unsqueeze.default, + aten.unbind.int, + aten._unsafe_view.default, + aten.view.default, + aten.where.self, + aten.zero_.default, + aten.zeros_like.default, + aten.fill_.Scalar + ] # yapf: disable for op in zero_flop_aten: flop_mapping[op] = zero_flop_jit From 053f8bf33971b7c2a1939f0a36d3745adf0a6ad9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 16:38:30 +0800 Subject: [PATCH 06/14] code format --- .../origin_openfold/structure_module.py | 12 ++++++------ .../utils/geometry/rotation_matrix.py | 2 +- .../origin_openfold/utils/tensor_utils.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_autochunk/origin_openfold/structure_module.py b/tests/test_autochunk/origin_openfold/structure_module.py index da3b98202a26..051109515daa 100644 --- a/tests/test_autochunk/origin_openfold/structure_module.py +++ b/tests/test_autochunk/origin_openfold/structure_module.py @@ -124,8 +124,8 @@ def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tenso s = self.linear_in(s) s = s + s_initial - for l in self.layers: - s = l(s) + for ll in self.layers: + s = ll(s) s = self.relu(s) @@ -526,15 +526,15 @@ def __init__(self, c: int, num_layers: int, dropout_rate: float): self.layers = nn.ModuleList() for _ in range(self.num_layers): - l = StructureModuleTransitionLayer(self.c) - self.layers.append(l) + ll = StructureModuleTransitionLayer(self.c) + self.layers.append(ll) self.dropout = nn.Dropout(self.dropout_rate) self.layer_norm = LayerNorm(self.c) def forward(self, s: torch.Tensor) -> torch.Tensor: - for l in self.layers: - s = l(s) + for ll in self.layers: + s = ll(s) s = self.dropout(s) s = self.layer_norm(s) diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py b/tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py index 552128ac0c0c..8d929524268a 100644 --- a/tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py +++ b/tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py @@ -158,5 +158,5 @@ def reshape(self, new_shape): @classmethod def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array: field_names = utils.get_field_names(Rot3Array) - cat_fn = lambda l: torch.cat(l, dim=dim) + cat_fn = lambda ll: torch.cat(ll, dim=dim) return cls(**{name: cat_fn([getattr(r, name) for r in rots]) for name in field_names}) diff --git a/tests/test_autochunk/origin_openfold/utils/tensor_utils.py b/tests/test_autochunk/origin_openfold/utils/tensor_utils.py index 6ef3674c7226..5d5b3c32b5c6 100644 --- a/tests/test_autochunk/origin_openfold/utils/tensor_utils.py +++ b/tests/test_autochunk/origin_openfold/utils/tensor_utils.py @@ -154,12 +154,12 @@ def _get_minimal_slice_set( # start_edges and end_edges both indicate whether, starting from any given # dimension, the start/end index is at the top/bottom edge of the # corresponding tensor, modeled as a tree - def reduce_edge_list(l): + def reduce_edge_list(ll): tally = 1 - for i in range(len(l)): + for i in range(len(ll)): reversed_idx = -1 * (i + 1) - l[reversed_idx] *= tally - tally = l[reversed_idx] + ll[reversed_idx] *= tally + tally = ll[reversed_idx] if (start_edges is None): start_edges = [s == 0 for s in start] From d68bfbf6ed271583dd386c5844a56be7d65580c5 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 17:26:52 +0800 Subject: [PATCH 07/14] use repo for simple evoformer --- tests/test_autochunk/benchmark_autochunk.py | 50 ++++--------------- .../test_autochunk/test_autochunk_codegen.py | 4 +- 2 files changed, 13 insertions(+), 41 deletions(-) diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_autochunk.py index 6632ece61376..e9549f005259 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_autochunk.py @@ -2,14 +2,13 @@ import torch import torch.fx +from simple_evoformer import base_evoformer, openfold_evoformer from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor -from tests.test_autochunk.evoformer.evoformer import evoformer_base -from tests.test_autochunk.openfold.evoformer import EvoformerBlock def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): @@ -34,10 +33,7 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N time2 = time.time() new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - print( - "%s: time %.4fs, mem %dMB" - % (title, (time2 - time1) / loop, new_max_mem - now_mem) - ) + print("%s: time %.4fs, mem %dMB" % (title, (time2 - time1) / loop, new_max_mem - now_mem)) def _build_autochunk(model, max_memory, node, pair): @@ -50,18 +46,14 @@ def _build_autochunk(model, max_memory, node, pair): }, ) - gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace interp = MetaInfoProp(gm_prop) - interp.propagate( - MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") - ) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) # now run it twice to get meta info in graph module, not necessary gm = torch.fx.GraphModule(model, graph) interp = MetaInfoProp(gm) - interp.propagate( - MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") - ) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) # set code_gen codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False) @@ -75,42 +67,22 @@ def _build_autochunk(model, max_memory, node, pair): return gm -def _build_openfold(): - model = EvoformerBlock( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - no_heads_msa=8, - no_heads_pair=4, - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.15, - inf=1e4, - eps=1e-4, - is_multimer=False, - ).cuda() - return model - - def benchmark_evoformer(): # init data and model - msa_len = 256 - pair_len = 512 + msa_len = 32 + pair_len = 64 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() - model = evoformer_base().cuda() + model = base_evoformer().cuda() # build autochunk model # max_memory = 1000 # MB, fit memory mode - max_memory = None # min memory mode - autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) + max_memory = None # min memory mode + autochunk = _build_autochunk(base_evoformer().cuda(), max_memory, node, pair) # build openfold chunk_size = 64 - openfold = _build_openfold() + openfold = openfold_evoformer().cuda() # benchmark _benchmark_evoformer(model, node, pair, "base") diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index 02fa07e2ca00..9ac4521e7a77 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -4,6 +4,7 @@ import torch import torch.fx import torch.multiprocessing as mp +from simple_evoformer import base_evoformer import colossalai from colossalai.core import global_context as gpc @@ -13,7 +14,6 @@ from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.utils import free_port -from tests.test_autochunk.evoformer.evoformer import evoformer_base if CODEGEN_AVAILABLE and is_compatible_with_meta(): from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen @@ -60,7 +60,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): ) # build model and input - model = evoformer_base().cuda() + model = base_evoformer().cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() From ac40df0853a76071af8836232f948e0c4b7f5b8e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 17:27:46 +0800 Subject: [PATCH 08/14] rename --- ...hmark_autochunk.py => benchmark_simple_evoformer_autochunk.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/test_autochunk/{benchmark_autochunk.py => benchmark_simple_evoformer_autochunk.py} (100%) diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_simple_evoformer_autochunk.py similarity index 100% rename from tests/test_autochunk/benchmark_autochunk.py rename to tests/test_autochunk/benchmark_simple_evoformer_autochunk.py From 3e2532a9fbd3980c900261a24ae2af4771054b2d Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 17:30:58 +0800 Subject: [PATCH 09/14] rename --- ...ple_evoformer_autochunk.py => benchmark_simple_evoformer.py} | 0 tests/test_autochunk/test_autochunk_openfold_codegen.py | 1 - ...st_autochunk_codegen.py => test_simple_evoformer_codegen.py} | 0 ...test_autochunk_search.py => test_simple_evoformer_search.py} | 2 +- 4 files changed, 1 insertion(+), 2 deletions(-) rename tests/test_autochunk/{benchmark_simple_evoformer_autochunk.py => benchmark_simple_evoformer.py} (100%) rename tests/test_autochunk/{test_autochunk_codegen.py => test_simple_evoformer_codegen.py} (100%) rename tests/test_autochunk/{test_autochunk_search.py => test_simple_evoformer_search.py} (98%) diff --git a/tests/test_autochunk/benchmark_simple_evoformer_autochunk.py b/tests/test_autochunk/benchmark_simple_evoformer.py similarity index 100% rename from tests/test_autochunk/benchmark_simple_evoformer_autochunk.py rename to tests/test_autochunk/benchmark_simple_evoformer.py diff --git a/tests/test_autochunk/test_autochunk_openfold_codegen.py b/tests/test_autochunk/test_autochunk_openfold_codegen.py index 86aa447a9d2b..11bb6d79080a 100644 --- a/tests/test_autochunk/test_autochunk_openfold_codegen.py +++ b/tests/test_autochunk/test_autochunk_openfold_codegen.py @@ -13,7 +13,6 @@ from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace from colossalai.utils import free_port -from tests.test_autochunk.evoformer.evoformer import evoformer_base from tests.test_autochunk.origin_openfold.evoformer import EvoformerBlock if CODEGEN_AVAILABLE and is_compatible_with_meta(): diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_simple_evoformer_codegen.py similarity index 100% rename from tests/test_autochunk/test_autochunk_codegen.py rename to tests/test_autochunk/test_simple_evoformer_codegen.py diff --git a/tests/test_autochunk/test_autochunk_search.py b/tests/test_autochunk/test_simple_evoformer_search.py similarity index 98% rename from tests/test_autochunk/test_autochunk_search.py rename to tests/test_autochunk/test_simple_evoformer_search.py index 371fce64fdf7..5c418b9edcce 100644 --- a/tests/test_autochunk/test_autochunk_search.py +++ b/tests/test_autochunk/test_simple_evoformer_search.py @@ -4,6 +4,7 @@ import torch import torch.fx import torch.multiprocessing as mp +from simple_evoformer import base_evoformer, openfold_evoformer import colossalai from colossalai.core import global_context as gpc @@ -11,7 +12,6 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.utils import free_port -from tests.test_autochunk.evoformer.evoformer import evoformer_base if CODEGEN_AVAILABLE and is_compatible_with_meta(): from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen From e4413d6a0bdae010a329acb8af70570309351355 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 17:36:09 +0800 Subject: [PATCH 10/14] detect repo in test --- .../test_autochunk/test_simple_evoformer_codegen.py | 10 ++++++++-- tests/test_autochunk/test_simple_evoformer_search.py | 12 +++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/test_autochunk/test_simple_evoformer_codegen.py b/tests/test_autochunk/test_simple_evoformer_codegen.py index 9ac4521e7a77..02c38505a0ac 100644 --- a/tests/test_autochunk/test_simple_evoformer_codegen.py +++ b/tests/test_autochunk/test_simple_evoformer_codegen.py @@ -4,7 +4,12 @@ import torch import torch.fx import torch.multiprocessing as mp -from simple_evoformer import base_evoformer + +try: + from simple_evoformer import base_evoformer + HAS_REPO = True +except: + HAS_REPO = False import colossalai from colossalai.core import global_context as gpc @@ -95,7 +100,8 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): gpc.destroy() -@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), + reason='torch version is lower than 1.12.0') @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) diff --git a/tests/test_autochunk/test_simple_evoformer_search.py b/tests/test_autochunk/test_simple_evoformer_search.py index 5c418b9edcce..e295558c88a0 100644 --- a/tests/test_autochunk/test_simple_evoformer_search.py +++ b/tests/test_autochunk/test_simple_evoformer_search.py @@ -4,7 +4,12 @@ import torch import torch.fx import torch.multiprocessing as mp -from simple_evoformer import base_evoformer, openfold_evoformer + +try: + from simple_evoformer import base_evoformer + HAS_REPO = True +except: + HAS_REPO = False import colossalai from colossalai.core import global_context as gpc @@ -69,7 +74,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory): ) # build model and input - model = evoformer_base().cuda() + model = base_evoformer().cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() @@ -84,7 +89,8 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory): gpc.destroy() -@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason="torch version is lower than 1.12.0") +@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), + reason="torch version is lower than 1.12.0") @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) From a01f9679818945df3a81314119620d2a4a06c0fc Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 17:50:49 +0800 Subject: [PATCH 11/14] delete dirs --- .../benchmark_simple_evoformer.py | 4 +- tests/test_autochunk/evoformer/evoformer.py | 59 - tests/test_autochunk/evoformer/initializer.py | 29 - tests/test_autochunk/evoformer/kernel.py | 19 - tests/test_autochunk/evoformer/msa.py | 95 -- tests/test_autochunk/evoformer/ops.py | 176 --- tests/test_autochunk/evoformer/triangle.py | 192 --- .../test_autochunk/openfold/checkpointing.py | 84 - tests/test_autochunk/openfold/dropout.py | 78 - tests/test_autochunk/openfold/evoformer.py | 431 ----- tests/test_autochunk/openfold/msa.py | 331 ---- .../openfold/outer_product_mean.py | 129 -- .../openfold/pair_transition.py | 99 -- tests/test_autochunk/openfold/primitives.py | 529 ------- tests/test_autochunk/openfold/tensor_utils.py | 408 ----- .../openfold/triangular_attention.py | 139 -- .../triangular_multiplicative_update.py | 127 -- .../origin_openfold/__init__.py | 0 .../test_autochunk/origin_openfold/dropout.py | 78 - .../origin_openfold/embedders.py | 420 ----- .../origin_openfold/embedders_multimer.py | 352 ----- .../origin_openfold/evoformer.py | 626 -------- tests/test_autochunk/origin_openfold/heads.py | 231 --- tests/test_autochunk/origin_openfold/msa.py | 384 ----- .../origin_openfold/outer_product_mean.py | 124 -- .../origin_openfold/pair_transition.py | 103 -- .../origin_openfold/primitives.py | 544 ------- .../origin_openfold/structure_module.py | 914 ----------- .../origin_openfold/template.py | 308 ---- .../origin_openfold/triangular_attention.py | 130 -- .../triangular_multiplicative_update.py | 129 -- .../utils/all_atom_multimer.py | 415 ----- .../origin_openfold/utils/checkpointing.py | 86 - .../origin_openfold/utils/feats.py | 302 ---- .../utils/geometry/__init__.py | 26 - .../utils/geometry/quat_rigid.py | 42 - .../utils/geometry/rigid_matrix_vector.py | 156 -- .../utils/geometry/rotation_matrix.py | 162 -- .../utils/geometry/test_utils.py | 86 - .../origin_openfold/utils/geometry/utils.py | 22 - .../origin_openfold/utils/geometry/vector.py | 253 --- .../origin_openfold/utils/loss.py | 1403 ----------------- .../origin_openfold/utils/tensor_utils.py | 384 ----- ...d_codegen.py => test_evoformer_codegen.py} | 4 +- 44 files changed, 4 insertions(+), 10609 deletions(-) delete mode 100644 tests/test_autochunk/evoformer/evoformer.py delete mode 100755 tests/test_autochunk/evoformer/initializer.py delete mode 100644 tests/test_autochunk/evoformer/kernel.py delete mode 100644 tests/test_autochunk/evoformer/msa.py delete mode 100755 tests/test_autochunk/evoformer/ops.py delete mode 100644 tests/test_autochunk/evoformer/triangle.py delete mode 100644 tests/test_autochunk/openfold/checkpointing.py delete mode 100644 tests/test_autochunk/openfold/dropout.py delete mode 100644 tests/test_autochunk/openfold/evoformer.py delete mode 100644 tests/test_autochunk/openfold/msa.py delete mode 100644 tests/test_autochunk/openfold/outer_product_mean.py delete mode 100644 tests/test_autochunk/openfold/pair_transition.py delete mode 100644 tests/test_autochunk/openfold/primitives.py delete mode 100644 tests/test_autochunk/openfold/tensor_utils.py delete mode 100644 tests/test_autochunk/openfold/triangular_attention.py delete mode 100644 tests/test_autochunk/openfold/triangular_multiplicative_update.py delete mode 100644 tests/test_autochunk/origin_openfold/__init__.py delete mode 100644 tests/test_autochunk/origin_openfold/dropout.py delete mode 100644 tests/test_autochunk/origin_openfold/embedders.py delete mode 100644 tests/test_autochunk/origin_openfold/embedders_multimer.py delete mode 100644 tests/test_autochunk/origin_openfold/evoformer.py delete mode 100644 tests/test_autochunk/origin_openfold/heads.py delete mode 100644 tests/test_autochunk/origin_openfold/msa.py delete mode 100644 tests/test_autochunk/origin_openfold/outer_product_mean.py delete mode 100644 tests/test_autochunk/origin_openfold/pair_transition.py delete mode 100644 tests/test_autochunk/origin_openfold/primitives.py delete mode 100644 tests/test_autochunk/origin_openfold/structure_module.py delete mode 100644 tests/test_autochunk/origin_openfold/template.py delete mode 100644 tests/test_autochunk/origin_openfold/triangular_attention.py delete mode 100644 tests/test_autochunk/origin_openfold/triangular_multiplicative_update.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/all_atom_multimer.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/checkpointing.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/feats.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/__init__.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/quat_rigid.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/rigid_matrix_vector.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/test_utils.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/utils.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/geometry/vector.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/loss.py delete mode 100644 tests/test_autochunk/origin_openfold/utils/tensor_utils.py rename tests/test_autochunk/{test_autochunk_openfold_codegen.py => test_evoformer_codegen.py} (97%) diff --git a/tests/test_autochunk/benchmark_simple_evoformer.py b/tests/test_autochunk/benchmark_simple_evoformer.py index e9549f005259..8b5d8a8bee77 100644 --- a/tests/test_autochunk/benchmark_simple_evoformer.py +++ b/tests/test_autochunk/benchmark_simple_evoformer.py @@ -69,8 +69,8 @@ def _build_autochunk(model, max_memory, node, pair): def benchmark_evoformer(): # init data and model - msa_len = 32 - pair_len = 64 + msa_len = 128 + pair_len = 256 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() model = base_evoformer().cuda() diff --git a/tests/test_autochunk/evoformer/evoformer.py b/tests/test_autochunk/evoformer/evoformer.py deleted file mode 100644 index cfd2bb2a2529..000000000000 --- a/tests/test_autochunk/evoformer/evoformer.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import torch.nn as nn - -from .msa import MSAStack -from .ops import OutProductMean -from .triangle import PairStack - - -def print_memory(init_mem, text=None): - now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem - max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem - print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem)) - torch.cuda.reset_peak_memory_stats() - - -class EvoformerBlock(nn.Module): - - def __init__(self, d_node, d_pair): - super(EvoformerBlock, self).__init__() - - self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) - self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) - self.pair_stack = PairStack(d_pair=d_pair) - - def forward(self, node, pair): - node = self.msa_stack(node, pair) - pair = pair + self.communication(node) - pair = self.pair_stack(pair) - return node, pair - - -class Evoformer(nn.Module): - - def __init__(self, d_node, d_pair): - super(Evoformer, self).__init__() - - self.blocks = nn.ModuleList() - for _ in range(1): - self.blocks.append(EvoformerBlock(d_node, d_pair)) - - def forward(self, node, pair): - for b in self.blocks: - node, pair = b(node, pair) - return node, pair - - -def evoformer_tiny(): - return Evoformer(d_node=64, d_pair=32) - - -def evoformer_base(): - return Evoformer(d_node=256, d_pair=128) - - -def evoformer_large(): - return Evoformer(d_node=512, d_pair=256) - - -__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large'] diff --git a/tests/test_autochunk/evoformer/initializer.py b/tests/test_autochunk/evoformer/initializer.py deleted file mode 100755 index c6ce0659e597..000000000000 --- a/tests/test_autochunk/evoformer/initializer.py +++ /dev/null @@ -1,29 +0,0 @@ -import math - -import numpy as np -import torch.nn as nn - - -def glorot_uniform_af(x, gain=1.0): - """ - initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different: - In PyTorch: - [feature_out, feature_in, n_head ...] - In Jax: - [... n_head, feature_in, feature_out] - However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like: - [feature_in, n_head, feature_out] - - In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors - """ - fan_in, fan_out = x.shape[-2:] - if len(x.shape) > 2: - receptive_field_size = np.prod(x.shape[:-2]) - fan_in *= receptive_field_size - fan_out *= receptive_field_size - std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation - - nn.init.uniform_(x, -dev, dev) - - return x diff --git a/tests/test_autochunk/evoformer/kernel.py b/tests/test_autochunk/evoformer/kernel.py deleted file mode 100644 index 26ab5dc53261..000000000000 --- a/tests/test_autochunk/evoformer/kernel.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torch.nn.functional as F - - -def bias_sigmod_ele(y, bias, z): - return torch.sigmoid(y + bias) * z - - -def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, - residual: torch.Tensor, prob: float) -> torch.Tensor: - out = (x + bias) * F.dropout(dropmask, p=prob, training=False) - out = residual + out - return out - - -def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor, - dropout_mask: torch.Tensor, Z_raw: torch.Tensor, - prob: float) -> torch.Tensor: - return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b)) \ No newline at end of file diff --git a/tests/test_autochunk/evoformer/msa.py b/tests/test_autochunk/evoformer/msa.py deleted file mode 100644 index cac456638a55..000000000000 --- a/tests/test_autochunk/evoformer/msa.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add -from .ops import SelfAttention, Transition - - -class MSARowAttentionWithPairBias(nn.Module): - - def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15): - super(MSARowAttentionWithPairBias, self).__init__() - self.d_node = d_node - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernormM = LayerNorm(d_node) - self.layernormZ = LayerNorm(d_pair) - - _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True) - - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True) - - def forward(self, M_raw, Z): - ## Input projections - M = self.layernormM(M_raw) - Z = self.layernormZ(Z) - b = F.linear(Z, self.linear_b_weights) - b = b.permute(0, 3, 1, 2) - # b = rearrange(b, 'b q k h -> b h q k') - - M = self.attention(M, b) - dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype) - - return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) - - -class MSAColumnAttention(nn.Module): - - def __init__(self, d_node, c=32, n_head=8): - super(MSAColumnAttention, self).__init__() - self.d_node = d_node - self.c = c - self.n_head = n_head - - self.layernormM = LayerNorm(d_node) - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True) - - def forward(self, M_raw): - M = M_raw.transpose(-2, -3) - M = self.layernormM(M) - - M = self.attention(M) - - M = M.transpose(-2, -3) - return M_raw + M - - -class MSAStack(nn.Module): - - def __init__(self, d_node, d_pair, p_drop=0.15): - super(MSAStack, self).__init__() - - self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, - d_pair=d_pair, - p_drop=p_drop) - - self.MSAColumnAttention = MSAColumnAttention(d_node=d_node) - self.MSATransition = Transition(d=d_node) - - def forward(self, node, pair): - node = self.MSARowAttentionWithPairBias(node, pair) - node = self.MSAColumnAttention(node) - node = self.MSATransition(node) - - return node diff --git a/tests/test_autochunk/evoformer/ops.py b/tests/test_autochunk/evoformer/ops.py deleted file mode 100755 index a56057522eaa..000000000000 --- a/tests/test_autochunk/evoformer/ops.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .initializer import glorot_uniform_af -from .kernel import bias_sigmod_ele - - -class DropoutRowwise(nn.Module): - - def __init__(self, p): - super(DropoutRowwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, 0:1, :, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class DropoutColumnwise(nn.Module): - - def __init__(self, p): - super(DropoutColumnwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, :, 0:1, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class Transition(nn.Module): - - def __init__(self, d, n=4): - super(Transition, self).__init__() - self.norm = LayerNorm(d) - self.linear1 = Linear(d, n * d, initializer='relu') - self.linear2 = Linear(n * d, d, initializer='zeros') - - def forward(self, src): - x = self.norm(src) - x = self.linear2(F.relu(self.linear1(x))) - return src + x - - -class OutProductMean(nn.Module): - - def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): - super(OutProductMean, self).__init__() - - self.layernormM = LayerNorm(n_feat) - self.linear_a = Linear(n_feat, n_feat_proj) - self.linear_b = Linear(n_feat, n_feat_proj) - - self.o_linear = Linear(n_feat_proj * n_feat_proj, - n_feat_out, - initializer='zero', - use_bias=True) - - def forward(self, M): - M = self.layernormM(M) - left_act = self.linear_a(M) - right_act = self.linear_b(M) - - o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() - # O = rearrange(O, 'b i j d e -> b i j (d e)') - o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1) - Z = self.o_linear(o) - - return Z - - -class Linear(nn.Linear): - """ - A Linear layer with built-in nonstandard initializations. Called just - like torch.nn.Linear. - Implements the initializers in 1.11.4, plus some additional ones found - in the code. - """ - - def __init__( - self, - feature_in: int, - feature_out: int, - initializer: str = 'linear', - use_bias: bool = True, - bias_init: float = 0., - ): - super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) - - self.use_bias = use_bias - if initializer == 'linear': - glorot_uniform_af(self.weight, gain=1.0) - elif initializer == 'relu': - glorot_uniform_af(self.weight, gain=2.0) - elif initializer == 'zeros': - nn.init.zeros_(self.weight) - if self.use_bias: - with torch.no_grad(): - self.bias.fill_(bias_init) - - -class SelfAttention(nn.Module): - """ - Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors - """ - - def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): - super(SelfAttention, self).__init__() - self.qkv_dim = qkv_dim - self.c = c - self.n_head = n_head - self.out_dim = out_dim - self.gating = gating - self.last_bias_fuse = last_bias_fuse - - self.scaling = self.c**(-0.5) - - # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') - self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - - if gating: - self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) - self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) - - self.o_linear = Linear(n_head * c, - out_dim, - initializer='zero', - use_bias=(not last_bias_fuse)) - - def forward(self, in_data, nonbatched_bias=None): - """ - :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] - :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] - :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] - """ - - # qkv = self.to_qkv(in_data).chunk(3, dim=-1) - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) - - q = self.to_q(in_data) - k = self.to_k(in_data) - v = self.to_v(in_data) - - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), - # [q, k, v]) - q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), - [q, k, v]) - - q = q * self.scaling - - logits = torch.matmul(q, k.transpose(-1, -2)) - - if nonbatched_bias is not None: - logits += nonbatched_bias.unsqueeze(1) - weights = torch.softmax(logits, dim=-1) - # weights = softmax(logits) - - weighted_avg = torch.matmul(weights, v) - # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') - weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) - weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) - - if self.gating: - gate_values = self.gating_linear(in_data) - weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) - - output = self.o_linear(weighted_avg) - return output diff --git a/tests/test_autochunk/evoformer/triangle.py b/tests/test_autochunk/evoformer/triangle.py deleted file mode 100644 index f479469c3836..000000000000 --- a/tests/test_autochunk/evoformer/triangle.py +++ /dev/null @@ -1,192 +0,0 @@ -import math - -import torch -import torch.nn as nn -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add, bias_ele_dropout_residual -from .ops import Linear, SelfAttention, Transition - - -def permute_final_dims(tensor, inds): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - - -class TriangleMultiplicationOutgoing(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationOutgoing, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 0, 1)), - # permute_final_dims(right_proj_act, (2, 1, 0)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleMultiplicationIncoming(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationIncoming, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 1, 0)), - # permute_final_dims(right_proj_act, (2, 0, 1)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleAttentionStartingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionStartingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class TriangleAttentionEndingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionEndingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = Z_raw.transpose(-2, -3) - Z = self.layernorm1(Z) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - Z = Z.transpose(-2, -3) - dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class PairStack(nn.Module): - - def __init__(self, d_pair, p_drop=0.25): - super(PairStack, self).__init__() - - self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop) - self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop) - self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop) - self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop) - self.PairTransition = Transition(d=d_pair) - - def forward(self, pair): - pair = self.TriangleMultiplicationOutgoing(pair) - pair = self.TriangleMultiplicationIncoming(pair) - pair = self.TriangleAttentionStartingNode(pair) - pair = self.TriangleAttentionEndingNode(pair) - pair = self.PairTransition(pair) - return pair diff --git a/tests/test_autochunk/openfold/checkpointing.py b/tests/test_autochunk/openfold/checkpointing.py deleted file mode 100644 index 83e77c638ec1..000000000000 --- a/tests/test_autochunk/openfold/checkpointing.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.utils.checkpoint -from typing import Any, Tuple, List, Callable, Optional - - -BLOCK_ARG = Any -BLOCK_ARGS = List[BLOCK_ARG] - - -def get_checkpoint_fn(): - checkpoint = torch.utils.checkpoint.checkpoint - - return checkpoint - - -@torch.jit.ignore -def checkpoint_blocks( - blocks: List[Callable], - args: BLOCK_ARGS, - blocks_per_ckpt: Optional[int], -) -> BLOCK_ARGS: - """ - Chunk a list of blocks and run each chunk with activation - checkpointing. We define a "block" as a callable whose only inputs are - the outputs of the previous block. - - Implements Subsection 1.11.8 - - Args: - blocks: - List of blocks - args: - Tuple of arguments for the first block. - blocks_per_ckpt: - Size of each chunk. A higher value corresponds to fewer - checkpoints, and trades memory for speed. If None, no checkpointing - is performed. - Returns: - The output of the final block - """ - def wrap(a): - return (a,) if type(a) is not tuple else a - - def exec(b, a): - for block in b: - a = wrap(block(*a)) - return a - - def chunker(s, e): - def exec_sliced(*a): - return exec(blocks[s:e], a) - - return exec_sliced - - # Avoids mishaps when the blocks take just one argument - args = wrap(args) - - if blocks_per_ckpt is None: - return exec(blocks, args) - elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): - raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") - - checkpoint = get_checkpoint_fn() - - for s in range(0, len(blocks), blocks_per_ckpt): - e = s + blocks_per_ckpt - args = checkpoint(chunker(s, e), *args) - args = wrap(args) - - return args diff --git a/tests/test_autochunk/openfold/dropout.py b/tests/test_autochunk/openfold/dropout.py deleted file mode 100644 index 651b9775ef44..000000000000 --- a/tests/test_autochunk/openfold/dropout.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn -from functools import partialmethod -from typing import Union, List - - -class Dropout(nn.Module): - """ - Implementation of dropout with the ability to share the dropout mask - along a particular dimension. - - If not in training mode, this module computes the identity function. - """ - - def __init__(self, r: float, batch_dim: Union[int, List[int]]): - """ - Args: - r: - Dropout rate - batch_dim: - Dimension(s) along which the dropout mask is shared - """ - super(Dropout, self).__init__() - - self.r = r - if type(batch_dim) == int: - batch_dim = [batch_dim] - self.batch_dim = batch_dim - self.dropout = nn.Dropout(self.r) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - Tensor to which dropout is applied. Can have any shape - compatible with self.batch_dim - """ - shape = list(x.shape) - if self.batch_dim is not None: - for bd in self.batch_dim: - shape[bd] = 1 - mask = x.new_ones(shape) - mask = self.dropout(mask) - x *= mask - return x - - -class DropoutRowwise(Dropout): - """ - Convenience class for rowwise dropout as described in subsection - 1.11.6. - """ - - __init__ = partialmethod(Dropout.__init__, batch_dim=-3) - - -class DropoutColumnwise(Dropout): - """ - Convenience class for columnwise dropout as described in subsection - 1.11.6. - """ - - __init__ = partialmethod(Dropout.__init__, batch_dim=-2) diff --git a/tests/test_autochunk/openfold/evoformer.py b/tests/test_autochunk/openfold/evoformer.py deleted file mode 100644 index b53ec1aa51e5..000000000000 --- a/tests/test_autochunk/openfold/evoformer.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import torch -import torch.nn as nn -from typing import Tuple, Optional -from functools import partial - -from .primitives import Linear, LayerNorm -from .dropout import DropoutRowwise, DropoutColumnwise -from .msa import ( - MSARowAttentionWithPairBias, - MSAColumnAttention, - MSAColumnGlobalAttention, -) -from .outer_product_mean import OuterProductMean -from .pair_transition import PairTransition -from .triangular_attention import ( - TriangleAttentionStartingNode, - TriangleAttentionEndingNode, -) -from .triangular_multiplicative_update import ( - TriangleMultiplicationOutgoing, - TriangleMultiplicationIncoming, -) -from .checkpointing import checkpoint_blocks, get_checkpoint_fn -from .tensor_utils import chunk_layer - - -class MSATransition(nn.Module): - """ - Feed-forward network applied to MSA activations after attention. - - Implements Algorithm 9 - """ - def __init__(self, c_m, n): - """ - Args: - c_m: - MSA channel dimension - n: - Factor multiplied to c_m to obtain the hidden channel - dimension - """ - super(MSATransition, self).__init__() - - self.c_m = c_m - self.n = n - - self.layer_norm = LayerNorm(self.c_m) - self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") - self.relu = nn.ReLU() - self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") - - def _transition(self, m, mask): - m = self.linear_1(m) - m = self.relu(m) - m = self.linear_2(m) * mask - return m - - @torch.jit.ignore - def _chunk(self, - m: torch.Tensor, - mask: torch.Tensor, - chunk_size: int, - ) -> torch.Tensor: - return chunk_layer( - self._transition, - {"m": m, "mask": mask}, - chunk_size=chunk_size, - no_batch_dims=len(m.shape[:-2]), - ) - - def forward( - self, - m: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - ) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA activation - mask: - [*, N_seq, N_res, C_m] MSA mask - Returns: - m: - [*, N_seq, N_res, C_m] MSA activation update - """ - - # DISCREPANCY: DeepMind forgets to apply the MSA mask here. - if mask is None: - mask = m.new_ones(m.shape[:-1]) - - # [*, N_seq, N_res, 1] - mask = mask.unsqueeze(-1) - - m = self.layer_norm(m) - - if chunk_size is not None: - m = self._chunk(m, mask, chunk_size) - else: - m = self._transition(m, mask) - - return m - - -class EvoformerBlockCore(nn.Module): - def __init__( - self, - c_m: int, - c_z: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - transition_n: int, - pair_dropout: float, - inf: float, - eps: float, - _is_extra_msa_stack: bool = False, - is_multimer: bool = False, - ): - super(EvoformerBlockCore, self).__init__() - self.is_multimer = is_multimer - self.msa_transition = MSATransition( - c_m=c_m, - n=transition_n, - ) - - self.outer_product_mean = OuterProductMean( - c_m, - c_z, - c_hidden_opm, - ) - - self.tri_mul_out = TriangleMultiplicationOutgoing( - c_z, - c_hidden_mul, - ) - self.tri_mul_in = TriangleMultiplicationIncoming( - c_z, - c_hidden_mul, - ) - - self.tri_att_start = TriangleAttentionStartingNode( - c_z, - c_hidden_pair_att, - no_heads_pair, - inf=inf, - ) - self.tri_att_end = TriangleAttentionEndingNode( - c_z, - c_hidden_pair_att, - no_heads_pair, - inf=inf, - ) - - self.pair_transition = PairTransition( - c_z, - transition_n, - ) - - self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) - self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) - - def forward( - self, - m: torch.Tensor, - z: torch.Tensor, - chunk_size: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # DeepMind doesn't mask these transitions in the source, so _mask_trans - # should be disabled to better approximate the exact activations of - # the original. - - m = m + self.msa_transition( - m, chunk_size=chunk_size - ) - z = z + self.outer_product_mean( - m, chunk_size=chunk_size - ) - z = z + self.ps_dropout_row_layer(self.tri_mul_out(z)) - z = z + self.ps_dropout_row_layer(self.tri_mul_in(z)) - z = z + self.ps_dropout_row_layer( - self.tri_att_start(z, chunk_size=chunk_size) - ) - z = z + self.ps_dropout_col_layer( - self.tri_att_end(z, chunk_size=chunk_size) - ) - z = z + self.pair_transition( - z, chunk_size=chunk_size - ) - - return m, z - - -class EvoformerBlock(nn.Module): - def __init__(self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - inf: float, - eps: float, - is_multimer: bool, - ): - super(EvoformerBlock, self).__init__() - - self.msa_att_row = MSARowAttentionWithPairBias( - c_m=c_m, - c_z=c_z, - c_hidden=c_hidden_msa_att, - no_heads=no_heads_msa, - inf=inf, - ) - - self.msa_att_col = MSAColumnAttention( - c_m, - c_hidden_msa_att, - no_heads_msa, - inf=inf, - ) - - self.msa_dropout_layer = DropoutRowwise(msa_dropout) - - self.core = EvoformerBlockCore( - c_m=c_m, - c_z=c_z, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - ) - - self.outer_product_mean = OuterProductMean( - c_m, - c_z, - c_hidden_opm, - ) - self.is_multimer = is_multimer - - def forward(self, - m: torch.Tensor, - z: torch.Tensor, - chunk_size: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - m = m + self.msa_dropout_layer( - self.msa_att_row(m, z=z, chunk_size=chunk_size) - ) - m = m + self.msa_att_col(m, chunk_size=chunk_size) - m, z = self.core( - m, - z, - chunk_size=chunk_size, - ) - - return m, z - - -class EvoformerStack(nn.Module): - """ - Main Evoformer trunk. - - Implements Algorithm 6. - """ - - def __init__( - self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - c_s: int, - no_heads_msa: int, - no_heads_pair: int, - no_blocks: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - blocks_per_ckpt: int, - inf: float, - eps: float, - clear_cache_between_blocks: bool = False, - is_multimer: bool = False, - **kwargs, - ): - """ - Args: - c_m: - MSA channel dimension - c_z: - Pair channel dimension - c_hidden_msa_att: - Hidden dimension in MSA attention - c_hidden_opm: - Hidden dimension in outer product mean module - c_hidden_mul: - Hidden dimension in multiplicative updates - c_hidden_pair_att: - Hidden dimension in triangular attention - c_s: - Channel dimension of the output "single" embedding - no_heads_msa: - Number of heads used for MSA attention - no_heads_pair: - Number of heads used for pair attention - no_blocks: - Number of Evoformer blocks in the stack - transition_n: - Factor by which to multiply c_m to obtain the MSATransition - hidden dimension - msa_dropout: - Dropout rate for MSA activations - pair_dropout: - Dropout used for pair activations - blocks_per_ckpt: - Number of Evoformer blocks in each activation checkpoint - clear_cache_between_blocks: - Whether to clear CUDA's GPU memory cache between blocks of the - stack. Slows down each block but can reduce fragmentation - """ - super(EvoformerStack, self).__init__() - - self.blocks_per_ckpt = blocks_per_ckpt - self.clear_cache_between_blocks = clear_cache_between_blocks - - self.blocks = nn.ModuleList() - - for _ in range(no_blocks): - block = EvoformerBlock( - c_m=c_m, - c_z=c_z, - c_hidden_msa_att=c_hidden_msa_att, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - msa_dropout=msa_dropout, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - is_multimer=is_multimer, - ) - self.blocks.append(block) - - self.linear = Linear(c_m, c_s) - - def forward(self, - m: torch.Tensor, - z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, - chunk_size: int, - _mask_trans: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding - msa_mask: - [*, N_seq, N_res] MSA mask - pair_mask: - [*, N_res, N_res] pair mask - Returns: - m: - [*, N_seq, N_res, C_m] MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding - s: - [*, N_res, C_s] single embedding (or None if extra MSA stack) - """ - blocks = [ - partial( - b, - msa_mask=msa_mask, - pair_mask=pair_mask, - chunk_size=chunk_size, - _mask_trans=_mask_trans, - ) - for b in self.blocks - ] - - if(self.clear_cache_between_blocks): - def block_with_cache_clear(block, *args): - torch.cuda.empty_cache() - return block(*args) - - blocks = [partial(block_with_cache_clear, b) for b in blocks] - - m, z = checkpoint_blocks( - blocks, - args=(m, z), - blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, - ) - - s = self.linear(m[..., 0, :, :]) - - return m, z, s diff --git a/tests/test_autochunk/openfold/msa.py b/tests/test_autochunk/openfold/msa.py deleted file mode 100644 index 7c137286feab..000000000000 --- a/tests/test_autochunk/openfold/msa.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import torch -import torch.nn as nn -from typing import Optional, List, Tuple - -from .primitives import ( - Linear, - LayerNorm, - Attention, - GlobalAttention, - _attention_chunked_trainable, -) -from .checkpointing import get_checkpoint_fn -from .tensor_utils import ( - chunk_layer, - permute_final_dims, - flatten_final_dims, -) - - -class MSAAttention(nn.Module): - def __init__( - self, - c_in, - c_hidden, - no_heads, - pair_bias=False, - c_z=None, - inf=1e9, - ): - """ - Args: - c_in: - Input channel dimension - c_hidden: - Per-head hidden channel dimension - no_heads: - Number of attention heads - pair_bias: - Whether to use pair embedding bias - c_z: - Pair embedding channel dimension. Ignored unless pair_bias - is true - inf: - A large number to be used in computing the attention mask - """ - super(MSAAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.pair_bias = pair_bias - self.c_z = c_z - self.inf = inf - - self.layer_norm_m = LayerNorm(self.c_in) - - self.layer_norm_z = None - self.linear_z = None - if self.pair_bias: - self.layer_norm_z = LayerNorm(self.c_z) - self.linear_z = Linear( - self.c_z, self.no_heads, bias=False, init="normal" - ) - - self.mha = Attention( - self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads - ) - - @torch.jit.ignore - def _chunk(self, - m: torch.Tensor, - biases: List[torch.Tensor], - chunk_size: int, - ) -> torch.Tensor: - return chunk_layer( - self.mha, - {"q_x": m, "kv_x": m, "biases": biases}, - chunk_size=chunk_size, - no_batch_dims=len(m.shape[:-2]), - ) - - def _prep_inputs(self, - m: torch.Tensor, - z: Optional[torch.Tensor], - mask: Optional[torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # [*, N_seq, N_res, C_m] - m = self.layer_norm_m(m) - - n_seq, n_res = m.shape[-3:-1] - if mask is None: - # [*, N_seq, N_res] - mask = m.new_ones( - m.shape[:-3] + (n_seq, n_res), - ) - - # [*, N_seq, 1, 1, N_res] - mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] - - # This step simply returns a larger view of the bias, and does not - # consume additional memory. - # [*, N_seq, no_heads, N_res, N_res] - #bias = bias.expand( - # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) - #) - - if (self.pair_bias and - z is not None and # For the - self.layer_norm_z is not None and # benefit of - self.linear_z is not None # TorchScript - ): - # [*, N_res, N_res, C_z] - z = self.layer_norm_z(z) - - # [*, N_res, N_res, no_heads] - z = self.linear_z(z) - - # [*, 1, no_heads, N_res, N_res] - z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) - - return m, mask_bias, z - - - def forward(self, - m: torch.Tensor, - z: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - _chunk_logits: Optional[int] = None, - _checkpoint_chunks: Optional[bool] = None, - ) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding. Required only if - pair_bias is True - mask: - [*, N_seq, N_res] MSA mask - chunk_size: - Size of chunks into which the inputs are split along their - batch dimensions. A low value decreases memory overhead at the - cost of slower execution. Chunking is not performed by default. - - """ - m, mask_bias, z = self._prep_inputs(m, z, mask) - - biases = [mask_bias] - if(z is not None): - biases.append(z) - - if chunk_size is not None: - m = self._chunk(m, biases, chunk_size) - else: - m = self.mha( - q_x=m, - kv_x=m, - biases=biases - ) - - return m - - -class MSARowAttentionWithPairBias(MSAAttention): - """ - Implements Algorithm 7. - """ - - def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9): - """ - Args: - c_m: - Input channel dimension - c_z: - Pair embedding channel dimension - c_hidden: - Per-head hidden channel dimension - no_heads: - Number of attention heads - inf: - Large number used to construct attention masks - """ - super(MSARowAttentionWithPairBias, self).__init__( - c_m, - c_hidden, - no_heads, - pair_bias=True, - c_z=c_z, - inf=inf, - ) - - -class MSAColumnAttention(nn.Module): - """ - Implements Algorithm 8. - - By rights, this should also be a subclass of MSAAttention. Alas, - most inheritance isn't supported by TorchScript. - """ - - def __init__(self, c_m, c_hidden, no_heads, inf=1e9): - """ - Args: - c_m: - MSA channel dimension - c_hidden: - Per-head hidden channel dimension - no_heads: - Number of attention heads - inf: - Large number used to construct attention masks - """ - super(MSAColumnAttention, self).__init__() - - self.c_m = c_m - self.c_hidden = c_hidden - self.no_heads = no_heads - self.inf = inf - - self._msa_att = MSAAttention( - c_in=c_m, - c_hidden=c_hidden, - no_heads=no_heads, - pair_bias=False, - c_z=None, - inf=inf, - ) - - def forward(self, - m: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None - ) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - mask: - [*, N_seq, N_res] MSA mask - chunk_size: - Size of chunks into which the inputs are split along their - batch dimensions. A low value decreases memory overhead at the - cost of slower execution. Chunking is not performed by default. - """ - # [*, N_res, N_seq, C_in] - m = m.transpose(-2, -3) - - m = self._msa_att(m, chunk_size=chunk_size) - - # [*, N_seq, N_res, C_in] - m = m.transpose(-2, -3) - - return m - - -class MSAColumnGlobalAttention(nn.Module): - def __init__( - self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10, - ): - super(MSAColumnGlobalAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.inf = inf - self.eps = eps - - self.layer_norm_m = nn.LayerNorm(c_in) - - self.global_attention = GlobalAttention( - c_in=c_in, - c_hidden=c_hidden, - no_heads=no_heads, - inf=inf, - eps=eps, - ) - - @torch.jit.ignore - def _chunk(self, - m: torch.Tensor, - chunk_size: int, - ) -> torch.Tensor: - mha_input = { - "m": m, - } - return chunk_layer( - self.global_attention, - mha_input, - chunk_size=chunk_size, - no_batch_dims=len(m.shape[:-2]), - ) - - def forward( - self, - m: torch.Tensor, - chunk_size: Optional[int] = None, - ) -> torch.Tensor: - n_seq, n_res, c_in = m.shape[-3:] - - # [*, N_res, N_seq, C_in] - m = m.transpose(-2, -3) - - # [*, N_res, N_seq, C_in] - m = self.layer_norm_m(m) - - if chunk_size is not None: - m = self._chunk(m, chunk_size) - else: - m = self.global_attention(m=m) - - # [*, N_seq, N_res, C_in] - m = m.transpose(-2, -3) - - return m diff --git a/tests/test_autochunk/openfold/outer_product_mean.py b/tests/test_autochunk/openfold/outer_product_mean.py deleted file mode 100644 index daadf1c272cf..000000000000 --- a/tests/test_autochunk/openfold/outer_product_mean.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Optional - -import torch -import torch.nn as nn - -from .primitives import Linear -from .tensor_utils import chunk_layer - - -class OuterProductMean(nn.Module): - """ - Implements Algorithm 10. - """ - - def __init__(self, c_m, c_z, c_hidden, eps=1e-3): - """ - Args: - c_m: - MSA embedding channel dimension - c_z: - Pair embedding channel dimension - c_hidden: - Hidden channel dimension - """ - super(OuterProductMean, self).__init__() - - self.c_m = c_m - self.c_z = c_z - self.c_hidden = c_hidden - self.eps = eps - - self.layer_norm = nn.LayerNorm(c_m) - self.linear_1 = Linear(c_m, c_hidden) - self.linear_2 = Linear(c_m, c_hidden) - self.linear_out = Linear(c_hidden ** 2, c_z, init="final") - - def _opm(self, a, b): - # [*, N_res, N_res, C, C] - outer = torch.einsum("...bac,...dae->...bdce", a, b) - - # [*, N_res, N_res, C * C] - outer = outer.reshape(outer.shape[:-2] + (-1,)) - - # [*, N_res, N_res, C_z] - outer = self.linear_out(outer) - - return outer - - @torch.jit.ignore - def _chunk(self, - a: torch.Tensor, - b: torch.Tensor, - chunk_size: int - ) -> torch.Tensor: - # Since the "batch dim" in this case is not a true batch dimension - # (in that the shape of the output depends on it), we need to - # iterate over it ourselves - a_reshape = a.reshape((-1,) + a.shape[-3:]) - b_reshape = b.reshape((-1,) + b.shape[-3:]) - out = [] - for a_prime, b_prime in zip(a_reshape, b_reshape): - outer = chunk_layer( - partial(self._opm, b=b_prime), - {"a": a_prime}, - chunk_size=chunk_size, - no_batch_dims=1, - ) - out.append(outer) - outer = torch.stack(out, dim=0) - outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) - - return outer - - def forward(self, - m: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None - ) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - mask: - [*, N_seq, N_res] MSA mask - Returns: - [*, N_res, N_res, C_z] pair embedding update - """ - if mask is None: - mask = m.new_ones(m.shape[:-1]) - - # [*, N_seq, N_res, C_m] - m = self.layer_norm(m) - - # [*, N_seq, N_res, C] - mask = mask.unsqueeze(-1) - a = self.linear_1(m) * mask - b = self.linear_2(m) * mask - - a = a.transpose(-2, -3) - b = b.transpose(-2, -3) - - if chunk_size is not None: - outer = self._chunk(a, b, chunk_size) - else: - outer = self._opm(a, b) - - # [*, N_res, N_res, 1] - norm = torch.einsum("...abc,...adc->...bdc", mask, mask) - - # [*, N_res, N_res, C_z] - outer = outer / (self.eps + norm) - - return outer diff --git a/tests/test_autochunk/openfold/pair_transition.py b/tests/test_autochunk/openfold/pair_transition.py deleted file mode 100644 index 7d09914dc3cc..000000000000 --- a/tests/test_autochunk/openfold/pair_transition.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import torch -import torch.nn as nn - -from .primitives import Linear, LayerNorm -from .tensor_utils import chunk_layer - - -class PairTransition(nn.Module): - """ - Implements Algorithm 15. - """ - - def __init__(self, c_z, n): - """ - Args: - c_z: - Pair transition channel dimension - n: - Factor by which c_z is multiplied to obtain hidden channel - dimension - """ - super(PairTransition, self).__init__() - - self.c_z = c_z - self.n = n - - self.layer_norm = LayerNorm(self.c_z) - self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") - self.relu = nn.ReLU() - self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") - - def _transition(self, z, mask): - # [*, N_res, N_res, C_hidden] - z = self.linear_1(z) - z = self.relu(z) - - # [*, N_res, N_res, C_z] - z = self.linear_2(z) * mask - - return z - - @torch.jit.ignore - def _chunk(self, - z: torch.Tensor, - mask: torch.Tensor, - chunk_size: int, - ) -> torch.Tensor: - return chunk_layer( - self._transition, - {"z": z, "mask": mask}, - chunk_size=chunk_size, - no_batch_dims=len(z.shape[:-2]), - ) - - - def forward(self, - z: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - ) -> torch.Tensor: - """ - Args: - z: - [*, N_res, N_res, C_z] pair embedding - Returns: - [*, N_res, N_res, C_z] pair embedding update - """ - # DISCREPANCY: DeepMind forgets to apply the mask in this module. - if mask is None: - mask = z.new_ones(z.shape[:-1]) - - # [*, N_res, N_res, 1] - mask = mask.unsqueeze(-1) - - # [*, N_res, N_res, C_z] - z = self.layer_norm(z) - - if chunk_size is not None: - z = self._chunk(z, mask, chunk_size) - else: - z = self._transition(z=z, mask=mask) - - return z diff --git a/tests/test_autochunk/openfold/primitives.py b/tests/test_autochunk/openfold/primitives.py deleted file mode 100644 index 32a9d487c441..000000000000 --- a/tests/test_autochunk/openfold/primitives.py +++ /dev/null @@ -1,529 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -import math -from typing import Optional, Callable, List, Tuple, Sequence -import numpy as np - -import torch -import torch.nn as nn - -from .checkpointing import get_checkpoint_fn -from .tensor_utils import ( - permute_final_dims, - flatten_final_dims, - _chunk_slice, -) - - -def _prod(nums): - out = 1 - for n in nums: - out = out * n - return out - - -def _calculate_fan(linear_weight_shape, fan="fan_in"): - fan_out, fan_in = linear_weight_shape - - if fan == "fan_in": - f = fan_in - elif fan == "fan_out": - f = fan_out - elif fan == "fan_avg": - f = (fan_in + fan_out) / 2 - else: - raise ValueError("Invalid fan option") - - return f - - -def glorot_uniform_init_(weights): - nn.init.xavier_uniform_(weights, gain=1) - - -def final_init_(weights): - with torch.no_grad(): - weights.fill_(0.0) - - -def gating_init_(weights): - with torch.no_grad(): - weights.fill_(0.0) - - -def normal_init_(weights): - torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") - - -def ipa_point_weights_init_(weights): - with torch.no_grad(): - softplus_inverse_1 = 0.541324854612918 - weights.fill_(softplus_inverse_1) - - -class Linear(nn.Linear): - """ - A Linear layer with built-in nonstandard initializations. Called just - like torch.nn.Linear. - - Implements the initializers in 1.11.4, plus some additional ones found - in the code. - """ - - def __init__( - self, - in_dim: int, - out_dim: int, - bias: bool = True, - init: str = "default", - init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, - ): - """ - Args: - in_dim: - The final dimension of inputs to the layer - out_dim: - The final dimension of layer outputs - bias: - Whether to learn an additive bias. True by default - init: - The initializer to use. Choose from: - - "default": LeCun fan-in truncated normal initialization - "relu": He initialization w/ truncated normal distribution - "glorot": Fan-average Glorot uniform initialization - "gating": Weights=0, Bias=1 - "normal": Normal initialization with std=1/sqrt(fan_in) - "final": Weights=0, Bias=0 - - Overridden by init_fn if the latter is not None. - init_fn: - A custom initializer taking weight and bias as inputs. - Overrides init if not None. - """ - super(Linear, self).__init__(in_dim, out_dim, bias=bias) - - if bias: - with torch.no_grad(): - self.bias.fill_(0) - - if init_fn is not None: - init_fn(self.weight, self.bias) - else: - if init == "default": - normal_init_(self.weight) - elif init == "relu": - normal_init_(self.weight) - elif init == "glorot": - glorot_uniform_init_(self.weight) - elif init == "gating": - gating_init_(self.weight) - if bias: - with torch.no_grad(): - self.bias.fill_(1.0) - elif init == "normal": - normal_init_(self.weight) - elif init == "final": - final_init_(self.weight) - else: - raise ValueError("Invalid init string.") - - -class LayerNorm(nn.Module): - - def __init__(self, c_in, eps=1e-5): - super(LayerNorm, self).__init__() - - self.c_in = (c_in,) - self.eps = eps - - self.weight = nn.Parameter(torch.ones(c_in)) - self.bias = nn.Parameter(torch.zeros(c_in)) - - def forward(self, x): - out = nn.functional.layer_norm( - x, - self.c_in, - self.weight, - self.bias, - self.eps, - ) - - return out - - -@torch.jit.ignore -def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Softmax, but without automatic casting to fp32 when the input is of - type bfloat16 - """ - s = torch.nn.functional.softmax(t, dim=dim) - - return s - - -#@torch.jit.script -def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - biases: List[torch.Tensor]) -> torch.Tensor: - # [*, H, Q, C_hidden] - query = permute_final_dims(query, (1, 0, 2)) - - # [*, H, C_hidden, K] - key = permute_final_dims(key, (1, 2, 0)) - - # [*, H, V, C_hidden] - value = permute_final_dims(value, (1, 0, 2)) - - # [*, H, Q, K] - a = torch.matmul(query, key) - - for b in biases: - a += b - - a = softmax(a, -1) - - # [*, H, Q, C_hidden] - a = torch.matmul(a, value) - - # [*, Q, H, C_hidden] - a = a.transpose(-2, -3) - - return a - - -@torch.jit.ignore -def _attention_chunked_trainable( - query, - key, - value, - biases, - chunk_size, - chunk_dim, - checkpoint, -): - if (checkpoint and len(biases) > 2): - raise ValueError("Checkpointed version permits only permits two bias terms") - - def _checkpointable_attention(q, k, v, b1, b2): - bs = [b for b in [b1, b2] if b is not None] - return _attention(q, k, v, bs) - - o_chunks = [] - checkpoint_fn = get_checkpoint_fn() - count = query.shape[chunk_dim] - for start in range(0, count, chunk_size): - end = start + chunk_size - idx = [slice(None)] * len(query.shape) - idx[chunk_dim] = slice(start, end) - idx_tup = tuple(idx) - q_chunk = query[idx_tup] - k_chunk = key[idx_tup] - v_chunk = value[idx_tup] - - def _slice_bias(b): - idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)) - return b[tuple(idx)] - - if (checkpoint): - bias_1_chunk, bias_2_chunk = [ - _slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2] - ] - - o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk, - bias_1_chunk, bias_2_chunk) - else: - bias_chunks = [_slice_bias(b) for b in biases] - - o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) - - o_chunks.append(o_chunk) - - o = torch.cat(o_chunks, dim=chunk_dim) - return o - - -class Attention(nn.Module): - """ - Standard multi-head attention using AlphaFold's default layer - initialization. Allows multiple bias vectors. - """ - - def __init__( - self, - c_q: int, - c_k: int, - c_v: int, - c_hidden: int, - no_heads: int, - gating: bool = True, - ): - """ - Args: - c_q: - Input dimension of query data - c_k: - Input dimension of key data - c_v: - Input dimension of value data - c_hidden: - Per-head hidden dimension - no_heads: - Number of attention heads - gating: - Whether the output should be gated using query data - """ - super(Attention, self).__init__() - - self.c_q = c_q - self.c_k = c_k - self.c_v = c_v - self.c_hidden = c_hidden - self.no_heads = no_heads - self.gating = gating - - # DISCREPANCY: c_hidden is not the per-head channel dimension, as - # stated in the supplement, but the overall channel dimension. - - self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") - self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") - self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") - self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final") - - self.linear_g = None - if self.gating: - self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating") - - self.sigmoid = nn.Sigmoid() - - def _prep_qkv(self, q_x: torch.Tensor, - kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # [*, Q/K/V, H * C_hidden] - q = self.linear_q(q_x) - k = self.linear_k(kv_x) - v = self.linear_v(kv_x) - - # [*, Q/K, H, C_hidden] - q = q.view(q.shape[:-1] + (self.no_heads, -1)) - k = k.view(k.shape[:-1] + (self.no_heads, -1)) - v = v.view(v.shape[:-1] + (self.no_heads, -1)) - - q /= math.sqrt(self.c_hidden) - - return q, k, v - - def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: - if (self.linear_g is not None): - g = self.sigmoid(self.linear_g(q_x)) - - # [*, Q, H, C_hidden] - g = g.view(g.shape[:-1] + (self.no_heads, -1)) - o = o * g - - # [*, Q, H * C_hidden] - o = flatten_final_dims(o, 2) - - # [*, Q, C_q] - o = self.linear_o(o) - - return o - - def forward( - self, - q_x: torch.Tensor, - kv_x: torch.Tensor, - biases: Optional[List[torch.Tensor]] = None, - use_lma: bool = False, - q_chunk_size: Optional[int] = None, - kv_chunk_size: Optional[int] = None, - ) -> torch.Tensor: - """ - Args: - q_x: - [*, Q, C_q] query data - kv_x: - [*, K, C_k] key data - biases: - List of biases that broadcast to [*, H, Q, K] - use_lma: - Whether to use low-memory attention - q_chunk_size: - Query chunk size (for LMA) - kv_chunk_size: - Key/Value chunk size (for LMA) - Returns - [*, Q, C_q] attention update - """ - if (biases is None): - biases = [] - if (use_lma and (q_chunk_size is None or kv_chunk_size is None)): - raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must " - "be provided") - - q, k, v = self._prep_qkv(q_x, kv_x) - - if (use_lma): - biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases] - - o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) - else: - o = _attention(q, k, v, biases) - - o = self._wrap_up(o, q_x) - - return o - - -class GlobalAttention(nn.Module): - - def __init__(self, c_in, c_hidden, no_heads, inf, eps): - super(GlobalAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.inf = inf - self.eps = eps - - self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot") - - self.linear_k = Linear( - c_in, - c_hidden, - bias=False, - init="glorot", - ) - self.linear_v = Linear( - c_in, - c_hidden, - bias=False, - init="glorot", - ) - self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") - self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") - - self.sigmoid = nn.Sigmoid() - - def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - # [*, N_res, C_in] - q = torch.sum(m * mask.unsqueeze(-1), - dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps) - - # [*, N_res, H * C_hidden] - q = self.linear_q(q) - q *= (self.c_hidden**(-0.5)) - - # [*, N_res, H, C_hidden] - q = q.view(q.shape[:-1] + (self.no_heads, -1)) - - # [*, N_res, N_seq, C_hidden] - k = self.linear_k(m) - v = self.linear_v(m) - - # [*, N_res, H, N_seq] - a = torch.matmul( - q, - k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] - ) - bias = (self.inf * (mask - 1))[..., :, None, :] - a += bias - a = softmax(a) - - # [*, N_res, H, C_hidden] - o = torch.matmul( - a, - v, - ) - - # [*, N_res, N_seq, C_hidden] - g = self.sigmoid(self.linear_g(m)) - - # [*, N_res, N_seq, H, C_hidden] - g = g.view(g.shape[:-1] + (self.no_heads, -1)) - - # [*, N_res, N_seq, H, C_hidden] - o = o.unsqueeze(-3) * g - - # [*, N_res, N_seq, H * C_hidden] - o = o.reshape(o.shape[:-2] + (-1,)) - - # [*, N_res, N_seq, C_in] - m = self.linear_o(o) - - return m - - -def _lma( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - biases: List[torch.Tensor], - q_chunk_size: int, - kv_chunk_size: int, -): - no_q, no_kv = q.shape[-3], k.shape[-3] - - # [*, Q, H, C_hidden] - o = q.new_zeros(q.shape) - for q_s in range(0, no_q, q_chunk_size): - q_chunk = q[..., q_s:q_s + q_chunk_size, :, :] - large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases] - - maxes = [] - weights = [] - values = [] - for kv_s in range(0, no_kv, kv_chunk_size): - k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :] - v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :] - small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks] - - a = torch.einsum( - "...qhd,...khd->...hqk", - q_chunk, - k_chunk, - ) - - for b in small_bias_chunks: - a += b - - a = a.transpose(-2, -3) - - max_a = torch.max(a, dim=-1, keepdim=True)[0] - exp_a = torch.exp(a - max_a) - exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) - - maxes.append(max_a.detach().squeeze(-1)) - weights.append(torch.sum(exp_a, dim=-1)) - values.append(exp_v) - - chunk_max = torch.stack(maxes, dim=-3) - chunk_weights = torch.stack(weights, dim=-3) - chunk_values = torch.stack(values, dim=-4) - - global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] - max_diffs = torch.exp(chunk_max - global_max) - chunk_values *= max_diffs.unsqueeze(-1) - chunk_weights *= max_diffs - - all_values = torch.sum(chunk_values, dim=-4) - all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) - - q_chunk_out = all_values / all_weights - - o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out - - return o diff --git a/tests/test_autochunk/openfold/tensor_utils.py b/tests/test_autochunk/openfold/tensor_utils.py deleted file mode 100644 index 384a71fb5ffd..000000000000 --- a/tests/test_autochunk/openfold/tensor_utils.py +++ /dev/null @@ -1,408 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -import torch -import torch.nn as nn -from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional - - -def permute_final_dims(tensor: torch.Tensor, inds: List[int]): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - - -def flatten_final_dims(t: torch.Tensor, no_dims: int): - return t.reshape(t.shape[:-no_dims] + (-1,)) - - -def masked_mean(mask, value, dim, eps=1e-4): - mask = mask.expand(*value.shape) - return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) - - -def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): - boundaries = torch.linspace( - min_bin, max_bin, no_bins - 1, device=pts.device - ) - dists = torch.sqrt( - torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) - ) - return torch.bucketize(dists, boundaries) - - -def dict_multimap(fn, dicts): - first = dicts[0] - new_dict = {} - for k, v in first.items(): - all_v = [d[k] for d in dicts] - if type(v) is dict: - new_dict[k] = dict_multimap(fn, all_v) - else: - new_dict[k] = fn(all_v) - - return new_dict - - -def one_hot(x, v_bins): - reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) - diffs = x[..., None] - reshaped_bins - am = torch.argmin(torch.abs(diffs), dim=-1) - return nn.functional.one_hot(am, num_classes=len(v_bins)).float() - - -def batched_gather(data, inds, dim=0, no_batch_dims=0): - ranges = [] - for i, s in enumerate(data.shape[:no_batch_dims]): - r = torch.arange(s) - r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) - ranges.append(r) - - remaining_dims = [ - slice(None) for _ in range(len(data.shape) - no_batch_dims) - ] - remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds - ranges.extend(remaining_dims) - return data[ranges] - - -# With tree_map, a poor man's JAX tree_map -def dict_map(fn, dic, leaf_type): - new_dict = {} - for k, v in dic.items(): - if type(v) is dict: - new_dict[k] = dict_map(fn, v, leaf_type) - else: - new_dict[k] = tree_map(fn, v, leaf_type) - - return new_dict - - -def tree_map(fn, tree, leaf_type): - if isinstance(tree, dict): - return dict_map(fn, tree, leaf_type) - elif isinstance(tree, list): - return [tree_map(fn, x, leaf_type) for x in tree] - elif isinstance(tree, tuple): - return tuple([tree_map(fn, x, leaf_type) for x in tree]) - elif isinstance(tree, leaf_type): - return fn(tree) - else: - print(type(tree)) - raise ValueError("Not supported") - - -tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) - -def _fetch_dims(tree): - shapes = [] - tree_type = type(tree) - if tree_type is dict: - for v in tree.values(): - shapes.extend(_fetch_dims(v)) - elif tree_type is list or tree_type is tuple: - for t in tree: - shapes.extend(_fetch_dims(t)) - elif tree_type is torch.Tensor: - shapes.append(tree.shape) - else: - raise ValueError("Not supported") - - return shapes - - -@torch.jit.ignore -def _flat_idx_to_idx( - flat_idx: int, - dims: Tuple[int], -) -> Tuple[int]: - idx = [] - for d in reversed(dims): - idx.append(flat_idx % d) - flat_idx = flat_idx // d - - return tuple(reversed(idx)) - - -@torch.jit.ignore -def _get_minimal_slice_set( - start: Sequence[int], - end: Sequence[int], - dims: int, - start_edges: Optional[Sequence[bool]] = None, - end_edges: Optional[Sequence[bool]] = None, -) -> Sequence[Tuple[int]]: - """ - Produces an ordered sequence of tensor slices that, when used in - sequence on a tensor with shape dims, yields tensors that contain every - leaf in the contiguous range [start, end]. Care is taken to yield a - short sequence of slices, and perhaps even the shortest possible (I'm - pretty sure it's the latter). - - end is INCLUSIVE. - """ - # start_edges and end_edges both indicate whether, starting from any given - # dimension, the start/end index is at the top/bottom edge of the - # corresponding tensor, modeled as a tree - def reduce_edge_list(ll): - tally = 1 - for i in range(len(ll)): - reversed_idx = -1 * (i + 1) - ll[reversed_idx] *= tally - tally = ll[reversed_idx] - - if(start_edges is None): - start_edges = [s == 0 for s in start] - reduce_edge_list(start_edges) - if(end_edges is None): - end_edges = [e == (d - 1) for e,d in zip(end, dims)] - reduce_edge_list(end_edges) - - # Base cases. Either start/end are empty and we're done, or the final, - # one-dimensional tensor can be simply sliced - if(len(start) == 0): - return [tuple()] - elif(len(start) == 1): - return [(slice(start[0], end[0] + 1),)] - - slices = [] - path = [] - - # Dimensions common to start and end can be selected directly - for s,e in zip(start, end): - if(s == e): - path.append(slice(s, s + 1)) - else: - break - - path = tuple(path) - divergence_idx = len(path) - - # start == end, and we're done - if(divergence_idx == len(dims)): - return [tuple(path)] - - def upper(): - sdi = start[divergence_idx] - return [ - path + (slice(sdi, sdi + 1),) + s for s in - _get_minimal_slice_set( - start[divergence_idx + 1:], - [d - 1 for d in dims[divergence_idx + 1:]], - dims[divergence_idx + 1:], - start_edges=start_edges[divergence_idx + 1:], - end_edges=[1 for _ in end_edges[divergence_idx + 1:]] - ) - ] - - def lower(): - edi = end[divergence_idx] - return [ - path + (slice(edi, edi + 1),) + s for s in - _get_minimal_slice_set( - [0 for _ in start[divergence_idx + 1:]], - end[divergence_idx + 1:], - dims[divergence_idx + 1:], - start_edges=[1 for _ in start_edges[divergence_idx + 1:]], - end_edges=end_edges[divergence_idx + 1:], - ) - ] - - # If both start and end are at the edges of the subtree rooted at - # divergence_idx, we can just select the whole subtree at once - if(start_edges[divergence_idx] and end_edges[divergence_idx]): - slices.append( - path + (slice(start[divergence_idx], end[divergence_idx] + 1),) - ) - # If just start is at the edge, we can grab almost all of the subtree, - # treating only the ragged bottom edge as an edge case - elif(start_edges[divergence_idx]): - slices.append( - path + (slice(start[divergence_idx], end[divergence_idx]),) - ) - slices.extend(lower()) - # Analogous to the previous case, but the top is ragged this time - elif(end_edges[divergence_idx]): - slices.extend(upper()) - slices.append( - path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),) - ) - # If both sides of the range are ragged, we need to handle both sides - # separately. If there's contiguous meat in between them, we can index it - # in one big chunk - else: - slices.extend(upper()) - middle_ground = end[divergence_idx] - start[divergence_idx] - if(middle_ground > 1): - slices.append( - path + (slice(start[divergence_idx] + 1, end[divergence_idx]),) - ) - slices.extend(lower()) - - return [tuple(s) for s in slices] - - -@torch.jit.ignore -def _chunk_slice( - t: torch.Tensor, - flat_start: int, - flat_end: int, - no_batch_dims: int, -) -> torch.Tensor: - """ - Equivalent to - - t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] - - but without the need for the initial reshape call, which can be - memory-intensive in certain situations. The only reshape operations - in this function are performed on sub-tensors that scale with - (flat_end - flat_start), the chunk size. - """ - - batch_dims = t.shape[:no_batch_dims] - start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) - # _get_minimal_slice_set is inclusive - end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) - - # Get an ordered list of slices to perform - slices = _get_minimal_slice_set( - start_idx, - end_idx, - batch_dims, - ) - - sliced_tensors = [t[s] for s in slices] - - return torch.cat( - [s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors] - ) - - -def chunk_layer( - layer: Callable, - inputs: Dict[str, Any], - chunk_size: int, - no_batch_dims: int, - low_mem: bool = False, -) -> Any: - """ - Implements the "chunking" procedure described in section 1.11.8. - - Layer outputs and inputs are assumed to be simple "pytrees," - consisting only of (arbitrarily nested) lists, tuples, and dicts with - torch.Tensor leaves. - - Args: - layer: - The layer to be applied chunk-wise - inputs: - A (non-nested) dictionary of keyworded inputs. All leaves must - be tensors and must share the same batch dimensions. - chunk_size: - The number of sub-batches per chunk. If multiple batch - dimensions are specified, a "sub-batch" is defined as a single - indexing of all batch dimensions simultaneously (s.t. the - number of sub-batches is the product of the batch dimensions). - no_batch_dims: - How many of the initial dimensions of each input tensor can - be considered batch dimensions. - low_mem: - Avoids flattening potentially large input tensors. Unnecessary - in most cases, and is ever so slightly slower than the default - setting. - Returns: - The reassembled output of the layer on the inputs. - """ - if not (len(inputs) > 0): - raise ValueError("Must provide at least one input") - - initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] - orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) - - def _prep_inputs(t): - # TODO: make this more memory efficient. This sucks - if(not low_mem): - if not sum(t.shape[:no_batch_dims]) == no_batch_dims: - t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) - t = t.reshape(-1, *t.shape[no_batch_dims:]) - else: - t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) - return t - - prepped_inputs = tensor_tree_map(_prep_inputs, inputs) - - flat_batch_dim = 1 - for d in orig_batch_dims: - flat_batch_dim *= d - - no_chunks = flat_batch_dim // chunk_size + ( - flat_batch_dim % chunk_size != 0 - ) - - i = 0 - out = None - for _ in range(no_chunks): - # Chunk the input - if(not low_mem): - select_chunk = ( - lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t - ) - else: - select_chunk = ( - partial( - _chunk_slice, - flat_start=i, - flat_end=min(flat_batch_dim, i + chunk_size), - no_batch_dims=len(orig_batch_dims) - ) - ) - - chunks = tensor_tree_map(select_chunk, prepped_inputs) - - # Run the layer on the chunk - output_chunk = layer(**chunks) - - # Allocate space for the output - if out is None: - allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]) - out = tensor_tree_map(allocate, output_chunk) - - # Put the chunk in its pre-allocated space - out_type = type(output_chunk) - if out_type is dict: - def assign(d1, d2): - for k, v in d1.items(): - if type(v) is dict: - assign(v, d2[k]) - else: - v[i : i + chunk_size] = d2[k] - - assign(out, output_chunk) - elif out_type is tuple: - for x1, x2 in zip(out, output_chunk): - x1[i : i + chunk_size] = x2 - elif out_type is torch.Tensor: - out[i : i + chunk_size] = output_chunk - else: - raise ValueError("Not supported") - - i += chunk_size - - reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) - out = tensor_tree_map(reshape, out) - - return out diff --git a/tests/test_autochunk/openfold/triangular_attention.py b/tests/test_autochunk/openfold/triangular_attention.py deleted file mode 100644 index 12d09c502daf..000000000000 --- a/tests/test_autochunk/openfold/triangular_attention.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partialmethod, partial -import math -from typing import Optional, List - -import torch -import torch.nn as nn - -from .primitives import Linear, LayerNorm, Attention -from .tensor_utils import ( - chunk_layer, - permute_final_dims, - flatten_final_dims, -) - - -class TriangleAttention(nn.Module): - def __init__( - self, c_in, c_hidden, no_heads, starting, inf=1e9 - ): - """ - Args: - c_in: - Input channel dimension - c_hidden: - Overall hidden channel dimension (not per-head) - no_heads: - Number of attention heads - """ - super(TriangleAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.starting = starting - self.inf = inf - - self.layer_norm = LayerNorm(self.c_in) - - self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") - - self.mha = Attention( - self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads - ) - - @torch.jit.ignore - def _chunk(self, - x: torch.Tensor, - biases: List[torch.Tensor], - chunk_size: int, - ) -> torch.Tensor: - mha_inputs = { - "q_x": x, - "kv_x": x, - "biases": biases, - } - return chunk_layer( - partial(self.mha), - mha_inputs, - chunk_size=chunk_size, - no_batch_dims=len(x.shape[:-2]), - ) - - def forward(self, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None - ) -> torch.Tensor: - """ - Args: - x: - [*, I, J, C_in] input tensor (e.g. the pair representation) - Returns: - [*, I, J, C_in] output tensor - """ - if mask is None: - # [*, I, J] - mask = x.new_ones( - x.shape[:-1], - ) - - # Shape annotations assume self.starting. Else, I and J are flipped - if not self.starting: - x = x.transpose(-2, -3) - mask = mask.transpose(-1, -2) - - # [*, I, J, C_in] - x = self.layer_norm(x) - - # [*, I, 1, 1, J] - mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] - - # [*, H, I, J] - triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) - - # [*, 1, H, I, J] - triangle_bias = triangle_bias.unsqueeze(-4) - - biases = [mask_bias, triangle_bias] - - if chunk_size is not None: - x = self._chunk(x, biases, chunk_size) - else: - x = self.mha(q_x=x, kv_x=x, biases=biases) - - if not self.starting: - x = x.transpose(-2, -3) - - return x - - -class TriangleAttentionStartingNode(TriangleAttention): - """ - Implements Algorithm 13. - """ - - __init__ = partialmethod(TriangleAttention.__init__, starting=True) - - -class TriangleAttentionEndingNode(TriangleAttention): - """ - Implements Algorithm 14. - """ - - __init__ = partialmethod(TriangleAttention.__init__, starting=False) diff --git a/tests/test_autochunk/openfold/triangular_multiplicative_update.py b/tests/test_autochunk/openfold/triangular_multiplicative_update.py deleted file mode 100644 index 29f7062c3212..000000000000 --- a/tests/test_autochunk/openfold/triangular_multiplicative_update.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partialmethod -from typing import Optional - -import torch -import torch.nn as nn - -from .primitives import Linear, LayerNorm -from .tensor_utils import permute_final_dims - - -class TriangleMultiplicativeUpdate(nn.Module): - """ - Implements Algorithms 11 and 12. - """ - def __init__(self, c_z, c_hidden, _outgoing=True): - """ - Args: - c_z: - Input channel dimension - c: - Hidden channel dimension - """ - super(TriangleMultiplicativeUpdate, self).__init__() - self.c_z = c_z - self.c_hidden = c_hidden - self._outgoing = _outgoing - - self.linear_a_p = Linear(self.c_z, self.c_hidden) - self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating") - self.linear_b_p = Linear(self.c_z, self.c_hidden) - self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating") - self.linear_g = Linear(self.c_z, self.c_z, init="gating") - self.linear_z = Linear(self.c_hidden, self.c_z, init="final") - - self.layer_norm_in = LayerNorm(self.c_z) - self.layer_norm_out = LayerNorm(self.c_hidden) - - self.sigmoid = nn.Sigmoid() - - def _combine_projections(self, - a: torch.Tensor, - b: torch.Tensor, - ) -> torch.Tensor: - raise NotImplementedError("This method needs to be overridden") - - def forward(self, - z: torch.Tensor, - mask: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """ - Args: - x: - [*, N_res, N_res, C_z] input tensor - mask: - [*, N_res, N_res] input mask - Returns: - [*, N_res, N_res, C_z] output tensor - """ - if mask is None: - mask = z.new_ones(z.shape[:-1]) - - mask = mask.unsqueeze(-1) - - z = self.layer_norm_in(z) - a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z)) - a = a * mask - b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) - b = b * mask - x = self._combine_projections(a, b) - x = self.layer_norm_out(x) - x = self.linear_z(x) - g = self.sigmoid(self.linear_g(z)) - z = x * g - - return z - - -class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): - """ - Implements Algorithm 11. - """ - def _combine_projections(self, - a: torch.Tensor, # [*, N_i, N_k, C] - b: torch.Tensor, # [*, N_j, N_k, C] - ): - # [*, C, N_i, N_j] - p = torch.matmul( - permute_final_dims(a, (2, 0, 1)), - permute_final_dims(b, (2, 1, 0)), - ) - - # [*, N_i, N_j, C] - return permute_final_dims(p, (1, 2, 0)) - - -class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): - """ - Implements Algorithm 12. - """ - def _combine_projections(self, - a: torch.Tensor, # [*, N_k, N_i, C] - b: torch.Tensor, # [*, N_k, N_j, C] - ): - # [*, C, N_i, N_j] - p = torch.matmul( - permute_final_dims(a, (2, 1, 0)), - permute_final_dims(b, (2, 0, 1)), - ) - - # [*, N_i, N_j, C] - return permute_final_dims(p, (1, 2, 0)) - diff --git a/tests/test_autochunk/origin_openfold/__init__.py b/tests/test_autochunk/origin_openfold/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/test_autochunk/origin_openfold/dropout.py b/tests/test_autochunk/origin_openfold/dropout.py deleted file mode 100644 index 5e3f8d620498..000000000000 --- a/tests/test_autochunk/origin_openfold/dropout.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partialmethod -from typing import List, Union - -import torch -import torch.nn as nn - - -class Dropout(nn.Module): - """ - Implementation of dropout with the ability to share the dropout mask - along a particular dimension. - - If not in training mode, this module computes the identity function. - """ - - def __init__(self, r: float, batch_dim: Union[int, List[int]]): - """ - Args: - r: - Dropout rate - batch_dim: - Dimension(s) along which the dropout mask is shared - """ - super(Dropout, self).__init__() - - self.r = r - if type(batch_dim) == int: - batch_dim = [batch_dim] - self.batch_dim = batch_dim - self.dropout = nn.Dropout(self.r) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - Tensor to which dropout is applied. Can have any shape - compatible with self.batch_dim - """ - shape = list(x.shape) - if self.batch_dim is not None: - for bd in self.batch_dim: - shape[bd] = 1 - mask = x.new_ones(shape) - mask = self.dropout(mask) - x *= mask - return x - - -class DropoutRowwise(Dropout): - """ - Convenience class for rowwise dropout as described in subsection - 1.11.6. - """ - - __init__ = partialmethod(Dropout.__init__, batch_dim=-3) - - -class DropoutColumnwise(Dropout): - """ - Convenience class for columnwise dropout as described in subsection - 1.11.6. - """ - - __init__ = partialmethod(Dropout.__init__, batch_dim=-2) diff --git a/tests/test_autochunk/origin_openfold/embedders.py b/tests/test_autochunk/origin_openfold/embedders.py deleted file mode 100644 index daeddb084a90..000000000000 --- a/tests/test_autochunk/origin_openfold/embedders.py +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Dict, Tuple - -import torch -import torch.nn as nn - -from .primitives import LayerNorm, Linear -from .template import TemplatePairStack, TemplatePointwiseAttention -from .utils import all_atom_multimer, geometry -from .utils.feats import build_template_angle_feat, build_template_pair_feat -from .utils.tensor_utils import dict_multimap, one_hot, tensor_tree_map - - -class InputEmbedder(nn.Module): - """ - Embeds a subset of the input features. - - Implements Algorithms 3 (InputEmbedder) and 4 (relpos). - """ - - def __init__( - self, - tf_dim: int, - msa_dim: int, - c_z: int, - c_m: int, - relpos_k: int, - **kwargs, - ): - """ - Args: - tf_dim: - Final dimension of the target features - msa_dim: - Final dimension of the MSA features - c_z: - Pair embedding dimension - c_m: - MSA embedding dimension - relpos_k: - Window size used in relative positional encoding - """ - super(InputEmbedder, self).__init__() - - self.tf_dim = tf_dim - self.msa_dim = msa_dim - - self.c_z = c_z - self.c_m = c_m - - self.linear_tf_z_i = Linear(tf_dim, c_z) - self.linear_tf_z_j = Linear(tf_dim, c_z) - self.linear_tf_m = Linear(tf_dim, c_m) - self.linear_msa_m = Linear(msa_dim, c_m) - - # RPE stuff - self.relpos_k = relpos_k - self.no_bins = 2 * relpos_k + 1 - self.linear_relpos = Linear(self.no_bins, c_z) - - def relpos(self, ri: torch.Tensor): - """ - Computes relative positional encodings - - Implements Algorithm 4. - - Args: - ri: - "residue_index" features of shape [*, N] - """ - d = ri[..., None] - ri[..., None, :] - boundaries = torch.arange(start=-self.relpos_k, end=self.relpos_k + 1, device=d.device) - oh = one_hot(d, boundaries).type(ri.dtype) - return self.linear_relpos(oh) - - def forward( - self, - tf: torch.Tensor, - ri: torch.Tensor, - msa: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - tf: - "target_feat" features of shape [*, N_res, tf_dim] - ri: - "residue_index" features of shape [*, N_res] - msa: - "msa_feat" features of shape [*, N_clust, N_res, msa_dim] - Returns: - msa_emb: - [*, N_clust, N_res, C_m] MSA embedding - pair_emb: - [*, N_res, N_res, C_z] pair embedding - - """ - # [*, N_res, c_z] - tf_emb_i = self.linear_tf_z_i(tf) - tf_emb_j = self.linear_tf_z_j(tf) - - # [*, N_res, N_res, c_z] - pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] - pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype)) - - # [*, N_clust, N_res, c_m] - n_clust = msa.shape[-3] - tf_m = (self.linear_tf_m(tf).unsqueeze(-3).expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))) - msa_emb = self.linear_msa_m(msa) + tf_m - - return msa_emb, pair_emb - - -class RecyclingEmbedder(nn.Module): - """ - Embeds the output of an iteration of the model for recycling. - - Implements Algorithm 32. - """ - - def __init__( - self, - c_m: int, - c_z: int, - min_bin: float, - max_bin: float, - no_bins: int, - inf: float = 1e8, - **kwargs, - ): - """ - Args: - c_m: - MSA channel dimension - c_z: - Pair embedding channel dimension - min_bin: - Smallest distogram bin (Angstroms) - max_bin: - Largest distogram bin (Angstroms) - no_bins: - Number of distogram bins - """ - super(RecyclingEmbedder, self).__init__() - - self.c_m = c_m - self.c_z = c_z - self.min_bin = min_bin - self.max_bin = max_bin - self.no_bins = no_bins - self.inf = inf - - self.linear = Linear(self.no_bins, self.c_z) - self.layer_norm_m = LayerNorm(self.c_m) - self.layer_norm_z = LayerNorm(self.c_z) - - def forward( - self, - m: torch.Tensor, - z: torch.Tensor, - x: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - m: - First row of the MSA embedding. [*, N_res, C_m] - z: - [*, N_res, N_res, C_z] pair embedding - x: - [*, N_res, 3] predicted C_beta coordinates - Returns: - m: - [*, N_res, C_m] MSA embedding update - z: - [*, N_res, N_res, C_z] pair embedding update - """ - bins = torch.linspace( - self.min_bin, - self.max_bin, - self.no_bins, - dtype=x.dtype, - device=x.device, - requires_grad=False, - ) - - # [*, N, C_m] - m_update = self.layer_norm_m(m) - - # This squared method might become problematic in FP16 mode. - # I'm using it because my homegrown method had a stubborn discrepancy I - # couldn't find in time. - squared_bins = bins**2 - upper = torch.cat([squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1) - d = torch.sum((x[..., None, :] - x[..., None, :, :])**2, dim=-1, keepdims=True) - - # [*, N, N, no_bins] - d = ((d > squared_bins) * (d < upper)).type(x.dtype) - - # [*, N, N, C_z] - d = self.linear(d) - z_update = d + self.layer_norm_z(z) - - return m_update, z_update - - -class TemplateEmbedder(nn.Module): - - def __init__(self, config): - super(TemplateEmbedder, self).__init__() - - self.config = config - self.template_angle_embedder = TemplateAngleEmbedder(**config["template_angle_embedder"],) - self.template_pair_embedder = TemplatePairEmbedder(**config["template_pair_embedder"],) - self.template_pair_stack = TemplatePairStack(**config["template_pair_stack"],) - self.template_pointwise_att = TemplatePointwiseAttention(**config["template_pointwise_attention"],) - - def forward(self, batch, z, pair_mask, templ_dim, chunk_size, _mask_trans=True): - # Embed the templates one at a time (with a poor man's vmap) - template_embeds = [] - n_templ = batch["template_aatype"].shape[templ_dim] - for i in range(n_templ): - idx = batch["template_aatype"].new_tensor(i) - single_template_feats = tensor_tree_map( - lambda t: torch.index_select(t, templ_dim, idx), - batch, - ) - - single_template_embeds = {} - if self.config.embed_angles: - template_angle_feat = build_template_angle_feat(single_template_feats,) - - # [*, S_t, N, C_m] - a = self.template_angle_embedder(template_angle_feat) - - single_template_embeds["angle"] = a - - # [*, S_t, N, N, C_t] - t = build_template_pair_feat( - single_template_feats, - use_unit_vector=self.config.use_unit_vector, - inf=self.config.inf, - eps=self.config.eps, - **self.config.distogram, - ).to(z.dtype) - t = self.template_pair_embedder(t) - - single_template_embeds.update({"pair": t}) - - template_embeds.append(single_template_embeds) - - template_embeds = dict_multimap( - partial(torch.cat, dim=templ_dim), - template_embeds, - ) - - # [*, S_t, N, N, C_z] - t = self.template_pair_stack( - template_embeds["pair"], - pair_mask.unsqueeze(-3).to(dtype=z.dtype), - chunk_size=chunk_size, - _mask_trans=_mask_trans, - ) - - # [*, N, N, C_z] - t = self.template_pointwise_att( - t, - z, - template_mask=batch["template_mask"].to(dtype=z.dtype), - chunk_size=chunk_size, - ) - t = t * (torch.sum(batch["template_mask"]) > 0) - - ret = {} - if self.config.embed_angles: - ret["template_single_embedding"] = template_embeds["angle"] - - ret.update({"template_pair_embedding": t}) - - return ret - - -class TemplateAngleEmbedder(nn.Module): - """ - Embeds the "template_angle_feat" feature. - - Implements Algorithm 2, line 7. - """ - - def __init__( - self, - c_in: int, - c_out: int, - **kwargs, - ): - """ - Args: - c_in: - Final dimension of "template_angle_feat" - c_out: - Output channel dimension - """ - super(TemplateAngleEmbedder, self).__init__() - - self.c_out = c_out - self.c_in = c_in - - self.linear_1 = Linear(self.c_in, self.c_out, init="relu") - self.relu = nn.ReLU() - self.linear_2 = Linear(self.c_out, self.c_out, init="relu") - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: [*, N_templ, N_res, c_in] "template_angle_feat" features - Returns: - x: [*, N_templ, N_res, C_out] embedding - """ - x = self.linear_1(x) - x = self.relu(x) - x = self.linear_2(x) - - return x - - -class TemplatePairEmbedder(nn.Module): - """ - Embeds "template_pair_feat" features. - - Implements Algorithm 2, line 9. - """ - - def __init__( - self, - c_in: int, - c_out: int, - **kwargs, - ): - """ - Args: - c_in: - - c_out: - Output channel dimension - """ - super(TemplatePairEmbedder, self).__init__() - - self.c_in = c_in - self.c_out = c_out - - # Despite there being no relu nearby, the source uses that initializer - self.linear = Linear(self.c_in, self.c_out, init="relu") - - def forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - x: - [*, C_in] input tensor - Returns: - [*, C_out] output tensor - """ - x = self.linear(x) - - return x - - -class ExtraMSAEmbedder(nn.Module): - """ - Embeds unclustered MSA sequences. - - Implements Algorithm 2, line 15 - """ - - def __init__( - self, - c_in: int, - c_out: int, - **kwargs, - ): - """ - Args: - c_in: - Input channel dimension - c_out: - Output channel dimension - """ - super(ExtraMSAEmbedder, self).__init__() - - self.c_in = c_in - self.c_out = c_out - - self.linear = Linear(self.c_in, self.c_out) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - [*, N_extra_seq, N_res, C_in] "extra_msa_feat" features - Returns: - [*, N_extra_seq, N_res, C_out] embedding - """ - x = self.linear(x) - - return x diff --git a/tests/test_autochunk/origin_openfold/embedders_multimer.py b/tests/test_autochunk/origin_openfold/embedders_multimer.py deleted file mode 100644 index 6bee17227457..000000000000 --- a/tests/test_autochunk/origin_openfold/embedders_multimer.py +++ /dev/null @@ -1,352 +0,0 @@ -from functools import partial -from typing import Dict, Tuple - -import torch -import torch.nn as nn - -from .primitives import LayerNorm, Linear -from .template import TemplatePairStack, TemplatePointwiseAttention -from .utils import all_atom_multimer, dgram_from_positions, geometry -from .utils.tensor_utils import dict_multimap, one_hot, tensor_tree_map - - -class InputEmbedderMultimer(nn.Module): - """ - Embeds a subset of the input features. - - Implements Algorithms 3 (InputEmbedder) and 4 (relpos). - """ - - def __init__( - self, - tf_dim: int, - msa_dim: int, - c_z: int, - c_m: int, - max_relative_idx: int, - use_chain_relative: bool, - max_relative_chain: int, - **kwargs, - ): - """ - Args: - tf_dim: - Final dimension of the target features - msa_dim: - Final dimension of the MSA features - c_z: - Pair embedding dimension - c_m: - MSA embedding dimension - relpos_k: - Window size used in relative positional encoding - """ - super(InputEmbedderMultimer, self).__init__() - - self.tf_dim = tf_dim - self.msa_dim = msa_dim - - self.c_z = c_z - self.c_m = c_m - - self.linear_tf_z_i = Linear(tf_dim, c_z) - self.linear_tf_z_j = Linear(tf_dim, c_z) - self.linear_tf_m = Linear(tf_dim, c_m) - self.linear_msa_m = Linear(msa_dim, c_m) - - # RPE stuff - self.max_relative_idx = max_relative_idx - self.use_chain_relative = use_chain_relative - self.max_relative_chain = max_relative_chain - if self.use_chain_relative: - self.no_bins = 2 * max_relative_idx + 2 + 1 + 2 * max_relative_chain + 2 - else: - self.no_bins = 2 * max_relative_idx + 1 - self.linear_relpos = Linear(self.no_bins, c_z) - - def relpos(self, batch: Dict[str, torch.Tensor]): - pos = batch["residue_index"] - asym_id = batch["asym_id"] - asym_id_same = asym_id[..., None] == asym_id[..., None, :] - offset = pos[..., None] - pos[..., None, :] - - clipped_offset = torch.clamp(offset + self.max_relative_idx, 0, 2 * self.max_relative_idx) - - rel_feats = [] - if self.use_chain_relative: - final_offset = torch.where( - asym_id_same, - clipped_offset, - (2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset), - ) - - rel_pos = torch.nn.functional.one_hot( - final_offset, - 2 * self.max_relative_idx + 2, - ) - - rel_feats.append(rel_pos) - - entity_id = batch["entity_id"] - entity_id_same = entity_id[..., None] == entity_id[..., None, :] - rel_feats.append(entity_id_same[..., None]) - - sym_id = batch["sym_id"] - rel_sym_id = sym_id[..., None] - sym_id[..., None, :] - - max_rel_chain = self.max_relative_chain - clipped_rel_chain = torch.clamp( - rel_sym_id + max_rel_chain, - 0, - 2 * max_rel_chain, - ) - - final_rel_chain = torch.where( - entity_id_same, - clipped_rel_chain, - (2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain), - ) - - rel_chain = torch.nn.functional.one_hot( - final_rel_chain.long(), - 2 * max_rel_chain + 2, - ) - - rel_feats.append(rel_chain) - else: - rel_pos = torch.nn.functional.one_hot( - clipped_offset, - 2 * self.max_relative_idx + 1, - ) - rel_feats.append(rel_pos) - - rel_feat = torch.cat(rel_feats, dim=-1).to(self.linear_relpos.weight.dtype) - - return self.linear_relpos(rel_feat) - - def forward(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - tf = batch["target_feat"] - msa = batch["msa_feat"] - - # [*, N_res, c_z] - tf_emb_i = self.linear_tf_z_i(tf) - tf_emb_j = self.linear_tf_z_j(tf) - - # [*, N_res, N_res, c_z] - pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] - pair_emb = pair_emb + self.relpos(batch) - - # [*, N_clust, N_res, c_m] - n_clust = msa.shape[-3] - tf_m = (self.linear_tf_m(tf).unsqueeze(-3).expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))) - msa_emb = self.linear_msa_m(msa) + tf_m - - return msa_emb, pair_emb - - -class TemplatePairEmbedderMultimer(nn.Module): - - def __init__( - self, - c_z: int, - c_out: int, - c_dgram: int, - c_aatype: int, - ): - super().__init__() - - self.dgram_linear = Linear(c_dgram, c_out) - self.aatype_linear_1 = Linear(c_aatype, c_out) - self.aatype_linear_2 = Linear(c_aatype, c_out) - self.query_embedding_layer_norm = LayerNorm(c_z) - self.query_embedding_linear = Linear(c_z, c_out) - - self.pseudo_beta_mask_linear = Linear(1, c_out) - self.x_linear = Linear(1, c_out) - self.y_linear = Linear(1, c_out) - self.z_linear = Linear(1, c_out) - self.backbone_mask_linear = Linear(1, c_out) - - def forward( - self, - template_dgram: torch.Tensor, - aatype_one_hot: torch.Tensor, - query_embedding: torch.Tensor, - pseudo_beta_mask: torch.Tensor, - backbone_mask: torch.Tensor, - multichain_mask_2d: torch.Tensor, - unit_vector: geometry.Vec3Array, - ) -> torch.Tensor: - act = 0. - - pseudo_beta_mask_2d = (pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]) - pseudo_beta_mask_2d *= multichain_mask_2d - template_dgram *= pseudo_beta_mask_2d[..., None] - act += self.dgram_linear(template_dgram) - act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None]) - - aatype_one_hot = aatype_one_hot.to(template_dgram.dtype) - act += self.aatype_linear_1(aatype_one_hot[..., None, :, :]) - act += self.aatype_linear_2(aatype_one_hot[..., None, :]) - - backbone_mask_2d = (backbone_mask[..., None] * backbone_mask[..., None, :]) - backbone_mask_2d *= multichain_mask_2d - x, y, z = [coord * backbone_mask_2d for coord in unit_vector] - act += self.x_linear(x[..., None]) - act += self.y_linear(y[..., None]) - act += self.z_linear(z[..., None]) - - act += self.backbone_mask_linear(backbone_mask_2d[..., None]) - - query_embedding = self.query_embedding_layer_norm(query_embedding) - act += self.query_embedding_linear(query_embedding) - - return act - - -class TemplateSingleEmbedderMultimer(nn.Module): - - def __init__( - self, - c_in: int, - c_m: int, - ): - super().__init__() - self.template_single_embedder = Linear(c_in, c_m) - self.template_projector = Linear(c_m, c_m) - - def forward( - self, - batch, - atom_pos, - aatype_one_hot, - ): - out = {} - - template_chi_angles, template_chi_mask = (all_atom_multimer.compute_chi_angles( - atom_pos, - batch["template_all_atom_mask"], - batch["template_aatype"], - )) - - template_features = torch.cat( - [ - aatype_one_hot, - torch.sin(template_chi_angles) * template_chi_mask, - torch.cos(template_chi_angles) * template_chi_mask, - template_chi_mask, - ], - dim=-1, - ) - - template_mask = template_chi_mask[..., 0] - - template_activations = self.template_single_embedder(template_features) - template_activations = torch.nn.functional.relu(template_activations) - template_activations = self.template_projector(template_activations,) - - out["template_single_embedding"] = (template_activations) - out["template_mask"] = template_mask - - return out - - -class TemplateEmbedderMultimer(nn.Module): - - def __init__(self, config): - super(TemplateEmbedderMultimer, self).__init__() - - self.config = config - self.template_pair_embedder = TemplatePairEmbedderMultimer(**config["template_pair_embedder"],) - self.template_single_embedder = TemplateSingleEmbedderMultimer(**config["template_single_embedder"],) - self.template_pair_stack = TemplatePairStack(**config["template_pair_stack"],) - - self.linear_t = Linear(config.c_t, config.c_z) - - def forward( - self, - batch, - z, - padding_mask_2d, - templ_dim, - chunk_size, - multichain_mask_2d, - ): - template_embeds = [] - n_templ = batch["template_aatype"].shape[templ_dim] - for i in range(n_templ): - idx = batch["template_aatype"].new_tensor(i) - single_template_feats = tensor_tree_map( - lambda t: torch.index_select(t, templ_dim, idx), - batch, - ) - - single_template_embeds = {} - act = 0. - - template_positions, pseudo_beta_mask = ( - single_template_feats["template_pseudo_beta"], - single_template_feats["template_pseudo_beta_mask"], - ) - - template_dgram = dgram_from_positions( - template_positions, - inf=self.config.inf, - **self.config.distogram, - ) - - aatype_one_hot = torch.nn.functional.one_hot( - single_template_feats["template_aatype"], - 22, - ) - - raw_atom_pos = single_template_feats["template_all_atom_positions"] - - atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) - rigid, backbone_mask = all_atom_multimer.make_backbone_affine( - atom_pos, - single_template_feats["template_all_atom_mask"], - single_template_feats["template_aatype"], - ) - points = rigid.translation - rigid_vec = rigid[..., None].inverse().apply_to_point(points) - unit_vector = rigid_vec.normalized() - - pair_act = self.template_pair_embedder( - template_dgram, - aatype_one_hot, - z, - pseudo_beta_mask, - backbone_mask, - multichain_mask_2d, - unit_vector, - ) - - single_template_embeds["template_pair_embedding"] = pair_act - single_template_embeds.update( - self.template_single_embedder( - single_template_feats, - atom_pos, - aatype_one_hot, - )) - template_embeds.append(single_template_embeds) - - template_embeds = dict_multimap( - partial(torch.cat, dim=templ_dim), - template_embeds, - ) - - # [*, S_t, N, N, C_z] - t = self.template_pair_stack( - template_embeds["template_pair_embedding"], - padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype), - chunk_size=chunk_size, - _mask_trans=False, - ) - # [*, N, N, C_z] - t = torch.sum(t, dim=-4) / n_templ - t = torch.nn.functional.relu(t) - t = self.linear_t(t) - template_embeds["template_pair_embedding"] = t - - return template_embeds diff --git a/tests/test_autochunk/origin_openfold/evoformer.py b/tests/test_autochunk/origin_openfold/evoformer.py deleted file mode 100644 index ab1f03a950f7..000000000000 --- a/tests/test_autochunk/origin_openfold/evoformer.py +++ /dev/null @@ -1,626 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from functools import partial -from typing import Optional, Tuple - -import torch -import torch.nn as nn - -from .dropout import DropoutColumnwise, DropoutRowwise -from .msa import MSAColumnAttention, MSAColumnGlobalAttention, MSARowAttentionWithPairBias -from .outer_product_mean import OuterProductMean -from .pair_transition import PairTransition -from .primitives import LayerNorm, Linear -from .triangular_attention import TriangleAttentionEndingNode, TriangleAttentionStartingNode -from .triangular_multiplicative_update import TriangleMultiplicationIncoming, TriangleMultiplicationOutgoing -from .utils.checkpointing import checkpoint_blocks, get_checkpoint_fn -from .utils.tensor_utils import chunk_layer - - -class MSATransition(nn.Module): - """ - Feed-forward network applied to MSA activations after attention. - - Implements Algorithm 9 - """ - - def __init__(self, c_m, n): - """ - Args: - c_m: - MSA channel dimension - n: - Factor multiplied to c_m to obtain the hidden channel - dimension - """ - super(MSATransition, self).__init__() - - self.c_m = c_m - self.n = n - - self.layer_norm = LayerNorm(self.c_m) - self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") - self.relu = nn.ReLU() - self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") - - def _transition(self, m, mask): - m = self.linear_1(m) - m = self.relu(m) - m = self.linear_2(m) * mask - return m - - @torch.jit.ignore - def _chunk( - self, - m: torch.Tensor, - mask: torch.Tensor, - chunk_size: int, - ) -> torch.Tensor: - return chunk_layer( - self._transition, - { - "m": m, - "mask": mask - }, - chunk_size=chunk_size, - no_batch_dims=len(m.shape[:-2]), - ) - - def forward( - self, - m: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - ) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA activation - mask: - [*, N_seq, N_res, C_m] MSA mask - Returns: - m: - [*, N_seq, N_res, C_m] MSA activation update - """ - - # DISCREPANCY: DeepMind forgets to apply the MSA mask here. - if mask is None: - mask = m.new_ones(m.shape[:-1]) - - # [*, N_seq, N_res, 1] - mask = mask.unsqueeze(-1) - - m = self.layer_norm(m) - - if chunk_size is not None: - m = self._chunk(m, mask, chunk_size) - else: - m = self._transition(m, mask) - - return m - - -class EvoformerBlockCore(nn.Module): - - def __init__( - self, - c_m: int, - c_z: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - transition_n: int, - pair_dropout: float, - inf: float, - eps: float, - _is_extra_msa_stack: bool = False, - is_multimer: bool = False, - ): - super(EvoformerBlockCore, self).__init__() - self.is_multimer = is_multimer - self.msa_transition = MSATransition( - c_m=c_m, - n=transition_n, - ) - - self.outer_product_mean = OuterProductMean( - c_m, - c_z, - c_hidden_opm, - ) - - self.tri_mul_out = TriangleMultiplicationOutgoing( - c_z, - c_hidden_mul, - ) - self.tri_mul_in = TriangleMultiplicationIncoming( - c_z, - c_hidden_mul, - ) - - self.tri_att_start = TriangleAttentionStartingNode( - c_z, - c_hidden_pair_att, - no_heads_pair, - inf=inf, - ) - self.tri_att_end = TriangleAttentionEndingNode( - c_z, - c_hidden_pair_att, - no_heads_pair, - inf=inf, - ) - - self.pair_transition = PairTransition( - c_z, - transition_n, - ) - - self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) - self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) - - def forward( - self, - m: torch.Tensor, - z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, - chunk_size: Optional[int] = None, - _mask_trans: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # DeepMind doesn't mask these transitions in the source, so _mask_trans - # should be disabled to better approximate the exact activations of - # the original. - msa_trans_mask = msa_mask if _mask_trans else None - pair_trans_mask = pair_mask if _mask_trans else None - - m = m + self.msa_transition(m, mask=msa_trans_mask, chunk_size=chunk_size) - z = z + self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size) - z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) - z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) - z = z + self.ps_dropout_row_layer(self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)) - z = z + self.ps_dropout_col_layer(self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)) - z = z + self.pair_transition(z, mask=pair_trans_mask, chunk_size=chunk_size) - - return m, z - - -class EvoformerBlock(nn.Module): - - def __init__( - self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - inf: float, - eps: float, - is_multimer: bool, - ): - super(EvoformerBlock, self).__init__() - - self.msa_att_row = MSARowAttentionWithPairBias( - c_m=c_m, - c_z=c_z, - c_hidden=c_hidden_msa_att, - no_heads=no_heads_msa, - inf=inf, - ) - - self.msa_att_col = MSAColumnAttention( - c_m, - c_hidden_msa_att, - no_heads_msa, - inf=inf, - ) - - self.msa_dropout_layer = DropoutRowwise(msa_dropout) - - self.core = EvoformerBlockCore( - c_m=c_m, - c_z=c_z, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - ) - - self.outer_product_mean = OuterProductMean( - c_m, - c_z, - c_hidden_opm, - ) - self.is_multimer = is_multimer - - def forward( - self, - m: torch.Tensor, - z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, - chunk_size: Optional[int] = None, - _mask_trans: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: - m = m + self.msa_dropout_layer(self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)) - m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) - m, z = self.core( - m, - z, - msa_mask=msa_mask, - pair_mask=pair_mask, - chunk_size=chunk_size, - _mask_trans=_mask_trans, - ) - - return m, z - - -class ExtraMSABlock(nn.Module): - """ - Almost identical to the standard EvoformerBlock, except in that the - ExtraMSABlock uses GlobalAttention for MSA column attention and - requires more fine-grained control over checkpointing. Separated from - its twin to preserve the TorchScript-ability of the latter. - """ - - def __init__( - self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - inf: float, - eps: float, - ckpt: bool, - is_multimer: bool, - ): - super(ExtraMSABlock, self).__init__() - - self.ckpt = ckpt - - self.msa_att_row = MSARowAttentionWithPairBias( - c_m=c_m, - c_z=c_z, - c_hidden=c_hidden_msa_att, - no_heads=no_heads_msa, - inf=inf, - ) - - self.msa_att_col = MSAColumnGlobalAttention( - c_in=c_m, - c_hidden=c_hidden_msa_att, - no_heads=no_heads_msa, - inf=inf, - eps=eps, - ) - - self.msa_dropout_layer = DropoutRowwise(msa_dropout) - - self.core = EvoformerBlockCore( - c_m=c_m, - c_z=c_z, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - ) - self.is_multimer = is_multimer - - def forward( - self, - m: torch.Tensor, - z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, - chunk_size: Optional[int] = None, - _chunk_logits: Optional[int] = 1024, - ) -> Tuple[torch.Tensor, torch.Tensor]: - m = m + self.msa_dropout_layer( - self.msa_att_row( - m.clone(), - z=z.clone(), - mask=msa_mask, - chunk_size=chunk_size, - _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None, - _checkpoint_chunks=self.ckpt if torch.is_grad_enabled() else False, - )) - - def fn(m, z): - m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) - m, z = self.core(m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size) - - return m, z - - if (torch.is_grad_enabled() and self.ckpt): - checkpoint_fn = get_checkpoint_fn() - m, z = checkpoint_fn(fn, m, z) - else: - m, z = fn(m, z) - - return m, z - - -class EvoformerStack(nn.Module): - """ - Main Evoformer trunk. - - Implements Algorithm 6. - """ - - def __init__( - self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - c_s: int, - no_heads_msa: int, - no_heads_pair: int, - no_blocks: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - blocks_per_ckpt: int, - inf: float, - eps: float, - clear_cache_between_blocks: bool = False, - is_multimer: bool = False, - **kwargs, - ): - """ - Args: - c_m: - MSA channel dimension - c_z: - Pair channel dimension - c_hidden_msa_att: - Hidden dimension in MSA attention - c_hidden_opm: - Hidden dimension in outer product mean module - c_hidden_mul: - Hidden dimension in multiplicative updates - c_hidden_pair_att: - Hidden dimension in triangular attention - c_s: - Channel dimension of the output "single" embedding - no_heads_msa: - Number of heads used for MSA attention - no_heads_pair: - Number of heads used for pair attention - no_blocks: - Number of Evoformer blocks in the stack - transition_n: - Factor by which to multiply c_m to obtain the MSATransition - hidden dimension - msa_dropout: - Dropout rate for MSA activations - pair_dropout: - Dropout used for pair activations - blocks_per_ckpt: - Number of Evoformer blocks in each activation checkpoint - clear_cache_between_blocks: - Whether to clear CUDA's GPU memory cache between blocks of the - stack. Slows down each block but can reduce fragmentation - """ - super(EvoformerStack, self).__init__() - - self.blocks_per_ckpt = blocks_per_ckpt - self.clear_cache_between_blocks = clear_cache_between_blocks - - self.blocks = nn.ModuleList() - - for _ in range(no_blocks): - block = EvoformerBlock( - c_m=c_m, - c_z=c_z, - c_hidden_msa_att=c_hidden_msa_att, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - msa_dropout=msa_dropout, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - is_multimer=is_multimer, - ) - self.blocks.append(block) - - self.linear = Linear(c_m, c_s) - - def forward( - self, - m: torch.Tensor, - z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, - chunk_size: int, - _mask_trans: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding - msa_mask: - [*, N_seq, N_res] MSA mask - pair_mask: - [*, N_res, N_res] pair mask - Returns: - m: - [*, N_seq, N_res, C_m] MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding - s: - [*, N_res, C_s] single embedding (or None if extra MSA stack) - """ - blocks = [ - partial( - b, - msa_mask=msa_mask, - pair_mask=pair_mask, - chunk_size=chunk_size, - _mask_trans=_mask_trans, - ) for b in self.blocks - ] - - if (self.clear_cache_between_blocks): - - def block_with_cache_clear(block, *args): - torch.cuda.empty_cache() - return block(*args) - - blocks = [partial(block_with_cache_clear, b) for b in blocks] - - m, z = checkpoint_blocks( - blocks, - args=(m, z), - blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, - ) - - s = self.linear(m[..., 0, :, :]) - - return m, z, s - - -class ExtraMSAStack(nn.Module): - """ - Implements Algorithm 18. - """ - - def __init__( - self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - no_blocks: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - inf: float, - eps: float, - ckpt: bool, - clear_cache_between_blocks: bool = False, - is_multimer: bool = False, - **kwargs, - ): - super(ExtraMSAStack, self).__init__() - - self.clear_cache_between_blocks = clear_cache_between_blocks - self.blocks = nn.ModuleList() - for _ in range(no_blocks): - block = ExtraMSABlock( - c_m=c_m, - c_z=c_z, - c_hidden_msa_att=c_hidden_msa_att, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - msa_dropout=msa_dropout, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - ckpt=ckpt, - is_multimer=is_multimer, - ) - self.blocks.append(block) - - def forward( - self, - m: torch.Tensor, - z: torch.Tensor, - chunk_size: int, - msa_mask: Optional[torch.Tensor] = None, - pair_mask: Optional[torch.Tensor] = None, - _mask_trans: bool = True, - ) -> torch.Tensor: - """ - Args: - m: - [*, N_extra, N_res, C_m] extra MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding - msa_mask: - Optional [*, N_extra, N_res] MSA mask - pair_mask: - Optional [*, N_res, N_res] pair mask - Returns: - [*, N_res, N_res, C_z] pair update - """ - #checkpoint_fn = get_checkpoint_fn() - #blocks = [ - # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks - #] - - #def dodo(b, *args): - # torch.cuda.empty_cache() - # return b(*args) - - #blocks = [partial(dodo, b) for b in blocks] - - #for b in blocks: - # if(torch.is_grad_enabled()): - # m, z = checkpoint_fn(b, *(m, z)) - # else: - # m, z = b(m, z) - - for b in self.blocks: - m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) - - if (self.clear_cache_between_blocks): - torch.cuda.empty_cache() - - return z diff --git a/tests/test_autochunk/origin_openfold/heads.py b/tests/test_autochunk/origin_openfold/heads.py deleted file mode 100644 index 57718d5d4a1f..000000000000 --- a/tests/test_autochunk/origin_openfold/heads.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn - -from .primitives import LayerNorm, Linear -from .utils.loss import compute_plddt, compute_predicted_aligned_error, compute_tm - - -class AuxiliaryHeads(nn.Module): - - def __init__(self, config): - super(AuxiliaryHeads, self).__init__() - - self.plddt = PerResidueLDDTCaPredictor(**config["lddt"],) - - self.distogram = DistogramHead(**config["distogram"],) - - self.masked_msa = MaskedMSAHead(**config["masked_msa"],) - - self.experimentally_resolved = ExperimentallyResolvedHead(**config["experimentally_resolved"],) - - if config.tm.enabled: - self.tm = TMScoreHead(**config.tm,) - - self.config = config - - def forward(self, outputs): - aux_out = {} - lddt_logits = self.plddt(outputs["sm"]["single"]) - aux_out["lddt_logits"] = lddt_logits - - # Required for relaxation later on - aux_out["plddt"] = compute_plddt(lddt_logits) - - distogram_logits = self.distogram(outputs["pair"]) - aux_out["distogram_logits"] = distogram_logits - - masked_msa_logits = self.masked_msa(outputs["msa"]) - aux_out["masked_msa_logits"] = masked_msa_logits - - experimentally_resolved_logits = self.experimentally_resolved(outputs["single"]) - aux_out["experimentally_resolved_logits"] = experimentally_resolved_logits - - if self.config.tm.enabled: - tm_logits = self.tm(outputs["pair"]) - aux_out["tm_logits"] = tm_logits - aux_out["predicted_tm_score"] = compute_tm(tm_logits, **self.config.tm) - aux_out.update(compute_predicted_aligned_error( - tm_logits, - **self.config.tm, - )) - - return aux_out - - -class PerResidueLDDTCaPredictor(nn.Module): - - def __init__(self, no_bins, c_in, c_hidden): - super(PerResidueLDDTCaPredictor, self).__init__() - - self.no_bins = no_bins - self.c_in = c_in - self.c_hidden = c_hidden - - self.layer_norm = LayerNorm(self.c_in) - - self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu") - self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu") - self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final") - - self.relu = nn.ReLU() - - def forward(self, s): - s = self.layer_norm(s) - s = self.linear_1(s) - s = self.relu(s) - s = self.linear_2(s) - s = self.relu(s) - s = self.linear_3(s) - - return s - - -class DistogramHead(nn.Module): - """ - Computes a distogram probability distribution. - - For use in computation of distogram loss, subsection 1.9.8 - """ - - def __init__(self, c_z, no_bins, **kwargs): - """ - Args: - c_z: - Input channel dimension - no_bins: - Number of distogram bins - """ - super(DistogramHead, self).__init__() - - self.c_z = c_z - self.no_bins = no_bins - - self.linear = Linear(self.c_z, self.no_bins, init="final") - - def forward(self, z): # [*, N, N, C_z] - """ - Args: - z: - [*, N_res, N_res, C_z] pair embedding - Returns: - [*, N, N, no_bins] distogram probability distribution - """ - # [*, N, N, no_bins] - logits = self.linear(z) - logits = logits + logits.transpose(-2, -3) - return logits - - -class TMScoreHead(nn.Module): - """ - For use in computation of TM-score, subsection 1.9.7 - """ - - def __init__(self, c_z, no_bins, **kwargs): - """ - Args: - c_z: - Input channel dimension - no_bins: - Number of bins - """ - super(TMScoreHead, self).__init__() - - self.c_z = c_z - self.no_bins = no_bins - - self.linear = Linear(self.c_z, self.no_bins, init="final") - - def forward(self, z): - """ - Args: - z: - [*, N_res, N_res, C_z] pairwise embedding - Returns: - [*, N_res, N_res, no_bins] prediction - """ - # [*, N, N, no_bins] - logits = self.linear(z) - return logits - - -class MaskedMSAHead(nn.Module): - """ - For use in computation of masked MSA loss, subsection 1.9.9 - """ - - def __init__(self, c_m, c_out, **kwargs): - """ - Args: - c_m: - MSA channel dimension - c_out: - Output channel dimension - """ - super(MaskedMSAHead, self).__init__() - - self.c_m = c_m - self.c_out = c_out - - self.linear = Linear(self.c_m, self.c_out, init="final") - - def forward(self, m): - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - Returns: - [*, N_seq, N_res, C_out] reconstruction - """ - # [*, N_seq, N_res, C_out] - logits = self.linear(m) - return logits - - -class ExperimentallyResolvedHead(nn.Module): - """ - For use in computation of "experimentally resolved" loss, subsection - 1.9.10 - """ - - def __init__(self, c_s, c_out, **kwargs): - """ - Args: - c_s: - Input channel dimension - c_out: - Number of distogram bins - """ - super(ExperimentallyResolvedHead, self).__init__() - - self.c_s = c_s - self.c_out = c_out - - self.linear = Linear(self.c_s, self.c_out, init="final") - - def forward(self, s): - """ - Args: - s: - [*, N_res, C_s] single embedding - Returns: - [*, N, C_out] logits - """ - # [*, N, C_out] - logits = self.linear(s) - return logits diff --git a/tests/test_autochunk/origin_openfold/msa.py b/tests/test_autochunk/origin_openfold/msa.py deleted file mode 100644 index 4c7714ab73f1..000000000000 --- a/tests/test_autochunk/origin_openfold/msa.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import torch -import torch.nn as nn - -from .primitives import Attention, GlobalAttention, LayerNorm, Linear, _attention_chunked_trainable -from .utils.checkpointing import get_checkpoint_fn -from .utils.tensor_utils import chunk_layer, flatten_final_dims, permute_final_dims - - -class MSAAttention(nn.Module): - - def __init__( - self, - c_in, - c_hidden, - no_heads, - pair_bias=False, - c_z=None, - inf=1e9, - ): - """ - Args: - c_in: - Input channel dimension - c_hidden: - Per-head hidden channel dimension - no_heads: - Number of attention heads - pair_bias: - Whether to use pair embedding bias - c_z: - Pair embedding channel dimension. Ignored unless pair_bias - is true - inf: - A large number to be used in computing the attention mask - """ - super(MSAAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.pair_bias = pair_bias - self.c_z = c_z - self.inf = inf - - self.layer_norm_m = LayerNorm(self.c_in) - - self.layer_norm_z = None - self.linear_z = None - if self.pair_bias: - self.layer_norm_z = LayerNorm(self.c_z) - self.linear_z = Linear(self.c_z, self.no_heads, bias=False, init="normal") - - self.mha = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) - - @torch.jit.ignore - def _chunk( - self, - m: torch.Tensor, - biases: List[torch.Tensor], - chunk_size: int, - ) -> torch.Tensor: - return chunk_layer( - self.mha, - { - "q_x": m, - "kv_x": m, - "biases": biases - }, - chunk_size=chunk_size, - no_batch_dims=len(m.shape[:-2]), - ) - - def _prep_inputs(self, m: torch.Tensor, z: Optional[torch.Tensor], - mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # [*, N_seq, N_res, C_m] - m = self.layer_norm_m(m) - - n_seq, n_res = m.shape[-3:-1] - if mask is None: - # [*, N_seq, N_res] - mask = m.new_ones(m.shape[:-3] + (n_seq, n_res),) - - # [*, N_seq, 1, 1, N_res] - mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] - - # This step simply returns a larger view of the bias, and does not - # consume additional memory. - # [*, N_seq, no_heads, N_res, N_res] - #bias = bias.expand( - # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) - #) - - if (self.pair_bias and z is not None and # For the - self.layer_norm_z is not None and # benefit of - self.linear_z is not None # TorchScript - ): - # [*, N_res, N_res, C_z] - z = self.layer_norm_z(z) - - # [*, N_res, N_res, no_heads] - z = self.linear_z(z) - - # [*, 1, no_heads, N_res, N_res] - z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) - - return m, mask_bias, z - - @torch.jit.ignore - def _chunked_msa_attn( - self, - m: torch.Tensor, - z: Optional[torch.Tensor], - mask: Optional[torch.Tensor], - chunk_logits: int, - checkpoint: bool, - ) -> torch.Tensor: - MSA_DIM = -4 - - def _get_qkv(m, z): - m, mask_bias, z = self._prep_inputs(m, z, mask) - q, k, v = self.mha._prep_qkv(m, m) - return m, q, k, v, mask_bias, z - - checkpoint_fn = get_checkpoint_fn() - - if (torch.is_grad_enabled() and checkpoint): - m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z) - else: - m, q, k, v, mask_bias, z = _get_qkv(m, z) - - o = _attention_chunked_trainable( - query=q, - key=k, - value=v, - biases=[mask_bias, z], - chunk_size=chunk_logits, - chunk_dim=MSA_DIM, - checkpoint=checkpoint, - ) - - if (torch.is_grad_enabled() and checkpoint): - # Storing an additional m here is far from ideal - m = checkpoint_fn(self.mha._wrap_up, o, m) - else: - m = self.mha._wrap_up(o, m) - - return m - - def forward( - self, - m: torch.Tensor, - z: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - _chunk_logits: Optional[int] = None, - _checkpoint_chunks: Optional[bool] = None, - ) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding. Required only if - pair_bias is True - mask: - [*, N_seq, N_res] MSA mask - chunk_size: - Size of chunks into which the inputs are split along their - batch dimensions. A low value decreases memory overhead at the - cost of slower execution. Chunking is not performed by default. - - """ - if (_chunk_logits is not None): - return self._chunked_msa_attn(m=m, - z=z, - mask=mask, - chunk_logits=_chunk_logits, - checkpoint=_checkpoint_chunks) - - m, mask_bias, z = self._prep_inputs(m, z, mask) - - biases = [mask_bias] - if (z is not None): - biases.append(z) - - if chunk_size is not None: - m = self._chunk(m, biases, chunk_size) - else: - m = self.mha(q_x=m, kv_x=m, biases=biases) - - return m - - -class MSARowAttentionWithPairBias(MSAAttention): - """ - Implements Algorithm 7. - """ - - def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9): - """ - Args: - c_m: - Input channel dimension - c_z: - Pair embedding channel dimension - c_hidden: - Per-head hidden channel dimension - no_heads: - Number of attention heads - inf: - Large number used to construct attention masks - """ - super(MSARowAttentionWithPairBias, self).__init__( - c_m, - c_hidden, - no_heads, - pair_bias=True, - c_z=c_z, - inf=inf, - ) - - -class MSAColumnAttention(nn.Module): - """ - Implements Algorithm 8. - - By rights, this should also be a subclass of MSAAttention. Alas, - most inheritance isn't supported by TorchScript. - """ - - def __init__(self, c_m, c_hidden, no_heads, inf=1e9): - """ - Args: - c_m: - MSA channel dimension - c_hidden: - Per-head hidden channel dimension - no_heads: - Number of attention heads - inf: - Large number used to construct attention masks - """ - super(MSAColumnAttention, self).__init__() - - self.c_m = c_m - self.c_hidden = c_hidden - self.no_heads = no_heads - self.inf = inf - - self._msa_att = MSAAttention( - c_in=c_m, - c_hidden=c_hidden, - no_heads=no_heads, - pair_bias=False, - c_z=None, - inf=inf, - ) - - def forward(self, - m: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - mask: - [*, N_seq, N_res] MSA mask - chunk_size: - Size of chunks into which the inputs are split along their - batch dimensions. A low value decreases memory overhead at the - cost of slower execution. Chunking is not performed by default. - """ - # [*, N_res, N_seq, C_in] - m = m.transpose(-2, -3) - if mask is not None: - mask = mask.transpose(-1, -2) - - m = self._msa_att(m, mask=mask, chunk_size=chunk_size) - - # [*, N_seq, N_res, C_in] - m = m.transpose(-2, -3) - if mask is not None: - mask = mask.transpose(-1, -2) - - return m - - -class MSAColumnGlobalAttention(nn.Module): - - def __init__( - self, - c_in, - c_hidden, - no_heads, - inf=1e9, - eps=1e-10, - ): - super(MSAColumnGlobalAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.inf = inf - self.eps = eps - - self.layer_norm_m = nn.LayerNorm(c_in) - - self.global_attention = GlobalAttention( - c_in=c_in, - c_hidden=c_hidden, - no_heads=no_heads, - inf=inf, - eps=eps, - ) - - @torch.jit.ignore - def _chunk( - self, - m: torch.Tensor, - mask: torch.Tensor, - chunk_size: int, - ) -> torch.Tensor: - mha_input = { - "m": m, - "mask": mask, - } - return chunk_layer( - self.global_attention, - mha_input, - chunk_size=chunk_size, - no_batch_dims=len(m.shape[:-2]), - ) - - def forward( - self, - m: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - ) -> torch.Tensor: - n_seq, n_res, c_in = m.shape[-3:] - - if mask is None: - # [*, N_seq, N_res] - mask = torch.ones( - m.shape[:-1], - dtype=m.dtype, - device=m.device, - ).detach() - - # [*, N_res, N_seq, C_in] - m = m.transpose(-2, -3) - mask = mask.transpose(-1, -2) - - # [*, N_res, N_seq, C_in] - m = self.layer_norm_m(m) - - if chunk_size is not None: - m = self._chunk(m, mask, chunk_size) - else: - m = self.global_attention(m=m, mask=mask) - - # [*, N_seq, N_res, C_in] - m = m.transpose(-2, -3) - - return m diff --git a/tests/test_autochunk/origin_openfold/outer_product_mean.py b/tests/test_autochunk/origin_openfold/outer_product_mean.py deleted file mode 100644 index 074555ad8a9a..000000000000 --- a/tests/test_autochunk/origin_openfold/outer_product_mean.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Optional - -import torch -import torch.nn as nn - -from .primitives import Linear -from .utils.tensor_utils import chunk_layer - - -class OuterProductMean(nn.Module): - """ - Implements Algorithm 10. - """ - - def __init__(self, c_m, c_z, c_hidden, eps=1e-3): - """ - Args: - c_m: - MSA embedding channel dimension - c_z: - Pair embedding channel dimension - c_hidden: - Hidden channel dimension - """ - super(OuterProductMean, self).__init__() - - self.c_m = c_m - self.c_z = c_z - self.c_hidden = c_hidden - self.eps = eps - - self.layer_norm = nn.LayerNorm(c_m) - self.linear_1 = Linear(c_m, c_hidden) - self.linear_2 = Linear(c_m, c_hidden) - self.linear_out = Linear(c_hidden**2, c_z, init="final") - - def _opm(self, a, b): - # [*, N_res, N_res, C, C] - outer = torch.einsum("...bac,...dae->...bdce", a, b) - - # [*, N_res, N_res, C * C] - outer = outer.reshape(outer.shape[:-2] + (-1,)) - - # [*, N_res, N_res, C_z] - outer = self.linear_out(outer) - - return outer - - @torch.jit.ignore - def _chunk(self, a: torch.Tensor, b: torch.Tensor, chunk_size: int) -> torch.Tensor: - # Since the "batch dim" in this case is not a true batch dimension - # (in that the shape of the output depends on it), we need to - # iterate over it ourselves - a_reshape = a.reshape((-1,) + a.shape[-3:]) - b_reshape = b.reshape((-1,) + b.shape[-3:]) - out = [] - for a_prime, b_prime in zip(a_reshape, b_reshape): - outer = chunk_layer( - partial(self._opm, b=b_prime), - {"a": a_prime}, - chunk_size=chunk_size, - no_batch_dims=1, - ) - out.append(outer) - outer = torch.stack(out, dim=0) - outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) - - return outer - - def forward(self, - m: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - mask: - [*, N_seq, N_res] MSA mask - Returns: - [*, N_res, N_res, C_z] pair embedding update - """ - if mask is None: - mask = m.new_ones(m.shape[:-1]) - - # [*, N_seq, N_res, C_m] - m = self.layer_norm(m) - - # [*, N_seq, N_res, C] - mask = mask.unsqueeze(-1) - a = self.linear_1(m) * mask - b = self.linear_2(m) * mask - - a = a.transpose(-2, -3) - b = b.transpose(-2, -3) - - if chunk_size is not None: - outer = self._chunk(a, b, chunk_size) - else: - outer = self._opm(a, b) - - # [*, N_res, N_res, 1] - norm = torch.einsum("...abc,...adc->...bdc", mask, mask) - - # [*, N_res, N_res, C_z] - outer = outer / (self.eps + norm) - - return outer diff --git a/tests/test_autochunk/origin_openfold/pair_transition.py b/tests/test_autochunk/origin_openfold/pair_transition.py deleted file mode 100644 index 9d32adb89b63..000000000000 --- a/tests/test_autochunk/origin_openfold/pair_transition.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import torch -import torch.nn as nn - -from .primitives import LayerNorm, Linear -from .utils.tensor_utils import chunk_layer - - -class PairTransition(nn.Module): - """ - Implements Algorithm 15. - """ - - def __init__(self, c_z, n): - """ - Args: - c_z: - Pair transition channel dimension - n: - Factor by which c_z is multiplied to obtain hidden channel - dimension - """ - super(PairTransition, self).__init__() - - self.c_z = c_z - self.n = n - - self.layer_norm = LayerNorm(self.c_z) - self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") - self.relu = nn.ReLU() - self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") - - def _transition(self, z, mask): - # [*, N_res, N_res, C_hidden] - z = self.linear_1(z) - z = self.relu(z) - - # [*, N_res, N_res, C_z] - z = self.linear_2(z) * mask - - return z - - @torch.jit.ignore - def _chunk( - self, - z: torch.Tensor, - mask: torch.Tensor, - chunk_size: int, - ) -> torch.Tensor: - return chunk_layer( - self._transition, - { - "z": z, - "mask": mask - }, - chunk_size=chunk_size, - no_batch_dims=len(z.shape[:-2]), - ) - - def forward( - self, - z: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - ) -> torch.Tensor: - """ - Args: - z: - [*, N_res, N_res, C_z] pair embedding - Returns: - [*, N_res, N_res, C_z] pair embedding update - """ - # DISCREPANCY: DeepMind forgets to apply the mask in this module. - if mask is None: - mask = z.new_ones(z.shape[:-1]) - - # [*, N_res, N_res, 1] - mask = mask.unsqueeze(-1) - - # [*, N_res, N_res, C_z] - z = self.layer_norm(z) - - if chunk_size is not None: - z = self._chunk(z, mask, chunk_size) - else: - z = self._transition(z=z, mask=mask) - - return z diff --git a/tests/test_autochunk/origin_openfold/primitives.py b/tests/test_autochunk/origin_openfold/primitives.py deleted file mode 100644 index 5b7556ee3d60..000000000000 --- a/tests/test_autochunk/origin_openfold/primitives.py +++ /dev/null @@ -1,544 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from functools import partial -from typing import Callable, List, Optional, Sequence, Tuple - -import numpy as np -import torch -import torch.nn as nn -from scipy.stats import truncnorm - -from .utils.checkpointing import get_checkpoint_fn -from .utils.tensor_utils import _chunk_slice, flatten_final_dims, permute_final_dims - - -def _prod(nums): - out = 1 - for n in nums: - out = out * n - return out - - -def _calculate_fan(linear_weight_shape, fan="fan_in"): - fan_out, fan_in = linear_weight_shape - - if fan == "fan_in": - f = fan_in - elif fan == "fan_out": - f = fan_out - elif fan == "fan_avg": - f = (fan_in + fan_out) / 2 - else: - raise ValueError("Invalid fan option") - - return f - - -def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): - shape = weights.shape - f = _calculate_fan(shape, fan) - scale = scale / max(1, f) - a = -2 - b = 2 - std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) - size = _prod(shape) - samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) - samples = np.reshape(samples, shape) - with torch.no_grad(): - weights.copy_(torch.tensor(samples, device=weights.device)) - - -def lecun_normal_init_(weights): - trunc_normal_init_(weights, scale=1.0) - - -def he_normal_init_(weights): - trunc_normal_init_(weights, scale=2.0) - - -def glorot_uniform_init_(weights): - nn.init.xavier_uniform_(weights, gain=1) - - -def final_init_(weights): - with torch.no_grad(): - weights.fill_(0.0) - - -def gating_init_(weights): - with torch.no_grad(): - weights.fill_(0.0) - - -def normal_init_(weights): - torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") - - -def ipa_point_weights_init_(weights): - with torch.no_grad(): - softplus_inverse_1 = 0.541324854612918 - weights.fill_(softplus_inverse_1) - - -class Linear(nn.Linear): - """ - A Linear layer with built-in nonstandard initializations. Called just - like torch.nn.Linear. - - Implements the initializers in 1.11.4, plus some additional ones found - in the code. - """ - - def __init__( - self, - in_dim: int, - out_dim: int, - bias: bool = True, - init: str = "default", - init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, - ): - """ - Args: - in_dim: - The final dimension of inputs to the layer - out_dim: - The final dimension of layer outputs - bias: - Whether to learn an additive bias. True by default - init: - The initializer to use. Choose from: - - "default": LeCun fan-in truncated normal initialization - "relu": He initialization w/ truncated normal distribution - "glorot": Fan-average Glorot uniform initialization - "gating": Weights=0, Bias=1 - "normal": Normal initialization with std=1/sqrt(fan_in) - "final": Weights=0, Bias=0 - - Overridden by init_fn if the latter is not None. - init_fn: - A custom initializer taking weight and bias as inputs. - Overrides init if not None. - """ - super(Linear, self).__init__(in_dim, out_dim, bias=bias) - - if bias: - with torch.no_grad(): - self.bias.fill_(0) - - if init_fn is not None: - init_fn(self.weight, self.bias) - else: - if init == "default": - lecun_normal_init_(self.weight) - elif init == "relu": - he_normal_init_(self.weight) - elif init == "glorot": - glorot_uniform_init_(self.weight) - elif init == "gating": - gating_init_(self.weight) - if bias: - with torch.no_grad(): - self.bias.fill_(1.0) - elif init == "normal": - normal_init_(self.weight) - elif init == "final": - final_init_(self.weight) - else: - raise ValueError("Invalid init string.") - - -class LayerNorm(nn.Module): - - def __init__(self, c_in, eps=1e-5): - super(LayerNorm, self).__init__() - - self.c_in = (c_in,) - self.eps = eps - - self.weight = nn.Parameter(torch.ones(c_in)) - self.bias = nn.Parameter(torch.zeros(c_in)) - - def forward(self, x): - out = nn.functional.layer_norm( - x, - self.c_in, - self.weight, - self.bias, - self.eps, - ) - - return out - - -@torch.jit.ignore -def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Softmax, but without automatic casting to fp32 when the input is of - type bfloat16 - """ - s = torch.nn.functional.softmax(t, dim=dim) - - return s - - -#@torch.jit.script -def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor: - # [*, H, Q, C_hidden] - query = permute_final_dims(query, (1, 0, 2)) - - # [*, H, C_hidden, K] - key = permute_final_dims(key, (1, 2, 0)) - - # [*, H, V, C_hidden] - value = permute_final_dims(value, (1, 0, 2)) - - # [*, H, Q, K] - a = torch.matmul(query, key) - - for b in biases: - a += b - - a = softmax(a, -1) - - # [*, H, Q, C_hidden] - a = torch.matmul(a, value) - - # [*, Q, H, C_hidden] - a = a.transpose(-2, -3) - - return a - - -@torch.jit.ignore -def _attention_chunked_trainable( - query, - key, - value, - biases, - chunk_size, - chunk_dim, - checkpoint, -): - if (checkpoint and len(biases) > 2): - raise ValueError("Checkpointed version permits only permits two bias terms") - - def _checkpointable_attention(q, k, v, b1, b2): - bs = [b for b in [b1, b2] if b is not None] - return _attention(q, k, v, bs) - - o_chunks = [] - checkpoint_fn = get_checkpoint_fn() - count = query.shape[chunk_dim] - for start in range(0, count, chunk_size): - end = start + chunk_size - idx = [slice(None)] * len(query.shape) - idx[chunk_dim] = slice(start, end) - idx_tup = tuple(idx) - q_chunk = query[idx_tup] - k_chunk = key[idx_tup] - v_chunk = value[idx_tup] - - def _slice_bias(b): - idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)) - return b[tuple(idx)] - - if (checkpoint): - bias_1_chunk, bias_2_chunk = [ - _slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2] - ] - - o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk) - else: - bias_chunks = [_slice_bias(b) for b in biases] - - o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) - - o_chunks.append(o_chunk) - - o = torch.cat(o_chunks, dim=chunk_dim) - return o - - -class Attention(nn.Module): - """ - Standard multi-head attention using AlphaFold's default layer - initialization. Allows multiple bias vectors. - """ - - def __init__( - self, - c_q: int, - c_k: int, - c_v: int, - c_hidden: int, - no_heads: int, - gating: bool = True, - ): - """ - Args: - c_q: - Input dimension of query data - c_k: - Input dimension of key data - c_v: - Input dimension of value data - c_hidden: - Per-head hidden dimension - no_heads: - Number of attention heads - gating: - Whether the output should be gated using query data - """ - super(Attention, self).__init__() - - self.c_q = c_q - self.c_k = c_k - self.c_v = c_v - self.c_hidden = c_hidden - self.no_heads = no_heads - self.gating = gating - - # DISCREPANCY: c_hidden is not the per-head channel dimension, as - # stated in the supplement, but the overall channel dimension. - - self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") - self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") - self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") - self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final") - - self.linear_g = None - if self.gating: - self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating") - - self.sigmoid = nn.Sigmoid() - - def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # [*, Q/K/V, H * C_hidden] - q = self.linear_q(q_x) - k = self.linear_k(kv_x) - v = self.linear_v(kv_x) - - # [*, Q/K, H, C_hidden] - q = q.view(q.shape[:-1] + (self.no_heads, -1)) - k = k.view(k.shape[:-1] + (self.no_heads, -1)) - v = v.view(v.shape[:-1] + (self.no_heads, -1)) - - q /= math.sqrt(self.c_hidden) - - return q, k, v - - def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: - if (self.linear_g is not None): - g = self.sigmoid(self.linear_g(q_x)) - - # [*, Q, H, C_hidden] - g = g.view(g.shape[:-1] + (self.no_heads, -1)) - o = o * g - - # [*, Q, H * C_hidden] - o = flatten_final_dims(o, 2) - - # [*, Q, C_q] - o = self.linear_o(o) - - return o - - def forward( - self, - q_x: torch.Tensor, - kv_x: torch.Tensor, - biases: Optional[List[torch.Tensor]] = None, - use_lma: bool = False, - q_chunk_size: Optional[int] = None, - kv_chunk_size: Optional[int] = None, - ) -> torch.Tensor: - """ - Args: - q_x: - [*, Q, C_q] query data - kv_x: - [*, K, C_k] key data - biases: - List of biases that broadcast to [*, H, Q, K] - use_lma: - Whether to use low-memory attention - q_chunk_size: - Query chunk size (for LMA) - kv_chunk_size: - Key/Value chunk size (for LMA) - Returns - [*, Q, C_q] attention update - """ - if (biases is None): - biases = [] - if (use_lma and (q_chunk_size is None or kv_chunk_size is None)): - raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must " - "be provided") - - q, k, v = self._prep_qkv(q_x, kv_x) - - if (use_lma): - biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases] - - o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) - else: - o = _attention(q, k, v, biases) - - o = self._wrap_up(o, q_x) - - return o - - -class GlobalAttention(nn.Module): - - def __init__(self, c_in, c_hidden, no_heads, inf, eps): - super(GlobalAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.inf = inf - self.eps = eps - - self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot") - - self.linear_k = Linear( - c_in, - c_hidden, - bias=False, - init="glorot", - ) - self.linear_v = Linear( - c_in, - c_hidden, - bias=False, - init="glorot", - ) - self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") - self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") - - self.sigmoid = nn.Sigmoid() - - def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - # [*, N_res, C_in] - q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps) - - # [*, N_res, H * C_hidden] - q = self.linear_q(q) - q *= (self.c_hidden**(-0.5)) - - # [*, N_res, H, C_hidden] - q = q.view(q.shape[:-1] + (self.no_heads, -1)) - - # [*, N_res, N_seq, C_hidden] - k = self.linear_k(m) - v = self.linear_v(m) - - # [*, N_res, H, N_seq] - a = torch.matmul( - q, - k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] - ) - bias = (self.inf * (mask - 1))[..., :, None, :] - a += bias - a = softmax(a) - - # [*, N_res, H, C_hidden] - o = torch.matmul( - a, - v, - ) - - # [*, N_res, N_seq, C_hidden] - g = self.sigmoid(self.linear_g(m)) - - # [*, N_res, N_seq, H, C_hidden] - g = g.view(g.shape[:-1] + (self.no_heads, -1)) - - # [*, N_res, N_seq, H, C_hidden] - o = o.unsqueeze(-3) * g - - # [*, N_res, N_seq, H * C_hidden] - o = o.reshape(o.shape[:-2] + (-1,)) - - # [*, N_res, N_seq, C_in] - m = self.linear_o(o) - - return m - - -def _lma( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - biases: List[torch.Tensor], - q_chunk_size: int, - kv_chunk_size: int, -): - no_q, no_kv = q.shape[-3], k.shape[-3] - - # [*, Q, H, C_hidden] - o = q.new_zeros(q.shape) - for q_s in range(0, no_q, q_chunk_size): - q_chunk = q[..., q_s:q_s + q_chunk_size, :, :] - large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases] - - maxes = [] - weights = [] - values = [] - for kv_s in range(0, no_kv, kv_chunk_size): - k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :] - v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :] - small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks] - - a = torch.einsum( - "...qhd,...khd->...hqk", - q_chunk, - k_chunk, - ) - - for b in small_bias_chunks: - a += b - - a = a.transpose(-2, -3) - - max_a = torch.max(a, dim=-1, keepdim=True)[0] - exp_a = torch.exp(a - max_a) - exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) - - maxes.append(max_a.detach().squeeze(-1)) - weights.append(torch.sum(exp_a, dim=-1)) - values.append(exp_v) - - chunk_max = torch.stack(maxes, dim=-3) - chunk_weights = torch.stack(weights, dim=-3) - chunk_values = torch.stack(values, dim=-4) - - global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] - max_diffs = torch.exp(chunk_max - global_max) - chunk_values *= max_diffs.unsqueeze(-1) - chunk_weights *= max_diffs - - all_values = torch.sum(chunk_values, dim=-4) - all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) - - q_chunk_out = all_values / all_weights - - o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out - - return o diff --git a/tests/test_autochunk/origin_openfold/structure_module.py b/tests/test_autochunk/origin_openfold/structure_module.py deleted file mode 100644 index 051109515daa..000000000000 --- a/tests/test_autochunk/origin_openfold/structure_module.py +++ /dev/null @@ -1,914 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Any, Dict, Optional, Tuple, Union - -import torch -import torch.nn as nn -from fastfold.common.residue_constants import ( - restype_atom14_mask, - restype_atom14_rigid_group_positions, - restype_atom14_to_rigid_group, - restype_rigid_group_default_frame, -) -from fastfold.utils.feats import frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames -from fastfold.utils.geometry.quat_rigid import QuatRigid -from fastfold.utils.geometry.rigid_matrix_vector import Rigid3Array -from fastfold.utils.geometry.vector import Vec3Array -from fastfold.utils.rigid_utils import Rigid, Rotation -from fastfold.utils.tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims - -from .primitives import LayerNorm, Linear, ipa_point_weights_init_ - - -class AngleResnetBlock(nn.Module): - - def __init__(self, c_hidden): - """ - Args: - c_hidden: - Hidden channel dimension - """ - super(AngleResnetBlock, self).__init__() - - self.c_hidden = c_hidden - - self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu") - self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final") - - self.relu = nn.ReLU() - - def forward(self, a: torch.Tensor) -> torch.Tensor: - - s_initial = a - - a = self.relu(a) - a = self.linear_1(a) - a = self.relu(a) - a = self.linear_2(a) - - return a + s_initial - - -class AngleResnet(nn.Module): - """ - Implements Algorithm 20, lines 11-14 - """ - - def __init__(self, c_in: int, c_hidden: int, no_blocks: int, no_angles: int, epsilon: float): - """ - Args: - c_in: - Input channel dimension - c_hidden: - Hidden channel dimension - no_blocks: - Number of resnet blocks - no_angles: - Number of torsion angles to generate - epsilon: - Small constant for normalization - """ - super(AngleResnet, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_blocks = no_blocks - self.no_angles = no_angles - self.eps = epsilon - - self.linear_in = Linear(self.c_in, self.c_hidden) - self.linear_initial = Linear(self.c_in, self.c_hidden) - - self.layers = nn.ModuleList() - for _ in range(self.no_blocks): - layer = AngleResnetBlock(c_hidden=self.c_hidden) - self.layers.append(layer) - - self.linear_out = Linear(self.c_hidden, self.no_angles * 2) - - self.relu = nn.ReLU() - - def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - s: - [*, C_hidden] single embedding - s_initial: - [*, C_hidden] single embedding as of the start of the - StructureModule - Returns: - [*, no_angles, 2] predicted angles - """ - # NOTE: The ReLU's applied to the inputs are absent from the supplement - # pseudocode but present in the source. For maximal compatibility with - # the pretrained weights, I'm going with the source. - - # [*, C_hidden] - s_initial = self.relu(s_initial) - s_initial = self.linear_initial(s_initial) - s = self.relu(s) - s = self.linear_in(s) - s = s + s_initial - - for ll in self.layers: - s = ll(s) - - s = self.relu(s) - - # [*, no_angles * 2] - s = self.linear_out(s) - - # [*, no_angles, 2] - s = s.view(s.shape[:-1] + (-1, 2)) - - unnormalized_s = s - norm_denom = torch.sqrt(torch.clamp( - torch.sum(s**2, dim=-1, keepdim=True), - min=self.eps, - )) - s = s / norm_denom - - return unnormalized_s, s - - -class PointProjection(nn.Module): - - def __init__( - self, - c_hidden: int, - num_points: int, - no_heads: int, - return_local_points: bool = False, - ): - super().__init__() - self.return_local_points = return_local_points - self.no_heads = no_heads - - self.linear = Linear(c_hidden, no_heads * 3 * num_points) - - def forward( - self, - activations: torch.Tensor, - rigids: Rigid3Array, - ) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array]]: - # TODO: Needs to run in high precision during training - points_local = self.linear(activations) - points_local = points_local.reshape( - *points_local.shape[:-1], - self.no_heads, - -1, - ) - points_local = torch.split(points_local, points_local.shape[-1] // 3, dim=-1) - points_local = Vec3Array(*points_local) - points_global = rigids[..., None, None].apply_to_point(points_local) - - if self.return_local_points: - return points_global, points_local - - return points_global - - -class InvariantPointAttention(nn.Module): - """ - Implements Algorithm 22. - """ - - def __init__( - self, - c_s: int, - c_z: int, - c_hidden: int, - no_heads: int, - no_qk_points: int, - no_v_points: int, - inf: float = 1e5, - eps: float = 1e-8, - is_multimer: bool = False, - ): - """ - Args: - c_s: - Single representation channel dimension - c_z: - Pair representation channel dimension - c_hidden: - Hidden channel dimension - no_heads: - Number of attention heads - no_qk_points: - Number of query/key points to generate - no_v_points: - Number of value points to generate - """ - super(InvariantPointAttention, self).__init__() - - self.c_s = c_s - self.c_z = c_z - self.c_hidden = c_hidden - self.no_heads = no_heads - self.no_qk_points = no_qk_points - self.no_v_points = no_v_points - self.inf = inf - self.eps = eps - self.is_multimer = is_multimer - - # These linear layers differ from their specifications in the - # supplement. There, they lack bias and use Glorot initialization. - # Here as in the official source, they have bias and use the default - # Lecun initialization. - if not self.is_multimer: - hc = self.c_hidden * self.no_heads - self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer)) - self.linear_kv = Linear(self.c_s, 2 * hc) - - hpq = self.no_heads * self.no_qk_points * 3 - self.linear_q_points = Linear(self.c_s, hpq) - - hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 - self.linear_kv_points = Linear(self.c_s, hpkv) - - # hpv = self.no_heads * self.no_v_points * 3 - - else: - hc = self.c_hidden * self.no_heads - self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer)) - self.linear_q_points = PointProjection(self.c_s, self.no_qk_points, self.no_heads) - - self.linear_k = Linear(self.c_s, hc, bias=False) - self.linear_v = Linear(self.c_s, hc, bias=False) - self.linear_k_points = PointProjection( - self.c_s, - self.no_qk_points, - self.no_heads, - ) - - self.linear_v_points = PointProjection( - self.c_s, - self.no_v_points, - self.no_heads, - ) - self.linear_b = Linear(self.c_z, self.no_heads) - - self.head_weights = nn.Parameter(torch.zeros((no_heads))) - ipa_point_weights_init_(self.head_weights) - - concat_out_dim = self.no_heads * (self.c_z + self.c_hidden + self.no_v_points * 4) - self.linear_out = Linear(concat_out_dim, self.c_s, init="final") - - self.softmax = nn.Softmax(dim=-1) - self.softplus = nn.Softplus() - - def forward( - self, - s: torch.Tensor, - z: torch.Tensor, - r: Union[Rigid, Rigid3Array], - mask: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - s: - [*, N_res, C_s] single representation - z: - [*, N_res, N_res, C_z] pair representation - r: - [*, N_res] transformation object - mask: - [*, N_res] mask - Returns: - [*, N_res, C_s] single representation update - """ - ####################################### - # Generate scalar and point activations - ####################################### - - # The following two blocks are equivalent - # They're separated only to preserve compatibility with old AF weights - if self.is_multimer: - # [*, N_res, H * C_hidden] - q = self.linear_q(s) - - # [*, N_res, H, C_hidden] - q = q.view(q.shape[:-1] + (self.no_heads, -1)) - - # [*, N_res, H, P_qk] - q_pts = self.linear_q_points(s, r) - # [*, N_res, H * C_hidden] - k = self.linear_k(s) - v = self.linear_v(s) - - # [*, N_res, H, C_hidden] - k = k.view(k.shape[:-1] + (self.no_heads, -1)) - v = v.view(v.shape[:-1] + (self.no_heads, -1)) - - # [*, N_res, H, P_qk, 3] - k_pts = self.linear_k_points(s, r) - - # [*, N_res, H, P_v, 3] - v_pts = self.linear_v_points(s, r) - else: - # [*, N_res, H * C_hidden] - q = self.linear_q(s) - kv = self.linear_kv(s) - - # [*, N_res, H, C_hidden] - q = q.view(q.shape[:-1] + (self.no_heads, -1)) - - # [*, N_res, H, 2 * C_hidden] - kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) - - # [*, N_res, H, C_hidden] - k, v = torch.split(kv, self.c_hidden, dim=-1) - - # [*, N_res, H * P_q * 3] - q_pts = self.linear_q_points(s) - - # This is kind of clunky, but it's how the original does it - # [*, N_res, H * P_q, 3] - q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) - q_pts = torch.stack(q_pts, dim=-1) - q_pts = r[..., None].apply(q_pts) - - # [*, N_res, H, P_q, 3] - q_pts = q_pts.view(q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)) - - # [*, N_res, H * (P_q + P_v) * 3] - kv_pts = self.linear_kv_points(s) - - # [*, N_res, H * (P_q + P_v), 3] - kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) - kv_pts = torch.stack(kv_pts, dim=-1) - kv_pts = r[..., None].apply(kv_pts) - - # [*, N_res, H, (P_q + P_v), 3] - kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) - - # [*, N_res, H, P_q/P_v, 3] - k_pts, v_pts = torch.split(kv_pts, [self.no_qk_points, self.no_v_points], dim=-2) - - ########################## - # Compute attention scores - ########################## - # [*, N_res, N_res, H] - b = self.linear_b(z) - - # [*, H, N_res, N_res] - a = torch.matmul( - permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] - permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] - ) - a *= math.sqrt(1.0 / (3 * self.c_hidden)) - a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)) - - if self.is_multimer: - # [*, N_res, N_res, H, P_q, 3] - pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :] - # [*, N_res, N_res, H, P_q] - pt_att = sum([c**2 for c in pt_att]) - else: - # [*, N_res, N_res, H, P_q, 3] - pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) - pt_att = pt_att**2 - # [*, N_res, N_res, H, P_q] - pt_att = sum(torch.unbind(pt_att, dim=-1)) - - head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1))) - head_weights = head_weights * math.sqrt(1.0 / (3 * (self.no_qk_points * 9.0 / 2))) - pt_att = pt_att * head_weights - - # [*, N_res, N_res, H] - pt_att = torch.sum(pt_att, dim=-1) * (-0.5) - # [*, N_res, N_res] - square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) - square_mask = self.inf * (square_mask - 1) - - # [*, H, N_res, N_res] - pt_att = permute_final_dims(pt_att, (2, 0, 1)) - a = a + pt_att - a = a + square_mask.unsqueeze(-3) - a = self.softmax(a) - - ################ - # Compute output - ################ - # [*, N_res, H, C_hidden] - o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3) - - # [*, N_res, H * C_hidden] - o = flatten_final_dims(o, 2) - - # As DeepMind explains, this manual matmul ensures that the operation - # happens in float32. - if self.is_multimer: - # [*, N_res, H, P_v] - o_pt = v_pts * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1) - o_pt = o_pt.sum(dim=-3) - - # [*, N_res, H, P_v] - o_pt = r[..., None, None].apply_inverse_to_point(o_pt) - - # [*, N_res, H * P_v, 3] - o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,)) - - # [*, N_res, H * P_v] - o_pt_norm = o_pt.norm(self.eps) - else: - # [*, H, 3, N_res, P_v] - o_pt = torch.sum( - (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), - dim=-2, - ) - - # [*, N_res, H, P_v, 3] - o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) - o_pt = r[..., None, None].invert_apply(o_pt) - - # [*, N_res, H * P_v] - o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.eps), 2) - - # [*, N_res, H * P_v, 3] - o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) - - # [*, N_res, H, C_z] - o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype)) - - # [*, N_res, H * C_z] - o_pair = flatten_final_dims(o_pair, 2) - - # [*, N_res, C_s] - if self.is_multimer: - s = self.linear_out(torch.cat((o, *o_pt, o_pt_norm, o_pair), dim=-1).to(dtype=z.dtype)) - else: - s = self.linear_out( - torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z.dtype)) - - return s - - -class BackboneUpdate(nn.Module): - """ - Implements part of Algorithm 23. - """ - - def __init__(self, c_s: int): - """ - Args: - c_s: - Single representation channel dimension - """ - super(BackboneUpdate, self).__init__() - - self.c_s = c_s - - self.linear = Linear(self.c_s, 6, init="final") - - def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - [*, N_res, C_s] single representation - Returns: - [*, N_res, 6] update vector - """ - # [*, 6] - update = self.linear(s) - - return update - - -class StructureModuleTransitionLayer(nn.Module): - - def __init__(self, c: int): - super(StructureModuleTransitionLayer, self).__init__() - - self.c = c - - self.linear_1 = Linear(self.c, self.c, init="relu") - self.linear_2 = Linear(self.c, self.c, init="relu") - self.linear_3 = Linear(self.c, self.c, init="final") - - self.relu = nn.ReLU() - - def forward(self, s: torch.Tensor): - s_initial = s - s = self.linear_1(s) - s = self.relu(s) - s = self.linear_2(s) - s = self.relu(s) - s = self.linear_3(s) - - s = s + s_initial - - return s - - -class StructureModuleTransition(nn.Module): - - def __init__(self, c: int, num_layers: int, dropout_rate: float): - super(StructureModuleTransition, self).__init__() - - self.c = c - self.num_layers = num_layers - self.dropout_rate = dropout_rate - - self.layers = nn.ModuleList() - for _ in range(self.num_layers): - ll = StructureModuleTransitionLayer(self.c) - self.layers.append(ll) - - self.dropout = nn.Dropout(self.dropout_rate) - self.layer_norm = LayerNorm(self.c) - - def forward(self, s: torch.Tensor) -> torch.Tensor: - for ll in self.layers: - s = ll(s) - - s = self.dropout(s) - s = self.layer_norm(s) - - return s - - -class StructureModule(nn.Module): - - def __init__( - self, - c_s: int, - c_z: int, - c_ipa: int, - c_resnet: int, - no_heads_ipa: int, - no_qk_points: int, - no_v_points: int, - dropout_rate: float, - no_blocks: int, - no_transition_layers: int, - no_resnet_blocks: int, - no_angles: int, - trans_scale_factor: float, - epsilon: float, - inf: float, - is_multimer: bool = False, - **kwargs, - ): - """ - Args: - c_s: - Single representation channel dimension - c_z: - Pair representation channel dimension - c_ipa: - IPA hidden channel dimension - c_resnet: - Angle resnet (Alg. 23 lines 11-14) hidden channel dimension - no_heads_ipa: - Number of IPA heads - no_qk_points: - Number of query/key points to generate during IPA - no_v_points: - Number of value points to generate during IPA - dropout_rate: - Dropout rate used throughout the layer - no_blocks: - Number of structure module blocks - no_transition_layers: - Number of layers in the single representation transition - (Alg. 23 lines 8-9) - no_resnet_blocks: - Number of blocks in the angle resnet - no_angles: - Number of angles to generate in the angle resnet - trans_scale_factor: - Scale of single representation transition hidden dimension - epsilon: - Small number used in angle resnet normalization - inf: - Large number used for attention masking - is_multimer: - whether running under multimer mode - """ - super(StructureModule, self).__init__() - - self.c_s = c_s - self.c_z = c_z - self.c_ipa = c_ipa - self.c_resnet = c_resnet - self.no_heads_ipa = no_heads_ipa - self.no_qk_points = no_qk_points - self.no_v_points = no_v_points - self.dropout_rate = dropout_rate - self.no_blocks = no_blocks - self.no_transition_layers = no_transition_layers - self.no_resnet_blocks = no_resnet_blocks - self.no_angles = no_angles - self.trans_scale_factor = trans_scale_factor - self.epsilon = epsilon - self.inf = inf - self.is_multimer = is_multimer - - # To be lazily initialized later - self.default_frames = None - self.group_idx = None - self.atom_mask = None - self.lit_positions = None - - self.layer_norm_s = LayerNorm(self.c_s) - self.layer_norm_z = LayerNorm(self.c_z) - - self.linear_in = Linear(self.c_s, self.c_s) - - self.ipa = InvariantPointAttention( - self.c_s, - self.c_z, - self.c_ipa, - self.no_heads_ipa, - self.no_qk_points, - self.no_v_points, - inf=self.inf, - eps=self.epsilon, - is_multimer=self.is_multimer, - ) - - self.ipa_dropout = nn.Dropout(self.dropout_rate) - self.layer_norm_ipa = LayerNorm(self.c_s) - - self.transition = StructureModuleTransition( - self.c_s, - self.no_transition_layers, - self.dropout_rate, - ) - - if is_multimer: - self.bb_update = QuatRigid(self.c_s, full_quat=False) - else: - self.bb_update = BackboneUpdate(self.c_s) - - self.angle_resnet = AngleResnet( - self.c_s, - self.c_resnet, - self.no_resnet_blocks, - self.no_angles, - self.epsilon, - ) - - def _forward_monomer( - self, - s: torch.Tensor, - z: torch.Tensor, - aatype: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> Dict[str, Any]: - """ - Args: - s: - [*, N_res, C_s] single representation - z: - [*, N_res, N_res, C_z] pair representation - aatype: - [*, N_res] amino acid indices - mask: - Optional [*, N_res] sequence mask - Returns: - A dictionary of outputs - """ - if mask is None: - # [*, N] - mask = s.new_ones(s.shape[:-1]) - - # [*, N, C_s] - s = self.layer_norm_s(s) - - # [*, N, N, C_z] - z = self.layer_norm_z(z) - - # [*, N, C_s] - s_initial = s - s = self.linear_in(s) - - # [*, N] - rigids = Rigid.identity( - s.shape[:-1], - s.dtype, - s.device, - self.training, - fmt="quat", - ) - outputs = [] - for i in range(self.no_blocks): - # [*, N, C_s] - s = s + self.ipa(s, z, rigids, mask) - s = self.ipa_dropout(s) - s = self.layer_norm_ipa(s) - s = self.transition(s) - - # [*, N] - rigids = rigids.compose_q_update_vec(self.bb_update(s)) - - # To hew as closely as possible to AlphaFold, we convert our - # quaternion-based transformations to rotation-matrix ones - # here - backb_to_global = Rigid( - Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None), - rigids.get_trans(), - ) - - backb_to_global = backb_to_global.scale_translation(self.trans_scale_factor) - - # [*, N, 7, 2] - unnormalized_angles, angles = self.angle_resnet(s, s_initial) - - all_frames_to_global = self.torsion_angles_to_frames( - backb_to_global, - angles, - aatype, - ) - - pred_xyz = self.frames_and_literature_positions_to_atom14_pos( - all_frames_to_global, - aatype, - ) - - scaled_rigids = rigids.scale_translation(self.trans_scale_factor) - - preds = { - "frames": scaled_rigids.to_tensor_7(), - "sidechain_frames": all_frames_to_global.to_tensor_4x4(), - "unnormalized_angles": unnormalized_angles, - "angles": angles, - "positions": pred_xyz, - } - - outputs.append(preds) - - if i < (self.no_blocks - 1): - rigids = rigids.stop_rot_gradient() - - outputs = dict_multimap(torch.stack, outputs) - outputs["single"] = s - - return outputs - - def _forward_multimer( - self, - s: torch.Tensor, - z: torch.Tensor, - aatype: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> Dict[str, Any]: - if mask is None: - # [*, N] - mask = s.new_ones(s.shape[:-1]) - - # [*, N, C_s] - s = self.layer_norm_s(s) - - # [*, N, N, C_z] - z = self.layer_norm_z(z) - - # [*, N, C_s] - s_initial = s - s = self.linear_in(s) - - # [*, N] - rigids = Rigid3Array.identity( - s.shape[:-1], - s.device, - ) - outputs = [] - for i in range(self.no_blocks): - # [*, N, C_s] - s = s + self.ipa(s, z, rigids, mask) - s = self.ipa_dropout(s) - s = self.layer_norm_ipa(s) - s = self.transition(s) - - # [*, N] - rigids = rigids @ self.bb_update(s) - - # [*, N, 7, 2] - unnormalized_angles, angles = self.angle_resnet(s, s_initial) - - all_frames_to_global = self.torsion_angles_to_frames( - rigids.scale_translation(self.trans_scale_factor), - angles, - aatype, - ) - - pred_xyz = self.frames_and_literature_positions_to_atom14_pos( - all_frames_to_global, - aatype, - ) - - preds = { - "frames": rigids.scale_translation(self.trans_scale_factor).to_tensor(), - "sidechain_frames": all_frames_to_global.to_tensor_4x4(), - "unnormalized_angles": unnormalized_angles, - "angles": angles, - "positions": pred_xyz.to_tensor(), - } - - outputs.append(preds) - - if i < (self.no_blocks - 1): - rigids = rigids.stop_rot_gradient() - - outputs = dict_multimap(torch.stack, outputs) - outputs["single"] = s - - return outputs - - def forward( - self, - s: torch.Tensor, - z: torch.Tensor, - aatype: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ): - """ - Args: - s: - [*, N_res, C_s] single representation - z: - [*, N_res, N_res, C_z] pair representation - aatype: - [*, N_res] amino acid indices - mask: - Optional [*, N_res] sequence mask - Returns: - A dictionary of outputs - """ - if self.is_multimer: - outputs = self._forward_multimer(s, z, aatype, mask) - else: - outputs = self._forward_monomer(s, z, aatype, mask) - - return outputs - - def _init_residue_constants(self, float_dtype: torch.dtype, device: torch.device): - if self.default_frames is None: - self.default_frames = torch.tensor( - restype_rigid_group_default_frame, - dtype=float_dtype, - device=device, - requires_grad=False, - ) - if self.group_idx is None: - self.group_idx = torch.tensor( - restype_atom14_to_rigid_group, - device=device, - requires_grad=False, - ) - if self.atom_mask is None: - self.atom_mask = torch.tensor( - restype_atom14_mask, - dtype=float_dtype, - device=device, - requires_grad=False, - ) - if self.lit_positions is None: - self.lit_positions = torch.tensor( - restype_atom14_rigid_group_positions, - dtype=float_dtype, - device=device, - requires_grad=False, - ) - - def torsion_angles_to_frames(self, r: Union[Rigid, Rigid3Array], alpha: torch.Tensor, f): - # Lazily initialize the residue constants on the correct device - self._init_residue_constants(alpha.dtype, alpha.device) - # Separated purely to make testing less annoying - return torsion_angles_to_frames(r, alpha, f, self.default_frames) - - def frames_and_literature_positions_to_atom14_pos( - self, - r: Union[Rigid, Rigid3Array], - f # [*, N, 8] # [*, N] - ): - # Lazily initialize the residue constants on the correct device - if type(r) == Rigid: - self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) - elif type(r) == Rigid3Array: - self._init_residue_constants(r.dtype, r.device) - else: - raise ValueError("Unknown rigid type") - return frames_and_literature_positions_to_atom14_pos( - r, - f, - self.default_frames, - self.group_idx, - self.atom_mask, - self.lit_positions, - ) diff --git a/tests/test_autochunk/origin_openfold/template.py b/tests/test_autochunk/origin_openfold/template.py deleted file mode 100644 index 6c4c0f42875e..000000000000 --- a/tests/test_autochunk/origin_openfold/template.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from functools import partial -from typing import List, Optional - -import torch -import torch.nn as nn -from fastfold.model.nn.dropout import DropoutColumnwise, DropoutRowwise -from fastfold.model.nn.pair_transition import PairTransition -from fastfold.model.nn.primitives import Attention, LayerNorm, Linear -from fastfold.model.nn.triangular_attention import TriangleAttentionEndingNode, TriangleAttentionStartingNode -from fastfold.model.nn.triangular_multiplicative_update import ( - TriangleMultiplicationIncoming, - TriangleMultiplicationOutgoing, -) -from fastfold.utils.checkpointing import checkpoint_blocks -from fastfold.utils.tensor_utils import chunk_layer, flatten_final_dims, permute_final_dims - - -class TemplatePointwiseAttention(nn.Module): - """ - Implements Algorithm 17. - """ - - def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs): - """ - Args: - c_t: - Template embedding channel dimension - c_z: - Pair embedding channel dimension - c_hidden: - Hidden channel dimension - """ - super(TemplatePointwiseAttention, self).__init__() - - self.c_t = c_t - self.c_z = c_z - self.c_hidden = c_hidden - self.no_heads = no_heads - self.inf = inf - - self.mha = Attention( - self.c_z, - self.c_t, - self.c_t, - self.c_hidden, - self.no_heads, - gating=False, - ) - - def _chunk( - self, - z: torch.Tensor, - t: torch.Tensor, - biases: List[torch.Tensor], - chunk_size: int, - ) -> torch.Tensor: - mha_inputs = { - "q_x": z, - "kv_x": t, - "biases": biases, - } - return chunk_layer( - self.mha, - mha_inputs, - chunk_size=chunk_size, - no_batch_dims=len(z.shape[:-2]), - ) - - def forward(self, - t: torch.Tensor, - z: torch.Tensor, - template_mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None) -> torch.Tensor: - """ - Args: - t: - [*, N_templ, N_res, N_res, C_t] template embedding - z: - [*, N_res, N_res, C_t] pair embedding - template_mask: - [*, N_templ] template mask - Returns: - [*, N_res, N_res, C_z] pair embedding update - """ - if template_mask is None: - template_mask = t.new_ones(t.shape[:-3]) - - bias = self.inf * (template_mask[..., None, None, None, None, :] - 1) - - # [*, N_res, N_res, 1, C_z] - z = z.unsqueeze(-2) - - # [*, N_res, N_res, N_temp, C_t] - t = permute_final_dims(t, (1, 2, 0, 3)) - - # [*, N_res, N_res, 1, C_z] - biases = [bias] - if chunk_size is not None: - z = self._chunk(z, t, biases, chunk_size) - else: - z = self.mha(q_x=z, kv_x=t, biases=biases) - - # [*, N_res, N_res, C_z] - z = z.squeeze(-2) - - return z - - -class TemplatePairStackBlock(nn.Module): - - def __init__( - self, - c_t: int, - c_hidden_tri_att: int, - c_hidden_tri_mul: int, - no_heads: int, - pair_transition_n: int, - dropout_rate: float, - inf: float, - is_multimer: bool = False, - **kwargs, - ): - super(TemplatePairStackBlock, self).__init__() - - self.c_t = c_t - self.c_hidden_tri_att = c_hidden_tri_att - self.c_hidden_tri_mul = c_hidden_tri_mul - self.no_heads = no_heads - self.pair_transition_n = pair_transition_n - self.dropout_rate = dropout_rate - self.inf = inf - self.is_multimer = is_multimer - - self.dropout_row = DropoutRowwise(self.dropout_rate) - self.dropout_col = DropoutColumnwise(self.dropout_rate) - - self.tri_att_start = TriangleAttentionStartingNode( - self.c_t, - self.c_hidden_tri_att, - self.no_heads, - inf=inf, - ) - self.tri_att_end = TriangleAttentionEndingNode( - self.c_t, - self.c_hidden_tri_att, - self.no_heads, - inf=inf, - ) - - self.tri_mul_out = TriangleMultiplicationOutgoing( - self.c_t, - self.c_hidden_tri_mul, - ) - self.tri_mul_in = TriangleMultiplicationIncoming( - self.c_t, - self.c_hidden_tri_mul, - ) - - self.pair_transition = PairTransition( - self.c_t, - self.pair_transition_n, - ) - - def forward(self, z: torch.Tensor, mask: torch.Tensor, chunk_size: Optional[int] = None, _mask_trans: bool = True): - single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)] - single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)] - if not self.is_multimer: - for i in range(len(single_templates)): - single = single_templates[i] - single_mask = single_templates_masks[i] - - single = single + self.dropout_row(self.tri_att_start(single, chunk_size=chunk_size, mask=single_mask)) - single = single + self.dropout_col(self.tri_att_end(single, chunk_size=chunk_size, mask=single_mask)) - single = single + self.dropout_row(self.tri_mul_out(single, mask=single_mask)) - single = single + self.dropout_row(self.tri_mul_in(single, mask=single_mask)) - single = single + self.pair_transition( - single, - mask=single_mask if _mask_trans else None, - chunk_size=chunk_size, - ) - - single_templates[i] = single - else: - for i in range(len(single_templates)): - single = single_templates[i] - single_mask = single_templates_masks[i] - - single = single + self.dropout_row(self.tri_att_start(single, chunk_size=chunk_size, mask=single_mask)) - single = single + self.dropout_col(self.tri_att_end(single, chunk_size=chunk_size, mask=single_mask)) - single = single + self.dropout_row(self.tri_mul_out(single, mask=single_mask)) - single = single + self.dropout_row(self.tri_mul_in(single, mask=single_mask)) - single = single + self.pair_transition( - single, - mask=single_mask if _mask_trans else None, - chunk_size=chunk_size, - ) - single_templates[i] = single - - z = torch.cat(single_templates, dim=-4) - - return z - - -class TemplatePairStack(nn.Module): - """ - Implements Algorithm 16. - """ - - def __init__( - self, - c_t, - c_hidden_tri_att, - c_hidden_tri_mul, - no_blocks, - no_heads, - pair_transition_n, - dropout_rate, - blocks_per_ckpt, - inf=1e9, - **kwargs, - ): - """ - Args: - c_t: - Template embedding channel dimension - c_hidden_tri_att: - Per-head hidden dimension for triangular attention - c_hidden_tri_att: - Hidden dimension for triangular multiplication - no_blocks: - Number of blocks in the stack - pair_transition_n: - Scale of pair transition (Alg. 15) hidden dimension - dropout_rate: - Dropout rate used throughout the stack - blocks_per_ckpt: - Number of blocks per activation checkpoint. None disables - activation checkpointing - """ - super(TemplatePairStack, self).__init__() - - self.blocks_per_ckpt = blocks_per_ckpt - - self.blocks = nn.ModuleList() - for _ in range(no_blocks): - block = TemplatePairStackBlock( - c_t=c_t, - c_hidden_tri_att=c_hidden_tri_att, - c_hidden_tri_mul=c_hidden_tri_mul, - no_heads=no_heads, - pair_transition_n=pair_transition_n, - dropout_rate=dropout_rate, - inf=inf, - ) - self.blocks.append(block) - - self.layer_norm = LayerNorm(c_t) - - def forward( - self, - t: torch.tensor, - mask: torch.tensor, - chunk_size: int, - _mask_trans: bool = True, - ): - """ - Args: - t: - [*, N_templ, N_res, N_res, C_t] template embedding - mask: - [*, N_templ, N_res, N_res] mask - Returns: - [*, N_templ, N_res, N_res, C_t] template embedding update - """ - if (mask.shape[-3] == 1): - expand_idx = list(mask.shape) - expand_idx[-3] = t.shape[-4] - mask = mask.expand(*expand_idx) - - t, = checkpoint_blocks( - blocks=[partial( - b, - mask=mask, - chunk_size=chunk_size, - _mask_trans=_mask_trans, - ) for b in self.blocks], - args=(t,), - blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, - ) - - t = self.layer_norm(t) - - return t diff --git a/tests/test_autochunk/origin_openfold/triangular_attention.py b/tests/test_autochunk/origin_openfold/triangular_attention.py deleted file mode 100644 index 4d2e8efc419d..000000000000 --- a/tests/test_autochunk/origin_openfold/triangular_attention.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from functools import partial, partialmethod -from typing import List, Optional - -import torch -import torch.nn as nn - -from .primitives import Attention, LayerNorm, Linear -from .utils.tensor_utils import chunk_layer, flatten_final_dims, permute_final_dims - - -class TriangleAttention(nn.Module): - - def __init__(self, c_in, c_hidden, no_heads, starting, inf=1e9): - """ - Args: - c_in: - Input channel dimension - c_hidden: - Overall hidden channel dimension (not per-head) - no_heads: - Number of attention heads - """ - super(TriangleAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.starting = starting - self.inf = inf - - self.layer_norm = LayerNorm(self.c_in) - - self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") - - self.mha = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) - - @torch.jit.ignore - def _chunk( - self, - x: torch.Tensor, - biases: List[torch.Tensor], - chunk_size: int, - ) -> torch.Tensor: - mha_inputs = { - "q_x": x, - "kv_x": x, - "biases": biases, - } - return chunk_layer( - partial(self.mha), - mha_inputs, - chunk_size=chunk_size, - no_batch_dims=len(x.shape[:-2]), - ) - - def forward(self, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None) -> torch.Tensor: - """ - Args: - x: - [*, I, J, C_in] input tensor (e.g. the pair representation) - Returns: - [*, I, J, C_in] output tensor - """ - if mask is None: - # [*, I, J] - mask = x.new_ones(x.shape[:-1],) - - # Shape annotations assume self.starting. Else, I and J are flipped - if not self.starting: - x = x.transpose(-2, -3) - mask = mask.transpose(-1, -2) - - # [*, I, J, C_in] - x = self.layer_norm(x) - - # [*, I, 1, 1, J] - mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] - - # [*, H, I, J] - triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) - - # [*, 1, H, I, J] - triangle_bias = triangle_bias.unsqueeze(-4) - - biases = [mask_bias, triangle_bias] - - if chunk_size is not None: - x = self._chunk(x, biases, chunk_size) - else: - x = self.mha(q_x=x, kv_x=x, biases=biases) - - if not self.starting: - x = x.transpose(-2, -3) - - return x - - -class TriangleAttentionStartingNode(TriangleAttention): - """ - Implements Algorithm 13. - """ - - __init__ = partialmethod(TriangleAttention.__init__, starting=True) - - -class TriangleAttentionEndingNode(TriangleAttention): - """ - Implements Algorithm 14. - """ - - __init__ = partialmethod(TriangleAttention.__init__, starting=False) diff --git a/tests/test_autochunk/origin_openfold/triangular_multiplicative_update.py b/tests/test_autochunk/origin_openfold/triangular_multiplicative_update.py deleted file mode 100644 index f02e9033ae15..000000000000 --- a/tests/test_autochunk/origin_openfold/triangular_multiplicative_update.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partialmethod -from typing import Optional - -import torch -import torch.nn as nn - -from .primitives import LayerNorm, Linear -from .utils.tensor_utils import permute_final_dims - - -class TriangleMultiplicativeUpdate(nn.Module): - """ - Implements Algorithms 11 and 12. - """ - - def __init__(self, c_z, c_hidden, _outgoing=True): - """ - Args: - c_z: - Input channel dimension - c: - Hidden channel dimension - """ - super(TriangleMultiplicativeUpdate, self).__init__() - self.c_z = c_z - self.c_hidden = c_hidden - self._outgoing = _outgoing - - self.linear_a_p = Linear(self.c_z, self.c_hidden) - self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating") - self.linear_b_p = Linear(self.c_z, self.c_hidden) - self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating") - self.linear_g = Linear(self.c_z, self.c_z, init="gating") - self.linear_z = Linear(self.c_hidden, self.c_z, init="final") - - self.layer_norm_in = LayerNorm(self.c_z) - self.layer_norm_out = LayerNorm(self.c_hidden) - - self.sigmoid = nn.Sigmoid() - - def _combine_projections( - self, - a: torch.Tensor, - b: torch.Tensor, - ) -> torch.Tensor: - raise NotImplementedError("This method needs to be overridden") - - def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - Args: - x: - [*, N_res, N_res, C_z] input tensor - mask: - [*, N_res, N_res] input mask - Returns: - [*, N_res, N_res, C_z] output tensor - """ - if mask is None: - mask = z.new_ones(z.shape[:-1]) - - mask = mask.unsqueeze(-1) - - z = self.layer_norm_in(z) - a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z)) - a = a * mask - b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) - b = b * mask - x = self._combine_projections(a, b) - x = self.layer_norm_out(x) - x = self.linear_z(x) - g = self.sigmoid(self.linear_g(z)) - z = x * g - - return z - - -class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): - """ - Implements Algorithm 11. - """ - - def _combine_projections( - self, - a: torch.Tensor, # [*, N_i, N_k, C] - b: torch.Tensor, # [*, N_j, N_k, C] - ): - # [*, C, N_i, N_j] - p = torch.matmul( - permute_final_dims(a, (2, 0, 1)), - permute_final_dims(b, (2, 1, 0)), - ) - - # [*, N_i, N_j, C] - return permute_final_dims(p, (1, 2, 0)) - - -class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): - """ - Implements Algorithm 12. - """ - - def _combine_projections( - self, - a: torch.Tensor, # [*, N_k, N_i, C] - b: torch.Tensor, # [*, N_k, N_j, C] - ): - # [*, C, N_i, N_j] - p = torch.matmul( - permute_final_dims(a, (2, 1, 0)), - permute_final_dims(b, (2, 0, 1)), - ) - - # [*, N_i, N_j, C] - return permute_final_dims(p, (1, 2, 0)) diff --git a/tests/test_autochunk/origin_openfold/utils/all_atom_multimer.py b/tests/test_autochunk/origin_openfold/utils/all_atom_multimer.py deleted file mode 100644 index 0b3199698bbd..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/all_atom_multimer.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Ops for all atom representations.""" - -from functools import partial -from typing import Dict, Text, Tuple - -import numpy as np -import torch -from fastfold.common import residue_constants as rc -from fastfold.utils import geometry, tensor_utils - - -def squared_difference(x, y): - return np.square(x - y) - - -def get_rc_tensor(rc_np, aatype): - return torch.tensor(rc_np, device=aatype.device)[aatype] - - -def atom14_to_atom37( - atom14_data: torch.Tensor, # (*, N, 14, ...) - aatype: torch.Tensor # (*, N) -) -> torch.Tensor: # (*, N, 37, ...) - """Convert atom14 to atom37 representation.""" - idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype) - no_batch_dims = len(aatype.shape) - 1 - atom37_data = tensor_utils.batched_gather(atom14_data, - idx_atom37_to_atom14, - dim=no_batch_dims + 1, - no_batch_dims=no_batch_dims + 1) - atom37_mask = get_rc_tensor(rc.RESTYPE_ATOM37_MASK, aatype) - if len(atom14_data.shape) == no_batch_dims + 2: - atom37_data *= atom37_mask - elif len(atom14_data.shape) == no_batch_dims + 3: - atom37_data *= atom37_mask[..., None].astype(atom37_data.dtype) - else: - raise ValueError("Incorrectly shaped data") - return atom37_data - - -def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask): - """Convert Atom37 positions to Atom14 positions.""" - residx_atom14_to_atom37 = get_rc_tensor(rc.RESTYPE_ATOM14_TO_ATOM37, aatype) - no_batch_dims = len(aatype.shape) - atom14_mask = tensor_utils.batched_gather( - all_atom_mask, - residx_atom14_to_atom37, - dim=no_batch_dims + 1, - no_batch_dims=no_batch_dims + 1, - ).to(torch.float32) - # create a mask for known groundtruth positions - atom14_mask *= get_rc_tensor(rc.RESTYPE_ATOM14_MASK, aatype) - # gather the groundtruth positions - atom14_positions = tensor_utils.batched_gather( - all_atom_pos, - residx_atom14_to_atom37, - dim=no_batch_dims + 1, - no_batch_dims=no_batch_dims + 1, - ), - atom14_positions = atom14_mask * atom14_positions - return atom14_positions, atom14_mask - - -def get_alt_atom14(aatype, positions: torch.Tensor, mask): - """Get alternative atom14 positions.""" - # pick the transformation matrices for the given residue sequence - # shape (num_res, 14, 14) - renaming_transform = get_rc_tensor(rc.RENAMING_MATRICES, aatype) - alternative_positions = torch.sum(positions[..., None, :] * renaming_transform[..., None], dim=-2) - - # Create the mask for the alternative ground truth (differs from the - # ground truth mask, if only one of the atoms in an ambiguous pair has a - # ground truth position) - alternative_mask = torch.sum(mask[..., None] * renaming_transform, dim=-2) - - return alternative_positions, alternative_mask - - -def atom37_to_frames( - aatype: torch.Tensor, # (...) - all_atom_positions: torch.Tensor, # (..., 37) - all_atom_mask: torch.Tensor, # (..., 37) -) -> Dict[Text, torch.Tensor]: - """Computes the frames for the up to 8 rigid groups for each residue.""" - # 0: 'backbone group', - # 1: 'pre-omega-group', (empty) - # 2: 'phi-group', (currently empty, because it defines only hydrogens) - # 3: 'psi-group', - # 4,5,6,7: 'chi1,2,3,4-group' - - no_batch_dims = len(aatype.shape) - 1 - - # Compute the gather indices for all residues in the chain. - # shape (N, 8, 3) - residx_rigidgroup_base_atom37_idx = get_rc_tensor(rc.RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype) - - # Gather the base atom positions for each rigid group. - base_atom_pos = tensor_utils.batched_gather( - all_atom_positions, - residx_rigidgroup_base_atom37_idx, - dim=no_batch_dims + 1, - batch_dims=no_batch_dims + 1, - ) - - # Compute the Rigids. - point_on_neg_x_axis = base_atom_pos[..., :, :, 0] - origin = base_atom_pos[..., :, :, 1] - point_on_xy_plane = base_atom_pos[..., :, :, 2] - gt_rotation = geometry.Rot3Array.from_two_vectors(origin - point_on_neg_x_axis, point_on_xy_plane - origin) - - gt_frames = geometry.Rigid3Array(gt_rotation, origin) - - # Compute a mask whether the group exists. - # (N, 8) - group_exists = get_rc_tensor(rc.RESTYPE_RIGIDGROUP_MASK, aatype) - - # Compute a mask whether ground truth exists for the group - gt_atoms_exist = tensor_utils.batched_gather( # shape (N, 8, 3) - all_atom_mask.to(dtype=torch.float32), - residx_rigidgroup_base_atom37_idx, - batch_dims=no_batch_dims + 1, - ) - gt_exists = torch.min(gt_atoms_exist, dim=-1) * group_exists # (N, 8) - - # Adapt backbone frame to old convention (mirror x-axis and z-axis). - rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) - rots[0, 0, 0] = -1 - rots[0, 2, 2] = -1 - gt_frames = gt_frames.compose_rotation(geometry.Rot3Array.from_array(torch.tensor(rots, device=aatype.device))) - - # The frames for ambiguous rigid groups are just rotated by 180 degree around - # the x-axis. The ambiguous group is always the last chi-group. - restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) - restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) - - for resname, _ in rc.residue_atom_renaming_swaps.items(): - restype = rc.restype_order[rc.restype_3to1[resname]] - chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) - restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 - restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 - restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 - - # Gather the ambiguity information for each residue. - residx_rigidgroup_is_ambiguous = torch.tensor( - restype_rigidgroup_is_ambiguous, - device=aatype.device, - )[aatype] - ambiguity_rot = torch.tensor( - restype_rigidgroup_rots, - device=aatype.device, - )[aatype] - ambiguity_rot = geometry.Rot3Array.from_array(torch.Tensor(ambiguity_rot, device=aatype.device)) - - # Create the alternative ground truth frames. - alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot) - - fix_shape = lambda x: x.reshape(x.shape[:-2] + (8,)) - - # reshape back to original residue layout - gt_frames = fix_shape(gt_frames) - gt_exists = fix_shape(gt_exists) - group_exists = fix_shape(group_exists) - residx_rigidgroup_is_ambiguous = fix_shape(residx_rigidgroup_is_ambiguous) - alt_gt_frames = fix_shape(alt_gt_frames) - - return { - 'rigidgroups_gt_frames': gt_frames, # Rigid (..., 8) - 'rigidgroups_gt_exists': gt_exists, # (..., 8) - 'rigidgroups_group_exists': group_exists, # (..., 8) - 'rigidgroups_group_is_ambiguous': residx_rigidgroup_is_ambiguous, # (..., 8) - 'rigidgroups_alt_gt_frames': alt_gt_frames, # Rigid (..., 8) - } - - -def torsion_angles_to_frames( - aatype: torch.Tensor, # (N) - backb_to_global: geometry.Rigid3Array, # (N) - torsion_angles_sin_cos: torch.Tensor # (N, 7, 2) -) -> geometry.Rigid3Array: # (N, 8) - """Compute rigid group frames from torsion angles.""" - # Gather the default frames for all rigid groups. - # geometry.Rigid3Array with shape (N, 8) - m = get_rc_tensor(rc.restype_rigid_group_default_frame, aatype) - default_frames = geometry.Rigid3Array.from_array4x4(m) - - # Create the rotation matrices according to the given angles (each frame is - # defined such that its rotation is around the x-axis). - sin_angles = torsion_angles_sin_cos[..., 0] - cos_angles = torsion_angles_sin_cos[..., 1] - - # insert zero rotation for backbone group. - num_residues = aatype.shape[-1] - sin_angles = torch.cat([ - torch.zeros_like(aatype).unsqueeze(), - sin_angles, - ], dim=-1) - cos_angles = torch.cat([torch.ones_like(aatype).unsqueeze(), cos_angles], dim=-1) - zeros = torch.zeros_like(sin_angles) - ones = torch.ones_like(sin_angles) - - # all_rots are geometry.Rot3Array with shape (..., N, 8) - all_rots = geometry.Rot3Array(ones, zeros, zeros, zeros, cos_angles, -sin_angles, zeros, sin_angles, cos_angles) - - # Apply rotations to the frames. - all_frames = default_frames.compose_rotation(all_rots) - - # chi2, chi3, and chi4 frames do not transform to the backbone frame but to - # the previous frame. So chain them up accordingly. - - chi1_frame_to_backb = all_frames[..., 4] - chi2_frame_to_backb = chi1_frame_to_backb @ all_frames[..., 5] - chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[..., 6] - chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[..., 7] - - all_frames_to_backb = Rigid3Array.cat([ - all_frames[..., 0:5], chi2_frame_to_backb[..., None], chi3_frame_to_backb[..., None], chi4_frame_to_backb[..., - None] - ], - dim=-1) - - # Create the global frames. - # shape (N, 8) - all_frames_to_global = backb_to_global[..., None] @ all_frames_to_backb - - return all_frames_to_global - - -def frames_and_literature_positions_to_atom14_pos( - aatype: torch.Tensor, # (*, N) - all_frames_to_global: geometry.Rigid3Array # (N, 8) -) -> geometry.Vec3Array: # (*, N, 14) - """Put atom literature positions (atom14 encoding) in each rigid group.""" - # Pick the appropriate transform for every atom. - residx_to_group_idx = get_rc_tensor(rc.restype_atom14_to_rigid_group, aatype) - group_mask = torch.nn.functional.one_hot(residx_to_group_idx, num_classes=8) # shape (*, N, 14, 8) - - # geometry.Rigid3Array with shape (N, 14) - map_atoms_to_global = all_frames_to_global[..., None, :] * group_mask - map_atoms_to_global = map_atoms_to_global.map_tensor_fn(partial(torch.sum, dim=-1)) - - # Gather the literature atom positions for each residue. - # geometry.Vec3Array with shape (N, 14) - lit_positions = geometry.Vec3Array.from_array(get_rc_tensor(rc.restype_atom14_rigid_group_positions, aatype)) - - # Transform each atom from its local frame to the global frame. - # geometry.Vec3Array with shape (N, 14) - pred_positions = map_atoms_to_global.apply_to_point(lit_positions) - - # Mask out non-existing atoms. - mask = get_rc_tensor(rc.restype_atom14_mask, aatype) - pred_positions = pred_positions * mask - - return pred_positions - - -def extreme_ca_ca_distance_violations( - positions: geometry.Vec3Array, # (N, 37(14)) - mask: torch.Tensor, # (N, 37(14)) - residue_index: torch.Tensor, # (N) - max_angstrom_tolerance=1.5, - eps: float = 1e-6) -> torch.Tensor: - """Counts residues whose Ca is a large distance from its neighbor.""" - this_ca_pos = positions[..., :-1, 1] # (N - 1,) - this_ca_mask = mask[..., :-1, 1] # (N - 1) - next_ca_pos = positions[..., 1:, 1] # (N - 1,) - next_ca_mask = mask[..., 1:, 1] # (N - 1) - has_no_gap_mask = ((residue_index[..., 1:] - residue_index[..., :-1]) == 1.0).astype(torch.float32) - ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, eps) - violations = (ca_ca_distance - rc.ca_ca) > max_angstrom_tolerance - mask = this_ca_mask * next_ca_mask * has_no_gap_mask - return tensor_utils.masked_mean(mask=mask, value=violations, dim=-1) - - -def get_chi_atom_indices(device: torch.device): - """Returns atom indices needed to compute chi angles for all residue types. - - Returns: - A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are - in the order specified in rc.restypes + unknown residue type - at the end. For chi angles which are not defined on the residue, the - positions indices are by default set to 0. - """ - chi_atom_indices = [] - for residue_name in rc.restypes: - residue_name = rc.restype_1to3[residue_name] - residue_chi_angles = rc.chi_angles_atoms[residue_name] - atom_indices = [] - for chi_angle in residue_chi_angles: - atom_indices.append([rc.atom_order[atom] for atom in chi_angle]) - for _ in range(4 - len(atom_indices)): - atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. - chi_atom_indices.append(atom_indices) - - chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. - return torch.tensor(chi_atom_indices, device=device) - - -def compute_chi_angles(positions: geometry.Vec3Array, mask: torch.Tensor, aatype: torch.Tensor): - """Computes the chi angles given all atom positions and the amino acid type. - - Args: - positions: A Vec3Array of shape - [num_res, rc.atom_type_num], with positions of - atoms needed to calculate chi angles. Supports up to 1 batch dimension. - mask: An optional tensor of shape - [num_res, rc.atom_type_num] that masks which atom - positions are set for each residue. If given, then the chi mask will be - set to 1 for a chi angle only if the amino acid has that chi angle and all - the chi atoms needed to calculate that chi angle are set. If not given - (set to None), the chi mask will be set to 1 for a chi angle if the amino - acid has that chi angle and whether the actual atoms needed to calculate - it were set will be ignored. - aatype: A tensor of shape [num_res] with amino acid type integer - code (0 to 21). Supports up to 1 batch dimension. - - Returns: - A tuple of tensors (chi_angles, mask), where both have shape - [num_res, 4]. The mask masks out unused chi angles for amino acid - types that have less than 4 chi angles. If atom_positions_mask is set, the - chi mask will also mask out uncomputable chi angles. - """ - - # Don't assert on the num_res and batch dimensions as they might be unknown. - assert positions.shape[-1] == rc.atom_type_num - assert mask.shape[-1] == rc.atom_type_num - no_batch_dims = len(aatype.shape) - 1 - - # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. - chi_atom_indices = get_chi_atom_indices(aatype.device) - - # DISCREPANCY: DeepMind doesn't remove the gaps here. I don't know why - # theirs works. - aatype_gapless = torch.clamp(aatype, max=20) - - # Select atoms to compute chis. Shape: [*, num_res, chis=4, atoms=4]. - atom_indices = chi_atom_indices[aatype_gapless] - # Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3]. - chi_angle_atoms = positions.map_tensor_fn( - partial(tensor_utils.batched_gather, inds=atom_indices, dim=-1, no_batch_dims=no_batch_dims + 1)) - - a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)] - - chi_angles = geometry.dihedral_angle(a, b, c, d) - - # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4]. - chi_angles_mask = list(rc.chi_angles_mask) - chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) - chi_angles_mask = torch.tensor(chi_angles_mask, device=aatype.device) - # Compute the chi angle mask. Shape [num_res, chis=4]. - chi_mask = chi_angles_mask[aatype_gapless] - - # The chi_mask is set to 1 only when all necessary chi angle atoms were set. - # Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4]. - chi_angle_atoms_mask = tensor_utils.batched_gather(mask, atom_indices, dim=-1, no_batch_dims=no_batch_dims + 1) - # Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4]. - chi_angle_atoms_mask = torch.prod(chi_angle_atoms_mask, dim=-1) - chi_mask = chi_mask * chi_angle_atoms_mask.to(torch.float32) - - return chi_angles, chi_mask - - -def make_transform_from_reference(a_xyz: geometry.Vec3Array, b_xyz: geometry.Vec3Array, - c_xyz: geometry.Vec3Array) -> geometry.Rigid3Array: - """Returns rotation and translation matrices to convert from reference. - - Note that this method does not take care of symmetries. If you provide the - coordinates in the non-standard way, the A atom will end up in the negative - y-axis rather than in the positive y-axis. You need to take care of such - cases in your code. - - Args: - a_xyz: A Vec3Array. - b_xyz: A Vec3Array. - c_xyz: A Vec3Array. - - Returns: - A Rigid3Array which, when applied to coordinates in a canonicalized - reference frame, will give coordinates approximately equal - the original coordinates (in the global frame). - """ - rotation = geometry.Rot3Array.from_two_vectors(c_xyz - b_xyz, a_xyz - b_xyz) - return geometry.Rigid3Array(rotation, b_xyz) - - -def make_backbone_affine( - positions: geometry.Vec3Array, - mask: torch.Tensor, - aatype: torch.Tensor, -) -> Tuple[geometry.Rigid3Array, torch.Tensor]: - a = rc.atom_order['N'] - b = rc.atom_order['CA'] - c = rc.atom_order['C'] - - rigid_mask = (mask[..., a] * mask[..., b] * mask[..., c]) - - rigid = make_transform_from_reference( - a_xyz=positions[..., a], - b_xyz=positions[..., b], - c_xyz=positions[..., c], - ) - - return rigid, rigid_mask diff --git a/tests/test_autochunk/origin_openfold/utils/checkpointing.py b/tests/test_autochunk/origin_openfold/utils/checkpointing.py deleted file mode 100644 index bd8def5f63c7..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/checkpointing.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Callable, List, Optional, Tuple - -import torch -import torch.utils.checkpoint - -BLOCK_ARG = Any -BLOCK_ARGS = List[BLOCK_ARG] - - -def get_checkpoint_fn(): - checkpoint = torch.utils.checkpoint.checkpoint - - return checkpoint - - -@torch.jit.ignore -def checkpoint_blocks( - blocks: List[Callable], - args: BLOCK_ARGS, - blocks_per_ckpt: Optional[int], -) -> BLOCK_ARGS: - """ - Chunk a list of blocks and run each chunk with activation - checkpointing. We define a "block" as a callable whose only inputs are - the outputs of the previous block. - - Implements Subsection 1.11.8 - - Args: - blocks: - List of blocks - args: - Tuple of arguments for the first block. - blocks_per_ckpt: - Size of each chunk. A higher value corresponds to fewer - checkpoints, and trades memory for speed. If None, no checkpointing - is performed. - Returns: - The output of the final block - """ - - def wrap(a): - return (a,) if type(a) is not tuple else a - - def exec(b, a): - for block in b: - a = wrap(block(*a)) - return a - - def chunker(s, e): - - def exec_sliced(*a): - return exec(blocks[s:e], a) - - return exec_sliced - - # Avoids mishaps when the blocks take just one argument - args = wrap(args) - - if blocks_per_ckpt is None: - return exec(blocks, args) - elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): - raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") - - checkpoint = get_checkpoint_fn() - - for s in range(0, len(blocks), blocks_per_ckpt): - e = s + blocks_per_ckpt - args = checkpoint(chunker(s, e), *args) - args = wrap(args) - - return args diff --git a/tests/test_autochunk/origin_openfold/utils/feats.py b/tests/test_autochunk/origin_openfold/utils/feats.py deleted file mode 100644 index 04ae54d413b3..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/feats.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Any, Dict, Optional, Tuple, Union - -import fastfold.common.residue_constants as rc -import numpy as np -import torch -import torch.nn as nn -from fastfold.common import protein -from fastfold.utils.geometry.rigid_matrix_vector import Rigid3Array -from fastfold.utils.geometry.rotation_matrix import Rot3Array -from fastfold.utils.rigid_utils import Rigid, Rotation -from fastfold.utils.tensor_utils import batched_gather, one_hot, tensor_tree_map, tree_map - - -def dgram_from_positions( - pos: torch.Tensor, - min_bin: float = 3.25, - max_bin: float = 50.75, - no_bins: float = 39, - inf: float = 1e8, -) -> torch.Tensor: - dgram = torch.sum((pos[..., None, :] - pos[..., None, :, :])**2, dim=-1, keepdim=True) - lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device)**2 - upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) - dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) - - return dgram - - -def pseudo_beta_fn(aatype, all_atom_positions: torch.Tensor, - all_atom_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - is_gly = aatype == rc.restype_order["G"] - ca_idx = rc.atom_order["CA"] - cb_idx = rc.atom_order["CB"] - pseudo_beta = torch.where( - is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3), - all_atom_positions[..., ca_idx, :], - all_atom_positions[..., cb_idx, :], - ) - - if all_atom_masks is not None: - pseudo_beta_mask = torch.where( - is_gly, - all_atom_masks[..., ca_idx], - all_atom_masks[..., cb_idx], - ) - return pseudo_beta, pseudo_beta_mask - else: - return pseudo_beta, None - - -def atom14_to_atom37(atom14, batch: Dict[str, Any]): - atom37_data = batched_gather( - atom14, - batch["residx_atom37_to_atom14"], - dim=-2, - no_batch_dims=len(atom14.shape[:-2]), - ) - - atom37_data = atom37_data * batch["atom37_atom_exists"][..., None] - - return atom37_data - - -def build_template_angle_feat(template_feats: Dict[str, Any]) -> torch.Tensor: - template_aatype = template_feats["template_aatype"] - torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"] - alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"] - torsion_angles_mask = template_feats["template_torsion_angles_mask"] - template_angle_feat = torch.cat( - [ - nn.functional.one_hot(template_aatype, 22), - torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14), - alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14), - torsion_angles_mask, - ], - dim=-1, - ) - - return template_angle_feat - - -def build_template_pair_feat(batch: Dict[str, Any], - min_bin: float, - max_bin: float, - no_bins: int, - use_unit_vector: bool = False, - eps: float = 1e-20, - inf: float = 1e8, - chunk=None): - if chunk and 1 <= chunk <= 4: - for k, v in batch.items(): - batch[k] = v.cpu() - - template_mask = batch["template_pseudo_beta_mask"] - template_mask_2d = template_mask[..., None] * template_mask[..., None, :] - - # Compute distogram (this seems to differ slightly from Alg. 5) - tpb = batch["template_pseudo_beta"] - dgram = dgram_from_positions(tpb, min_bin, max_bin, no_bins, inf) - - to_concat = [dgram, template_mask_2d[..., None]] - - aatype_one_hot = nn.functional.one_hot( - batch["template_aatype"], - rc.restype_num + 2, - ) - - n_res = batch["template_aatype"].shape[-1] - to_concat.append(aatype_one_hot[..., None, :, :].expand(*aatype_one_hot.shape[:-2], n_res, -1, -1)) - to_concat.append(aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1)) - - n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] - rigids = Rigid.make_transform_from_reference( - n_xyz=batch["template_all_atom_positions"][..., n, :], - ca_xyz=batch["template_all_atom_positions"][..., ca, :], - c_xyz=batch["template_all_atom_positions"][..., c, :], - eps=eps, - ) - points = rigids.get_trans()[..., None, :, :] - rigid_vec = rigids[..., None].invert_apply(points) - del rigids, points - - inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) - - t_aa_masks = batch["template_all_atom_mask"] - template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] - del t_aa_masks, n, ca, c - template_mask_2d = template_mask[..., None] * template_mask[..., None, :] - - inv_distance_scalar = inv_distance_scalar * template_mask_2d - unit_vector = rigid_vec * inv_distance_scalar[..., None] - - if not use_unit_vector: - unit_vector = unit_vector * 0.0 - - to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) - to_concat.append(template_mask_2d[..., None]) - del unit_vector, rigid_vec, inv_distance_scalar - - act = torch.cat(to_concat, dim=-1) - act = act * template_mask_2d[..., None] - - return act - - -def build_extra_msa_feat(batch: Dict[str, Any]) -> torch.Tensor: - msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23) - msa_feat = [ - msa_1hot, - batch["extra_has_deletion"].unsqueeze(-1), - batch["extra_deletion_value"].unsqueeze(-1), - ] - return torch.cat(msa_feat, dim=-1) - - -def torsion_angles_to_frames( - r: Union[Rigid3Array, Rigid], - alpha: torch.Tensor, - aatype: torch.Tensor, - rrgdf: torch.Tensor, -) -> Union[Rigid, Rigid3Array]: - # [*, N, 8, 4, 4] - default_4x4 = rrgdf[aatype, ...] - - # [*, N, 8] transformations, i.e. - # One [*, N, 8, 3, 3] rotation matrix and - # One [*, N, 8, 3] translation matrix - default_r = r.from_tensor_4x4(default_4x4) - - bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) - bb_rot[..., 1] = 1 - - # [*, N, 8, 2] - alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) - - # [*, N, 8, 3, 3] - # Produces rotation matrices of the form: - # [ - # [1, 0 , 0 ], - # [0, a_2,-a_1], - # [0, a_1, a_2] - # ] - # This follows the original code rather than the supplement, which uses - # different indices. - if type(r) == Rigid3Array: - all_rots = alpha.new_zeros(default_r.shape + (3, 3)) - elif type(r) == Rigid: - all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) - else: - raise TypeError(f"Wrong type of Rigid: {type(r)}") - - all_rots[..., 0, 0] = 1 - all_rots[..., 1, 1] = alpha[..., 1] - all_rots[..., 1, 2] = -alpha[..., 0] - all_rots[..., 2, 1:] = alpha - - if type(r) == Rigid3Array: - all_rots = Rot3Array.from_array(all_rots) - all_frames = default_r.compose_rotation(all_rots) - elif type(r) == Rigid: - all_rots = Rigid(Rotation(rot_mats=all_rots), None) - all_frames = default_r.compose(all_rots) - else: - raise TypeError(f"Wrong type of Rigid: {type(r)}") - - chi2_frame_to_frame = all_frames[..., 5] - chi3_frame_to_frame = all_frames[..., 6] - chi4_frame_to_frame = all_frames[..., 7] - - chi1_frame_to_bb = all_frames[..., 4] - chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) - chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) - chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) - - if type(all_frames) == Rigid3Array: - all_frames_to_bb = Rigid3Array.cat( - [ - all_frames[..., :5], - chi2_frame_to_bb.unsqueeze(-1), - chi3_frame_to_bb.unsqueeze(-1), - chi4_frame_to_bb.unsqueeze(-1), - ], - dim=-1, - ) - elif type(all_frames) == Rigid: - all_frames_to_bb = Rigid.cat( - [ - all_frames[..., :5], - chi2_frame_to_bb.unsqueeze(-1), - chi3_frame_to_bb.unsqueeze(-1), - chi4_frame_to_bb.unsqueeze(-1), - ], - dim=-1, - ) - - all_frames_to_global = r[..., None].compose(all_frames_to_bb) - - return all_frames_to_global - - -def frames_and_literature_positions_to_atom14_pos( - r: Union[Rigid3Array, Rigid], - aatype: torch.Tensor, - default_frames: torch.Tensor, - group_idx: torch.Tensor, - atom_mask: torch.Tensor, - lit_positions: torch.Tensor, -) -> torch.Tensor: - # [*, N, 14, 4, 4] - default_4x4 = default_frames[aatype, ...] - - # [*, N, 14] - group_mask = group_idx[aatype, ...] - - # [*, N, 14, 8] - if type(r) == Rigid3Array: - group_mask = nn.functional.one_hot( - group_mask.long(), - num_classes=default_frames.shape[-3], - ) - elif type(r) == Rigid: - group_mask = nn.functional.one_hot( - group_mask, - num_classes=default_frames.shape[-3], - ) - else: - raise TypeError(f"Wrong type of Rigid: {type(r)}") - - # [*, N, 14, 8] - t_atoms_to_global = r[..., None, :] * group_mask - - # [*, N, 14] - t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) - - # [*, N, 14, 1] - if type(r) == Rigid: - atom_mask = atom_mask[aatype, ...].unsqueeze(-1) - elif type(r) == Rigid3Array: - atom_mask = atom_mask[aatype, ...] - - # [*, N, 14, 3] - lit_positions = lit_positions[aatype, ...] - pred_positions = t_atoms_to_global.apply(lit_positions) - pred_positions = pred_positions * atom_mask - - return pred_positions diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/__init__.py b/tests/test_autochunk/origin_openfold/utils/geometry/__init__.py deleted file mode 100644 index 2abd731c6a31..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/geometry/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Geometry Module.""" - -from fastfold.utils.geometry import rigid_matrix_vector, rotation_matrix, vector - -Rot3Array = rotation_matrix.Rot3Array -Rigid3Array = rigid_matrix_vector.Rigid3Array - -Vec3Array = vector.Vec3Array -square_euclidean_distance = vector.square_euclidean_distance -euclidean_distance = vector.euclidean_distance -dihedral_angle = vector.dihedral_angle -dot = vector.dot -cross = vector.cross diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/quat_rigid.py b/tests/test_autochunk/origin_openfold/utils/geometry/quat_rigid.py deleted file mode 100644 index fa72b4a7437f..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/geometry/quat_rigid.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -import torch.nn as nn -from fastfold.model.nn.primitives import Linear -from fastfold.utils.geometry.rigid_matrix_vector import Rigid3Array -from fastfold.utils.geometry.rotation_matrix import Rot3Array -from fastfold.utils.geometry.vector import Vec3Array - - -class QuatRigid(nn.Module): - - def __init__(self, c_hidden, full_quat): - super().__init__() - self.full_quat = full_quat - if self.full_quat: - rigid_dim = 7 - else: - rigid_dim = 6 - - self.linear = Linear(c_hidden, rigid_dim) - - def forward(self, activations: torch.Tensor) -> Rigid3Array: - # NOTE: During training, this needs to be run in higher precision - rigid_flat = self.linear(activations.to(torch.float32)) - - rigid_flat = torch.unbind(rigid_flat, dim=-1) - if (self.full_quat): - qw, qx, qy, qz = rigid_flat[:4] - translation = rigid_flat[4:] - else: - qx, qy, qz = rigid_flat[:3] - qw = torch.ones_like(qx) - translation = rigid_flat[3:] - - rotation = Rot3Array.from_quaternion( - qw, - qx, - qy, - qz, - normalize=True, - ) - translation = Vec3Array(*translation) - return Rigid3Array(rotation, translation) diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/rigid_matrix_vector.py b/tests/test_autochunk/origin_openfold/utils/geometry/rigid_matrix_vector.py deleted file mode 100644 index 7b97e1827c16..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/geometry/rigid_matrix_vector.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Rigid3Array Transformations represented by a Matrix and a Vector.""" - -from __future__ import annotations - -import dataclasses -from typing import List, Union - -import torch -from fastfold.utils.geometry import rotation_matrix, vector - -Float = Union[float, torch.Tensor] - - -@dataclasses.dataclass(frozen=True) -class Rigid3Array: - """Rigid Transformation, i.e. element of special euclidean group.""" - - rotation: rotation_matrix.Rot3Array - translation: vector.Vec3Array - - def __matmul__(self, other: Rigid3Array) -> Rigid3Array: - new_rotation = self.rotation @ other.rotation # __matmul__ - new_translation = self.apply_to_point(other.translation) - return Rigid3Array(new_rotation, new_translation) - - def __getitem__(self, index) -> Rigid3Array: - return Rigid3Array( - self.rotation[index], - self.translation[index], - ) - - def __mul__(self, other: torch.Tensor) -> Rigid3Array: - return Rigid3Array( - self.rotation * other, - self.translation * other, - ) - - def map_tensor_fn(self, fn) -> Rigid3Array: - return Rigid3Array( - self.rotation.map_tensor_fn(fn), - self.translation.map_tensor_fn(fn), - ) - - def inverse(self) -> Rigid3Array: - """Return Rigid3Array corresponding to inverse transform.""" - inv_rotation = self.rotation.inverse() - inv_translation = inv_rotation.apply_to_point(-self.translation) - return Rigid3Array(inv_rotation, inv_translation) - - def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: - """Apply Rigid3Array transform to point.""" - return self.rotation.apply_to_point(point) + self.translation - - def apply(self, point: torch.Tensor) -> vector.Vec3Array: - return self.apply_to_point(vector.Vec3Array.from_array(point)) - - def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: - """Apply inverse Rigid3Array transform to point.""" - new_point = point - self.translation - return self.rotation.apply_inverse_to_point(new_point) - - def compose_rotation(self, other_rotation): - rot = self.rotation @ other_rotation - return Rigid3Array(rot, self.translation.clone()) - - def compose(self, other_rigid): - return self @ other_rigid - - def unsqueeze(self, dim: int): - return Rigid3Array( - self.rotation.unsqueeze(dim), - self.translation.unsqueeze(dim), - ) - - @property - def shape(self) -> torch.Size: - return self.rotation.xx.shape - - @property - def dtype(self) -> torch.dtype: - return self.rotation.xx.dtype - - @property - def device(self) -> torch.device: - return self.rotation.xx.device - - @classmethod - def identity(cls, shape, device) -> Rigid3Array: - """Return identity Rigid3Array of given shape.""" - return cls(rotation_matrix.Rot3Array.identity(shape, device), vector.Vec3Array.zeros(shape, device)) - - @classmethod - def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array: - return cls( - rotation_matrix.Rot3Array.cat([r.rotation for r in rigids], dim=dim), - vector.Vec3Array.cat([r.translation for r in rigids], dim=dim), - ) - - def scale_translation(self, factor: Float) -> Rigid3Array: - """Scale translation in Rigid3Array by 'factor'.""" - return Rigid3Array(self.rotation, self.translation * factor) - - def to_tensor(self) -> torch.Tensor: - rot_array = self.rotation.to_tensor() - vec_array = self.translation.to_tensor() - array = torch.zeros(rot_array.shape[:-2] + (4, 4), device=rot_array.device, dtype=rot_array.dtype) - array[..., :3, :3] = rot_array - array[..., :3, 3] = vec_array - array[..., 3, 3] = 1. - return array - - def to_tensor_4x4(self) -> torch.Tensor: - return self.to_tensor() - - def reshape(self, new_shape) -> Rigid3Array: - rots = self.rotation.reshape(new_shape) - trans = self.translation.reshape(new_shape) - return Rigid3Aray(rots, trans) - - def stop_rot_gradient(self) -> Rigid3Array: - return Rigid3Array( - self.rotation.stop_gradient(), - self.translation, - ) - - @classmethod - def from_array(cls, array): - rot = rotation_matrix.Rot3Array.from_array(array[..., :3, :3],) - vec = vector.Vec3Array.from_array(array[..., :3, 3]) - return cls(rot, vec) - - @classmethod - def from_tensor_4x4(cls, array): - return cls.from_array(array) - - @classmethod - def from_array4x4(cls, array: torch.tensor) -> Rigid3Array: - """Construct Rigid3Array from homogeneous 4x4 array.""" - rotation = rotation_matrix.Rot3Array(array[..., 0, 0], array[..., 0, 1], array[..., 0, 2], array[..., 1, 0], - array[..., 1, 1], array[..., 1, 2], array[..., 2, 0], array[..., 2, 1], - array[..., 2, 2]) - translation = vector.Vec3Array(array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]) - return cls(rotation, translation) diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py b/tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py deleted file mode 100644 index 8d929524268a..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/geometry/rotation_matrix.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Rot3Array Matrix Class.""" - -from __future__ import annotations - -import dataclasses - -import numpy as np -import torch -from fastfold.utils.geometry import utils, vector -from fastfold.utils.tensor_utils import tensor_tree_map - -COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] - - -@dataclasses.dataclass(frozen=True) -class Rot3Array: - """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" - xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) - xy: torch.Tensor - xz: torch.Tensor - yx: torch.Tensor - yy: torch.Tensor - yz: torch.Tensor - zx: torch.Tensor - zy: torch.Tensor - zz: torch.Tensor - - __array_ufunc__ = None - - def __getitem__(self, index): - field_names = utils.get_field_names(Rot3Array) - return Rot3Array(**{name: getattr(self, name)[index] for name in field_names}) - - def __mul__(self, other: torch.Tensor): - field_names = utils.get_field_names(Rot3Array) - return Rot3Array(**{name: getattr(self, name) * other for name in field_names}) - - def __matmul__(self, other: Rot3Array) -> Rot3Array: - """Composes two Rot3Arrays.""" - c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) - c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) - c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) - return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) - - def map_tensor_fn(self, fn) -> Rot3Array: - field_names = utils.get_field_names(Rot3Array) - return Rot3Array(**{name: fn(getattr(self, name)) for name in field_names}) - - def inverse(self) -> Rot3Array: - """Returns inverse of Rot3Array.""" - return Rot3Array(self.xx, self.yx, self.zx, self.xy, self.yy, self.zy, self.xz, self.yz, self.zz) - - def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: - """Applies Rot3Array to point.""" - return vector.Vec3Array(self.xx * point.x + self.xy * point.y + self.xz * point.z, - self.yx * point.x + self.yy * point.y + self.yz * point.z, - self.zx * point.x + self.zy * point.y + self.zz * point.z) - - def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: - """Applies inverse Rot3Array to point.""" - return self.inverse().apply_to_point(point) - - def unsqueeze(self, dim: int): - return Rot3Array(*tensor_tree_map(lambda t: t.unsqueeze(dim), [getattr(self, c) for c in COMPONENTS])) - - def stop_gradient(self) -> Rot3Array: - return Rot3Array(*[getattr(self, c).detach() for c in COMPONENTS]) - - @classmethod - def identity(cls, shape, device) -> Rot3Array: - """Returns identity of given shape.""" - ones = torch.ones(shape, dtype=torch.float32, device=device) - zeros = torch.zeros(shape, dtype=torch.float32, device=device) - return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) - - @classmethod - def from_two_vectors(cls, e0: vector.Vec3Array, e1: vector.Vec3Array) -> Rot3Array: - """Construct Rot3Array from two Vectors. - - Rot3Array is constructed such that in the corresponding frame 'e0' lies on - the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. - - Args: - e0: Vector - e1: Vector - Returns: - Rot3Array - """ - # Normalize the unit vector for the x-axis, e0. - e0 = e0.normalized() - # make e1 perpendicular to e0. - c = e1.dot(e0) - e1 = (e1 - c * e0).normalized() - # Compute e2 as cross product of e0 and e1. - e2 = e0.cross(e1) - return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) - - @classmethod - def from_array(cls, array: torch.Tensor) -> Rot3Array: - """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" - rows = torch.unbind(array, dim=-2) - rc = [torch.unbind(e, dim=-1) for e in rows] - return cls(*[e for row in rc for e in row]) - - def to_tensor(self) -> torch.Tensor: - """Convert Rot3Array to array of shape [..., 3, 3].""" - return torch.stack([ - torch.stack([self.xx, self.xy, self.xz], dim=-1), - torch.stack([self.yx, self.yy, self.yz], dim=-1), - torch.stack([self.zx, self.zy, self.zz], dim=-1) - ], - dim=-2) - - @classmethod - def from_quaternion(cls, - w: torch.Tensor, - x: torch.Tensor, - y: torch.Tensor, - z: torch.Tensor, - normalize: bool = True, - eps: float = 1e-6) -> Rot3Array: - """Construct Rot3Array from components of quaternion.""" - if normalize: - inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2) - w *= inv_norm - x *= inv_norm - y *= inv_norm - z *= inv_norm - xx = 1 - 2 * (y**2 + z**2) - xy = 2 * (x * y - w * z) - xz = 2 * (x * z + w * y) - yx = 2 * (x * y + w * z) - yy = 1 - 2 * (x**2 + z**2) - yz = 2 * (y * z - w * x) - zx = 2 * (x * z - w * y) - zy = 2 * (y * z + w * x) - zz = 1 - 2 * (x**2 + y**2) - return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) - - def reshape(self, new_shape): - field_names = utils.get_field_names(Rot3Array) - reshape_fn = lambda t: t.reshape(new_shape) - return Rot3Array(**{name: reshape_fn(getattr(self, name)) for name in field_names}) - - @classmethod - def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array: - field_names = utils.get_field_names(Rot3Array) - cat_fn = lambda ll: torch.cat(ll, dim=dim) - return cls(**{name: cat_fn([getattr(r, name) for r in rots]) for name in field_names}) diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/test_utils.py b/tests/test_autochunk/origin_openfold/utils/geometry/test_utils.py deleted file mode 100644 index a86cb6a864e6..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/geometry/test_utils.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Shared utils for tests.""" - -import dataclasses - -import numpy as np -from fastfold.utils.geometry import rigid_matrix_vector, rotation_matrix, vector - - -def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array, matrix2: rotation_matrix.Rot3Array): - for field in dataclasses.fields(rotation_matrix.Rot3Array): - field = field.name - np.testing.assert_array_equal(getattr(matrix1, field), getattr(matrix2, field)) - - -def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array, mat2: rotation_matrix.Rot3Array): - np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6) - - -def assert_array_equal_to_rotation_matrix(array: np.ndarray, matrix: rotation_matrix.Rot3Array): - """Check that array and Matrix match.""" - np.testing.assert_array_equal(matrix.xx, array[..., 0, 0]) - np.testing.assert_array_equal(matrix.xy, array[..., 0, 1]) - np.testing.assert_array_equal(matrix.xz, array[..., 0, 2]) - np.testing.assert_array_equal(matrix.yx, array[..., 1, 0]) - np.testing.assert_array_equal(matrix.yy, array[..., 1, 1]) - np.testing.assert_array_equal(matrix.yz, array[..., 1, 2]) - np.testing.assert_array_equal(matrix.zx, array[..., 2, 0]) - np.testing.assert_array_equal(matrix.zy, array[..., 2, 1]) - np.testing.assert_array_equal(matrix.zz, array[..., 2, 2]) - - -def assert_array_close_to_rotation_matrix(array: np.ndarray, matrix: rotation_matrix.Rot3Array): - np.testing.assert_array_almost_equal(matrix.to_array(), array, 6) - - -def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array): - np.testing.assert_array_equal(vec1.x, vec2.x) - np.testing.assert_array_equal(vec1.y, vec2.y) - np.testing.assert_array_equal(vec1.z, vec2.z) - - -def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array): - np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.) - np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.) - np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.) - - -def assert_array_close_to_vector(array: np.ndarray, vec: vector.Vec3Array): - np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.) - - -def assert_array_equal_to_vector(array: np.ndarray, vec: vector.Vec3Array): - np.testing.assert_array_equal(vec.to_array(), array) - - -def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, rigid2: rigid_matrix_vector.Rigid3Array): - assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2) - - -def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, rigid2: rigid_matrix_vector.Rigid3Array): - assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2) - - -def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array, trans: vector.Vec3Array, - rigid: rigid_matrix_vector.Rigid3Array): - assert_rotation_matrix_equal(rot, rigid.rotation) - assert_vectors_equal(trans, rigid.translation) - - -def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array, trans: vector.Vec3Array, - rigid: rigid_matrix_vector.Rigid3Array): - assert_rotation_matrix_close(rot, rigid.rotation) - assert_vectors_close(trans, rigid.translation) diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/utils.py b/tests/test_autochunk/origin_openfold/utils/geometry/utils.py deleted file mode 100644 index 6c4d52ba9969..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/geometry/utils.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utils for geometry library.""" - -import dataclasses - - -def get_field_names(cls): - fields = dataclasses.fields(cls) - field_names = [f.name for f in fields] - return field_names diff --git a/tests/test_autochunk/origin_openfold/utils/geometry/vector.py b/tests/test_autochunk/origin_openfold/utils/geometry/vector.py deleted file mode 100644 index 4204e736c328..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/geometry/vector.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Vec3Array Class.""" - -from __future__ import annotations - -import dataclasses -from typing import List, Union - -import torch -from fastfold.utils.geometry import utils - -Float = Union[float, torch.Tensor] - - -@dataclasses.dataclass(frozen=True) -class Vec3Array: - x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) - y: torch.Tensor - z: torch.Tensor - - def __post_init__(self): - if hasattr(self.x, 'dtype'): - assert self.x.dtype == self.y.dtype - assert self.x.dtype == self.z.dtype - assert all([x == y for x, y in zip(self.x.shape, self.y.shape)]) - assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) - - def __add__(self, other: Vec3Array) -> Vec3Array: - return Vec3Array( - self.x + other.x, - self.y + other.y, - self.z + other.z, - ) - - def __sub__(self, other: Vec3Array) -> Vec3Array: - return Vec3Array( - self.x - other.x, - self.y - other.y, - self.z - other.z, - ) - - def __mul__(self, other: Float) -> Vec3Array: - return Vec3Array( - self.x * other, - self.y * other, - self.z * other, - ) - - def __rmul__(self, other: Float) -> Vec3Array: - return self * other - - def __truediv__(self, other: Float) -> Vec3Array: - return Vec3Array( - self.x / other, - self.y / other, - self.z / other, - ) - - def __neg__(self) -> Vec3Array: - return self * -1 - - def __pos__(self) -> Vec3Array: - return self * 1 - - def __getitem__(self, index) -> Vec3Array: - return Vec3Array( - self.x[index], - self.y[index], - self.z[index], - ) - - def __iter__(self): - return iter((self.x, self.y, self.z)) - - @property - def shape(self): - return self.x.shape - - def map_tensor_fn(self, fn) -> Vec3Array: - return Vec3Array( - fn(self.x), - fn(self.y), - fn(self.z), - ) - - def cross(self, other: Vec3Array) -> Vec3Array: - """Compute cross product between 'self' and 'other'.""" - new_x = self.y * other.z - self.z * other.y - new_y = self.z * other.x - self.x * other.z - new_z = self.x * other.y - self.y * other.x - return Vec3Array(new_x, new_y, new_z) - - def dot(self, other: Vec3Array) -> Float: - """Compute dot product between 'self' and 'other'.""" - return self.x * other.x + self.y * other.y + self.z * other.z - - def norm(self, epsilon: float = 1e-6) -> Float: - """Compute Norm of Vec3Array, clipped to epsilon.""" - # To avoid NaN on the backward pass, we must use maximum before the sqrt - norm2 = self.dot(self) - if epsilon: - norm2 = torch.clamp(norm2, min=epsilon**2) - return torch.sqrt(norm2) - - def norm2(self): - return self.dot(self) - - def normalized(self, epsilon: float = 1e-6) -> Vec3Array: - """Return unit vector with optional clipping.""" - return self / self.norm(epsilon) - - def clone(self) -> Vec3Array: - return Vec3Array( - self.x.clone(), - self.y.clone(), - self.z.clone(), - ) - - def reshape(self, new_shape) -> Vec3Array: - x = self.x.reshape(new_shape) - y = self.y.reshape(new_shape) - z = self.z.reshape(new_shape) - - return Vec3Array(x, y, z) - - def sum(self, dim: int) -> Vec3Array: - return Vec3Array( - torch.sum(self.x, dim=dim), - torch.sum(self.y, dim=dim), - torch.sum(self.z, dim=dim), - ) - - def unsqueeze(self, dim: int): - return Vec3Array( - self.x.unsqueeze(dim), - self.y.unsqueeze(dim), - self.z.unsqueeze(dim), - ) - - @classmethod - def zeros(cls, shape, device="cpu"): - """Return Vec3Array corresponding to zeros of given shape.""" - return cls(torch.zeros(shape, dtype=torch.float32, device=device), - torch.zeros(shape, dtype=torch.float32, device=device), - torch.zeros(shape, dtype=torch.float32, device=device)) - - def to_tensor(self) -> torch.Tensor: - return torch.stack([self.x, self.y, self.z], dim=-1) - - @classmethod - def from_array(cls, tensor): - return cls(*torch.unbind(tensor, dim=-1)) - - @classmethod - def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array: - return cls( - torch.cat([v.x for v in vecs], dim=dim), - torch.cat([v.y for v in vecs], dim=dim), - torch.cat([v.z for v in vecs], dim=dim), - ) - - -def square_euclidean_distance(vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6) -> Float: - """Computes square of euclidean distance between 'vec1' and 'vec2'. - - Args: - vec1: Vec3Array to compute distance to - vec2: Vec3Array to compute distance from, should be - broadcast compatible with 'vec1' - epsilon: distance is clipped from below to be at least epsilon - - Returns: - Array of square euclidean distances; - shape will be result of broadcasting 'vec1' and 'vec2' - """ - difference = vec1 - vec2 - distance = difference.dot(difference) - if epsilon: - distance = torch.maximum(distance, epsilon) - return distance - - -def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: - return vector1.dot(vector2) - - -def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: - return vector1.cross(vector2) - - -def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: - return vector.norm(epsilon) - - -def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: - return vector.normalized(epsilon) - - -def euclidean_distance(vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6) -> Float: - """Computes euclidean distance between 'vec1' and 'vec2'. - - Args: - vec1: Vec3Array to compute euclidean distance to - vec2: Vec3Array to compute euclidean distance from, should be - broadcast compatible with 'vec1' - epsilon: distance is clipped from below to be at least epsilon - - Returns: - Array of euclidean distances; - shape will be result of broadcasting 'vec1' and 'vec2' - """ - distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) - distance = torch.sqrt(distance_sq) - return distance - - -def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, d: Vec3Array) -> Float: - """Computes torsion angle for a quadruple of points. - - For points (a, b, c, d), this is the angle between the planes defined by - points (a, b, c) and (b, c, d). It is also known as the dihedral angle. - - Arguments: - a: A Vec3Array of coordinates. - b: A Vec3Array of coordinates. - c: A Vec3Array of coordinates. - d: A Vec3Array of coordinates. - - Returns: - A tensor of angles in radians: [-pi, pi]. - """ - v1 = a - b - v2 = b - c - v3 = d - c - - c1 = v1.cross(v2) - c2 = v3.cross(v2) - c3 = c2.cross(c1) - - v2_mag = v2.norm() - return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2)) diff --git a/tests/test_autochunk/origin_openfold/utils/loss.py b/tests/test_autochunk/origin_openfold/utils/loss.py deleted file mode 100644 index d39705c901f3..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/loss.py +++ /dev/null @@ -1,1403 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from functools import partial -from typing import Dict, Optional, Tuple - -import ml_collections -import numpy as np -import torch -import torch.nn as nn -from fastfold.common import residue_constants -from fastfold.utils import feats -from fastfold.utils.rigid_utils import Rigid, Rotation -from fastfold.utils.tensor_utils import batched_gather, masked_mean, permute_final_dims, tensor_tree_map, tree_map -from torch.distributions.bernoulli import Bernoulli - - -def softmax_cross_entropy(logits, labels): - loss = -1 * torch.sum( - labels * torch.nn.functional.log_softmax(logits, dim=-1), - dim=-1, - ) - return loss - - -def sigmoid_cross_entropy(logits, labels): - log_p = torch.log(torch.sigmoid(logits)) - log_not_p = torch.log(torch.sigmoid(-logits)) - loss = -labels * log_p - (1 - labels) * log_not_p - return loss - - -def torsion_angle_loss( - a, # [*, N, 7, 2] - a_gt, # [*, N, 7, 2] - a_alt_gt, # [*, N, 7, 2] -): - # [*, N, 7] - norm = torch.norm(a, dim=-1) - - # [*, N, 7, 2] - a = a / norm.unsqueeze(-1) - - # [*, N, 7] - diff_norm_gt = torch.norm(a - a_gt, dim=-1) - diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1) - min_diff = torch.minimum(diff_norm_gt**2, diff_norm_alt_gt**2) - - # [*] - l_torsion = torch.mean(min_diff, dim=(-1, -2)) - l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2)) - - an_weight = 0.02 - return l_torsion + an_weight * l_angle_norm - - -def compute_fape( - pred_frames: Rigid, - target_frames: Rigid, - frames_mask: torch.Tensor, - pred_positions: torch.Tensor, - target_positions: torch.Tensor, - positions_mask: torch.Tensor, - length_scale: float, - l1_clamp_distance: Optional[float] = None, - eps=1e-8, -) -> torch.Tensor: - """ - Computes FAPE loss. - - Args: - pred_frames: - [*, N_frames] Rigid object of predicted frames - target_frames: - [*, N_frames] Rigid object of ground truth frames - frames_mask: - [*, N_frames] binary mask for the frames - pred_positions: - [*, N_pts, 3] predicted atom positions - target_positions: - [*, N_pts, 3] ground truth positions - positions_mask: - [*, N_pts] positions mask - length_scale: - Length scale by which the loss is divided - l1_clamp_distance: - Cutoff above which distance errors are disregarded - eps: - Small value used to regularize denominators - Returns: - [*] loss tensor - """ - # [*, N_frames, N_pts, 3] - local_pred_pos = pred_frames.invert()[..., None].apply(pred_positions[..., None, :, :],) - local_target_pos = target_frames.invert()[..., None].apply(target_positions[..., None, :, :],) - - error_dist = torch.sqrt(torch.sum((local_pred_pos - local_target_pos)**2, dim=-1) + eps) - - if l1_clamp_distance is not None: - error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance) - - normed_error = error_dist / length_scale - normed_error = normed_error * frames_mask[..., None] - normed_error = normed_error * positions_mask[..., None, :] - - # FP16-friendly averaging. Roughly equivalent to: - # - # norm_factor = ( - # torch.sum(frames_mask, dim=-1) * - # torch.sum(positions_mask, dim=-1) - # ) - # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor) - # - # ("roughly" because eps is necessarily duplicated in the latter) - normed_error = torch.sum(normed_error, dim=-1) - normed_error = (normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]) - normed_error = torch.sum(normed_error, dim=-1) - normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1)) - - return normed_error - - -def backbone_loss( - backbone_rigid_tensor: torch.Tensor, - backbone_rigid_mask: torch.Tensor, - traj: torch.Tensor, - use_clamped_fape: Optional[torch.Tensor] = None, - clamp_distance: float = 10.0, - loss_unit_distance: float = 10.0, - eps: float = 1e-4, - **kwargs, -) -> torch.Tensor: - pred_aff = Rigid.from_tensor_7(traj) - pred_aff = Rigid( - Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None), - pred_aff.get_trans(), - ) - - # DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of - # backbone tensor, normalizes it, and then turns it back to a rotation - # matrix. To avoid a potentially numerically unstable rotation matrix - # to quaternion conversion, we just use the original rotation matrix - # outright. This one hasn't been composed a bunch of times, though, so - # it might be fine. - gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor) - - fape_loss = compute_fape( - pred_aff, - gt_aff[None], - backbone_rigid_mask[None], - pred_aff.get_trans(), - gt_aff[None].get_trans(), - backbone_rigid_mask[None], - l1_clamp_distance=clamp_distance, - length_scale=loss_unit_distance, - eps=eps, - ) - if use_clamped_fape is not None: - unclamped_fape_loss = compute_fape( - pred_aff, - gt_aff[None], - backbone_rigid_mask[None], - pred_aff.get_trans(), - gt_aff[None].get_trans(), - backbone_rigid_mask[None], - l1_clamp_distance=None, - length_scale=loss_unit_distance, - eps=eps, - ) - - fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (1 - use_clamped_fape) - - # Average over the batch dimension - fape_loss = torch.mean(fape_loss) - - return fape_loss - - -def sidechain_loss( - sidechain_frames: torch.Tensor, - sidechain_atom_pos: torch.Tensor, - rigidgroups_gt_frames: torch.Tensor, - rigidgroups_alt_gt_frames: torch.Tensor, - rigidgroups_gt_exists: torch.Tensor, - renamed_atom14_gt_positions: torch.Tensor, - renamed_atom14_gt_exists: torch.Tensor, - alt_naming_is_better: torch.Tensor, - clamp_distance: float = 10.0, - length_scale: float = 10.0, - eps: float = 1e-4, - **kwargs, -) -> torch.Tensor: - renamed_gt_frames = (1.0 - alt_naming_is_better[..., None, None, None] - ) * rigidgroups_gt_frames + alt_naming_is_better[..., None, None, - None] * rigidgroups_alt_gt_frames - - # Steamroll the inputs - sidechain_frames = sidechain_frames[-1] - batch_dims = sidechain_frames.shape[:-4] - sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4) - sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames) - renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4) - renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames) - rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1) - sidechain_atom_pos = sidechain_atom_pos[-1] - sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3) - renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(*batch_dims, -1, 3) - renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1) - - fape = compute_fape( - sidechain_frames, - renamed_gt_frames, - rigidgroups_gt_exists, - sidechain_atom_pos, - renamed_atom14_gt_positions, - renamed_atom14_gt_exists, - l1_clamp_distance=clamp_distance, - length_scale=length_scale, - eps=eps, - ) - - return fape - - -def fape_loss( - out: Dict[str, torch.Tensor], - batch: Dict[str, torch.Tensor], - config: ml_collections.ConfigDict, -) -> torch.Tensor: - bb_loss = backbone_loss( - traj=out["sm"]["frames"], - **{ - **batch, - **config.backbone - }, - ) - - sc_loss = sidechain_loss( - out["sm"]["sidechain_frames"], - out["sm"]["positions"], - **{ - **batch, - **config.sidechain - }, - ) - - loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss - - # Average over the batch dimension - loss = torch.mean(loss) - - return loss - - -def supervised_chi_loss( - angles_sin_cos: torch.Tensor, - unnormalized_angles_sin_cos: torch.Tensor, - aatype: torch.Tensor, - seq_mask: torch.Tensor, - chi_mask: torch.Tensor, - chi_angles_sin_cos: torch.Tensor, - chi_weight: float, - angle_norm_weight: float, - eps=1e-6, - **kwargs, -) -> torch.Tensor: - """ - Implements Algorithm 27 (torsionAngleLoss) - - Args: - angles_sin_cos: - [*, N, 7, 2] predicted angles - unnormalized_angles_sin_cos: - The same angles, but unnormalized - aatype: - [*, N] residue indices - seq_mask: - [*, N] sequence mask - chi_mask: - [*, N, 7] angle mask - chi_angles_sin_cos: - [*, N, 7, 2] ground truth angles - chi_weight: - Weight for the angle component of the loss - angle_norm_weight: - Weight for the normalization component of the loss - Returns: - [*] loss tensor - """ - pred_angles = angles_sin_cos[..., 3:, :] - residue_type_one_hot = torch.nn.functional.one_hot( - aatype, - residue_constants.restype_num + 1, - ) - chi_pi_periodic = torch.einsum( - "...ij,jk->ik", - residue_type_one_hot.type(angles_sin_cos.dtype), - angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), - ) - - true_chi = chi_angles_sin_cos[None] - - shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1) - true_chi_shifted = shifted_mask * true_chi - sq_chi_error = torch.sum((true_chi - pred_angles)**2, dim=-1) - sq_chi_error_shifted = torch.sum((true_chi_shifted - pred_angles)**2, dim=-1) - sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted) - # The ol' switcheroo - sq_chi_error = sq_chi_error.permute(*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1) - sq_chi_loss = masked_mean(chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)) - - loss = chi_weight * sq_chi_loss - - angle_norm = torch.sqrt(torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps) - norm_error = torch.abs(angle_norm - 1.0) - norm_error = norm_error.permute(*range(len(norm_error.shape))[1:-2], 0, -2, -1) - angle_norm_loss = masked_mean(seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)) - - loss = loss + angle_norm_weight * angle_norm_loss - - # Average over the batch dimension - loss = torch.mean(loss) - - return loss - - -def compute_plddt(logits: torch.Tensor) -> torch.Tensor: - num_bins = logits.shape[-1] - bin_width = 1.0 / num_bins - bounds = torch.arange(start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device) - probs = torch.nn.functional.softmax(logits, dim=-1) - pred_lddt_ca = torch.sum( - probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape), - dim=-1, - ) - return pred_lddt_ca * 100 - - -def lddt( - all_atom_pred_pos: torch.Tensor, - all_atom_positions: torch.Tensor, - all_atom_mask: torch.Tensor, - cutoff: float = 15.0, - eps: float = 1e-10, - per_residue: bool = True, -) -> torch.Tensor: - n = all_atom_mask.shape[-2] - dmat_true = torch.sqrt(eps + torch.sum( - (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])**2, - dim=-1, - )) - - dmat_pred = torch.sqrt(eps + torch.sum( - (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :])**2, - dim=-1, - )) - dists_to_score = ((dmat_true < cutoff) * all_atom_mask * permute_final_dims(all_atom_mask, (1, 0)) * - (1.0 - torch.eye(n, device=all_atom_mask.device))) - - dist_l1 = torch.abs(dmat_true - dmat_pred) - - score = ((dist_l1 < 0.5).type(dist_l1.dtype) + (dist_l1 < 1.0).type(dist_l1.dtype) + - (dist_l1 < 2.0).type(dist_l1.dtype) + (dist_l1 < 4.0).type(dist_l1.dtype)) - score = score * 0.25 - - dims = (-1,) if per_residue else (-2, -1) - norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) - score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) - - return score - - -def lddt_ca( - all_atom_pred_pos: torch.Tensor, - all_atom_positions: torch.Tensor, - all_atom_mask: torch.Tensor, - cutoff: float = 15.0, - eps: float = 1e-10, - per_residue: bool = True, -) -> torch.Tensor: - ca_pos = residue_constants.atom_order["CA"] - all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] - all_atom_positions = all_atom_positions[..., ca_pos, :] - all_atom_mask = all_atom_mask[..., ca_pos:(ca_pos + 1)] # keep dim - - return lddt( - all_atom_pred_pos, - all_atom_positions, - all_atom_mask, - cutoff=cutoff, - eps=eps, - per_residue=per_residue, - ) - - -def lddt_loss( - logits: torch.Tensor, - all_atom_pred_pos: torch.Tensor, - all_atom_positions: torch.Tensor, - all_atom_mask: torch.Tensor, - resolution: torch.Tensor, - cutoff: float = 15.0, - no_bins: int = 50, - min_resolution: float = 0.1, - max_resolution: float = 3.0, - eps: float = 1e-10, - **kwargs, -) -> torch.Tensor: - n = all_atom_mask.shape[-2] - - ca_pos = residue_constants.atom_order["CA"] - all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] - all_atom_positions = all_atom_positions[..., ca_pos, :] - all_atom_mask = all_atom_mask[..., ca_pos:(ca_pos + 1)] # keep dim - - score = lddt(all_atom_pred_pos, all_atom_positions, all_atom_mask, cutoff=cutoff, eps=eps) - - score = score.detach() - - bin_index = torch.floor(score * no_bins).long() - bin_index = torch.clamp(bin_index, max=(no_bins - 1)) - lddt_ca_one_hot = torch.nn.functional.one_hot(bin_index, num_classes=no_bins) - - errors = softmax_cross_entropy(logits, lddt_ca_one_hot) - all_atom_mask = all_atom_mask.squeeze(-1) - loss = torch.sum(errors * all_atom_mask, dim=-1) / (eps + torch.sum(all_atom_mask, dim=-1)) - - loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution)) - - # Average over the batch dimension - loss = torch.mean(loss) - - return loss - - -def distogram_loss( - logits, - pseudo_beta, - pseudo_beta_mask, - min_bin=2.3125, - max_bin=21.6875, - no_bins=64, - eps=1e-6, - **kwargs, -): - boundaries = torch.linspace( - min_bin, - max_bin, - no_bins - 1, - device=logits.device, - ) - boundaries = boundaries**2 - - dists = torch.sum( - (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :])**2, - dim=-1, - keepdims=True, - ) - - true_bins = torch.sum(dists > boundaries, dim=-1) - - errors = softmax_cross_entropy( - logits, - torch.nn.functional.one_hot(true_bins, no_bins), - ) - - square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :] - - # FP16-friendly sum. Equivalent to: - # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) / - # (eps + torch.sum(square_mask, dim=(-1, -2)))) - denom = eps + torch.sum(square_mask, dim=(-1, -2)) - mean = errors * square_mask - mean = torch.sum(mean, dim=-1) - mean = mean / denom[..., None] - mean = torch.sum(mean, dim=-1) - - # Average over the batch dimensions - mean = torch.mean(mean) - - return mean - - -def _calculate_bin_centers(boundaries: torch.Tensor): - step = boundaries[1] - boundaries[0] - bin_centers = boundaries + step / 2 - bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0) - return bin_centers - - -def _calculate_expected_aligned_error( - alignment_confidence_breaks: torch.Tensor, - aligned_distance_error_probs: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - bin_centers = _calculate_bin_centers(alignment_confidence_breaks) - return ( - torch.sum(aligned_distance_error_probs * bin_centers, dim=-1), - bin_centers[-1], - ) - - -def compute_predicted_aligned_error( - logits: torch.Tensor, - max_bin: int = 31, - no_bins: int = 64, - **kwargs, -) -> Dict[str, torch.Tensor]: - """Computes aligned confidence metrics from logits. - - Args: - logits: [*, num_res, num_res, num_bins] the logits output from - PredictedAlignedErrorHead. - max_bin: Maximum bin value - no_bins: Number of bins - Returns: - aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted - aligned error probabilities over bins for each residue pair. - predicted_aligned_error: [*, num_res, num_res] the expected aligned distance - error for each pair of residues. - max_predicted_aligned_error: [*] the maximum predicted error possible. - """ - boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) - - aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1) - ( - predicted_aligned_error, - max_predicted_aligned_error, - ) = _calculate_expected_aligned_error( - alignment_confidence_breaks=boundaries, - aligned_distance_error_probs=aligned_confidence_probs, - ) - - return { - "aligned_confidence_probs": aligned_confidence_probs, - "predicted_aligned_error": predicted_aligned_error, - "max_predicted_aligned_error": max_predicted_aligned_error, - } - - -def compute_tm( - logits: torch.Tensor, - residue_weights: Optional[torch.Tensor] = None, - max_bin: int = 31, - no_bins: int = 64, - eps: float = 1e-8, - **kwargs, -) -> torch.Tensor: - if residue_weights is None: - residue_weights = logits.new_ones(logits.shape[-2]) - - boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) - - bin_centers = _calculate_bin_centers(boundaries) - torch.sum(residue_weights) - n = logits.shape[-2] - clipped_n = max(n, 19) - - d0 = 1.24 * (clipped_n - 15)**(1.0 / 3) - 1.8 - - probs = torch.nn.functional.softmax(logits, dim=-1) - - tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2)) - predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) - - normed_residue_mask = residue_weights / (eps + residue_weights.sum()) - per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) - weighted = per_alignment * residue_weights - argmax = (weighted == torch.max(weighted)).nonzero()[0] - return per_alignment[tuple(argmax)] - - -def tm_loss( - logits, - final_affine_tensor, - backbone_rigid_tensor, - backbone_rigid_mask, - resolution, - max_bin=31, - no_bins=64, - min_resolution: float = 0.1, - max_resolution: float = 3.0, - eps=1e-8, - **kwargs, -): - pred_affine = Rigid.from_tensor_7(final_affine_tensor) - backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor) - - def _points(affine): - pts = affine.get_trans()[..., None, :, :] - return affine.invert()[..., None].apply(pts) - - sq_diff = torch.sum((_points(pred_affine) - _points(backbone_rigid))**2, dim=-1) - - sq_diff = sq_diff.detach() - - boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) - boundaries = boundaries**2 - true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1) - - errors = softmax_cross_entropy(logits, torch.nn.functional.one_hot(true_bins, no_bins)) - - square_mask = (backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]) - - loss = torch.sum(errors * square_mask, dim=-1) - scale = 0.5 # hack to help FP16 training along - denom = eps + torch.sum(scale * square_mask, dim=(-1, -2)) - loss = loss / denom[..., None] - loss = torch.sum(loss, dim=-1) - loss = loss * scale - - loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution)) - - # Average over the loss dimension - loss = torch.mean(loss) - - return loss - - -def between_residue_bond_loss( - pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3) - pred_atom_mask: torch.Tensor, # (*, N, 37/14) - residue_index: torch.Tensor, # (*, N) - aatype: torch.Tensor, # (*, N) - tolerance_factor_soft=12.0, - tolerance_factor_hard=12.0, - eps=1e-6, -) -> Dict[str, torch.Tensor]: - """Flat-bottom loss to penalize structural violations between residues. - - This is a loss penalizing any violation of the geometry around the peptide - bond between consecutive amino acids. This loss corresponds to - Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. - - Args: - pred_atom_positions: Atom positions in atom37/14 representation - pred_atom_mask: Atom mask in atom37/14 representation - residue_index: Residue index for given amino acid, this is assumed to be - monotonically increasing. - aatype: Amino acid type of given residue - tolerance_factor_soft: soft tolerance factor measured in standard deviations - of pdb distributions - tolerance_factor_hard: hard tolerance factor measured in standard deviations - of pdb distributions - - Returns: - Dict containing: - * 'c_n_loss_mean': Loss for peptide bond length violations - * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned - by CA, C, N - * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned - by C, N, CA - * 'per_residue_loss_sum': sum of all losses for each residue - * 'per_residue_violation_mask': mask denoting all residues with violation - present. - """ - # Get the positions of the relevant backbone atoms. - this_ca_pos = pred_atom_positions[..., :-1, 1, :] - this_ca_mask = pred_atom_mask[..., :-1, 1] - this_c_pos = pred_atom_positions[..., :-1, 2, :] - this_c_mask = pred_atom_mask[..., :-1, 2] - next_n_pos = pred_atom_positions[..., 1:, 0, :] - next_n_mask = pred_atom_mask[..., 1:, 0] - next_ca_pos = pred_atom_positions[..., 1:, 1, :] - next_ca_mask = pred_atom_mask[..., 1:, 1] - has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 - - # Compute loss for the C--N bond. - c_n_bond_length = torch.sqrt(eps + torch.sum((this_c_pos - next_n_pos)**2, dim=-1)) - - # The C-N bond to proline has slightly different length because of the ring. - next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"] - gt_length = (~next_is_proline) * residue_constants.between_res_bond_length_c_n[ - 0] + next_is_proline * residue_constants.between_res_bond_length_c_n[1] - gt_stddev = (~next_is_proline) * residue_constants.between_res_bond_length_stddev_c_n[ - 0] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1] - c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length)**2) - c_n_loss_per_residue = torch.nn.functional.relu(c_n_bond_length_error - tolerance_factor_soft * gt_stddev) - mask = this_c_mask * next_n_mask * has_no_gap_mask - c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (torch.sum(mask, dim=-1) + eps) - c_n_violation_mask = mask * (c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) - - # Compute loss for the angles. - ca_c_bond_length = torch.sqrt(eps + torch.sum((this_ca_pos - this_c_pos)**2, dim=-1)) - n_ca_bond_length = torch.sqrt(eps + torch.sum((next_n_pos - next_ca_pos)**2, dim=-1)) - - c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None] - c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None] - n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None] - - ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1) - gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] - gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] - ca_c_n_cos_angle_error = torch.sqrt(eps + (ca_c_n_cos_angle - gt_angle)**2) - ca_c_n_loss_per_residue = torch.nn.functional.relu(ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) - mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask - ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (torch.sum(mask, dim=-1) + eps) - ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)) - - c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1) - gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] - gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] - c_n_ca_cos_angle_error = torch.sqrt(eps + torch.square(c_n_ca_cos_angle - gt_angle)) - c_n_ca_loss_per_residue = torch.nn.functional.relu(c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) - mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask - c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (torch.sum(mask, dim=-1) + eps) - c_n_ca_violation_mask = mask * (c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) - - # Compute a per residue loss (equally distribute the loss to both - # neighbouring residues). - per_residue_loss_sum = (c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue) - per_residue_loss_sum = 0.5 * (torch.nn.functional.pad(per_residue_loss_sum, (0, 1)) + - torch.nn.functional.pad(per_residue_loss_sum, (1, 0))) - - # Compute hard violations. - violation_mask = torch.max( - torch.stack( - [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask], - dim=-2, - ), - dim=-2, - )[0] - violation_mask = torch.maximum( - torch.nn.functional.pad(violation_mask, (0, 1)), - torch.nn.functional.pad(violation_mask, (1, 0)), - ) - - return { - "c_n_loss_mean": c_n_loss, - "ca_c_n_loss_mean": ca_c_n_loss, - "c_n_ca_loss_mean": c_n_ca_loss, - "per_residue_loss_sum": per_residue_loss_sum, - "per_residue_violation_mask": violation_mask, - } - - -def between_residue_clash_loss( - atom14_pred_positions: torch.Tensor, - atom14_atom_exists: torch.Tensor, - atom14_atom_radius: torch.Tensor, - residue_index: torch.Tensor, - overlap_tolerance_soft=1.5, - overlap_tolerance_hard=1.5, - eps=1e-10, -) -> Dict[str, torch.Tensor]: - """Loss to penalize steric clashes between residues. - - This is a loss penalizing any steric clashes due to non bonded atoms in - different peptides coming too close. This loss corresponds to the part with - different residues of - Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. - - Args: - atom14_pred_positions: Predicted positions of atoms in - global prediction frame - atom14_atom_exists: Mask denoting whether atom at positions exists for given - amino acid type - atom14_atom_radius: Van der Waals radius for each atom. - residue_index: Residue index for given amino acid. - overlap_tolerance_soft: Soft tolerance factor. - overlap_tolerance_hard: Hard tolerance factor. - - Returns: - Dict containing: - * 'mean_loss': average clash loss - * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14) - * 'per_atom_clash_mask': mask whether atom clashes with any other atom - shape (N, 14) - """ - fp_type = atom14_pred_positions.dtype - - # Create the distance matrix. - # (N, N, 14, 14) - dists = torch.sqrt(eps + torch.sum( - (atom14_pred_positions[..., :, None, :, None, :] - atom14_pred_positions[..., None, :, None, :, :])**2, - dim=-1, - )) - - # Create the mask for valid distances. - # shape (N, N, 14, 14) - dists_mask = (atom14_atom_exists[..., :, None, :, None] * atom14_atom_exists[..., None, :, None, :]).type(fp_type) - - # Mask out all the duplicate entries in the lower triangular matrix. - # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms - # are handled separately. - dists_mask = dists_mask * (residue_index[..., :, None, None, None] < residue_index[..., None, :, None, None]) - - # Backbone C--N bond between subsequent residues is no clash. - c_one_hot = torch.nn.functional.one_hot(residue_index.new_tensor(2), num_classes=14) - c_one_hot = c_one_hot.reshape(*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape) - c_one_hot = c_one_hot.type(fp_type) - n_one_hot = torch.nn.functional.one_hot(residue_index.new_tensor(0), num_classes=14) - n_one_hot = n_one_hot.reshape(*((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape) - n_one_hot = n_one_hot.type(fp_type) - - neighbour_mask = (residue_index[..., :, None, None, None] + 1) == residue_index[..., None, :, None, None] - c_n_bonds = (neighbour_mask * c_one_hot[..., None, None, :, None] * n_one_hot[..., None, None, None, :]) - dists_mask = dists_mask * (1.0 - c_n_bonds) - - # Disulfide bridge between two cysteines is no clash. - cys = residue_constants.restype_name_to_atom14_names["CYS"] - cys_sg_idx = cys.index("SG") - cys_sg_idx = residue_index.new_tensor(cys_sg_idx) - cys_sg_idx = cys_sg_idx.reshape(*((1,) * len(residue_index.shape[:-1])), 1).squeeze(-1) - cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14) - disulfide_bonds = (cys_sg_one_hot[..., None, None, :, None] * cys_sg_one_hot[..., None, None, None, :]) - dists_mask = dists_mask * (1.0 - disulfide_bonds) - - # Compute the lower bound for the allowed distances. - # shape (N, N, 14, 14) - dists_lower_bound = dists_mask * (atom14_atom_radius[..., :, None, :, None] + - atom14_atom_radius[..., None, :, None, :]) - - # Compute the error. - # shape (N, N, 14, 14) - dists_to_low_error = dists_mask * torch.nn.functional.relu(dists_lower_bound - overlap_tolerance_soft - dists) - - # Compute the mean loss. - # shape () - mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask)) - - # Compute the per atom loss sum. - # shape (N, 14) - per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(dists_to_low_error, axis=(-3, -1)) - - # Compute the hard clash mask. - # shape (N, N, 14, 14) - clash_mask = dists_mask * (dists < (dists_lower_bound - overlap_tolerance_hard)) - - # Compute the per atom clash. - # shape (N, 14) - per_atom_clash_mask = torch.maximum( - torch.amax(clash_mask, axis=(-4, -2)), - torch.amax(clash_mask, axis=(-3, -1)), - ) - - return { - "mean_loss": mean_loss, # shape () - "per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14) - "per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14) - } - - -def within_residue_violations( - atom14_pred_positions: torch.Tensor, - atom14_atom_exists: torch.Tensor, - atom14_dists_lower_bound: torch.Tensor, - atom14_dists_upper_bound: torch.Tensor, - tighten_bounds_for_loss=0.0, - eps=1e-10, -) -> Dict[str, torch.Tensor]: - """Loss to penalize steric clashes within residues. - - This is a loss penalizing any steric violations or clashes of non-bonded atoms - in a given peptide. This loss corresponds to the part with - the same residues of - Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. - - Args: - atom14_pred_positions ([*, N, 14, 3]): - Predicted positions of atoms in global prediction frame. - atom14_atom_exists ([*, N, 14]): - Mask denoting whether atom at positions exists for given - amino acid type - atom14_dists_lower_bound ([*, N, 14]): - Lower bound on allowed distances. - atom14_dists_upper_bound ([*, N, 14]): - Upper bound on allowed distances - tighten_bounds_for_loss ([*, N]): - Extra factor to tighten loss - - Returns: - Dict containing: - * 'per_atom_loss_sum' ([*, N, 14]): - sum of all clash losses per atom, shape - * 'per_atom_clash_mask' ([*, N, 14]): - mask whether atom clashes with any other atom shape - """ - # Compute the mask for each residue. - dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None] - dists_masks = dists_masks.reshape(*((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape) - dists_masks = (atom14_atom_exists[..., :, :, None] * atom14_atom_exists[..., :, None, :] * dists_masks) - - # Distance matrix - dists = torch.sqrt(eps + torch.sum( - (atom14_pred_positions[..., :, :, None, :] - atom14_pred_positions[..., :, None, :, :])**2, - dim=-1, - )) - - # Compute the loss. - dists_to_low_error = torch.nn.functional.relu(atom14_dists_lower_bound + tighten_bounds_for_loss - dists) - dists_to_high_error = torch.nn.functional.relu(dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)) - loss = dists_masks * (dists_to_low_error + dists_to_high_error) - - # Compute the per atom loss sum. - per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1) - - # Compute the violations mask. - violations = dists_masks * ((dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)) - - # Compute the per atom violations. - per_atom_violations = torch.maximum(torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]) - - return { - "per_atom_loss_sum": per_atom_loss_sum, - "per_atom_violations": per_atom_violations, - } - - -def find_structural_violations( - batch: Dict[str, torch.Tensor], - atom14_pred_positions: torch.Tensor, - violation_tolerance_factor: float, - clash_overlap_tolerance: float, - **kwargs, -) -> Dict[str, torch.Tensor]: - """Computes several checks for structural violations.""" - - # Compute between residue backbone violations of bonds and angles. - connection_violations = between_residue_bond_loss( - pred_atom_positions=atom14_pred_positions, - pred_atom_mask=batch["atom14_atom_exists"], - residue_index=batch["residue_index"], - aatype=batch["aatype"], - tolerance_factor_soft=violation_tolerance_factor, - tolerance_factor_hard=violation_tolerance_factor, - ) - - # Compute the Van der Waals radius for every atom - # (the first letter of the atom name is the element type). - # Shape: (N, 14). - atomtype_radius = [residue_constants.van_der_waals_radius[name[0]] for name in residue_constants.atom_types] - atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius) - atom14_atom_radius = (batch["atom14_atom_exists"] * atomtype_radius[batch["residx_atom14_to_atom37"]]) - - # Compute the between residue clash loss. - between_residue_clashes = between_residue_clash_loss( - atom14_pred_positions=atom14_pred_positions, - atom14_atom_exists=batch["atom14_atom_exists"], - atom14_atom_radius=atom14_atom_radius, - residue_index=batch["residue_index"], - overlap_tolerance_soft=clash_overlap_tolerance, - overlap_tolerance_hard=clash_overlap_tolerance, - ) - - # Compute all within-residue violations (clashes, - # bond length and angle violations). - restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( - overlap_tolerance=clash_overlap_tolerance, - bond_length_tolerance_factor=violation_tolerance_factor, - ) - atom14_atom_exists = batch["atom14_atom_exists"] - atom14_dists_lower_bound = atom14_pred_positions.new_tensor(restype_atom14_bounds["lower_bound"])[batch["aatype"]] - atom14_dists_upper_bound = atom14_pred_positions.new_tensor(restype_atom14_bounds["upper_bound"])[batch["aatype"]] - residue_violations = within_residue_violations( - atom14_pred_positions=atom14_pred_positions, - atom14_atom_exists=batch["atom14_atom_exists"], - atom14_dists_lower_bound=atom14_dists_lower_bound, - atom14_dists_upper_bound=atom14_dists_upper_bound, - tighten_bounds_for_loss=0.0, - ) - - # Combine them to a single per-residue violation mask (used later for LDDT). - per_residue_violations_mask = torch.max( - torch.stack( - [ - connection_violations["per_residue_violation_mask"], - torch.max(between_residue_clashes["per_atom_clash_mask"], dim=-1)[0], - torch.max(residue_violations["per_atom_violations"], dim=-1)[0], - ], - dim=-1, - ), - dim=-1, - )[0] - - return { - "between_residues": { - "bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # () - "angles_ca_c_n_loss_mean": connection_violations["ca_c_n_loss_mean"], # () - "angles_c_n_ca_loss_mean": connection_violations["c_n_ca_loss_mean"], # () - "connections_per_residue_loss_sum": connection_violations["per_residue_loss_sum"], # (N) - "connections_per_residue_violation_mask": connection_violations["per_residue_violation_mask"], # (N) - "clashes_mean_loss": between_residue_clashes["mean_loss"], # () - "clashes_per_atom_loss_sum": between_residue_clashes["per_atom_loss_sum"], # (N, 14) - "clashes_per_atom_clash_mask": between_residue_clashes["per_atom_clash_mask"], # (N, 14) - }, - "within_residues": { - "per_atom_loss_sum": residue_violations["per_atom_loss_sum"], # (N, 14) - "per_atom_violations": residue_violations["per_atom_violations"], # (N, 14), - }, - "total_per_residue_violations_mask": per_residue_violations_mask, # (N) - } - - -def find_structural_violations_np( - batch: Dict[str, np.ndarray], - atom14_pred_positions: np.ndarray, - config: ml_collections.ConfigDict, -) -> Dict[str, np.ndarray]: - to_tensor = lambda x: torch.tensor(x) - batch = tree_map(to_tensor, batch, np.ndarray) - atom14_pred_positions = to_tensor(atom14_pred_positions) - - out = find_structural_violations(batch, atom14_pred_positions, **config) - - to_np = lambda x: np.array(x) - np_out = tensor_tree_map(to_np, out) - - return np_out - - -def extreme_ca_ca_distance_violations( - pred_atom_positions: torch.Tensor, # (N, 37(14), 3) - pred_atom_mask: torch.Tensor, # (N, 37(14)) - residue_index: torch.Tensor, # (N) - max_angstrom_tolerance=1.5, - eps=1e-6, -) -> torch.Tensor: - """Counts residues whose Ca is a large distance from its neighbour. - - Measures the fraction of CA-CA pairs between consecutive amino acids that are - more than 'max_angstrom_tolerance' apart. - - Args: - pred_atom_positions: Atom positions in atom37/14 representation - pred_atom_mask: Atom mask in atom37/14 representation - residue_index: Residue index for given amino acid, this is assumed to be - monotonically increasing. - max_angstrom_tolerance: Maximum distance allowed to not count as violation. - Returns: - Fraction of consecutive CA-CA pairs with violation. - """ - this_ca_pos = pred_atom_positions[..., :-1, 1, :] - this_ca_mask = pred_atom_mask[..., :-1, 1] - next_ca_pos = pred_atom_positions[..., 1:, 1, :] - next_ca_mask = pred_atom_mask[..., 1:, 1] - has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 - ca_ca_distance = torch.sqrt(eps + torch.sum((this_ca_pos - next_ca_pos)**2, dim=-1)) - violations = (ca_ca_distance - residue_constants.ca_ca) > max_angstrom_tolerance - mask = this_ca_mask * next_ca_mask * has_no_gap_mask - mean = masked_mean(mask, violations, -1) - return mean - - -def compute_violation_metrics( - batch: Dict[str, torch.Tensor], - atom14_pred_positions: torch.Tensor, # (N, 14, 3) - violations: Dict[str, torch.Tensor], -) -> Dict[str, torch.Tensor]: - """Compute several metrics to assess the structural violations.""" - ret = {} - extreme_ca_ca_violations = extreme_ca_ca_distance_violations( - pred_atom_positions=atom14_pred_positions, - pred_atom_mask=batch["atom14_atom_exists"], - residue_index=batch["residue_index"], - ) - ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations - ret["violations_between_residue_bond"] = masked_mean( - batch["seq_mask"], - violations["between_residues"]["connections_per_residue_violation_mask"], - dim=-1, - ) - ret["violations_between_residue_clash"] = masked_mean( - mask=batch["seq_mask"], - value=torch.max( - violations["between_residues"]["clashes_per_atom_clash_mask"], - dim=-1, - )[0], - dim=-1, - ) - ret["violations_within_residue"] = masked_mean( - mask=batch["seq_mask"], - value=torch.max(violations["within_residues"]["per_atom_violations"], dim=-1)[0], - dim=-1, - ) - ret["violations_per_residue"] = masked_mean( - mask=batch["seq_mask"], - value=violations["total_per_residue_violations_mask"], - dim=-1, - ) - return ret - - -def compute_violation_metrics_np( - batch: Dict[str, np.ndarray], - atom14_pred_positions: np.ndarray, - violations: Dict[str, np.ndarray], -) -> Dict[str, np.ndarray]: - to_tensor = lambda x: torch.tensor(x) - batch = tree_map(to_tensor, batch, np.ndarray) - atom14_pred_positions = to_tensor(atom14_pred_positions) - violations = tree_map(to_tensor, violations, np.ndarray) - - out = compute_violation_metrics(batch, atom14_pred_positions, violations) - - to_np = lambda x: np.array(x) - return tree_map(to_np, out, torch.Tensor) - - -def violation_loss( - violations: Dict[str, torch.Tensor], - atom14_atom_exists: torch.Tensor, - eps=1e-6, - **kwargs, -) -> torch.Tensor: - num_atoms = torch.sum(atom14_atom_exists) - l_clash = torch.sum(violations["between_residues"]["clashes_per_atom_loss_sum"] + - violations["within_residues"]["per_atom_loss_sum"]) - l_clash = l_clash / (eps + num_atoms) - loss = (violations["between_residues"]["bonds_c_n_loss_mean"] + - violations["between_residues"]["angles_ca_c_n_loss_mean"] + - violations["between_residues"]["angles_c_n_ca_loss_mean"] + l_clash) - - return loss - - -def compute_renamed_ground_truth( - batch: Dict[str, torch.Tensor], - atom14_pred_positions: torch.Tensor, - eps=1e-10, -) -> Dict[str, torch.Tensor]: - """ - Find optimal renaming of ground truth based on the predicted positions. - - Alg. 26 "renameSymmetricGroundTruthAtoms" - - This renamed ground truth is then used for all losses, - such that each loss moves the atoms in the same direction. - - Args: - batch: Dictionary containing: - * atom14_gt_positions: Ground truth positions. - * atom14_alt_gt_positions: Ground truth positions with renaming swaps. - * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by - renaming swaps. - * atom14_gt_exists: Mask for which atoms exist in ground truth. - * atom14_alt_gt_exists: Mask for which atoms exist in ground truth - after renaming. - * atom14_atom_exists: Mask for whether each atom is part of the given - amino acid type. - atom14_pred_positions: Array of atom positions in global frame with shape - Returns: - Dictionary containing: - alt_naming_is_better: Array with 1.0 where alternative swap is better. - renamed_atom14_gt_positions: Array of optimal ground truth positions - after renaming swaps are performed. - renamed_atom14_gt_exists: Mask after renaming swap is performed. - """ - - pred_dists = torch.sqrt(eps + torch.sum( - (atom14_pred_positions[..., None, :, None, :] - atom14_pred_positions[..., None, :, None, :, :])**2, - dim=-1, - )) - - atom14_gt_positions = batch["atom14_gt_positions"] - gt_dists = torch.sqrt(eps + torch.sum( - (atom14_gt_positions[..., None, :, None, :] - atom14_gt_positions[..., None, :, None, :, :])**2, - dim=-1, - )) - - atom14_alt_gt_positions = batch["atom14_alt_gt_positions"] - alt_gt_dists = torch.sqrt(eps + torch.sum( - (atom14_alt_gt_positions[..., None, :, None, :] - atom14_alt_gt_positions[..., None, :, None, :, :])**2, - dim=-1, - )) - - lddt = torch.sqrt(eps + (pred_dists - gt_dists)**2) - alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists)**2) - - atom14_gt_exists = batch["atom14_gt_exists"] - atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"] - mask = (atom14_gt_exists[..., None, :, None] * atom14_atom_is_ambiguous[..., None, :, None] * - atom14_gt_exists[..., None, :, None, :] * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])) - - per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3)) - alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3)) - - fp_type = atom14_pred_positions.dtype - alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type) - - renamed_atom14_gt_positions = (1.0 - alt_naming_is_better[ - ..., None, None]) * atom14_gt_positions + alt_naming_is_better[..., None, None] * atom14_alt_gt_positions - - renamed_atom14_gt_mask = (1.0 - alt_naming_is_better[..., None] - ) * atom14_gt_exists + alt_naming_is_better[..., None] * batch["atom14_alt_gt_exists"] - - return { - "alt_naming_is_better": alt_naming_is_better, - "renamed_atom14_gt_positions": renamed_atom14_gt_positions, - "renamed_atom14_gt_exists": renamed_atom14_gt_mask, - } - - -def experimentally_resolved_loss( - logits: torch.Tensor, - atom37_atom_exists: torch.Tensor, - all_atom_mask: torch.Tensor, - resolution: torch.Tensor, - min_resolution: float, - max_resolution: float, - eps: float = 1e-8, - **kwargs, -) -> torch.Tensor: - errors = sigmoid_cross_entropy(logits, all_atom_mask) - loss = torch.sum(errors * atom37_atom_exists, dim=-1) - loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2))) - loss = torch.sum(loss, dim=-1) - - loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution)) - - loss = torch.mean(loss) - - return loss - - -def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): - """ - Computes BERT-style masked MSA loss. Implements subsection 1.9.9. - - Args: - logits: [*, N_seq, N_res, 23] predicted residue distribution - true_msa: [*, N_seq, N_res] true MSA - bert_mask: [*, N_seq, N_res] MSA mask - Returns: - Masked MSA loss - """ - errors = softmax_cross_entropy(logits, torch.nn.functional.one_hot(true_msa, num_classes=23)) - - # FP16-friendly averaging. Equivalent to: - # loss = ( - # torch.sum(errors * bert_mask, dim=(-1, -2)) / - # (eps + torch.sum(bert_mask, dim=(-1, -2))) - # ) - loss = errors * bert_mask - loss = torch.sum(loss, dim=-1) - scale = 0.5 - denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2)) - loss = loss / denom[..., None] - loss = torch.sum(loss, dim=-1) - loss = loss * scale - - loss = torch.mean(loss) - - return loss - - -def compute_drmsd(structure_1, structure_2, mask=None): - if (mask is not None): - structure_1 = structure_1 * mask[..., None] - structure_2 = structure_2 * mask[..., None] - - d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :] - d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :] - - d1 = d1**2 - d2 = d2**2 - - d1 = torch.sqrt(torch.sum(d1, dim=-1)) - d2 = torch.sqrt(torch.sum(d2, dim=-1)) - - drmsd = d1 - d2 - drmsd = drmsd**2 - drmsd = torch.sum(drmsd, dim=(-1, -2)) - n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) - drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.) - drmsd = torch.sqrt(drmsd) - - return drmsd - - -def compute_drmsd_np(structure_1, structure_2, mask=None): - structure_1 = torch.tensor(structure_1) - structure_2 = torch.tensor(structure_2) - if (mask is not None): - mask = torch.tensor(mask) - - return compute_drmsd(structure_1, structure_2, mask) - - -class AlphaFoldLoss(nn.Module): - """Aggregation of the various losses described in the supplement""" - - def __init__(self, config): - super(AlphaFoldLoss, self).__init__() - self.config = config - - def forward(self, out, batch, _return_breakdown=False): - if "violation" not in out.keys(): - out["violation"] = find_structural_violations( - batch, - out["sm"]["positions"][-1], - **self.config.violation, - ) - - if "renamed_atom14_gt_positions" not in out.keys(): - batch.update(compute_renamed_ground_truth( - batch, - out["sm"]["positions"][-1], - )) - - loss_fns = { - "distogram": - lambda: distogram_loss( - logits=out["distogram_logits"], - **{ - **batch, - **self.config.distogram - }, - ), - "experimentally_resolved": - lambda: experimentally_resolved_loss( - logits=out["experimentally_resolved_logits"], - **{ - **batch, - **self.config.experimentally_resolved - }, - ), - "fape": - lambda: fape_loss( - out, - batch, - self.config.fape, - ), - "lddt": - lambda: lddt_loss( - logits=out["lddt_logits"], - all_atom_pred_pos=out["final_atom_positions"], - **{ - **batch, - **self.config.lddt - }, - ), - "masked_msa": - lambda: masked_msa_loss( - logits=out["masked_msa_logits"], - **{ - **batch, - **self.config.masked_msa - }, - ), - "supervised_chi": - lambda: supervised_chi_loss( - out["sm"]["angles"], - out["sm"]["unnormalized_angles"], - **{ - **batch, - **self.config.supervised_chi - }, - ), - "violation": - lambda: violation_loss( - out["violation"], - **batch, - ), - } - - if (self.config.tm.enabled): - loss_fns["tm"] = lambda: tm_loss( - logits=out["tm_logits"], - **{ - **batch, - **out, - **self.config.tm - }, - ) - - cum_loss = 0. - losses = {} - for loss_name, loss_fn in loss_fns.items(): - weight = self.config[loss_name].weight - loss = loss_fn() - if (torch.isnan(loss) or torch.isinf(loss)): - logging.warning(f"{loss_name} loss is NaN. Skipping...") - loss = loss.new_tensor(0., requires_grad=True) - cum_loss = cum_loss + weight * loss - losses[loss_name] = loss.detach().clone() - - losses["unscaled_loss"] = cum_loss.detach().clone() - - # Scale the loss by the square root of the minimum of the crop size and - # the (average) sequence length. See subsection 1.9. - seq_len = torch.mean(batch["seq_length"].float()) - crop_len = batch["aatype"].shape[-1] - cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len)) - - losses["loss"] = cum_loss.detach().clone() - - if (not _return_breakdown): - return cum_loss - - return cum_loss, losses diff --git a/tests/test_autochunk/origin_openfold/utils/tensor_utils.py b/tests/test_autochunk/origin_openfold/utils/tensor_utils.py deleted file mode 100644 index 5d5b3c32b5c6..000000000000 --- a/tests/test_autochunk/origin_openfold/utils/tensor_utils.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple - -import torch -import torch.nn as nn - - -def permute_final_dims(tensor: torch.Tensor, inds: List[int]): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - - -def flatten_final_dims(t: torch.Tensor, no_dims: int): - return t.reshape(t.shape[:-no_dims] + (-1,)) - - -def masked_mean(mask, value, dim, eps=1e-4): - mask = mask.expand(*value.shape) - return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) - - -def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): - boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device) - dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3))**2, dim=-1)) - return torch.bucketize(dists, boundaries) - - -def dict_multimap(fn, dicts): - first = dicts[0] - new_dict = {} - for k, v in first.items(): - all_v = [d[k] for d in dicts] - if type(v) is dict: - new_dict[k] = dict_multimap(fn, all_v) - else: - new_dict[k] = fn(all_v) - - return new_dict - - -def one_hot(x, v_bins): - reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) - diffs = x[..., None] - reshaped_bins - am = torch.argmin(torch.abs(diffs), dim=-1) - return nn.functional.one_hot(am, num_classes=len(v_bins)).float() - - -def batched_gather(data, inds, dim=0, no_batch_dims=0): - ranges = [] - for i, s in enumerate(data.shape[:no_batch_dims]): - r = torch.arange(s) - r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) - ranges.append(r) - - remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] - remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds - ranges.extend(remaining_dims) - return data[ranges] - - -# With tree_map, a poor man's JAX tree_map -def dict_map(fn, dic, leaf_type): - new_dict = {} - for k, v in dic.items(): - if type(v) is dict: - new_dict[k] = dict_map(fn, v, leaf_type) - else: - new_dict[k] = tree_map(fn, v, leaf_type) - - return new_dict - - -def tree_map(fn, tree, leaf_type): - if isinstance(tree, dict): - return dict_map(fn, tree, leaf_type) - elif isinstance(tree, list): - return [tree_map(fn, x, leaf_type) for x in tree] - elif isinstance(tree, tuple): - return tuple([tree_map(fn, x, leaf_type) for x in tree]) - elif isinstance(tree, leaf_type): - return fn(tree) - else: - print(type(tree)) - raise ValueError("Not supported") - - -tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) - - -def _fetch_dims(tree): - shapes = [] - tree_type = type(tree) - if tree_type is dict: - for v in tree.values(): - shapes.extend(_fetch_dims(v)) - elif tree_type is list or tree_type is tuple: - for t in tree: - shapes.extend(_fetch_dims(t)) - elif tree_type is torch.Tensor: - shapes.append(tree.shape) - else: - raise ValueError("Not supported") - - return shapes - - -@torch.jit.ignore -def _flat_idx_to_idx( - flat_idx: int, - dims: Tuple[int], -) -> Tuple[int]: - idx = [] - for d in reversed(dims): - idx.append(flat_idx % d) - flat_idx = flat_idx // d - - return tuple(reversed(idx)) - - -@torch.jit.ignore -def _get_minimal_slice_set( - start: Sequence[int], - end: Sequence[int], - dims: int, - start_edges: Optional[Sequence[bool]] = None, - end_edges: Optional[Sequence[bool]] = None, -) -> Sequence[Tuple[int]]: - """ - Produces an ordered sequence of tensor slices that, when used in - sequence on a tensor with shape dims, yields tensors that contain every - leaf in the contiguous range [start, end]. Care is taken to yield a - short sequence of slices, and perhaps even the shortest possible (I'm - pretty sure it's the latter). - - end is INCLUSIVE. - """ - - # start_edges and end_edges both indicate whether, starting from any given - # dimension, the start/end index is at the top/bottom edge of the - # corresponding tensor, modeled as a tree - def reduce_edge_list(ll): - tally = 1 - for i in range(len(ll)): - reversed_idx = -1 * (i + 1) - ll[reversed_idx] *= tally - tally = ll[reversed_idx] - - if (start_edges is None): - start_edges = [s == 0 for s in start] - reduce_edge_list(start_edges) - if (end_edges is None): - end_edges = [e == (d - 1) for e, d in zip(end, dims)] - reduce_edge_list(end_edges) - - # Base cases. Either start/end are empty and we're done, or the final, - # one-dimensional tensor can be simply sliced - if (len(start) == 0): - return [tuple()] - elif (len(start) == 1): - return [(slice(start[0], end[0] + 1),)] - - slices = [] - path = [] - - # Dimensions common to start and end can be selected directly - for s, e in zip(start, end): - if (s == e): - path.append(slice(s, s + 1)) - else: - break - - path = tuple(path) - divergence_idx = len(path) - - # start == end, and we're done - if (divergence_idx == len(dims)): - return [tuple(path)] - - def upper(): - sdi = start[divergence_idx] - return [ - path + (slice(sdi, sdi + 1),) + s - for s in _get_minimal_slice_set(start[divergence_idx + 1:], [d - 1 for d in dims[divergence_idx + 1:]], - dims[divergence_idx + 1:], - start_edges=start_edges[divergence_idx + 1:], - end_edges=[1 for _ in end_edges[divergence_idx + 1:]]) - ] - - def lower(): - edi = end[divergence_idx] - return [ - path + (slice(edi, edi + 1),) + s for s in _get_minimal_slice_set( - [0 for _ in start[divergence_idx + 1:]], - end[divergence_idx + 1:], - dims[divergence_idx + 1:], - start_edges=[1 for _ in start_edges[divergence_idx + 1:]], - end_edges=end_edges[divergence_idx + 1:], - ) - ] - - # If both start and end are at the edges of the subtree rooted at - # divergence_idx, we can just select the whole subtree at once - if (start_edges[divergence_idx] and end_edges[divergence_idx]): - slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),)) - # If just start is at the edge, we can grab almost all of the subtree, - # treating only the ragged bottom edge as an edge case - elif (start_edges[divergence_idx]): - slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),)) - slices.extend(lower()) - # Analogous to the previous case, but the top is ragged this time - elif (end_edges[divergence_idx]): - slices.extend(upper()) - slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)) - # If both sides of the range are ragged, we need to handle both sides - # separately. If there's contiguous meat in between them, we can index it - # in one big chunk - else: - slices.extend(upper()) - middle_ground = end[divergence_idx] - start[divergence_idx] - if (middle_ground > 1): - slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)) - slices.extend(lower()) - - return [tuple(s) for s in slices] - - -@torch.jit.ignore -def _chunk_slice( - t: torch.Tensor, - flat_start: int, - flat_end: int, - no_batch_dims: int, -) -> torch.Tensor: - """ - Equivalent to - - t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] - - but without the need for the initial reshape call, which can be - memory-intensive in certain situations. The only reshape operations - in this function are performed on sub-tensors that scale with - (flat_end - flat_start), the chunk size. - """ - - batch_dims = t.shape[:no_batch_dims] - start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) - # _get_minimal_slice_set is inclusive - end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) - - # Get an ordered list of slices to perform - slices = _get_minimal_slice_set( - start_idx, - end_idx, - batch_dims, - ) - - sliced_tensors = [t[s] for s in slices] - - return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]) - - -def chunk_layer( - layer: Callable, - inputs: Dict[str, Any], - chunk_size: int, - no_batch_dims: int, - low_mem: bool = False, -) -> Any: - """ - Implements the "chunking" procedure described in section 1.11.8. - - Layer outputs and inputs are assumed to be simple "pytrees," - consisting only of (arbitrarily nested) lists, tuples, and dicts with - torch.Tensor leaves. - - Args: - layer: - The layer to be applied chunk-wise - inputs: - A (non-nested) dictionary of keyworded inputs. All leaves must - be tensors and must share the same batch dimensions. - chunk_size: - The number of sub-batches per chunk. If multiple batch - dimensions are specified, a "sub-batch" is defined as a single - indexing of all batch dimensions simultaneously (s.t. the - number of sub-batches is the product of the batch dimensions). - no_batch_dims: - How many of the initial dimensions of each input tensor can - be considered batch dimensions. - low_mem: - Avoids flattening potentially large input tensors. Unnecessary - in most cases, and is ever so slightly slower than the default - setting. - Returns: - The reassembled output of the layer on the inputs. - """ - if not (len(inputs) > 0): - raise ValueError("Must provide at least one input") - - initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] - orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) - - def _prep_inputs(t): - # TODO: make this more memory efficient. This sucks - if (not low_mem): - if not sum(t.shape[:no_batch_dims]) == no_batch_dims: - t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) - t = t.reshape(-1, *t.shape[no_batch_dims:]) - else: - t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) - return t - - prepped_inputs = tensor_tree_map(_prep_inputs, inputs) - - flat_batch_dim = 1 - for d in orig_batch_dims: - flat_batch_dim *= d - - no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0) - - i = 0 - out = None - for _ in range(no_chunks): - # Chunk the input - if (not low_mem): - select_chunk = (lambda t: t[i:i + chunk_size] if t.shape[0] != 1 else t) - else: - select_chunk = (partial(_chunk_slice, - flat_start=i, - flat_end=min(flat_batch_dim, i + chunk_size), - no_batch_dims=len(orig_batch_dims))) - - chunks = tensor_tree_map(select_chunk, prepped_inputs) - - # Run the layer on the chunk - output_chunk = layer(**chunks) - - # Allocate space for the output - if out is None: - allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]) - out = tensor_tree_map(allocate, output_chunk) - - # Put the chunk in its pre-allocated space - out_type = type(output_chunk) - if out_type is dict: - - def assign(d1, d2): - for k, v in d1.items(): - if type(v) is dict: - assign(v, d2[k]) - else: - v[i:i + chunk_size] = d2[k] - - assign(out, output_chunk) - elif out_type is tuple: - for x1, x2 in zip(out, output_chunk): - x1[i:i + chunk_size] = x2 - elif out_type is torch.Tensor: - out[i:i + chunk_size] = output_chunk - else: - raise ValueError("Not supported") - - i += chunk_size - - reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) - out = tensor_tree_map(reshape, out) - - return out diff --git a/tests/test_autochunk/test_autochunk_openfold_codegen.py b/tests/test_autochunk/test_evoformer_codegen.py similarity index 97% rename from tests/test_autochunk/test_autochunk_openfold_codegen.py rename to tests/test_autochunk/test_evoformer_codegen.py index 11bb6d79080a..b1cd127cbd83 100644 --- a/tests/test_autochunk/test_autochunk_openfold_codegen.py +++ b/tests/test_autochunk/test_evoformer_codegen.py @@ -4,6 +4,7 @@ import torch import torch.fx import torch.multiprocessing as mp +from fastfold.model.nn.evoformer import EvoformerBlock import colossalai from colossalai.core import global_context as gpc @@ -13,7 +14,6 @@ from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace from colossalai.utils import free_port -from tests.test_autochunk.origin_openfold.evoformer import EvoformerBlock if CODEGEN_AVAILABLE and is_compatible_with_meta(): from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen @@ -108,7 +108,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), ) - codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory) + # codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory) # trace and recompile # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer From 45d25e466854de2edddd0a9f5303e2acb4ba93cd Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 17:54:11 +0800 Subject: [PATCH 12/14] rename tests --- tests/test_autochunk/test_evoformer_codegen.py | 8 ++++---- tests/test_autochunk/test_simple_evoformer_codegen.py | 8 ++++---- tests/test_autochunk/test_simple_evoformer_search.py | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/test_autochunk/test_evoformer_codegen.py b/tests/test_autochunk/test_evoformer_codegen.py index b1cd127cbd83..3579fba467ef 100644 --- a/tests/test_autochunk/test_evoformer_codegen.py +++ b/tests/test_autochunk/test_evoformer_codegen.py @@ -69,7 +69,7 @@ def _build_openfold(): return model -def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): +def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): # launch colossalai colossalai.launch( config={}, @@ -145,9 +145,9 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) -def test_autochunk_codegen(msa_len, pair_len, max_memory): +def test_evoformer_codegen(msa_len, pair_len, max_memory): run_func = partial( - _test_autochunk_codegen, + _test_evoformer_codegen, msa_len=msa_len, pair_len=pair_len, max_memory=max_memory, @@ -156,4 +156,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_autochunk_codegen(0, 32, 64, 25) + _test_evoformer_codegen(0, 32, 64, 25) diff --git a/tests/test_autochunk/test_simple_evoformer_codegen.py b/tests/test_autochunk/test_simple_evoformer_codegen.py index 02c38505a0ac..f1272330fcd9 100644 --- a/tests/test_autochunk/test_simple_evoformer_codegen.py +++ b/tests/test_autochunk/test_simple_evoformer_codegen.py @@ -53,7 +53,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): torch.abs(non_fx_out[1] - fx_out[1])) -def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): +def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory): # launch colossalai colossalai.launch( config={}, @@ -105,9 +105,9 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) -def test_autochunk_codegen(msa_len, pair_len, max_memory): +def test_simple_evoformer_codegen(msa_len, pair_len, max_memory): run_func = partial( - _test_autochunk_codegen, + _test_simple_evoformer_codegen, msa_len=msa_len, pair_len=pair_len, max_memory=max_memory, @@ -116,4 +116,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_autochunk_codegen(0, 32, 64, 25) + _test_simple_evoformer_codegen(0, 32, 64, 25) diff --git a/tests/test_autochunk/test_simple_evoformer_search.py b/tests/test_autochunk/test_simple_evoformer_search.py index e295558c88a0..04fb514fbf44 100644 --- a/tests/test_autochunk/test_simple_evoformer_search.py +++ b/tests/test_autochunk/test_simple_evoformer_search.py @@ -62,7 +62,7 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): ) -def _test_autochunk_search(rank, msa_len, pair_len, max_memory): +def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory): # launch colossalai colossalai.launch( config={}, @@ -94,9 +94,9 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory): @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) -def test_autochunk_search(msa_len, pair_len, max_memory): +def test_simple_evoformer_search(msa_len, pair_len, max_memory): run_func = partial( - _test_autochunk_search, + _test_simple_evoformer_search, msa_len=msa_len, pair_len=pair_len, max_memory=max_memory, @@ -105,4 +105,4 @@ def test_autochunk_search(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_autochunk_search(0, 32, 64, 20) + _test_simple_evoformer_search(0, 32, 64, 20) From 83b87e67b71b215d552a3724bb96cbc4836f44a5 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 17:59:33 +0800 Subject: [PATCH 13/14] add test for evoformer --- tests/test_autochunk/test_evoformer_codegen.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_autochunk/test_evoformer_codegen.py b/tests/test_autochunk/test_evoformer_codegen.py index 3579fba467ef..75b1ec55ba71 100644 --- a/tests/test_autochunk/test_evoformer_codegen.py +++ b/tests/test_autochunk/test_evoformer_codegen.py @@ -4,7 +4,12 @@ import torch import torch.fx import torch.multiprocessing as mp -from fastfold.model.nn.evoformer import EvoformerBlock + +try: + from fastfold.model.nn.evoformer import EvoformerBlock + HAS_REPO = True +except: + HAS_REPO = False import colossalai from colossalai.core import global_context as gpc @@ -139,7 +144,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): @pytest.mark.skipif( - not (CODEGEN_AVAILABLE and is_compatible_with_meta()), + not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), reason="torch version is lower than 1.12.0", ) @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) From 4e91ac93913a7abc69767d5e982972e6051f253b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 16 Jan 2023 18:48:59 +0800 Subject: [PATCH 14/14] optimize import --- tests/test_autochunk/test_evoformer_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autochunk/test_evoformer_codegen.py b/tests/test_autochunk/test_evoformer_codegen.py index 75b1ec55ba71..1273bf2fecbf 100644 --- a/tests/test_autochunk/test_evoformer_codegen.py +++ b/tests/test_autochunk/test_evoformer_codegen.py @@ -17,12 +17,12 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace from colossalai.utils import free_port if CODEGEN_AVAILABLE and is_compatible_with_meta(): from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):