diff --git a/src/autoqasm/types/types.py b/src/autoqasm/types/types.py index 3739740..1963049 100644 --- a/src/autoqasm/types/types.py +++ b/src/autoqasm/types/types.py @@ -96,7 +96,13 @@ def __init__( class ArrayVar(oqpy.ArrayVar): - def __init__(self, *args, annotations: str | Iterable[str] | None = None, **kwargs): + def __init__( + self, + init_expression: Iterable, + *args, + annotations: str | Iterable[str] | None = None, + **kwargs, + ): if ( program.get_program_conversion_context().subroutines_processing or not program.get_program_conversion_context().at_function_root_scope @@ -104,8 +110,17 @@ def __init__(self, *args, annotations: str | Iterable[str] | None = None, **kwar raise errors.InvalidArrayDeclaration( "Arrays may only be declared at the root scope of an AutoQASM main function." ) + + if not isinstance(init_expression, Iterable): + raise errors.InvalidArrayDeclaration("init_expression must be an iterable type.") + + dimensions = [len(init_expression)] super(ArrayVar, self).__init__( - *args, annotations=make_annotations_list(annotations), **kwargs + init_expression=init_expression, + *args, + annotations=make_annotations_list(annotations), + dimensions=dimensions, + **kwargs, ) self.name = program.get_program_conversion_context().next_var_name(oqpy.ArrayVar) diff --git a/test/unit_tests/autoqasm/test_types.py b/test/unit_tests/autoqasm/test_types.py index 5b63dc4..ec4ff93 100644 --- a/test/unit_tests/autoqasm/test_types.py +++ b/test/unit_tests/autoqasm/test_types.py @@ -186,9 +186,9 @@ def test_declare_array(): @aq.main def declare_array(): - a = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3]) + a = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar) a[0] = 11 - b = aq.ArrayVar([4, 5, 6], base_type=aq.IntVar, dimensions=[3]) + b = aq.ArrayVar([4, 5, 6], base_type=aq.IntVar) b[2] = 14 b = a @@ -207,8 +207,8 @@ def test_invalid_array_assignment(): @aq.main def invalid(): - a = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3]) - b = aq.ArrayVar([4, 5], base_type=aq.IntVar, dimensions=[2]) + a = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar) + b = aq.ArrayVar([4, 5], base_type=aq.IntVar) a = b # noqa: F841 with pytest.raises(aq.errors.InvalidAssignmentStatement): @@ -221,7 +221,7 @@ def test_declare_array_in_local_scope(): @aq.main def declare_array(): if aq.BoolVar(True): - _ = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3]) + _ = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar) with pytest.raises(aq.errors.InvalidArrayDeclaration): declare_array.build() @@ -236,7 +236,7 @@ def main() -> list[int]: @aq.subroutine def declare_array(): - _ = aq.ArrayVar([1, 2, 3], dimensions=[3]) + _ = aq.ArrayVar([1, 2, 3]) with pytest.raises(aq.errors.InvalidArrayDeclaration): main.build() @@ -383,7 +383,7 @@ def annotation_test(input: list[int]): @aq.main def main(): - a = aq.ArrayVar([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dimensions=[10]) + a = aq.ArrayVar([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) annotation_test(a) with pytest.raises(aq.errors.ParameterTypeError): @@ -724,3 +724,41 @@ def main(): with pytest.raises(aq.errors.ParameterTypeError): main.build() + + +def test_array_does_not_accept_dimensions_argument(): + @aq.main + def declare_array(): + aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3]) + + with pytest.raises(TypeError): + declare_array.build() + + +def test_array_requires_init_expression(): + @aq.main + def declare_array(): + aq.ArrayVar() + + with pytest.raises(TypeError): + declare_array.build() + + +def test_array_init_expression_type(): + @aq.main + def declare_array(): + aq.ArrayVar(1) + + with pytest.raises(aq.errors.InvalidArrayDeclaration): + declare_array.build() + + +def test_array_supports_multidimensional_arrays(): + @aq.main + def declare_array(): + aq.ArrayVar([[1, 2], [3, 4]]) + + expected = """OPENQASM 3.0; +array[int[32], 2, 2] a = {{1, 2}, {3, 4}};""" + + declare_array.build().to_ir() == expected