Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soften switching device with strict conditions #724

Merged
merged 16 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
376 changes: 251 additions & 125 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import copy
import itertools
import json
import os
import warnings
Expand Down Expand Up @@ -785,10 +786,223 @@ def check_retarget(ch_obj: Channel) -> bool:
int, ch_obj.fixed_retarget_t
) < cast(int, ch_obj.min_retarget_interval)

def check_channels_match(
old_ch_name: str,
new_ch_obj: Channel,
active_eom_channels: list,
strict: bool,
) -> tuple[str, str]:
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
old_ch_obj = self.declared_channels[old_ch_name]
# We verify the channel class then
# check whether the addressing is Global or Local
type_match = type(old_ch_obj) is type(new_ch_obj)
basis_match = old_ch_obj.basis == new_ch_obj.basis
addressing_match = old_ch_obj.addressing == new_ch_obj.addressing
if not (type_match and basis_match and addressing_match):
# If there already is a message, keeps it
return (" with the right type, basis and addressing.", "")
if old_ch_name in active_eom_channels:
# Uses EOM mode, so the new device needs a matching
# EOM configuration
if new_ch_obj.eom_config is None:
return (" with an EOM configuration.", "")
if strict:
if (
not self.is_parametrized()
and new_ch_obj.eom_config.mod_bandwidth
!= cast(
RydbergEOM, old_ch_obj.eom_config
).mod_bandwidth
):
return (
"",
" with the same mod_bandwidth for the EOM.",
)
if (
self.is_parametrized()
a-corni marked this conversation as resolved.
Show resolved Hide resolved
and new_ch_obj.eom_config != old_ch_obj.eom_config
):
return ("", " with the same EOM configuration.")
if not strict:
return ("", "")

params_to_check = [
"mod_bandwidth",
"fixed_retarget_t",
"clock_period",
]
if check_retarget(old_ch_obj) or check_retarget(new_ch_obj):
params_to_check.append("min_retarget_interval")
for param_ in params_to_check:
if getattr(new_ch_obj, param_) != getattr(old_ch_obj, param_):
return ("", f" with the same {param_}.")
else:
return ("", "")

def is_good_match(
channel_match: dict[str, str],
reusable_channels: bool,
all_channels_new_device: dict[str, Channel],
active_eom_channels: list,
strict: bool,
) -> bool:
used_channels_new_device = list(channel_match.values())
if not reusable_channels and len(
set(used_channels_new_device)
) < len(used_channels_new_device):
return False
for old_ch_name, new_ch_name in channel_match.items():
if check_channels_match(
old_ch_name,
all_channels_new_device[new_ch_name],
active_eom_channels,
strict,
) != ("", ""):
return False
return True

def raise_error_non_matching_channel(
reusable_channels: bool,
all_channels_new_device: dict[str, Channel],
active_eom_channels: list,
strict: bool,
) -> None:
strict_error_message = ""
ch_match_err = ""
channel_match: dict[str, Any] = {}
for old_ch_name, old_ch_obj in self.declared_channels.items():
channel_match[old_ch_name] = None
base_msg = f"No match for channel {old_ch_name}"
# Find the corresponding channel on the new device
for new_ch_id, new_ch_obj in all_channels_new_device.items():
if (
not reusable_channels
and new_ch_id in channel_match.values()
):
# Channel already matched and can't be reused
continue
(ch_match_err_suffix, strict_error_message_suffix) = (
check_channels_match(
old_ch_name,
new_ch_obj,
active_eom_channels,
strict,
)
)
if (ch_match_err_suffix, strict_error_message_suffix) == (
"",
"",
):
channel_match[old_ch_name] = new_ch_id
# Found a match, clear match error msg for this channel
if ch_match_err.startswith(base_msg):
ch_match_err = ""
if strict_error_message.startswith(base_msg):
strict_error_message = ""
break
elif ch_match_err_suffix != "":
ch_match_err = (
ch_match_err
if (
"type, basis and addressing"
in ch_match_err_suffix
and ch_match_err
)
else base_msg + ch_match_err_suffix
)
a-corni marked this conversation as resolved.
Show resolved Hide resolved
else:
strict_error_message = (
base_msg + strict_error_message_suffix
)
assert None in channel_match.values()
if strict_error_message:
raise ValueError(strict_error_message)
raise TypeError(ch_match_err)

