Skip to content

Commit

Permalink
Allow separate decoders (#597)
Browse files Browse the repository at this point in the history
* update decoder

* release note
  • Loading branch information
jyu00 authored Oct 28, 2022
1 parent b22b35e commit 5ef5711
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 27 deletions.
13 changes: 13 additions & 0 deletions qiskit_ibm_runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

from qiskit.providers.jobstatus import JobStatus

from .program.result_decoder import ResultDecoder
from .utils.estimator_result_decoder import EstimatorResultDecoder
from .utils.sampler_result_decoder import SamplerResultDecoder
from .utils.runner_result import RunnerResult


QISKIT_IBM_RUNTIME_API_URL = "https://auth.quantum-computing.ibm.com/api"

API_TO_JOB_STATUS = {
Expand All @@ -29,3 +35,10 @@
"CANCELLED - RAN TOO LONG": "Job {} ran longer than maximum execution time. "
"Job was cancelled:\n{}",
}

DEFAULT_DECODERS = {
"sampler": [ResultDecoder, SamplerResultDecoder],
"estimator": [ResultDecoder, EstimatorResultDecoder],
"circuit-runner": RunnerResult,
"qasm3-runner": RunnerResult,
}
6 changes: 3 additions & 3 deletions qiskit_ibm_runtime/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

# TODO import _circuit_key from terra once 0.23 is released
from .qiskit_runtime_service import QiskitRuntimeService
from .utils.estimator_result_decoder import EstimatorResultDecoder
from .runtime_job import RuntimeJob
from .utils.deprecation import (
deprecate_arguments,
Expand All @@ -37,6 +36,7 @@
from .ibm_backend import IBMBackend
from .session import get_default_session
from .options import Options
from .constants import DEFAULT_DECODERS

# pylint: disable=unused-import,cyclic-import
from .session import Session
Expand Down Expand Up @@ -315,7 +315,7 @@ def _run( # pylint: disable=arguments-differ
inputs=inputs,
options=Options._get_runtime_options(combined),
callback=combined.get("environment", {}).get("callback", None),
result_decoder=EstimatorResultDecoder,
result_decoder=DEFAULT_DECODERS.get(self._PROGRAM_ID),
)

def _call(
Expand Down Expand Up @@ -364,7 +364,7 @@ def _call(
program_id=self._PROGRAM_ID,
inputs=inputs,
options=Options._get_runtime_options(combined),
result_decoder=EstimatorResultDecoder,
result_decoder=DEFAULT_DECODERS.get(self._PROGRAM_ID),
).result()

@deprecate_function(
Expand Down
11 changes: 7 additions & 4 deletions qiskit_ibm_runtime/qiskit_runtime_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from datetime import datetime
from collections import OrderedDict
from typing import Dict, Callable, Optional, Union, List, Any, Type
from typing import Dict, Callable, Optional, Union, List, Any, Type, Sequence

from qiskit.providers.backend import BackendV1 as Backend
from qiskit.providers.provider import ProviderV1 as Provider
Expand Down Expand Up @@ -851,7 +851,9 @@ def run(
inputs: Union[Dict, ParameterNamespace],
options: Optional[Union[RuntimeOptions, Dict]] = None,
callback: Optional[Callable] = None,
result_decoder: Optional[Type[ResultDecoder]] = None,
result_decoder: Optional[
Union[Type[ResultDecoder], Sequence[Type[ResultDecoder]]]
] = None,
instance: Optional[str] = None,
session_id: Optional[str] = None,
job_tags: Optional[List[str]] = None,
Expand All @@ -874,7 +876,9 @@ def run(
2. Job result.
result_decoder: A :class:`ResultDecoder` subclass used to decode job results.
``ResultDecoder`` is used if not specified.
If more than one decoder is specified, the first is used for interim results and
the second final results. If not specified, a program-specific decoder or the default
``ResultDecoder`` is used.
instance: (DEPRECATED) This is only supported for ``ibm_quantum`` runtime and is in the
hub/group/project format.
session_id: Job ID of the first job in a runtime session.
Expand Down Expand Up @@ -937,7 +941,6 @@ def run(
backend = hgp.backend(qrt_options.backend)
hgp_name = hgp.name

result_decoder = result_decoder or ResultDecoder
try:
response = self._api_client.program_run(
program_id=program_id,
Expand Down
32 changes: 17 additions & 15 deletions qiskit_ibm_runtime/runtime_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

"""Qiskit runtime job."""

from typing import Any, Optional, Callable, Dict, Type
from typing import Any, Optional, Callable, Dict, Type, Union, Sequence
import time
import json
import logging
Expand All @@ -24,11 +24,8 @@
from qiskit.providers.backend import Backend
from qiskit.providers.jobstatus import JobStatus, JOB_FINAL_STATES
from qiskit.providers.job import JobV1 as Job
from .utils.estimator_result_decoder import EstimatorResultDecoder
from .utils.sampler_result_decoder import SamplerResultDecoder


from .constants import API_TO_JOB_ERROR_MESSAGE, API_TO_JOB_STATUS
from .constants import API_TO_JOB_ERROR_MESSAGE, API_TO_JOB_STATUS, DEFAULT_DECODERS
from .exceptions import (
RuntimeJobFailureError,
RuntimeInvalidStateError,
Expand Down Expand Up @@ -95,7 +92,9 @@ def __init__(
params: Optional[Dict] = None,
creation_date: Optional[str] = None,
user_callback: Optional[Callable] = None,
result_decoder: Type[ResultDecoder] = ResultDecoder,
result_decoder: Optional[
Union[Type[ResultDecoder], Sequence[Type[ResultDecoder]]]
] = None,
image: Optional[str] = "",
) -> None:
"""RuntimeJob constructor.
Expand All @@ -122,10 +121,17 @@ def __init__(
self._status = JobStatus.INITIALIZING
self._reason: Optional[str] = None
self._error_message: Optional[str] = None
self._result_decoder = result_decoder
self._image = image
self._final_interim_results = False

decoder = (
result_decoder or DEFAULT_DECODERS.get(program_id, None) or ResultDecoder
)
if isinstance(decoder, Sequence):
self._interim_result_decoder, self._final_result_decoder = decoder
else:
self._interim_result_decoder = self._final_result_decoder = decoder

# Used for streaming
self._ws_client_future = None # type: Optional[futures.Future]
self._result_queue = queue.Queue() # type: queue.Queue
Expand Down Expand Up @@ -159,7 +165,7 @@ def interim_results(self, decoder: Optional[Type[ResultDecoder]] = None) -> Any:
RuntimeJobFailureError: If the job failed.
"""
if not self._final_interim_results:
_decoder = decoder or self._result_decoder
_decoder = decoder or self._interim_result_decoder
interim_results_raw = self._api_client.job_interim_results(
job_id=self.job_id()
)
Expand All @@ -186,12 +192,8 @@ def result( # pylint: disable=arguments-differ
RuntimeJobFailureError: If the job failed.
RuntimeJobMaxTimeoutError: If the job does not complete within given timeout.
"""
if self.program_id == "sampler":
self._result_decoder = SamplerResultDecoder
elif self.program_id == "estimator":
self._result_decoder = EstimatorResultDecoder
_decoder = decoder or self._result_decoder
if self._results is None or (_decoder != self._result_decoder):
_decoder = decoder or self._final_result_decoder
if self._results is None or (_decoder != self._final_result_decoder):
self.wait_for_final_state(timeout=timeout)
if self._status == JobStatus.ERROR:
error_message = self.error_message()
Expand Down Expand Up @@ -474,7 +476,7 @@ def _stream_results(
decoder: A :class:`ResultDecoder` (sub)class used to decode job results.
"""
logger.debug("Start result streaming for job %s", self.job_id())
_decoder = decoder or self._result_decoder
_decoder = decoder or self._interim_result_decoder
while True:
try:
response = result_queue.get()
Expand Down
4 changes: 2 additions & 2 deletions qiskit_ibm_runtime/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
# TODO import _circuit_key from terra once 0.23 released
from .qiskit_runtime_service import QiskitRuntimeService
from .options import Options
from .utils.sampler_result_decoder import SamplerResultDecoder
from .runtime_job import RuntimeJob
from .ibm_backend import IBMBackend
from .session import get_default_session
Expand All @@ -34,6 +33,7 @@
issue_deprecation_msg,
deprecate_function,
)
from .constants import DEFAULT_DECODERS

# pylint: disable=unused-import,cyclic-import
from .session import Session
Expand Down Expand Up @@ -269,7 +269,7 @@ def _run( # pylint: disable=arguments-differ
inputs=inputs,
options=Options._get_runtime_options(combined),
callback=combined.get("environment", {}).get("callback", None),
result_decoder=SamplerResultDecoder,
result_decoder=DEFAULT_DECODERS.get(self._PROGRAM_ID),
)

def _call(
Expand Down
7 changes: 7 additions & 0 deletions releasenotes/notes/separate-decoders-dde5bf7d051038e6.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
You can now specify a pair of result decoders for the ``result_decoder``
parameter of :meth:`qiskit_ibm_runtime.QiskitRuntimeService.run` method.
If a pair is specified, the first one is used to decode interim results
and the second the final results.
2 changes: 2 additions & 0 deletions test/ibm_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def _run_program(
max_execution_time=None,
session_id=None,
start_session=False,
sleep_per_iteration=0,
):
"""Run a program."""
self.log.debug("Running program on %s", service.channel)
Expand All @@ -197,6 +198,7 @@ def _run_program(
"iterations": iterations,
"interim_results": interim_results or {},
"final_result": final_result or {},
"sleep_per_iteration": sleep_per_iteration,
}
)
pid = program_id or self.program_ids[service.channel]
Expand Down
4 changes: 3 additions & 1 deletion test/integration/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def _callback(job_id_, result_):
circuits=[bell] * 40, observables=[obs] * 40, callback=_callback
)
result = job.result()
self.assertTrue((result.values == ws_result[-1].values).all())
self.assertIsInstance(ws_result[-1], dict)
ws_result_values = np.asarray(ws_result[-1]["values"])
self.assertTrue((result.values == ws_result_values).all())
self.assertEqual(len(job_ids), 1)
self.assertEqual(job.job_id(), job_ids.pop())

Expand Down
4 changes: 3 additions & 1 deletion test/integration/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def result_callback(job_id, result):
called_back_count += 1

called_back_count = 0
job = self._run_program(service, interim_results="foobar")
job = self._run_program(
service, interim_results="foobar", sleep_per_iteration=10
)
job.wait_for_final_state()
job._status = JobStatus.RUNNING # Allow stream_results()
job.stream_results(result_callback)
Expand Down
7 changes: 6 additions & 1 deletion test/integration/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from qiskit.circuit.library import RealAmplitudes
from qiskit.test.reference_circuits import ReferenceCircuits
from qiskit.primitives import BaseSampler, SamplerResult
from qiskit.result import QuasiDistribution

from qiskit_ibm_runtime import Sampler, Session
from qiskit_ibm_runtime.exceptions import RuntimeJobFailureError
Expand Down Expand Up @@ -270,6 +271,10 @@ def _callback(job_id_, result_):
job = sampler.run(circuits=[self.bell] * 20, callback=_callback)
result = job.result()

self.assertEqual(result.quasi_dists, ws_result[-1].quasi_dists)
self.assertIsInstance(ws_result[-1], dict)
ws_result_quasi = [
QuasiDistribution(quasi) for quasi in ws_result[-1]["quasi_dists"]
]
self.assertEqual(result.quasi_dists, ws_result_quasi)
self.assertEqual(len(job_ids), 1)
self.assertEqual(job.job_id(), job_ids.pop())

0 comments on commit 5ef5711

Please sign in to comment.