Skip to content

Commit

Permalink
Add vega plot backend (#43)
Browse files Browse the repository at this point in the history
* rework with Compose

* fix tests

* fix version

* add Vega backend

* add tests

* resolve version
  • Loading branch information
ChenZhao44 authored Dec 8, 2021
1 parent 7fd99f2 commit f04e3f4
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 253 deletions.
14 changes: 8 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ version = "0.6.1"
BitBasis = "50ba71b6-fa0f-514d-ae9a-0916efc90dcf"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Compose = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Multigraphs = "7ebac608-6c66-46e6-9856-b5f43e107bac"
Vega = "239c3e63-733f-47ad-beb7-a12fde22c578"
Viznet = "52a3aca4-6234-47fd-b74a-806bdf78ede9"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
ZXCalculus = "3525faa3-032d-4235-a8d4-8c2939a218dd"
Expand All @@ -17,16 +17,18 @@ ZXCalculus = "3525faa3-032d-4235-a8d4-8c2939a218dd"
BitBasis = "0.7"
Colors = "0.11, 0.12"
Compose = "0.8, 0.9"
GraphPlot = "0.4"
LightGraphs = "1.3"
Multigraphs = "0.2"
Viznet = "0.3.1"
YaoBlocks = "0.11"
ZXCalculus = "0.3, 0.4"
ZXCalculus = "0.4.4"
julia = "1"

[extras]
CompilerPluginTools = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3638"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c"
YaoHIR = "6769671a-fce8-4286-b3f7-6099e1b1298a"
YaoLocations = "66df03fb-d475-48f7-b449-3d9064bf085b"

[targets]
test = ["Test"]
test = ["Test", "CompilerPluginTools", "Yao", "YaoHIR", "YaoLocations"]
2 changes: 2 additions & 0 deletions src/YaoPlots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ plot(;kwargs...) = x->plot(x;kwargs...)
include("helperblock.jl")
include("vizcircuit.jl")
include("zx_plot.jl")
include("zx_plot_vega.jl")
include("zx_plot_compose.jl")

end
233 changes: 4 additions & 229 deletions src/zx_plot.jl
Original file line number Diff line number Diff line change
@@ -1,231 +1,6 @@
using ZXCalculus
using LightGraphs, Multigraphs
using GraphPlot: gplot
using Colors
using ZXCalculus: qubit_loc

function Multigraph2Graph(mg::Multigraph)
g = SimpleGraph(nv(mg))
vs = sort!(vertices(mg))
for me in edges(mg)
add_edge!(g, searchsortedfirst(vs, src(me)), searchsortedfirst(vs, dst(me)))
end
multiplicities = ["×$(mul(mg, vs[src(e)], vs[dst(e)]))" for e in edges(g)]
for i = 1:length(multiplicities)
if multiplicities[i] == "×1"
multiplicities[i] = ""
end
end
return g, multiplicities
end

ZX2Graph(zxd::ZXDiagram) = Multigraph2Graph(zxd.mg)
function ZX2Graph(zxg::ZXGraph)
g = SimpleGraph(nv(zxg.mg))
vs = sort!(vertices(zxg.mg))
for me in edges(zxg.mg)
add_edge!(g, searchsortedfirst(vs, src(me)), searchsortedfirst(vs, dst(me)))
end
# multiplicities = ["$(mul(mg, src(e), dst(e)))" for e in edges(g)]
multiplicities = [ZXCalculus.is_hadamard(zxg, vs[src(e)], vs[dst(e)]) ? "×2" : "" for e in edges(g)]
return g, multiplicities
end

function et2color(et::String)
et == "" && return colorant"black"
return colorant"blue"
end

function st2color(S::SpiderType.SType)
S == SpiderType.Z && return colorant"green"
S == SpiderType.X && return colorant"red"
S == SpiderType.H && return colorant"yellow"
S == SpiderType.In && return colorant"lightblue"
S == SpiderType.Out && return colorant"gray"
end

ZX2nodefillc(zxd) = [st2color(zxd.st[v]) for v in sort!(vertices(zxd.mg))]

function ZX2nodelabel(zxd)
nodelabel = String[]
for v in sort!(vertices(zxd.mg))
zxd.st[v] == SpiderType.Z && push!(nodelabel, "[$(v)]\n$(print_phase(zxd.ps[v]))")
zxd.st[v] == SpiderType.X && push!(nodelabel, "[$(v)]\n$(print_phase(zxd.ps[v]))")
zxd.st[v] == SpiderType.H && push!(nodelabel, "[$(v)]")
zxd.st[v] == SpiderType.In && push!(nodelabel, "[$(v)]")
zxd.st[v] == SpiderType.Out && push!(nodelabel, "[$(v)]")
end
return nodelabel
end

function print_phase(p)
if typeof(p) <: Rational
return "$(p.num)π/$(p.den)"
else
return "$p π"
end
end

function layout2locs(zxd::ZXDiagram{T,P}) where {T,P}
lo = zxd.layout
spider_seq = ZXCalculus.spider_sequence(zxd)
vs = sort!(spiders(zxd))
locs = Dict()
nqubit = lo.nbits
frontier_v = ones(T, nqubit)
frontier_locs = ones(nqubit)

while sum([frontier_v[i] <= length(spider_seq[i]) for i = 1:nqubit]) > 0
for q = 1:nqubit
if frontier_v[q] <= length(spider_seq[q])
v = spider_seq[q][frontier_v[q]]
nb = neighbors(zxd, v)
if length(nb) <= 2
locs[v] = (Float64(frontier_locs[q]), Float64(q))
frontier_locs[q] += 1
frontier_v[q] += 1
else
v1 = nb[[qubit_loc(zxd, u) != q for u in nb]][1]
if spider_type(zxd, v1) == SpiderType.H
v1 = setdiff(neighbors(zxd, v1), [v])[1]
end
if sum([findfirst(isequal(u), spider_seq[qubit_loc(zxd, u)]) != frontier_v[qubit_loc(zxd, u)] for u in [v, v1]]) == 0
x = maximum(frontier_locs[min(qubit_loc(zxd, v), qubit_loc(zxd, v1)):max(qubit_loc(zxd, v), qubit_loc(zxd, v1))])
for u in [v, v1]
locs[u] = (Float64(x), Float64(qubit_loc(zxd, u)))
frontier_v[qubit_loc(zxd, u)] += 1
end
for q in min(qubit_loc(zxd, v), qubit_loc(zxd, v1)):max(qubit_loc(zxd, v), qubit_loc(zxd, v1))
frontier_locs[q] = x + 1
end
end
end
end
end
end
for v in vs
if !haskey(locs, v)
v1, v2 = neighbors(zxd, v)
x1, y1 = locs[v1]
x2, y2 = locs[v2]
locs[v] = ((x1+x2)/2, (y1+y2)/2)
end
end
locs_x = [locs[v][1] for v in vs]
locs_y = [locs[v][2] for v in vs]
return locs_x, locs_y
end

function layout2locs(zxg::ZXGraph{T,P}) where {T,P}
lo = zxg.layout
spider_seq = ZXCalculus.spider_sequence(zxg)
vs = sort!(spiders(zxg))
locs = Dict()
nqubit = lo.nbits
frontier_v = ones(T, nqubit)
frontier_locs = ones(nqubit)
phase_gadget_loc = 1.0

for v in vs
if qubit_loc(zxg, v) !== nothing
y = qubit_loc(zxg, v)
x = findfirst(isequal(v), spider_seq[y])
locs[v] = (Float64(x), Float64(y))
else
locs[v] = nothing
end
end
for v in vs
if locs[v] === nothing
nb = neighbors(zxg, v)
if length(nb) == 1
gads = [v]
u = v
w = setdiff(neighbors(zxg, u), gads)[1]
while locs[w] === nothing
push!(gads, w)
u = w
w = setdiff(neighbors(zxg, u), gads)[1]
end
push!(gads, w)
for j = 1:(length(gads) - 1)
locs[gads[length(gads)-j]] = (phase_gadget_loc, Float64(nqubit + j))
end
phase_gadget_loc += 1
end
end
end
for v in vs
if locs[v] === nothing
# println(v)
locs[v] = (phase_gadget_loc, Float64(nqubit + 1))
phase_gadget_loc += 1
end
end
locs_x = [locs[v][1] for v in vs]
locs_y = [locs[v][2] for v in vs]
return locs_x, locs_y
end

function plot(zxd::ZXDiagram; size_x=nothing, size_y=nothing, kwargs...)
g, edgelabel = ZX2Graph(zxd)
nodelabel = ZX2nodelabel(zxd)
nodefillc = ZX2nodefillc(zxd)
edgelabelc = colorant"black"
if zxd.layout.nbits > 0
locs_x, locs_y = layout2locs(zxd)
if size_x === nothing
size_x = maximum(locs_x) - minimum(locs_x)
end
if size_y === nothing
size_y = maximum(locs_y) - minimum(locs_y)
end
set_default_graphic_size(3size_x*cm, 3size_y*cm)
composition = gplot(g,
locs_x, locs_y;
nodelabel = nodelabel, edgelabel = edgelabel, edgelabelc = edgelabelc, nodefillc = nodefillc,
NODESIZE = 1/(2size_x),
kwargs...
# EDGELINEWIDTH = 8.0 / sqrt(nv(g))
)
# draw(SVG("test.svg", size_x*cm, size_y*cm), composition)
else
gplot(g;
nodelabel = nodelabel, edgelabel = edgelabel, edgelabelc = edgelabelc, nodefillc = nodefillc,
kwargs...
# NODESIZE = 0.35 / sqrt(nv(g)), EDGELINEWIDTH = 8.0 / sqrt(nv(g))
)
end
end
function plot(zxd::ZXGraph; size_x=nothing, size_y=nothing, kwargs...)
g, edge_types = ZX2Graph(zxd)

nodelabel = ZX2nodelabel(zxd)
nodefillc = ZX2nodefillc(zxd)
edgestrokec = et2color.(edge_types)
if zxd.layout.nbits > 0
locs_x, locs_y = layout2locs(zxd)
if size_x === nothing
size_x = maximum(locs_x) - minimum(locs_x)
end
if size_y === nothing
size_y = maximum(locs_y) - minimum(locs_y)
end
set_default_graphic_size(3size_x*cm, 3size_y*cm)
gplot(g,
locs_x, locs_y;
nodelabel = nodelabel,
edgestrokec = edgestrokec,
nodefillc = nodefillc,
NODESIZE = 1/(2size_x),
kwargs...
)
else
gplot(g;
nodelabel = nodelabel,
edgestrokec = edgestrokec,
nodefillc = nodefillc,
kwargs...
)
end
end
function plot(zxd::AbstractZXDiagram; backend = :vega, kwargs...)
backend === :vega && return plot_vega(zxd; kwargs...)
backend === :compose && return plot_compose(zxd; kwargs...)
end
110 changes: 110 additions & 0 deletions src/zx_plot_compose.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
using Compose, ZXCalculus

function plot_compose(zxd::Union{ZXDiagram, ZXGraph}; scale = 2)
zxd = copy(zxd)
ZXCalculus.generate_layout!(zxd)
vs = spiders(zxd)
x_locs = zxd.layout.spider_col
x_min = minimum(values(x_locs)) - 0.5
x_max = maximum(values(x_locs)) + 0.5
x_range = x_max - x_min
y_locs = zxd.layout.spider_q
y_min = minimum(values(y_locs)) - 0.5
y_max = maximum(values(y_locs)) + 0.5
y_range = y_max - y_min
x_locs_normal = copy(x_locs)
for (v, x) in x_locs_normal
x_locs_normal[v] = (x - x_min) - 0.25
end
y_locs_normal = copy(y_locs)
for (v, y) in y_locs_normal
y_locs_normal[v] = (y - y_min) - 0.25
end
st = zxd.st
ps = zxd.ps
nodes = generate_nodes(vs, st, ps, 0.1cm, scale)
edges = generate_edges(zxd, x_locs_normal, y_locs_normal, scale)
ct_vs = context()
for v in vs
ct_v = (context(x_locs_normal[v]*scale*cm,
y_locs_normal[v]*scale*cm,
0.5*scale*cm,
0.5*scale*cm
),
nodes[v],
)
ct_vs = compose(context(), ct_vs, ct_v)
end
set_default_graphic_size(x_range*scale*cm, y_range*scale*cm)
return compose(context(), ct_vs, edges)
end

function generate_nodes(vs, st, ps, ftsize, scale)
nodes = Dict()
for v in vs
if st[v] (ZXCalculus.SpiderType.In, ZXCalculus.SpiderType.Out)
spider_shape = :circle
spider_color = "gray"
spider_text = "[$v]"
elseif st[v] == ZXCalculus.SpiderType.H
spider_shape = :box
spider_color = "yellow"
spider_text = "[$v]"
elseif st[v] == ZXCalculus.SpiderType.X
spider_shape = :circle
spider_color = "red"
spider_text = "[$v]" * (iszero(ps[v]) ? "" : "\n$(ps[v])")
elseif st[v] == ZXCalculus.SpiderType.Z
spider_shape = :circle
spider_color = "green"
spider_text = "[$v]" * (iszero(ps[v]) ? "" : "\n$(ps[v])")
end
nodes[v] = (context(),
(context(), text(0.5, 0.5, spider_text, hcenter, vcenter), fontsize(ftsize*scale)),
(context(), (spider_shape === :circle) ? circle() : rectangle(0.25, 0.25, 0.5, 0.5),
fill(spider_color), stroke("black"), linewidth(0.3*scale*mm)),
)
end
return nodes
end

function generate_edges(zxd::ZXDiagram, x_locs_normal, y_locs_normal, scale)
ct_edges = context()
for me in ZXCalculus.edges(zxd.mg)
x_center = (x_locs_normal[me.src]+x_locs_normal[me.dst]+0.5)/2*scale*cm
y_center = (y_locs_normal[me.src]+y_locs_normal[me.dst]+0.5)/2*scale*cm
theta = angle((x_locs_normal[me.dst]-x_locs_normal[me.src])+im*(y_locs_normal[me.dst]-y_locs_normal[me.src]))
theta = rem(theta, pi, RoundDown)
r = Rotation(theta, x_center, y_center)
ct_edges = (context(), ct_edges,
(context(),
text(x_center,
y_center,
((me.mul > 1) ? "× $(me.mul)\n" : ""),
hcenter, vcenter, r
),
fill("black"), fontsize(1.5mm*scale)
),
(context(),
line([((x_locs_normal[me.src]+0.25)*scale*cm, (y_locs_normal[me.src]+0.25)*scale*cm),
((x_locs_normal[me.dst]+0.25)*scale*cm, (y_locs_normal[me.dst]+0.25)*scale*cm)]),
stroke("gray"), linewidth(0.3*scale*mm),
)
)
end
return ct_edges
end
function generate_edges(zxg::ZXGraph, x_locs_normal, y_locs_normal, scale)
ct_edges = context()
for me in ZXCalculus.edges(zxg.mg)
ct_edges = (context(), ct_edges,
(context(),
line([((x_locs_normal[me.src]+0.25)*scale*cm, (y_locs_normal[me.src]+0.25)*scale*cm),
((x_locs_normal[me.dst]+0.25)*scale*cm, (y_locs_normal[me.dst]+0.25)*scale*cm)]),
stroke(ZXCalculus.is_hadamard(zxg, me.src, me.dst) ? "blue" : "black"),
linewidth(0.3*scale*mm),
)
)
end
return ct_edges
end
Loading

0 comments on commit f04e3f4

Please sign in to comment.