def build_sequence_from_matching(
new_device: DeviceType,
channel_match: dict[str, str],
active_eom_channels: list,
strict: bool,
) -> Sequence:
# Initialize the new sequence (works for Sequence subclasses too)
new_seq = type(self)(register=self._register, device=new_device)
dmm_calls: list[str] = []
# Copy the variables to the new sequence
new_seq._variables = self.declared_variables
for call in self._calls[1:] + self._to_build_calls:
# Switch the old id with the correct id
sw_channel_args = list(call.args)
sw_channel_kw_args = call.kwargs.copy()
if not (
call.name == "declare_channel"
or call.name == "config_detuning_map"
or call.name == "config_slm_mask"
or call.name == "add_dmm_detuning"
):
pass
# if calling declare_channel
elif "name" in sw_channel_kw_args: # pragma: no cover
sw_channel_kw_args["channel_id"] = channel_match[
sw_channel_kw_args["name"]
]
elif "channel_id" in sw_channel_kw_args: # pragma: no cover
sw_channel_kw_args["channel_id"] = channel_match[
sw_channel_args[0]
]
elif call.name == "declare_channel":
sw_channel_args[1] = channel_match[sw_channel_args[0]]
# if adding a detuning waveform to the dmm
elif "dmm_name" in sw_channel_kw_args: # program: no cover
sw_channel_kw_args["dmm_name"] = channel_match[
sw_channel_kw_args["dmm_name"]
]
elif call.name == "add_dmm_detuning":
sw_channel_args[1] = channel_match[sw_channel_args[1]]
# if configuring a detuning map or an SLM mask
else:
assert (
call.name == "config_detuning_map"
or call.name == "config_slm_mask"
)
if "dmm_id" in sw_channel_kw_args: # pragma: no cover
dmm_called = _get_dmm_name(
sw_channel_kw_args["dmm_id"], dmm_calls
)
sw_channel_kw_args["dmm_id"] = channel_match[
dmm_called
]
else:
dmm_called = _get_dmm_name(
sw_channel_args[1], dmm_calls
)
sw_channel_args[1] = channel_match[dmm_called]
dmm_calls.append(dmm_called)
channel_match[dmm_called] = _get_dmm_name(
channel_match[dmm_called],
list(new_seq.declared_channels.keys()),
)
getattr(new_seq, call.name)(
*sw_channel_args, **sw_channel_kw_args
)

if strict:
for eom_channel in active_eom_channels:
current_samples = self._schedule[eom_channel].get_samples()
new_samples = new_seq._schedule[eom_channel].get_samples()
if (
np.any(current_samples.amp != new_samples.amp)
or np.any(current_samples.det != new_samples.det)
or np.any(current_samples.phase != new_samples.phase)
a-corni marked this conversation as resolved.
Show resolved Hide resolved
):
raise ValueError(
f"No match for channel {eom_channel} with an"
" EOM configuration that does not change the"
" samples."
)
return new_seq

# Channel match
channel_match: dict[str, Any] = {}
strict_error_message = ""
ch_match_err = ""
active_eom_channels = [
{**dict(zip(("channel",), call.args)), **call.kwargs}["channel"]
for call in self._calls + self._to_build_calls
Expand All @@ -798,129 +1012,41 @@ def check_retarget(ch_obj: Channel) -> bool:
**new_device.channels,
**new_device.dmm_channels,
}

for old_ch_name, old_ch_obj in self.declared_channels.items():
channel_match[old_ch_name] = None
base_msg = f"No match for channel {old_ch_name}"
# Find the corresponding channel on the new device
for new_ch_id, new_ch_obj in all_channels_new_device.items():
if (
not new_device.reusable_channels
and new_ch_id in channel_match.values()
):
# Channel already matched and can't be reused
continue

