Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): support DeepEval.eval_descriptor #4214

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,58 @@ def eval_typeebd(self) -> np.ndarray:
def get_model_def_script(self) -> str:
"""Get model defination script."""
return self.model_def_script

def eval_descriptor(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Any,
) -> np.ndarray:
"""Evaluate descriptors by using this DP.

Parameters
----------
coords
The coordinates of atoms.
The array should be of size nframes x natoms x 3
cells
The cell of the region.
If None then non-PBC is assumed, otherwise using PBC.
The array should be of size nframes x 9
atom_types
The atom types
The list should contain natoms ints
fparam
The frame parameter.
The array can be of size :
- nframes x dim_fparam.
- dim_fparam. Then all frames are assumed to be provided with the same fparam.
aparam
The atomic parameter
The array can be of size :
- nframes x natoms x dim_aparam.
- natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam.
- dim_aparam. Then all frames and atoms are provided with the same aparam.

Returns
-------
descriptor
Descriptors.
"""
model = self.dp.model["Default"]
model.set_eval_descriptor_hook(True)
self.eval(
coords,
cells,
atom_types,
atomic=False,
fparam=fparam,
aparam=aparam,
**kwargs,
)
descriptor = model.eval_descriptor()
model.set_eval_descriptor_hook(False)
return to_numpy_array(descriptor)
15 changes: 15 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ def __init__(
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
super().init_out_stat()
self.enable_eval_descriptor_hook = False
self.eval_descriptor_list = []

eval_descriptor_list: list[torch.Tensor]

njzjz marked this conversation as resolved.
Show resolved Hide resolved
def set_eval_descriptor_hook(self, enable: bool) -> None:
"""Set the hook for evaluating descriptor and clear the cache for descriptor list."""
self.enable_eval_descriptor_hook = enable
self.eval_descriptor_list = []

def eval_descriptor(self) -> torch.Tensor:
"""Evaluate the descriptor."""
return torch.concat(self.eval_descriptor_list)
njzjz marked this conversation as resolved.
Show resolved Hide resolved

@torch.jit.export
def fitting_output_def(self) -> FittingOutputDef:
Expand Down Expand Up @@ -192,6 +205,8 @@ def forward_atomic(
comm_dict=comm_dict,
)
assert descriptor is not None
if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
# energy, force
fit_ret = self.fitting_net(
descriptor,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Optional,
)

import torch

from deepmd.pt.model.descriptor.base_descriptor import (
BaseDescriptor,
)
Expand Down Expand Up @@ -52,3 +54,13 @@ def get_fitting_net(self):
def get_descriptor(self):
"""Get the descriptor."""
return self.atomic_model.descriptor

@torch.jit.export
def set_eval_descriptor_hook(self, enable: bool) -> None:
"""Set the hook for evaluating descriptor and clear the cache for descriptor list."""
self.atomic_model.set_eval_descriptor_hook(enable)

@torch.jit.export
def eval_descriptor(self) -> torch.Tensor:
"""Evaluate the descriptor."""
return self.atomic_model.eval_descriptor()
2 changes: 0 additions & 2 deletions source/tests/infer/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def test_1frame_atm(self):

def test_descriptor(self):
_, extension = self.param
if extension == ".pth":
self.skipTest("eval_descriptor not supported for PyTorch models")
for ii, result in enumerate(self.case.results):
if result.descriptor is None:
continue
Expand Down