Skip to content

Commit

Permalink
wick eq i/o
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Oct 4, 2023
1 parent 23a902e commit 64c7ce8
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 13 deletions.
6 changes: 6 additions & 0 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4729,6 +4729,8 @@ def make_su2_open_shell(h1e, g2e, const_e, cidx, midx, iprint=1):
)
is_single = lambda x: (x & b.WickIndexTypes.Single) != b.WickIndexTypes.Nothing

assert cidx.dtype == bool

def ix(x):
p = {"I": cidx, "S": midx, "E": ~cidx & ~midx}
r = np.outer(p[x[0]], p[x[1]])
Expand Down Expand Up @@ -4836,6 +4838,8 @@ def make_sz(h1e, g2e, const_e, cidx, iprint=1):
cidxa, cidxb = cidx
else:
cidxa, cidxb = cidx, cidx
assert cidxa.dtype == bool
assert cidxb.dtype == bool

def ix(x):
p = {"i": cidxa, "e": ~cidxa, "I": cidxb, "E": ~cidxb}
Expand Down Expand Up @@ -4929,6 +4933,8 @@ def make_sgf(h1e, g2e, const_e, cidx, iprint=1):
lambda x: (x & b.WickIndexTypes.Inactive) != b.WickIndexTypes.Nothing
)

assert cidx.dtype == bool

def ix(x):
p = {"I": cidx, "E": ~cidx}
r = np.outer(p[x[0]], p[x[1]])
Expand Down
4 changes: 2 additions & 2 deletions pyblock2/uc/ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@

class CI(lib.StreamObject):
def __init__(self, mf, frozen=None, mo_coeff=None, mo_occ=None, ci_order=2):
from pyscf.scf import hf, addons
from pyscf.scf import hf, rohf, addons

if isinstance(mf, hf.KohnShamDFT):
raise RuntimeError("CI Warning: The first argument mf is a DFT object.")
if isinstance(mf, scf.rohf.ROHF):
if isinstance(mf, rohf.ROHF):
lib.logger.warn(mf, 'RCI method does not support ROHF method. ROHF object '
'is converted to UHF object and UCI method is called.')
mf = addons.convert_to_uhf(mf)
Expand Down
96 changes: 96 additions & 0 deletions src/ic/wick.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,19 @@ struct WickIndex {
sort(r.begin(), r.end());
return set<WickIndex>(r.begin(), r.end());
}
void save(ostream &ofs) const {
size_t lname = (size_t)name.length();
ofs.write((char *)&lname, sizeof(lname));
ofs.write((char *)&name[0], sizeof(char) * lname);
ofs.write((char *)&types, sizeof(types));
}
void load(istream &ifs) {
size_t lname = 0;
ifs.read((char *)&lname, sizeof(lname));
name = string(lname, ' ');
ifs.read((char *)&name[0], sizeof(char) * lname);
ifs.read((char *)&types, sizeof(types));
}
};

struct WickPermutation {
Expand Down Expand Up @@ -307,6 +320,19 @@ struct WickPermutation {
assert(ir == nr);
return r;
}
void save(ostream &ofs) const {
size_t ldata = (size_t)data.size();
ofs.write((char *)&ldata, sizeof(ldata));
ofs.write((char *)&data[0], sizeof(int16_t) * ldata);
ofs.write((char *)&negative, sizeof(negative));
}
void load(istream &ifs) {
size_t ldata = 0;
ifs.read((char *)&ldata, sizeof(ldata));
data.resize(ldata);
ifs.read((char *)&data[0], sizeof(int16_t) * ldata);
ifs.read((char *)&negative, sizeof(negative));
}
};

struct WickTensor {
Expand Down Expand Up @@ -742,6 +768,37 @@ struct WickTensor {
}
return r;
}
void save(ostream &ofs) const {
size_t lname = (size_t)name.length();
ofs.write((char *)&lname, sizeof(lname));
ofs.write((char *)&name[0], sizeof(char) * lname);
size_t lindices = (size_t)indices.size();
ofs.write((char *)&lindices, sizeof(lindices));
for (size_t i = 0; i < lindices; i++)
indices[i].save(ofs);
size_t lperms = (size_t)perms.size();
ofs.write((char *)&lperms, sizeof(lperms));
for (size_t i = 0; i < lperms; i++)
perms[i].save(ofs);
ofs.write((char *)&type, sizeof(type));
}
void load(istream &ifs) {
size_t lname = 0;
ifs.read((char *)&lname, sizeof(lname));
name = string(lname, ' ');
ifs.read((char *)&name[0], sizeof(char) * lname);
size_t lindices = 0;
ifs.read((char *)&lindices, sizeof(lindices));
indices.resize(lindices);
for (size_t i = 0; i < lindices; i++)
indices[i].load(ifs);
size_t lperms = 0;
ifs.read((char *)&lperms, sizeof(lperms));
perms.resize(lperms);
for (size_t i = 0; i < lperms; i++)
perms[i].load(ifs);
ifs.read((char *)&type, sizeof(type));
}
};

