Skip to content

Commit

Permalink
Deprecate usage of int in QubitIds (#774)
Browse files Browse the repository at this point in the history
* Deprecate usage of int

* Fix tutorials

* Fix targets

* Update warning message

* Fix tutorials

* Revert "Fix tutorials"

This reverts commit a99d615.

* Fix tutorial

* Fix typing with qubitID=str

* Fix tests

* Fix typing

* Fix tests

* Fix built in values for register generation

* Fix tests

* Address review comments
  • Loading branch information
a-corni authored Dec 11, 2024
1 parent 9a32a2b commit 4e73021
Show file tree
Hide file tree
Showing 23 changed files with 1,523 additions and 1,468 deletions.
6 changes: 3 additions & 3 deletions docs/source/intro_rydberg_blockade.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"outputs": [],
"source": [
"layers = 3\n",
"reg = pulser.Register.hexagon(layers)\n",
"reg = pulser.Register.hexagon(layers, prefix=\"q\")\n",
"reg.draw(with_labels=False)"
]
},
Expand Down Expand Up @@ -221,7 +221,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "pulserenv",
"language": "python",
"name": "python3"
},
Expand All @@ -235,7 +235,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions pulser-core/pulser/devices/_device_datacls.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def _validate_coords(
),
kind: Literal["atoms", "traps"] = "atoms",
) -> None:
ids = list(coords_dict.keys())
ids = [str(id) for id in list(coords_dict.keys())]
coords = list(map(pm.AbstractArray, coords_dict.values()))
if kind == "atoms" and not (
"max_atom_num" in self._optional_parameters
Expand Down Expand Up @@ -763,7 +763,7 @@ def _specs(self, for_docs: bool = False) -> str:
(
"\t"
+ r"- Maximum :math:`\Omega`:"
+ f" {float(cast(float,ch.max_amp)):.4g} rad/µs"
+ f" {float(cast(float, ch.max_amp)):.4g} rad/µs"
),
(
(
Expand All @@ -776,7 +776,7 @@ def _specs(self, for_docs: bool = False) -> str:
else (
"\t"
+ r"- Bottom :math:`|\delta|`:"
+ f" {float(cast(float,ch.bottom_detuning)):.4g}"
+ f" {float(cast(float, ch.bottom_detuning)):.4g}"
+ " rad/µs"
)
),
Expand Down
4 changes: 2 additions & 2 deletions pulser-core/pulser/register/_reg_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def _draw_2D(

if dmm_qubits:
dmm_pos = []
for i, c in zip(ids, pos):
if i in dmm_qubits.keys():
for id, c in zip(ids, pos):
if id in dmm_qubits.keys():
dmm_pos.append(c)
dmm_arr = np.array(dmm_pos)
max_weight = max(dmm_qubits.values())
Expand Down
13 changes: 12 additions & 1 deletion pulser-core/pulser/register/base_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import json
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from collections.abc import Sequence as abcSequence
Expand Down Expand Up @@ -44,7 +45,7 @@
from pulser.register.register_layout import RegisterLayout

T = TypeVar("T", bound="BaseRegister")
QubitId = Union[int, str]
QubitId = str


class _LayoutInfo(NamedTuple):
Expand Down Expand Up @@ -77,6 +78,16 @@ def __init__(
[pm.AbstractArray(v, dtype=float) for v in qubits.values()]
)
self._ids: tuple[QubitId, ...] = tuple(qubits.keys())
if any(not isinstance(id, str) for id in self._ids):
warnings.simplefilter("always")
warnings.warn(
"Usage of `int`s or any non-`str`types as `QubitId`s will be "
"deprecated. Define your `QubitId`s as `str`s, prefer setting "
"`prefix='q'` when using classmethods, as that will become the"
" new default once `int` qubit IDs become invalid.",
DeprecationWarning,
stacklevel=2,
)
self._layout_info: Optional[_LayoutInfo] = None
self._init_kwargs(**kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pulser-core/pulser/register/register_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def draw(
draw_graph=draw_graph,
draw_half_radius=draw_half_radius,
)
ids = list(range(self.number_of_traps))
ids = [str(i) for i in range(self.number_of_traps)]
if self.dimensionality == 2:
fig, ax = self._initialize_fig_axes(
coords,
Expand Down
4 changes: 3 additions & 1 deletion pulser-core/pulser/register/weight_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def draw(
pos = self.trap_coordinates
custom_ax = custom_ax or cast(Axes, self._initialize_fig_axes(pos)[1])

labels_ = labels if labels is not None else list(range(len(pos)))
labels_ = (
labels if labels is not None else [str(i) for i in range(len(pos))]
)

super()._draw_2D(
custom_ax,
Expand Down
8 changes: 4 additions & 4 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ def target_index(
Args:
qubits: The new target for this channel. Must correspond to a
qubit index or an collection of qubit indices, when multi-qubit
qubit index or a collection of qubit indices, when multi-qubit
addressing is possible.
A qubit index is a number between 0 and the number of qubits.
It is then converted to a Qubit ID using the order in which
Expand Down Expand Up @@ -2057,7 +2057,7 @@ def _add(
@seq_decorators.block_if_measured
def _target(
self,
qubits: Union[Collection[QubitId], QubitId, Parametrized],
qubits: Union[Collection[QubitId | int], QubitId | int, Parametrized],
channel: str,
_index: bool = False,
) -> None:
Expand Down Expand Up @@ -2105,7 +2105,7 @@ def _target(
self._schedule.add_target(qubit_ids_set, channel)

def _check_qubits_give_ids(
self, *qubits: Union[QubitId, Parametrized], _index: bool = False
self, *qubits: Union[QubitId, int, Parametrized], _index: bool = False
) -> set[QubitId]:
if _index:
if self.is_parametrized():
Expand Down Expand Up @@ -2158,7 +2158,7 @@ def _delay(
def _phase_shift(
self,
phi: float | Parametrized,
*targets: QubitId | Parametrized,
*targets: QubitId | int | Parametrized,
basis: str,
_index: bool = False,
) -> None:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ filterwarnings = [
"error",
# Except these particular warnings, which are ignored
'ignore:A duration of \d+ ns is not a multiple of:UserWarning',
'ignore:Usage of `int`s or any non-`str`types as `QubitId`s:DeprecationWarning',
]

[build-system]
Expand Down
3 changes: 3 additions & 0 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,9 @@ def test_exceptions(self, sequence):
UserWarning, match="converts all qubit ID's to strings"
), pytest.raises(
AbstractReprError, match="Name collisions encountered"
), pytest.warns(
DeprecationWarning,
match="Usage of `int`s or any non-`str`types as `QubitId`s",
):
Register({"0": (0, 0), 0: (20, 20)})._to_abstract_repr()

Expand Down
40 changes: 32 additions & 8 deletions tests/test_dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

import re
from typing import cast
from typing import Union, cast
from unittest.mock import patch

import numpy as np
Expand All @@ -37,7 +37,9 @@ def layout(self) -> RegisterLayout:

@pytest.fixture
def register(self, layout: RegisterLayout) -> BaseRegister:
return layout.define_register(0, 1, 2, 3, qubit_ids=(0, 1, 2, 3))
return layout.define_register(
0, 1, 2, 3, qubit_ids=("0", "1", "2", "3")
)

@pytest.fixture
def map_reg(self, layout: RegisterLayout) -> MappableRegister:
Expand All @@ -63,7 +65,7 @@ def slm_map(
) -> DetuningMap:
return layout.define_detuning_map(slm_dict)

@pytest.mark.parametrize("bad_key", [{"1": 1.0}, {4: 1.0}])
@pytest.mark.parametrize("bad_key", [{1: 1.0}, {"4": 1.0}])
def test_define_detuning_map(
self,
layout: RegisterLayout,
Expand All @@ -72,6 +74,13 @@ def test_define_detuning_map(
bad_key: dict,
):
for reg in (layout, map_reg):
if type(list(bad_key.keys())[0]) == int:
with pytest.raises(
ValueError,
match="'trap_coordinates' must be an array or list",
):
reg.define_detuning_map(bad_key) # type: ignore
continue
with pytest.raises(
ValueError,
match=re.escape(
Expand All @@ -91,7 +100,7 @@ def test_define_detuning_map(

def test_qubit_weight_map(self, register):
# Purposefully unsorted
qid_weight_map = {1: 1.0, 0: 0.1, 3: 0.4}
qid_weight_map = {"1": 1.0, "0": 0.1, "3": 0.4}
sorted_qids = sorted(qid_weight_map)
det_map = register.define_detuning_map(qid_weight_map)
qubits = register.qubits
Expand All @@ -104,7 +113,7 @@ def test_qubit_weight_map(self, register):
# We recover the original qid_weight_map (and undefined qids show as 0)
assert det_map.get_qubit_weight_map(qubits) == {
**qid_weight_map,
2: 0.0,
"2": 0.0,
}

tri_layout = TriangularLatticeLayout(100, spacing=5)
Expand Down Expand Up @@ -172,8 +181,12 @@ def test_detuning_map_bad_init(
):
DetuningMap([(0, 0), (1, 0)], [0])

bad_weights = {0: -1.0, 1: 1.0, 2: 1.0}
for reg in (layout, map_reg, register):
bad_weights: dict[int | str, float]
if reg == register:
bad_weights = {"0": -1.0, "1": 1.0, "2": 1.0}
else:
bad_weights = {0: -1.0, 1: 1.0, 2: 1.0}
with pytest.raises(
ValueError, match="All weights must be between 0 and 1."
):
Expand All @@ -187,11 +200,22 @@ def test_init(
det_dict: dict[int, float],
slm_dict: dict[int, float],
):

for reg in (layout, map_reg, register):
for detuning_map_dict in (det_dict, slm_dict):
reg_det_map_dict: dict[int | str, float]
if reg == register:
reg_det_map_dict = {
str(id): weight
for (id, weight) in detuning_map_dict.items()
}
else:
reg_det_map_dict = cast(
dict[Union[int, str], float], detuning_map_dict
)
detuning_map = cast(
DetuningMap,
reg.define_detuning_map(detuning_map_dict), # type: ignore
reg.define_detuning_map(reg_det_map_dict), # type: ignore
)
assert np.all(
[
Expand Down Expand Up @@ -227,7 +251,7 @@ def test_draw(self, det_map, slm_map, patch_plt_show, with_labels):
)[1],
np.array(slm_map.trap_coordinates),
[
i
str(i)
for i, _ in enumerate(cast(list, slm_map.trap_coordinates))
],
with_labels=True,
Expand Down
19 changes: 14 additions & 5 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,24 @@ def test_detuning_map():


@pytest.mark.parametrize(
"reg",
"reg_dict",
[
Register(dict(enumerate([(2, 3), (5, 1), (10, 0)]))),
Register3D({3: (2, 3, 4), 4: (3, 4, 5), 2: (4, 5, 7)}),
dict(enumerate([(2, 3), (5, 1), (10, 0)])),
{3: (2, 3, 4), 4: (3, 4, 5), 2: (4, 5, 7)},
],
)
def test_register_numbered_keys(reg):
def test_register_numbered_keys(reg_dict):
with pytest.warns(
DeprecationWarning,
match="Usage of `int`s or any non-`str`types as `QubitId`s",
):
reg = (Register if len(reg_dict[2]) == 2 else Register3D)(reg_dict)
j = json.dumps(reg, cls=PulserEncoder)
decoded_reg = json.loads(j, cls=PulserDecoder)
with pytest.warns(
DeprecationWarning,
match="Usage of `int`s or any non-`str`types as `QubitId`s",
):
decoded_reg = json.loads(j, cls=PulserDecoder)
assert reg == decoded_reg
assert all([type(i) is int for i in decoded_reg.qubit_ids])

Expand Down
18 changes: 10 additions & 8 deletions tests/test_paramseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pulser.parametrized.variable import VariableItem
from pulser.waveforms import BlackmanWaveform

reg = Register.rectangle(4, 3)
reg = Register.rectangle(4, 3, prefix="q")
device = DigitalAnalogDevice


Expand All @@ -50,10 +50,10 @@ def test_parametrized_channel_initial_target():
var = sb.declare_variable("var")
sb.declare_channel("ch1", "rydberg_local")
sb.target_index(var, "ch1")
sb.declare_channel("ch0", "raman_local", initial_target=0)
sb.declare_channel("ch0", "raman_local", initial_target="q0")
assert sb._calls[-1].name == "declare_channel"
assert sb._to_build_calls[-1].name == "target"
assert sb._to_build_calls[-1].args == (0, "ch0")
assert sb._to_build_calls[-1].args == ("q0", "ch0")


def test_stored_calls():
Expand Down Expand Up @@ -125,7 +125,9 @@ def test_stored_calls():
sb.target_index(q_var, "ch1")

sb2 = Sequence(reg, MockDevice)
sb2.declare_channel("ch1", "rydberg_local", initial_target={3, 4, 5})
sb2.declare_channel(
"ch1", "rydberg_local", initial_target={"q3", "q4", "q5"}
)
q_var2 = sb2.declare_variable("q_var2", size=5, dtype=int)
var2 = sb2.declare_variable("var2")
assert sb2._building
Expand Down Expand Up @@ -229,7 +231,7 @@ def test_str():
def test_screen():
sb = Sequence(reg, device)
sb.declare_channel("ch1", "rydberg_global")
assert sb.current_phase_ref(4, basis="ground-rydberg") == 0
assert sb.current_phase_ref("q4", basis="ground-rydberg") == 0
var = sb.declare_variable("var")
sb.delay(var, "ch1")
with pytest.raises(RuntimeError, match="can't be called in parametrized"):
Expand All @@ -239,7 +241,7 @@ def test_screen():
def test_parametrized_in_eom_mode(mod_device):
# Case 1: Sequence becomes parametrized while in EOM mode
seq = Sequence(reg, mod_device)
seq.declare_channel("ch0", "rydberg_local", initial_target=0)
seq.declare_channel("ch0", "rydberg_local", initial_target="q0")

assert not seq.is_in_eom_mode("ch0")
seq.enable_eom_mode("ch0", amp_on=2.0, detuning_on=0.0)
Expand Down Expand Up @@ -278,8 +280,8 @@ def test_parametrized_before_eom_mode(mod_device):
# Case 2: Sequence is parametrized before entering EOM mode
seq = Sequence(reg, mod_device)

seq.declare_channel("ch0", "rydberg_local", initial_target=0)
seq.declare_channel("raman", "raman_local", initial_target=2)
seq.declare_channel("ch0", "rydberg_local", initial_target="q0")
seq.declare_channel("raman", "raman_local", initial_target="q2")
amp = seq.declare_variable("amp", dtype=float)
seq.add(Pulse.ConstantPulse(200, amp, -1, 0), "ch0")

Expand Down
7 changes: 5 additions & 2 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,8 +1614,11 @@ def test_str(reg, device, mod_device, det_map):

measure_msg = "\n\nMeasured in basis: digital"
assert seq.__str__() == msg_ch0 + msg_ch1 + msg_det_map + measure_msg

seq2 = Sequence(Register({"q0": (0, 0), 1: (5, 5)}), device)
with pytest.warns(
DeprecationWarning,
match="Usage of `int`s or any non-`str`types as `QubitId`s",
):
seq2 = Sequence(Register({"q0": (0, 0), 1: (5, 5)}), device)
seq2.declare_channel("ch1", "rydberg_global")
with pytest.raises(
NotImplementedError,
Expand Down
Loading

0 comments on commit 4e73021

Please sign in to comment.