diff --git a/pulser-core/pulser/sequence/helpers/__init__.py b/pulser-core/pulser/sequence/helpers/__init__.py new file mode 100644 index 00000000..d456a930 --- /dev/null +++ b/pulser-core/pulser/sequence/helpers/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Pulser Development Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module containing helpers of the sequence class definition.""" diff --git a/pulser-core/pulser/sequence/_seq_str.py b/pulser-core/pulser/sequence/helpers/_seq_str.py similarity index 100% rename from pulser-core/pulser/sequence/_seq_str.py rename to pulser-core/pulser/sequence/helpers/_seq_str.py diff --git a/pulser-core/pulser/sequence/helpers/_switch_device.py b/pulser-core/pulser/sequence/helpers/_switch_device.py new file mode 100644 index 00000000..1e18f6bf --- /dev/null +++ b/pulser-core/pulser/sequence/helpers/_switch_device.py @@ -0,0 +1,386 @@ +# Copyright 2024 Pulser Development Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Function to switch the Device in a Sequence.""" +from __future__ import annotations + +import dataclasses +import itertools +import warnings +from typing import TYPE_CHECKING, Any, cast + +import numpy as np + +from pulser.channels.base_channel import Channel +from pulser.channels.dmm import _get_dmm_name +from pulser.channels.eom import RydbergEOM +from pulser.devices._device_datacls import BaseDevice + +if TYPE_CHECKING: + from pulser.sequence.sequence import Sequence + + +def switch_device( + seq: Sequence, new_device: BaseDevice, strict: bool = False +) -> Sequence: + """Replicate the sequence with a different device. + + This method is designed to replicate the sequence with as few changes + to the original contents as possible. + If the `strict` option is chosen, the device switch will fail whenever + it cannot guarantee that the new sequence's contents will not be + modified in the process. + + Args: + seq: The Sequence whose device should be switched. + new_device: The target device instance. + strict: Enforce a strict match between devices and channels to + guarantee the pulse sequence is left unchanged. + + Returns: + The sequence on the new device, using the match channels of + the former device declared in the sequence. + """ + # Check if the device is new or not + + if seq.device == new_device: + warnings.warn( + "Switching a sequence to the same device" + + " returns the sequence unchanged.", + stacklevel=2, + ) + return seq + + if seq._in_xy: + interaction_param = "interaction_coeff_xy" + name_in_msg = "XY interaction coefficient" + else: + interaction_param = "rydberg_level" + name_in_msg = "Rydberg level" + + if getattr(new_device, interaction_param) != getattr( + seq._device, interaction_param + ): + if strict: + raise ValueError( + "Strict device match failed because the" + f" devices have different {name_in_msg}s." + ) + warnings.warn( + f"Switching to a device with a different {name_in_msg}," + " check that the expected interactions still hold.", + stacklevel=2, + ) + + def check_retarget(ch_obj: Channel) -> bool: + # Check the min_retarget_interval when it is is not + # fully covered by the fixed_retarget_t + return ch_obj.addressing == "Local" and cast( + int, ch_obj.fixed_retarget_t + ) < cast(int, ch_obj.min_retarget_interval) + + def check_channels_match( + old_ch_name: str, + new_ch_obj: Channel, + active_eom_channels: list, + strict: bool, + ) -> tuple[str, str]: + """Check whether two channels match. + + Returns a tuple that contains a non-strict error message and a + strict error message. If the channel matches, the two error + messages are empty strings. If strict=False, only non-strict + conditions are checked, and only the non-strict error message + will eventually be filled. If strict=True, all the conditions are + checked - the returned error can either be non-strict or strict. + """ + old_ch_obj = seq.declared_channels[old_ch_name] + # We verify the channel class then + # check whether the addressing is Global or Local + type_match = type(old_ch_obj) is type(new_ch_obj) + basis_match = old_ch_obj.basis == new_ch_obj.basis + addressing_match = old_ch_obj.addressing == new_ch_obj.addressing + if not (type_match and basis_match and addressing_match): + # If there already is a message, keeps it + return (" with the right type, basis and addressing.", "") + if old_ch_name in active_eom_channels: + # Uses EOM mode, so the new device needs a matching + # EOM configuration + if new_ch_obj.eom_config is None: + return (" with an EOM configuration.", "") + if strict: + if not seq.is_parametrized(): + if ( + new_ch_obj.eom_config.mod_bandwidth + != cast( + RydbergEOM, old_ch_obj.eom_config + ).mod_bandwidth + ): + return ( + "", + " with the same mod_bandwidth for the EOM.", + ) + else: + # Eom configs have to match is Sequence is parametrized + new_eom_config = dataclasses.asdict( + cast(RydbergEOM, new_ch_obj.eom_config) + ) + old_eom_config = dataclasses.asdict( + cast(RydbergEOM, old_ch_obj.eom_config) + ) + # However, multiple_beam_control only matters when + # the two beams are controlled + if len(old_eom_config["controlled_beams"]) == 1: + new_eom_config.pop("multiple_beam_control") + old_eom_config.pop("multiple_beam_control") + # Controlled beams only matter when only one beam + # is controlled by the new eom + if len(new_eom_config["controlled_beams"]) > 1: + new_eom_config.pop("controlled_beams") + old_eom_config.pop("controlled_beams") + # Controlled_beams doesn't matter if the two EOMs + # control two beams + elif set(new_eom_config["controlled_beams"]) == set( + old_eom_config["controlled_beams"] + ): + new_eom_config.pop("controlled_beams") + old_eom_config.pop("controlled_beams") + + # And custom_buffer_time doesn't have to match as long + # as `Channel_eom_buffer_time`` does + if ( + new_ch_obj._eom_buffer_time + == old_ch_obj._eom_buffer_time + ): + new_eom_config.pop("custom_buffer_time") + old_eom_config.pop("custom_buffer_time") + if new_eom_config != old_eom_config: + return ("", " with the same EOM configuration.") + if not strict: + return ("", "") + + params_to_check = [ + "mod_bandwidth", + "fixed_retarget_t", + "clock_period", + ] + if check_retarget(old_ch_obj) or check_retarget(new_ch_obj): + params_to_check.append("min_retarget_interval") + for param_ in params_to_check: + if getattr(new_ch_obj, param_) != getattr(old_ch_obj, param_): + return ("", f" with the same {param_}.") + else: + return ("", "") + + def is_good_match( + channel_match: dict[str, str], + reusable_channels: bool, + all_channels_new_device: dict[str, Channel], + active_eom_channels: list, + strict: bool, + ) -> bool: + used_channels_new_device = list(channel_match.values()) + if not reusable_channels and len(set(used_channels_new_device)) < len( + used_channels_new_device + ): + return False + for old_ch_name, new_ch_name in channel_match.items(): + if check_channels_match( + old_ch_name, + all_channels_new_device[new_ch_name], + active_eom_channels, + strict, + ) != ("", ""): + return False + return True + + def raise_error_non_matching_channel( + reusable_channels: bool, + all_channels_new_device: dict[str, Channel], + active_eom_channels: list, + strict: bool, + ) -> None: + strict_error_message = "" + ch_match_err = "" + channel_match: dict[str, Any] = {} + for old_ch_name, old_ch_obj in seq.declared_channels.items(): + channel_match[old_ch_name] = None + base_msg = f"No match for channel {old_ch_name}" + # Find the corresponding channel on the new device + for new_ch_id, new_ch_obj in all_channels_new_device.items(): + if ( + not reusable_channels + and new_ch_id in channel_match.values() + ): + # Channel already matched and can't be reused + continue + (ch_match_err_suffix, strict_error_message_suffix) = ( + check_channels_match( + old_ch_name, + new_ch_obj, + active_eom_channels, + strict, + ) + ) + if (ch_match_err_suffix, strict_error_message_suffix) == ( + "", + "", + ): + channel_match[old_ch_name] = new_ch_id + # Found a match, clear match error msg for this channel + if ch_match_err.startswith(base_msg): + ch_match_err = "" + if strict_error_message.startswith(base_msg): + strict_error_message = "" + break + elif ch_match_err_suffix != "": + ch_match_err = ( + ch_match_err or base_msg + ch_match_err_suffix + ) + else: + strict_error_message = ( + base_msg + strict_error_message_suffix + ) + assert None in channel_match.values() + if strict_error_message: + raise ValueError(strict_error_message) + raise TypeError(ch_match_err) + + def build_sequence_from_matching( + new_device: BaseDevice, + channel_match: dict[str, str], + active_eom_channels: list, + strict: bool, + ) -> Sequence: + # Initialize the new sequence (works for Sequence subclasses too) + new_seq = type(seq)(register=seq._register, device=new_device) + dmm_calls: list[str] = [] + # Copy the variables to the new sequence + new_seq._variables = seq.declared_variables + for call in seq._calls[1:] + seq._to_build_calls: + # Switch the old id with the correct id + sw_channel_args = list(call.args) + sw_channel_kw_args = call.kwargs.copy() + if not ( + call.name == "declare_channel" + or call.name == "config_detuning_map" + or call.name == "config_slm_mask" + or call.name == "add_dmm_detuning" + ): + pass + # if calling declare_channel + elif "name" in sw_channel_kw_args: # pragma: no cover + sw_channel_kw_args["channel_id"] = channel_match[ + sw_channel_kw_args["name"] + ] + elif "channel_id" in sw_channel_kw_args: # pragma: no cover + sw_channel_kw_args["channel_id"] = channel_match[ + sw_channel_args[0] + ] + elif call.name == "declare_channel": + sw_channel_args[1] = channel_match[sw_channel_args[0]] + # if adding a detuning waveform to the dmm + elif "dmm_name" in sw_channel_kw_args: # program: no cover + sw_channel_kw_args["dmm_name"] = channel_match[ + sw_channel_kw_args["dmm_name"] + ] + elif call.name == "add_dmm_detuning": + sw_channel_args[1] = channel_match[sw_channel_args[1]] + # if configuring a detuning map or an SLM mask + else: + assert ( + call.name == "config_detuning_map" + or call.name == "config_slm_mask" + ) + if "dmm_id" in sw_channel_kw_args: # pragma: no cover + dmm_called = _get_dmm_name( + sw_channel_kw_args["dmm_id"], dmm_calls + ) + sw_channel_kw_args["dmm_id"] = channel_match[dmm_called] + else: + dmm_called = _get_dmm_name(sw_channel_args[1], dmm_calls) + sw_channel_args[1] = channel_match[dmm_called] + dmm_calls.append(dmm_called) + channel_match[dmm_called] = _get_dmm_name( + channel_match[dmm_called], + list(new_seq.declared_channels.keys()), + ) + getattr(new_seq, call.name)(*sw_channel_args, **sw_channel_kw_args) + + if strict: + for eom_channel in active_eom_channels: + current_samples = seq._schedule[eom_channel].get_samples() + new_samples = new_seq._schedule[eom_channel].get_samples() + if ( + not np.all( + np.isclose(current_samples.amp, new_samples.amp) + ) + or not np.all( + np.isclose(current_samples.det, new_samples.det) + ) + or not np.all( + np.isclose(current_samples.phase, new_samples.phase) + ) + ): + raise ValueError( + f"No match for channel {eom_channel} with an" + " EOM configuration that does not change the" + " samples." + ) + return new_seq + + # Channel match + active_eom_channels = [ + {**dict(zip(("channel",), call.args)), **call.kwargs}["channel"] + for call in seq._calls + seq._to_build_calls + if call.name == "enable_eom_mode" + ] + all_channels_new_device = { + **new_device.channels, + **new_device.dmm_channels, + } + possible_channel_match: list[dict[str, str]] = [] + for channels_comb in itertools.product( + all_channels_new_device, repeat=len(seq.declared_channels) + ): + channel_match = dict(zip(seq.declared_channels, channels_comb)) + if is_good_match( + channel_match, + new_device.reusable_channels, + all_channels_new_device, + active_eom_channels, + strict, + ): + possible_channel_match.append(channel_match) + if not possible_channel_match: + raise_error_non_matching_channel( + new_device.reusable_channels, + all_channels_new_device, + active_eom_channels, + strict, + ) + err_channel_match = {} + for channel_match in possible_channel_match: + try: + return build_sequence_from_matching( + new_device, channel_match, active_eom_channels, strict + ) + except ValueError as e: + err_channel_match[tuple(channel_match.items())] = e.args + continue + raise ValueError( + "No matching found between declared channels and channels in the " + "new device that does not modify the samples of the Sequence. " + "Here is a list of matchings tested and their associated errors: " + f"{err_channel_match}" + ) diff --git a/pulser-core/pulser/sequence/sequence.py b/pulser-core/pulser/sequence/sequence.py index 5d3166d8..f05c74b0 100644 --- a/pulser-core/pulser/sequence/sequence.py +++ b/pulser-core/pulser/sequence/sequence.py @@ -69,7 +69,8 @@ _TimeSlot, ) from pulser.sequence._seq_drawer import Figure, draw_sequence -from pulser.sequence._seq_str import seq_to_str +from pulser.sequence.helpers._seq_str import seq_to_str +from pulser.sequence.helpers._switch_device import switch_device from pulser.waveforms import Waveform DeviceType = TypeVar("DeviceType", bound=BaseDevice) @@ -748,180 +749,7 @@ def switch_device( The sequence on the new device, using the match channels of the former device declared in the sequence. """ - # Check if the device is new or not - - if self._device == new_device: - warnings.warn( - "Switching a sequence to the same device" - + " returns the sequence unchanged.", - stacklevel=2, - ) - return self - - if self._in_xy: - interaction_param = "interaction_coeff_xy" - name_in_msg = "XY interaction coefficient" - else: - interaction_param = "rydberg_level" - name_in_msg = "Rydberg level" - - if getattr(new_device, interaction_param) != getattr( - self._device, interaction_param - ): - if strict: - raise ValueError( - "Strict device match failed because the" - f" devices have different {name_in_msg}s." - ) - warnings.warn( - f"Switching to a device with a different {name_in_msg}," - " check that the expected interactions still hold.", - stacklevel=2, - ) - - def check_retarget(ch_obj: Channel) -> bool: - # Check the min_retarget_interval when it is is not - # fully covered by the fixed_retarget_t - return ch_obj.addressing == "Local" and cast( - int, ch_obj.fixed_retarget_t - ) < cast(int, ch_obj.min_retarget_interval) - - # Channel match - channel_match: dict[str, Any] = {} - strict_error_message = "" - ch_match_err = "" - active_eom_channels = [ - {**dict(zip(("channel",), call.args)), **call.kwargs}["channel"] - for call in self._calls + self._to_build_calls - if call.name == "enable_eom_mode" - ] - all_channels_new_device = { - **new_device.channels, - **new_device.dmm_channels, - } - - for old_ch_name, old_ch_obj in self.declared_channels.items(): - channel_match[old_ch_name] = None - base_msg = f"No match for channel {old_ch_name}" - # Find the corresponding channel on the new device - for new_ch_id, new_ch_obj in all_channels_new_device.items(): - if ( - not new_device.reusable_channels - and new_ch_id in channel_match.values() - ): - # Channel already matched and can't be reused - continue - - # We verify the channel class then - # check whether the addressing is Global or Local - type_match = type(old_ch_obj) is type(new_ch_obj) - basis_match = old_ch_obj.basis == new_ch_obj.basis - addressing_match = ( - old_ch_obj.addressing == new_ch_obj.addressing - ) - if not (type_match and basis_match and addressing_match): - # If there already is a message, keeps it - ch_match_err = ch_match_err or ( - base_msg - + " with the right type, basis and addressing." - ) - continue - if old_ch_name in active_eom_channels: - # Uses EOM mode, so the new device needs a matching - # EOM configuration - if new_ch_obj.eom_config is None: - ch_match_err = base_msg + " with an EOM configuration." - continue - if ( - # TODO: Improvements to this check: - # 1. multiple_beam_control doesn't matter when there - # is only one beam - # 2. custom_buffer_time doesn't have to match as long - # as `Channel_eom_buffer_time`` does - new_ch_obj.eom_config != old_ch_obj.eom_config - and strict - ): - strict_error_message = ( - base_msg + " with the same EOM configuration." - ) - continue - if not strict: - channel_match[old_ch_name] = new_ch_id - # Found a match, clear match error msg for this channel - if ch_match_err.startswith(base_msg): - ch_match_err = "" - break - - params_to_check = [ - "mod_bandwidth", - "fixed_retarget_t", - "clock_period", - ] - if isinstance(old_ch_obj, DMM): - params_to_check.append("bottom_detuning") - params_to_check.append("total_bottom_detuning") - if check_retarget(old_ch_obj) or check_retarget(new_ch_obj): - params_to_check.append("min_retarget_interval") - for param_ in params_to_check: - if getattr(new_ch_obj, param_) != getattr( - old_ch_obj, param_ - ): - strict_error_message = ( - base_msg + f" with the same {param_}." - ) - break - else: - # Only reached if all checks passed - channel_match[old_ch_name] = new_ch_id - # Found a match, clear match error msgs for this channel - if ch_match_err.startswith(base_msg): - ch_match_err = "" - if strict_error_message.startswith(base_msg): - strict_error_message = "" - break - - if None in channel_match.values(): - if strict_error_message: - raise ValueError(strict_error_message) - else: - raise TypeError(ch_match_err) - # Initialize the new sequence (works for Sequence subclasses too) - new_seq = type(self)(register=self._register, device=new_device) - dmm_calls: list[str] = [] - # Copy the variables to the new sequence - new_seq._variables = self.declared_variables - for call in self._calls[1:] + self._to_build_calls: - # Switch the old id with the correct id - sw_channel_args = list(call.args) - sw_channel_kw_args = call.kwargs.copy() - if not ( - call.name == "declare_channel" - or call.name == "config_detuning_map" - or call.name == "config_slm_mask" - ): - pass - elif "name" in sw_channel_kw_args: # pragma: no cover - sw_channel_kw_args["channel_id"] = channel_match[ - sw_channel_kw_args["name"] - ] - elif "channel_id" in sw_channel_kw_args: # pragma: no cover - sw_channel_kw_args["channel_id"] = channel_match[ - sw_channel_args[0] - ] - elif "dmm_id" in sw_channel_kw_args: # pragma: no cover - sw_channel_kw_args["dmm_id"] = channel_match[ - _get_dmm_name(sw_channel_kw_args["dmm_id"], dmm_calls) - ] - dmm_calls.append(sw_channel_kw_args["dmm_id"]) - elif call.name == "declare_channel": - sw_channel_args[1] = channel_match[sw_channel_args[0]] - else: - sw_channel_args[1] = channel_match[ - _get_dmm_name(sw_channel_args[1], dmm_calls) - ] - dmm_calls.append(sw_channel_args[1]) - getattr(new_seq, call.name)(*sw_channel_args, **sw_channel_kw_args) - return new_seq + return switch_device(self, new_device, strict) @seq_decorators.block_if_measured def declare_channel( diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 5979c464..6254ea50 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -17,7 +17,8 @@ import dataclasses import itertools import json -from typing import Any +import re +from typing import Any, cast from unittest.mock import patch import numpy as np @@ -27,6 +28,7 @@ from pulser import Pulse, Register, Register3D, Sequence from pulser.channels import Raman, Rydberg from pulser.channels.dmm import DMM +from pulser.channels.eom import RydbergBeam, RydbergEOM from pulser.devices import AnalogDevice, DigitalAnalogDevice, MockDevice from pulser.devices._device_datacls import Device, VirtualDevice from pulser.register.base_register import BaseRegister @@ -357,8 +359,9 @@ def devices(): clock_period=4, min_duration=16, max_duration=2**26, - bottom_detuning=-2 * np.pi * 20, - total_bottom_detuning=-2 * np.pi * 2000, + # Better than DMM of DigitalAnalogDevice + bottom_detuning=-2 * np.pi * 40, + total_bottom_detuning=-2 * np.pi * 4000, ), ), ) @@ -742,40 +745,82 @@ def test_switch_device_down( ): # Can't find a match for the 2nd dmm_0 seq.switch_device(phys_Chadoq2) - # Strict switch imposes to have same bottom detuning for DMMs - with pytest.raises( - ValueError, - match="No match for channel dmm_0_1 with the same bottom_detuning.", - ): - # Can't find a match for the 1st dmm_0 + # There is no need to have same bottom detuning to have a strict switch + dmm_down = dataclasses.replace( + phys_Chadoq2.dmm_channels["dmm_0"], bottom_detuning=-10 + ) + new_seq = seq.switch_device( + dataclasses.replace(phys_Chadoq2, dmm_objects=(dmm_down, dmm_down)), + strict=True, + ) + assert list(new_seq.declared_channels.keys()) == [ + "global", + "dmm_0", + "dmm_1", + ] + seq.add_dmm_detuning(ConstantWaveform(100, -20), "dmm_0_1") + seq.add_dmm_detuning(ConstantWaveform(100, -20), dmm_name="dmm_0_1") + # Still works with reusable channels + new_seq = seq.switch_device( + dataclasses.replace( + phys_Chadoq2.to_virtual(), + reusable_channels=True, + dmm_objects=(dataclasses.replace(dmm_down, bottom_detuning=-20),), + ), + strict=True, + ) + assert list(new_seq.declared_channels.keys()) == [ + "global", + "dmm_0", + "dmm_0_1", + ] + # Still one compatible configuration + new_seq = seq.switch_device( + dataclasses.replace( + phys_Chadoq2, + dmm_objects=(phys_Chadoq2.dmm_channels["dmm_0"], dmm_down), + ), + strict=True, + ) + assert list(new_seq.declared_channels.keys()) == [ + "global", + "dmm_1", + "dmm_0", + ] + # No compatible configuration + error_msg = ( + "No matching found between declared channels and channels in the " + "new device that does not modify the samples of the Sequence. " + "Here is a list of matchings tested and their associated errors: " + "{(('global', 'rydberg_global'), ('dmm_0', 'dmm_0'), ('dmm_0_1', " + "'dmm_1')): ('The detunings on some atoms go below the local bottom " + "detuning of the DMM (-10 rad/µs).',), (('global', 'rydberg_global'), " + "('dmm_0', 'dmm_1'), ('dmm_0_1', 'dmm_0')): ('The detunings on some " + "atoms go below the local bottom detuning of the DMM (-10 rad/µs).',)}" + ) + with pytest.raises(ValueError, match=re.escape(error_msg)): seq.switch_device( dataclasses.replace( - phys_Chadoq2, - dmm_objects=( - phys_Chadoq2.dmm_channels["dmm_0"], - dataclasses.replace( - phys_Chadoq2.dmm_channels["dmm_0"], bottom_detuning=-10 - ), - ), + phys_Chadoq2, dmm_objects=(dmm_down, dmm_down) ), strict=True, ) - with pytest.raises( - ValueError, - match="No match for channel dmm_0_1 with the same " - "total_bottom_detuning.", - ): - # Can't find a match for the 1st dmm_0 + dmm_down = dataclasses.replace( + phys_Chadoq2.dmm_channels["dmm_0"], + bottom_detuning=-10, + total_bottom_detuning=-10, + ) + seq.switch_device( + dataclasses.replace( + phys_Chadoq2, + dmm_objects=(phys_Chadoq2.dmm_channels["dmm_0"], dmm_down), + ), + strict=True, + ) + with pytest.raises(ValueError, match=re.escape(error_msg)): seq.switch_device( dataclasses.replace( - phys_Chadoq2, - dmm_objects=( - phys_Chadoq2.dmm_channels["dmm_0"], - dataclasses.replace( - phys_Chadoq2.dmm_channels["dmm_0"], - total_bottom_detuning=-500, - ), - ), + phys_Chadoq2, dmm_objects=(dmm_down, dmm_down) ), strict=True, ) @@ -1029,17 +1074,38 @@ def test_switch_device_up( assert "digital" in seq.switch_device(devices[1], True).declared_channels +extended_eom = dataclasses.replace( + cast(RydbergEOM, AnalogDevice.channels["rydberg_global"].eom_config), + controlled_beams=tuple(RydbergBeam), + multiple_beam_control=True, + custom_buffer_time=None, +) +extended_eom_channel = dataclasses.replace( + AnalogDevice.channels["rydberg_global"], eom_config=extended_eom +) +extended_eom_device = dataclasses.replace( + AnalogDevice, channel_objects=(extended_eom_channel,) +) + + +@pytest.mark.parametrize("device", [AnalogDevice, extended_eom_device]) @pytest.mark.parametrize("mappable_reg", [False, True]) @pytest.mark.parametrize("parametrized", [False, True]) -def test_switch_device_eom(reg, mappable_reg, parametrized, patch_plt_show): +@pytest.mark.parametrize( + "extension_arg", ["amp", "control", "2control", "buffer_time"] +) +def test_switch_device_eom( + reg, device, mappable_reg, parametrized, extension_arg, patch_plt_show +): # Sequence with EOM blocks seq = init_seq( reg, - dataclasses.replace(AnalogDevice, max_atom_num=28), + dataclasses.replace(device, max_atom_num=28), "rydberg", "rydberg_global", [], parametrized=parametrized, + mappable_reg=mappable_reg, ) seq.enable_eom_mode("rydberg", amp_on=2.0, detuning_on=0.0) seq.add_eom_pulse("rydberg", 100, 0.0) @@ -1057,30 +1123,196 @@ def test_switch_device_eom(reg, mappable_reg, parametrized, patch_plt_show): seq.switch_device(DigitalAnalogDevice) ch_obj = seq.declared_channels["rydberg"] + wrong_eom_config = dataclasses.replace(ch_obj.eom_config, mod_bandwidth=20) + wrong_ch_obj = dataclasses.replace(ch_obj, eom_config=wrong_eom_config) + wrong_analog = dataclasses.replace( + device, channel_objects=(wrong_ch_obj,), max_atom_num=28 + ) + if parametrized: + # Can't switch if the two EOM configurations don't match + # If the modulation bandwidth is different + with pytest.raises( + ValueError, match=err_base + "with the same EOM configuration." + ): + seq.switch_device(wrong_analog, strict=True) + down_eom_configs = { + # If the amplitude is different + "amp": dataclasses.replace( + ch_obj.eom_config, max_limiting_amp=10 * 2 * np.pi + ), + # If less controlled beams/the controlled beam is not the same + "control": dataclasses.replace( + ch_obj.eom_config, + controlled_beams=(RydbergBeam.RED,), + multiple_beam_control=False, + ), + # If the multiple_beam_control is not the same + "2control": dataclasses.replace( + ch_obj.eom_config, + controlled_beams=( + tuple(RydbergBeam) + if device == extended_eom_device + else (RydbergBeam.RED,) + ), + multiple_beam_control=False, + ), + # If the buffer time is different + "buffer_time": dataclasses.replace( + ch_obj.eom_config, + custom_buffer_time=300, + ), + } + wrong_ch_obj = dataclasses.replace( + ch_obj, eom_config=down_eom_configs[extension_arg] + ) + wrong_analog = dataclasses.replace( + device, channel_objects=(wrong_ch_obj,), max_atom_num=28 + ) + with pytest.raises( + ValueError, match=err_base + "with the same EOM configuration." + ): + seq.switch_device(wrong_analog, strict=True) + else: + # Can't switch to eom if the modulation bandwidth doesn't match + with pytest.raises( + ValueError, + match=err_base + "with the same mod_bandwidth for the EOM.", + ): + seq.switch_device(wrong_analog, strict=True) + # Can if one Channel has a correct EOM configuration + new_seq = seq.switch_device( + dataclasses.replace( + wrong_analog, + channel_objects=(wrong_ch_obj, ch_obj), + channel_ids=("wrong_eom", "good_eom"), + ), + strict=True, + ) + assert new_seq.declared_channels == {"rydberg": ch_obj} + # Can if eom extends current eom + up_eom_configs = { + # Still raises for max_amplitude in parametrized Sequence + "amp": dataclasses.replace( + ch_obj.eom_config, max_limiting_amp=40 * 2 * np.pi + ), + # With one controlled beam, don't care about multiple_beam_control + # Raises an error if device is extended_eom_device (less options) + "control": dataclasses.replace( + ch_obj.eom_config, + controlled_beams=(RydbergBeam.BLUE,), + multiple_beam_control=False, + ), + # Using 2 controlled beams + # Raises an error if device is extended_eom_device (less options) + "2control": dataclasses.replace( + ch_obj.eom_config, + controlled_beams=tuple(RydbergBeam), + multiple_beam_control=False, + ), + # If custom buffer time is None + # Raises an error if device is extended_eom_device + "buffer_time": dataclasses.replace( + ch_obj.eom_config, + custom_buffer_time=None, + ), + } + up_eom_config = up_eom_configs[extension_arg] + up_ch_obj = dataclasses.replace(ch_obj, eom_config=up_eom_config) + up_analog = dataclasses.replace( + device, channel_objects=(up_ch_obj,), max_atom_num=28 + ) + if ( + (parametrized and extension_arg == "amp") + or ( + parametrized + and extension_arg in ["control", "2control"] + and device == extended_eom_device + ) + or ( + parametrized + and extension_arg == "buffer_time" + and device == AnalogDevice + ) + ): + with pytest.raises( + ValueError, + match=err_base + "with the same EOM configuration.", + ): + seq.switch_device(up_analog, strict=True) + return + if device == extended_eom_device: + if extension_arg in ["control", "2control"]: + with pytest.raises( + ValueError, + match="No match for channel rydberg with an EOM configuration", + ): + seq.switch_device(up_analog, strict=True) + return + elif extension_arg == "buffer_time": + with pytest.warns( + UserWarning, match="Switching a sequence to the same device" + ): + up_seq = seq.switch_device(up_analog, strict=True) + else: + up_seq = seq.switch_device(up_analog, strict=True) + else: + up_seq = seq.switch_device(up_analog, strict=True) + build_kwargs = {} + if parametrized: + build_kwargs["delay"] = 120 + if mappable_reg: + build_kwargs["qubits"] = {"q0": 0} + og_eom_block = ( + (seq.build(**build_kwargs) if build_kwargs else seq) + ._schedule["rydberg"] + .eom_blocks[0] + ) + up_eom_block = ( + (up_seq.build(**build_kwargs) if build_kwargs else up_seq) + ._schedule["rydberg"] + .eom_blocks[0] + ) + assert og_eom_block.detuning_on == up_eom_block.detuning_on + assert og_eom_block.rabi_freq == up_eom_block.rabi_freq + assert og_eom_block.detuning_off == up_eom_block.detuning_off + + # Some parameters might modify the samples mod_eom_config = dataclasses.replace( - ch_obj.eom_config, max_limiting_amp=10 * 2 * np.pi + ch_obj.eom_config, max_limiting_amp=5 * 2 * np.pi ) mod_ch_obj = dataclasses.replace(ch_obj, eom_config=mod_eom_config) mod_analog = dataclasses.replace( - AnalogDevice, channel_objects=(mod_ch_obj,), max_atom_num=28 + device, channel_objects=(mod_ch_obj,), max_atom_num=28 ) - with pytest.raises( - ValueError, match=err_base + "with the same EOM configuration." - ): + err_msg = ( + "No matching found between declared channels and channels in " + "the new device that does not modify the samples of the " + "Sequence. Here is a list of matchings tested and their " + "associated errors: {(('rydberg', 'rydberg_global'),): ('No " + "match for channel rydberg with an EOM configuration that " + "does not change the samples." + ) + if parametrized: + with pytest.raises( + ValueError, + match=err_base + "with the same EOM configuration.", + ): + seq.switch_device(mod_analog, strict=True) + return + with pytest.raises(ValueError, match=re.escape(err_msg)): seq.switch_device(mod_analog, strict=True) - mod_seq = seq.switch_device(mod_analog, strict=False) - if parametrized: - seq = seq.build(delay=120) - mod_seq = mod_seq.build(delay=120) - og_eom_block = seq._schedule["rydberg"].eom_blocks[0] - mod_eom_block = mod_seq._schedule["rydberg"].eom_blocks[0] + mod_eom_block = ( + (mod_seq.build(**build_kwargs) if build_kwargs else mod_seq) + ._schedule["rydberg"] + .eom_blocks[0] + ) assert og_eom_block.detuning_on == mod_eom_block.detuning_on assert og_eom_block.rabi_freq == mod_eom_block.rabi_freq assert og_eom_block.detuning_off != mod_eom_block.detuning_off # Test drawing in eom mode - seq.draw() + (seq.build(**build_kwargs) if build_kwargs else seq).draw() def test_target(reg, device):