From 38395a78ff0326bf5ffe122050e4598e7ef469a6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <67932820+kshyatt-aws@users.noreply.github.com> Date: Fri, 5 Jul 2024 20:39:33 -0400 Subject: [PATCH] fix: StateVector shouldn't be a supported pragma for DM simulator (#25) Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com> --- src/braket/simulator_v2/base_simulator_v2.py | 7 ++- .../density_matrix_simulator_v2.py | 1 - src/braket/simulator_v2/julia_import.py | 50 ++++++++----------- .../test_density_matrix_simulator_v2.py | 1 - .../test_state_vector_simulator_v2.py | 49 +++++++++++------- 5 files changed, 55 insertions(+), 53 deletions(-) diff --git a/src/braket/simulator_v2/base_simulator_v2.py b/src/braket/simulator_v2/base_simulator_v2.py index 63bf16c..213f229 100644 --- a/src/braket/simulator_v2/base_simulator_v2.py +++ b/src/braket/simulator_v2/base_simulator_v2.py @@ -3,7 +3,6 @@ from collections.abc import Sequence from typing import Any, Optional, Union -import juliacall import numpy as np from braket.default_simulator.result_types import TargetedResultType from braket.default_simulator.simulator import BaseLocalSimulator @@ -100,7 +99,7 @@ def run_jaqcd( translated_ir, qubit_count = self._jaqcd_to_jl(circuit_ir, shots) try: r = jl.simulate(self._device, translated_ir, qubit_count, shots) - except juliacall.JuliaError as e: + except JuliaError as e: _handle_julia_error(e) r.additionalMetadata.action = circuit_ir r = _result_value_to_ndarray(r) @@ -158,7 +157,7 @@ def run_openqasm( """ try: r = jl.simulate(self._device, self._openqasm_to_jl(openqasm_ir), shots) - except juliacall.JuliaError as e: + except JuliaError as e: _handle_julia_error(e) r.additionalMetadata.action = openqasm_ir # attach the result types @@ -209,7 +208,7 @@ def run_multiple( shots=shots, inputs=inputs, ) - except juliacall.JuliaError as e: + except JuliaError as e: _handle_julia_error(e) for r_ix, result in enumerate(results): diff --git a/src/braket/simulator_v2/density_matrix_simulator_v2.py b/src/braket/simulator_v2/density_matrix_simulator_v2.py index 9204b62..027aefb 100644 --- a/src/braket/simulator_v2/density_matrix_simulator_v2.py +++ b/src/braket/simulator_v2/density_matrix_simulator_v2.py @@ -115,7 +115,6 @@ def properties(self) -> GateModelSimulatorDeviceCapabilities: ], "supportedPragmas": [ "braket_unitary_matrix", - "braket_result_type_state_vector", "braket_result_type_density_matrix", "braket_result_type_sample", "braket_result_type_expectation", diff --git a/src/braket/simulator_v2/julia_import.py b/src/braket/simulator_v2/julia_import.py index 4c058f0..94d5c6c 100644 --- a/src/braket/simulator_v2/julia_import.py +++ b/src/braket/simulator_v2/julia_import.py @@ -1,42 +1,34 @@ import os -import sys import warnings +import juliacall + # Check if JuliaCall is already loaded, and if so, warn the user # about the relevant environment variables. If not loaded, # set up sensible defaults. -if "juliacall" in sys.modules: +# Required to avoid segfaults (https://juliapy.github.io/PythonCall.jl/dev/faq/) +if os.environ.get("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes") != "yes": warnings.warn( - "`juliacall` module has already been imported. " - + "Make sure that you have set the environment variable " - + "`PYTHON_JULIACALL_HANDLE_SIGNALS=yes` to avoid segfaults. " + "`PYTHON_JULIACALL_HANDLE_SIGNALS` environment variable " + + "is set to something other than 'yes' or ''. " + + "You will experience segfaults if running with Julia multithreading." ) -else: - # Required to avoid segfaults (https://juliapy.github.io/PythonCall.jl/dev/faq/) - if os.environ.get("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes") != "yes": - warnings.warn( - "`PYTHON_JULIACALL_HANDLE_SIGNALS` environment variable " - + "is set to something other than 'yes' or ''. " - + "You will experience segfaults if running with Julia multithreading." - ) - - if os.environ.get("PYTHON_JULIACALL_THREADS", "auto") != "auto": - warnings.warn( - "`PYTHON_JULIACALL_THREADS` environment variable is set to " - + "something other than `auto`, so `amazon-braket-simulator-v2` " - + "was not able to set it." - ) - for k, default in ( - ("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes"), - ("PYTHON_JULIACALL_THREADS", "auto"), - ("PYTHON_JULIACALL_OPTLEVEL", "3"), - # let the user's Conda/Pip handle installing things - ("JULIA_CONDAPKG_BACKEND", "Null"), - ): - os.environ[k] = os.environ.get(k, default) +if os.environ.get("PYTHON_JULIACALL_THREADS", "auto") != "auto": + warnings.warn( + "`PYTHON_JULIACALL_THREADS` environment variable is set to " + + "something other than `auto`, so `amazon-braket-simulator-v2` " + + "was not able to set it." + ) -import juliacall +for k, default in ( + ("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes"), + ("PYTHON_JULIACALL_THREADS", "auto"), + ("PYTHON_JULIACALL_OPTLEVEL", "3"), + # let the user's Conda/Pip handle installing things + ("JULIA_CONDAPKG_BACKEND", "Null"), +): + os.environ[k] = os.environ.get(k, default) jl = juliacall.Base.Module() diff --git a/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator_v2.py b/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator_v2.py index 5d4e1a2..f823965 100644 --- a/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator_v2.py +++ b/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator_v2.py @@ -267,7 +267,6 @@ def test_properties(): ], "supportedPragmas": [ "braket_unitary_matrix", - "braket_result_type_state_vector", "braket_result_type_density_matrix", "braket_result_type_sample", "braket_result_type_expectation", diff --git a/test/unit_tests/braket/simulator_v2/test_state_vector_simulator_v2.py b/test/unit_tests/braket/simulator_v2/test_state_vector_simulator_v2.py index bcb8c0e..6afd193 100644 --- a/test/unit_tests/braket/simulator_v2/test_state_vector_simulator_v2.py +++ b/test/unit_tests/braket/simulator_v2/test_state_vector_simulator_v2.py @@ -85,7 +85,6 @@ def test_simulator_run_grcs_16(grcs_16_qubit, batch_size): if isinstance(grcs_16_qubit.circuit_ir, JaqcdProgram): result = simulator.run( grcs_16_qubit.circuit_ir, - qubit_count=16, shots=0, batch_size=batch_size, ) @@ -102,9 +101,7 @@ def test_simulator_run_bell_pair(bell_ir, batch_size, caplog): simulator = StateVectorSimulator() shots_count = 10000 if isinstance(bell_ir, JaqcdProgram): - result = simulator.run( - bell_ir, qubit_count=2, shots=shots_count, batch_size=batch_size - ) + result = simulator.run(bell_ir, shots=shots_count, batch_size=batch_size) else: result = simulator.run(bell_ir, shots=shots_count, batch_size=batch_size) @@ -729,7 +726,6 @@ def test_simulator_identity(caplog): if isinstance(program, JaqcdProgram): result = simulator.run( program, - qubit_count=2, shots=shots_count, ) else: @@ -756,7 +752,7 @@ def test_simulator_instructions_not_supported(circuit_noise): ) with pytest.raises(TypeError, match=no_noise): if isinstance(circuit_noise, JaqcdProgram): - simulator.run(circuit_noise, qubit_count=2, shots=0) + simulator.run(circuit_noise, shots=0) else: simulator.run(circuit_noise, shots=0) @@ -765,7 +761,7 @@ def test_simulator_run_no_results_no_shots(bell_ir): simulator = StateVectorSimulator() with pytest.raises(ValueError): if isinstance(bell_ir, JaqcdProgram): - simulator.run(bell_ir, qubit_count=2, shots=0) + simulator.run(bell_ir, shots=0) else: simulator.run(bell_ir, shots=0) @@ -788,7 +784,7 @@ def test_simulator_run_amplitude_shots(): """ ) with pytest.raises(ValueError): - simulator.run(jaqcd, qubit_count=2, shots=100) + simulator.run(jaqcd, shots=100) with pytest.raises(ValueError): simulator.run(qasm, shots=100) @@ -838,7 +834,7 @@ def test_simulator_run_statevector_shots(): """ ) with pytest.raises(ValueError): - simulator.run(jaqcd, qubit_count=2, shots=100) + simulator.run(jaqcd, shots=100) with pytest.raises(ValueError): simulator.run(qasm, shots=100) @@ -871,7 +867,7 @@ def test_simulator_run_result_types_shots(caplog): """ ) shots_count = 100 - jaqcd_result = simulator.run(jaqcd, qubit_count=2, shots=shots_count) + jaqcd_result = simulator.run(jaqcd, shots=shots_count) qasm_result = simulator.run(qasm, shots=shots_count) for result in jaqcd_result, qasm_result: assert all([len(measurement) == 2] for measurement in result.measurements) @@ -911,7 +907,7 @@ def test_simulator_run_result_types_shots_basis_rotation_gates(caplog): """ ) shots_count = 1000 - jaqcd_result = simulator.run(jaqcd, qubit_count=2, shots=shots_count) + jaqcd_result = simulator.run(jaqcd, shots=shots_count) qasm_result = simulator.run(qasm, shots=shots_count) for result in jaqcd_result, qasm_result: assert all([len(measurement) == 2] for measurement in result.measurements) @@ -941,7 +937,7 @@ def test_simulator_run_result_types_shots_basis_rotation_gates_value_error(): ) ) shots_count = 1000 - simulator.run(ir, qubit_count=2, shots=shots_count) + simulator.run(ir, shots=shots_count) @pytest.mark.parametrize( @@ -1031,7 +1027,7 @@ def test_simulator_run_observable_references_invalid_qubit(ir, qubit_count): shots_count = 0 if isinstance(ir, JaqcdProgram): with pytest.raises(ValueError): - simulator.run(ir, qubit_count=qubit_count, shots=shots_count) + simulator.run(ir, shots=shots_count) else: # index error since you're indexing from a logical qubit with pytest.raises(IndexError): @@ -1046,7 +1042,7 @@ def test_simulator_bell_pair_result_types( simulator = StateVectorSimulator() ir = bell_ir_with_result(targets) if isinstance(ir, JaqcdProgram): - result = simulator.run(ir, qubit_count=2, shots=0, batch_size=batch_size) + result = simulator.run(ir, shots=0, batch_size=batch_size) else: result = simulator.run(ir, shots=0, batch_size=batch_size) assert len(result.resultTypes) == 2 @@ -1082,7 +1078,7 @@ def test_simulator_fails_samples_0_shots(): """ ) with pytest.raises(ValueError): - simulator.run(jaqcd, qubit_count=1, shots=0) + simulator.run(jaqcd, shots=0) with pytest.raises(ValueError): simulator.run(qasm, shots=0) @@ -1161,7 +1157,7 @@ def test_simulator_valid_observables(result_types, expected): } ) ) - result = simulator.run(prog, qubit_count=2, shots=0) + result = simulator.run(prog, shots=0) for i in range(len(result_types)): assert np.allclose(result.resultTypes[i].value, expected[i]) @@ -1482,7 +1478,7 @@ def test_simulator_analytic_value_type(jaqcd_string, oq3_pragma, jaqcd_type): #pragma braket result {oq3_pragma} """ ) - result = simulator.run(jaqcd, qubit_count=2, shots=0) + result = simulator.run(jaqcd, shots=0) assert result.resultTypes[0].type == jaqcd_type assert isinstance(result.resultTypes[0].value, np.ndarray) result = simulator.run(qasm, shots=0) @@ -1568,12 +1564,29 @@ def test_noncontiguous_qubits_jaqcd_multiple_targets(): "results": [{"type": "expectation", "observable": ["z"], "targets": [4]}], } prg = JaqcdProgram.parse_raw(json.dumps(jaqcd_program)) - result = StateVectorSimulator().run(prg, qubit_count=2, shots=0) + result = StateVectorSimulator().run(prg, shots=0) assert result.measuredQubits == [0, 1] assert result.resultTypes[0].value == -1 +def test_run_multiple_single_circuit(): + payload = [ + OpenQASMProgram( + source=""" + OPENQASM 3.0; + bit[1] b; + qubit[1] q; + h q[0]; + #pragma braket result state_vector + """ + ) + ] + simulator = StateVectorSimulator() + results = simulator.run_multiple(payload, shots=0) + assert np.allclose(results[0].resultTypes[0].value, np.array([1, 1]) / np.sqrt(2)) + + def test_run_multiple(): payloads = [ OpenQASMProgram(