Skip to content

Commit

Permalink
update: add rdm as optional input of mutual_info calc
Browse files Browse the repository at this point in the history
  • Loading branch information
ansatzX authored and liwt31 committed Jul 14, 2024
1 parent 1d5eb05 commit 394ad22
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion renormalizer/tn/tests/test_tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_rdm_entropy_holstein():
mps_idx1, mps_idx2 = 1, 3
dof1 = model.basis[mps_idx1].dof
dof2 = model.basis[mps_idx2].dof
ttns_mutual_infos = ttns.calc_2dof_mutual_info((dof1, dof2))
ttns_mutual_infos,_ = ttns.calc_2dof_mutual_info((dof1, dof2))
ttns_mutual_info = ttns_mutual_infos[(dof1, dof2)]
np.testing.assert_allclose(ttns_mutual_info, mps_mutual_info[mps_idx1, mps_idx2], atol=1e-4)

Expand Down
18 changes: 11 additions & 7 deletions renormalizer/tn/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,12 +1185,14 @@ def calc_2dof_rdm(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]]) ->
rdm_[dof_pair] = res
return rdm_

def calc_2dof_entropy(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]]) -> Dict[Tuple[Any, Any], float]:
rdm = self.calc_2dof_rdm(dofs)
def calc_2dof_entropy(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]], rdm: Dict[Any, np.ndarray]=None) -> Dict[Tuple[Any, Any], float]:
if rdm is None:
rdm = self.calc_2dof_rdm(dofs)

entropy = {key: calc_vn_entropy_dm(dm) for key, dm in rdm.items()}
return entropy

def calc_2dof_mutual_info(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]]) -> Dict[Tuple[Any, Any], float]:
def calc_2dof_mutual_info(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]], rdm_2dof: Dict[Any, np.ndarray]=None) -> Dict[Tuple[Any, Any], float]:
r"""
Calculate mutual information between two DOFs.
Expand All @@ -1214,14 +1216,16 @@ def calc_2dof_mutual_info(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any
for dof_pair in dofs:
dofs_lst.append(dof_pair[0])
dofs_lst.append(dof_pair[1])
entropy_1site = self.calc_1dof_entropy(dofs_lst)
entropy_2site = self.calc_2dof_entropy(dofs)
entropy_1dof = self.calc_1dof_entropy(dofs_lst)
entropy_2dof = self.calc_2dof_entropy(dofs, rdm_2dof)
for dof_pair in dofs:
dof1 = dof_pair[0]
dof2 = dof_pair[1]
mutual_info = (entropy_1site[dof1] + entropy_1site[dof2] - entropy_2site[dof_pair]) / 2
mutual_info = (entropy_1dof[dof1] + entropy_1dof[dof2] - entropy_2dof[dof_pair]) / 2
mutual_infos[dof_pair] = mutual_info
return mutual_infos

entropy_tuple = (entropy_1dof, entropy_2dof)
return mutual_infos, entropy_tuple

def calc_bond_entropy(self) -> np.ndarray:
r"""
Expand Down

0 comments on commit 394ad22

Please sign in to comment.