Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct numpy tolerence #167

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*.eps
*.png
*.lprof
*.diff

.idea/
.cache/
Expand Down
44 changes: 38 additions & 6 deletions renormalizer/mps/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def get_git_commit_hash():
USE_GPU, xp = try_import_cupy()


#USE_GPU = False
#xp = np
# USE_GPU = False
# xp = np

xpseed = 2019
npseed = 9012
Expand Down Expand Up @@ -166,10 +166,42 @@ def dtypes(self, target):

@property
def canonical_atol(self):
if self.is_32bits:
return 1e-4
else:
return 1e-5
'''
Absolute tolerence for use in matrix.check_lortho,
mp.check_left_canonical, mp.ensure_left_canonical
and their right counterparts
'''
return (
self._canonical_atol
if hasattr(self, "_canonical_atol")
else (1e-4 if self.is_32bits else 1e-8)
)

@property
def canonical_rtol(self):
'''
Relative tolerence for use in matrix.check_lortho,
mp.check_left_canonical, mp.ensure_left_canonical
and their right counterparts
'''
return (
self._canonical_rtol
if hasattr(self, "_canonical_rtol")
else (1e-2 if self.is_32bits else 1e-5)
)

@canonical_atol.setter
def canonical_atol(self, value):
self._canonical_atol = self._tol_checker(value)

@canonical_rtol.setter
def canonical_rtol(self, value):
self._canonical_rtol = self._tol_checker(value)

def _tol_checker(self, value):
if not isinstance(value, (int, float)) or value < 0:
raise ValueError("Tolerance must be a non-negative float number")
return value


backend = Backend()
12 changes: 8 additions & 4 deletions renormalizer/mps/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,29 @@ def r_combine(self):
def l_combine(self):
return self.reshape(self.l_combine_shape)

def check_lortho(self, atol=None):
def check_lortho(self, rtol: float = None, atol: float = None):
"""
check L-orthogonal
"""
if atol is None:
atol = backend.canonical_atol
if rtol is None:
rtol = backend.canonical_rtol
tensm = asxp(self.array.reshape([np.prod(self.shape[:-1]), self.shape[-1]]))
s = tensm.T.conj() @ tensm
return xp.allclose(s, xp.eye(s.shape[0]), atol=atol)
return xp.allclose(s, xp.eye(s.shape[0]), rtol=rtol, atol=atol)

def check_rortho(self, atol=None):
def check_rortho(self, rtol: float = None, atol: float = None):
"""
check R-orthogonal
"""
if atol is None:
atol = backend.canonical_atol
if rtol is None:
rtol = backend.canonical_rtol
tensm = asxp(self.array.reshape([self.shape[0], np.prod(self.shape[1:])]))
s = tensm @ tensm.T.conj()
return xp.allclose(s, xp.eye(s.shape[0]), atol=atol)
return xp.allclose(s, xp.eye(s.shape[0]), rtol=rtol, atol=atol)

def to_complex(self):
# `xp.array` always creates new array, so to_complex means copy, which is
Expand Down
64 changes: 32 additions & 32 deletions renormalizer/mps/mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def bond_dims(self) -> List:
# return a list so that the logging result is more pretty
return bond_dims


vbond_list = vbond_dims = bond_list = bond_dims

@property
Expand Down Expand Up @@ -158,21 +157,21 @@ def move_qnidx(self, dstidx: int):
self.qn[idx] = self.qntot - self.qn[idx]
self.qnidx = dstidx

def check_left_canonical(self, atol=None):
def check_left_canonical(self, rtol: float = None, atol: float = None):
"""
check L-canonical
"""
for i in range(len(self)-1):
if not self[i].check_lortho(atol):
for i in range(len(self) - 1):
if not self[i].check_lortho(rtol, atol):
return False
return True

def check_right_canonical(self, atol=None):
def check_right_canonical(self, rtol: float = None, atol: float = None):
"""
check R-canonical
"""
for i in range(1, len(self)):
if not self[i].check_rortho(atol):
if not self[i].check_rortho(rtol, atol):
return False
return True

Expand All @@ -190,18 +189,24 @@ def is_right_canonical(self):
"""
return self.qnidx == 0

def ensure_left_canonical(self, atol=None):
if self.to_right or self.qnidx != self.site_num-1 or \
(not self.check_left_canonical(atol)):
def ensure_left_canonical(self, rtol: float = None, atol: float = None):
if (
self.to_right
or self.qnidx != self.site_num - 1
or (not self.check_left_canonical(rtol, atol))
):
self.move_qnidx(0)
self.to_right = True
return self.canonicalise()
else:
return self

def ensure_right_canonical(self, atol=None):
if (not self.to_right) or self.qnidx != 0 or \
(not self.check_right_canonical(atol)):
def ensure_right_canonical(self, rtol: float = None, atol: float = None):
if (
(not self.to_right)
or self.qnidx != 0
or (not self.check_right_canonical(rtol, atol))
):
self.move_qnidx(self.site_num - 1)
self.to_right = False
return self.canonicalise()
Expand Down Expand Up @@ -406,7 +411,7 @@ def add(self, other: "MatrixProduct"):
else:
assert False

#assert self.qnidx == other.qnidx
# assert self.qnidx == other.qnidx
new_mps.move_qnidx(other.qnidx)
new_mps.to_right = other.to_right
new_mps.qn = [np.concatenate([qn1, qn2]) for qn1, qn2 in zip(self.qn, other.qn)]
Expand Down Expand Up @@ -602,7 +607,9 @@ def variational_compress(self, mpo=None, guess=None):

mps_old = mps.copy()
else:
logger.warning("Variational compress is not converged! Please increase the procedure!")
logger.warning(
"Variational compress is not converged! Please increase the procedure!"
)

# remove the redundant bond dimension near the boundary of the MPS
mps.canonicalise()
Expand Down Expand Up @@ -643,7 +650,7 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
"""

