Skip to content

Commit

Permalink
fix the bug for dsdx
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Oct 3, 2024
1 parent 0d2a100 commit 22ddb12
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 16 deletions.
9 changes: 7 additions & 2 deletions pyxtal/lego/SO3.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def compute_dpdr_5d(self, atoms):
pair_ids = neigh_ids[self.neighbor_indices[:, 0] == i]
if len(pair_ids) > 0:
ctot = cs[pair_ids].sum(axis=0) #(n, l, m)
dctot = dcs[pair_ids].sum(axis=0)
# power spectrum P = c*c_conj
# eq_3 (n, n', l) eliminate m
P = np.einsum('ijk, ljk->ilj', ctot, np.conj(ctot)).real
Expand All @@ -287,12 +288,16 @@ def compute_dpdr_5d(self, atoms):
# (N_ijs, n, n', l, 3)
# dc * c_conj + c * dc_conj
dP = np.einsum('ijkn, ljk->iljn', dcs[pair_id], np.conj(ctot))
dP += np.conj(np.transpose(dP, axes=[1, 0, 2, 3]))
dP += np.einsum('ijkn, ljk->iljn', np.conj(dcs[pair_id]), ctot)
#dP += np.conj(np.transpose(dP, axes=[1, 0, 2, 3]))
#dP += np.einsum('ijkn, ljk->iljn', np.conj(dctot), cs[pair_id])
#dP += np.einsum('ijkn, ljk->iljn', dctot, np.conj(cs[pair_id]))

dP = dP.real[self.tril_indices].flatten().reshape(self.ncoefs, 3)
#print(cs[pair_id].shape, dcs[pair_id].shape, dP.shape)

dp_list[i, j, :, :, cell_id] += dP
dp_list[i, i, :, :, cell_id] -= dP
dp_list[i, i, :, :, 13] -= dP

return dp_list, p_list

Expand Down
128 changes: 114 additions & 14 deletions tests/test_SO3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,75 @@
from ase import Atoms
from ase.build import bulk, sort

def calculate_S(atoms, P_ref):
P = calculator.compute_p(atoms)
S = np.sum((P - P_ref)**2)
return S


def numerical_dSdx(x, xtal, P_ref, eps=1e-4):
if type(x) == list: x = np.array(x)

xtal.update_from_1d_rep(x)
atoms = xtal.to_ase() #* 2
S0 = calculate_S(atoms, P_ref)
dSdx = np.zeros(len(x))
for i in range(len(x)):
x0 = x.copy()
x0[i] += eps
xtal.update_from_1d_rep(x0)
atoms = xtal.to_ase() #* 2
S1 = calculate_S(atoms, P_ref)

x0 = x.copy()
x0[i] -= eps
xtal.update_from_1d_rep(x0)
atoms = xtal.to_ase() #* 2
S2 = calculate_S(atoms, P_ref)

dSdx[i] = 0.5*(S1-S2)/eps
return dSdx


def calculate_dSdx_supercell(x, xtal, P_ref, eps=1e-4):

xtal.update_from_1d_rep(x)
atoms = xtal.to_ase() #* 2

dPdr, P = calculator.compute_dpdr_5d(atoms)

# Compute dSdr [N, M] [N, N, M, 3, 27] => [N, 3, 27]
dSdr = np.einsum("ik, ijklm -> jlm", 2*(P - P_ref), dPdr)

# Get supercell positions
ref_pos = np.repeat(atoms.positions[:, :, np.newaxis], 27, axis=2)
for cell in range(27):
x1, y1, z1 = cell // 9 - 1, (cell // 3) % 3 - 1, cell % 3 - 1
ref_pos[:, :, cell] += np.array([x1, y1, z1]) @ atoms.cell

# Compute drdx via numerical func
drdx = np.zeros([len(atoms), 3, 27, len(x)])

xtal0 = xtal.copy()
for i in range(len(x)):
x0 = x.copy()
x0[i] += eps
xtal0.update_from_1d_rep(x0)
atoms = xtal0.to_ase()

# Get supercell positions
pos = np.repeat(atoms.positions[:, :, np.newaxis], 27, axis=2)
for cell in range(27):
x1, y1, z1 = cell // 9 - 1, (cell // 3) % 3 - 1, cell % 3 - 1
pos[:, :, cell] += np.array([x1, y1, z1]) @ atoms.cell

drdx[:, :, :, i] += (pos - ref_pos)/eps

# [N, 3, 27] [N, 3, 27, H] => H
dSdx = np.einsum("ijk, ijkl -> l", dSdr, drdx)
return dSdx


def get_rotated_cluster(struc, angle=0, axis='x'):
s_new = struc.copy()
s_new.rotate(angle, axis)
Expand All @@ -28,15 +97,15 @@ def get_perturbed_xtal(struc, p0, p1, eps):
p_struc.set_positions(pos)
return p_struc

def get_dPdR_xtal(xtal, nmax, lmax, rc, eps):
p0 = SO3(nmax=nmax, lmax=lmax, rcut=rc).calculate(xtal, derivative=True)
def get_dPdR_xtal(xtal, eps):
p0 = calculator.calculate(xtal, derivative=True)
shp = p0['x'].shape
array1 = p0['dxdr']

