Skip to content

Commit

Permalink
Bugfix and more cleanup (#21)
Browse files Browse the repository at this point in the history
* More visitor cleanup

* Better array_literal handling

* Fix power mod issue

* Another test
  • Loading branch information
kshyatt-aws authored Nov 27, 2024
1 parent c0a2811 commit 2d8e369
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 80 deletions.
2 changes: 2 additions & 0 deletions src/qasm_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Base.copy(qasm_expr::QasmExpression) = QasmExpression(qasm_expr.head, deepcopy(q

head(qasm_expr::QasmExpression) = qasm_expr.head

Base.convert(::Type{Vector{QasmExpression}}, expr::QasmExpression) = head(expr) == :array_literal ? convert(Vector{QasmExpression}, expr.args) : [expr]

AbstractTrees.children(qasm_expr::QasmExpression) = qasm_expr.args
AbstractTrees.printnode(io::IO, qasm_expr::QasmExpression) = print(io, "QasmExpression :$(qasm_expr.head)")

Expand Down
157 changes: 77 additions & 80 deletions src/visitor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,9 @@ mutable struct QasmFunctionVisitor <: AbstractVisitor
return v
end
end
function QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::Vector{QasmExpression}, provided_arguments::QasmExpression)
head(provided_arguments) == :array_literal && return QasmFunctionVisitor(parent, declared_arguments, convert(Vector{QasmExpression}, provided_arguments.args))
QasmFunctionVisitor(parent, declared_arguments, [provided_arguments])
end
function QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::QasmExpression, provided_arguments)
head(declared_arguments) == :array_literal && return QasmFunctionVisitor(parent, convert(Vector{QasmExpression}, declared_arguments.args), provided_arguments)
QasmFunctionVisitor(parent, [declared_arguments], provided_arguments)
end
QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::Vector{QasmExpression}, provided_arguments::QasmExpression) = QasmFunctionVisitor(parent, declared_arguments, convert(Vector{QasmExpression}, provided_arguments))
QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::QasmExpression, provided_arguments) = QasmFunctionVisitor(parent, convert(Vector{QasmExpression}, declared_arguments), provided_arguments)

Base.parent(v::AbstractVisitor) = v.parent

hasgate(v::AbstractVisitor, gate_name::String) = hasgate(parent(v), gate_name)
Expand Down Expand Up @@ -281,9 +276,9 @@ function evaluate_modifiers(v::V, expr::QasmExpression) where {V<:AbstractVisito
arg_val::Int = v(first(expr.args)::QasmExpression)::Int
isinteger(arg_val) || throw(QasmVisitorError("cannot apply non-integer ($arg_val) number of controls or negcontrols."))
true_inner = expr.args[2]::QasmExpression
inner = QasmExpression(head(expr), true_inner)
inner = QasmExpression(head(expr), true_inner)
while arg_val > 2
inner = QasmExpression(head(expr), inner)
inner = QasmExpression(head(expr), inner)
arg_val -= 1
end
else
Expand Down Expand Up @@ -346,6 +341,7 @@ end
evaluate_qubits(v::AbstractVisitor, qubit_targets::QasmExpression) = evaluate_qubits(v::AbstractVisitor, [qubit_targets])

function remap(ix, target_mapper::Dict{Int, Int})
isempty(target_mapper) && return ix
mapped_targets = map(t->getindex(target_mapper, t), ix.targets)
mapped_controls = map(c->getindex(target_mapper, c[1])=>c[2], ix.controls)
return (type=ix.type, arguments=ix.arguments, targets=mapped_targets, controls=mapped_controls, exponent=ix.exponent)
Expand All @@ -359,57 +355,64 @@ function process_gate_arguments(v::AbstractVisitor, gate_name::String, defined_a
def_has_arguments = !isempty(defined_arguments)
call_has_arguments = !isempty(v(called_arguments))
if def_has_arguments call_has_arguments
def_has_arguments && throw(QasmVisitorError("gate $gate_name requires arguments but none were provided."))
def_has_arguments && throw(QasmVisitorError("gate $gate_name requires arguments but none were provided."))
call_has_arguments && throw(QasmVisitorError("gate $gate_name does not accept arguments but arguments were provided."))
end
if def_has_arguments
evaled_args = v(called_arguments)
argument_values = Dict{Symbol, Real}(Symbol(arg_name)=>argument for (arg_name, argument) in zip(defined_arguments, evaled_args))
return map(ix->bind_arguments!(ix, argument_values), gate_body)
else
return deepcopy(gate_body)
end
!def_has_arguments && return deepcopy(gate_body) # deep copy to avoid overwriting canonical definition

evaled_args = v(called_arguments)
argument_values = Dict{Symbol, Real}(Symbol(arg_name)=>argument for (arg_name, argument) in zip(defined_arguments, evaled_args))
return map(ix->bind_arguments!(ix, argument_values), gate_body)
end

function handle_gate_modifiers(ixs, mods::Vector{QasmExpression}, control_qubits::Vector{Int}, is_gphase::Bool)
for mod in Iterators.reverse(mods)
control_qubit = head(mod) (:negctrl, :ctrl) ? pop!(control_qubits) : -1
for (ii, ix) in enumerate(ixs)
if head(mod) == :pow
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=ix.controls, exponent=ix.exponent*mod.args[1])
elseif head(mod) == :inv
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=ix.controls, exponent=-ix.exponent)
# need to handle "extra" target
elseif head(mod) (:negctrl, :ctrl)
if head(mod) (:negctrl, :ctrl)
control_qubit = pop!(control_qubits)
for (ii, ix) in enumerate(ixs)
exp = ix.exponent
targets = ix.targets
controls = ix.controls
bit = head(mod) == :ctrl ? 1 : 0
if is_gphase
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=pushfirst!(ix.controls, control_qubit=>bit), exponent=ix.exponent)
else
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=pushfirst!(ix.targets, control_qubit), controls=pushfirst!(ix.controls, control_qubit=>bit), exponent=ix.exponent)
controls = pushfirst!(controls, control_qubit=>bit)
if !is_gphase
targets = pushfirst!(targets, control_qubit)
end
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=targets, controls=controls, exponent=exp)
end
elseif head(mod) == :inv
reverse!(ixs)
for (ii, ix) in enumerate(ixs)
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=ix.controls, exponent=-ix.exponent)
end
elseif head(mod) == :pow
pow_exp = mod.args[1]
(isinteger(pow_exp) || length(ixs) == 1) || throw(QasmVisitorError("can't apply a non-integer exponent to a gate of multiple instructions")) # can't do 2.5 for a list... yet
if length(ixs) > 1
pow_exp < 0 && reverse!(ixs)
ixs = repeat(ixs, abs(pow_exp))
else
ixs[1] = (type=ixs[1].type, arguments=ixs[1].arguments, targets=ixs[1].targets, controls=ixs[1].controls, exponent=ixs[1].exponent*pow_exp)
end
end
head(mod) == :inv && reverse!(ixs)
end
return ixs
end

