From 2d8e3691fd7b217502c1d85ff906b8b625a24ee8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <67932820+kshyatt-aws@users.noreply.github.com> Date: Wed, 27 Nov 2024 14:17:19 +0100 Subject: [PATCH] Bugfix and more cleanup (#21) * More visitor cleanup * Better array_literal handling * Fix power mod issue * Another test --- src/qasm_expression.jl | 2 + src/visitor.jl | 157 ++++++++++++++++++++--------------------- test/runtests.jl | 12 ++++ 3 files changed, 91 insertions(+), 80 deletions(-) diff --git a/src/qasm_expression.jl b/src/qasm_expression.jl index 91beb51..22c6ea0 100644 --- a/src/qasm_expression.jl +++ b/src/qasm_expression.jl @@ -19,6 +19,8 @@ Base.copy(qasm_expr::QasmExpression) = QasmExpression(qasm_expr.head, deepcopy(q head(qasm_expr::QasmExpression) = qasm_expr.head +Base.convert(::Type{Vector{QasmExpression}}, expr::QasmExpression) = head(expr) == :array_literal ? convert(Vector{QasmExpression}, expr.args) : [expr] + AbstractTrees.children(qasm_expr::QasmExpression) = qasm_expr.args AbstractTrees.printnode(io::IO, qasm_expr::QasmExpression) = print(io, "QasmExpression :$(qasm_expr.head)") diff --git a/src/visitor.jl b/src/visitor.jl index 701df17..aa2c7fa 100644 --- a/src/visitor.jl +++ b/src/visitor.jl @@ -181,14 +181,9 @@ mutable struct QasmFunctionVisitor <: AbstractVisitor return v end end -function QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::Vector{QasmExpression}, provided_arguments::QasmExpression) - head(provided_arguments) == :array_literal && return QasmFunctionVisitor(parent, declared_arguments, convert(Vector{QasmExpression}, provided_arguments.args)) - QasmFunctionVisitor(parent, declared_arguments, [provided_arguments]) -end -function QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::QasmExpression, provided_arguments) - head(declared_arguments) == :array_literal && return QasmFunctionVisitor(parent, convert(Vector{QasmExpression}, declared_arguments.args), provided_arguments) - QasmFunctionVisitor(parent, [declared_arguments], provided_arguments) -end +QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::Vector{QasmExpression}, provided_arguments::QasmExpression) = QasmFunctionVisitor(parent, declared_arguments, convert(Vector{QasmExpression}, provided_arguments)) +QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::QasmExpression, provided_arguments) = QasmFunctionVisitor(parent, convert(Vector{QasmExpression}, declared_arguments), provided_arguments) + Base.parent(v::AbstractVisitor) = v.parent hasgate(v::AbstractVisitor, gate_name::String) = hasgate(parent(v), gate_name) @@ -281,9 +276,9 @@ function evaluate_modifiers(v::V, expr::QasmExpression) where {V<:AbstractVisito arg_val::Int = v(first(expr.args)::QasmExpression)::Int isinteger(arg_val) || throw(QasmVisitorError("cannot apply non-integer ($arg_val) number of controls or negcontrols.")) true_inner = expr.args[2]::QasmExpression - inner = QasmExpression(head(expr), true_inner) + inner = QasmExpression(head(expr), true_inner) while arg_val > 2 - inner = QasmExpression(head(expr), inner) + inner = QasmExpression(head(expr), inner) arg_val -= 1 end else @@ -346,6 +341,7 @@ end evaluate_qubits(v::AbstractVisitor, qubit_targets::QasmExpression) = evaluate_qubits(v::AbstractVisitor, [qubit_targets]) function remap(ix, target_mapper::Dict{Int, Int}) + isempty(target_mapper) && return ix mapped_targets = map(t->getindex(target_mapper, t), ix.targets) mapped_controls = map(c->getindex(target_mapper, c[1])=>c[2], ix.controls) return (type=ix.type, arguments=ix.arguments, targets=mapped_targets, controls=mapped_controls, exponent=ix.exponent) @@ -359,57 +355,64 @@ function process_gate_arguments(v::AbstractVisitor, gate_name::String, defined_a def_has_arguments = !isempty(defined_arguments) call_has_arguments = !isempty(v(called_arguments)) if def_has_arguments ⊻ call_has_arguments - def_has_arguments && throw(QasmVisitorError("gate $gate_name requires arguments but none were provided.")) + def_has_arguments && throw(QasmVisitorError("gate $gate_name requires arguments but none were provided.")) call_has_arguments && throw(QasmVisitorError("gate $gate_name does not accept arguments but arguments were provided.")) end - if def_has_arguments - evaled_args = v(called_arguments) - argument_values = Dict{Symbol, Real}(Symbol(arg_name)=>argument for (arg_name, argument) in zip(defined_arguments, evaled_args)) - return map(ix->bind_arguments!(ix, argument_values), gate_body) - else - return deepcopy(gate_body) - end + !def_has_arguments && return deepcopy(gate_body) # deep copy to avoid overwriting canonical definition + + evaled_args = v(called_arguments) + argument_values = Dict{Symbol, Real}(Symbol(arg_name)=>argument for (arg_name, argument) in zip(defined_arguments, evaled_args)) + return map(ix->bind_arguments!(ix, argument_values), gate_body) end function handle_gate_modifiers(ixs, mods::Vector{QasmExpression}, control_qubits::Vector{Int}, is_gphase::Bool) for mod in Iterators.reverse(mods) - control_qubit = head(mod) ∈ (:negctrl, :ctrl) ? pop!(control_qubits) : -1 - for (ii, ix) in enumerate(ixs) - if head(mod) == :pow - ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=ix.controls, exponent=ix.exponent*mod.args[1]) - elseif head(mod) == :inv - ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=ix.controls, exponent=-ix.exponent) - # need to handle "extra" target - elseif head(mod) ∈ (:negctrl, :ctrl) + if head(mod) ∈ (:negctrl, :ctrl) + control_qubit = pop!(control_qubits) + for (ii, ix) in enumerate(ixs) + exp = ix.exponent + targets = ix.targets + controls = ix.controls bit = head(mod) == :ctrl ? 1 : 0 - if is_gphase - ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=pushfirst!(ix.controls, control_qubit=>bit), exponent=ix.exponent) - else - ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=pushfirst!(ix.targets, control_qubit), controls=pushfirst!(ix.controls, control_qubit=>bit), exponent=ix.exponent) + controls = pushfirst!(controls, control_qubit=>bit) + if !is_gphase + targets = pushfirst!(targets, control_qubit) end + ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=targets, controls=controls, exponent=exp) + end + elseif head(mod) == :inv + reverse!(ixs) + for (ii, ix) in enumerate(ixs) + ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=ix.controls, exponent=-ix.exponent) + end + elseif head(mod) == :pow + pow_exp = mod.args[1] + (isinteger(pow_exp) || length(ixs) == 1) || throw(QasmVisitorError("can't apply a non-integer exponent to a gate of multiple instructions")) # can't do 2.5 for a list... yet + if length(ixs) > 1 + pow_exp < 0 && reverse!(ixs) + ixs = repeat(ixs, abs(pow_exp)) + else + ixs[1] = (type=ixs[1].type, arguments=ixs[1].arguments, targets=ixs[1].targets, controls=ixs[1].controls, exponent=ixs[1].exponent*pow_exp) end end - head(mod) == :inv && reverse!(ixs) end return ixs end function splat_gate_targets(gate_targets::Vector{Vector{Int}}) target_lengths::Vector{Int} = Int[length(t) for t in gate_targets] - longest = maximum(target_lengths) + longest = maximum(target_lengths) must_splat::Bool = any(len->len!=1 || len != longest, target_lengths) !must_splat && return longest, gate_targets - for target_ix in 1:length(gate_targets) - if target_lengths[target_ix] == 1 - append!(gate_targets[target_ix], fill(only(gate_targets[target_ix]), longest-1)) - end + for target_ix in filter(ix->target_lengths[ix] == 1, 1:length(gate_targets)) + append!(gate_targets[target_ix], fill(only(gate_targets[target_ix]), longest-1)) end return longest, gate_targets end function visit_gphase_call(v::AbstractVisitor, program_expr::QasmExpression) - has_modifiers = length(program_expr.args) == 4 - n_called_with::Int = qubit_count(v) + has_modifiers = length(program_expr.args) == 4 + n_called_with::Int = qubit_count(v) gate_targets::Vector{Int} = collect(0:n_called_with-1) provided_arg::QasmExpression = only(program_expr.args[2].args) evaled_arg = v(provided_arg) @@ -421,17 +424,9 @@ function visit_gphase_call(v::AbstractVisitor, program_expr::QasmExpression) return end -function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression) - gate_name = name(program_expr)::String - raw_call_targets = program_expr.args[3]::QasmExpression - call_targets::Vector{QasmExpression} = convert(Vector{QasmExpression}, head(raw_call_targets.args[1]) == :array_literal ? raw_call_targets.args[1].args : raw_call_targets.args)::Vector{QasmExpression} - provided_args = isempty(program_expr.args[2].args) ? QasmExpression(:empty) : only(program_expr.args[2].args)::QasmExpression - has_modifiers = length(program_expr.args) == 4 - hasgate(v, gate_name) || throw(QasmVisitorError("gate $gate_name not defined!")) - gate_def = gate_defs(v)[gate_name] - gate_def_v = QasmGateDefVisitor(v, gate_def.arguments, provided_args, gate_def.qubit_targets) - gate_def_v(deepcopy(gate_def.body)) - gate_ixs = instructions(gate_def_v) +function process_gate_targets(v, expr, gate_def) + raw_call_targets = expr.args[3]::QasmExpression + call_targets::Vector{QasmExpression} = convert(Vector{QasmExpression}, raw_call_targets.args[1])::Vector{QasmExpression} gate_targets = Vector{Int}[evaluate_qubits(v, call_target)::Vector{Int} for call_target in call_targets] n_called_with = length(gate_targets) n_defined_with = length(gate_def.qubit_targets) @@ -440,18 +435,31 @@ function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression) n_called_with = length(gate_targets[1]) gate_targets = Vector{Int}[[gt] for gt in gate_targets[1]] end - applied_arguments = process_gate_arguments(v, gate_name, gate_def.arguments, provided_args, gate_ixs) control_qubits::Vector{Int} = collect(0:(n_called_with-n_defined_with)-1) + modifier_remap = Dict{Int, Int}(old_qubit=>(old_qubit + length(control_qubits)) for old_qubit in 0:length(gate_def.qubit_targets)) + return gate_targets, control_qubits, n_called_with, modifier_remap +end + +function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression) + gate_name = name(program_expr)::String + provided_args = isempty(program_expr.args[2].args) ? QasmExpression(:empty) : only(program_expr.args[2].args)::QasmExpression + has_modifiers = length(program_expr.args) == 4 mods::Vector{QasmExpression} = has_modifiers ? convert(Vector{QasmExpression}, program_expr.args[4].args) : QasmExpression[] - if !isempty(control_qubits) - modifier_remap = Dict{Int, Int}(old_qubit=>(old_qubit + length(control_qubits)) for old_qubit in 0:length(gate_def.qubit_targets)) - for ii in 1:length(applied_arguments) - applied_arguments[ii] = remap(applied_arguments[ii], modifier_remap) - end + hasgate(v, gate_name) || throw(QasmVisitorError("gate $gate_name not defined!")) + gate_def = gate_defs(v)[gate_name] + gate_def_v = QasmGateDefVisitor(v, gate_def.arguments, provided_args, gate_def.qubit_targets) + gate_def_v(deepcopy(gate_def.body)) + gate_ixs = instructions(gate_def_v) + # generate instruction list based on provided arguments to gate + applied_arguments = process_gate_arguments(v, gate_name, gate_def.arguments, provided_args, gate_ixs) + gate_targets, control_qubits, n_called_with, modifier_remap = process_gate_targets(v, program_expr, gate_def) + for ii in 1:length(applied_arguments) # first apply any needed control qubits to the entire gate, shuffling targets + applied_arguments[ii] = remap(applied_arguments[ii], modifier_remap) end + # go through individual instructions, applying the modifiers to each argument applied_arguments = handle_gate_modifiers(applied_arguments, mods, control_qubits, false) longest, gate_targets = splat_gate_targets(gate_targets) - for splatted_ix in 1:longest + for splatted_ix in 1:longest # then splat if necessary target_mapper = Dict{Int, Int}(g_ix=>gate_targets[g_ix+1][splatted_ix] for g_ix in 0:n_called_with-1) push!(v, map(ix->remap(ix, target_mapper), applied_arguments)) end @@ -473,17 +481,8 @@ function visit_function_call(v, expr, function_name) function_v(f_expr) end end - # remap qubits and classical variables - function_args = if head(declared_args) == :array_literal - convert(Vector{QasmExpression}, declared_args.args)::Vector{QasmExpression} - else - declared_args - end - called_args = if head(provided_args) == :array_literal - convert(Vector{QasmExpression}, provided_args.args)::Vector{QasmExpression} - else - provided_args - end + function_args = convert(Vector{QasmExpression}, declared_args)::Vector{QasmExpression} + called_args = convert(Vector{QasmExpression}, provided_args)::Vector{QasmExpression} reverse_arguments_map = Dict{QasmExpression, QasmExpression}(zip(called_args, function_args)) reverse_qubits_map = Dict{Int, Int}() for variable in filter(v->head(v) ∈ (:identifier, :indexed_identifier), keys(reverse_arguments_map)) @@ -598,9 +597,9 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) end delete!(classical_defs(v), loop_variable_name) elseif head(program_expr) == :switch - case_val = v(program_expr.args[1]) - all_cases = convert(Vector{QasmExpression}, program_expr.args[2:end]) - default = findfirst(expr->head(expr) == :default, all_cases) + case_val = v(program_expr.args[1]) + all_cases = convert(Vector{QasmExpression}, program_expr.args[2:end]) + default = findfirst(expr->head(expr) == :default, all_cases) case_found = false for case in all_cases if head(case) == :case && case_val ∈ v(case.args[1]) @@ -614,7 +613,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) foreach(v, convert(Vector{QasmExpression}, all_cases[default].args)) end elseif head(program_expr) == :alias - alias_name = name(program_expr) + alias_name = name(program_expr) right_hand_side = program_expr.args[1].args[1].args[end] if head(right_hand_side) == :binary_op right_hand_side.args[1] == Symbol("++") || throw(QasmVisitorError("right hand side of alias must be either an identifier or concatenation")) @@ -624,8 +623,8 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) is_right_qubit = haskey(qubit_mapping(v), name(concat_right)) (is_left_qubit ⊻ is_right_qubit) && throw(QasmVisitorError("cannot concatenate qubit and classical arrays")) if is_left_qubit - left_qs = v(concat_left) - right_qs = v(concat_right) + left_qs = v(concat_left) + right_qs = v(concat_right) alias_qubits = collect(vcat(left_qs, right_qs)) qubit_size = length(alias_qubits) qubit_defs(v)[alias_name] = Qubit(alias_name, qubit_size) @@ -636,7 +635,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) else # both classical left_array = classical_defs(v)[name(concat_left)] right_array = classical_defs(v)[name(concat_right)] - new_size = QasmExpression(:binary_op, :+, only(size(left_array.type)), only(size(right_array.type))) + new_size = QasmExpression(:binary_op, :+, only(size(left_array.type)), only(size(right_array.type))) if left_array.type isa SizedBitVector classical_defs(v)[alias_name] = ClassicalVariable(alias_name, new_size, vcat(left_array.val, right_array.val), false) else @@ -775,9 +774,9 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) new_val = evaluate_binary_op(op, left_val, right_val) end if length(inds) > 1 - var.val[inds] .= new_val + var.val[inds] .= new_val else - var.val[inds] = new_val + var.val[inds] = new_val end end elseif head(program_expr) == :classical_declaration @@ -835,11 +834,9 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) gate_arguments = gate_def[2]::QasmExpression gate_def_targets = gate_def[3]::QasmExpression gate_body = gate_def[4]::QasmExpression - single_argument = !isempty(gate_arguments.args) && head(gate_arguments.args[1]) == :array_literal - argument_exprs = single_argument ? gate_arguments.args[1].args::Vector{Any} : gate_arguments.args::Vector{Any} + argument_exprs = !isempty(gate_arguments.args) ? convert(Vector{QasmExpression}, gate_arguments.args[1]) : QasmExpression[] argument_names = String[arg.args[1] for arg::QasmExpression in argument_exprs] - single_target = head(gate_def_targets.args[1]) == :array_literal - qubit_targets = single_target ? map(name, gate_def_targets.args[1].args)::Vector{String} : map(name, gate_def_targets.args)::Vector{String} + qubit_targets = map(name, convert(Vector{QasmExpression}, gate_def_targets.args[1]))::Vector{String} v.gate_defs[gate_name] = GateDefinition(gate_name, argument_names, qubit_targets, gate_body) elseif head(program_expr) == :function_call function_name = name(program_expr) diff --git a/test/runtests.jl b/test/runtests.jl index a58bb2e..80b067f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1205,6 +1205,11 @@ Quasar.builtin_gates[] = complex_builtin_gates gate cxx_2 c, a { pow(1/2) @ pow(4) @ cx c, a; } + gate cxx_3 c, a { + pow(1/2) @ pow(4) @ cx c, a; + i c; + i a; + } gate cxxx c, a { pow(1) @ pow(two) @ cx c, a; } @@ -1220,6 +1225,7 @@ Quasar.builtin_gates[] = complex_builtin_gates cx q1, q2; // flip cxx_1 q1, q3; // don't flip cxx_2 q1, q4; // don't flip + pow(2) @ cxx_3 q1, q4; // don't flip cx q1, q5; // flip x q3; // flip x q4; // flip @@ -1236,6 +1242,12 @@ Quasar.builtin_gates[] = complex_builtin_gates (type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 1], controls=[0=>1], exponent=1.0), (type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 2], controls=[0=>1], exponent=2.0), (type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 3], controls=[0=>1], exponent=2.0), + (type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 3], controls=[0=>1], exponent=2.0), + (type="i", arguments=InstructionArgument[], targets=[0], controls=Pair{Int,Int}[], exponent=1.0), + (type="i", arguments=InstructionArgument[], targets=[3], controls=Pair{Int,Int}[], exponent=1.0), + (type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 3], controls=[0=>1], exponent=2.0), + (type="i", arguments=InstructionArgument[], targets=[0], controls=Pair{Int,Int}[], exponent=1.0), + (type="i", arguments=InstructionArgument[], targets=[3], controls=Pair{Int,Int}[], exponent=1.0), (type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 4], controls=[0=>1], exponent=1.0), (type="u", arguments=InstructionArgument[π, 0, π], targets=[2], controls=Pair{Int,Int}[], exponent=1.0), (type="u", arguments=InstructionArgument[π, 0, π], targets=[3], controls=Pair{Int,Int}[], exponent=1.0),