Skip to content

Commit

Permalink
Remove qiskit.test from test/common.py (#1971)
Browse files Browse the repository at this point in the history
* remove qiskit.test from test/common.py

* fix format, add releasenote

---------

Co-authored-by: Hiroshi Horii <hhorii@users.noreply.github.com>
  • Loading branch information
doichanj and hhorii authored Dec 8, 2023
1 parent 096e1e3 commit 49667a0
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 3 deletions.
4 changes: 4 additions & 0 deletions releasenotes/notes/remove_qiskit_test-777882fa1591b6e7.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
deprecations:
- |
Removed importing qiskit.test from test and include some classes in Aer
141 changes: 138 additions & 3 deletions test/terra/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import warnings
import unittest
from enum import Enum
from itertools import repeat
from math import pi
Expand All @@ -30,7 +31,7 @@
from qiskit_aer import __path__ as main_path
from qiskit.quantum_info import Operator, Statevector
from qiskit.quantum_info.operators.predicates import matrix_equal
from qiskit.test.base import FullQiskitTestCase
from .decorators import enforce_subclasses_call


class Path(Enum):
Expand All @@ -42,12 +43,146 @@ class Path(Enum):
EXAMPLES = os.path.join(MAIN, "../examples")


class QiskitAerTestCase(FullQiskitTestCase):
@enforce_subclasses_call(["setUp", "setUpClass", "tearDown", "tearDownClass"])
class BaseQiskitAerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__setup_called = False
self.__teardown_called = False

def setUp(self):
super().setUp()
if self.__setup_called:
raise ValueError(
"In File: %s\n"
"TestCase.setUp was already called. Do not explicitly call "
"setUp from your tests. In your own setUp, use super to call "
"the base setUp." % (sys.modules[self.__class__.__module__].__file__,)
)
self.__setup_called = True

def tearDown(self):
super().tearDown()
if self.__teardown_called:
raise ValueError(
"In File: %s\n"
"TestCase.tearDown was already called. Do not explicitly call "
"tearDown from your tests. In your own tearDown, use super to "
"call the base tearDown." % (sys.modules[self.__class__.__module__].__file__,)
)
self.__teardown_called = True

@staticmethod
def _get_resource_path(filename, path=Path.TEST):
"""Get the absolute path to a resource.
Args:
filename (string): filename or relative path to the resource.
path (Path): path used as relative to the filename.
Returns:
str: the absolute path to the resource.
"""
return os.path.normpath(os.path.join(path.value, filename))

def assertDictAlmostEqual(
self, dict1, dict2, delta=None, msg=None, places=None, default_value=0
):
"""Assert two dictionaries with numeric values are almost equal.
Fail if the two dictionaries are unequal as determined by
comparing that the difference between values with the same key are
not greater than delta (default 1e-8), or that difference rounded
to the given number of decimal places is not zero. If a key in one
dictionary is not in the other the default_value keyword argument
will be used for the missing value (default 0). If the two objects
compare equal then they will automatically compare almost equal.
Args:
dict1 (dict): a dictionary.
dict2 (dict): a dictionary.
delta (number): threshold for comparison (defaults to 1e-8).
msg (str): return a custom message on failure.
places (int): number of decimal places for comparison.
default_value (number): default value for missing keys.
Raises:
TypeError: if the arguments are not valid (both `delta` and
`places` are specified).
AssertionError: if the dictionaries are not almost equal.
"""

error_msg = dicts_almost_equal(dict1, dict2, delta, places, default_value)

if error_msg:
msg = self._formatMessage(msg, error_msg)
raise self.failureException(msg)


def dicts_almost_equal(dict1, dict2, delta=None, places=None, default_value=0):
"""Test if two dictionaries with numeric values are almost equal.
Fail if the two dictionaries are unequal as determined by
comparing that the difference between values with the same key are
not greater than delta (default 1e-8), or that difference rounded
to the given number of decimal places is not zero. If a key in one
dictionary is not in the other the default_value keyword argument
will be used for the missing value (default 0). If the two objects
compare equal then they will automatically compare almost equal.
Args:
dict1 (dict): a dictionary.
dict2 (dict): a dictionary.
delta (number): threshold for comparison (defaults to 1e-8).
places (int): number of decimal places for comparison.
default_value (number): default value for missing keys.
Raises:
TypeError: if the arguments are not valid (both `delta` and
`places` are specified).
Returns:
String: Empty string if dictionaries are almost equal. A description
of their difference if they are deemed not almost equal.
"""

def valid_comparison(value):
"""compare value to delta, within places accuracy"""
if places is not None:
return round(value, places) == 0
else:
return value < delta

# Check arguments.
if dict1 == dict2:
return ""
if places is not None:
if delta is not None:
raise TypeError("specify delta or places not both")
msg_suffix = " within %s places" % places
else:
delta = delta or 1e-8
msg_suffix = " within %s delta" % delta

# Compare all keys in both dicts, populating error_msg.
error_msg = ""
for key in set(dict1.keys()) | set(dict2.keys()):
val1 = dict1.get(key, default_value)
val2 = dict2.get(key, default_value)
if not valid_comparison(abs(val1 - val2)):
error_msg += f"({safe_repr(key)}: {safe_repr(val1)} != {safe_repr(val2)}), "

if error_msg:
return error_msg[:-2] + msg_suffix
else:
return ""


class QiskitAerTestCase(BaseQiskitAerTestCase):
"""Helper class that contains common functionality."""

def setUp(self):
super().setUp()
self.useFixture(fixtures.Timeout(240, gentle=False))

@classmethod
def setUpClass(cls):
Expand Down
99 changes: 99 additions & 0 deletions test/terra/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
import unittest

from qiskit import QuantumCircuit, execute
from qiskit.utils import wrap_method

from qiskit_aer import AerProvider, AerSimulator

from typing import Union, Callable, Type, Iterable


def is_method_available(backend, method):
"""Check if input method is available for the qasm simulator."""
Expand Down Expand Up @@ -94,3 +97,99 @@ def _deprecated_method(self, *args, **kwargs):
method(self, *args, **kwargs)

return _deprecated_method


def enforce_subclasses_call(
methods: Union[str, Iterable[str]], attr: str = "_enforce_subclasses_call_cache"
) -> Callable[[Type], Type]:
"""Class decorator which enforces that if any subclasses define on of the ``methods``, they must
call ``super().<method>()`` or face a ``ValueError`` at runtime.
This is unlikely to be useful for concrete test classes, who are not normally subclassed. It
should not be used on user-facing code, because it prevents subclasses from being free to
override parent-class behavior, even when the parent-class behavior is not needed.
This adds behavior to the ``__init__`` and ``__init_subclass__`` methods of the class, in
addition to the named methods of this class and all subclasses. The checks could be averted in
grandchildren if a child class overrides ``__init_subclass__`` without up-calling the decorated
class's method, though this would typically break inheritance principles.
Arguments:
methods:
Names of the methods to add the enforcement to. These do not necessarily need to be
defined in the class body, provided they are somewhere in the method-resolution tree.
attr:
The attribute which will be added to all instances of this class and subclasses, in
order to manage the call enforcement. This can be changed to avoid clashes.
Returns:
A decorator, which returns its input class with the class with the relevant methods modified
to include checks, and injection code in the ``__init_subclass__`` method.
"""

methods = {methods} if isinstance(methods, str) else set(methods)

def initialize_call_memory(self, *_args, **_kwargs):
"""Add the extra attribute used for tracking the method calls."""
setattr(self, attr, set())

def save_call_status(name):
"""Decorator, whose return saves the fact that the top-level method call occurred."""

def out(self, *_args, **_kwargs):
getattr(self, attr).add(name)

return out

def clear_call_status(name):
"""Decorator, whose return clears the call status of the method ``name``. This prepares the
call tracking for the child class's method call."""

def out(self, *_args, **_kwargs):
getattr(self, attr).discard(name)

return out

def enforce_call_occurred(name):
"""Decorator, whose return checks that the top-level method call occurred, and raises
``ValueError`` if not. Concretely, this is an assertion that ``save_call_status`` ran."""

def out(self, *_args, **_kwargs):
cache = getattr(self, attr)
if name not in cache:
classname = self.__name__ if isinstance(self, type) else type(self).__name__
raise ValueError(
f"Parent '{name}' method was not called by '{classname}.{name}'."
f" Ensure you have put in calls to 'super().{name}()'."
)

return out

def wrap_subclass_methods(cls):
"""Wrap all the ``methods`` of ``cls`` with the call-tracking assertions that the top-level
versions of the methods were called (likely via ``super()``)."""
# Only wrap methods who are directly defined in this class; if we're resolving to a method
# higher up the food chain, then it will already have been wrapped.
for name in set(cls.__dict__) & methods:
wrap_method(
cls,
name,
before=clear_call_status(name),
after=enforce_call_occurred(name),
)

def decorator(cls):
# Add a class-level memory on, so class methods will work as well. Instances will override
# this on instantiation, to keep the "namespace" of class- and instance-methods separate.
initialize_call_memory(cls)
# Do the extra bits after the main body of __init__ so we can check we're not overwriting
# anything, and after __init_subclass__ in case the decorated class wants to influence the
# creation of the subclass's methods before we get to them.
wrap_method(cls, "__init__", after=initialize_call_memory)
for name in methods:
wrap_method(cls, name, before=save_call_status(name))
wrap_method(cls, "__init_subclass__", after=wrap_subclass_methods)
return cls

return decorator

0 comments on commit 49667a0

Please sign in to comment.