Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


class RepFlowArgs:
def __init__(
self,
n_dim: int = 128,
e_dim: int = 64,
a_dim: int = 64,
nlayers: int = 6,
e_rcut: float = 6.0,
e_rcut_smth: float = 5.0,
e_sel: int = 120,
a_rcut: float = 4.0,
a_rcut_smth: float = 3.5,
a_sel: int = 20,
a_compress_rate: int = 0,
axis_neuron: int = 4,
update_angle: bool = True,
update_style: str = "res_residual",
update_residual: float = 0.1,
update_residual_init: str = "const",
) -> None:
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.

Parameters
----------
n_dim : int, optional
The dimension of node representation.
e_dim : int, optional
The dimension of edge representation.
a_dim : int, optional
The dimension of angle representation.
nlayers : int, optional
Number of repflow layers.
e_rcut : float, optional
The edge cut-off radius.
e_rcut_smth : float, optional
Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth.
e_sel : int, optional
Maximally possible number of selected edge neighbors.
a_rcut : float, optional
The angle cut-off radius.
a_rcut_smth : float, optional
Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth.
a_sel : int, optional
Maximally possible number of selected angle neighbors.
a_compress_rate : int, optional
The compression rate for angular messages. The default value is 0, indicating no compression.
If a non-zero integer c is provided, the node and edge dimensions will be compressed
to n_dim/c and e_dim/2c, respectively, within the angular message.
axis_neuron : int, optional
The number of dimension of submatrix in the symmetrization ops.
update_angle : bool, optional
Where to update the angle rep. If not, only node and edge rep will be used.
update_style : str, optional
Style to update a representation.
Supported options are:
-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n)
-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)
-'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n)
where `r1`, `r2` ... `r3` are residual weights defined by `update_residual`
and `update_residual_init`.
update_residual : float, optional
When update using residual mode, the initial std of residual vector weights.
update_residual_init : str, optional
When update using residual mode, the initialization mode of residual vector weights.
"""
self.n_dim = n_dim
self.e_dim = e_dim
self.a_dim = a_dim
self.nlayers = nlayers
self.e_rcut = e_rcut
self.e_rcut_smth = e_rcut_smth
self.e_sel = e_sel
self.a_rcut = a_rcut
self.a_rcut_smth = a_rcut_smth
self.a_sel = a_sel
self.a_compress_rate = a_compress_rate
self.axis_neuron = axis_neuron
self.update_angle = update_angle
self.update_style = update_style
self.update_residual = update_residual
self.update_residual_init = update_residual_init

def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
else:
raise KeyError(key)

def serialize(self) -> dict:
return {
"n_dim": self.n_dim,
"e_dim": self.e_dim,
"a_dim": self.a_dim,
"nlayers": self.nlayers,
"e_rcut": self.e_rcut,
"e_rcut_smth": self.e_rcut_smth,
"e_sel": self.e_sel,
"a_rcut": self.a_rcut,
"a_rcut_smth": self.a_rcut_smth,
"a_sel": self.a_sel,
"a_compress_rate": self.a_compress_rate,
"axis_neuron": self.axis_neuron,
"update_angle": self.update_angle,
"update_style": self.update_style,
"update_residual": self.update_residual,
"update_residual_init": self.update_residual_init,
}

@classmethod
def deserialize(cls, data: dict) -> "RepFlowArgs":
return cls(**data)
71 changes: 40 additions & 31 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,28 +187,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
)
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
energy_pred = energy_pred * atom_norm
energy_label = energy_label * atom_norm
l1_ener_loss = F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="sum",
reduction="mean",
)
loss += pref_e * l1_ener_loss
more_loss["mae_e"] = self.display_if_exist(
F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="mean",
).detach(),
l1_ener_loss.detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
if mae:
mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
more_loss["mae_e_all"] = self.display_if_exist(
mae_e_all.detach(), find_energy
)
# if mae:
# mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
# more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
# mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
# more_loss["mae_e_all"] = self.display_if_exist(
# mae_e_all.detach(), find_energy
Comment on lines +203 to +208

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
# )

if (
(self.has_f or self.has_pf or self.relative_f or self.has_gf)
Expand Down Expand Up @@ -241,17 +239,17 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
rmse_f.detach(), find_force
)
else:
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none")
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="mean")
more_loss["mae_f"] = self.display_if_exist(
l1_force_loss.mean().detach(), find_force
l1_force_loss.detach(), find_force
)
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
# l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if mae:
mae_f = torch.mean(torch.abs(diff_f))
more_loss["mae_f"] = self.display_if_exist(
mae_f.detach(), find_force
)
# if mae:
# mae_f = torch.mean(torch.abs(diff_f))
# more_loss["mae_f"] = self.display_if_exist(
# mae_f.detach(), find_force
Comment on lines +248 to +251

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
# )

if self.has_pf and "atom_pref" in label:
atom_pref = label["atom_pref"]
Expand Down Expand Up @@ -297,18 +295,29 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
if self.has_v and "virial" in model_pred and "virial" in label:
find_virial = label.get("find_virial", 0.0)
pref_v = pref_v * find_virial
virial_label = label["virial"]
virial_pred = model_pred["virial"].reshape(-1, 9)
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
if not self.use_l1_all:
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(
rmse_v.detach(), find_virial
)
else:
l1_virial_loss = F.l1_loss(virial_label, virial_pred, reduction="mean")
more_loss["mae_v"] = self.display_if_exist(
l1_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
loss += (pref_v * l1_virial_loss).to(GLOBAL_PT_FLOAT_PRECISION)
# if mae:
# mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
# more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
Comment on lines +318 to +320

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.

if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label:
atom_ener = model_pred["atom_energy"]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from .dpa2 import (
DescrptDPA2,
)
from .dpa3 import (
DescrptDPA3,
)
from .env_mat import (
prod_env_mat,
)
Expand Down Expand Up @@ -49,6 +52,7 @@
"DescrptBlockSeTTebd",
"DescrptDPA1",
"DescrptDPA2",
"DescrptDPA3",
"DescrptHybrid",
"DescrptSeA",
"DescrptSeAttenV2",
Expand Down
Loading