Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TDAx #73

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions momentGW/bse.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,21 @@ def __init__(self, gw, **kwargs):
# Attributes
self.gf = None

@property
def polarizability_name(self):
"""Get the polarizability name."""
return {
"drpa": "dRPA",
"drpa-exact": "dRPA",
"dtda": "dTDA",
"thc-dtda": "THC-dTDA",
"tdax": "TDAx",
}[self.polarizability.lower()]

@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-BSE"
return f"{self.polarizability_name}-BSE"

@logging.with_timer("Integral construction")
@logging.with_status("Constructing integrals")
Expand Down Expand Up @@ -526,8 +536,7 @@ def __init__(self, gw, **kwargs):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-cpBSE"
return f"{self.polarizability_name}-cpBSE"

@logging.with_timer("Dynamic polarizability moments")
@logging.with_status("Constructing dynamic polarizability moments")
Expand Down
3 changes: 1 addition & 2 deletions momentGW/evgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ class evGW(GW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-evG{'0' if self.g0 else ''}W{'0' if self.w0 else ''}"
return f"{self.polarizability_name}-evG{'0' if self.g0 else ''}W{'0' if self.w0 else ''}"

def check_convergence(self, mo_energy, mo_energy_prev, th, th_prev, tp, tp_prev):
"""Check for convergence, and print a summary of changes.
Expand Down
3 changes: 1 addition & 2 deletions momentGW/fsgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,5 +219,4 @@ class fsGW(GW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-fsGW"
return f"{self.polarizability_name}-fsGW"
20 changes: 17 additions & 3 deletions momentGW/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from momentGW.fock import FockLoop, search_chempot
from momentGW.ints import Integrals
from momentGW.rpa import dRPA
from momentGW.tda import dTDA
from momentGW.tda import TDAx, dTDA


def kernel(
Expand Down Expand Up @@ -138,11 +138,21 @@ class GW(BaseGW):

_kernel = kernel

@property
def polarizability_name(self):
"""Get the polarizability name."""
return {
"drpa": "dRPA",
"drpa-exact": "dRPA",
"dtda": "dTDA",
"thc-dtda": "THC-dTDA",
"tdax": "TDAx",
}[self.polarizability.lower()]

@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-G0W0"
return f"{self.polarizability_name}-G0W0"

@logging.with_timer("Static self-energy")
@logging.with_status("Building static self-energy")
Expand Down Expand Up @@ -235,6 +245,10 @@ def build_se_moments(self, nmom_max, integrals, **kwargs):
tda = thc.dTDA(self, nmom_max, integrals, **kwargs)
return tda.kernel()

elif self.polarizability.lower() == "tdax":
tda = TDAx(self, nmom_max, integrals, **kwargs)
return tda.kernel()

else:
raise NotImplementedError

Expand Down
3 changes: 1 addition & 2 deletions momentGW/pbc/evgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class evKGW(KGW, evGW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-evKG{'0' if self.g0 else ''}W{'0' if self.w0 else ''}"
return f"{self.polarizability_name}-evKG{'0' if self.g0 else ''}W{'0' if self.w0 else ''}"

def check_convergence(self, mo_energy, mo_energy_prev, th, th_prev, tp, tp_prev):
"""Check for convergence, and print a summary of changes.
Expand Down
3 changes: 1 addition & 2 deletions momentGW/pbc/fsgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,4 @@ class fsKGW(KGW, fsGW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-fsKGW"
return f"{self.polarizability_name}-fsKGW"
8 changes: 5 additions & 3 deletions momentGW/pbc/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from momentGW.pbc.fock import FockLoop, search_chempot_unconstrained
from momentGW.pbc.ints import KIntegrals
from momentGW.pbc.rpa import dRPA
from momentGW.pbc.tda import dTDA
from momentGW.pbc.tda import TDAx, dTDA


class KGW(BaseKGW, GW):
Expand Down Expand Up @@ -66,8 +66,7 @@ class KGW(BaseKGW, GW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-KG0W0"
return f"{self.polarizability_name}-KG0W0"

@logging.with_timer("Static self-energy")
@logging.with_status("Building static self-energy")
Expand Down Expand Up @@ -120,6 +119,9 @@ def build_se_moments(self, nmom_max, integrals, **kwargs):
elif self.polarizability.lower() == "thc-dtda":
tda = thc.dTDA(self, nmom_max, integrals, **kwargs)
return tda.kernel()
elif self.polarizability.lower() == "tdax":
tda = TDAx(self, nmom_max, integrals, **kwargs)
return tda.kernel()
else:
raise NotImplementedError

Expand Down
3 changes: 1 addition & 2 deletions momentGW/pbc/qsgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ class qsKGW(KGW, qsGW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-qsKGW"
return f"{self.polarizability_name}-qsKGW"

@staticmethod
def project_basis(matrix, ovlp, mo1, mo2):
Expand Down
3 changes: 1 addition & 2 deletions momentGW/pbc/scgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,4 @@ class scKGW(KGW, scGW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-KG{'0' if self.g0 else ''}W{'0' if self.w0 else ''}"
return f"{self.polarizability_name}-KG{'0' if self.g0 else ''}W{'0' if self.w0 else ''}"
105 changes: 105 additions & 0 deletions momentGW/pbc/tda.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,108 @@ def kpts(self):
def nkpts(self):
"""Get the number of k-points."""
return self.gw.nkpts


class TDAx(dTDA):
"""
Compute the self-energy moments using TDA (with exchange) with
periodic boundary conditions.

Parameters
----------
gw : BaseKGW
GW object.
nmom_max : int
Maximum moment number to calculate.
integrals : KIntegrals
Density-fitted integrals at each k-point.
mo_energy : dict, optional
Molecular orbital energies at each k-point. Keys are "g" and
"w" for the Green's function and screened Coulomb interaction,
respectively. If `None`, use `gw.mo_energy` for both. Default
value is `None`.
mo_occ : dict, optional
Molecular orbital occupancies at each k-point. Keys are "g"
and "w" for the Green's function and screened Coulomb
interaction, respectively. If `None`, use `gw.mo_occ` for both.
Default value is `None`.
"""

@logging.with_timer("Self-energy moments")
@logging.with_status("Constructing self-energy moments")
def build_se_moments(self, moments_dd):
"""Build the moments of the self-energy via convolution.

Parameters
----------
moments_dd : numpy.ndarray
Moments of the density-density response at each k-point.

Returns
-------
moments_occ : numpy.ndarray
Moments of the occupied self-energy at each k-point.
moments_vir : numpy.ndarray
Moments of the virtual self-energy at each k-point.
"""

# Get the sizes
nocc = self.integrals.nocc
nvir = self.integrals.nvir
kpts = self.kpts

# Setup dependent on diagonal SE
if self.gw.diagonal_se:
pqchar = pchar = qchar = "p"
eta_shape = lambda k: (self.mo_energy_g[k].size, self.nmom_max + 1, self.nmo)
else:
pqchar, pchar, qchar = "pq", "p", "q"
eta_shape = lambda k: (self.mo_energy_g[k].size, self.nmom_max + 1, self.nmo, self.nmo)
eta = np.zeros((self.nkpts, self.nkpts), dtype=object)

# Get the moments
for n in range(self.nmom_max + 1):
for kp, kx, ki in kpts.loop(3):
ka = kpts.conserve(kp, kx, ki)
q = kpts.member(kpts.wrap_around(kpts[ki] - kpts[ka]))

if not isinstance(eta[kp, q], np.ndarray):
eta[kp, q] = np.zeros(eta_shape(kx), dtype=complex)

Lia = self.integrals.Lia[ki, ka]
Lia = Lia.reshape(Lia.shape[0], nocc[ki], nvir[ka])

Lxa = self.integrals.Lia[kx, ka]
Lxa = Lxa.reshape(Lxa.shape[0], nocc[kx], nvir[ka])

Lix = self.integrals.Lia[ki, kx]
Lix = Lix.reshape(Lix.shape[0], nocc[ki], nvir[kx])

Lpi = self.integrals.Lpx[kp, ki][:, :, : nocc[ki]]
Lpa = self.integrals.Lpx[kp, ka][:, :, nocc[ka] :]

moment = moments_dd[q, ka, n]
moment = moment.reshape(Lia.shape)

for x in range(self.mo_energy_g[kx].size):
Lp = self.integrals.Lpx[kp, kx][:, :, x]

v = util.einsum("Pia,Pq->iaq", Lia, Lp) * 2.0
if self.mo_occ_g[kx][x] > 0:
La = Lxa[:, x]
v -= util.einsum("Pa,Pqi->iaq", La, Lpi)
else:
Li = Lix[:, :, x - nocc[kx]]
v -= util.einsum("Pi,Pqa->iaq", Li, Lpa)

eta[kp, q][x, n] += util.einsum(
f"P{pchar},Pia,ia{qchar}->{pqchar}",
Lp,
moment,
v.conj(),
)

# Construct the self-energy moments
moments_occ, moments_vir = self.convolve(eta)

return moments_occ, moments_vir
3 changes: 1 addition & 2 deletions momentGW/pbc/uhf/evgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ class evKUGW(KUGW, evKGW, evUGW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-evKUGW"
return f"{self.polarizability_name}-evKUG{'0' if self.g0 else ''}W{'0' if self.w0 else ''}"

def check_convergence(self, mo_energy, mo_energy_prev, th, th_prev, tp, tp_prev):
"""Check for convergence, and print a summary of changes.
Expand Down
3 changes: 1 addition & 2 deletions momentGW/pbc/uhf/fsgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,4 @@ class fsKUGW(KUGW, fsKGW, fsUGW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-fsKUGW"
return f"{self.polarizability_name}-fsKUGW"
3 changes: 1 addition & 2 deletions momentGW/pbc/uhf/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ class KUGW(BaseKUGW, KGW, UGW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-KUG0W0"
return f"{self.polarizability_name}-KUG0W0"

@logging.with_timer("Static self-energy")
@logging.with_status("Building static self-energy")
Expand Down
3 changes: 1 addition & 2 deletions momentGW/pbc/uhf/qsgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ class qsKUGW(KUGW, qsKGW, qsUGW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-qsKUGW"
return f"{self.polarizability_name}-qsKUGW"

@staticmethod
def project_basis(matrix, ovlp, mo1, mo2):
Expand Down
3 changes: 1 addition & 2 deletions momentGW/pbc/uhf/scgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,4 @@ class scKUGW(KUGW, scKGW, scUGW):
@property
def name(self):
"""Get the method name."""
polarizability = self.polarizability.upper().replace("DTDA", "dTDA").replace("DRPA", "dRPA")
return f"{polarizability}-scKUG{'0' if self.g0 else ''}W{'0' if self.w0 else ''}"
return f"{self.polarizability_name}-scKUG{'0' if self.g0 else ''}W{'0' if self.w0 else ''}"
Loading
Loading