Skip to content

Commit

Permalink
Merge pull request #129 from BoothGroup/fix_uccsd_unpack_MPI_
Browse files Browse the repository at this point in the history
Bugfix: MPI communication - functionality for pack-unpack arrays needed for linearizing data added
  • Loading branch information
ghb24 authored Sep 5, 2023
2 parents 515afec + 3dc8ac5 commit 44c60bf
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions vayesta/core/types/wf/ccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vayesta.core import spinalg
from vayesta.core.util import NotCalculatedError, Object, callif, einsum, dot
from vayesta.core.types import wf as wf_types
from vayesta.core.types.orbitals import SpatialOrbitals
from vayesta.core.types.orbitals import SpatialOrbitals, SpinOrbitals
from vayesta.core.types.wf.project import (
project_c1,
project_c2,
Expand Down Expand Up @@ -398,27 +398,29 @@ def rotate_ov(self, to, tv, inplace=False):
wf.l2 = transform_uc2(wf.l2, to, tv)
return wf

# def pack(self, dtype=float):
# """Pack into a single array of data type `dtype`.

# Useful for communication via MPI."""
# mo = self.mo.pack(dtype=dtype)
# l1 = self.l1 is not None else [None, None]
# l2 = self.l2 is not None else len(self.t2)*[None]
# projector = self.projector is not None else [None]
# data = (mo, *self.t1, *self.t2, *l1, *l2, *projector)
# pack = pack_arrays(*data, dtype=dtype)
# return pack

# @classmethod
# def unpack(cls, packed):
# """Unpack from a single array of data type `dtype`.

# Useful for communication via MPI."""
# mo, *unpacked = unpack_arrays(packed)
# mo = SpinOrbitals.unpack(mo)
# t1a, t1b, t2, l1, l2, projector =
# wf = cls(mo, t1, t2, l1=l1, l2=l2)
# if projector is not None:
# wf.projector = projector
# return wf
def pack(self, dtype=float):
"""Pack into a single array of data type `dtype`.
Useful for communication via MPI."""
mo = self.mo.pack(dtype=dtype)
l1 = self.l1 if self.l1 is not None else [None, None]
l2 = self.l2 if self.l2 is not None else len(self.t2)*[None]
projector=self.projector
data = (mo, *self.t1, *self.t2, *l1, *l2, *projector)
pack = pack_arrays(*data, dtype=dtype)
return pack

@classmethod
def unpack(cls, packed):
"""Unpack from a single array of data type `dtype`.
Useful for communication via MPI."""
mo, t1a, t1b, t2aa, t2ab, t2ba, t2bb, l1a, l1b, l2aa, l2ab, l2ba, l2bb, proja, projb = unpack_arrays(packed)
t1 = (t1a, t1b)
t2 = (t2aa, t2ab, t2ba, t2bb)
l1 = (l1a, l1b)
l2 = (l2aa, l2ab, l2ba, l2bb)
projector = (proja, projb)
mo = SpinOrbitals.unpack(mo)
wf = cls(mo, t1, t2, l1=l1, l2=l2)
if projector is not None:
wf.projector = projector
return wf

0 comments on commit 44c60bf

Please sign in to comment.