Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow modification of the EOM setpoint without disabling EOM mode #708

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pulser-core/pulser/sequence/_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,14 @@ def enable_eom(
detuning_off: float,
switching_beams: tuple[RydbergBeam, ...] = (),
_skip_buffer: bool = False,
_skip_wait_for_fall: bool = False,
) -> None:
channel_obj = self[channel_id].channel_obj
# Adds a buffer unless the channel is empty or _skip_buffer = True
if not _skip_buffer and self.get_duration(channel_id):
# Wait for the last pulse to ramp down (if needed)
self.wait_for_fall(channel_id)
if not _skip_wait_for_fall:
# Wait for the last pulse to ramp down (if needed)
self.wait_for_fall(channel_id)
eom_buffer_time = self[channel_id].adjust_duration(
channel_obj._eom_buffer_time
)
Expand Down
190 changes: 146 additions & 44 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import pulser.sequence._decorators as seq_decorators
from pulser.channels.base_channel import Channel, States, get_states_from_bases
from pulser.channels.dmm import DMM, _dmm_id_from_name, _get_dmm_name
from pulser.channels.eom import RydbergEOM
from pulser.channels.eom import RydbergBeam, RydbergEOM
from pulser.devices._device_datacls import BaseDevice
from pulser.json.abstract_repr.deserializer import (
deserialize_abstract_sequence,
Expand Down Expand Up @@ -1139,54 +1139,35 @@ def enable_eom_mode(
raise RuntimeError(
f"The '{channel}' channel is already in EOM mode."
)

channel_obj = self.declared_channels[channel]
if not channel_obj.supports_eom():
raise TypeError(f"Channel '{channel}' does not have an EOM.")

on_pulse = Pulse.ConstantPulse(
channel_obj.min_duration, amp_on, detuning_on, 0.0
detuning_off, switching_beams = self._process_eom_parameters(
channel_obj, amp_on, detuning_on, optimal_detuning_off
)
stored_opt_detuning_off = optimal_detuning_off
if not isinstance(on_pulse, Parametrized):
channel_obj.validate_pulse(on_pulse)
amp_on = cast(float, amp_on)
detuning_on = cast(float, detuning_on)
eom_config = cast(RydbergEOM, channel_obj.eom_config)
if not isinstance(optimal_detuning_off, Parametrized):
(
detuning_off,
switching_beams,
) = eom_config.calculate_detuning_off(
amp_on,
detuning_on,
optimal_detuning_off,
return_switching_beams=True,
)
off_pulse = Pulse.ConstantPulse(
channel_obj.min_duration, 0.0, detuning_off, 0.0
)
channel_obj.validate_pulse(off_pulse)
# Update optimal_detuning_off to match the chosen detuning_off
# This minimizes the changes to the sequence when the device
# is switched
stored_opt_detuning_off = detuning_off

if not self.is_parametrized():
phase_drift_params = _PhaseDriftParams(
drift_rate=-detuning_off,
# enable_eom() calls wait for fall, so the block only
# starts after fall time
ti=self.get_duration(channel, include_fall_time=True),
)
self._schedule.enable_eom(
channel, amp_on, detuning_on, detuning_off, switching_beams
if not self.is_parametrized():
detuning_off = cast(float, detuning_off)
phase_drift_params = _PhaseDriftParams(
drift_rate=-detuning_off,
# enable_eom() calls wait for fall, so the block only
# starts after fall time
ti=self.get_duration(channel, include_fall_time=True),
)
self._schedule.enable_eom(
channel,
cast(float, amp_on),
cast(float, detuning_on),
detuning_off,
switching_beams,
)
if correct_phase_drift:
buffer_slot = self._last(channel)
drift = phase_drift_params.calc_phase_drift(buffer_slot.tf)
self._phase_shift(
-drift, *buffer_slot.targets, basis=channel_obj.basis
)
if correct_phase_drift:
buffer_slot = self._last(channel)
drift = phase_drift_params.calc_phase_drift(buffer_slot.tf)
self._phase_shift(
-drift, *buffer_slot.targets, basis=channel_obj.basis
)

# Manually store the call to "enable_eom_mode" so that the updated
# 'optimal_detuning_off' is stored
Expand All @@ -1201,7 +1182,7 @@ def enable_eom_mode(
channel=channel,
amp_on=amp_on,
detuning_on=detuning_on,
optimal_detuning_off=stored_opt_detuning_off,
optimal_detuning_off=detuning_off,
correct_phase_drift=correct_phase_drift,
),
)
Expand Down Expand Up @@ -1253,6 +1234,90 @@ def disable_eom_mode(
basis=ch_schedule.channel_obj.basis,
)

@seq_decorators.verify_parametrization
@seq_decorators.block_if_measured
def modify_eom_setpoint(
self,
channel: str,
amp_on: Union[float, Parametrized],
detuning_on: Union[float, Parametrized],
optimal_detuning_off: Union[float, Parametrized] = 0.0,
correct_phase_drift: bool = False,
) -> None:
"""Modifies the setpoint of an ongoing EOM mode operation.

Note:
Modifying the EOM setpoint will automatically enforce a buffer.
The detuning will go to the `detuning_off` value during
this buffer. This buffer will not wait for pulses on other
channels to finish, so calling `Sequence.align()` or
`Sequence.delay()` beforehand is necessary to avoid eventual
conflicts.

Args:
channel: The name of the channel currently in EOM mode.
amp_on: The new amplitude of the EOM pulses (in rad/µs).
detuning_on: The new detuning of the EOM pulses (in rad/µs).
optimal_detuning_off: The new optimal value of detuning (in rad/µs)
when there is no pulse being played. It will choose the closest
value among the existing options.
correct_phase_drift: Performs a phase shift to correct for the
phase drift incurred while modifying the EOM setpoint.
"""
if not self.is_in_eom_mode(channel):
raise RuntimeError(f"The '{channel}' channel is not in EOM mode.")

channel_obj = self.declared_channels[channel]
detuning_off, switching_beams = self._process_eom_parameters(
channel_obj, amp_on, detuning_on, optimal_detuning_off
)

if not self.is_parametrized():
detuning_off = cast(float, detuning_off)
self._schedule.disable_eom(channel, _skip_buffer=True)
old_phase_drift_params = self._get_last_eom_pulse_phase_drift(
channel
)
new_phase_drift_params = _PhaseDriftParams(
drift_rate=-detuning_off,
ti=self.get_duration(channel, include_fall_time=False),
)
self._schedule.enable_eom(
channel,
cast(float, amp_on),
cast(float, detuning_on),
detuning_off,
switching_beams,
_skip_wait_for_fall=True,
)
if correct_phase_drift:
buffer_slot = self._last(channel)
drift = old_phase_drift_params.calc_phase_drift(
buffer_slot.ti
) + new_phase_drift_params.calc_phase_drift(buffer_slot.tf)
self._phase_shift(
-drift, *buffer_slot.targets, basis=channel_obj.basis
)

# Manually store the call to "modify_eom_setpoint" so that the updated
# 'optimal_detuning_off' is stored
call_container = (
self._to_build_calls if self.is_parametrized() else self._calls
)
call_container.append(
_Call(
"modify_eom_setpoint",
(),
dict(
channel=channel,
amp_on=amp_on,
detuning_on=detuning_on,
optimal_detuning_off=detuning_off,
correct_phase_drift=correct_phase_drift,
),
)
)

@seq_decorators.store
@seq_decorators.mark_non_empty
@seq_decorators.block_if_measured
Expand Down Expand Up @@ -2389,6 +2454,43 @@ def _validate_add_protocol(self, protocol: str) -> None:
+ ", ".join(valid_protocols)
)

def _process_eom_parameters(
self,
channel_obj: Channel,
amp_on: Union[float, Parametrized],
detuning_on: Union[float, Parametrized],
optimal_detuning_off: Union[float, Parametrized],
) -> tuple[float | Parametrized, tuple[RydbergBeam, ...]]:
on_pulse = Pulse.ConstantPulse(
channel_obj.min_duration, amp_on, detuning_on, 0.0
)
stored_opt_detuning_off = optimal_detuning_off
switching_beams: tuple[RydbergBeam, ...] = ()
if not isinstance(on_pulse, Parametrized):
channel_obj.validate_pulse(on_pulse)
amp_on = cast(float, amp_on)
detuning_on = cast(float, detuning_on)
eom_config = cast(RydbergEOM, channel_obj.eom_config)
if not isinstance(optimal_detuning_off, Parametrized):
(
detuning_off,
switching_beams,
) = eom_config.calculate_detuning_off(
amp_on,
detuning_on,
optimal_detuning_off,
return_switching_beams=True,
)
off_pulse = Pulse.ConstantPulse(
channel_obj.min_duration, 0.0, detuning_off, 0.0
)
channel_obj.validate_pulse(off_pulse)
# Update optimal_detuning_off to match the chosen detuning_off
# This minimizes the changes to the sequence when the device
# is switched
stored_opt_detuning_off = detuning_off
return stored_opt_detuning_off, switching_beams

def _reset_parametrized(self) -> None:
"""Resets all attributes related to parametrization."""
# Signals the sequence as actively "building" ie not parametrized
Expand Down
63 changes: 63 additions & 0 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2408,6 +2408,69 @@ def test_eom_buffer(
)


@pytest.mark.parametrize("correct_phase_drift", [True, False])
@pytest.mark.parametrize("amp_diff", [0, -0.5, 0.5])
@pytest.mark.parametrize("det_diff", [0, -5, 10])
def test_modify_eom_setpoint(
reg, mod_device, amp_diff, det_diff, correct_phase_drift
):
seq = Sequence(reg, mod_device)
seq.declare_channel("ryd", "rydberg_global")
params = seq.declare_variable("params", dtype=float, size=2)
dt = 100
amp, det_on = params
with pytest.raises(
RuntimeError, match="The 'ryd' channel is not in EOM mode"
):
seq.modify_eom_setpoint("ryd", amp, det_on)
seq.enable_eom_mode("ryd", amp, det_on)
assert seq.is_in_eom_mode("ryd")
seq.add_eom_pulse("ryd", dt, 0.0)
seq.delay(dt, "ryd")

new_amp, new_det_on = amp + amp_diff, det_on + det_diff
seq.modify_eom_setpoint(
"ryd", new_amp, new_det_on, correct_phase_drift=correct_phase_drift
)
assert seq.is_in_eom_mode("ryd")
seq.add_eom_pulse("ryd", dt, 0.0)
seq.delay(dt, "ryd")

ryd_ch_obj = seq.declared_channels["ryd"]
eom_buffer_dt = ryd_ch_obj._eom_buffer_time
param_vals = [1.0, 0.0]
built_seq = seq.build(params=param_vals)
expected_duration = 4 * dt + eom_buffer_dt
assert built_seq.get_duration() == expected_duration

amp, det = param_vals
ch_samples = sample(built_seq).channel_samples["ryd"]
expected_amp = np.zeros(expected_duration)
expected_amp[:dt] = amp
expected_amp[-2 * dt : -dt] = amp + amp_diff
np.testing.assert_array_equal(expected_amp, ch_samples.amp)

det_off = ryd_ch_obj.eom_config.calculate_detuning_off(amp, det, 0.0)
new_det_off = ryd_ch_obj.eom_config.calculate_detuning_off(
amp + amp_diff, det + det_diff, 0.0
)
expected_det = np.zeros(expected_duration)
expected_det[:dt] = det
expected_det[dt : 2 * dt] = det_off
expected_det[2 * dt : 2 * dt + eom_buffer_dt] = new_det_off
expected_det[-2 * dt : -dt] = det + det_diff
expected_det[-dt:] = new_det_off
np.testing.assert_array_equal(expected_det, ch_samples.det)

final_phase = built_seq.current_phase_ref("q0", "ground-rydberg")
if not correct_phase_drift:
assert final_phase == 0.0
else:
assert final_phase != 0.0
np.testing.assert_array_equal(ch_samples.phase[: 2 * dt], 0.0)
np.testing.assert_array_equal(ch_samples.phase[-2 * dt :], final_phase)


def test_max_duration(reg, mod_device):
dev_ = dataclasses.replace(mod_device, max_sequence_duration=100)
seq = Sequence(reg, dev_)
Expand Down