Skip to content

Commit

Permalink
Basic circuit vizualization (#22)
Browse files Browse the repository at this point in the history
* new circuit plot

* fine tune

* fine tune

* rm 1.0 travis CI

* make CNOT a CNOT

* general update

Co-authored-by: Rogerluo <rogerluo.rl18@gmail.com>
  • Loading branch information
GiggleLiu and Roger-luo authored Sep 2, 2020
1 parent 8c3fe5a commit 227cba8
Show file tree
Hide file tree
Showing 16 changed files with 288 additions and 16 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ language: julia
notifications:
email: false
julia:
- 1.0
- 1.5
- nightly
os:
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Compose = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
Viznet = "52a3aca4-6234-47fd-b74a-806bdf78ede9"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
ZXCalculus = "3525faa3-032d-4235-a8d4-8c2939a218dd"

[compat]
Expand Down
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,21 @@

[![Build Status](https://travis-ci.com/QuantumBFS/YaoPlots.jl.svg?branch=master)](https://travis-ci.com/QuantumBFS/YaoPlots.jl)
[![Coverage](https://codecov.io/gh/QuantumBFS/YaoPlots.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/QuantumBFS/YaoPlots.jl)

## Example 1: Visualize a QBIR define in Yao

```julia
using YaoExtensions, YaoPlots
using Compose

# show a qft circuit
plot(qft_circuit(5))
```

If you are using a Pluto/Jupyter notebook, Atom/VSCode editor, you should see the following image in your plotting panel.

![qft](examples/qft.png)

Otherwise, you might be interested to learn [how to save it as an image](https://giovineitalia.github.io/Compose.jl/latest/tutorial/).

See more [examples](examples/circuits.jl).
23 changes: 23 additions & 0 deletions examples/circuits.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using YaoExtensions, YaoPlots
using Compose, Cairo

_save(str) = PNG(joinpath(@__DIR__, str))

# qft circuit
vizcircuit(qft_circuit(5)) |> _save("qft.png")

# variational circuit
vizcircuit(variational_circuit(5)) |> _save("variational.png")
# vizcircuit(variational_circuit(5; mode=:Merged))

# general U4 gate
vizcircuit(general_U4()) |> _save("u4.png")

# quantum supremacy circuit
vizcircuit(rand_supremacy2d(2, 2, 8)) |> _save("supremacy2d.png")

# google 52 qubit
vizcircuit(rand_google53(10)) |> _save("google53.png")

# control blocks
vizcircuit(chain(control(5, (2,-3), 4=>X), control(5, (-4, -2), 1=>Z))) |> _save("controls.png")
Binary file added examples/controls.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/google53.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/qft.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/supremacy2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/u4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/variational.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions src/YaoPlots.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module YaoPlots

using Compose

include("vizcircuit.jl")
include("zx_plot.jl")

end
182 changes: 182 additions & 0 deletions src/vizcircuit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
using Viznet: canvas
using YaoBlocks

export CircuitStyles, CircuitGrid, circuit_canvas, vizcircuit

module CircuitStyles
using Compose
const r = Ref(0.2)
const lw = Ref(1pt)
const textsize = Ref(16pt)
const paramtextsize = Ref(10pt)
const fontfamily = Ref("Helvetica Neue")
G() = compose(context(), rectangle(-r[], -r[], 2*r[], 2*r[]), fill("white"), stroke("black"), linewidth(lw[]))
C() = compose(context(), circle(0.0, 0.0, r[]/3), fill("black"))
NC() = compose(context(), circle(0.0, 0.0, r[]/3), fill("white"), stroke("black"), linewidth(lw[]))
X() = compose(context(), xgon(0.0, 0.0, r[], 4), fill("black"))
NOT() = compose(context(),
(context(), circle(0.0, 0.0, r[]), stroke("black"), linewidth(lw[]), fill("transparent")),
(context(), polygon([(-r[], 0.0), (r[], 0.0)]), stroke("black"), linewidth(lw[])),
(context(), polygon([(0.0, -r[]), (0.0, r[])]), stroke("black"), linewidth(lw[]))
)
WG() = compose(context(), rectangle(-1.5*r[], -r[], 3*r[], 2*r[]), fill("white"), stroke("black"), linewidth(lw[]))
LINE() = compose(context(), line(), stroke("black"), linewidth(lw[]))
TEXT() = compose(context(), text(0.0, 0.0, "", hcenter, vcenter), fontsize(textsize[]), font(fontfamily[]))
PARAMTEXT() = compose(context(), text(0.0, 0.0, "", hcenter, vcenter), fontsize(paramtextsize[]), font(fontfamily[]))
function setlw(_lw)
lw[] = _lw
end
function setr(_r)
r[] = _r
end
function settextsize(_textsize)
textsize[] = _textsize
end
function setparamtextsize(_paramtextsize)
paramtextsize[] = _paramtextsize
end
function setfontfamily(_fontfamily)
fontfamily[] = _fontfamily
end
end

struct CircuitGrid
frontier::Vector{Int}
w_depth::Float64
w_line::Float64
end

nline(c::CircuitGrid) = length(c.frontier)
depth(c::CircuitGrid) = frontier(c, 1, nline(c))
Base.getindex(c::CircuitGrid, i, j) = (c.w_depth*i, c.w_line*j)
Base.typed_vcat(c::CircuitGrid, ij1, ij2) = (c[ij1...], c[ij2...])

function CircuitGrid(nline::Int; w_depth=1.0, w_line=1.0)
CircuitGrid(zeros(Int, nline), w_depth, w_line)
end

function frontier(c::CircuitGrid, args...)
maximum(i->c.frontier[i], min(args..., nline(c)):max(args..., 1))
end

function _draw!(c::CircuitGrid, loc_brush_texts)
locs = getindex.(loc_brush_texts, 1)
i = frontier(c, locs...) + 1
local jpre
loc_brush_texts = sort(loc_brush_texts, by=x->x[1])
for (k, (j, b, txt)) in enumerate(loc_brush_texts)
b >> c[i, j]
if length(txt) >= 3
CircuitStyles.PARAMTEXT() >> (c[i, j], txt)
elseif length(txt) >= 1
CircuitStyles.TEXT() >> (c[i, j], txt)
end
if k!=1
CircuitStyles.LINE() >> c[(i, j); (i, jpre)]
end
jpre = j
end

jmin, jmax = min(locs..., nline(c)), max(locs..., 1)
for j = jmin:jmax
CircuitStyles.LINE() >> c[(i, j); (c.frontier[j], j)]
c.frontier[j] = i
end
end

function finalize!(c::CircuitGrid)
i = frontier(c, 1, nline(c)) + 1
for j=1:nline(c)
CircuitStyles.LINE() >> c[(i, j-0.2); (i, j+0.2)]
CircuitStyles.LINE() >> c[(i, j); (c.frontier[j], j)]
end
c.frontier .= i
end

function draw!(c::CircuitGrid, b::AbstractBlock, address)
error("block type $(typeof(b)) does not support visualization.")
end

function draw!(c::CircuitGrid, p::ChainBlock{N}, address) where N
draw!.(Ref(c), subblocks(p), Ref(address))
end

function draw!(c::CircuitGrid, p::PutBlock{N,1,<:PrimitiveBlock}, address) where N
locs = [address[p.locs[1]]]
draw!(c, p.content, [address[p.locs[1]]])
end

function draw!(c::CircuitGrid, p::PutBlock{N,M,<:PrimitiveBlock}, address) where {N,M}
locs = [address[i] for i in p.locs]
_draw!(c, [(loc, CircuitStyles.G(), "") for loc in locs])
end

function draw!(c::CircuitGrid, p::PrimitiveBlock{1}, address)
_draw!(c, [(address[], get_brush_text(p)...)])
end

function draw!(c::CircuitGrid, p::PutBlock{N,M,<:ChainBlock}, address) where {N,M}
locs = [address[i] for i in p.locs]
draw!.(Ref(c), subblocks(p.content), Ref(locs))
end

function draw!(c::CircuitGrid, p::PutBlock{N,2,<:SWAPGate}, address) where N
locs = [address[i] for i in p.locs]
_draw!(c, [(locs[1], CircuitStyles.X(), ""), (locs[2], CircuitStyles.X(), "")])
end

function draw!(c::CircuitGrid, cb::ControlBlock{N,GT,C,1}, address) where {N,GT,C}
ctrl_locs = [address[i] for i in cb.ctrl_locs]
locs = [address[i] for i in cb.locs]
_draw!(c, [[(loc, (bit == 1 ? CircuitStyles.C() : CircuitStyles.NC()), "") for (loc, bit)=zip(ctrl_locs, cb.ctrl_config)]..., (locs..., get_cbrush_text(cb.content)...)])
end

for (GATE, SYM) in [(:XGate, :Rx), (:YGate, :Ry), (:ZGate, :Rz)]
@eval get_brush_text(b::RotationGate{1,T,<:$GATE}) where T = (CircuitStyles.WG(), "$($(SYM))($(pretty_angle(b.theta)))")
end

pretty_angle(theta) = theta
pretty_angle(theta::Float64) = round(theta; digits=2)

get_brush_text(b::PrimitiveBlock{1}) = (CircuitStyles.G(), "")
get_brush_text(b::ShiftGate) = (CircuitStyles.WG(), "ϕ($(pretty_angle(b.theta)))")
get_brush_text(b::PhaseGate) = (CircuitStyles.WG(), "$(pretty_angle(b.theta))im")
get_brush_text(b::T) where T<:ConstantGate = (CircuitStyles.G(), string(T.name.name)[1:end-4])

get_cbrush_text(b::AbstractBlock) = get_brush_text(b)
get_cbrush_text(b::XGate) = (CircuitStyles.NOT(), "")
get_cbrush_text(b::ZGate) = (CircuitStyles.C(), "")

# front end
plot(blk::AbstractBlock; kwargs...) = vizcircuit(blk; kwargs...)
function vizcircuit(blk::AbstractBlock; w_depth=0.85, w_line=0.75, scale=1.0)
circuit_canvas(nqubits(blk); w_depth=w_depth, w_line=w_line) do c
basicstyle(blk) >> c
end |> rescale(scale)
end

function circuit_canvas(f, nline::Int; w_depth=0.85, w_line=0.75)
c = CircuitGrid(nline; w_depth=w_depth, w_line=w_line)
g = canvas() do
f(c)
finalize!(c)
end
a, b = (depth(c)+1)*w_depth, nline*w_line
Compose.set_default_graphic_size(a*2.5*cm, b*2.5*cm)
compose(context(0.5/a, -0.5/b, 1/a, 1/b), g)
end

Base.:>>(blk::AbstractBlock{N}, c::CircuitGrid) where N = draw!(c, blk, collect(1:N))
Base.:>>(blk::Function, c::CircuitGrid) = blk(nline(c)) >> c

function rescale(factor)
a, b = Compose.default_graphic_width, Compose.default_graphic_height
Compose.set_default_graphic_size(a*factor, b*factor)
graph -> compose(context(), graph)
end

vizcircuit(; kwargs...) = c->vizcircuit(c; kwargs...)

function basicstyle(blk::AbstractBlock)
YaoBlocks.Optimise.simplify(blk, rules=[YaoBlocks.Optimise.to_basictypes])
end
1 change: 0 additions & 1 deletion src/zx_plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ using LightGraphs
using GraphPlot: gplot
using Colors
using ZXCalculus: qubit_loc
using Compose

export plot

Expand Down
19 changes: 5 additions & 14 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
using YaoPlots, ZXCalculus, LightGraphs
using Test

# @testset "YaoPlots.jl" begin
# # Write your tests here.
# end
@testset "vizcircuit" begin
include("vizcircuit.jl")
end

@testset "zx_plot.jl" begin
g = Multigraph(6)
for e in [[1,3],[2,3],[3,4],[4,5],[4,6]]
add_edge!(g, e)
end
ps = [0, 0, 0//1, 2//1, 0, 0]
v_t = [SpiderType.In, SpiderType.Out, SpiderType.X, SpiderType.Z, SpiderType.Out, SpiderType.In]
zxd = ZXDiagram(g, v_t, ps)
plot(zxd)
replace!(Rule{:b}(), zxd)
plot(zxd)
@testset "zx_plot" begin
include("zx_plot.jl")
end
39 changes: 39 additions & 0 deletions test/vizcircuit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using YaoPlots
using Compose
using Test
using YaoBlocks

@testset "gate styles" begin
@test YaoPlots.get_brush_text(X)[2] == "X"
@test YaoPlots.get_brush_text(Rx(0.5))[2] == "Rx(0.5)"
@test YaoPlots.get_brush_text(shift(0.5))[2] == "ϕ(0.5)"
@test YaoPlots.get_brush_text(YaoBlocks.phase(0.5))[2] == "0.5im"
end

@testset "circuit canvas" begin
c = CircuitGrid(5)
@test YaoPlots.nline(c) == 5
@test YaoPlots.frontier(c, 2, 3) == 0
@test YaoPlots.depth(c) == 0
circuit_canvas(5) do c
YaoPlots.draw!(c, put(5, 3=>X), 1:5)
@test YaoPlots.frontier(c, 1, 2) == 0
@test YaoPlots.frontier(c, 3, 5) == 1
@test YaoPlots.depth(c) == 1
end

gg = circuit_canvas(5) do c
put(3=>X) >> c
control(2, 3=>X) >> c
chain(5, control(2, 3=>X), put(1=>X)) >> c
@test YaoPlots.depth(c) == 3
end
@test gg isa Context

g = put(5, (3, 4)=>SWAP) |> vizcircuit(; scale=0.7, w_line=0.8, w_depth=0.9)
@test g isa Context
@test vizcircuit(put(5, (3,4)=>kron(X, Y)); scale=0.7, w_line=0.8, w_depth=0.9) isa Context

@test vizcircuit(control(10, (2, -3), 6=>X)) isa Context
@test plot(put(7, (2,3)=>matblock(randn(4,4)))) isa Context
end
16 changes: 16 additions & 0 deletions test/zx_plot.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using Test, YaoPlots
using ZXCalculus

@testset "zx plot" begin
g = Multigraph(6)
for e in [[1,3],[2,3],[3,4],[4,5],[4,6]]
add_edge!(g, e)
end
ps = [0, 0, 0//1, 2//1, 0, 0]
v_t = [SpiderType.In, SpiderType.Out, SpiderType.X, SpiderType.Z, SpiderType.Out, SpiderType.In]
zxd = ZXDiagram(g, v_t, ps)
plot(zxd)
replace!(Rule{:b}(), zxd)
plt = plot(zxd)
@test plt !== nothing
end

0 comments on commit 227cba8

Please sign in to comment.