diff --git a/pulser-core/pulser/backend/remote.py b/pulser-core/pulser/backend/remote.py index 92809b62..6932e109 100644 --- a/pulser-core/pulser/backend/remote.py +++ b/pulser-core/pulser/backend/remote.py @@ -58,18 +58,47 @@ class RemoteResults(Results): the results. connection: The remote connection over which to get the submission's status and fetch the results. + job_ids: If given, specifies which jobs within the submission should + be included in the results and in what order. If left undefined, + all jobs are included. """ - def __init__(self, submission_id: str, connection: RemoteConnection): + def __init__( + self, + submission_id: str, + connection: RemoteConnection, + job_ids: list[str] | None = None, + ): """Instantiates a new collection of remote results.""" self._submission_id = submission_id self._connection = connection + if job_ids is not None and not set(job_ids).issubset( + all_job_ids := self._connection._get_job_ids(self._submission_id) + ): + unknown_ids = [id_ for id_ in job_ids if id_ not in all_job_ids] + raise RuntimeError( + f"Submission {self._submission_id!r} does not contain jobs " + f"{unknown_ids}." + ) + self._job_ids = job_ids @property def results(self) -> tuple[Result, ...]: """The actual results, obtained after execution is done.""" return self._results + @property + def batch_id(self) -> str: + """The ID of the batch containing these results.""" + return self._submission_id + + @property + def job_ids(self) -> list[str]: + """The IDs of the jobs within this results submission.""" + if self._job_ids is None: + return self._connection._get_job_ids(self._submission_id) + return self._job_ids + def get_status(self) -> SubmissionStatus: """Gets the status of the remote submission.""" return self._connection._get_submission_status(self._submission_id) @@ -79,7 +108,9 @@ def __getattr__(self, name: str) -> Any: status = self.get_status() if status == SubmissionStatus.DONE: self._results = tuple( - self._connection._fetch_result(self._submission_id) + self._connection._fetch_result( + self._submission_id, self._job_ids + ) ) return self._results raise RemoteResultsError( @@ -102,7 +133,9 @@ def submit( pass @abstractmethod - def _fetch_result(self, submission_id: str) -> typing.Sequence[Result]: + def _fetch_result( + self, submission_id: str, job_ids: list[str] | None + ) -> typing.Sequence[Result]: """Fetches the results of a completed submission.""" pass @@ -115,9 +148,15 @@ def _get_submission_status(self, submission_id: str) -> SubmissionStatus: """ pass + def _get_job_ids(self, submission_id: str) -> list[str]: + """Gets all the job IDs within a submission.""" + raise NotImplementedError( + "Unable to find job IDs through this remote connection." + ) + def fetch_available_devices(self) -> dict[str, Device]: """Fetches the devices available through this connection.""" - raise NotImplementedError( # pragma: no cover + raise NotImplementedError( "Unable to fetch the available devices through this " "remote connection." ) diff --git a/pulser-pasqal/pulser_pasqal/pasqal_cloud.py b/pulser-pasqal/pulser_pasqal/pasqal_cloud.py index b4c96397..e0faa009 100644 --- a/pulser-pasqal/pulser_pasqal/pasqal_cloud.py +++ b/pulser-pasqal/pulser_pasqal/pasqal_cloud.py @@ -179,7 +179,9 @@ def fetch_available_devices(self) -> dict[str, Device]: for name, dev_str in abstract_devices.items() } - def _fetch_result(self, submission_id: str) -> tuple[Result, ...]: + def _fetch_result( + self, submission_id: str, job_ids: list[str] | None + ) -> tuple[Result, ...]: # For now, the results are always sampled results get_batch_fn = backoff_decorator(self._sdk_connection.get_batch) batch = get_batch_fn(id=submission_id) @@ -189,7 +191,16 @@ def _fetch_result(self, submission_id: str) -> tuple[Result, ...]: meas_basis = seq_builder.get_measurement_basis() results = [] - for job in batch.ordered_jobs: + sdk_jobs = batch.ordered_jobs + if job_ids is not None: + ind_job_pairs = [ + (job_ids.index(job.id), job) + for job in sdk_jobs + if job.id in job_ids + ] + ind_job_pairs.sort() + sdk_jobs = [job for _, job in ind_job_pairs] + for job in sdk_jobs: vars = job.variables size: int | None = None if vars and "qubits" in vars: @@ -210,6 +221,12 @@ def _get_submission_status(self, submission_id: str) -> SubmissionStatus: batch = self._sdk_connection.get_batch(id=submission_id) return SubmissionStatus[batch.status] + @backoff_decorator + def _get_job_ids(self, submission_id: str) -> list[str]: + """Gets all the job IDs within a submission.""" + batch = self._sdk_connection.get_batch(id=submission_id) + return [job.id for job in batch.ordered_jobs] + def _convert_configuration( self, config: EmulatorConfig | None, diff --git a/tests/test_backend.py b/tests/test_backend.py index 4ebacb16..98320758 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -93,7 +93,9 @@ def __init__(self): def submit(self, sequence, wait: bool = False, **kwargs) -> RemoteResults: return RemoteResults("abcd", self) - def _fetch_result(self, submission_id: str) -> typing.Sequence[Result]: + def _fetch_result( + self, submission_id: str, job_ids: list[str] | None = None + ) -> typing.Sequence[Result]: return ( SampledResult( ("q0", "q1"), @@ -109,6 +111,18 @@ def _get_submission_status(self, submission_id: str) -> SubmissionStatus: return SubmissionStatus.DONE +def test_remote_connection(): + connection = _MockConnection() + + with pytest.raises(NotImplementedError, match="Unable to find job IDs"): + connection._get_job_ids("abc") + + with pytest.raises( + NotImplementedError, match="Unable to fetch the available devices" + ): + connection.fetch_available_devices() + + def test_qpu_backend(sequence): connection = _MockConnection() diff --git a/tests/test_pasqal.py b/tests/test_pasqal.py index 0fc950e0..76106194 100644 --- a/tests/test_pasqal.py +++ b/tests/test_pasqal.py @@ -15,6 +15,7 @@ import copy import dataclasses +import re from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch @@ -70,18 +71,22 @@ def seq(): return Sequence(reg, test_device) -@pytest.fixture -def mock_job(): - @dataclasses.dataclass - class MockJob: - runs = 10 - variables = {"t": 100, "qubits": {"q0": 1, "q1": 2, "q2": 4, "q3": 3}} - result = {"00": 5, "11": 5} +class _MockJob: + def __init__( + self, + runs=10, + variables={"t": 100, "qubits": {"q0": 1, "q1": 2, "q2": 4, "q3": 3}}, + result={"00": 5, "11": 5}, + ) -> None: + self.runs = runs + self.variables = variables + self.result = result + self.id = str(np.random.randint(10000)) - def __post_init__(self) -> None: - self.id = str(np.random.randint(10000)) - return MockJob() +@pytest.fixture +def mock_job(): + return _MockJob() @pytest.fixture @@ -94,7 +99,11 @@ def mock_batch(mock_job, seq): class MockBatch: id = "abcd" status = "DONE" - ordered_jobs = [mock_job] + ordered_jobs = [ + mock_job, + _MockJob(result={"00": 10}), + _MockJob(result={"11": 10}), + ] sequence_builder = seq_.to_abstract_repr() return MockBatch() @@ -132,12 +141,64 @@ def fixt(mock_batch): mock_cloud_sdk_class.assert_not_called() +@pytest.mark.parametrize("with_job_id", [False, True]) +def test_remote_results(fixt, mock_batch, with_job_id): + with pytest.raises( + RuntimeError, match=re.escape("does not contain jobs ['badjobid']") + ): + RemoteResults(mock_batch.id, fixt.pasqal_cloud, job_ids=["badjobid"]) + fixt.mock_cloud_sdk.get_batch.reset_mock() + + select_jobs = ( + mock_batch.ordered_jobs[::-1][:2] + if with_job_id + else mock_batch.ordered_jobs + ) + select_job_ids = [j.id for j in select_jobs] + + remote_results = RemoteResults( + mock_batch.id, + fixt.pasqal_cloud, + job_ids=select_job_ids if with_job_id else None, + ) + + assert remote_results.batch_id == mock_batch.id + assert remote_results.job_ids == select_job_ids + fixt.mock_cloud_sdk.get_batch.assert_called_once_with( + id=remote_results.batch_id + ) + fixt.mock_cloud_sdk.get_batch.reset_mock() + + assert remote_results.get_status() == SubmissionStatus.DONE + fixt.mock_cloud_sdk.get_batch.assert_called_once_with( + id=remote_results.batch_id + ) + + fixt.mock_cloud_sdk.get_batch.reset_mock() + results = remote_results.results + fixt.mock_cloud_sdk.get_batch.assert_called_with( + id=remote_results.batch_id + ) + assert results == tuple( + SampledResult( + atom_order=("q0", "q1", "q2", "q3"), + meas_basis="ground-rydberg", + bitstring_counts=job.result, + ) + for job in select_jobs + ) + + assert hasattr(remote_results, "_results") + + @pytest.mark.parametrize("mimic_qpu", [False, True]) @pytest.mark.parametrize( "emulator", [None, EmulatorType.EMU_TN, EmulatorType.EMU_FREE] ) @pytest.mark.parametrize("parametrized", [True, False]) -def test_submit(fixt, parametrized, emulator, mimic_qpu, seq, mock_job): +def test_submit( + fixt, parametrized, emulator, mimic_qpu, seq, mock_batch, mock_job +): with pytest.raises( ValueError, match="The measurement basis can't be implicitly determined for a " @@ -240,6 +301,8 @@ def test_submit(fixt, parametrized, emulator, mimic_qpu, seq, mock_job): config=config, mimic_qpu=mimic_qpu, ) + assert remote_results.batch_id == mock_batch.id + assert not seq.is_measured() seq.measure(basis="ground-rydberg") @@ -266,24 +329,6 @@ def test_submit(fixt, parametrized, emulator, mimic_qpu, seq, mock_job): ) assert isinstance(remote_results, RemoteResults) - assert remote_results.get_status() == SubmissionStatus.DONE - fixt.mock_cloud_sdk.get_batch.assert_called_once_with( - id=remote_results._submission_id - ) - - fixt.mock_cloud_sdk.get_batch.reset_mock() - results = remote_results.results - fixt.mock_cloud_sdk.get_batch.assert_called_with( - id=remote_results._submission_id - ) - assert results == ( - SampledResult( - atom_order=("q0", "q1", "q2", "q3"), - meas_basis="ground-rydberg", - bitstring_counts=mock_job.result, - ), - ) - assert hasattr(remote_results, "_results") @pytest.mark.parametrize("emu_cls", [EmuTNBackend, EmuFreeBackend])