Skip to content

Commit

Permalink
add functionality in mp/mps/mpo
Browse files Browse the repository at this point in the history
  • Loading branch information
jjren committed Aug 5, 2023
1 parent 9706a29 commit 005d609
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 84 deletions.
5 changes: 3 additions & 2 deletions renormalizer/model/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def op_mat(self, op: Union[Op, str]):

# operators currently not having analytical matrix elements
else:
logger.warning("Note that the quadrature part is not fully tested!")
logger.debug("Note that the quadrature part is not fully tested!")
op_symbol = "*".join(op_symbol.split())

# potential operators
Expand Down Expand Up @@ -643,14 +643,15 @@ def quad(self, expr):
if s != "":
expr = sp.sympify(s)*expr
expr = expr.subs({sL:self.L, sxi:self.xi})
print(expr)
logger.debug(f"operator expr: {expr}")
expr = sp.lambdify([x, sibas, sjbas], expr, "numpy")

mat = np.zeros((self.nbas, self.nbas))
for ibas in range(self.nbas):
for jbas in range(self.nbas):
val, error = scipy.integrate.quad(lambda x: expr(x, ibas, jbas),
self.xi, self.xf)
logger.debug(f"quadrature value and error: {val}, {error}")
mat[ibas, jbas] = val
return mat

Expand Down
112 changes: 54 additions & 58 deletions renormalizer/mps/mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ def load(cls, model: Model, fname: str):
mp.dtype = backend.real_dtype
mp.append(mt)

mp.qn = []
for i in range(nsites+1):
subqn = npload[f"subqn_{i}"].astype(int).tolist()
mp.qn.append(subqn)

mp.qn = npload["qn"]
mp.qnidx = int(npload["qnidx"])
mp.qntot = npload["qntot"].astype(int)
mp.to_right = bool(npload["to_right"])
Expand Down Expand Up @@ -353,73 +349,73 @@ def mp_norm(self) -> float:
res = np.sqrt(res)

return float(res)

def add(self, other: "MatrixProduct"):
assert np.all(self.qntot == other.qntot)
assert self.site_num == other.site_num


def add(self, others: List["MatrixProduct"]):

logger.info(f"new_mps:{type(others)}")

if not isinstance(others, list):
others = [others]

new_mps = self.metacopy()
if other.dtype == backend.complex_dtype:
new_mps.dtype = backend.complex_dtype
if self.is_complex:
for other in others:
assert np.all(self.qntot == other.qntot)
assert self.site_num == other.site_num
if other.dtype == backend.complex_dtype:
new_mps.dtype = backend.complex_dtype

if new_mps.is_complex:
new_mps.to_complex(inplace=True)
new_mps.compress_config.update(self.compress_config)

if self.is_mps: # MPS
new_mps[0] = dstack([self[0], other[0]])
for i in range(1, self.site_num - 1):
mta = self[i]
mtb = other[i]
pdim = mta.shape[1]
assert pdim == mtb.shape[1]
new_ms = zeros(
[mta.shape[0] + mtb.shape[0], pdim, mta.shape[2] + mtb.shape[2]],
dtype=new_mps.dtype,
)
new_ms[: mta.shape[0], :, : mta.shape[2]] = mta
new_ms[mta.shape[0] :, :, mta.shape[2] :] = mtb
new_mps[i] = new_ms

new_mps[-1] = vstack([self[-1], other[-1]])
elif self.is_mpo or self.is_mpdm: # MPO
new_mps[0] = concatenate((self[0], other[0]), axis=3)
for i in range(1, self.site_num - 1):
mta = self[i]
mtb = other[i]
pdimu = mta.shape[1]
pdimd = mta.shape[2]
assert pdimu == mtb.shape[1]
assert pdimd == mtb.shape[2]

new_ms = zeros(
[
mta.shape[0] + mtb.shape[0],
pdimu,
pdimd,
mta.shape[3] + mtb.shape[3],
],
dtype=new_mps.dtype,
)
new_ms[: mta.shape[0], :, :, : mta.shape[3]] = mta[:, :, :, :]
new_ms[mta.shape[0] :, :, :, mta.shape[3] :] = mtb[:, :, :, :]
new_mps[i] = new_ms
new_mps[0] = concatenate([self[0]] + [other[0] for other in others], axis=-1)

for i in range(1, self.site_num - 1):
mts = [self[i]] + [other[i] for other in others]
pdim = self[i].shape[1:-1]
for mt in mts:
assert pdim == mt.shape[1:-1]
new_ms = zeros(
(sum([mt.shape[0] for mt in mts]),) + pdim +
(sum([mt.shape[-1] for mt in mts]),), dtype=new_mps.dtype,
)
start_first = 0
start_last = 0
for mt in mts:
if len(pdim) == 1:
new_ms[start_first:start_first+mt.shape[0], :, start_last:start_last+mt.shape[-1]] = mt
elif len(pdim) == 2:
new_ms[start_first:start_first+mt.shape[0], :, :, start_last:start_last+mt.shape[-1]] = mt
else:
assert False

new_mps[-1] = concatenate((self[-1], other[-1]), axis=0)
else:
assert False
start_first += mt.shape[0]
start_last += mt.shape[-1]
new_mps[i] = new_ms

new_mps[-1] = concatenate([self[-1]] + [other[-1] for other in others], axis=0)