# We verify the channel class then
# check whether the addressing is Global or Local
type_match = type(old_ch_obj) is type(new_ch_obj)
basis_match = old_ch_obj.basis == new_ch_obj.basis
addressing_match = (
old_ch_obj.addressing == new_ch_obj.addressing
)
if not (type_match and basis_match and addressing_match):
# If there already is a message, keeps it
ch_match_err = ch_match_err or (
base_msg
+ " with the right type, basis and addressing."
)
continue
if old_ch_name in active_eom_channels:
# Uses EOM mode, so the new device needs a matching
# EOM configuration
if new_ch_obj.eom_config is None:
ch_match_err = base_msg + " with an EOM configuration."
continue
if (
# TODO: Improvements to this check:
# 1. multiple_beam_control doesn't matter when there
# is only one beam
# 2. custom_buffer_time doesn't have to match as long
# as `Channel_eom_buffer_time`` does
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
new_ch_obj.eom_config != old_ch_obj.eom_config
and strict
):
strict_error_message = (
base_msg + " with the same EOM configuration."
)
continue
if not strict:
channel_match[old_ch_name] = new_ch_id
# Found a match, clear match error msg for this channel
if ch_match_err.startswith(base_msg):
ch_match_err = ""
break

params_to_check = [
"mod_bandwidth",
"fixed_retarget_t",
"clock_period",
]
if isinstance(old_ch_obj, DMM):
params_to_check.append("bottom_detuning")
params_to_check.append("total_bottom_detuning")
if check_retarget(old_ch_obj) or check_retarget(new_ch_obj):
params_to_check.append("min_retarget_interval")
for param_ in params_to_check:
if getattr(new_ch_obj, param_) != getattr(
old_ch_obj, param_
):
strict_error_message = (
base_msg + f" with the same {param_}."
)
break
else:
# Only reached if all checks passed
channel_match[old_ch_name] = new_ch_id
# Found a match, clear match error msgs for this channel
if ch_match_err.startswith(base_msg):
ch_match_err = ""
if strict_error_message.startswith(base_msg):
strict_error_message = ""
break

if None in channel_match.values():
if strict_error_message:
raise ValueError(strict_error_message)
else:
raise TypeError(ch_match_err)
# Initialize the new sequence (works for Sequence subclasses too)
new_seq = type(self)(register=self._register, device=new_device)
dmm_calls: list[str] = []
# Copy the variables to the new sequence
new_seq._variables = self.declared_variables
for call in self._calls[1:] + self._to_build_calls:
# Switch the old id with the correct id
sw_channel_args = list(call.args)
sw_channel_kw_args = call.kwargs.copy()
if not (
call.name == "declare_channel"
or call.name == "config_detuning_map"
or call.name == "config_slm_mask"
possible_channel_match: list[dict[str, str]] = []
for channels_comb in itertools.product(
all_channels_new_device, repeat=len(self.declared_channels)
):
channel_match = dict(zip(self.declared_channels, channels_comb))
if is_good_match(
channel_match,
new_device.reusable_channels,
all_channels_new_device,
active_eom_channels,
strict,
):
pass
elif "name" in sw_channel_kw_args: # pragma: no cover
sw_channel_kw_args["channel_id"] = channel_match[
sw_channel_kw_args["name"]
]
elif "channel_id" in sw_channel_kw_args: # pragma: no cover
sw_channel_kw_args["channel_id"] = channel_match[
sw_channel_args[0]
]
elif "dmm_id" in sw_channel_kw_args: # pragma: no cover
sw_channel_kw_args["dmm_id"] = channel_match[
_get_dmm_name(sw_channel_kw_args["dmm_id"], dmm_calls)
]
dmm_calls.append(sw_channel_kw_args["dmm_id"])
elif call.name == "declare_channel":
sw_channel_args[1] = channel_match[sw_channel_args[0]]
else:
sw_channel_args[1] = channel_match[
_get_dmm_name(sw_channel_args[1], dmm_calls)
]
dmm_calls.append(sw_channel_args[1])
getattr(new_seq, call.name)(*sw_channel_args, **sw_channel_kw_args)
return new_seq
possible_channel_match.append(channel_match)
if not possible_channel_match:
raise_error_non_matching_channel(
new_device.reusable_channels,
all_channels_new_device,
active_eom_channels,
strict,
)
err_channel_match = {}
for channel_match in possible_channel_match:
try:
return build_sequence_from_matching(
new_device, channel_match, active_eom_channels, strict
)
except ValueError as e:
err_channel_match[tuple(channel_match.items())] = e.args
continue
raise ValueError(
"No matching found between declared channels and channels in the "
"new device that does not modify the samples of the Sequence. "
"Here is a list of matching tested and their associated errors: "
f"{err_channel_match}"
)

@seq_decorators.block_if_measured
def declare_channel(
Expand Down
Loading