Skip to content

Commit

Permalink
Fix MP2 RDM bootstrapping for immutable backend
Browse files Browse the repository at this point in the history
  • Loading branch information
obackhouse committed Sep 29, 2024
1 parent 25b78d1 commit 8973ad8
Showing 1 changed file with 37 additions and 48 deletions.
85 changes: 37 additions & 48 deletions ebcc/codegen/bootstrap_MPn.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,56 +214,45 @@ def get_postamble(n, spin, name="rdm{n}"):
# a second Norb^4 tensor
postamble += "\nrdm1 = make_rdm1_f(t2=t2, l2=l2)\n"
if spin == "uhf":
postamble += "nocc = Namespace(a=t2.aaaa.shape[0], b=t2.bbbb.shape[0])\n"
postamble += "rdm1.aa[np.diag_indices(nocc.a)] -= 1\n"
postamble += "rdm1.bb[np.diag_indices(nocc.b)] -= 1\n"
postamble += "for i in range(nocc.a):\n"
postamble += " rdm2.aaaa[i, i, :, :] += rdm1.aa.T\n"
postamble += " rdm2.aaaa[:, :, i, i] += rdm1.aa.T\n"
postamble += " rdm2.aaaa[:, i, i, :] -= rdm1.aa.T\n"
postamble += " rdm2.aaaa[i, :, :, i] -= rdm1.aa\n"
postamble += " rdm2.aabb[i, i, :, :] += rdm1.bb.T\n"
postamble += "for i in range(nocc.b):\n"
postamble += " rdm2.bbbb[i, i, :, :] += rdm1.bb.T\n"
postamble += " rdm2.bbbb[:, :, i, i] += rdm1.bb.T\n"
postamble += " rdm2.bbbb[:, i, i, :] -= rdm1.bb.T\n"
postamble += " rdm2.bbbb[i, :, :, i] -= rdm1.bb\n"
postamble += " rdm2.aabb[:, :, i, i] += rdm1.aa.T\n"
postamble += "for i in range(nocc.a):\n"
postamble += " for j in range(nocc.a):\n"
postamble += " rdm2.aaaa[i, i, j, j] += 1\n"
postamble += " rdm2.aaaa[i, j, j, i] -= 1\n"
postamble += "for i in range(nocc.b):\n"
postamble += " for j in range(nocc.b):\n"
postamble += " rdm2.bbbb[i, i, j, j] += 1\n"
postamble += " rdm2.bbbb[i, j, j, i] -= 1\n"
postamble += "for i in range(nocc.a):\n"
postamble += " for j in range(nocc.b):\n"
postamble += " rdm2.aabb[i, i, j, j] += 1"
postamble += "delta = Namespace(\n"
postamble += " aa=np.diag(np.concatenate([np.ones(t2.aaaa.shape[0]), np.zeros(t2.aaaa.shape[-1])])),\n"
postamble += " bb=np.diag(np.concatenate([np.ones(t2.bbbb.shape[0]), np.zeros(t2.bbbb.shape[-1])])),\n"
postamble += ")\n"
postamble += "rdm1.aa -= delta.aa\n"
postamble += "rdm1.bb -= delta.bb\n"
postamble += "rdm2.aaaa += einsum(delta.aa, (0, 1), rdm1.aa, (3, 2), (0, 1, 2, 3))\n"
postamble += "rdm2.aaaa += einsum(rdm1.aa, (1, 0), delta.aa, (2, 3), (0, 1, 2, 3))\n"
postamble += "rdm2.aaaa -= einsum(delta.aa, (0, 3), rdm1.aa, (2, 1), (0, 1, 2, 3))\n"
postamble += "rdm2.aaaa -= einsum(rdm1.aa, (0, 3), delta.aa, (1, 2), (0, 1, 2, 3))\n"
postamble += "rdm2.aaaa += einsum(delta.aa, (0, 1), delta.aa, (2, 3), (0, 1, 2, 3))\n"
postamble += "rdm2.aaaa -= einsum(delta.aa, (0, 3), delta.aa, (1, 2), (0, 1, 2, 3))\n"
postamble += "rdm2.bbbb += einsum(delta.bb, (0, 1), rdm1.bb, (3, 2), (0, 1, 2, 3))\n"
postamble += "rdm2.bbbb += einsum(rdm1.bb, (1, 0), delta.bb, (2, 3), (0, 1, 2, 3))\n"
postamble += "rdm2.bbbb -= einsum(delta.bb, (0, 3), rdm1.bb, (2, 1), (0, 1, 2, 3))\n"
postamble += "rdm2.bbbb -= einsum(rdm1.bb, (0, 3), delta.bb, (1, 2), (0, 1, 2, 3))\n"
postamble += "rdm2.bbbb += einsum(delta.bb, (0, 1), delta.bb, (2, 3), (0, 1, 2, 3))\n"
postamble += "rdm2.bbbb -= einsum(delta.bb, (0, 3), delta.bb, (1, 2), (0, 1, 2, 3))\n"
postamble += "rdm2.aabb += einsum(delta.aa, (0, 1), rdm1.bb, (3, 2), (0, 1, 2, 3))\n"
postamble += "rdm2.aabb += einsum(rdm1.aa, (1, 0), delta.bb, (2, 3), (0, 1, 2, 3))\n"
postamble += "rdm2.aabb += einsum(delta.aa, (0, 1), delta.bb, (2, 3), (0, 1, 2, 3))"
elif spin == "ghf":
postamble += "nocc = t2.shape[0]\n"
postamble += "rdm1[np.diag_indices(nocc)] -= 1\n"
postamble += "for i in range(nocc):\n"
postamble += " rdm2[i, i, :, :] += rdm1.T\n"
postamble += " rdm2[:, :, i, i] += rdm1.T\n"
postamble += " rdm2[:, i, i, :] -= rdm1.T\n"
postamble += " rdm2[i, :, :, i] -= rdm1\n"
postamble += "for i in range(nocc):\n"
postamble += " for j in range(nocc):\n"
postamble += " rdm2[i, i, j, j] += 1\n"
postamble += " rdm2[i, j, j, i] -= 1"
postamble += "delta = np.diag(np.concatenate([np.ones(t2.shape[0]), np.zeros(t2.shape[-1])]))\n"
postamble += "rdm1 -= delta\n"
postamble += "rdm2 += einsum(delta, (0, 1), rdm1, (3, 2), (0, 1, 2, 3))\n"
postamble += "rdm2 += einsum(rdm1, (1, 0), delta, (2, 3), (0, 1, 2, 3))\n"
postamble += "rdm2 -= einsum(delta, (0, 3), rdm1, (2, 1), (0, 1, 2, 3))\n"
postamble += "rdm2 -= einsum(rdm1, (0, 3), delta, (1, 2), (0, 1, 2, 3))\n"
postamble += "rdm2 += einsum(delta, (0, 1), delta, (2, 3), (0, 1, 2, 3))\n"
postamble += "rdm2 -= einsum(delta, (0, 3), delta, (1, 2), (0, 1, 2, 3))"
elif spin == "rhf":
postamble += "nocc = t2.shape[0]\n"
postamble += "rdm1[np.diag_indices(nocc)] -= 2\n"
postamble += "for i in range(nocc):\n"
postamble += " rdm2[i, i, :, :] += rdm1.T * 2\n"
postamble += " rdm2[:, :, i, i] += rdm1.T * 2\n"
postamble += " rdm2[:, i, i, :] -= rdm1.T\n"
postamble += " rdm2[i, :, :, i] -= rdm1\n"
postamble += "for i in range(nocc):\n"
postamble += " for j in range(nocc):\n"
postamble += " rdm2[i, i, j, j] += 4\n"
postamble += " rdm2[i, j, j, i] -= 2"
postamble += "delta = np.diag(np.concatenate([np.ones(t2.shape[0]), np.zeros(t2.shape[-1])]))\n"
postamble += "rdm1 -= delta * 2\n"
postamble += "rdm2 += einsum(delta, (0, 1), rdm1, (3, 2), (0, 1, 2, 3)) * 2\n"
postamble += "rdm2 += einsum(rdm1, (1, 0), delta, (2, 3), (0, 1, 2, 3)) * 2\n"
postamble += "rdm2 -= einsum(delta, (0, 3), rdm1, (2, 1), (0, 1, 2, 3))\n"
postamble += "rdm2 -= einsum(rdm1, (0, 3), delta, (1, 2), (0, 1, 2, 3))\n"
postamble += "rdm2 += einsum(delta, (0, 1), delta, (2, 3), (0, 1, 2, 3)) * 4\n"
postamble += "rdm2 -= einsum(delta, (0, 3), delta, (1, 2), (0, 1, 2, 3)) * 2"

return postamble

Expand Down

0 comments on commit 8973ad8

Please sign in to comment.