Skip to content

Commit

Permalink
Correct numpy tolerence
Browse files Browse the repository at this point in the history
- originally introduced in commit 254b981
  • Loading branch information
atxy-blip authored and liwt31 committed May 17, 2024
1 parent 07b639f commit 24626e1
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 42 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*.eps
*.png
*.lprof
*.diff

.idea/
.cache/
Expand Down
44 changes: 38 additions & 6 deletions renormalizer/mps/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def get_git_commit_hash():
USE_GPU, xp = try_import_cupy()


#USE_GPU = False
#xp = np
# USE_GPU = False
# xp = np

xpseed = 2019
npseed = 9012
Expand Down Expand Up @@ -166,10 +166,42 @@ def dtypes(self, target):

@property
def canonical_atol(self):
if self.is_32bits:
return 1e-4
else:
return 1e-5
'''
Absolute tolerence for use in matrix.check_lortho,
mp.check_left_canonical, mp.ensure_left_canonical
and their right counterparts
'''
return (
self._canonical_atol
if hasattr(self, "_canonical_atol")
else (1e-4 if self.is_32bits else 1e-8)
)

@property
def canonical_rtol(self):
'''
Relative tolerence for use in matrix.check_lortho,
mp.check_left_canonical, mp.ensure_left_canonical
and their right counterparts
'''
return (
self._canonical_rtol
if hasattr(self, "_canonical_rtol")
else (1e-2 if self.is_32bits else 1e-5)
)

@canonical_atol.setter
def canonical_atol(self, value):
self._canonical_atol = self._tol_checker(value)

@canonical_rtol.setter
def canonical_rtol(self, value):
self._canonical_rtol = self._tol_checker(value)

def _tol_checker(self, value):
if not isinstance(value, (int, float)) or value < 0:
raise ValueError("Tolerance must be a non-negative float number")
return value


backend = Backend()
12 changes: 8 additions & 4 deletions renormalizer/mps/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,29 @@ def r_combine(self):
def l_combine(self):
return self.reshape(self.l_combine_shape)

def check_lortho(self, atol=None):
def check_lortho(self, rtol: float = None, atol: float = None):
"""
check L-orthogonal
"""
if atol is None:
atol = backend.canonical_atol
if rtol is None:
rtol = backend.canonical_rtol
tensm = asxp(self.array.reshape([np.prod(self.shape[:-1]), self.shape[-1]]))
s = tensm.T.conj() @ tensm
return xp.allclose(s, xp.eye(s.shape[0]), atol=atol)
return xp.allclose(s, xp.eye(s.shape[0]), rtol=rtol, atol=atol)

def check_rortho(self, atol=None):
def check_rortho(self, rtol: float = None, atol: float = None):
"""
check R-orthogonal
"""
if atol is None:
atol = backend.canonical_atol
if rtol is None:
rtol = backend.canonical_rtol
tensm = asxp(self.array.reshape([self.shape[0], np.prod(self.shape[1:])]))
s = tensm @ tensm.T.conj()
return xp.allclose(s, xp.eye(s.shape[0]), atol=atol)
return xp.allclose(s, xp.eye(s.shape[0]), rtol=rtol, atol=atol)

def to_complex(self):
# `xp.array` always creates new array, so to_complex means copy, which is
Expand Down
64 changes: 32 additions & 32 deletions renormalizer/mps/mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def bond_dims(self) -> List:
# return a list so that the logging result is more pretty
return bond_dims


vbond_list = vbond_dims = bond_list = bond_dims

@property
Expand Down Expand Up @@ -158,21 +157,21 @@ def move_qnidx(self, dstidx: int):
self.qn[idx] = self.qntot - self.qn[idx]
self.qnidx = dstidx

def check_left_canonical(self, atol=None):
def check_left_canonical(self, rtol: float = None, atol: float = None):
"""
check L-canonical
"""
for i in range(len(self)-1):
if not self[i].check_lortho(atol):
for i in range(len(self) - 1):
if not self[i].check_lortho(rtol, atol):
return False
return True

def check_right_canonical(self, atol=None):
def check_right_canonical(self, rtol: float = None, atol: float = None):
"""
check R-canonical
"""
for i in range(1, len(self)):
if not self[i].check_rortho(atol):
if not self[i].check_rortho(rtol, atol):
return False
return True

Expand All @@ -190,18 +189,24 @@ def is_right_canonical(self):
"""
return self.qnidx == 0

def ensure_left_canonical(self, atol=None):
if self.to_right or self.qnidx != self.site_num-1 or \
(not self.check_left_canonical(atol)):
def ensure_left_canonical(self, rtol: float = None, atol: float = None):
if (
self.to_right
or self.qnidx != self.site_num - 1
or (not self.check_left_canonical(rtol, atol))
):
self.move_qnidx(0)
self.to_right = True
return self.canonicalise()
else:
return self

def ensure_right_canonical(self, atol=None):
if (not self.to_right) or self.qnidx != 0 or \
(not self.check_right_canonical(atol)):
def ensure_right_canonical(self, rtol: float = None, atol: float = None):
if (
(not self.to_right)
or self.qnidx != 0
or (not self.check_right_canonical(rtol, atol))
):
self.move_qnidx(self.site_num - 1)
self.to_right = False
return self.canonicalise()
Expand Down Expand Up @@ -406,7 +411,7 @@ def add(self, other: "MatrixProduct"):
else:
assert False

#assert self.qnidx == other.qnidx
# assert self.qnidx == other.qnidx
new_mps.move_qnidx(other.qnidx)
new_mps.to_right = other.to_right
new_mps.qn = [np.concatenate([qn1, qn2]) for qn1, qn2 in zip(self.qn, other.qn)]
Expand Down Expand Up @@ -602,7 +607,9 @@ def variational_compress(self, mpo=None, guess=None):

mps_old = mps.copy()
else:
logger.warning("Variational compress is not converged! Please increase the procedure!")
logger.warning(
"Variational compress is not converged! Please increase the procedure!"
)

# remove the redundant bond dimension near the boundary of the MPS
mps.canonicalise()
Expand Down Expand Up @@ -643,7 +650,7 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
"""

