From 005d6093cdd816c7a924694c24a6bbbb2bdf0283 Mon Sep 17 00:00:00 2001 From: jjren Date: Sat, 5 Aug 2023 14:07:18 +0800 Subject: [PATCH] add functionality in mp/mps/mpo --- renormalizer/model/basis.py | 5 +- renormalizer/mps/mp.py | 112 +++++++++++++++++------------------- renormalizer/mps/mpdm.py | 18 +----- renormalizer/mps/mpo.py | 25 ++++++++ renormalizer/mps/mps.py | 43 +++++++++++--- 5 files changed, 119 insertions(+), 84 deletions(-) diff --git a/renormalizer/model/basis.py b/renormalizer/model/basis.py index 9a5f1f83..57083798 100644 --- a/renormalizer/model/basis.py +++ b/renormalizer/model/basis.py @@ -588,7 +588,7 @@ def op_mat(self, op: Union[Op, str]): # operators currently not having analytical matrix elements else: - logger.warning("Note that the quadrature part is not fully tested!") + logger.debug("Note that the quadrature part is not fully tested!") op_symbol = "*".join(op_symbol.split()) # potential operators @@ -643,7 +643,7 @@ def quad(self, expr): if s != "": expr = sp.sympify(s)*expr expr = expr.subs({sL:self.L, sxi:self.xi}) - print(expr) + logger.debug(f"operator expr: {expr}") expr = sp.lambdify([x, sibas, sjbas], expr, "numpy") mat = np.zeros((self.nbas, self.nbas)) @@ -651,6 +651,7 @@ def quad(self, expr): for jbas in range(self.nbas): val, error = scipy.integrate.quad(lambda x: expr(x, ibas, jbas), self.xi, self.xf) + logger.debug(f"quadrature value and error: {val}, {error}") mat[ibas, jbas] = val return mat diff --git a/renormalizer/mps/mp.py b/renormalizer/mps/mp.py index cd0bfd64..502cb37f 100644 --- a/renormalizer/mps/mp.py +++ b/renormalizer/mps/mp.py @@ -47,11 +47,7 @@ def load(cls, model: Model, fname: str): mp.dtype = backend.real_dtype mp.append(mt) - mp.qn = [] - for i in range(nsites+1): - subqn = npload[f"subqn_{i}"].astype(int).tolist() - mp.qn.append(subqn) - + mp.qn = npload["qn"] mp.qnidx = int(npload["qnidx"]) mp.qntot = npload["qntot"].astype(int) mp.to_right = bool(npload["to_right"]) @@ -353,65 +349,65 @@ def mp_norm(self) -> float: res = np.sqrt(res) return float(res) - - def add(self, other: "MatrixProduct"): - assert np.all(self.qntot == other.qntot) - assert self.site_num == other.site_num - + + def add(self, others: List["MatrixProduct"]): + + logger.info(f"new_mps:{type(others)}") + + if not isinstance(others, list): + others = [others] + new_mps = self.metacopy() - if other.dtype == backend.complex_dtype: - new_mps.dtype = backend.complex_dtype - if self.is_complex: + for other in others: + assert np.all(self.qntot == other.qntot) + assert self.site_num == other.site_num + if other.dtype == backend.complex_dtype: + new_mps.dtype = backend.complex_dtype + + if new_mps.is_complex: new_mps.to_complex(inplace=True) new_mps.compress_config.update(self.compress_config) - if self.is_mps: # MPS - new_mps[0] = dstack([self[0], other[0]]) - for i in range(1, self.site_num - 1): - mta = self[i] - mtb = other[i] - pdim = mta.shape[1] - assert pdim == mtb.shape[1] - new_ms = zeros( - [mta.shape[0] + mtb.shape[0], pdim, mta.shape[2] + mtb.shape[2]], - dtype=new_mps.dtype, - ) - new_ms[: mta.shape[0], :, : mta.shape[2]] = mta - new_ms[mta.shape[0] :, :, mta.shape[2] :] = mtb - new_mps[i] = new_ms - - new_mps[-1] = vstack([self[-1], other[-1]]) - elif self.is_mpo or self.is_mpdm: # MPO - new_mps[0] = concatenate((self[0], other[0]), axis=3) - for i in range(1, self.site_num - 1): - mta = self[i] - mtb = other[i] - pdimu = mta.shape[1] - pdimd = mta.shape[2] - assert pdimu == mtb.shape[1] - assert pdimd == mtb.shape[2] - - new_ms = zeros( - [ - mta.shape[0] + mtb.shape[0], - pdimu, - pdimd, - mta.shape[3] + mtb.shape[3], - ], - dtype=new_mps.dtype, - ) - new_ms[: mta.shape[0], :, :, : mta.shape[3]] = mta[:, :, :, :] - new_ms[mta.shape[0] :, :, :, mta.shape[3] :] = mtb[:, :, :, :] - new_mps[i] = new_ms + new_mps[0] = concatenate([self[0]] + [other[0] for other in others], axis=-1) + + for i in range(1, self.site_num - 1): + mts = [self[i]] + [other[i] for other in others] + pdim = self[i].shape[1:-1] + for mt in mts: + assert pdim == mt.shape[1:-1] + new_ms = zeros( + (sum([mt.shape[0] for mt in mts]),) + pdim + + (sum([mt.shape[-1] for mt in mts]),), dtype=new_mps.dtype, + ) + start_first = 0 + start_last = 0 + for mt in mts: + if len(pdim) == 1: + new_ms[start_first:start_first+mt.shape[0], :, start_last:start_last+mt.shape[-1]] = mt + elif len(pdim) == 2: + new_ms[start_first:start_first+mt.shape[0], :, :, start_last:start_last+mt.shape[-1]] = mt + else: + assert False - new_mps[-1] = concatenate((self[-1], other[-1]), axis=0) - else: - assert False + start_first += mt.shape[0] + start_last += mt.shape[-1] + new_mps[i] = new_ms + + new_mps[-1] = concatenate([self[-1]] + [other[-1] for other in others], axis=0) #assert self.qnidx == other.qnidx - new_mps.move_qnidx(other.qnidx) - new_mps.to_right = other.to_right - new_mps.qn = [np.concatenate([qn1, qn2]) for qn1, qn2 in zip(self.qn, other.qn)] + original_qnidx = [] + for other in others: + original_qnidx.append(other.qnidx) + other.move_qnidx(self.qnidx) + + new_mps.to_right = self.to_right + qn_all = [self.qn] + [other.qn for other in others] + new_mps.qn = [np.concatenate(qns) for qns in zip(*qn_all)] + + for i, other in enumerate(others): + other.move_qnidx(original_qnidx[i]) + # qn at the boundary should have dimension 1 new_mps.qn[0] = np.zeros((1, new_mps.qn[0].shape[1]), dtype=int) new_mps.qn[-1] = np.zeros((1, new_mps.qn[0].shape[1]), dtype=int) @@ -419,7 +415,7 @@ def add(self, other: "MatrixProduct"): new_mps.canonicalise() new_mps.compress() return new_mps - + def compress(self, temp_m_trunc=None, ret_s=False): """ inp: canonicalise MPS (or MPO) diff --git a/renormalizer/mps/mpdm.py b/renormalizer/mps/mpdm.py index 14427814..7aff5955 100644 --- a/renormalizer/mps/mpdm.py +++ b/renormalizer/mps/mpdm.py @@ -26,25 +26,9 @@ def ground_state(cls, model, max_entangled): @classmethod def from_mps(cls, mps: Mps): - mpo = cls() - mpo.model = mps.model - for ms in mps: - mo = np.zeros(tuple([ms.shape[0]] + [ms.shape[1]] * 2 + [ms.shape[2]])) - for iaxis in range(ms.shape[1]): - mo[:, iaxis, iaxis, :] = ms[:, iaxis, :].array - mpo.append(mo) - + mpo = super().from_mps(mps) mpo.coeff = mps.coeff - - mpo.optimize_config = mps.optimize_config mpo.evolve_config = mps.evolve_config - mpo.compress_add = mps.compress_add - - mpo.qn = [qn.copy() for qn in mps.qn] - mpo.qntot = mps.qntot - mpo.qnidx = mps.qnidx - mpo.to_right = mps.to_right - mpo.compress_config = mps.compress_config.copy() return mpo @classmethod diff --git a/renormalizer/mps/mpo.py b/renormalizer/mps/mpo.py index 9879f57a..c15c6fe2 100644 --- a/renormalizer/mps/mpo.py +++ b/renormalizer/mps/mpo.py @@ -247,6 +247,31 @@ def identity(cls, model: Model): mpo.build_empty_qn() return mpo + @classmethod + def from_mps(cls, mps): + mpo = cls() + mpo.model = mps.model + for ms in mps: + mo = np.zeros(tuple([ms.shape[0]] + [ms.shape[1]] * 2 + [ms.shape[2]])) + for iaxis in range(ms.shape[1]): + mo[:, iaxis, iaxis, :] = ms[:, iaxis, :].array + mpo.append(mo) + + mpo.optimize_config = mps.optimize_config + mpo.compress_add = mps.compress_add + + if mpo.is_mpo: + assert np.allclose(mps.coeff, 1) + # currently, only used when qn is zeros + for qn in mps.qn: + assert np.allclose(qn, np.zeros_like(qn)) + mpo.qn = [qn.copy() for qn in mps.qn] + mpo.qntot = mps.qntot + mpo.qnidx = mps.qnidx + mpo.to_right = mps.to_right + mpo.compress_config = mps.compress_config.copy() + return mpo + def __init__(self, model: Model = None, terms: Union[Op, List[Op]] = None, offset: Quantity = Quantity(0), ): """ diff --git a/renormalizer/mps/mps.py b/renormalizer/mps/mps.py index 53c27890..905b21d1 100644 --- a/renormalizer/mps/mps.py +++ b/renormalizer/mps/mps.py @@ -390,6 +390,28 @@ def from_dense(cls, model, wfn: np.ndarray): mp.append(residual_wfn) mp.build_empty_qn() return mp + + @classmethod + def from_mpo(cls, mpo: Mpo): + # convert diagonal mpo to mps, which usually happens in + # potential energy surface + mps = cls() + mps.model = mpo.model + for mo in mpo: + ms = np.zeros(tuple([mo.shape[0]] + [mo.shape[1]] + [mo.shape[3]])) + for iaxis in range(mo.shape[1]): + ms[:, iaxis, :] = mo[:, iaxis, iaxis, :].array + mps.append(ms) + + mps.coeff = 1 + if mpo.is_mpo: + logger.warning("Note that the qn part is directly inherited from mpo, make sure it is what you want!") + mps.qn = [qn.copy() for qn in mpo.qn] + mps.qntot = mpo.qntot + mps.qnidx = mpo.qnidx + mps.to_right = None + mps.compress_config = mpo.compress_config.copy() + return mps def __init__(self): super().__init__() @@ -1785,13 +1807,20 @@ def __setitem__(self, key, value): return super().__setitem__(key, value) - def add(self, other): - if not np.allclose(self.coeff, other.coeff): - self.scale(self.coeff, inplace=True) - other.scale(other.coeff, inplace=True) - self.coeff = 1 - other.coeff = 1 - return super().add(other) + def add(self, others): + """ + support add many mpss together in a batch way + """ + if not isinstance(others, list): + others = [others] + + for other in others: + if not np.allclose(self.coeff, other.coeff): + self.scale(self.coeff, inplace=True) + other.scale(other.coeff, inplace=True) + self.coeff = 1 + other.coeff = 1 + return super().add(others) def distance(self, other) -> float: if not np.allclose(self.coeff, other.coeff):