function splat_gate_targets(gate_targets::Vector{Vector{Int}})
target_lengths::Vector{Int} = Int[length(t) for t in gate_targets]
longest = maximum(target_lengths)
longest = maximum(target_lengths)
must_splat::Bool = any(len->len!=1 || len != longest, target_lengths)
!must_splat && return longest, gate_targets
for target_ix in 1:length(gate_targets)
if target_lengths[target_ix] == 1
append!(gate_targets[target_ix], fill(only(gate_targets[target_ix]), longest-1))
end
for target_ix in filter(ix->target_lengths[ix] == 1, 1:length(gate_targets))
append!(gate_targets[target_ix], fill(only(gate_targets[target_ix]), longest-1))
end
return longest, gate_targets
end

function visit_gphase_call(v::AbstractVisitor, program_expr::QasmExpression)
has_modifiers = length(program_expr.args) == 4
n_called_with::Int = qubit_count(v)
has_modifiers = length(program_expr.args) == 4
n_called_with::Int = qubit_count(v)
gate_targets::Vector{Int} = collect(0:n_called_with-1)
provided_arg::QasmExpression = only(program_expr.args[2].args)
evaled_arg = v(provided_arg)
Expand All @@ -421,17 +424,9 @@ function visit_gphase_call(v::AbstractVisitor, program_expr::QasmExpression)
return
end

function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression)
gate_name = name(program_expr)::String
raw_call_targets = program_expr.args[3]::QasmExpression
call_targets::Vector{QasmExpression} = convert(Vector{QasmExpression}, head(raw_call_targets.args[1]) == :array_literal ? raw_call_targets.args[1].args : raw_call_targets.args)::Vector{QasmExpression}
provided_args = isempty(program_expr.args[2].args) ? QasmExpression(:empty) : only(program_expr.args[2].args)::QasmExpression
has_modifiers = length(program_expr.args) == 4
hasgate(v, gate_name) || throw(QasmVisitorError("gate $gate_name not defined!"))
gate_def = gate_defs(v)[gate_name]
gate_def_v = QasmGateDefVisitor(v, gate_def.arguments, provided_args, gate_def.qubit_targets)
gate_def_v(deepcopy(gate_def.body))
gate_ixs = instructions(gate_def_v)
function process_gate_targets(v, expr, gate_def)
raw_call_targets = expr.args[3]::QasmExpression
call_targets::Vector{QasmExpression} = convert(Vector{QasmExpression}, raw_call_targets.args[1])::Vector{QasmExpression}
gate_targets = Vector{Int}[evaluate_qubits(v, call_target)::Vector{Int} for call_target in call_targets]
n_called_with = length(gate_targets)
n_defined_with = length(gate_def.qubit_targets)
Expand All @@ -440,18 +435,31 @@ function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression)
n_called_with = length(gate_targets[1])
gate_targets = Vector{Int}[[gt] for gt in gate_targets[1]]
end
applied_arguments = process_gate_arguments(v, gate_name, gate_def.arguments, provided_args, gate_ixs)
control_qubits::Vector{Int} = collect(0:(n_called_with-n_defined_with)-1)
modifier_remap = Dict{Int, Int}(old_qubit=>(old_qubit + length(control_qubits)) for old_qubit in 0:length(gate_def.qubit_targets))
return gate_targets, control_qubits, n_called_with, modifier_remap
end