for j in range(shp[0]):
for k in range(3):
struc = get_perturbed_xtal(xtal, j, k, eps)
p1 = SO3(nmax=nmax, lmax=lmax, rcut=rc).calculate(struc)
p1 = calculator.calculate(struc)
array2 = (p1['x'] - p0['x'])/eps
#if np.linalg.norm(array2) > 1e-2: print(j, k, array2)
if not np.allclose(array1[:, j, :, k], array2, atol=1e-4):
Expand All @@ -45,21 +114,29 @@ def get_dPdR_xtal(xtal, nmax, lmax, rc, eps):

# Descriptors Parameters
eps = 1e-8
rc = 2.80
rc1 = 1.8
rc2 = 3.5
nmax, lmax = 2, 2
calculator = SO3(nmax=nmax, lmax=lmax, rcut=rc1)
calculator0 = SO3(nmax=nmax, lmax=lmax, rcut=rc2)

# NaCl cluster
cluster = bulk('NaCl', crystalstructure='rocksalt', a=5.691694, cubic=True)
cluster = sort(cluster, tags=[0, 4, 1, 5, 2, 6, 3, 7])
cluster.set_pbc((0,0,0))
cluster = get_rotated_cluster(cluster, angle=0.1) # Must rotate

xtal = pyxtal()
xtal.from_prototype('graphite')
atoms = xtal.to_ase()
P_ref = calculator.compute_p(atoms)[0]

# Diamond
class TestCluster(unittest.TestCase):
struc = get_rotated_cluster(cluster)
p0 = SO3(nmax=nmax, lmax=lmax, rcut=rc).calculate(struc, derivative=True)
p0 = calculator0.calculate(struc, derivative=True)
struc = get_rotated_cluster(cluster, 10, 'x')
p1 = SO3(nmax=nmax, lmax=lmax, rcut=rc).calculate(struc)
p1 = calculator0.calculate(struc)

def test_SO3_rotation_variance(self):
array1 = self.p0['x']
Expand All @@ -73,7 +150,7 @@ def test_dPdR_vs_numerical(self):
for j in range(shp[0]):
for k in range(3):
struc = get_perturbed_cluster(cluster, j, k, eps)
p2 = SO3(nmax=nmax, lmax=lmax, rcut=rc).calculate(struc)
p2 = calculator0.calculate(struc)
array2 = (p2['x'] - self.p0['x'])/eps
assert(np.allclose(array1[:,j,:,k], array2, atol=1e-3))

Expand All @@ -82,30 +159,53 @@ class TestXtal(unittest.TestCase):
def test_dPdR_diamond(self):
c = pyxtal()
c.from_prototype('diamond')
get_dPdR_xtal(c.to_ase(), nmax, lmax, rc, eps)
get_dPdR_xtal(c.to_ase(), eps)

def test_dPdR_graphite(self):
c = pyxtal()
c.from_prototype('graphite')
get_dPdR_xtal(c.to_ase(), nmax, lmax, rc, eps)
get_dPdR_xtal(c.to_ase(), eps)

def test_dPdR_random(self):
x = [ 7.952, 2.606, 0.592, 0.926, 0.608, 0.307]
c = pyxtal()
c.from_spg_wps_rep(179, ['6a', '6a', '6a', '6a'], x)
get_dPdR_xtal(c.to_ase(), nmax, lmax, rc, eps)
get_dPdR_xtal(c.to_ase(), eps)

def test_dPdR_random_P(self):
x = [ 7.952, 2.606, 0.592, 0.926, 0.608, 0.307]
c = pyxtal()
c.from_spg_wps_rep(179, ['6a', '6a', '6a', '6a'], x)
atoms = c.to_ase()
f = SO3(nmax=nmax, lmax=lmax, rcut=rc)
p0 = f.compute_p(atoms)
_, p1 = f.compute_dpdr(atoms)
_, p2 = f.compute_dpdr_5d(atoms)
p0 = calculator.compute_p(atoms)
_, p1 = calculator.compute_dpdr(atoms)
_, p2 = calculator.compute_dpdr_5d(atoms)
assert(np.allclose(p0, p1, atol=1e-3))
assert(np.allclose(p0, p2, atol=1e-3))

class TestSimilarity(unittest.TestCase):

def test_sim_diamond(self):
x = [3.0]
c = pyxtal()
c.from_spg_wps_rep(227, ['8a'], x, ['C'])
atoms = c.to_ase()
x = c.get_1d_rep_x()
dSdx1 = numerical_dSdx(x, c, P_ref)
dSdx2 = calculate_dSdx_supercell(x, c, P_ref)
#print(dSdx1, dSdx2)
assert(np.allclose(dSdx1, dSdx2, rtol=1e-1, atol=1e+1))

def test_dPdR_random(self):
#x = [ 7.952, 2.606, 0.592, 0.926, 0.608, 0.307]
x = [9.55, 2.60, 0.48, 0.88, 0.76, 0.36]
c = pyxtal()
c.from_spg_wps_rep(179, ['6a', '6a', '6a', '6a'], x)
atoms = c.to_ase()
dSdx1 = numerical_dSdx(x, c, P_ref)
dSdx2 = calculate_dSdx_supercell(x, c, P_ref)
#print(dSdx1, dSdx2)
assert(np.allclose(dSdx1, dSdx2, rtol=1e-1, atol=1e+1))

if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tests/test_lego.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_opt_xtal(self):
xtal = pyxtal()
xtal.from_spg_wps_rep(spg, wps, x, ['C']*len(wps))
xtal, sim, _ = builder1.optimize_xtal(xtal, add_db=False)
#print(xtal.get_1d_rep_x())
assert sim < 1e-2

def test_opt_xtal2(self):
Expand Down

0 comments on commit 22ddb12

Please sign in to comment.