diff --git a/pyblock2/driver/core.py b/pyblock2/driver/core.py index 0326483b..6f04a0b1 100644 --- a/pyblock2/driver/core.py +++ b/pyblock2/driver/core.py @@ -565,6 +565,7 @@ def __init__( fp_codec_cutoff=1e-16, fp_codec_chunk=1024, min_mpo_mem=False, + seq_type=None, compressed_mps_storage=False, ): """ @@ -608,6 +609,8 @@ def __init__( Chunk size for compressed storage of renormalized operators. Default is 1024. min_mpo_mem : bool If True, will dynamically load/save MPO to save memory. Default is False. + seq_type : None or str + Shared-memory scheme type. Default is None ('Tasked'). compressed_mps_storage : bool Whether block-sparse tensor should be stored in compressed form to save storage (mainly for MPS). Default is False. @@ -638,7 +641,11 @@ def __init__( n_threads // n_mkl_threads, n_mkl_threads, ) - bw.b.Global.threading.seq_type = bw.b.SeqTypes.Tasked + if seq_type is None: + seq_type = bw.b.SeqTypes.Tasked + else: + seq_type = getattr(bw.b.SeqTypes, seq_type) + bw.b.Global.threading.seq_type = seq_type self.reorder_idx = None self.pg = "c1" self.orb_sym = None @@ -7146,8 +7153,8 @@ def get_spin_projection_mpo( it = np.ones((1, 1, 1, 1)) pympo = None for ixw, (xt, wt) in enumerate(zip(xts, wts)): - if not mpi_split or ( - mpi_split + if self.mpi is None or not mpi_split or ( + mpi_split and self.mpi.rank == min(ixw, len(wts) - 1 - ixw) % self.mpi.size ): ct = np.cos(xt / 2) * it diff --git a/src/dmrg/moving_environment.hpp b/src/dmrg/moving_environment.hpp index 4da37f9b..9aace3a2 100644 --- a/src/dmrg/moving_environment.hpp +++ b/src/dmrg/moving_environment.hpp @@ -2031,103 +2031,116 @@ template struct MovingEnvironment { frame_()->activate(0); mps->tensors[i] = old_wfn; } - static shared_ptr> - symm_context_convert_group(int i, const shared_ptr> &mps, - const shared_ptr> &cmps, int dot, - bool fuse_left, bool mask, bool forward, bool is_wfn, - bool infer_info, - const shared_ptr> &pket) { - shared_ptr> cpket = nullptr; - symm_context_convert_impl(i, mps, cmps, dot, fuse_left, mask, forward, - is_wfn, infer_info, pket, cpket); - return cpket; - } static shared_ptr> symm_context_convert(int i, const shared_ptr> &mps, const shared_ptr> &cmps, int dot, bool fuse_left, bool mask, bool forward, bool is_wfn, - bool infer_info) { - shared_ptr> cpket = nullptr; - return symm_context_convert_impl(i, mps, cmps, dot, fuse_left, mask, - forward, is_wfn, infer_info, nullptr, - cpket); + bool infer_info, + shared_ptr> ket = nullptr, + shared_ptr> cket = nullptr) { + return symm_context_convert_impl( + i, mps->info, cmps->info, dot, fuse_left, mask, forward, + is_wfn, infer_info, + ket == nullptr && !(!forward && infer_info) ? mps->tensors[i] + : ket, + cket == nullptr && !(forward && infer_info) + ? cmps->tensors[i] + : cket, + nullptr, nullptr) + .first; + } + static shared_ptr> symm_context_convert_group( + int i, const shared_ptr> &mps, + const shared_ptr> &cmps, int dot, bool fuse_left, bool mask, + bool forward, bool is_wfn, bool infer_info, + const shared_ptr> &pket) { + return symm_context_convert_impl(i, mps->info, cmps->info, dot, + fuse_left, mask, forward, is_wfn, + infer_info, mps->tensors[i], + cmps->tensors[i], pket, nullptr) + .second; } // forward = proj to high symmetry - static shared_ptr> - symm_context_convert_impl(int i, const shared_ptr> &mps, - const shared_ptr> &cmps, int dot, + static pair>, + shared_ptr>> + symm_context_convert_impl(int i, const shared_ptr> &info, + const shared_ptr> &cinfo, int dot, bool fuse_left, bool mask, bool forward, bool is_wfn, bool infer_info, - const shared_ptr> &pket, - shared_ptr> &cpket) { + shared_ptr> ket, + shared_ptr> cket, + shared_ptr> pket, + shared_ptr> cpket) { if (is_wfn || fuse_left) - mps->info->load_left_dims(i), cmps->info->load_left_dims(i); + info->load_left_dims(i), cinfo->load_left_dims(i); else - mps->info->load_right_dims(i), cmps->info->load_right_dims(i); + info->load_right_dims(i), cinfo->load_right_dims(i); if (is_wfn || !fuse_left) - mps->info->load_right_dims(i + dot), - cmps->info->load_right_dims(i + dot); + info->load_right_dims(i + dot), cinfo->load_right_dims(i + dot); else - mps->info->load_left_dims(i + dot), - cmps->info->load_left_dims(i + dot); - StateInfo l = is_wfn || fuse_left ? *mps->info->left_dims[i] - : *mps->info->right_dims[i], - ml = *mps->info->basis[i], - mr = *mps->info->basis[i + dot - 1], - r = is_wfn || !fuse_left ? *mps->info->right_dims[i + dot] - : *mps->info->left_dims[i + dot]; + info->load_left_dims(i + dot), cinfo->load_left_dims(i + dot); + StateInfo l = is_wfn || fuse_left ? *info->left_dims[i] + : *info->right_dims[i], + ml = dot == 0 ? StateInfo() : *info->basis[i], + mr = dot == 0 ? StateInfo() : *info->basis[i + dot - 1], + r = is_wfn || !fuse_left ? *info->right_dims[i + dot] + : *info->left_dims[i + dot]; shared_ptr> ll = - dot == 2 || fuse_left + dot == 2 || (dot != 0 && fuse_left) ? make_shared>(StateInfo::tensor_product( - l, ml, *mps->info->left_dims_fci[i + 1])) + l, ml, *info->left_dims_fci[i + 1])) : make_shared>(l); shared_ptr::ConnectionInfo> clm = - dot == 2 || fuse_left + dot == 2 || (dot != 0 && fuse_left) ? StateInfo::get_connection_info(l, ml, *ll) : nullptr; shared_ptr> rr = - dot == 2 || !fuse_left + dot == 2 || (dot != 0 && !fuse_left) ? make_shared>(StateInfo::tensor_product( - mr, r, *mps->info->right_dims_fci[i + dot - 1])) + mr, r, *info->right_dims_fci[i + dot - 1])) : make_shared>(r); shared_ptr::ConnectionInfo> cmr = - dot == 2 || !fuse_left + dot == 2 || (dot != 0 && !fuse_left) ? StateInfo::get_connection_info(mr, r, *rr) : nullptr; - StateInfo lu = is_wfn || fuse_left ? *cmps->info->left_dims[i] - : *cmps->info->right_dims[i], - mlu = *cmps->info->basis[i], - mru = *cmps->info->basis[i + dot - 1], - ru = is_wfn || !fuse_left - ? *cmps->info->right_dims[i + dot] - : *cmps->info->left_dims[i + dot]; + StateInfo lu = is_wfn || fuse_left ? *cinfo->left_dims[i] + : *cinfo->right_dims[i], + mlu = dot == 0 ? StateInfo() : *cinfo->basis[i], + mru = + dot == 0 ? StateInfo() : *cinfo->basis[i + dot - 1], + ru = is_wfn || !fuse_left ? *cinfo->right_dims[i + dot] + : *cinfo->left_dims[i + dot]; shared_ptr> llu = - dot == 2 || fuse_left + dot == 2 || (dot != 0 && fuse_left) ? make_shared>(StateInfo::tensor_product( - lu, mlu, *cmps->info->left_dims_fci[i + 1])) + lu, mlu, *cinfo->left_dims_fci[i + 1])) : make_shared>(lu); shared_ptr::ConnectionInfo> clmu = - dot == 2 || fuse_left + dot == 2 || (dot != 0 && fuse_left) ? StateInfo::get_connection_info(lu, mlu, *llu) : nullptr; shared_ptr> rru = - dot == 2 || !fuse_left + dot == 2 || (dot != 0 && !fuse_left) ? make_shared>(StateInfo::tensor_product( - mru, ru, *cmps->info->right_dims_fci[i + dot - 1])) + mru, ru, *cinfo->right_dims_fci[i + dot - 1])) : make_shared>(ru); shared_ptr::ConnectionInfo> cmru = - dot == 2 || !fuse_left + dot == 2 || (dot != 0 && !fuse_left) ? StateInfo::get_connection_info(mru, ru, *rru) : nullptr; shared_ptr> d_alloc = make_shared>(); shared_ptr> i_alloc = make_shared>(); - shared_ptr> mask_wfn = - make_shared>(d_alloc); - S ref = mps->info->vacuum; - S refu = cmps->info->vacuum; - if (pket != nullptr) { + const S ref = info->vacuum, refu = cinfo->vacuum; + const bool is_group = pket != nullptr || cpket != nullptr; + shared_ptr> r_wfn = + is_group ? nullptr : make_shared>(d_alloc); + shared_ptr> gr_wfn = + is_group ? make_shared>(d_alloc) + : nullptr; + if (is_group && infer_info) { + // FIXME: multi will have problem vector pket_dqs; for (int iw = 0; iw < pket->n; iw++) { S dq = pket->infos[iw]->delta_quantum; @@ -2147,16 +2160,15 @@ template struct MovingEnvironment { info->initialize(*llu, *rru, pket_dqs[j], false, true); infos.push_back(info); } - cpket = make_shared>(d_alloc); - cpket->allocate(infos); - cpket->clear(); - } else if (infer_info) { + gr_wfn->allocate(infos); + gr_wfn->clear(); + } else if (!is_group && infer_info) { shared_ptr> xinfo = make_shared>(i_alloc); shared_ptr> xll = forward ? llu : ll; shared_ptr> xrr = forward ? rru : rr; - S xdq = is_wfn ? (forward ? cmps->info->target : mps->info->target) - : (forward ? cmps->info->vacuum : mps->info->vacuum); + S xdq = is_wfn ? (forward ? cinfo->target : info->target) + : (forward ? cinfo->vacuum : info->vacuum); if (fuse_left) { xrr = forward ? TransStateInfo::forward(rr, refu) : TransStateInfo::forward(rru, ref); @@ -2173,13 +2185,16 @@ template struct MovingEnvironment { ll = xll, l = *xll; } xinfo->initialize(*xll, *xrr, xdq, false, is_wfn); - mask_wfn->allocate(xinfo); - } else - mask_wfn->allocate(forward ? cmps->tensors[i]->info - : mps->tensors[i]->info); - if (pket == nullptr) - mask_wfn->clear(); - S cptu = cmps->info->target, cpt = mps->info->target; + r_wfn->allocate(xinfo); + r_wfn->clear(); + } else if (is_group) { + gr_wfn->allocate(forward ? cpket->infos : pket->infos); + gr_wfn->clear(); + } else { + r_wfn->allocate(forward ? cket->info : ket->info); + r_wfn->clear(); + } + S cptu = cinfo->target, cpt = info->target; shared_ptr> cplu = is_wfn || fuse_left ? make_shared>(lu) : make_shared>( @@ -2199,24 +2214,27 @@ template struct MovingEnvironment { shared_ptr> conn_l = TransStateInfo::backward_connection(cplu, cpl); shared_ptr> conn_lm = - dot == 2 || fuse_left ? TransStateInfo::backward_connection( - make_shared>(mlu), - make_shared>(ml)) - : nullptr; + dot == 2 || (dot != 0 && fuse_left) + ? TransStateInfo::backward_connection( + make_shared>(mlu), + make_shared>(ml)) + : nullptr; shared_ptr> conn_mr = - dot == 2 || !fuse_left ? TransStateInfo::backward_connection( - make_shared>(mru), - make_shared>(mr)) - : nullptr; + dot == 2 || (dot != 0 && !fuse_left) + ? TransStateInfo::backward_connection( + make_shared>(mru), + make_shared>(mr)) + : nullptr; shared_ptr> conn_r = TransStateInfo::backward_connection(cpru, cpr); + map, pair> mp0; map, pair> mp; map, pair> mp2; - int nxw = pket == nullptr ? 1 : pket->n; + int nxw = is_group ? (forward ? pket->n : gr_wfn->n) : 1; for (int iw = 0; iw < nxw; iw++) { shared_ptr> xwfn = - forward ? (pket == nullptr ? mps->tensors[i] : (*pket)[iw]) - : mask_wfn; + forward ? (is_group ? (*pket)[iw] : ket) + : (is_group ? (*gr_wfn)[iw] : r_wfn); for (int k = 0; k < xwfn->info->n; k++) { S pln = xwfn->info->quanta[k].get_bra(xwfn->info->delta_quantum); @@ -2258,7 +2276,9 @@ template struct MovingEnvironment { } assert(p - xwfn->info->n_states_total[k] == xwfn->info->n_states_ket[k]); - } else if (dot == 2) { + } else if (dot == 0) + mp0[array{pln, prn}] = make_pair(xwfn->data, 0); + else if (dot == 2) { int ib = ll->find_state(pln), ik = rr->find_state(prn); int bbed = clm->acc_n_states[ib + 1], kked = cmr->acc_n_states[ik + 1]; @@ -2291,10 +2311,11 @@ template struct MovingEnvironment { assert(false); } } - nxw = pket == nullptr ? 1 : cpket->n; + nxw = is_group ? (forward ? gr_wfn->n : cpket->n) : 1; for (int iw = 0; iw < nxw; iw++) { shared_ptr> cwfn = - pket == nullptr ? cmps->tensors[i] : (*cpket)[iw]; + forward ? (is_group ? (*gr_wfn)[iw] : r_wfn) + : (is_group ? (*cpket)[iw] : cket); for (int k = 0; k < cwfn->info->n; k++) { S plu = cwfn->info->quanta[k].get_bra(cwfn->info->delta_quantum); @@ -2313,9 +2334,7 @@ template struct MovingEnvironment { cwfn->info->n_states_ket[k]; S pplu = lu.quanta[ibbau], ppmu = mlu.quanta[ibbbu], ppru = pru; - FLS *x = forward && pket == nullptr - ? mask_wfn->data + pu - : cwfn->data + pu; + FLS *x = cwfn->data + pu; pu += lpu; shared_ptr> mls = TransStateInfo::forward( @@ -2440,9 +2459,7 @@ template struct MovingEnvironment { (size_t)mru.n_states[ikkau] * ru.n_states[ikkbu]; S pplu = plu, ppmu = mru.quanta[ikkau], ppru = ru.quanta[ikkbu]; - FLS *x = forward && pket == nullptr - ? mask_wfn->data + pu - : cwfn->data + pu; + FLS *x = cwfn->data + pu; size_t xstr = cwfn->info->n_states_ket[k]; pu += lpu; shared_ptr> mls = @@ -2556,6 +2573,71 @@ template struct MovingEnvironment { } assert(pu - cwfn->info->n_states_total[k] == cwfn->info->n_states_ket[k]); + } else if (dot == 0) { + S pplu = plu, ppru = pru; + FLS *x = cwfn->data; + shared_ptr> mls = + TransStateInfo::forward( + make_shared>(pplu), ref); + shared_ptr> mrs = + TransStateInfo::forward( + make_shared>(ppru), ref); + S xpplu = is_wfn || fuse_left ? pplu : cptu - pplu; + S xppru = is_wfn || !fuse_left ? cptu - ppru : ppru; + for (int iln = 0; iln < mls->n; iln++) + for (int irn = 0; irn < mrs->n; irn++) { + S lqn = mls->quanta[iln], rqn = mrs->quanta[irn]; + if (!mp0.count(array{lqn, rqn})) + continue; + FLS *xr = mp0.at(array{lqn, rqn}).first; + lqn = is_wfn || fuse_left ? lqn : cpt - lqn; + rqn = is_wfn || !fuse_left ? cpt - rqn : rqn; + int il = cpl->find_state(lqn); + int ir = cpr->find_state(rqn); + MKL_INT zl = cpl->n_states[il], + zr = cpr->n_states[ir]; + int klst = conn_l->n_states[il]; + int krst = conn_r->n_states[ir]; + int kled = il == cpl->n - 1 + ? conn_l->n + : conn_l->n_states[il + 1]; + int kred = ir == cpr->n - 1 + ? conn_r->n + : conn_r->n_states[ir + 1]; + size_t lsh = 0, rsh = 0; + for (int ilp = klst; + ilp < kled && conn_l->quanta[ilp] != xpplu; + ilp++) + lsh += cplu->n_states[cplu->find_state( + conn_l->quanta[ilp])]; + for (int irp = krst; + irp < kred && conn_r->quanta[irp] != xppru; + irp++) + rsh += cpru->n_states[cpru->find_state( + conn_r->quanta[irp])]; + MKL_INT kl = + (MKL_INT) + cplu->n_states[cplu->find_state(xpplu)]; + MKL_INT kr = + (MKL_INT) + cpru->n_states[cpru->find_state(xppru)]; + if (mask) { + for (MKL_INT ikl = 0; ikl < kl; ikl++) + for (MKL_INT ikr = 0; ikr < kr; ikr++) + xr[(ikl + lsh) * zr + (ikr + rsh)] = + 1.0; + } else if (forward) { + for (MKL_INT ikl = 0; ikl < kl; ikl++) + for (MKL_INT ikr = 0; ikr < kr; ikr++) + x[(size_t)ikl * kr + (size_t)ikr] += + xr[(ikl + lsh) * zr + (ikr + rsh)]; + } else { + for (MKL_INT ikl = 0; ikl < kl; ikl++) + for (MKL_INT ikr = 0; ikr < kr; ikr++) + xr[(ikl + lsh) * zr + (ikr + rsh)] = + x[(size_t)ikl * kr + (size_t)ikr]; + } + } } else if (dot == 2) { int ibu = llu->find_state(plu), iku = rru->find_state(pru); int bbedu = clmu->acc_n_states[ibu + 1], @@ -2579,9 +2661,7 @@ template struct MovingEnvironment { ppru = ru.quanta[ikkbu]; size_t npmru = mru.n_states[ikkau], npru = ru.n_states[ikkbu]; - FLS *x = forward && pket == nullptr - ? mask_wfn->data + pu + iplu + ipru - : cwfn->data + pu + iplu + ipru; + FLS *x = cwfn->data + pu + iplu + ipru; size_t xstr = cwfn->info->n_states_ket[k]; ipru += (size_t)npmru * npru; S xppru = is_wfn || !fuse_left ? cptu - ppru : ppru; @@ -2800,7 +2880,7 @@ template struct MovingEnvironment { } } } - return mask_wfn; + return make_pair(r_wfn, gr_wfn); } // Contract two adjcent MPS tensors to one two-site MPS tensor static void contract_two_dot(int i, const shared_ptr> &mps, diff --git a/src/dmrg/sweep_algorithm.hpp b/src/dmrg/sweep_algorithm.hpp index 5f653d5f..331adce8 100644 --- a/src/dmrg/sweep_algorithm.hpp +++ b/src/dmrg/sweep_algorithm.hpp @@ -187,6 +187,8 @@ template struct DMRG { // state specific if (ext_mpss.size() != 0) mpss.insert(mpss.end(), ext_mpss.begin(), ext_mpss.end()); + if (context_ket != nullptr) + mpss.insert(mpss.begin(), context_ket); for (auto &mps : mpss) { if (mps->canonical_form[i] == 'C') { if (i == 0) @@ -218,10 +220,16 @@ template struct DMRG { prev_wfn->deallocate(); } } + if (context_ket != nullptr) + me->ket->tensors[i] = + MovingEnvironment::symm_context_convert( + i, me->ket, context_ket, 1, fuse_left, false, false, true, + false); int mmps = 0; FPS error = 0.0; tuple pdi; - shared_ptr> pket = nullptr; + shared_ptr> pket = nullptr, + context_pket = nullptr; shared_ptr> pdm = nullptr; bool skip_decomp = !decomp_last_site && @@ -241,7 +249,25 @@ template struct DMRG { !skip_decomp) { // change to fused form for splitting if (fuse_left != forward) { - shared_ptr> prev_wfn = me->ket->tensors[i]; + shared_ptr> prev_wfn = nullptr; + if (context_ket != nullptr) { + prev_wfn = context_ket->tensors[i]; + if (!fuse_left && forward) + context_ket->tensors[i] = + MovingEnvironment:: + swap_wfn_to_fused_left(i, context_ket->info, + prev_wfn, + me->mpo->tf->opf->cg); + else if (fuse_left && !forward) + context_ket->tensors[i] = + MovingEnvironment:: + swap_wfn_to_fused_right(i, context_ket->info, + prev_wfn, + me->mpo->tf->opf->cg); + prev_wfn->info->deallocate(); + prev_wfn->deallocate(); + } + prev_wfn = me->ket->tensors[i]; if (!fuse_left && forward) me->ket->tensors[i] = MovingEnvironment::swap_wfn_to_fused_left( @@ -283,6 +309,24 @@ template struct DMRG { } } } + shared_ptr> xpket = pket; + shared_ptr> xket = me->ket; + if (context_ket != nullptr) { + context_ket->tensors[i] = + MovingEnvironment::symm_context_convert( + i, me->ket, context_ket, 1, + !skip_decomp ? forward : fuse_left, false, true, true, + false); + xket = context_ket; + if (pket != nullptr) { + context_pket = + MovingEnvironment::symm_context_convert_group( + i, me->ket, context_ket, 1, + !skip_decomp ? forward : fuse_left, false, true, true, + true, pket); + xpket = context_pket; + } + } // state specific for (auto &mps : ext_mpss) { if (mps->info->bond_dim < bond_dim) @@ -322,16 +366,21 @@ template struct DMRG { _t.get_time(); assert(decomp_type == DecompositionTypes::DensityMatrix); pdm = MovingEnvironment::density_matrix( - me->ket->info->vacuum, me->ket->tensors[i], forward, + xket->info->vacuum, xket->tensors[i], forward, me->para_rule != nullptr ? noise / me->para_rule->comm->size : noise, - noise_type, 0.0, pket); + noise_type, 0.0, xpket); if (me->para_rule != nullptr) me->para_rule->comm->reduce_sum(pdm, me->para_rule->comm->root); tdm += _t.get_time(); } if (me->para_rule == nullptr || me->para_rule->is_root()) { if (skip_decomp) { + if (context_ket != nullptr) { + context_ket->save_tensor(i); + context_ket->unload_tensor(i); + context_ket->canonical_form[i] = forward ? 'S' : 'K'; + } me->ket->save_tensor(i); me->ket->unload_tensor(i); me->ket->canonical_form[i] = forward ? 'S' : 'K'; @@ -343,17 +392,21 @@ template struct DMRG { } else { // splitting of wavefunction shared_ptr> old_ket = me->ket->tensors[i], - old_bra = nullptr; + old_bra = nullptr, + old_context_ket = nullptr; if (me->bra != me->ket) old_bra = me->bra->tensors[i]; + if (context_ket != nullptr) + old_context_ket = context_ket->tensors[i]; shared_ptr> dm, dm_b, left_k, right_k, + context_left_k = nullptr, context_right_k = nullptr, left_b = nullptr, right_b = nullptr; if (decomp_type == DecompositionTypes::DensityMatrix) { _t.get_time(); const FPS factor = me->bra != me->ket ? (FPS)0.5 : (FPS)1.0; dm = MovingEnvironment::density_matrix( - me->ket->info->vacuum, me->ket->tensors[i], forward, - build_pdm ? 0.0 : noise, noise_type, factor, pket); + xket->info->vacuum, xket->tensors[i], forward, + build_pdm ? 0.0 : noise, noise_type, factor, xpket); if (me->bra != me->ket) MovingEnvironment::density_matrix_add_wfn( dm, me->bra->tensors[i], forward, 0.5); @@ -369,7 +422,7 @@ template struct DMRG { dm_b = dm->deep_copy(make_shared>()); error = MovingEnvironment::split_density_matrix( - dm, me->ket->tensors[i], (int)bond_dim, forward, true, + dm, xket->tensors[i], (int)bond_dim, forward, true, left_k, right_k, cutoff, store_wfn_spectra, wfn_spectra, trunc_type); // TODO: this may have some problem if small numerical @@ -392,31 +445,30 @@ template struct DMRG { if (noise != 0) { if (noise_type & NoiseTypes::Wavefunction) MovingEnvironment:: - wavefunction_add_noise(me->ket->tensors[i], - noise); + wavefunction_add_noise(xket->tensors[i], noise); else if (noise_type & NoiseTypes::Perturbative) MovingEnvironment:: scale_perturbative_noise(noise, noise_type, - pket); + xpket); } _t.get_time(); if (me->bra == me->ket) error = MovingEnvironment:: split_wavefunction_svd( - me->ket->info->vacuum, me->ket->tensors[i], + xket->info->vacuum, xket->tensors[i], (int)bond_dim, forward, true, left_k, right_k, cutoff, store_wfn_spectra, wfn_spectra, - trunc_type, decomp_type, pket); + trunc_type, decomp_type, xpket); else { vector weights = {(FPS)0.5, (FPS)0.5}; vector>> xwfns = { me->bra->tensors[i]}; error = MovingEnvironment:: split_wavefunction_svd( - me->ket->info->vacuum, me->ket->tensors[i], + xket->info->vacuum, xket->tensors[i], (int)bond_dim, forward, true, left_k, right_k, cutoff, store_wfn_spectra, wfn_spectra, - trunc_type, decomp_type, pket, xwfns, weights); + trunc_type, decomp_type, xpket, xwfns, weights); xwfns[0] = me->ket->tensors[i]; error += MovingEnvironment:: split_wavefunction_svd( @@ -429,8 +481,89 @@ template struct DMRG { tsvd += _t.get_time(); } else assert(false); + if (context_ket != nullptr) { + if (forward) { + context_ket->info->left_dims[i + 1] = + left_k->info->extract_state_info(forward); + context_ket->info->save_left_dims(i + 1); + } else { + context_ket->info->right_dims[i] = + right_k->info->extract_state_info(forward); + ; + context_ket->info->save_right_dims(i); + } + context_left_k = left_k, context_right_k = right_k; + left_k = + MovingEnvironment::symm_context_convert( + forward ? i : i, me->ket, context_ket, + forward ? 1 : 0, true, false, false, !forward, true, + nullptr, context_left_k); + right_k = + MovingEnvironment::symm_context_convert( + forward ? i + 1 : i, me->ket, context_ket, + forward ? 0 : 1, false, false, false, forward, true, + nullptr, context_right_k); + } shared_ptr> info = nullptr; + int context_mmps = 0; // propagation + if (context_ket != nullptr) { + if (forward) { + context_ket->tensors[i] = context_left_k; + context_ket->save_tensor(i); + info = + context_left_k->info->extract_state_info(forward); + context_mmps = (int)info->n_states_total; + context_ket->info->bond_dim = max( + context_ket->info->bond_dim, (ubond_t)context_mmps); + context_ket->info->left_dims[i + 1] = info; + context_ket->info->save_left_dims(i + 1); + info->deallocate(); + if (i != sweep_end_site - 1) { + MovingEnvironment::contract_one_dot( + i + 1, context_right_k, context_ket, forward); + context_ket->save_tensor(i + 1); + context_ket->unload_tensor(i + 1); + context_ket->canonical_form[i] = 'L'; + context_ket->canonical_form[i + 1] = 'S'; + } else { + context_ket->tensors[i] = + make_shared>(); + MovingEnvironment::contract_one_dot( + i, context_right_k, context_ket, !forward); + context_ket->save_tensor(i); + context_ket->unload_tensor(i); + context_ket->canonical_form[i] = 'K'; + } + } else { + context_ket->tensors[i] = context_right_k; + context_ket->save_tensor(i); + info = + context_right_k->info->extract_state_info(forward); + context_mmps = (int)info->n_states_total; + context_ket->info->bond_dim = max( + context_ket->info->bond_dim, (ubond_t)context_mmps); + context_ket->info->right_dims[i] = info; + context_ket->info->save_right_dims(i); + info->deallocate(); + if (i > sweep_start_site) { + MovingEnvironment::contract_one_dot( + i - 1, context_left_k, context_ket, forward); + context_ket->save_tensor(i - 1); + context_ket->unload_tensor(i - 1); + context_ket->canonical_form[i - 1] = 'K'; + context_ket->canonical_form[i] = 'R'; + } else { + context_ket->tensors[i] = + make_shared>(); + MovingEnvironment::contract_one_dot( + i, context_left_k, context_ket, !forward); + context_ket->save_tensor(i); + context_ket->unload_tensor(i); + context_ket->canonical_form[i] = 'S'; + } + } + } mpss = {me->ket}; if (me->bra != me->ket) mpss.insert(mpss.begin(), me->bra); @@ -494,6 +627,8 @@ template struct DMRG { } } } + if (context_ket != nullptr) + mmps = context_mmps; if (right_b != nullptr) { right_b->info->deallocate(); right_b->deallocate(); @@ -504,6 +639,12 @@ template struct DMRG { right_k->deallocate(); left_k->info->deallocate(); left_k->deallocate(); + if (context_ket != nullptr) { + context_right_k->info->deallocate(); + context_right_k->deallocate(); + context_left_k->info->deallocate(); + context_left_k->deallocate(); + } if (dm != nullptr) { dm->info->deallocate(); dm->deallocate(); @@ -518,10 +659,16 @@ template struct DMRG { } old_ket->info->deallocate(); old_ket->deallocate(); + if (old_context_ket != nullptr) { + old_context_ket->info->deallocate(); + old_context_ket->deallocate(); + } } me->ket->save_data(); if (me->bra != me->ket) me->bra->save_data(); + if (context_ket != nullptr) + context_ket->save_data(); } else { if (pdm != nullptr) { pdm->info->deallocate(); @@ -530,6 +677,8 @@ template struct DMRG { mpss = {me->ket}; if (me->bra != me->ket) mpss.insert(mpss.begin(), me->bra); + if (context_ket != nullptr) + mpss.insert(mpss.begin(), context_ket); for (auto &mps : mpss) { mps->unload_tensor(i); if (skip_decomp) @@ -602,6 +751,11 @@ template struct DMRG { fuse_left ? FuseTypes::FuseL : FuseTypes::FuseR, forward, true, me->bra->tensors[i], me->ket->tensors[i]); h_eff->eff_kernel = eff_kernel; + if (context_ket != nullptr) + h_eff->context_mask = + MovingEnvironment::symm_context_convert( + i, me->ket, context_ket, 1, fuse_left, true, false, true, + false); if (store_seq_data) { stringstream ss; ss << frame_()->save_dir << "/" << frame_()->prefix_distri @@ -710,7 +864,7 @@ template struct DMRG { context_pket = MovingEnvironment::symm_context_convert_group( i, me->ket, context_ket, 2, true, false, true, true, - false, pket); + true, pket); xpket = context_pket; } } @@ -853,14 +1007,15 @@ template struct DMRG { forward, true); } shared_ptr> info = nullptr; + int context_mmps = 0; // propagation if (context_ket != nullptr) { if (forward) { info = context_ket->tensors[i]->info->extract_state_info( forward); - mmps = (int)info->n_states_total; + context_mmps = (int)info->n_states_total; context_ket->info->bond_dim = - max(context_ket->info->bond_dim, (ubond_t)mmps); + max(context_ket->info->bond_dim, (ubond_t)context_mmps); context_ket->info->left_dims[i + 1] = info; context_ket->info->save_left_dims(i + 1); context_ket->canonical_form[i] = 'L'; @@ -869,9 +1024,9 @@ template struct DMRG { info = context_ket->tensors[i + 1]->info->extract_state_info( forward); - mmps = (int)info->n_states_total; + context_mmps = (int)info->n_states_total; context_ket->info->bond_dim = - max(context_ket->info->bond_dim, (ubond_t)mmps); + max(context_ket->info->bond_dim, (ubond_t)context_mmps); context_ket->info->right_dims[i + 1] = info; context_ket->info->save_right_dims(i + 1); context_ket->canonical_form[i] = 'C'; @@ -913,6 +1068,8 @@ template struct DMRG { mps->unload_tensor(i + 1); mps->unload_tensor(i); } + if (context_ket != nullptr) + mmps = context_mmps; if (dm != nullptr) { dm->info->deallocate(); dm->deallocate(); diff --git a/src/pybind/pybind_dmrg.hpp b/src/pybind/pybind_dmrg.hpp index 45edbe83..352a07e8 100644 --- a/src/pybind/pybind_dmrg.hpp +++ b/src/pybind/pybind_dmrg.hpp @@ -1034,7 +1034,8 @@ void bind_fl_moving_environment(py::module &m, const string &name) { py::arg("i"), py::arg("mps"), py::arg("cmps"), py::arg("dot"), py::arg("fuse_left"), py::arg("mask"), py::arg("forward"), py::arg("is_wfn"), - py::arg("infer_info")) + py::arg("infer_info"), py::arg("ket") = nullptr, + py::arg("cket") = nullptr) .def_static("symm_context_convert_group", &MovingEnvironment::symm_context_convert_group, py::arg("i"), py::arg("mps"), py::arg("cmps"),