Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move some utilities in YaoExtensions to Yao.EasyBuild submodule. #315

Merged
merged 9 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/EasyBuild/block_extension/Bag.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
export AbstractBag
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
"""
AbstractBag{BT, N}<:TagBlock{BT, N}

Abstract `Bag` is a wrapper of a block that conserves all properties.
Including `mat`, `apply!`, `ishermitian`, `isreflexive`, `isunitary`,
`occupied_locs`, `apply_back!` and `mat_back!`.
"""
abstract type AbstractBag{BT, N}<:TagBlock{BT, N} end

YaoAPI.mat(::Type{T}, bag::AbstractBag{N}) where {T,N} = mat(T, content(bag))
_apply!(reg::AbstractRegister, bag::AbstractBag) = _apply!(reg, content(bag))
LinearAlgebra.ishermitian(bag::AbstractBag) = ishermitian(content(bag))
YaoAPI.isreflexive(bag::AbstractBag) = isreflexive(content(bag))
YaoAPI.isunitary(bag::AbstractBag) = isunitary(content(bag))
YaoAPI.occupied_locs(bag::AbstractBag) = occupied_locs(content(bag))

function YaoBlocks.AD.apply_back!(state, b::AbstractBag, collector)
YaoBlocks.AD.apply_back!(state, content(b), collector)
end
function YaoBlocks.AD.mat_back!(::Type{T}, b::AbstractBag, adjy, collector) where T
YaoBlocks.AD.mat_back!(T, content(b), adjy, collector)
end

export Bag, enable_block!, disable_block!, setcontent!, isenabled
"""
Bag{N}<:TagBlock{AbstractBlock, N}

A bag is a trivil container, but can
* `setcontent!(bag, content)`
* `disable_block!(bag)`
* `enable_block!(bag)`
"""
mutable struct Bag{N}<:AbstractBag{AbstractBlock, N}
content::AbstractBlock{N}
mask::Bool
end
Bag(b::AbstractBlock) = Bag(b,true)

YaoBlocks.content(bag::Bag{N}) where N = bag.mask ? bag.content : put(N, 1=>I2)
YaoBlocks.chcontent(bag::Bag, content) = Bag(content)
setcontent!(bag::Bag, content) = (bag.content = content; bag)
disable_block!(b::Bag) = (b.mask = false; b)
enable_block!(b::Bag) = (b.mask = true; b)
isenabled(b::Bag) = b.mask

function YaoBlocks.print_annotation(io::IO, bag::Bag)
printstyled(io, isenabled(bag) ? "[⊙] " : "[⊗] "; bold=true, color=isenabled(bag) ? :green : :red)
end

function Base.show(io::IO, ::MIME"plain/text", blk::Bag)
return print_tree(io, blk; title=false, compact=false)
end
33 changes: 33 additions & 0 deletions src/EasyBuild/block_extension/ConditionBlock.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using YaoBlocks: print_prefix
export ConditionBlock, condition

struct ConditionBlock{N, BTT<:AbstractBlock{N}, BTF<:AbstractBlock{N}} <: CompositeBlock{N}
m::Measure
block_true::BTT
block_false::BTF
function ConditionBlock(m::Measure, block_true::BTT, block_false::BTF) where {N, BTT<:AbstractBlock{N}, BTF<:AbstractBlock{N}}
new{N, BTT, BTF}(m, block_true, block_false)
end
end

YaoAPI.subblocks(c::ConditionBlock) = (c.m, c.block_true, c.block_false)
YaoAPI.chsubblocks(c::ConditionBlock, blocks) = ConditionBlock(blocks...)

function _apply!(reg::AbstractRegister{B}, c::ConditionBlock) where B
if !isdefined(c.m, :results)
println("Conditioned on a measurement that has not been performed.")
throw(UndefRefError())
end
for i = 1:B
viewbatch(reg, i) |> (c.m.results[i] == 0 ? c.block_false : c.block_true)
end
reg
end

