Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use symengine for parameter and parameter expressions #6270

Merged
merged 32 commits into from
May 5, 2021
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
15faabb
Use symengine for parameter and parameter expressions
mtreinish Apr 20, 2021
b5c9bdb
Merge branch 'main' into use-symengine
mtreinish Apr 20, 2021
a6b6a8e
Add pickle support
mtreinish Apr 20, 2021
b7bef71
Merge branch 'use-symengine' of github.com:mtreinish/qiskit-core into…
mtreinish Apr 20, 2021
f95143b
Pickle fixes
mtreinish Apr 21, 2021
6ffbefd
Merge branch 'main' into use-symengine
mtreinish Apr 21, 2021
6efeea7
Convert to sympy expression for eq and str
mtreinish Apr 21, 2021
01383ed
Fix lint
mtreinish Apr 21, 2021
c3b6875
Fix vector pickle
mtreinish Apr 21, 2021
618dd49
Avoid runtime import for symengine
mtreinish Apr 21, 2021
6169c9d
Fix test failures due to precision differences
mtreinish Apr 21, 2021
2ad468c
Merge branch 'main' into use-symengine
mtreinish Apr 21, 2021
6a64a6a
More fixes from subtle behavior differences
mtreinish Apr 21, 2021
7bd6ce8
Merge branch 'use-symengine' of github.com:mtreinish/qiskit-core into…
mtreinish Apr 21, 2021
d73739d
Fix some of the gradient failures
mtreinish Apr 21, 2021
95eecad
Fix failure in sympy fallback path
mtreinish Apr 21, 2021
4cd6d7a
Fix gradient tests
mtreinish Apr 21, 2021
7077fd6
Adjust complex check logic
mtreinish Apr 21, 2021
f0c5142
Fix lint
mtreinish Apr 21, 2021
f57be79
Workaround pulse failure
mtreinish Apr 21, 2021
8cb1331
Merge branch 'main' into use-symengine
mtreinish Apr 21, 2021
ae7d19a
Adjust failing test to use np allclose
mtreinish Apr 22, 2021
89df1e9
Merge branch 'main' into use-symengine
mtreinish Apr 22, 2021
1fe5ddf
Add release note
mtreinish Apr 22, 2021
efdd964
Add is_real method to parameterexpression
mtreinish Apr 22, 2021
25f9c9a
Merge branch 'main' into use-symengine
mtreinish Apr 22, 2021
5c14d37
Merge remote-tracking branch 'origin/main' into use-symengine
mtreinish May 5, 2021
fa77157
Run black post-rebase
mtreinish May 5, 2021
3b4c5ec
Remove symengine/sympy usage from gradients tests
mtreinish May 5, 2021
d193f29
Remove unused import
mtreinish May 5, 2021
98f1130
Fix Delay instruction's new validate is_real check
mtreinish May 5, 2021
a872706
Merge branch 'main' into use-symengine
mtreinish May 5, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions qiskit/circuit/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
from scipy.linalg import schur

from qiskit.circuit.parameter import ParameterExpression
from qiskit.circuit.parameterexpression import ParameterExpression
from qiskit.circuit.exceptions import CircuitError
from .instruction import Instruction

Expand Down Expand Up @@ -249,10 +249,9 @@ def validate_parameter(self, parameter):
if isinstance(parameter, ParameterExpression):
if len(parameter.parameters) > 0:
return parameter # expression has free parameters, we cannot validate it
if not parameter._symbol_expr.is_real:
raise CircuitError(
"Bound parameter expression is complex in gate {}".format(self.name)
)
if not parameter.is_real():
msg = "Bound parameter expression is complex in gate {}".format(self.name)
raise CircuitError(msg)
return parameter # per default assume parameters must be real when bound
if isinstance(parameter, (int, float)):
return parameter
Expand Down
28 changes: 25 additions & 3 deletions qiskit/circuit/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@

from .parameterexpression import ParameterExpression

try:
import symengine

HAS_SYMENGINE = True
except ImportError:
HAS_SYMENGINE = False


