Skip to content

Commit

Permalink
Support parallelization of conf filter (#268)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced batch processing capabilities for configuration checks,
improving efficiency when handling multiple frames.
- Added new filter classes (`BarFilter`, `BazFilter`) with specific
checks for frame coordinates.
  
- **Bug Fixes**
- Enhanced clarity and efficiency in the configuration filtering
process, streamlining logic and reducing complexity.

- **Tests**
- Updated test cases to reflect new filter logic and ensure accurate
validation of frame counts and coordinate values.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: zjgemi <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zjgemi and pre-commit-ci[bot] authored Oct 23, 2024
1 parent 08d8d6e commit 35e0b97
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 89 deletions.
7 changes: 1 addition & 6 deletions dpgen2/exploration/render/traj_render_lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,5 @@ def get_confs(
ss = ss.sub_system(id_selected[ii])
ms.append(ss)
if conf_filters is not None:
ms2 = dpdata.MultiSystems(type_map=type_map)
for s in ms:
s2 = conf_filters.check(s)
if len(s2) > 0:
ms2.append(s2)
ms = ms2
ms = conf_filters.check(ms)
return ms
45 changes: 38 additions & 7 deletions dpgen2/exploration/selector/conf_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
ABC,
abstractmethod,
)
from typing import (
List,
)

import dpdata
import numpy as np
Expand All @@ -32,6 +35,25 @@ def check(
"""
pass

def batched_check(
self,
frames: List[dpdata.System],
) -> List[bool]:
"""Check if a list of configurations are valid.
Parameters
----------
frames : List[dpdata.System]
A list of dpdata.System each containing a single frame
Returns
-------
valid : List[bool]
`True` if the configuration is a valid configuration, else `False`.
"""
return list(map(self.check, frames))


class ConfFilters:
def __init__(
Expand All @@ -48,11 +70,20 @@ def add(

def check(
self,
conf: dpdata.System,
) -> dpdata.System:
natoms = sum(conf["atom_numbs"]) # type: ignore
selected_idx = np.arange(conf.get_nframes())
ms: dpdata.MultiSystems,
) -> dpdata.MultiSystems:
selected_idx = []
for i in range(len(ms)):
for j in range(ms[i].get_nframes()):
selected_idx.append((i, j))
for ff in self._filters:
fsel = np.where([ff.check(conf[ii]) for ii in range(conf.get_nframes())])[0]
selected_idx = np.intersect1d(selected_idx, fsel)
return conf.sub_system(selected_idx)
res = ff.batched_check([ms[i][j] for i, j in selected_idx])
selected_idx = [idx for i, idx in enumerate(selected_idx) if res[i]]
selected_idx_list = [[] for _ in range(len(ms))]
for i, j in selected_idx:
selected_idx_list[i].append(j)
ms2 = dpdata.MultiSystems(type_map=ms.atom_names)
for i in range(len(ms)):
if len(selected_idx_list[i]) > 0:
ms2.append(ms[i].sub_system(selected_idx_list[i]))
return ms2
75 changes: 72 additions & 3 deletions dpgen2/exploration/selector/distance_conf_filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import logging
from concurrent.futures import (
ProcessPoolExecutor,
)
from copy import (
deepcopy,
)
Expand Down Expand Up @@ -133,7 +136,8 @@ def check_multiples(a, b, c, multiple):


class DistanceConfFilter(ConfFilter):
def __init__(self, custom_safe_dist=None, safe_dist_ratio=1.0):
def __init__(self, max_workers=None, custom_safe_dist=None, safe_dist_ratio=1.0):
self.max_workers = max_workers
self.custom_safe_dist = custom_safe_dist if custom_safe_dist is not None else {}
self.safe_dist_ratio = safe_dist_ratio

Expand Down Expand Up @@ -187,6 +191,16 @@ def check(

return True

def batched_check(
self,
frames: List[dpdata.System],
):
if self.max_workers == 1:
return list(map(self.check, frames))
else:
with ProcessPoolExecutor(self.max_workers) as executor:
return list(executor.map(self.check, frames))

@staticmethod
def args() -> List[dargs.Argument]:
r"""The argument definition of the `ConfFilter`.
Expand All @@ -197,9 +211,20 @@ def args() -> List[dargs.Argument]:
List of dargs.Argument defines the arguments of the `ConfFilter`.
"""

doc_max_workers = (
"The maximum number of processes used to filter configurations, "
+ "None represents as many as the processors of the machine, and 1 for serial"
)
doc_custom_safe_dist = "Custom safe distance (in unit of bohr) for each element"
doc_safe_dist_ratio = "The ratio multiplied to the safe distance"
return [
Argument(
"max_workers",
int,
optional=True,
default=None,
doc=doc_max_workers,
),
Argument(
"custom_safe_dist",
dict,
Expand All @@ -218,7 +243,8 @@ def args() -> List[dargs.Argument]:


class BoxSkewnessConfFilter(ConfFilter):
def __init__(self, theta=60.0):
def __init__(self, max_workers=None, theta=60.0):
self.max_workers = max_workers
self.theta = theta

def check(
Expand Down Expand Up @@ -251,6 +277,16 @@ def check(
return False
return True

def batched_check(
self,
frames: List[dpdata.System],
):
if self.max_workers == 1:
return list(map(self.check, frames))
else:
with ProcessPoolExecutor(self.max_workers) as executor:
return list(executor.map(self.check, frames))

@staticmethod
def args() -> List[dargs.Argument]:
r"""The argument definition of the `ConfFilter`.
Expand All @@ -261,8 +297,19 @@ def args() -> List[dargs.Argument]:
List of dargs.Argument defines the arguments of the `ConfFilter`.
"""

doc_max_workers = (
"The maximum number of processes used to filter configurations, "
+ "None represents as many as the processors of the machine, and 1 for serial"
)
doc_theta = "The threshold for angles between the edges of the cell. If all angles are larger than this value the check is passed"
return [
Argument(
"max_workers",
int,
optional=True,
default=None,
doc=doc_max_workers,
),
Argument(
"theta",
float,
Expand All @@ -274,7 +321,8 @@ def args() -> List[dargs.Argument]:


class BoxLengthFilter(ConfFilter):
def __init__(self, length_ratio=5.0):
def __init__(self, max_workers=None, length_ratio=5.0):
self.max_workers = max_workers
self.length_ratio = length_ratio

def check(
Expand Down Expand Up @@ -307,6 +355,16 @@ def check(
return False
return True

def batched_check(
self,
frames: List[dpdata.System],
):
if self.max_workers == 1:
return list(map(self.check, frames))
else:
with ProcessPoolExecutor(self.max_workers) as executor:
return list(executor.map(self.check, frames))

@staticmethod
def args() -> List[dargs.Argument]:
r"""The argument definition of the `ConfFilter`.
Expand All @@ -317,8 +375,19 @@ def args() -> List[dargs.Argument]:
List of dargs.Argument defines the arguments of the `ConfFilter`.
"""

doc_max_workers = (
"The maximum number of processes used to filter configurations, "
+ "None represents as many as the processors of the machine, and 1 for serial"
)
doc_length_ratio = "The threshold for the length ratio between the edges of the cell. If all length ratios are smaller than this value the check is passed"
return [
Argument(
"max_workers",
int,
optional=True,
default=None,
doc=doc_max_workers,
),
Argument(
"length_ratio",
float,
Expand Down
114 changes: 41 additions & 73 deletions tests/exploration/test_conf_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,110 +27,78 @@ def check(
self,
frame: dpdata.System,
) -> bool:
return True
return frame["coords"][0][0][0] > 0.0


class faked_filter:
myiter = -1
myret = [True]
class BarFilter(ConfFilter):
def check(
self,
frame: dpdata.System,
) -> bool:
return frame["coords"][0][0][1] > 0.0


@classmethod
def faked_check(cls, frame):
cls.myiter += 1
cls.myiter = cls.myiter % len(cls.myret)
return cls.myret[cls.myiter]
class BazFilter(ConfFilter):
def check(
self,
frame: dpdata.System,
) -> bool:
return frame["coords"][0][0][2] > 0.0


class TestConfFilter(unittest.TestCase):
@patch.object(FooFilter, "check", faked_filter.faked_check)
def test_filter_0(self):
faked_filter.myiter = -1
faked_filter.myret = [
True,
True,
False,
True,
False,
True,
True,
False,
True,
True,
False,
False,
]
faked_sys = fake_system(4, 3)
# expected only frame 1 is preseved.
faked_sys["coords"][1][0][0] = 1.0
faked_sys["coords"][1][0] = 1.0
faked_sys["coords"][0][0][0] = 2.0
faked_sys["coords"][2][0][1] = 3.0
faked_sys["coords"][3][0][2] = 4.0
filters = ConfFilters()
filters.add(FooFilter()).add(FooFilter()).add(FooFilter())
sel_sys = filters.check(faked_sys)
filters.add(FooFilter()).add(BarFilter()).add(BazFilter())
ms = dpdata.MultiSystems()
ms.append(faked_sys)
sel_sys = filters.check(ms)[0]
self.assertEqual(sel_sys.get_nframes(), 1)
self.assertAlmostEqual(sel_sys["coords"][0][0][0], 1)

@patch.object(FooFilter, "check", faked_filter.faked_check)
def test_filter_1(self):
faked_filter.myiter = -1
faked_filter.myret = [
True,
True,
False,
True,
False,
True,
True,
True,
True,
True,
False,
True,
]
faked_sys = fake_system(4, 3)
# expected frame 1 and 3 are preseved.
faked_sys["coords"][1][0][0] = 1.0
faked_sys["coords"][3][0][0] = 3.0
faked_sys["coords"][1][0] = 1.0
faked_sys["coords"][3][0] = 3.0
filters = ConfFilters()
filters.add(FooFilter()).add(FooFilter()).add(FooFilter())
sel_sys = filters.check(faked_sys)
filters.add(FooFilter()).add(BarFilter()).add(BazFilter())
ms = dpdata.MultiSystems()
ms.append(faked_sys)
sel_sys = filters.check(ms)[0]
self.assertEqual(sel_sys.get_nframes(), 2)
self.assertAlmostEqual(sel_sys["coords"][0][0][0], 1)
self.assertAlmostEqual(sel_sys["coords"][1][0][0], 3)

@patch.object(FooFilter, "check", faked_filter.faked_check)
def test_filter_all(self):
faked_filter.myiter = -1
faked_filter.myret = [
True,
True,
True,
True,
]
faked_sys = fake_system(4, 3)
# expected all frames are preseved.
faked_sys["coords"][0][0][0] = 0.5
faked_sys["coords"][1][0][0] = 1.0
faked_sys["coords"][2][0][0] = 2.0
faked_sys["coords"][3][0][0] = 3.0
faked_sys["coords"][0][0] = 0.5
faked_sys["coords"][1][0] = 1.0
faked_sys["coords"][2][0] = 2.0
faked_sys["coords"][3][0] = 3.0
filters = ConfFilters()
filters.add(FooFilter()).add(FooFilter()).add(FooFilter())
sel_sys = filters.check(faked_sys)
filters.add(FooFilter()).add(BarFilter()).add(BazFilter())
ms = dpdata.MultiSystems()
ms.append(faked_sys)
sel_sys = filters.check(ms)[0]
self.assertEqual(sel_sys.get_nframes(), 4)
self.assertAlmostEqual(sel_sys["coords"][0][0][0], 0.5)
self.assertAlmostEqual(sel_sys["coords"][1][0][0], 1)
self.assertAlmostEqual(sel_sys["coords"][2][0][0], 2)
self.assertAlmostEqual(sel_sys["coords"][3][0][0], 3)

@patch.object(FooFilter, "check", faked_filter.faked_check)
def test_filter_none(self):
faked_filter.myiter = -1
faked_filter.myret = [
False,
False,
False,
False,
]
faked_sys = fake_system(4, 3)
filters = ConfFilters()
filters.add(FooFilter()).add(FooFilter()).add(FooFilter())
sel_sys = filters.check(faked_sys)
self.assertEqual(sel_sys.get_nframes(), 0)
filters.add(FooFilter()).add(BarFilter()).add(BazFilter())
ms = dpdata.MultiSystems()
ms.append(faked_sys)
sel_ms = filters.check(ms)
self.assertEqual(sel_ms.get_nframes(), 0)

0 comments on commit 35e0b97

Please sign in to comment.