condition(m, a::AbstractBlock{N}, b::Nothing) where N = ConditionBlock(m, a, eyeblock(N))
condition(m, a::Nothing, b::AbstractBlock{N}) where N = ConditionBlock(m, eyeblock(N), b)
YaoAPI.mat(c::ConditionBlock) = throw(ArgumentError("ConditionBlock does not has matrix representation, try `mat(c.block_true)` or `mat(c.block_false)`"))

function YaoBlocks.print_block(io::IO, c::ConditionBlock)
print(io, "if result(id = $(objectid(c.m)))")
end
17 changes: 17 additions & 0 deletions src/EasyBuild/block_extension/EchoBlock.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
export EchoBlock
struct EchoBlock{N,OT<:IO} <: PrimitiveBlock{N}
sym::Symbol
io::OT
end

EchoBlock(nbits::Int, sym::Symbol) = EchoBlock{nbits, typeof(stdout)}(sym, stdout)
EchoBlock(nbits::Int) = EchoBlock(nbits, :ECHO)
EchoBlock() = nbits->EchoBlock(nbits)

_apply!(reg::AbstractRegister, ec::EchoBlock) = (println(ec.io, "apply!(::$(typeof(reg)), $(ec.sym))"); reg)
YaoAPI.mat(::Type{T}, ec::EchoBlock{N}) where {T,N} = (println(ec.io, "mat(::Type{$T}, $(ec.sym))"); IMatrix{1<<N}())
LinearAlgebra.ishermitian(ec::EchoBlock{N}) where N = (println(ec.io, "ishermitian($(ec.sym))"); true)
YaoAPI.isunitary(ec::EchoBlock{N}) where N = (println(ec.io, "isunitary($(ec.sym))"); true)
YaoAPI.isreflexive(ec::EchoBlock{N}) where N = (println(ec.io, "isreflexive($(ec.sym))"); true)
YaoAPI.getiparams(ec::EchoBlock{N}) where N = (println(ec.io, "getiparams($(ec.sym))");())
YaoBlocks.print_block(io::IO, ec::EchoBlock) = print(io, "EchoBlock($(ec.sym))")
31 changes: 31 additions & 0 deletions src/EasyBuild/block_extension/FSimGate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
export FSimGate

# https://arxiv.org/pdf/1711.04789.pdf
# Google supremacy paper
mutable struct FSimGate{T<:Number} <: PrimitiveBlock{2}
theta::T
phi::T
end

function Base.:(==)(fs1::FSimGate, fs2::FSimGate)
return fs1.theta == fs2.theta && fs1.phi == fs2.phi
end

function YaoAPI.mat(::Type{T}, fs::FSimGate) where T
θ, ϕ = fs.theta, fs.phi
T[1 0 0 0;
0 cos(θ) -im*sin(θ) 0;
0 -im*sin(θ) cos(θ) 0;
0 0 0 exp(-im*ϕ)]
end

YaoAPI.iparams_eltype(::FSimGate{T}) where T = T
YaoAPI.getiparams(fs::FSimGate{T}) where T = (fs.theta, fs.phi)
function YaoAPI.setiparams!(fs::FSimGate{T}, θ, ϕ) where T
fs.theta = θ
fs.phi = ϕ
return fs
end

YaoBlocks.@dumpload_fallback FSimGate FSimGate
YaoBlocks.Optimise.to_basictypes(fs::FSimGate) = fsim_gate(fs.theta, fs.phi)
90 changes: 90 additions & 0 deletions src/EasyBuild/block_extension/RotBasis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
export RotBasis, randpolar, polar2u, u2polar, rot_basis, basis_rotor

"""
RotBasis{T} <: PrimitiveBlock{1, Complex{T}}

A special rotation block that transform basis to angle θ and ϕ in bloch sphere.
"""
mutable struct RotBasis{T} <: PrimitiveBlock{1}
theta::T
phi::T
end