struct WickString {
Expand Down Expand Up @@ -1461,6 +1518,32 @@ struct WickString {
xtensors.resize(xidxs.size());
return WickString(xtensors, xctr_indices, xfactor);
}
void save(ostream &ofs) const {
size_t ltensors = (size_t)tensors.size();
ofs.write((char *)&ltensors, sizeof(ltensors));
for (size_t i = 0; i < ltensors; i++)
tensors[i].save(ofs);
size_t lctr_indices = (size_t)ctr_indices.size();
ofs.write((char *)&lctr_indices, sizeof(lctr_indices));
for (const auto &wi : ctr_indices)
wi.save(ofs);
ofs.write((char *)&factor, sizeof(factor));
}
void load(istream &ifs) {
size_t ltensors = 0;
ifs.read((char *)&ltensors, sizeof(ltensors));
tensors.resize(ltensors);
for (size_t i = 0; i < ltensors; i++)
tensors[i].load(ifs);
size_t lctr_indices = 0;
ifs.read((char *)&lctr_indices, sizeof(lctr_indices));
for (size_t i = 0; i < lctr_indices; i++) {
WickIndex wi;
wi.load(ifs);
ctr_indices.insert(wi);
}
ifs.read((char *)&factor, sizeof(factor));
}
};

struct WickExpr {
Expand Down Expand Up @@ -2529,6 +2612,19 @@ struct WickExpr {
WickExpr simplify() const {
return simplify_delta().simplify_zero().simplify_merge();
}
void save(ostream &ofs) const {
size_t lterms = (size_t)terms.size();
ofs.write((char *)&lterms, sizeof(lterms));
for (size_t i = 0; i < lterms; i++)
terms[i].save(ofs);
}
void load(istream &ifs) {
size_t lterms = 0;
ifs.read((char *)&lterms, sizeof(lterms));
terms.resize(lterms);
for (size_t i = 0; i < lterms; i++)
terms[i].load(ifs);
}
};

inline WickExpr operator+(const WickString &a, const WickString &b) noexcept {
Expand Down
144 changes: 133 additions & 11 deletions src/pybind/pybind_ic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,31 @@ template <typename S = void> void bind_wick(py::module &m) {
.def_static("add_types", &WickIndex::add_types)
.def_static("parse_with_types", &WickIndex::parse_with_types)
.def_static("parse_set", &WickIndex::parse_set)
.def_static("parse_set_with_types", &WickIndex::parse_set_with_types);
.def_static("parse_set_with_types", &WickIndex::parse_set_with_types)
.def("save",
[](WickIndex *self, const string &filename) {
ofstream ofs(filename.c_str(), ios::binary);
if (!ofs.good())
throw runtime_error("WickIndex::save on '" + filename +
"' failed.");
self->save(ofs);
if (!ofs.good())
throw runtime_error("WickIndex::save on '" + filename +
"' failed.");
ofs.close();
})
.def("load", [](WickIndex *self, const string &filename) {
ifstream ifs(filename.c_str(), ios::binary);
if (!ifs.good())
throw runtime_error("WickIndex::load on '" + filename +
"' failed.");
self->load(ifs);
if (ifs.fail() || ifs.bad())
throw runtime_error("WickIndex::load on '" + filename +
"' failed.");
ifs.close();
return *self;
});

py::bind_vector<vector<WickIndex>>(m, "VectorWickIndex");
py::bind_set_block2<std::set<WickIndex>>(m, "SetWickIndex");
Expand Down Expand Up @@ -345,7 +369,31 @@ template <typename S = void> void bind_wick(py::module &m) {
&WickPermutation::pair_anti_symmetric)
.def_static("all", &WickPermutation::all)
.def_static("pair_symmetric", &WickPermutation::pair_symmetric,
py::arg("n"), py::arg("hermitian") = false);
py::arg("n"), py::arg("hermitian") = false)
.def("save",
[](WickPermutation *self, const string &filename) {
ofstream ofs(filename.c_str(), ios::binary);
if (!ofs.good())
throw runtime_error("WickPermutation::save on '" +
filename + "' failed.");
self->save(ofs);
if (!ofs.good())
throw runtime_error("WickPermutation::save on '" +
filename + "' failed.");
ofs.close();
})
.def("load", [](WickPermutation *self, const string &filename) {
ifstream ifs(filename.c_str(), ios::binary);
if (!ifs.good())
throw runtime_error("WickPermutation::load on '" + filename +
"' failed.");
self->load(ifs);
if (ifs.fail() || ifs.bad())
throw runtime_error("WickPermutation::load on '" + filename +
"' failed.");
ifs.close();
return *self;
});

py::bind_vector<vector<WickPermutation>>(m, "VectorWickPermutation");
py::bind_map<map<pair<string, int>, vector<WickPermutation>>>(
Expand Down Expand Up @@ -415,7 +463,31 @@ template <typename S = void> void bind_wick(py::module &m) {
.def("get_permutation_rules", &WickTensor::get_permutation_rules)
.def_static("get_index_map", &WickTensor::get_index_map)
.def_static("get_all_index_permutations",
&WickTensor::get_all_index_permutations);
&WickTensor::get_all_index_permutations)
.def("save",
[](WickTensor *self, const string &filename) {
ofstream ofs(filename.c_str(), ios::binary);
if (!ofs.good())
throw runtime_error("WickTensor::save on '" + filename +
"' failed.");
self->save(ofs);
if (!ofs.good())
throw runtime_error("WickTensor::save on '" + filename +
"' failed.");
ofs.close();
})
.def("load", [](WickTensor *self, const string &filename) {
ifstream ifs(filename.c_str(), ios::binary);
if (!ifs.good())
throw runtime_error("WickTensor::load on '" + filename +
"' failed.");
self->load(ifs);
if (ifs.fail() || ifs.bad())
throw runtime_error("WickTensor::load on '" + filename +
"' failed.");
ifs.close();
return *self;
});

py::bind_vector<vector<WickTensor>>(m, "VectorWickTensor");

Expand Down Expand Up @@ -449,10 +521,35 @@ template <typename S = void> void bind_wick(py::module &m) {
.def("simple_sort", &WickString::simple_sort)
.def("quick_sort", &WickString::quick_sort)
.def("simplify_delta", &WickString::simplify_delta)
.def("__repr__", [](WickString *self) {
stringstream ss;
ss << *self;
return ss.str();
.def("__repr__",
[](WickString *self) {
stringstream ss;
ss << *self;
return ss.str();
})
.def("save",
[](WickString *self, const string &filename) {
ofstream ofs(filename.c_str(), ios::binary);
if (!ofs.good())
throw runtime_error("WickString::save on '" + filename +
"' failed.");
self->save(ofs);
if (!ofs.good())
throw runtime_error("WickString::save on '" + filename +
"' failed.");
ofs.close();
})
.def("load", [](WickString *self, const string &filename) {
ifstream ifs(filename.c_str(), ios::binary);
if (!ifs.good())
throw runtime_error("WickString::load on '" + filename +
"' failed.");
self->load(ifs);
if (ifs.fail() || ifs.bad())
throw runtime_error("WickString::load on '" + filename +
"' failed.");
ifs.close();
return *self;
});

py::bind_vector<vector<WickString>>(m, "VectorWickString");
Expand Down Expand Up @@ -505,10 +602,35 @@ template <typename S = void> void bind_wick(py::module &m) {
.def("remove_inactive", &WickExpr::remove_inactive)
.def("add_spin_free_trans_symm", &WickExpr::add_spin_free_trans_symm)
.def("conjugate", &WickExpr::conjugate)
.def("__repr__", [](WickExpr *self) {
stringstream ss;
ss << *self;
return ss.str();
.def("__repr__",
[](WickExpr *self) {
stringstream ss;
ss << *self;
return ss.str();
})
.def("save",
[](WickExpr *self, const string &filename) {
ofstream ofs(filename.c_str(), ios::binary);
if (!ofs.good())
throw runtime_error("WickExpr::save on '" + filename +
"' failed.");
self->save(ofs);
if (!ofs.good())
throw runtime_error("WickExpr::save on '" + filename +
"' failed.");
ofs.close();
})
.def("load", [](WickExpr *self, const string &filename) {
ifstream ifs(filename.c_str(), ios::binary);
if (!ifs.good())
throw runtime_error("WickExpr::load on '" + filename +
"' failed.");
self->load(ifs);
if (ifs.fail() || ifs.bad())
throw runtime_error("WickExpr::load on '" + filename +
"' failed.");
ifs.close();
return *self;
});

py::bind_vector<vector<WickExpr>>(m, "VectorWickExpr");
Expand Down

0 comments on commit 64c7ce8

Please sign in to comment.