Skip to content

Commit

Permalink
Enhance mps/mpo reproducibility; Print git hash for easier debugging;…
Browse files Browse the repository at this point in the history
… fix mp dump/load
  • Loading branch information
jiangtong1000 committed Sep 13, 2023
1 parent 1576c3c commit 1cfa0ec
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 12 deletions.
26 changes: 21 additions & 5 deletions renormalizer/mps/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import logging
import random
import subprocess

import numpy as np

Expand Down Expand Up @@ -49,24 +50,39 @@ def try_import_cupy():
logger.info(f"Using GPU: {GPU_ID}")
return True, cp

def get_git_commit_hash():
try:
commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8')
return commit_hash
except subprocess.CalledProcessError:
return "Unknown"


USE_GPU, xp = try_import_cupy()


#USE_GPU = False
#xp = np

xpseed = 2019
npseed = 9012
randomseed = 1092

xp.random.seed(xpseed)
np.random.seed(npseed)
random.seed(randomseed)


if not USE_GPU:
logger.info("Use NumPy as backend")
logger.info(f"numpy random seed is {npseed}")
OE_BACKEND = "numpy"
else:
logger.info("Use CuPy as backend")
logger.info(f"cupy random seed is {xpseed}")
OE_BACKEND = "cupy"


xp.random.seed(2019)
np.random.seed(9012)
random.seed(1092)
logger.info(f"random seed is {randomseed}")
logger.info("Git Commit Hash: %s", get_git_commit_hash())


class Backend:
Expand Down
9 changes: 6 additions & 3 deletions renormalizer/mps/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,10 @@ def eigh_direct(
nroots = mps.optimize_config.nroots
if nroots == 1:
e = w[0]
c = v[:, 0]
c = v[:, 0] / np.sign(np.max(v[:, 0]))
else:
e = w[:nroots]
c = [v[:, iroot] for iroot in range(min(nroots, v.shape[1]))]
c = [v[:, iroot] / np.sign(np.max(v[:, iroot])) for iroot in range(min(nroots, v.shape[1]))]
return e, c


Expand Down Expand Up @@ -556,4 +556,7 @@ def hop(x):
else:
assert False
logger.debug(f"use {algo}, HC hops: {count}")
return e, c
if nroots == 1:
return e, c/np.sign(np.max(c))
else:
return e, c/np.sign(np.max(c, axis=0))
5 changes: 4 additions & 1 deletion renormalizer/mps/mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def load(cls, model: Model, fname: str):

mp.qn = []
for i in range(nsites+1):
subqn = npload[f"subqn_{i}"].astype(int).tolist()
subqn = npload[f"subqn_{i}"].astype(int).tolist()
mp.qn.append(subqn)

mp.qnidx = int(npload["qnidx"])
Expand Down Expand Up @@ -1084,6 +1084,9 @@ def dump(self, fname, other_attrs=None):
arr = np.empty(len(qn), object)
arr[:] = qn
data_dict['qn'] = arr

for i in range(self.site_num+1):
data_dict[f"subqn_{i}"] = qn[i]

try:
np.savez(fname, **data_dict)
Expand Down
7 changes: 4 additions & 3 deletions renormalizer/mps/symbolic_mpo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import logging
import itertools
from collections import namedtuple
from collections import namedtuple, OrderedDict
from typing import List, Set, Tuple, Dict

import numpy as np
Expand Down Expand Up @@ -136,10 +136,11 @@ def construct_symbolic_mpo(table, factor, algo="Hopcroft-Karp"):

# translate the symbolic operator table to an easy to manipulate numpy array
table = np.array(table)

# unique operators with DoF names taken into consideration
# The inclusion of DoF names is necessary for multi-dof basis.
unique_op: Set[Op] = set(table.ravel())

unique_op = OrderedDict.fromkeys(table.ravel())
unique_op = list(unique_op.keys())
# check the index of different operators could be represented with np.uint16
assert len(unique_op) < max_uint16

Expand Down

0 comments on commit 1cfa0ec

Please sign in to comment.