system = "L" if self.to_right else "R"

if self.compress_config.bonddim_should_set:
self.compress_config.set_bonddim(len(self)+1)

Expand Down Expand Up @@ -679,7 +686,7 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
)
entropy1 = calc_vn_entropy(SUset1**2)
entropy2 = calc_vn_entropy(SUset2**2)

# TODO: more general control according to
# CompressCriteria.thresh
assert self.compress_config.criteria == CompressCriteria.fixed
Expand Down Expand Up @@ -717,9 +724,6 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
# Need some additional testing at production level calculation
self.model: Model = Model(new_basis, self.model.ham_terms, self.model.dipole, self.model.output_ordering)
logger.debug(f"DOF ordering: {[b.dof for b in self.model.basis]}")




if self.to_right:
m_trunc = self.compress_config.compute_m_trunc(
Expand All @@ -735,7 +739,7 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
m_trunc = self.compress_config.compute_m_trunc(
SVset, cidx[-1], self.to_right
)

ms, msdim, msqn, compms = select_basis(
Vset, SVset, qnrnew, Uset, m_trunc, percent=percent
)
Expand Down Expand Up @@ -765,8 +769,7 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
Uset, Sset, qnnew = svd_qn.eigh_qn(
asnumpy(ddm), qnbigl, qnbigr, self.qntot, system=system
)



if self.to_right:
m_trunc = self.compress_config.compute_m_trunc(
Sset, cidx[0], self.to_right
Expand All @@ -775,7 +778,7 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
m_trunc = self.compress_config.compute_m_trunc(
Sset, cidx[-1], self.to_right
)

ms, msdim, msqn, compms = select_basis(
Uset, Sset, qnnew, None, m_trunc, percent=percent
)
Expand Down Expand Up @@ -853,7 +856,6 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
else:
return None


def _push_cano(self, idx):
# move the canonical center to the next site
# idx is the current canonical center
Expand Down Expand Up @@ -882,7 +884,7 @@ def canonicalise(self, stop_idx: int=None):
assert self.qnidx == self.site_num-1

for idx in self.iter_idx_list(full=False, stop_idx=stop_idx):
self._push_cano(idx)
self._push_cano(idx)
# can't iter to idx == 0 or idx == self.site_num - 1
if (not self.to_right and idx == 1) or (self.to_right and idx == self.site_num - 2):
self._switch_direction()
Expand Down Expand Up @@ -921,19 +923,19 @@ def dot(self, other: "MatrixProduct") -> complex:
assert False

return complex(e0[0, 0])

def dot_ob(self, other: "MatrixProduct") -> complex:
"""
dot product of two mps / mpo with open boundary, but the boundary of mps/mpo is larger than
1, different from the normal mps/mpo
"""

assert len(self) == len(other)

e0 = xp.eye(self[0].shape[0])
tmp = xp.eye(other[0].shape[0])
e0 = tensordot(e0, tmp, 0).transpose(0,2,1,3)

for mt1, mt2 in zip(self, other):
e0 = tensordot(e0, mt2.array, 1)
if mt1.ndim == 3:
Expand All @@ -945,8 +947,6 @@ def dot_ob(self, other: "MatrixProduct") -> complex:

return e0



def angle(self, other):
return abs(self.conj().dot(other))

Expand Down Expand Up @@ -1075,7 +1075,7 @@ def dump(self, fname, other_attrs=None):

for i in range(self.site_num+1):
data_dict[f"subqn_{i}"] = qn[i]

try:
np.savez(fname, **data_dict)
except Exception:
Expand Down Expand Up @@ -1180,7 +1180,7 @@ def __del__(self):
shutil.rmtree(dir_with_id)
except OSError:
logger.exception(f"Removing temperary dump dir {dir_with_id} failed")

@classmethod
def from_mp(cls, model, mplist):
# mps/mpo/mpdm from matrix product
Expand Down
34 changes: 34 additions & 0 deletions renormalizer/mps/tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# Author: Yu Xiong <y-xiong22@mails.tsinghua.edu.cn>

import pytest

from renormalizer.mps import backend


def set_tolerance(tolerance_type, value):
setattr(backend, tolerance_type, value)


def get_tolerance(tolerance_type):
return getattr(backend, tolerance_type)


@pytest.mark.parametrize(
"tolerance_type, value",
[
("canonical_atol", 1e-5), # normal
("canonical_atol", -1e-7), # ValueError
("canonical_atol", "invalid"), # ValueError
("canonical_rtol", 1e-4), # normal
("canonical_rtol", -1e-6), # ValueError
("canonical_rtol", "invalid"), # ValueError
],
)
def test_tolerances(tolerance_type, value):
if isinstance(value, (int, float)) and value >= 0:
set_tolerance(tolerance_type, value)
assert get_tolerance(tolerance_type) == value
else:
with pytest.raises(ValueError):
set_tolerance(tolerance_type, value)

0 comments on commit 24626e1

Please sign in to comment.