function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression)
gate_name = name(program_expr)::String
provided_args = isempty(program_expr.args[2].args) ? QasmExpression(:empty) : only(program_expr.args[2].args)::QasmExpression
has_modifiers = length(program_expr.args) == 4
mods::Vector{QasmExpression} = has_modifiers ? convert(Vector{QasmExpression}, program_expr.args[4].args) : QasmExpression[]
if !isempty(control_qubits)
modifier_remap = Dict{Int, Int}(old_qubit=>(old_qubit + length(control_qubits)) for old_qubit in 0:length(gate_def.qubit_targets))
for ii in 1:length(applied_arguments)
applied_arguments[ii] = remap(applied_arguments[ii], modifier_remap)
end
hasgate(v, gate_name) || throw(QasmVisitorError("gate $gate_name not defined!"))
gate_def = gate_defs(v)[gate_name]
gate_def_v = QasmGateDefVisitor(v, gate_def.arguments, provided_args, gate_def.qubit_targets)
gate_def_v(deepcopy(gate_def.body))
gate_ixs = instructions(gate_def_v)
# generate instruction list based on provided arguments to gate
applied_arguments = process_gate_arguments(v, gate_name, gate_def.arguments, provided_args, gate_ixs)
gate_targets, control_qubits, n_called_with, modifier_remap = process_gate_targets(v, program_expr, gate_def)
for ii in 1:length(applied_arguments) # first apply any needed control qubits to the entire gate, shuffling targets
applied_arguments[ii] = remap(applied_arguments[ii], modifier_remap)
end
# go through individual instructions, applying the modifiers to each argument
applied_arguments = handle_gate_modifiers(applied_arguments, mods, control_qubits, false)
longest, gate_targets = splat_gate_targets(gate_targets)
for splatted_ix in 1:longest
for splatted_ix in 1:longest # then splat if necessary
target_mapper = Dict{Int, Int}(g_ix=>gate_targets[g_ix+1][splatted_ix] for g_ix in 0:n_called_with-1)
push!(v, map(ix->remap(ix, target_mapper), applied_arguments))
end
Expand All @@ -473,17 +481,8 @@ function visit_function_call(v, expr, function_name)
function_v(f_expr)
end
end
# remap qubits and classical variables
function_args = if head(declared_args) == :array_literal
convert(Vector{QasmExpression}, declared_args.args)::Vector{QasmExpression}
else
declared_args
end
called_args = if head(provided_args) == :array_literal
convert(Vector{QasmExpression}, provided_args.args)::Vector{QasmExpression}
else
provided_args
end
function_args = convert(Vector{QasmExpression}, declared_args)::Vector{QasmExpression}
called_args = convert(Vector{QasmExpression}, provided_args)::Vector{QasmExpression}
reverse_arguments_map = Dict{QasmExpression, QasmExpression}(zip(called_args, function_args))
reverse_qubits_map = Dict{Int, Int}()
for variable in filter(v->head(v) (:identifier, :indexed_identifier), keys(reverse_arguments_map))
Expand Down Expand Up @@ -598,9 +597,9 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
end
delete!(classical_defs(v), loop_variable_name)
elseif head(program_expr) == :switch
case_val = v(program_expr.args[1])
all_cases = convert(Vector{QasmExpression}, program_expr.args[2:end])
default = findfirst(expr->head(expr) == :default, all_cases)
case_val = v(program_expr.args[1])
all_cases = convert(Vector{QasmExpression}, program_expr.args[2:end])
default = findfirst(expr->head(expr) == :default, all_cases)
case_found = false
for case in all_cases
if head(case) == :case && case_val v(case.args[1])
Expand All @@ -614,7 +613,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
foreach(v, convert(Vector{QasmExpression}, all_cases[default].args))
end
elseif head(program_expr) == :alias
alias_name = name(program_expr)
alias_name = name(program_expr)
right_hand_side = program_expr.args[1].args[1].args[end]
if head(right_hand_side) == :binary_op
right_hand_side.args[1] == Symbol("++") || throw(QasmVisitorError("right hand side of alias must be either an identifier or concatenation"))
Expand All @@ -624,8 +623,8 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
is_right_qubit = haskey(qubit_mapping(v), name(concat_right))
(is_left_qubit is_right_qubit) && throw(QasmVisitorError("cannot concatenate qubit and classical arrays"))
if is_left_qubit
left_qs = v(concat_left)
right_qs = v(concat_right)
left_qs = v(concat_left)
right_qs = v(concat_right)
alias_qubits = collect(vcat(left_qs, right_qs))
qubit_size = length(alias_qubits)
qubit_defs(v)[alias_name] = Qubit(alias_name, qubit_size)
Expand All @@ -636,7 +635,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
else # both classical
left_array = classical_defs(v)[name(concat_left)]
right_array = classical_defs(v)[name(concat_right)]
new_size = QasmExpression(:binary_op, :+, only(size(left_array.type)), only(size(right_array.type)))
new_size = QasmExpression(:binary_op, :+, only(size(left_array.type)), only(size(right_array.type)))
if left_array.type isa SizedBitVector
classical_defs(v)[alias_name] = ClassicalVariable(alias_name, new_size, vcat(left_array.val, right_array.val), false)
else
Expand Down Expand Up @@ -775,9 +774,9 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
new_val = evaluate_binary_op(op, left_val, right_val)
end
if length(inds) > 1
var.val[inds] .= new_val
var.val[inds] .= new_val
else
var.val[inds] = new_val
var.val[inds] = new_val
end
end
elseif head(program_expr) == :classical_declaration
Expand Down Expand Up @@ -835,11 +834,9 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
gate_arguments = gate_def[2]::QasmExpression
gate_def_targets = gate_def[3]::QasmExpression
gate_body = gate_def[4]::QasmExpression
single_argument = !isempty(gate_arguments.args) && head(gate_arguments.args[1]) == :array_literal
argument_exprs = single_argument ? gate_arguments.args[1].args::Vector{Any} : gate_arguments.args::Vector{Any}
argument_exprs = !isempty(gate_arguments.args) ? convert(Vector{QasmExpression}, gate_arguments.args[1]) : QasmExpression[]
argument_names = String[arg.args[1] for arg::QasmExpression in argument_exprs]
single_target = head(gate_def_targets.args[1]) == :array_literal
qubit_targets = single_target ? map(name, gate_def_targets.args[1].args)::Vector{String} : map(name, gate_def_targets.args)::Vector{String}
qubit_targets = map(name, convert(Vector{QasmExpression}, gate_def_targets.args[1]))::Vector{String}
v.gate_defs[gate_name] = GateDefinition(gate_name, argument_names, qubit_targets, gate_body)
elseif head(program_expr) == :function_call
function_name = name(program_expr)
Expand Down
12 changes: 12 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,11 @@ Quasar.builtin_gates[] = complex_builtin_gates
gate cxx_2 c, a {
pow(1/2) @ pow(4) @ cx c, a;
}
gate cxx_3 c, a {
pow(1/2) @ pow(4) @ cx c, a;
i c;
i a;
}
gate cxxx c, a {
pow(1) @ pow(two) @ cx c, a;
}
Expand All @@ -1220,6 +1225,7 @@ Quasar.builtin_gates[] = complex_builtin_gates
cx q1, q2; // flip
cxx_1 q1, q3; // don't flip
cxx_2 q1, q4; // don't flip
pow(2) @ cxx_3 q1, q4; // don't flip
cx q1, q5; // flip
x q3; // flip
x q4; // flip
Expand All @@ -1236,6 +1242,12 @@ Quasar.builtin_gates[] = complex_builtin_gates
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 1], controls=[0=>1], exponent=1.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 2], controls=[0=>1], exponent=2.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 3], controls=[0=>1], exponent=2.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 3], controls=[0=>1], exponent=2.0),
(type="i", arguments=InstructionArgument[], targets=[0], controls=Pair{Int,Int}[], exponent=1.0),
(type="i", arguments=InstructionArgument[], targets=[3], controls=Pair{Int,Int}[], exponent=1.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 3], controls=[0=>1], exponent=2.0),
(type="i", arguments=InstructionArgument[], targets=[0], controls=Pair{Int,Int}[], exponent=1.0),
(type="i", arguments=InstructionArgument[], targets=[3], controls=Pair{Int,Int}[], exponent=1.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 4], controls=[0=>1], exponent=1.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[2], controls=Pair{Int,Int}[], exponent=1.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[3], controls=Pair{Int,Int}[], exponent=1.0),
Expand Down

0 comments on commit 2d8e369

Please sign in to comment.