system = "L" if self.to_right else "R"

if self.compress_config.bonddim_should_set:
self.compress_config.set_bonddim(len(self)+1)

Expand Down Expand Up @@ -679,7 +686,7 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
)
entropy1 = calc_vn_entropy(SUset1**2)
entropy2 = calc_vn_entropy(SUset2**2)

# TODO: more general control according to
# CompressCriteria.thresh
assert self.compress_config.criteria == CompressCriteria.fixed
Expand Down Expand Up @@ -717,9 +724,6 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
# Need some additional testing at production level calculation
self.model: Model = Model(new_basis, self.model.ham_terms, self.model.dipole, self.model.output_ordering)
logger.debug(f"DOF ordering: {[b.dof for b in self.model.basis]}")




if self.to_right:
m_trunc = self.compress_config.compute_m_trunc(
Expand All @@ -735,7 +739,7 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
m_trunc = self.compress_config.compute_m_trunc(
SVset, cidx[-1], self.to_right
)

ms, msdim, msqn, compms = select_basis(
Vset, SVset, qnrnew, Uset, m_trunc, percent=percent
)
Expand Down Expand Up @@ -765,8 +769,7 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
Uset, Sset, qnnew = svd_qn.eigh_qn(
asnumpy(ddm), qnbigl, qnbigr, self.qntot, system=system
)



if self.to_right:
m_trunc = self.compress_config.compute_m_trunc(
Sset, cidx[0], self.to_right
Expand All @@ -775,7 +778,7 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
m_trunc = self.compress_config.compute_m_trunc(
Sset, cidx[-1], self.to_right
)

ms, msdim, msqn, compms = select_basis(
Uset, Sset, qnnew, None, m_trunc, percent=percent
)
Expand Down Expand Up @@ -853,7 +856,6 @@ def _update_mps(self, cstruct, cidx, qnbigl, qnbigr, percent=0):
else:
return None


def _push_cano(self, idx):
# move the canonical center to the next site
# idx is the current canonical center
Expand Down Expand Up @@ -882,7 +884,7 @@ def canonicalise(self, stop_idx: int=None):
assert self.qnidx == self.site_num-1

for idx in self.iter_idx_list(full=False, stop_idx=stop_idx):
self._push_cano(idx)
self._push_cano(idx)
# can't iter to idx == 0 or idx == self.site_num - 1
if (not self.to_right and idx == 1) or (self.to_right and idx == self.site_num - 2):
self._switch_direction()
Expand Down Expand Up @@ -921,19 +923,19 @@ def dot(self, other: "MatrixProduct") -> complex:
assert False

return complex(e0[0, 0])

def dot_ob(self, other: "MatrixProduct") -> complex:
"""
dot product of two mps / mpo with open boundary, but the boundary of mps/mpo is larger than
1, different from the normal mps/mpo
"""

assert len(self) == len(other)

e0 = xp.eye(self[0].shape[0])
tmp = xp.eye(other[0].shape[0])
e0 = tensordot(e0, tmp, 0).transpose(0,2,1,3)

for mt1, mt2 in zip(self, other):
e0 = tensordot(e0, mt2.array, 1)
if mt1.ndim == 3:
Expand All @@ -945,8 +947,6 @@ def dot_ob(self, other: "MatrixProduct") -> complex:

return e0



def angle(self, other):
return abs(self.conj().dot(other))

Expand Down Expand Up @@ -1075,7 +1075,7 @@ def dump(self, fname, other_attrs=None):

for i in range(self.site_num+1):
data_dict[f"subqn_{i}"] = qn[i]

try:
np.savez(fname, **data_dict)
except Exception:
Expand Down Expand Up @@ -1180,7 +1180,7 @@ def __del__(self):
shutil.rmtree(dir_with_id)
except OSError:
logger.exception(f"Removing temperary dump dir {dir_with_id} failed")

@classmethod
def from_mp(cls, model, mplist):
# mps/mpo/mpdm from matrix product
Expand Down
34 changes: 34 additions & 0 deletions renormalizer/mps/tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# Author: Yu Xiong <y-xiong22@mails.tsinghua.edu.cn>

import pytest

from renormalizer.mps import backend


def set_tolerance(tolerance_type, value):
setattr(backend, tolerance_type, value)


def get_tolerance(tolerance_type):
return getattr(backend, tolerance_type)


@pytest.mark.parametrize(
"tolerance_type, value",
[
("canonical_atol", 1e-5), # normal
("canonical_atol", -1e-7), # ValueError
("canonical_atol", "invalid"), # ValueError
("canonical_rtol", 1e-4), # normal
("canonical_rtol", -1e-6), # ValueError
("canonical_rtol", "invalid"), # ValueError
],
)
def test_tolerances(tolerance_type, value):
if isinstance(value, (int, float)) and value >= 0:
set_tolerance(tolerance_type, value)
assert get_tolerance(tolerance_type) == value
else:
with pytest.raises(ValueError):
set_tolerance(tolerance_type, value)
Loading