Skip to content

Commit

Permalink
Allow specification of job IDs in RemoteResults (#718)
Browse files Browse the repository at this point in the history
* Allow specification of job IDs in RemoteResults

* Rename submission_id to batch_id
  • Loading branch information
HGSilveri authored Sep 16, 2024
1 parent 8550104 commit 6c12156
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 37 deletions.
47 changes: 43 additions & 4 deletions pulser-core/pulser/backend/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,47 @@ class RemoteResults(Results):
the results.
connection: The remote connection over which to get the submission's
status and fetch the results.
job_ids: If given, specifies which jobs within the submission should
be included in the results and in what order. If left undefined,
all jobs are included.
"""

def __init__(self, submission_id: str, connection: RemoteConnection):
def __init__(
self,
submission_id: str,
connection: RemoteConnection,
job_ids: list[str] | None = None,
):
"""Instantiates a new collection of remote results."""
self._submission_id = submission_id
self._connection = connection
if job_ids is not None and not set(job_ids).issubset(
all_job_ids := self._connection._get_job_ids(self._submission_id)
):
unknown_ids = [id_ for id_ in job_ids if id_ not in all_job_ids]
raise RuntimeError(
f"Submission {self._submission_id!r} does not contain jobs "
f"{unknown_ids}."
)
self._job_ids = job_ids

@property
def results(self) -> tuple[Result, ...]:
"""The actual results, obtained after execution is done."""
return self._results

@property
def batch_id(self) -> str:
"""The ID of the batch containing these results."""
return self._submission_id

@property
def job_ids(self) -> list[str]:
"""The IDs of the jobs within this results submission."""
if self._job_ids is None:
return self._connection._get_job_ids(self._submission_id)
return self._job_ids

def get_status(self) -> SubmissionStatus:
"""Gets the status of the remote submission."""
return self._connection._get_submission_status(self._submission_id)
Expand All @@ -79,7 +108,9 @@ def __getattr__(self, name: str) -> Any:
status = self.get_status()
if status == SubmissionStatus.DONE:
self._results = tuple(
self._connection._fetch_result(self._submission_id)
self._connection._fetch_result(
self._submission_id, self._job_ids
)
)
return self._results
raise RemoteResultsError(
Expand All @@ -102,7 +133,9 @@ def submit(
pass

@abstractmethod
def _fetch_result(self, submission_id: str) -> typing.Sequence[Result]:
def _fetch_result(
self, submission_id: str, job_ids: list[str] | None
) -> typing.Sequence[Result]:
"""Fetches the results of a completed submission."""
pass

Expand All @@ -115,9 +148,15 @@ def _get_submission_status(self, submission_id: str) -> SubmissionStatus:
"""
pass

def _get_job_ids(self, submission_id: str) -> list[str]:
"""Gets all the job IDs within a submission."""
raise NotImplementedError(
"Unable to find job IDs through this remote connection."
)

def fetch_available_devices(self) -> dict[str, Device]:
"""Fetches the devices available through this connection."""
raise NotImplementedError( # pragma: no cover
raise NotImplementedError(
"Unable to fetch the available devices through this "
"remote connection."
)
Expand Down
21 changes: 19 additions & 2 deletions pulser-pasqal/pulser_pasqal/pasqal_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def fetch_available_devices(self) -> dict[str, Device]:
for name, dev_str in abstract_devices.items()
}

def _fetch_result(self, submission_id: str) -> tuple[Result, ...]:
def _fetch_result(
self, submission_id: str, job_ids: list[str] | None
) -> tuple[Result, ...]:
# For now, the results are always sampled results
get_batch_fn = backoff_decorator(self._sdk_connection.get_batch)
batch = get_batch_fn(id=submission_id)
Expand All @@ -189,7 +191,16 @@ def _fetch_result(self, submission_id: str) -> tuple[Result, ...]:
meas_basis = seq_builder.get_measurement_basis()

results = []
for job in batch.ordered_jobs:
sdk_jobs = batch.ordered_jobs
if job_ids is not None:
ind_job_pairs = [
(job_ids.index(job.id), job)
for job in sdk_jobs
if job.id in job_ids
]
ind_job_pairs.sort()
sdk_jobs = [job for _, job in ind_job_pairs]
for job in sdk_jobs:
vars = job.variables
size: int | None = None
if vars and "qubits" in vars:
Expand All @@ -210,6 +221,12 @@ def _get_submission_status(self, submission_id: str) -> SubmissionStatus:
batch = self._sdk_connection.get_batch(id=submission_id)
return SubmissionStatus[batch.status]

@backoff_decorator
def _get_job_ids(self, submission_id: str) -> list[str]:
"""Gets all the job IDs within a submission."""
batch = self._sdk_connection.get_batch(id=submission_id)
return [job.id for job in batch.ordered_jobs]

def _convert_configuration(
self,
config: EmulatorConfig | None,
Expand Down
16 changes: 15 additions & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def __init__(self):
def submit(self, sequence, wait: bool = False, **kwargs) -> RemoteResults:
return RemoteResults("abcd", self)

def _fetch_result(self, submission_id: str) -> typing.Sequence[Result]:
def _fetch_result(
self, submission_id: str, job_ids: list[str] | None = None
) -> typing.Sequence[Result]:
return (
SampledResult(
("q0", "q1"),
Expand All @@ -109,6 +111,18 @@ def _get_submission_status(self, submission_id: str) -> SubmissionStatus:
return SubmissionStatus.DONE


def test_remote_connection():
connection = _MockConnection()

with pytest.raises(NotImplementedError, match="Unable to find job IDs"):
connection._get_job_ids("abc")

with pytest.raises(
NotImplementedError, match="Unable to fetch the available devices"
):
connection.fetch_available_devices()


def test_qpu_backend(sequence):
connection = _MockConnection()

Expand Down
105 changes: 75 additions & 30 deletions tests/test_pasqal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import copy
import dataclasses
import re
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -70,18 +71,22 @@ def seq():
return Sequence(reg, test_device)


@pytest.fixture
def mock_job():
@dataclasses.dataclass
class MockJob:
runs = 10
variables = {"t": 100, "qubits": {"q0": 1, "q1": 2, "q2": 4, "q3": 3}}
result = {"00": 5, "11": 5}
class _MockJob:
def __init__(
self,
runs=10,
variables={"t": 100, "qubits": {"q0": 1, "q1": 2, "q2": 4, "q3": 3}},
result={"00": 5, "11": 5},
) -> None:
self.runs = runs
self.variables = variables
self.result = result
self.id = str(np.random.randint(10000))

def __post_init__(self) -> None:
self.id = str(np.random.randint(10000))

return MockJob()
@pytest.fixture
def mock_job():
return _MockJob()


@pytest.fixture
Expand All @@ -94,7 +99,11 @@ def mock_batch(mock_job, seq):
class MockBatch:
id = "abcd"
status = "DONE"
ordered_jobs = [mock_job]
ordered_jobs = [
mock_job,
_MockJob(result={"00": 10}),
_MockJob(result={"11": 10}),
]
sequence_builder = seq_.to_abstract_repr()

return MockBatch()
Expand Down Expand Up @@ -132,12 +141,64 @@ def fixt(mock_batch):
mock_cloud_sdk_class.assert_not_called()


@pytest.mark.parametrize("with_job_id", [False, True])
def test_remote_results(fixt, mock_batch, with_job_id):
with pytest.raises(
RuntimeError, match=re.escape("does not contain jobs ['badjobid']")
):
RemoteResults(mock_batch.id, fixt.pasqal_cloud, job_ids=["badjobid"])
fixt.mock_cloud_sdk.get_batch.reset_mock()

select_jobs = (
mock_batch.ordered_jobs[::-1][:2]
if with_job_id
else mock_batch.ordered_jobs
)
select_job_ids = [j.id for j in select_jobs]

remote_results = RemoteResults(
mock_batch.id,
fixt.pasqal_cloud,
job_ids=select_job_ids if with_job_id else None,
)

assert remote_results.batch_id == mock_batch.id
assert remote_results.job_ids == select_job_ids
fixt.mock_cloud_sdk.get_batch.assert_called_once_with(
id=remote_results.batch_id
)
fixt.mock_cloud_sdk.get_batch.reset_mock()

assert remote_results.get_status() == SubmissionStatus.DONE
fixt.mock_cloud_sdk.get_batch.assert_called_once_with(
id=remote_results.batch_id
)

fixt.mock_cloud_sdk.get_batch.reset_mock()
results = remote_results.results
fixt.mock_cloud_sdk.get_batch.assert_called_with(
id=remote_results.batch_id
)
assert results == tuple(
SampledResult(
atom_order=("q0", "q1", "q2", "q3"),
meas_basis="ground-rydberg",
bitstring_counts=job.result,
)
for job in select_jobs
)

assert hasattr(remote_results, "_results")


@pytest.mark.parametrize("mimic_qpu", [False, True])
@pytest.mark.parametrize(
"emulator", [None, EmulatorType.EMU_TN, EmulatorType.EMU_FREE]
)
@pytest.mark.parametrize("parametrized", [True, False])
def test_submit(fixt, parametrized, emulator, mimic_qpu, seq, mock_job):
def test_submit(
fixt, parametrized, emulator, mimic_qpu, seq, mock_batch, mock_job
):
with pytest.raises(
ValueError,
match="The measurement basis can't be implicitly determined for a "
Expand Down Expand Up @@ -240,6 +301,8 @@ def test_submit(fixt, parametrized, emulator, mimic_qpu, seq, mock_job):
config=config,
mimic_qpu=mimic_qpu,
)
assert remote_results.batch_id == mock_batch.id

assert not seq.is_measured()
seq.measure(basis="ground-rydberg")

Expand All @@ -266,24 +329,6 @@ def test_submit(fixt, parametrized, emulator, mimic_qpu, seq, mock_job):
)

assert isinstance(remote_results, RemoteResults)
assert remote_results.get_status() == SubmissionStatus.DONE
fixt.mock_cloud_sdk.get_batch.assert_called_once_with(
id=remote_results._submission_id
)

fixt.mock_cloud_sdk.get_batch.reset_mock()
results = remote_results.results
fixt.mock_cloud_sdk.get_batch.assert_called_with(
id=remote_results._submission_id
)
assert results == (
SampledResult(
atom_order=("q0", "q1", "q2", "q3"),
meas_basis="ground-rydberg",
bitstring_counts=mock_job.result,
),
)
assert hasattr(remote_results, "_results")


@pytest.mark.parametrize("emu_cls", [EmuTNBackend, EmuFreeBackend])
Expand Down

0 comments on commit 6c12156

Please sign in to comment.