class Parameter(ParameterExpression):
"""Parameter Class for variable parameters."""
Expand Down Expand Up @@ -49,10 +56,12 @@ def __init__(self, name: str):
be any unicode string, e.g. "ϕ".
"""
self._name = name
if not HAS_SYMENGINE:
from sympy import Symbol

from sympy import Symbol

symbol = Symbol(name)
symbol = Symbol(name)
else:
symbol = symengine.Symbol(name)
super().__init__(symbol_map={self: symbol}, expr=symbol)

def subs(self, parameter_map: dict):
Expand Down Expand Up @@ -86,3 +95,16 @@ def __eq__(self, other):

def __hash__(self):
return self._hash

def __getstate__(self):
return {"name": self._name}

def __setstate__(self, state):
self._name = state["name"]
if not HAS_SYMENGINE:
from sympy import Symbol

symbol = Symbol(self._name)
else:
symbol = symengine.Symbol(self._name)
super().__init__(symbol_map={self: symbol}, expr=symbol)
151 changes: 122 additions & 29 deletions qiskit/circuit/parameterexpression.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@

from qiskit.circuit.exceptions import CircuitError

try:
import symengine

HAS_SYMENGINE = True
except ImportError:
HAS_SYMENGINE = False


ParameterValueType = Union["ParameterExpression", float, int]


Expand Down Expand Up @@ -53,7 +61,12 @@ def parameters(self) -> Set:

def conjugate(self) -> "ParameterExpression":
"""Return the conjugate."""
conjugated = ParameterExpression(self._parameter_symbols, self._symbol_expr.conjugate())
if HAS_SYMENGINE:
conjugated = ParameterExpression(
self._parameter_symbols, symengine.conjugate(self._symbol_expr)
)
else:
conjugated = ParameterExpression(self._parameter_symbols, self._symbol_expr.conjugate())
return conjugated

def assign(self, parameter, value: ParameterValueType) -> "ParameterExpression":
Expand Down Expand Up @@ -110,7 +123,9 @@ def bind(self, parameter_values: Dict) -> "ParameterExpression":
p: s for p, s in self._parameter_symbols.items() if p in free_parameters
}

if bound_symbol_expr.is_infinite:
if (
hasattr(bound_symbol_expr, "is_infinite") and bound_symbol_expr.is_infinite
) or bound_symbol_expr == float("inf"):
raise ZeroDivisionError(
"Binding provided for expression "
"results in division by zero "
Expand Down Expand Up @@ -142,10 +157,12 @@ def subs(self, parameter_map: Dict) -> "ParameterExpression":

self._raise_if_passed_unknown_parameters(parameter_map.keys())
self._raise_if_parameter_names_conflict(inbound_parameters, parameter_map.keys())
if HAS_SYMENGINE:
new_parameter_symbols = {p: symengine.Symbol(p.name) for p in inbound_parameters}
else:
from sympy import Symbol

from sympy import Symbol

new_parameter_symbols = {p: Symbol(p.name) for p in inbound_parameters}
new_parameter_symbols = {p: Symbol(p.name) for p in inbound_parameters}

# Include existing parameters in self not set to be replaced.
new_parameter_symbols.update(
Expand Down Expand Up @@ -257,11 +274,14 @@ def gradient(self, param) -> Union["ParameterExpression", float]:
return 0.0

# Compute the gradient of the parameter expression w.r.t. param
import sympy as sy

key = self._parameter_symbols[param]
# TODO enable nth derivative
expr_grad = sy.Derivative(self._symbol_expr, key).doit()
if HAS_SYMENGINE:
expr_grad = symengine.Derivative(self._symbol_expr, key)
else:
# TODO enable nth derivative
from sympy import Derivative

expr_grad = Derivative(self._symbol_expr, key).doit()

# generate the new dictionary of symbols
# this needs to be done since in the derivative some symbols might disappear (e.g.
Expand Down Expand Up @@ -310,57 +330,83 @@ def _call(self, ufunc):

def sin(self):
"""Sine of a ParameterExpression"""
from sympy import sin as _sin
if HAS_SYMENGINE:
return self._call(symengine.sin)
else:
from sympy import sin as _sin

return self._call(_sin)
return self._call(_sin)

def cos(self):
"""Cosine of a ParameterExpression"""
from sympy import cos as _cos
if HAS_SYMENGINE:
return self._call(symengine.cos)
else:
from sympy import cos as _cos

return self._call(_cos)
return self._call(_cos)

def tan(self):
"""Tangent of a ParameterExpression"""
from sympy import tan as _tan
if HAS_SYMENGINE:
return self._call(symengine.tan)
else:
from sympy import tan as _tan

return self._call(_tan)
return self._call(_tan)

def arcsin(self):
"""Arcsin of a ParameterExpression"""
from sympy import asin as _asin
if HAS_SYMENGINE:
return self._call(symengine.asin)
else:
from sympy import asin as _asin

return self._call(_asin)
return self._call(_asin)

def arccos(self):
"""Arccos of a ParameterExpression"""
from sympy import acos as _acos
if HAS_SYMENGINE:
return self._call(symengine.acos)
else:
from sympy import acos as _acos

return self._call(_acos)
return self._call(_acos)

def arctan(self):
"""Arctan of a ParameterExpression"""
from sympy import atan as _atan
if HAS_SYMENGINE:
return self._call(symengine.atan)
else:
from sympy import atan as _atan

return self._call(_atan)
return self._call(_atan)

def exp(self):
"""Exponential of a ParameterExpression"""
from sympy import exp as _exp
if HAS_SYMENGINE:
return self._call(symengine.exp)
else:
from sympy import exp as _exp

return self._call(_exp)
return self._call(_exp)

def log(self):
"""Logarithm of a ParameterExpression"""
from sympy import log as _log
if HAS_SYMENGINE:
return self._call(symengine.log)
else:
from sympy import log as _log

return self._call(_log)
return self._call(_log)

def __repr__(self):
return "{}({})".format(self.__class__.__name__, str(self))

def __str__(self):
return str(self._symbol_expr)
from sympy import sympify
ewinston marked this conversation as resolved.
Show resolved Hide resolved

return str(sympify(self._symbol_expr))

def __float__(self):
if self.parameters:
Expand Down Expand Up @@ -405,9 +451,56 @@ def __eq__(self, other):
bool: result of the comparison
"""
if isinstance(other, ParameterExpression):
return self.parameters == other.parameters and self._symbol_expr.equals(
other._symbol_expr
)
if self.parameters != other.parameters:
return False
if HAS_SYMENGINE:
from sympy import sympify

