Skip to content

Commit

Permalink
Add functionality of MPO with block diagonal form (#154)
Browse files Browse the repository at this point in the history
* stack mpo

* add test

* add comments

* update

---------

Co-authored-by: JiaceSun <jsun3@caltech.edu>
  • Loading branch information
SUSYUSTC and JiaceSun authored Aug 8, 2023
1 parent 9706a29 commit 9583324
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 29 deletions.
2 changes: 1 addition & 1 deletion renormalizer/mps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from renormalizer.mps.backend import backend
from renormalizer.mps.mpo import Mpo
from renormalizer.mps.mpo import Mpo, StackedMpo
from renormalizer.mps.mps import Mps, BraKetPair
from renormalizer.mps.mpdm import MpDm
from renormalizer.mps.thermalprop import ThermalProp, load_thermal_state
Expand Down
99 changes: 78 additions & 21 deletions renormalizer/mps/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from renormalizer.mps.matrix import multi_tensor_contract, tensordot, asnumpy, asxp
from renormalizer.mps.hop_expr import hop_expr
from renormalizer.mps.svd_qn import get_qn_mask
from renormalizer.mps import Mpo, Mps
from renormalizer.mps import Mpo, Mps, StackedMpo
from renormalizer.mps.lib import Environ, cvec2cmat
from renormalizer.utils import Quantity, CompressConfig, CompressCriteria

Expand All @@ -45,14 +45,14 @@ def construct_mps_mpo(model, mmax, nexciton, offset=Quantity(0)):
return mps, mpo


def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
def optimize_mps(mps: Mps, mpo: Union[Mpo, StackedMpo], omega: float = None) -> Tuple[List, Mps]:
r"""DMRG ground state algorithm and state-averaged excited states algorithm
Parameters
----------
mps : renormalizer.mps.Mps
initial guess of mps. The MPS is overwritten during the optimization.
mpo : renormalizer.mps.Mpo
mpo : Union[renormalizer.mps.Mpo, renormalizer.mps.StackedMpo]
mpo of Hamiltonian
omega: float, optional
target the eigenpair near omega with special variational function
Expand All @@ -67,7 +67,7 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
mps : renormalizer.mps.Mps
optimized ground state MPS.
Note it's not the same with the overwritten input MPS.
See Also
--------
renormalizer.utils.configs.OptimizeConfig : The optimization configuration.
Expand Down Expand Up @@ -95,14 +95,19 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
env = "L"

compress_config_bk = mps.compress_config

# construct the environment matrix
if omega is not None:
if isinstance(mpo, StackedMpo):
raise NotImplementedError("StackedMPO + omega is not implemented yet")
identity = Mpo.identity(mpo.model)
mpo = mpo.add(identity.scale(-omega))
environ = Environ(mps, [mpo, mpo], env)
else:
environ = Environ(mps, mpo, env)
if isinstance(mpo, StackedMpo):
environ = [Environ(mps, item, env) for item in mpo.mpos]
else:
environ = Environ(mps, mpo, env)

macro_iteration_result = []
# Idx of the active site with lowest energy for each sweep
Expand All @@ -111,7 +116,7 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
res_mps: Union[Mps, List[Mps]] = None
for isweep, (compress_config, percent) in enumerate(mps.optimize_config.procedure):
logger.debug(f"isweep: {isweep}")

if isinstance(compress_config, CompressConfig):
mps.compress_config = compress_config
elif isinstance(compress_config, int):
Expand Down Expand Up @@ -156,19 +161,19 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
for res in res_mps:
res.compress_config = compress_config_bk
logger.info(f"{res_mps[0]}")

return macro_iteration_result, res_mps


def single_sweep(
mps: Mps,
mpo: Mpo,
mpo: Union[Mpo, StackedMpo],
environ: Environ,
omega: float,
percent: float,
last_opt_e_idx: int
):

method = mps.optimize_config.method
nroots = mps.optimize_config.nroots

Expand Down Expand Up @@ -210,18 +215,26 @@ def single_sweep(
if omega is None:
operator = mpo
else:
assert isinstance(mpo, Mpo)
operator = [mpo, mpo]

ltensor = environ.GetLR("L", lidx, mps, operator, itensor=None, method=lmethod)
rtensor = environ.GetLR("R", ridx, mps, operator, itensor=None, method=rmethod)
if isinstance(mpo, StackedMpo):
ltensor = [environ_item.GetLR("L", lidx, mps, operator_item, itensor=None, method=lmethod) for environ_item, operator_item in zip(environ, operator.mpos)]
rtensor = [environ_item.GetLR("R", ridx, mps, operator_item, itensor=None, method=rmethod) for environ_item, operator_item in zip(environ, operator.mpos)]
else:
ltensor = environ.GetLR("L", lidx, mps, operator, itensor=None, method=lmethod)
rtensor = environ.GetLR("R", ridx, mps, operator, itensor=None, method=rmethod)

# get the quantum number pattern
qnbigl, qnbigr, qnmat = mps._get_big_qn(cidx)
qn_mask = get_qn_mask(qnmat, mps.qntot)
cshape = qn_mask.shape

# center mo
cmo = [asxp(mpo[idx]) for idx in cidx]
if isinstance(mpo, StackedMpo):
cmo = [[asxp(mpo_item[idx]) for idx in cidx] for mpo_item in mpo.mpos]
else:
cmo = [asxp(mpo[idx]) for idx in cidx]

use_direct_eigh = np.prod(cshape) < 1000 or mps.optimize_config.algo == "direct"
if use_direct_eigh:
Expand Down Expand Up @@ -285,15 +298,15 @@ def single_sweep(
return micro_iteration_result, res_mps, mpo


def eigh_direct(
def get_ham_direct(
mps: Mps,
qn_mask: np.ndarray,
ltensor: xp.ndarray,
rtensor: xp.ndarray,
ltensor: Union[xp.ndarray, List[xp.ndarray]],
rtensor: Union[xp.ndarray, List[xp.ndarray]],
cmo: List[xp.ndarray],
omega: float,
):
logger.debug(f"use direct eigensolver")
logger.debug("use direct eigensolver")

# direct algorithm
if omega is None:
Expand Down Expand Up @@ -347,6 +360,23 @@ def eigh_direct(
)
ham = ham[:, :, :, :, qn_mask][qn_mask, :]

return ham


def eigh_direct(
mps: Mps,
qn_mask: np.ndarray,
ltensor: Union[xp.ndarray, List[xp.ndarray]],
rtensor: Union[xp.ndarray, List[xp.ndarray]],
cmo: List[xp.ndarray],
omega: float,
):
if isinstance(ltensor, list):
assert isinstance(rtensor, list)
assert len(ltensor) == len(rtensor)
ham = sum([get_ham_direct(mps, qn_mask, ltensor_item, rtensor_item, cmo_item, omega) for ltensor_item, rtensor_item, cmo_item in zip(ltensor, rtensor, cmo)])
else:
ham = get_ham_direct(mps, qn_mask, ltensor, rtensor, cmo, omega)
inverse = mps.optimize_config.inverse
w, v = scipy.linalg.eigh(asnumpy(ham) * inverse)

Expand All @@ -360,14 +390,13 @@ def eigh_direct(
return e, c


def eigh_iterative(
def get_ham_iterative(
mps: Mps,
qn_mask: np.ndarray,
ltensor: xp.ndarray,
rtensor: xp.ndarray,
ltensor: Union[xp.ndarray, List[xp.ndarray]],
rtensor: Union[xp.ndarray, List[xp.ndarray]],
cmo: List[xp.ndarray],
omega: float,
cguess: List[np.ndarray],
):
# iterative algorithm
method = mps.optimize_config.method
Expand Down Expand Up @@ -428,6 +457,34 @@ def eigh_iterative(
# contraction expression
cshape = qn_mask.shape
expr = hop_expr(ltensor, rtensor, cmo, cshape, omega is not None)
return hdiag, expr


def func_sum(funcs):
def new_func(*args, **kwargs):
return sum([func(*args, **kwargs) for func in funcs])
return new_func


def eigh_iterative(
mps: Mps,
qn_mask: np.ndarray,
ltensor: Union[xp.ndarray, List[xp.ndarray]],
rtensor: Union[xp.ndarray, List[xp.ndarray]],
cmo: List[xp.ndarray],
omega: float,
cguess: List[np.ndarray],
):
# iterative algorithm
inverse = mps.optimize_config.inverse
if isinstance(ltensor, list):
assert isinstance(rtensor, list)
assert len(ltensor) == len(rtensor)
ham = [get_ham_iterative(mps, qn_mask, ltensor_item, rtensor_item, cmo_item, omega) for ltensor_item, rtensor_item, cmo_item in zip(ltensor, rtensor, cmo)]
hdiag = sum([hdiag_item for hdiag_item, expr_item in ham])
expr = func_sum([expr_item for hdiag_item, expr_item in ham])
else:
hdiag, expr = get_ham_iterative(mps, qn_mask, ltensor, rtensor, cmo, omega)

count = 0

Expand Down
25 changes: 19 additions & 6 deletions renormalizer/mps/mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def ph_onsite(cls, model: HolsteinModel, opera: str, mol_idx:int, ph_idx=0):
def intersite(cls, model: HolsteinModel, e_opera: dict, ph_opera: dict, scale:
Quantity=Quantity(1.)):
r""" construct the inter site MPO
Parameters
----------
model : HolsteinModel
Expand All @@ -142,7 +142,7 @@ def intersite(cls, model: HolsteinModel, e_opera: dict, ph_opera: dict, scale:
Note
-----
the operator index starts from 0,1,2...
"""

ops = []
Expand Down Expand Up @@ -330,7 +330,7 @@ def apply(self, mp: MatrixProduct, canonicalise: bool=False) -> MatrixProduct:
# todo: use meta copy to save time, could be subtle when complex type is involved
# todo: inplace version (saved memory and can be used in `hybrid_exact_propagator`)
# the model is the same as the mps.model

assert self.site_num == mp.site_num
new_mps = self.promote_mt_type(mp.copy())
if mp.is_mps:
Expand Down Expand Up @@ -388,14 +388,14 @@ def apply(self, mp: MatrixProduct, canonicalise: bool=False) -> MatrixProduct:

def contract(self, mps, algo="svd"):
r""" an approximation of mpo @ mps/mpdm/mpo
Parameters
----------
mps : `Mps`, `Mpo`, `MpDm`
algo: str, optional
The algorithm to compress mpo @ mps/mpdm/mpo. It could be ``svd``
(default) and ``variational``.
(default) and ``variational``.
Returns
-------
new_mps : `Mps`
Expand Down Expand Up @@ -476,3 +476,16 @@ def is_hermitian(self):
def __matmul__(self, other):
return self.apply(other)


class StackedMpo:
"""
An effective sparse representation of MPO in the block diagonal form.
When it enters into the optimization, the Hamiltonian is calculated as
the sum of Hamiltonians generated by each MPO, then the Hamiltonian is
diagonalized and the MPS is updated.
Usage:
optimize_mps(mps, StackedMpo([mpo1, mpo2, ...]))
"""
def __init__(self, mpos: List[Mpo]):
self.mpos = mpos
13 changes: 12 additions & 1 deletion renormalizer/mps/tests/test_gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from renormalizer.model import Model, h_qc
from renormalizer.mps.backend import primme
from renormalizer.mps.gs import construct_mps_mpo, optimize_mps
from renormalizer.mps import Mpo, Mps
from renormalizer.mps import Mpo, Mps, StackedMpo
from renormalizer.tests.parameter import holstein_model
from renormalizer.utils.configs import OFS
from renormalizer.mps.tests import cur_dir
Expand Down Expand Up @@ -136,3 +136,14 @@ def test_qc(with_ofs):
print(mpo)
gs_e = min(energies)
assert np.allclose(gs_e, fci_e, atol=5e-3)


def test_stackedmpo():
scheme = 1
method = '1site'
mps, mpo = construct_mps_mpo(holstein_model.switch_scheme(scheme), procedure[0][0], nexciton)
mps.optimize_config.procedure = procedure
mps.optimize_config.method = method
energies1, _ = optimize_mps(mps.copy(), mpo)
energies2, _ = optimize_mps(mps.copy(), StackedMpo([mpo, mpo]))
assert np.all(np.abs(np.array(energies2) - np.array(energies1) * 2) < 1e-8)

0 comments on commit 9583324

Please sign in to comment.