Skip to content

Commit

Permalink
!2280 fix overflow error & error of calc_energy
Browse files Browse the repository at this point in the history
Merge pull request !2280 from zengqg/r0.9
  • Loading branch information
donghufeng authored and gitee-org committed Jan 25, 2024
2 parents 0e75ccb + 6f8f8ca commit 668e46e
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
19 changes: 17 additions & 2 deletions mindquantum/algorithm/qaia/QAIA.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class QAIA:
Args:
J (Union[numpy.array, csr_matrix]): The coupling matrix with shape :math:`(N x N)`.
h (numpy.array): The external field with shape :math:`(N, )`.
h (numpy.array): The external field with shape :math:`(N x 1)`.
x (numpy.array): The initialized spin value with shape :math:`(N x batch_size)`. Default: ``None``.
n_iter (int): The number of iterations. Default: ``1000``.
batch_size (int): The number of sampling. Default: ``1``.
Expand All @@ -35,6 +35,8 @@ class QAIA:
def __init__(self, J, h=None, x=None, n_iter=1000, batch_size=1):
"""Construct a QAIA algorithm."""
self.J = J
if h is not None and len(h.shape) < 2:
h = h[:, np.newaxis]
self.h = h
self.x = x
# The number of spins
Expand Down Expand Up @@ -76,4 +78,17 @@ def calc_energy(self, x=None):

if self.h is None:
return -0.5 * np.sum(self.J.dot(sign) * sign, axis=0)
return -0.5 * np.sum(self.J.dot(sign) * sign, axis=0, keepdims=True) - self.h.dot(sign)
return -0.5 * np.sum(self.J.dot(sign) * sign, axis=0, keepdims=True) - self.h.T.dot(sign)


class OverflowException(Exception):
r"""
Custom exception class for handling overflow errors in numerical calculations.
Args:
message: Exception message string, defaults to "Overflow error".
"""

def __init__(self, message="Overflow error"):
self.message = message
super().__init__(self.message)
5 changes: 4 additions & 1 deletion mindquantum/algorithm/qaia/SB.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
from scipy.sparse import csr_matrix

from .QAIA import QAIA
from .QAIA import QAIA, OverflowException


class SB(QAIA):
Expand Down Expand Up @@ -124,6 +124,9 @@ def update(self):
else:
self.y += self.xi * self.dt * (self.J.dot(self.x) + self.h)

if np.isnan(self.x).any():
raise OverflowException("Value is too large to handle due to large dt or xi.")


class BSB(SB): # noqa: N801
r"""
Expand Down
5 changes: 4 additions & 1 deletion mindquantum/algorithm/qaia/SFC.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
from scipy.sparse import csr_matrix

from .QAIA import QAIA
from .QAIA import QAIA, OverflowException


class SFC(QAIA):
Expand Down Expand Up @@ -85,3 +85,6 @@ def update(self):
f = np.tanh(self.c[i] * z)
self.x = self.x + (-self.x**3 + (self.p[i] - 1) * self.x - f - self.k * (z - self.e)) * self.dt
self.e = self.e + (-self.beta[i] * (self.e - z)) * self.dt

if np.isnan(self.x).any():
raise OverflowException("Value is too large to handle due to large dt or xi.")
2 changes: 1 addition & 1 deletion mindquantum/algorithm/qaia/SimCIM.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,5 @@ def update(self):
)
# gradient + momentum
self.dx = self.dx * self.momentum + newdc * (1 - self.momentum)
ind = (np.abs(self.x + self.dx) < 1.0).astype(np.int)
ind = (np.abs(self.x + self.dx) < 1.0).astype(np.int64)
self.x += self.dx * ind

0 comments on commit 668e46e

Please sign in to comment.