return sympify(self._symbol_expr).equals(sympify(other._symbol_expr))
else:
return self._symbol_expr.equals(other._symbol_expr)
elif isinstance(other, numbers.Number):
return len(self.parameters) == 0 and complex(self._symbol_expr) == other
return False

def __getstate__(self):
if HAS_SYMENGINE:
from sympy import sympify

symbols = {k: sympify(v) for k, v in self._parameter_symbols.items()}
expr = sympify(self._symbol_expr)
return {"type": "symengine", "symbols": symbols, "expr": expr, "names": self._names}
else:
return {
"type": "sympy",
"symbols": self._parameter_symbols,
"expr": self._symbol_expr,
"names": self._names,
}

def __setstate__(self, state):
if state["type"] == "symengine":
self._symbol_expr = symengine.sympify(state["expr"])
self._parameter_symbols = {k: symengine.sympify(v) for k, v in state["symbols"].items()}
self._parameters = set(self._parameter_symbols)
else:
self._symbol_expr = state["expr"]
self._parameter_symbols = state["symbols"]
self._parameters = set(self._parameter_symbols)
self._names = state["names"]

def is_real(self):
"""Return whether the expression is real"""

if not self._symbol_expr.is_real and self._symbol_expr.is_real is not None:
# Symengine returns false for is_real on the expression if
# there is a imaginary component (even if that component is 0),
# but the parameter will evaluate as real. Check that if the
# expression's is_real attribute returns false that we have a
# non-zero imaginary
if HAS_SYMENGINE:
if self._symbol_expr.imag != 0.0:
return False
else:
return False
return True
15 changes: 15 additions & 0 deletions qiskit/circuit/parametervector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ def vector(self):
"""Get the parent vector instance."""
return self._vector

def __getstate__(self):
return {
"name": self._name,
"uuid": self._uuid,
"vector": self._vector,
"index": self._index,
}

def __setstate__(self, state):
self._name = state["name"]
self._uuid = state["uuid"]
self._vector = state["vector"]
self._index = state["index"]
super().__init__(self._name)


class ParameterVector:
"""ParameterVector class to quickly generate lists of parameters."""
Expand Down
8 changes: 7 additions & 1 deletion qiskit/circuit/tools/pi_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ def pi_check(inpt, eps=1e-6, output="text", ndigits=5):
"""
if isinstance(inpt, ParameterExpression):
param_str = str(inpt)
syms = inpt._symbol_expr.expr_free_symbols
if not hasattr(inpt._symbol_expr, "expr_free_symbols"):
from sympy import sympify

expr = sympify(inpt._symbol_expr)
else:
expr = inpt._symbol_expr
syms = expr.expr_free_symbols
for sym in syms:
if not sym.is_number:
continue
Expand Down
7 changes: 7 additions & 0 deletions qiskit/opflow/gradients/derivative_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@

OperatorType = Union[StateFn, PrimitiveOp, ListOp]

try:
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
import symengine

HAS_SYMENGINE = True
except ImportError:
HAS_SYMENGINE = False


class DerivativeBase(ConverterBase):
r"""Base class for differentiating opflow objects.
Expand Down
10 changes: 10 additions & 0 deletions releasenotes/notes/symengine-2fa0479fa7d9aa80.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
upgrade:
- |
A new requirement `symengine <https://pypi.org/project/symengine>`__ has
been added for Linux (on x86_64, aarch64, and ppc64le) and macOS users
(x86_64 and arm64). It is an optional dependency on Windows (and available
on PyPi as a precompiled package for 64bit Windows) and other
architectures. If it is installed it provides significantly improved
performance for the evaluation of :class:`~qiskit.circuit.Parameter` and
:class:`~qiskit.circuit.ParameterExpression` objects.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ dill>=0.3
fastjsonschema>=2.10
python-constraint>=1.4
python-dateutil>=2.8.0
symengine>0.7 ; platform_machine == 'x86_64' or platform_machine == 'aarch64' or platform_machine == 'ppc64le' or platform_machine == 'amd64' or platform_machine == 'arm64'
Loading