Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def compute_output_stats(

model_pred_g = (
{
kk: [vv[idx] for idx in global_sampled_idx[kk]]
kk: [
np.sum(vv[idx], axis=1) for idx in global_sampled_idx[kk]
] # sum atomic dim
for kk, vv in model_pred.items()
}
if model_pred
Expand All @@ -328,7 +330,7 @@ def compute_output_stats(
else None
)

# concat all frames within those systmes
# concat all frames within those systems
model_pred_g = (
{
kk: np.concatenate(model_pred_g[kk])
Expand Down Expand Up @@ -460,7 +462,6 @@ def compute_output_stats_global(
else:
# subtract the model bias and output the delta bias

model_pred = {kk: np.sum(model_pred[kk], axis=1) for kk in keys}
stats_input = {
kk: merged_output[kk] - model_pred[kk] for kk in keys if kk in merged_output
}
Expand Down
31 changes: 31 additions & 0 deletions source/tests/pt/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,22 @@ def test_finetune_change_out_bias(self):
data.dataloaders,
nbatches=1,
)
# make sampled of multiple frames with different atom numbs
numb_atom = sampled[0]["atype"].shape[1]
small_numb_atom = numb_atom // 2
small_atom_data = deepcopy(sampled[0])
atomic_key = ["coord", "atype"]
for kk in atomic_key:
small_atom_data[kk] = small_atom_data[kk][:, :small_numb_atom]
scale_pref = float(small_numb_atom / numb_atom)
small_atom_data[self.testkey] *= scale_pref
small_atom_data["natoms"][:, :2] = small_numb_atom
small_atom_data["natoms"][:, 2:] = torch.bincount(
small_atom_data["atype"][0],
minlength=small_atom_data["natoms"].shape[1] - 2,
)
sampled = [sampled[0], small_atom_data]

# get model
model = get_model(self.config["model"]).to(env.DEVICE)
atomic_model = model.atomic_model
Expand Down Expand Up @@ -144,12 +160,27 @@ def test_finetune_change_out_bias(self):
np.bincount(to_numpy_array(sampled[0]["atype"][0]))[idx_type_map],
(ntest, 1),
)
atom_nums_small = np.tile(
np.bincount(to_numpy_array(sampled[1]["atype"][0]))[idx_type_map],
(ntest, 1),
)
atom_nums = np.concatenate([atom_nums, atom_nums_small], axis=0)

energy = dp.eval(
to_numpy_array(sampled[0]["coord"][:ntest]),
to_numpy_array(sampled[0]["box"][:ntest]),
to_numpy_array(sampled[0]["atype"][0]),
)[0]
energy_small = dp.eval(
to_numpy_array(sampled[1]["coord"][:ntest]),
to_numpy_array(sampled[1]["box"][:ntest]),
to_numpy_array(sampled[1]["atype"][0]),
)[0]
energy_diff = to_numpy_array(sampled[0][self.testkey][:ntest]) - energy
energy_diff_small = (
to_numpy_array(sampled[1][self.testkey][:ntest]) - energy_small
)
energy_diff = np.concatenate([energy_diff, energy_diff_small], axis=0)
finetune_shift = (
energy_bias_after[idx_type_map] - energy_bias_before[idx_type_map]
).ravel()
Expand Down