Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add from_abstract_repr to Device and VirtualDevice #727

Merged
merged 6 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions pulser-core/pulser/devices/_device_datacls.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
from scipy.spatial.distance import squareform

import pulser.json.abstract_repr as pulser_abstract_repr
import pulser.math as pm
from pulser.channels.base_channel import Channel, States, get_states_from_bases
from pulser.channels.dmm import DMM
Expand Down Expand Up @@ -726,6 +727,33 @@ def _to_abstract_repr(self) -> dict[str, Any]:
d["is_virtual"] = False
return d

@staticmethod
def from_abstract_repr(obj_str: str) -> Device:
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
"""Deserialize a Device from an abstract JSON object.

Warning:
Raises an error if the JSON string represents a VirtualDevice.
VirtualDevice.from_abstract_repr should be used for this case.

Args:
obj_str (str): the JSON string representing the Device
encoded in the abstract JSON format.
"""
if not isinstance(obj_str, str):
raise TypeError(
"The serialized Device must be given as a string. "
f"Instead, got object of type {type(obj_str)}."
)

# Avoids circular imports
device = pulser_abstract_repr.deserializer.deserialize_device(obj_str)
if not isinstance(device, Device):
raise TypeError(
"The given schema is not related to a Device, but to a"
f" {type(device).__name__}."
)
return device


@dataclass(frozen=True)
class VirtualDevice(BaseDevice):
Expand Down Expand Up @@ -807,3 +835,27 @@ def _to_abstract_repr(self) -> dict[str, Any]:
d = super()._to_abstract_repr()
d["is_virtual"] = True
return d

@staticmethod
def from_abstract_repr(obj_str: str) -> VirtualDevice:
"""Deserialize a VirtualDevice from an abstract JSON object.

Warning:
If the JSON string represents a Device, the Device is converted
into a VirtualDevice using the `Device.to_virtual` method.

Args:
obj_str (str): the JSON string representing the noise model
encoded in the abstract JSON format.
"""
if not isinstance(obj_str, str):
raise TypeError(
"The serialized VirtualDevice must be given as a string. "
f"Instead, got object of type {type(obj_str)}."
)

