Skip to content

Commit

Permalink
optimize mpo disk/sany construct
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Jun 23, 2024
1 parent 75e381d commit 2a0e7ab
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 8 deletions.
7 changes: 7 additions & 0 deletions pyblock2/algebra/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,13 @@ def to_block2(mpo, basis, tag="PYMPO", add_ident=True):
sr = lop[mat.indices[ig][1]].q_label
sm = mat.data[ig].q_label
assert sl + sm == sr
for ii in range(0, bmpo.n_sites):
bmpo.save_tensor(ii)
bmpo.unload_tensor(ii)
bmpo.save_left_operators(ii)
bmpo.unload_left_operators(ii)
bmpo.save_right_operators(ii)
bmpo.unload_right_operators(ii)
bmpo = bs.SimplifiedMPO(bmpo, bs.Rule(), False, False)
if add_ident:
bmpo = bs.IdentityAddedMPO(bmpo)
Expand Down
15 changes: 15 additions & 0 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def __init__(
stack_mem_ratio=0.4,
fp_codec_cutoff=1e-16,
fp_codec_chunk=1024,
min_mpo_mem=False,
compressed_mps_storage=False,
):
"""
Expand Down Expand Up @@ -605,6 +606,8 @@ def __init__(
Default is 1E-16.
fp_codec_chunk : int
Chunk size for compressed storage of renormalized operators. Default is 1024.
min_mpo_mem : bool
If True, will dynamically load/save MPO to save memory. Default is False.
compressed_mps_storage : bool
Whether block-sparse tensor should be stored in compressed form to save storage (mainly for MPS).
Default is False.
Expand All @@ -621,6 +624,7 @@ def __init__(
self.stack_mem_ratio = stack_mem_ratio
self.fp_codec_cutoff = fp_codec_cutoff
self.fp_codec_chunk = fp_codec_chunk
self.min_mpo_mem = min_mpo_mem
self.compressed_mps_storage = compressed_mps_storage
self.symm_type = symm_type
self.clean_scratch = clean_scratch
Expand Down Expand Up @@ -721,6 +725,7 @@ def set_symm_type(self, symm_type, reset_frame=True):
self.frame.minimal_disk_usage = True
self.frame.use_main_stack = False
self.frame.compressed_sparse_tensor_storage = self.compressed_mps_storage
self.frame.minimal_memory_usage = self.min_mpo_mem

if self.mpi:
self.mpi = bw.brs.MPICommunicator()
Expand Down Expand Up @@ -3039,6 +3044,7 @@ def get_qc_mpo(
disjoint_all_blocks=False,
disjoint_multiplier=1.0,
block_max_length=False,
fast_no_orb_dep_op=False,
add_ident=True,
esptein_nesbet_partition=False,
ancilla=False,
Expand Down Expand Up @@ -3193,6 +3199,9 @@ def get_qc_mpo(
``MPOAlgorithmTypes.Bipartite`` appears in ``algo_type``.
If True, will separate the SVD or Bipartite for one- and two-electron integrals.
Default is False.
fast_no_orb_dep_op : bool
If the operator quantum number does not depend on orbital index,
one can set this True to save MPO construction time. Default is False.
add_ident : bool
If True, the hidden identity operator will be added into the MPO.
This is required when ``ecore`` is not zero and ``DMRGDriver.expectation``
Expand Down Expand Up @@ -3551,6 +3560,7 @@ def get_qc_mpo(
disjoint_all_blocks=disjoint_all_blocks,
disjoint_multiplier=disjoint_multiplier,
block_max_length=block_max_length,
fast_no_orb_dep_op=fast_no_orb_dep_op,
add_ident=add_ident,
ancilla=ancilla,
)
Expand All @@ -3571,6 +3581,7 @@ def get_mpo(
disjoint_all_blocks=False,
disjoint_multiplier=1.0,
block_max_length=False,
fast_no_orb_dep_op=False,
add_ident=True,
ancilla=False,
):
Expand Down Expand Up @@ -3626,6 +3637,9 @@ def get_mpo(
``MPOAlgorithmTypes.Bipartite`` appears in ``algo_type``.
If True, will separate the SVD or Bipartite for one- and two-electron integrals.
Default is False.
fast_no_orb_dep_op : bool
If the operator quantum number does not depend on orbital index,
one can set this True to save MPO construction time. Default is False.
add_ident : bool
If True, the hidden identity operator will be added into the MPO.
This is required when ``ecore`` is not zero and ``DMRGDriver.expectation``
Expand Down Expand Up @@ -3663,6 +3677,7 @@ def get_mpo(
mpo.disjoint_all_blocks = disjoint_all_blocks
mpo.disjoint_multiplier = disjoint_multiplier
mpo.block_max_length = block_max_length
mpo.fast_no_orb_dep_op = fast_no_orb_dep_op
mpo.build()

if iprint:
Expand Down
24 changes: 17 additions & 7 deletions src/dmrg/general_mpo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ template <typename S, typename FL> struct GeneralMPO : MPO<S, FL> {
vector<FP> disjoint_levels;
bool disjoint_all_blocks = false;
FP disjoint_multiplier = (FP)1.0;
bool block_max_length = false; // separate 1e/2e terms
bool block_max_length = false; // separate 1e/2e terms
bool fast_no_orb_dep_op = false; // fast mode for no orb_sym case
static inline size_t expr_index_hash(const string &expr,
const uint16_t *terms, int n,
const uint16_t init = 0) noexcept {
Expand Down Expand Up @@ -440,9 +441,14 @@ template <typename S, typename FL> struct GeneralMPO : MPO<S, FL> {
sub_exprs[ix][make_pair(0, term_l[ix])] =
GeneralHamiltonian<S, FL>::get_sub_expr(
afd->exprs[ix], 0, term_l[ix]);
pair<S, S> pq = hamil->get_string_quanta(
quanta_ref[ix], afd->exprs[ix],
&afd->indices[ix][itt], 0);
pair<S, S> pq =
fast_no_orb_dep_op
? make_pair(quanta_ref[ix][0],
quanta_ref[ix].back() -
quanta_ref[ix][0])
: hamil->get_string_quanta(
quanta_ref[ix], afd->exprs[ix],
&afd->indices[ix][itt], 0);
q_map[make_pair(make_pair(0, make_pair(0, 0)),
qh.combine(pq.first, -pq.second))] = 0;
map_ls.emplace_back();
Expand Down Expand Up @@ -502,9 +508,13 @@ template <typename S, typename FL> struct GeneralMPO : MPO<S, FL> {
lstr, afd->indices[ix].data() + itt + ik, k - ik, ip);
size_t hr = expr_index_hash(
rstr, afd->indices[ix].data() + itt + k, kmax - k, 1);
pair<S, S> pq =
hamil->get_string_quanta(quanta_ref[ix], afd->exprs[ix],
&afd->indices[ix][itt], k);
pair<S, S> pq = fast_no_orb_dep_op
? make_pair(quanta_ref[ix][k],
quanta_ref[ix].back() -
quanta_ref[ix][k])
: hamil->get_string_quanta(
quanta_ref[ix], afd->exprs[ix],
&afd->indices[ix][itt], k);
S qq = qh.combine(pq.first, -pq.second);
// possible error here due to unsymmetrized integral
assert(qq != S(S::invalid));
Expand Down
18 changes: 18 additions & 0 deletions src/dmrg/mpo_fusing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ template <typename S, typename FL> struct StackedMPO : MPO<S, FL> {
}
shared_ptr<OperatorTensor<S, FL>> opt =
make_shared<OperatorTensor<S, FL>>();
mpoa->load_tensor(m);
mpoa->load_left_operators(m);
mpoa->load_right_operators(m);
mpob->load_tensor(m);
mpob->load_left_operators(m);
mpob->load_right_operators(m);
const int xm =
mpoa->tensors[m]->lmat->m * mpob->tensors[m]->lmat->m;
const int xn =
Expand Down Expand Up @@ -406,6 +412,18 @@ template <typename S, typename FL> struct StackedMPO : MPO<S, FL> {
if (iprint)
cout << "Mmpo = " << setw(10) << xn << " T = " << fixed
<< setprecision(3) << tsite << endl;
mpoa->unload_tensor(m);
mpoa->unload_left_operators(m);
mpoa->unload_right_operators(m);
mpob->unload_tensor(m);
mpob->unload_left_operators(m);
mpob->unload_right_operators(m);
this->save_tensor(m);
this->unload_tensor(m);
this->save_left_operators(m);
this->unload_left_operators(m);
this->save_right_operators(m);
this->unload_right_operators(m);
}
if (iprint) {
cout << "Ttotal = " << fixed << setprecision(3) << setw(10)
Expand Down
2 changes: 1 addition & 1 deletion src/dmrg/mpo_simplification.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ template <typename S, typename FL> struct SimplifiedMPO : MPO<S, FL> {
OpNamesSet intermediate_ops = OpNamesSet::all_ops(),
const string &tag = "", bool check_indirect_ref = true)
: prim_mpo(mpo), rule(rule),
MPO<S, FL>(mpo->n_sites, tag == "" ? mpo->tag : tag),
MPO<S, FL>(mpo->n_sites, tag == "" ? "SMP-" + mpo->tag : tag),
collect_terms(collect_terms), use_intermediate(use_intermediate),
intermediate_ops(intermediate_ops),
check_indirect_ref(check_indirect_ref) {
Expand Down
2 changes: 2 additions & 0 deletions src/pybind/pybind_dmrg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2122,6 +2122,8 @@ template <typename S, typename FL> void bind_fl_general(py::module &m) {
.def_readwrite("disjoint_multiplier",
&GeneralMPO<S, FL>::disjoint_multiplier)
.def_readwrite("block_max_length", &GeneralMPO<S, FL>::block_max_length)
.def_readwrite("fast_no_orb_dep_op",
&GeneralMPO<S, FL>::fast_no_orb_dep_op)
.def(py::init<const shared_ptr<GeneralHamiltonian<S, FL>> &,
const shared_ptr<GeneralFCIDUMP<FL>> &,
MPOAlgorithmTypes>(),
Expand Down

0 comments on commit 2a0e7ab

Please sign in to comment.