_make_rot_mat(I, block, theta) = I * cos(theta / 2) - im * sin(theta / 2) * block
# chain -> *
# mat(rb::RotBasis{T}) where T = mat(Ry(-rb.theta))*mat(Rz(-rb.phi))
function YaoAPI.mat(::Type{TM}, x::RotBasis{T}) where {TM, T}
R1 = _make_rot_mat(IMatrix{2, Complex{T}}(), mat(TM, Z), -x.phi)
R2 = _make_rot_mat(IMatrix{2, Complex{T}}(), mat(TM, Y), -x.theta)
R2 * R1
end

Base.:(==)(rb1::RotBasis, rb2::RotBasis) = rb1.theta == rb2.theta && rb1.phi == rb2.phi

Base.copy(block::RotBasis{T}) where T = RotBasis{T}(block.theta, block.phi)
YaoAPI.dispatch!(block::RotBasis, params::Vector) = ((block.theta, block.phi) = params; block)

YaoAPI.getiparams(rb::RotBasis) = (rb.theta, rb.phi)
function YaoAPI.setiparams!(rb::RotBasis, θ::Real, ϕ::Real)
rb.theta, rb.phi = θ, ϕ
rb
end
YaoAPI.niparams(::Type{<:RotBasis}) = 2
YaoAPI.niparams(::RotBasis) = 2
YaoBlocks.render_params(r::RotBasis, ::Val{:random}) = rand()*π, rand()*2π

function YaoBlocks.print_block(io::IO, R::RotBasis)
print(io, "RotBasis($(R.theta), $(R.phi))")
end

function Base.hash(gate::RotBasis, h::UInt)
hash(hash(gate.theta, gate.phi, objectid(gate)), h)
end

YaoBlocks.cache_key(gate::RotBasis) = (gate.theta, gate.phi)

rot_basis(num_bit::Int) = dispatch!(chain(num_bit, put(i=>RotBasis(0.0, 0.0)) for i=1:num_bit), randpolar(num_bit) |> vec)

"""
u2polar(vec::Array) -> Array

transform su(2) state vector to polar angle, apply to the first dimension of size 2.
"""
function u2polar(vec::Vector)
ratio = vec[2]/vec[1]
[atan(abs(ratio))*2, angle(ratio)]
end

"""
polar2u(vec::Array) -> Array

transform polar angle to su(2) state vector, apply to the first dimension of size 2.
"""
function polar2u(polar::Vector)
theta, phi = polar
[cos(theta/2)*exp(-im*phi/2), sin(theta/2)*exp(im*phi/2)]
end

u2polar(arr::Array) = mapslices(u2polar, arr, dims=[1])
polar2u(arr::Array) = mapslices(polar2u, arr, dims=[1])

"""
randpolar(params::Int...) -> Array

random polar basis, number of basis
"""
randpolar(params::Int...) = rand(2, params...)*pi

"""
basis_rotor(::ZGate) -> AbstractBlock
basis_rotor(basis::PauliGate, nbit, locs) -> AbstractBlock

Return a block to rotate the basis to pauli basis for measurements.
"""
basis_rotor(::ZGate) = I2Gate()
basis_rotor(::XGate) = Ry(-0.5π)
basis_rotor(::YGate) = Rx(0.5π)

basis_rotor(basis::YaoBlocks.PauliGate, nbit, locs) = repeat(nbit, basis_rotor(basis), locs)

YaoBlocks.@dumpload_fallback RotBasis RotBasis
9 changes: 9 additions & 0 deletions src/EasyBuild/block_extension/blocks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import .YaoBlocks: _apply!
include("shortcuts.jl")
include("EchoBlock.jl")
include("Bag.jl")
include("ConditionBlock.jl")
include("RotBasis.jl")
include("reflect_gate.jl")
include("pauli_strings.jl")
include("FSimGate.jl")
107 changes: 107 additions & 0 deletions src/EasyBuild/block_extension/pauli_strings.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import YaoBlocks.YaoArrayRegister.StaticArrays: SizedVector
import YaoBlocks.YaoArrayRegister.StatsBase
using YaoBlocks: PauliGate
export PauliString

