Skip to content

Commit

Permalink
Add open batches to pulser-pasqal (#701)
Browse files Browse the repository at this point in the history
* rework sdk to not require batch_id as an argument

* rework sdk to not require batch_id as an argument

* rework sdk to not require batch_id as an argument

* rework sdk to not require batch_id as an argument

* rework sdk to not require batch_id as an argument

* change to context manager interface for open batches

* fix rebase

fix rebase, and linting

fix rebase, and linting

* fix rebase, and linting

fix rebase, and linting

fix rebase, and linting

fix rebase, and linting

fix rebase, and linting

* fix type

fix type

* complete test coverage for method calls

complete test coverage for method calls

* context management class, update tests

context management class, update tests

* inside return is ignored with _

* mr feedback

* boolean condition for open batch support

boolean condition for open batch support

boolean condition for open batch support

boolean condition for open batch support

* test coverage

* flake8

* MR feedback

MR feedback

* comment on arg name

* support complete -> open keyword change for batches

* support complete -> open keyword change for batches

* lint

lint

* Bump pasqal-cloud to v0.12

* Include only the new jobs in the RemoteResults of each call to submit()

* Give stored batch ID to get available results

* Submission -> Batch outside of RemoteResults

* Including backend specific kwargs to RemoteConnection.submit() when opening batch

* Fully deprecate 'submission' for 'batch'

* Relax `pasqal-cloud` requirement

* Consistency updates to the tutorial

---------

Co-authored-by: oliver.gordon <oliver.gordon@pasqal.com>
Co-authored-by: HGSilveri <henrique.silverio@tecnico.ulisboa.pt>
Co-authored-by: Henrique Silvério <29920212+HGSilveri@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 20, 2024
1 parent 7af1d2d commit c12306a
Show file tree
Hide file tree
Showing 9 changed files with 402 additions and 82 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ dist/
env*
*.egg-info/
__venv__/
venv
5 changes: 1 addition & 4 deletions pulser-core/pulser/backend/qpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ def run(
self.validate_job_params(
job_params or [], self._sequence.device.max_runs
)
results = self._connection.submit(
self._sequence, job_params=job_params, wait=wait
)
return cast(RemoteResults, results)
return cast(RemoteResults, super().run(job_params, wait))

@staticmethod
def validate_job_params(
Expand Down
221 changes: 188 additions & 33 deletions pulser-core/pulser/backend/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base classes for remote backend execution."""

from __future__ import annotations

import typing
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum, auto
from typing import Any, Mapping, TypedDict
from functools import wraps
from types import TracebackType
from typing import Any, Mapping, Type, TypedDict, TypeVar, cast

from pulser.backend.abc import Backend
from pulser.devices import Device
Expand All @@ -44,6 +49,21 @@ class SubmissionStatus(Enum):
PAUSED = auto()


class BatchStatus(Enum):
"""Status of a batch.
Same as SubmissionStatus, needed because we renamed Submission -> Batch.
"""

PENDING = auto()
RUNNING = auto()
DONE = auto()
CANCELED = auto()
TIMED_OUT = auto()
ERROR = auto()
PAUSED = auto()


class JobStatus(Enum):
"""Status of a remote job."""

Expand All @@ -61,34 +81,63 @@ class RemoteResultsError(Exception):
pass


F = TypeVar("F", bound=Callable)


def _deprecate_submission_id(func: F) -> F:
@wraps(func)
def wrapper(self: RemoteResults, *args: Any, **kwargs: Any) -> Any:
if "submission_id" in kwargs:
# 'batch_id' is the first positional arg so if len(args) > 0,
# then it is being given
if "batch_id" in kwargs or args:
raise ValueError(
"'submission_id' and 'batch_id' cannot be simultaneously"
" specified. Please provide only the 'batch_id'."
)
warnings.warn(
"'submission_id' has been deprecated and replaced by "
"'batch_id'.",
category=DeprecationWarning,
stacklevel=3,
)
kwargs["batch_id"] = kwargs.pop("submission_id")
return func(self, *args, **kwargs)

return cast(F, wrapper)


class RemoteResults(Results):
"""A collection of results obtained through a remote connection.
Warns:
DeprecationWarning: If 'submission_id' is given instead of 'batch_id'.
Args:
submission_id: The ID that identifies the submission linked to
the results.
connection: The remote connection over which to get the submission's
batch_id: The ID that identifies the batch linked to the results.
connection: The remote connection over which to get the batch's
status and fetch the results.
job_ids: If given, specifies which jobs within the submission should
job_ids: If given, specifies which jobs within the batch should
be included in the results and in what order. If left undefined,
all jobs are included.
"""

@_deprecate_submission_id
def __init__(
self,
submission_id: str,
batch_id: str,
connection: RemoteConnection,
job_ids: list[str] | None = None,
):
"""Instantiates a new collection of remote results."""
self._submission_id = submission_id
self._batch_id = batch_id
self._connection = connection
if job_ids is not None and not set(job_ids).issubset(
all_job_ids := self._connection._get_job_ids(self._submission_id)
all_job_ids := self._connection._get_job_ids(self._batch_id)
):
unknown_ids = [id_ for id_ in job_ids if id_ not in all_job_ids]
raise RuntimeError(
f"Submission {self._submission_id!r} does not contain jobs "
f"Batch {self._batch_id!r} does not contain jobs "
f"{unknown_ids}."
)
self._job_ids = job_ids
Expand All @@ -98,27 +147,53 @@ def results(self) -> tuple[Result, ...]:
"""The actual results, obtained after execution is done."""
return self._results

@property
def _submission_id(self) -> str:
"""The same as the batch ID, kept for backwards compatibility."""
warnings.warn(
"'RemoteResults._submission_id' has been deprecated, please use"
"'RemoteResults.batch_id' instead.",
category=DeprecationWarning,
stacklevel=2,
)
return self._batch_id

@property
def batch_id(self) -> str:
"""The ID of the batch containing these results."""
return self._submission_id
return self._batch_id

@property
def job_ids(self) -> list[str]:
"""The IDs of the jobs within this results submission."""
"""The IDs of the jobs within these results' batch."""
if self._job_ids is None:
return self._connection._get_job_ids(self._submission_id)
return self._connection._get_job_ids(self._batch_id)
return self._job_ids

def get_status(self) -> SubmissionStatus:
"""Gets the status of the remote submission."""
return self._connection._get_submission_status(self._submission_id)
"""Gets the status of the remote submission.
Warning:
This method has been deprecated, please use
`RemoteResults.get_batch_status()` instead.
"""
warnings.warn(
"'RemoteResults.get_status()' has been deprecated, please use"
"'RemoteResults.get_batch_status()' instead.",
category=DeprecationWarning,
stacklevel=2,
)
return SubmissionStatus[self.get_batch_status().name]

def get_available_results(self, submission_id: str) -> dict[str, Result]:
"""Returns the available results of a submission.
def get_batch_status(self) -> BatchStatus:
"""Gets the status of the batch linked to these results."""
return self._connection._get_batch_status(self._batch_id)

def get_available_results(self) -> dict[str, Result]:
"""Returns the available results.
Unlike the `results` property, this method does not raise an error if
some jobs associated to the submission do not have results.
some of the jobs do not have results.
Returns:
dict[str, Result]: A dictionary mapping the job ID to its results.
Expand All @@ -127,7 +202,7 @@ def get_available_results(self, submission_id: str) -> dict[str, Result]:
results = {
k: v[1]
for k, v in self._connection._query_job_progress(
submission_id
self.batch_id
).items()
if v[1] is not None
}
Expand All @@ -141,7 +216,7 @@ def __getattr__(self, name: str) -> Any:
try:
self._results = tuple(
self._connection._fetch_result(
self._submission_id, self._job_ids
self.batch_id, self._job_ids
)
)
return self._results
Expand All @@ -161,42 +236,43 @@ class RemoteConnection(ABC):

@abstractmethod
def submit(
self, sequence: Sequence, wait: bool = False, **kwargs: Any
self,
sequence: Sequence,
wait: bool = False,
open: bool = True,
batch_id: str | None = None,
**kwargs: Any,
) -> RemoteResults | tuple[RemoteResults, ...]:
"""Submit a job for execution."""
pass

@abstractmethod
def _fetch_result(
self, submission_id: str, job_ids: list[str] | None
self, batch_id: str, job_ids: list[str] | None
) -> typing.Sequence[Result]:
"""Fetches the results of a completed submission."""
"""Fetches the results of a completed batch."""
pass

@abstractmethod
def _query_job_progress(
self, submission_id: str
self, batch_id: str
) -> Mapping[str, tuple[JobStatus, Result | None]]:
"""Fetches the status and results of all the jobs in a submission.
"""Fetches the status and results of all the jobs in a batch.
Unlike `_fetch_result`, this method does not raise an error if some
jobs associated to the submission do not have results.
jobs in the batch do not have results.
It returns a dictionnary mapping the job ID to its status and results.
"""
pass

@abstractmethod
def _get_submission_status(self, submission_id: str) -> SubmissionStatus:
"""Gets the status of a submission from its ID.
Not all SubmissionStatus values must be covered, but at least
SubmissionStatus.DONE is expected.
"""
def _get_batch_status(self, batch_id: str) -> BatchStatus:
"""Gets the status of a batch from its ID."""
pass

def _get_job_ids(self, submission_id: str) -> list[str]:
"""Gets all the job IDs within a submission."""
def _get_job_ids(self, batch_id: str) -> list[str]:
"""Gets all the job IDs within a batch."""
raise NotImplementedError(
"Unable to find job IDs through this remote connection."
)
Expand All @@ -208,6 +284,17 @@ def fetch_available_devices(self) -> dict[str, Device]:
"remote connection."
)

def _close_batch(self, batch_id: str) -> None:
"""Closes a batch using its ID."""
raise NotImplementedError( # pragma: no cover
"Unable to close batch through this remote connection"
)

@abstractmethod
def supports_open_batch(self) -> bool:
"""Flag to confirm this class can support creating an open batch."""
pass


class RemoteBackend(Backend):
"""A backend for sequence execution through a remote connection.
Expand All @@ -234,6 +321,39 @@ def __init__(
"'connection' must be a valid RemoteConnection instance."
)
self._connection = connection
self._batch_id: str | None = None

def run(
self, job_params: list[JobParams] | None = None, wait: bool = False
) -> RemoteResults | tuple[RemoteResults, ...]:
"""Runs the sequence on the remote backend and returns the result.
Args:
job_params: A list of parameters for each job to execute. Each
mapping must contain a defined 'runs' field specifying
the number of times to run the same sequence. If the sequence
is parametrized, the values for all the variables necessary
to build the sequence must be given in it's own mapping, for
each job, under the 'variables' field.
wait: Whether to wait until the results of the jobs become
available. If set to False, the call is non-blocking and the
obtained results' status can be checked using their `status`
property.
Returns:
The results, which can be accessed once all sequences have been
successfully executed.
"""
return self._connection.submit(
self._sequence,
job_params=job_params,
wait=wait,
**self._submit_kwargs(),
)

def _submit_kwargs(self) -> dict[str, Any]:
"""Keyword arguments given to any call to RemoteConnection.submit()."""
return dict(batch_id=self._batch_id)

@staticmethod
def _type_check_job_params(job_params: list[JobParams] | None) -> None:
Expand All @@ -247,3 +367,38 @@ def _type_check_job_params(job_params: list[JobParams] | None) -> None:
"All elements of 'job_params' must be dictionaries; "
f"got {type(d)} instead."
)

def open_batch(self) -> _OpenBatchContextManager:
"""Creates an open batch within a context manager object."""
if not self._connection.supports_open_batch():
raise NotImplementedError(
"Unable to execute open_batch using this remote connection"
)
return _OpenBatchContextManager(self)


class _OpenBatchContextManager:
def __init__(self, backend: RemoteBackend) -> None:
self.backend = backend

def __enter__(self) -> _OpenBatchContextManager:
batch = cast(
RemoteResults,
self.backend._connection.submit(
self.backend._sequence,
open=True,
**self.backend._submit_kwargs(),
),
)
self.backend._batch_id = batch.batch_id
return self

def __exit__(
self,
exc_type: Type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
if self.backend._batch_id:
self.backend._connection._close_batch(self.backend._batch_id)
self.backend._batch_id = None
12 changes: 7 additions & 5 deletions pulser-pasqal/pulser_pasqal/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

from dataclasses import fields
from typing import ClassVar
from typing import Any, ClassVar

import pasqal_cloud

Expand Down Expand Up @@ -88,12 +88,14 @@ def run(
"All elements of 'job_params' must specify 'runs'" + suffix
)

return self._connection.submit(
self._sequence,
job_params=job_params,
return super().run(job_params, wait)

def _submit_kwargs(self) -> dict[str, Any]:
"""Keyword arguments given to any call to RemoteConnection.submit()."""
return dict(
batch_id=self._batch_id,
emulator=self.emulator,
config=self._config,
wait=wait,
mimic_qpu=self._mimic_qpu,
)

Expand Down
Loading

0 comments on commit c12306a

Please sign in to comment.