Skip to content

Commit

Permalink
Make total_bottom_detuning mandatory (#728)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-corni authored Sep 13, 2024
1 parent 393526f commit 8550104
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 122 deletions.
16 changes: 2 additions & 14 deletions pulser-core/pulser/channels/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Defines the detuning map modulator."""
from __future__ import annotations

import warnings
from dataclasses import dataclass, field, fields
from typing import Any, Literal, Optional

Expand Down Expand Up @@ -91,23 +90,12 @@ def basis(self) -> Literal["ground-rydberg"]:
return "ground-rydberg"

def _undefined_fields(self) -> list[str]:
optional = [
"bottom_detuning",
"max_duration",
# TODO: "total_bottom_detuning"
]
optional = ["bottom_detuning", "max_duration", "total_bottom_detuning"]
return [field for field in optional if getattr(self, field) is None]

def is_virtual(self) -> bool:
"""Whether the channel is virtual (i.e. partially defined)."""
virtual_dmm = bool(self._undefined_fields())
if not virtual_dmm and self.total_bottom_detuning is None:
warnings.warn(
"From v0.18 and onwards, `total_bottom_detuning` must be"
" defined to define a physical DMM.",
DeprecationWarning,
)
return virtual_dmm
return bool(self._undefined_fields())

def validate_pulse(
self,
Expand Down
97 changes: 46 additions & 51 deletions pulser-core/pulser/devices/_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Examples of realistic devices."""
import dataclasses
import warnings

import numpy as np

Expand All @@ -22,55 +21,53 @@
from pulser.devices._device_datacls import Device
from pulser.register.special_layouts import TriangularLatticeLayout

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
DigitalAnalogDevice = Device(
name="DigitalAnalogDevice",
dimensions=2,
rydberg_level=70,
max_atom_num=100,
max_radial_distance=50,
min_atom_distance=4,
supports_slm_mask=True,
channel_objects=(
Rydberg.Global(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 2.5,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
Rydberg.Local(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 10,
min_retarget_interval=220,
fixed_retarget_t=0,
max_targets=1,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
Raman.Local(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 10,
min_retarget_interval=220,
fixed_retarget_t=0,
max_targets=1,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
DigitalAnalogDevice = Device(
name="DigitalAnalogDevice",
dimensions=2,
rydberg_level=70,
max_atom_num=100,
max_radial_distance=50,
min_atom_distance=4,
supports_slm_mask=True,
channel_objects=(
Rydberg.Global(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 2.5,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
dmm_objects=(
DMM(
clock_period=4,
min_duration=16,
max_duration=2**26,
bottom_detuning=-2 * np.pi * 20,
# TODO: total_bottom_detuning=-2 * np.pi * 2000
),
Rydberg.Local(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 10,
min_retarget_interval=220,
fixed_retarget_t=0,
max_targets=1,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
Raman.Local(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 10,
min_retarget_interval=220,
fixed_retarget_t=0,
max_targets=1,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
)
),
dmm_objects=(
DMM(
clock_period=4,
min_duration=16,
max_duration=2**26,
bottom_detuning=-2 * np.pi * 20,
total_bottom_detuning=-2 * np.pi * 2000,
),
),
)

AnalogDevice = Device(
name="AnalogDevice",
Expand Down Expand Up @@ -105,9 +102,7 @@

# Legacy devices (deprecated, should not be used in new sequences)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
Chadoq2 = dataclasses.replace(DigitalAnalogDevice, name="Chadoq2")
Chadoq2 = dataclasses.replace(DigitalAnalogDevice, name="Chadoq2")

IroiseMVP = Device(
name="IroiseMVP",
Expand Down
52 changes: 2 additions & 50 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,7 @@ def _roundtrip(abstract_device):
device = deserialize_device(json.dumps(abstract_device))
assert json.loads(device.to_abstract_repr()) == abstract_device

if abstract_device["name"] == "DigitalAnalogDevice":
with pytest.warns(
DeprecationWarning, match="From v0.18 and onwards"
):
_roundtrip(abstract_device)
else:
_roundtrip(abstract_device)
_roundtrip(abstract_device)

def test_exceptions(self, abstract_device):
def check_error_raised(
Expand All @@ -238,13 +232,7 @@ def check_error_raised(
assert re.search(re.escape(err_msg), str(cause)) is not None
return cause

if abstract_device["name"] == "DigitalAnalogDevice":
with pytest.warns(
DeprecationWarning, match="From v0.18 and onwards"
):
good_device = deserialize_device(json.dumps(abstract_device))
else:
good_device = deserialize_device(json.dumps(abstract_device))
good_device = deserialize_device(json.dumps(abstract_device))

check_error_raised(
abstract_device, TypeError, "'obj_str' must be a string"
Expand Down Expand Up @@ -1312,9 +1300,6 @@ def _get_expression(op: dict) -> Any:

class TestDeserialization:
@pytest.mark.parametrize("is_phys_Chadoq2", [True, False])
@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_deserialize_device_and_channels(self, is_phys_Chadoq2) -> None:
kwargs = {}
if is_phys_Chadoq2:
Expand All @@ -1336,9 +1321,6 @@ def test_deserialize_device_and_channels(self, is_phys_Chadoq2) -> None:
_coords = np.concatenate((_coords, -_coords))

@pytest.mark.parametrize("layout_coords", [None, _coords])
@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_deserialize_register(self, layout_coords):
if layout_coords is not None:
reg_layout = RegisterLayout(layout_coords)
Expand Down Expand Up @@ -1413,9 +1395,6 @@ def test_deserialize_register3D(self, layout_coords):
assert "layout" not in s
assert seq.register.layout is None

@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_deserialize_mappable_register(self):
layout_coords = (5 * np.arange(8)).reshape((4, 2))
s = _get_serialized_seq(
Expand Down Expand Up @@ -1537,9 +1516,6 @@ def test_deserialize_seq_with_mag_field(self):
assert np.all(seq.magnetic_field == mag_field)

@pytest.mark.parametrize("without_default", [True, False])
@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_deserialize_variables(self, without_default):
s = _get_serialized_seq(
variables={
Expand Down Expand Up @@ -1677,9 +1653,6 @@ def test_deserialize_non_parametrized_op(self, op):
],
ids=_get_kind,
)
@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_deserialize_non_parametrized_waveform(self, wf_obj):
s = _get_serialized_seq(
operations=[
Expand Down Expand Up @@ -1759,9 +1732,6 @@ def test_deserialize_non_parametrized_waveform(self, wf_obj):
assert isinstance(wf, CustomWaveform)
assert np.array_equal(wf._samples, wf_obj["samples"])

@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_deserialize_measurement(self):
s = _get_serialized_seq()
_check_roundtrip(s)
Expand Down Expand Up @@ -1849,9 +1819,6 @@ def test_deserialize_measurement(self):
],
ids=_get_op,
)
@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_deserialize_parametrized_op(self, op):
s = _get_serialized_seq(
operations=[op],
Expand Down Expand Up @@ -2001,9 +1968,6 @@ def test_deserialize_parametrized_op(self, op):
),
],
)
@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_deserialize_parametrized_pulse(self, op, pulse_cls):
s = _get_serialized_seq(
operations=[op],
Expand Down Expand Up @@ -2234,9 +2198,6 @@ def test_deserialize_eom_ops(self, correct_phase_drift, var_detuning_on):
],
ids=_get_kind,
)
@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_deserialize_parametrized_waveform(self, wf_obj):
# var1,2 = duration 1000, 2000
# var2,4 = value - 2, 5
Expand Down Expand Up @@ -2357,9 +2318,6 @@ def test_deserialize_parametrized_waveform(self, wf_obj):
],
ids=_get_expression,
)
@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_deserialize_param(self, json_param):
s = _get_serialized_seq(
operations=[
Expand Down Expand Up @@ -2478,9 +2436,6 @@ def test_deserialize_param(self, json_param):
],
ids=["bad_var", "bad_param", "bad_exp"],
)
@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_param_exceptions(self, param, msg, patch_jsonschema):
s = _get_serialized_seq(
[
Expand All @@ -2503,9 +2458,6 @@ def test_param_exceptions(self, param, msg, patch_jsonschema):
with pytest.raises(std_error, **extra_params):
Sequence.from_abstract_repr(json.dumps(s))

@pytest.mark.filterwarnings(
"ignore:From v0.18 and onwards,.*:DeprecationWarning"
)
def test_unknow_waveform(self):
s = _get_serialized_seq(
[
Expand Down
17 changes: 10 additions & 7 deletions tests/test_dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,20 @@ def test_init(self, physical_dmm):
DMM.Local(None, None, bottom_detuning=1)

def test_validate_pulse(self, physical_dmm):
# both local and total bottom detuning must be defined to have a
# physical DMM
assert (virtual_local_dmm := DMM(bottom_detuning=-1)).is_virtual()
assert (virtual_dmm := DMM(total_bottom_detuning=-10)).is_virtual()
assert not physical_dmm.is_virtual()

# Detuning applied to DMM must be negative
pos_det_pulse = Pulse.ConstantPulse(100, 0, 1e-3, 0)
with pytest.raises(
ValueError, match="The detuning in a DMM must not be positive."
):
physical_dmm.validate_pulse(pos_det_pulse)

# Local detuning is given by Pulse.detuning * local_weight
too_low_pulse = Pulse.ConstantPulse(
100, 0, physical_dmm.bottom_detuning - 0.01, 0
)
Expand All @@ -311,8 +319,6 @@ def test_validate_pulse(self, physical_dmm):
physical_dmm.validate_pulse(too_low_pulse)

# Should be valid in a virtual DMM without local bottom detuning
virtual_dmm = DMM(total_bottom_detuning=-10)
assert virtual_dmm.is_virtual()
virtual_dmm.validate_pulse(too_low_pulse)

# Not too low if weights of detuning map are lower than 1
Expand All @@ -329,8 +335,5 @@ def test_validate_pulse(self, physical_dmm):
# local detunings match bottom_detuning, global don't
physical_dmm.validate_pulse(too_low_pulse, det_map)

# Should be valid in a physical DMM without global bottom detuning
physical_dmm = DMM(bottom_detuning=-1)
with pytest.warns(DeprecationWarning, match="From v0.18 and onwards"):
assert not physical_dmm.is_virtual()
physical_dmm.validate_pulse(too_low_pulse, det_map)
# Should be valid in a virtual DMM without total bottom detuning
virtual_local_dmm.validate_pulse(too_low_pulse, det_map)

0 comments on commit 8550104

Please sign in to comment.