Skip to content

Commit

Permalink
Merge pull request #17 from QuantumBFS/ZX
Browse files Browse the repository at this point in the history
Update for new Multigraph backend
  • Loading branch information
GiggleLiu authored Aug 29, 2020
2 parents aaebf82 + 8651ce7 commit b19cdbe
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions src/zx_plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export plot

function Multigraph2Graph(mg::Multigraph)
g = SimpleGraph(nv(mg))
vs = vertices(mg)
vs = sort!(vertices(mg))
for me in edges(mg)
add_edge!(g, searchsortedfirst(vs, src(me)), searchsortedfirst(vs, dst(me)))
end
Expand All @@ -24,11 +24,20 @@ function Multigraph2Graph(mg::Multigraph)
end

ZX2Graph(zxd::ZXDiagram) = Multigraph2Graph(zxd.mg)
ZX2Graph(zxg::ZXGraph) = Multigraph2Graph(zxg.mg)
function ZX2Graph(zxg::ZXGraph)
g = SimpleGraph(nv(zxg.mg))
vs = sort!(vertices(zxg.mg))
for me in edges(zxg.mg)
add_edge!(g, searchsortedfirst(vs, src(me)), searchsortedfirst(vs, dst(me)))
end
# multiplicities = ["$(mul(mg, src(e), dst(e)))" for e in edges(g)]
multiplicities = [ZXCalculus.is_hadamard(zxg, vs[src(e)], vs[dst(e)]) ? "×2" : "" for e in edges(g)]
return g, multiplicities
end

function et2color(et::String)
et == "" && return colorant"black"
et == "×2" && return colorant"blue"
return colorant"blue"
end

function st2color(S::SpiderType.SType)
Expand All @@ -39,11 +48,11 @@ function st2color(S::SpiderType.SType)
S == SpiderType.Out && return colorant"gray"
end

ZX2nodefillc(zxd) = [st2color(zxd.st[v]) for v in vertices(zxd.mg)]
ZX2nodefillc(zxd) = [st2color(zxd.st[v]) for v in sort!(vertices(zxd.mg))]

function ZX2nodelabel(zxd)
nodelabel = String[]
for v in vertices(zxd.mg)
for v in sort!(vertices(zxd.mg))
zxd.st[v] == SpiderType.Z && push!(nodelabel, "[$(v)]\n$(print_phase(zxd.ps[v]))")
zxd.st[v] == SpiderType.X && push!(nodelabel, "[$(v)]\n$(print_phase(zxd.ps[v]))")
zxd.st[v] == SpiderType.H && push!(nodelabel, "[$(v)]")
Expand All @@ -63,33 +72,34 @@ end

function layout2locs(zxd::ZXDiagram{T,P}) where {T,P}
lo = zxd.layout
vs = spiders(zxd)
spider_seq = ZXCalculus.spider_sequence(zxd)
vs = sort!(spiders(zxd))
locs = Dict()
nqubit = lo.nbits
frontier_v = ones(T, nqubit)
frontier_locs = ones(nqubit)

while sum([frontier_v[i] <= length(lo.spider_seq[i]) for i = 1:nqubit]) > 0
while sum([frontier_v[i] <= length(spider_seq[i]) for i = 1:nqubit]) > 0
for q = 1:nqubit
if frontier_v[q] <= length(lo.spider_seq[q])
v = lo.spider_seq[q][frontier_v[q]]
if frontier_v[q] <= length(spider_seq[q])
v = spider_seq[q][frontier_v[q]]
nb = neighbors(zxd, v)
if length(nb) <= 2
locs[v] = (Float64(frontier_locs[q]), Float64(q))
frontier_locs[q] += 1
frontier_v[q] += 1
else
v1 = nb[[qubit_loc(lo, u) != q for u in nb]][1]
v1 = nb[[qubit_loc(zxd, u) != q for u in nb]][1]
if spider_type(zxd, v1) == SpiderType.H
v1 = setdiff(neighbors(zxd, v1), [v])[1]
end
if sum([findfirst(isequal(u), lo.spider_seq[qubit_loc(lo, u)]) != frontier_v[qubit_loc(lo, u)] for u in [v, v1]]) == 0
x = maximum(frontier_locs[min(qubit_loc(lo, v), qubit_loc(lo, v1)):max(qubit_loc(lo, v), qubit_loc(lo, v1))])
if sum([findfirst(isequal(u), spider_seq[qubit_loc(zxd, u)]) != frontier_v[qubit_loc(zxd, u)] for u in [v, v1]]) == 0
x = maximum(frontier_locs[min(qubit_loc(zxd, v), qubit_loc(zxd, v1)):max(qubit_loc(zxd, v), qubit_loc(zxd, v1))])
for u in [v, v1]
locs[u] = (Float64(x), Float64(qubit_loc(lo, u)))
frontier_v[qubit_loc(lo, u)] += 1
locs[u] = (Float64(x), Float64(qubit_loc(zxd, u)))
frontier_v[qubit_loc(zxd, u)] += 1
end
for q in min(qubit_loc(lo, v), qubit_loc(lo, v1)):max(qubit_loc(lo, v), qubit_loc(lo, v1))
for q in min(qubit_loc(zxd, v), qubit_loc(zxd, v1)):max(qubit_loc(zxd, v), qubit_loc(zxd, v1))
frontier_locs[q] = x + 1
end
end
Expand All @@ -112,17 +122,18 @@ end

function layout2locs(zxd::ZXGraph{T,P}) where {T,P}
lo = zxd.layout
vs = spiders(zxd)
spider_seq = ZXCalculus.spider_sequence(zxd)
vs = sort!(spiders(zxd))
locs = Dict()
nqubit = lo.nbits
frontier_v = ones(T, nqubit)
frontier_locs = ones(nqubit)
phase_gadget_loc = 1.0

for v in vs
if qubit_loc(lo, v) != nothing
y = qubit_loc(lo, v)
x = findfirst(isequal(v), lo.spider_seq[y])
if qubit_loc(zxd, v) != nothing
y = qubit_loc(zxd, v)
x = findfirst(isequal(v), spider_seq[y])
locs[v] = (Float64(x), Float64(y))
else
locs[v] = nothing
Expand All @@ -144,7 +155,7 @@ function layout2locs(zxd::ZXGraph{T,P}) where {T,P}
# locs[v] = ((x1+x2)/2, (y1+y2)/2)
end
end
println(locs)
# println(locs)
locs_x = [locs[v][1] for v in vs]
locs_y = [locs[v][2] for v in vs]
return locs_x, locs_y
Expand Down

0 comments on commit b19cdbe

Please sign in to comment.