# Avoids circular imports
device = pulser_abstract_repr.deserializer.deserialize_device(obj_str)
if isinstance(device, Device):
return device.to_virtual()
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
return device
149 changes: 146 additions & 3 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
DigitalAnalogDevice,
IroiseMVP,
MockDevice,
VirtualDevice,
)
from pulser.json.abstract_repr.deserializer import (
VARIABLE_TYPE_MAP,
Expand Down Expand Up @@ -222,21 +223,45 @@ def _roundtrip(abstract_device):

def test_exceptions(self, abstract_device):
def check_error_raised(
obj_str: str, original_err: Type[Exception], err_msg: str = ""
obj_str: str,
original_err: Type[Exception],
err_msg: str = "",
func: Callable = deserialize_device,
) -> Exception:
with pytest.raises(DeserializeDeviceError) as exc_info:
deserialize_device(obj_str)
func(obj_str)

cause = exc_info.value.__cause__
assert isinstance(cause, original_err)
assert re.search(re.escape(err_msg), str(cause)) is not None
return cause

dev_str = json.dumps(abstract_device)
good_device = deserialize_device(json.dumps(abstract_device))

deser_device = type(good_device).from_abstract_repr(dev_str)
assert good_device == deser_device
if isinstance(good_device, Device):
deser_device = VirtualDevice.from_abstract_repr(dev_str)
assert good_device.to_virtual() == deser_device
else:
with pytest.raises(
TypeError,
match="The given schema is not related to a Device, but to "
"a VirtualDevice.",
):
Device.from_abstract_repr(dev_str)
check_error_raised(
abstract_device, TypeError, "'obj_str' must be a string"
)
with pytest.raises(
TypeError, match="The serialized Device must be given as a string."
):
Device.from_abstract_repr(abstract_device)
with pytest.raises(
TypeError,
match="The serialized VirtualDevice must be given as a string.",
):
VirtualDevice.from_abstract_repr(abstract_device)

# JSONDecodeError from json.loads()
bad_str = "\ufeff"
Expand All @@ -246,6 +271,15 @@ def check_error_raised(
json.loads(bad_str)
err_msg = str(err.value)
check_error_raised(bad_str, json.JSONDecodeError, err_msg)
check_error_raised(
bad_str, json.JSONDecodeError, err_msg, Device.from_abstract_repr
)
check_error_raised(
bad_str,
json.JSONDecodeError,
err_msg,
VirtualDevice.from_abstract_repr,
)

# jsonschema.exceptions.ValidationError from jsonschema
invalid_dev = abstract_device.copy()
Expand All @@ -257,6 +291,18 @@ def check_error_raised(
jsonschema.exceptions.ValidationError,
str(err.value),
)
check_error_raised(
json.dumps(invalid_dev),
jsonschema.exceptions.ValidationError,
str(err.value),
Device.from_abstract_repr,
)
check_error_raised(
json.dumps(invalid_dev),
jsonschema.exceptions.ValidationError,
str(err.value),
VirtualDevice.from_abstract_repr,
)

# AbstractReprError from invalid RydbergEOM configuration
if good_device.channels["rydberg_global"].eom_config:
Expand All @@ -266,6 +312,20 @@ def check_error_raised(
assert "max_limiting_amp" in ch_dict["eom_config"]
ch_dict["eom_config"]["max_limiting_amp"] = 0.0
break
prev_err = check_error_raised(
json.dumps(bad_eom_dev),
AbstractReprError,
"RydbergEOM deserialization failed.",
Device.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, ValueError)
prev_err = check_error_raised(
json.dumps(bad_eom_dev),
AbstractReprError,
"RydbergEOM deserialization failed.",
VirtualDevice.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, ValueError)
prev_err = check_error_raised(
json.dumps(bad_eom_dev),
AbstractReprError,
Expand All @@ -276,6 +336,20 @@ def check_error_raised(
# AbstractReprError from ValueError in channel creation
bad_ch_dev1 = deepcopy(abstract_device)
bad_ch_dev1["channels"][0]["min_duration"] = -1
prev_err = check_error_raised(
json.dumps(bad_ch_dev1),
AbstractReprError,
"Channel deserialization failed.",
Device.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, ValueError)
prev_err = check_error_raised(
json.dumps(bad_ch_dev1),
AbstractReprError,
"Channel deserialization failed.",
VirtualDevice.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, ValueError)
prev_err = check_error_raised(
json.dumps(bad_ch_dev1),
AbstractReprError,
Expand All @@ -286,6 +360,20 @@ def check_error_raised(
# AbstractReprError from NotImplementedError in channel creation
bad_ch_dev2 = deepcopy(abstract_device)
bad_ch_dev2["channels"][0]["mod_bandwidth"] = 1000
prev_err = check_error_raised(
json.dumps(bad_ch_dev2),
AbstractReprError,
"Channel deserialization failed.",
Device.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, NotImplementedError)
prev_err = check_error_raised(
json.dumps(bad_ch_dev2),
AbstractReprError,
"Channel deserialization failed.",
VirtualDevice.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, NotImplementedError)
prev_err = check_error_raised(
json.dumps(bad_ch_dev2),
AbstractReprError,
Expand All @@ -299,6 +387,20 @@ def check_error_raised(
# Identical coords fail
bad_layout_obj = {"coordinates": [[0, 0], [0.0, 0.0]]}
bad_layout_dev["pre_calibrated_layouts"] = [bad_layout_obj]
prev_err = check_error_raised(
json.dumps(bad_layout_dev),
AbstractReprError,
"Register layout deserialization failed.",
Device.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, ValueError)
prev_err = check_error_raised(
json.dumps(bad_layout_dev),
AbstractReprError,
"Register layout deserialization failed.",
VirtualDevice.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, ValueError)
prev_err = check_error_raised(
json.dumps(bad_layout_dev),
AbstractReprError,
Expand All @@ -310,6 +412,20 @@ def check_error_raised(
if "XY" in good_device.supported_bases:
bad_xy_coeff_dev = abstract_device.copy()
bad_xy_coeff_dev["interaction_coeff_xy"] = None
prev_err = check_error_raised(
json.dumps(bad_xy_coeff_dev),
AbstractReprError,
"Device deserialization failed.",
Device.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, TypeError)
prev_err = check_error_raised(
json.dumps(bad_xy_coeff_dev),
AbstractReprError,
"Device deserialization failed.",
VirtualDevice.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, TypeError)
prev_err = check_error_raised(
json.dumps(bad_xy_coeff_dev),
AbstractReprError,
Expand All @@ -320,6 +436,20 @@ def check_error_raised(
# AbstractReprError from ValueError in device init
bad_dev = abstract_device.copy()
bad_dev["min_atom_distance"] = -1
prev_err = check_error_raised(
json.dumps(bad_dev),
AbstractReprError,
"Device deserialization failed.",
Device.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, ValueError)
prev_err = check_error_raised(
json.dumps(bad_dev),
AbstractReprError,
"Device deserialization failed.",
VirtualDevice.from_abstract_repr,
)
assert isinstance(prev_err.__cause__, ValueError)
prev_err = check_error_raised(
json.dumps(bad_dev),
AbstractReprError,
Expand All @@ -341,6 +471,18 @@ def test_optional_device_fields(self, og_device, field, value):
device = replace(og_device, **{field: value})
dev_str = device.to_abstract_repr()
assert device == deserialize_device(dev_str)
assert device == type(og_device).from_abstract_repr(dev_str)
if isinstance(og_device, Device):
assert device.to_virtual() == VirtualDevice.from_abstract_repr(
dev_str
)
return
with pytest.raises(
TypeError,
match="The given schema is not related to a Device, but to a "
"VirtualDevice.",
):
Device.from_abstract_repr(dev_str)

@pytest.mark.parametrize(
"ch_obj",
Expand Down Expand Up @@ -406,6 +548,7 @@ def test_optional_channel_fields(self, ch_obj):
)
dev_str = device.to_abstract_repr()
assert device == deserialize_device(dev_str)
assert device == VirtualDevice.from_abstract_repr(dev_str)


def validate_schema(instance):
Expand Down