#assert self.qnidx == other.qnidx
new_mps.move_qnidx(other.qnidx)
new_mps.to_right = other.to_right
new_mps.qn = [np.concatenate([qn1, qn2]) for qn1, qn2 in zip(self.qn, other.qn)]
original_qnidx = []
for other in others:
original_qnidx.append(other.qnidx)
other.move_qnidx(self.qnidx)

new_mps.to_right = self.to_right
qn_all = [self.qn] + [other.qn for other in others]
new_mps.qn = [np.concatenate(qns) for qns in zip(*qn_all)]

for i, other in enumerate(others):
other.move_qnidx(original_qnidx[i])

# qn at the boundary should have dimension 1
new_mps.qn[0] = np.zeros((1, new_mps.qn[0].shape[1]), dtype=int)
new_mps.qn[-1] = np.zeros((1, new_mps.qn[0].shape[1]), dtype=int)
if self.compress_add:
new_mps.canonicalise()
new_mps.compress()
return new_mps

def compress(self, temp_m_trunc=None, ret_s=False):
"""
inp: canonicalise MPS (or MPO)
Expand Down
18 changes: 1 addition & 17 deletions renormalizer/mps/mpdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,9 @@ def ground_state(cls, model, max_entangled):

@classmethod
def from_mps(cls, mps: Mps):
mpo = cls()
mpo.model = mps.model
for ms in mps:
mo = np.zeros(tuple([ms.shape[0]] + [ms.shape[1]] * 2 + [ms.shape[2]]))
for iaxis in range(ms.shape[1]):
mo[:, iaxis, iaxis, :] = ms[:, iaxis, :].array
mpo.append(mo)

mpo = super().from_mps(mps)
mpo.coeff = mps.coeff

mpo.optimize_config = mps.optimize_config
mpo.evolve_config = mps.evolve_config
mpo.compress_add = mps.compress_add

mpo.qn = [qn.copy() for qn in mps.qn]
mpo.qntot = mps.qntot
mpo.qnidx = mps.qnidx
mpo.to_right = mps.to_right
mpo.compress_config = mps.compress_config.copy()
return mpo

@classmethod
Expand Down
25 changes: 25 additions & 0 deletions renormalizer/mps/mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,31 @@ def identity(cls, model: Model):
mpo.build_empty_qn()
return mpo

@classmethod
def from_mps(cls, mps):
mpo = cls()
mpo.model = mps.model
for ms in mps:
mo = np.zeros(tuple([ms.shape[0]] + [ms.shape[1]] * 2 + [ms.shape[2]]))
for iaxis in range(ms.shape[1]):
mo[:, iaxis, iaxis, :] = ms[:, iaxis, :].array
mpo.append(mo)

mpo.optimize_config = mps.optimize_config
mpo.compress_add = mps.compress_add

if mpo.is_mpo:
assert np.allclose(mps.coeff, 1)
# currently, only used when qn is zeros
for qn in mps.qn:
assert np.allclose(qn, np.zeros_like(qn))
mpo.qn = [qn.copy() for qn in mps.qn]
mpo.qntot = mps.qntot
mpo.qnidx = mps.qnidx
mpo.to_right = mps.to_right
mpo.compress_config = mps.compress_config.copy()
return mpo

def __init__(self, model: Model = None, terms: Union[Op, List[Op]] = None, offset: Quantity = Quantity(0), ):

"""
Expand Down
43 changes: 36 additions & 7 deletions renormalizer/mps/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,28 @@ def from_dense(cls, model, wfn: np.ndarray):
mp.append(residual_wfn)
mp.build_empty_qn()
return mp

@classmethod
def from_mpo(cls, mpo: Mpo):
# convert diagonal mpo to mps, which usually happens in
# potential energy surface
mps = cls()
mps.model = mpo.model
for mo in mpo:
ms = np.zeros(tuple([mo.shape[0]] + [mo.shape[1]] + [mo.shape[3]]))
for iaxis in range(mo.shape[1]):
ms[:, iaxis, :] = mo[:, iaxis, iaxis, :].array
mps.append(ms)

mps.coeff = 1
if mpo.is_mpo:
logger.warning("Note that the qn part is directly inherited from mpo, make sure it is what you want!")
mps.qn = [qn.copy() for qn in mpo.qn]
mps.qntot = mpo.qntot
mps.qnidx = mpo.qnidx
mps.to_right = None
mps.compress_config = mpo.compress_config.copy()
return mps

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -1785,13 +1807,20 @@ def __setitem__(self, key, value):
return super().__setitem__(key, value)


def add(self, other):
if not np.allclose(self.coeff, other.coeff):
self.scale(self.coeff, inplace=True)
other.scale(other.coeff, inplace=True)
self.coeff = 1
other.coeff = 1
return super().add(other)
def add(self, others):
"""
support add many mpss together in a batch way
"""
if not isinstance(others, list):
others = [others]

for other in others:
if not np.allclose(self.coeff, other.coeff):
self.scale(self.coeff, inplace=True)
other.scale(other.coeff, inplace=True)
self.coeff = 1
other.coeff = 1
return super().add(others)

def distance(self, other) -> float:
if not np.allclose(self.coeff, other.coeff):
Expand Down

0 comments on commit 005d609

Please sign in to comment.