# TODO: expand to Clifford?
struct PauliString{N, BT <: ConstantGate{1}, VT <: SizedVector{N, BT}} <: CompositeBlock{N}
blocks::VT
PauliString(blocks::SizedVector{N, BT}) where {N, BT <: ConstantGate{1}} =
new{N, BT, typeof(blocks)}(blocks)
end

"""
PauliString(xs::PauliGate...)

Create a `PauliString` from some Pauli gates.

# Example

```julia
julia> PauliString(X, Y, Z)
nqubits: 3
PauliString
├─ X gate
├─ Y gate
└─ Z gate
```
"""
PauliString(xs::PauliGate...) = PauliString(SizedVector{length(xs), PauliGate}(xs))

"""
PauliString(list::Vector)

Create a `PauliString` from a list of Pauli gates.

# Example

```julia
julia> PauliString([X, Y, Z])
nqubits: 3
PauliString
├─ X gate
├─ Y gate
└─ Z gate
```
"""
function PauliString(xs::Vector)
for each in xs
if !(each isa PauliGate)
error("expect pauli gates")
end
end
return PauliString(SizedVector{length(xs), PauliGate}(xs))
end

YaoAPI.subblocks(ps::PauliString) = ps.blocks
YaoAPI.chsubblocks(pb::PauliString, blocks::Vector) = PauliString(blocks)
YaoAPI.chsubblocks(pb::PauliString, it) = PauliString(collect(it))

YaoAPI.occupied_locs(ps::PauliString) = (findall(x->!(x isa I2Gate), ps.blocks)...,)

YaoBlocks.cache_key(ps::PauliString) = map(cache_key, ps.blocks)

LinearAlgebra.ishermitian(::PauliString) = true
YaoAPI.isreflexive(::PauliString) = true
YaoAPI.isunitary(::PauliString) = true

Base.copy(ps::PauliString) = PauliString(copy(ps.blocks))
Base.getindex(ps::PauliString, x) = getindex(ps.blocks, x)
Base.lastindex(ps::PauliString) = lastindex(ps.blocks)
Base.iterate(ps::PauliString) = iterate(ps.blocks)
Base.iterate(ps::PauliString, st) = iterate(ps.blocks, st)
Base.length(ps::PauliString) = length(ps.blocks)
Base.eltype(ps::PauliString) = eltype(ps.blocks)
Base.eachindex(ps::PauliString) = eachindex(ps.blocks)
Base.getindex(ps::PauliString, index::Union{UnitRange, Vector}) =
PauliString(getindex(ps.blocks, index))
function Base.setindex!(ps::PauliString, v::PauliGate, index::Union{Int})
ps.blocks[index] = v
return ps
end

function Base.:(==)(lhs::PauliString{N}, rhs::PauliString{N}) where N
(length(lhs.blocks) == length(rhs.blocks)) && all(lhs.blocks .== rhs.blocks)
end

xgates(ps::PauliString{N}) where N = RepeatedBlock{N}(X, (findall(x->x isa XGate, (ps.blocks...,))...,))
ygates(ps::PauliString{N}) where N = RepeatedBlock{N}(Y, (findall(x->x isa YGate, (ps.blocks...,))...,))
zgates(ps::PauliString{N}) where N = RepeatedBlock{N}(Z, (findall(x->x isa ZGate, (ps.blocks...,))...,))

function _apply!(reg::AbstractRegister, ps::PauliString)
for pauligates in [xgates, ygates, zgates]
blk = pauligates(ps)
_apply!(reg, blk)
end
return reg
end

function YaoAPI.mat(::Type{T}, ps::PauliString) where T
return mat(T, xgates(ps)) * mat(T, ygates(ps)) * mat(T, zgates(ps))
end

function YaoBlocks.print_block(io::IO, x::PauliString)
printstyled(io, "PauliString"; bold=true, color=YaoBlocks.color(PauliString))
end

YaoBlocks.color(::Type{T}) where {T <: PauliString} = :cyan
Loading