Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Jun 5, 2018
2 parents 2abc58f + b6446b2 commit a5ce710
Show file tree
Hide file tree
Showing 20 changed files with 222 additions and 41 deletions.
2 changes: 1 addition & 1 deletion example/test_qft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ reg = rand_state(num_bit)
rv = copy(statevec(reg))

rotgate(gate::AbstractMatrix, θ::Real) = expm(-0.5im*θ*Matrix(gate))
apply2mat(applyfunc!::Function, num_bit::Int) = applyfunc!(eye(Complex128, 1<<num_bit))
apply2mat(applyfunc!::Function, num_bit::Int) = applyfunc!(eye(ComplexF64, 1<<num_bit))

@testset "fft" begin
@test Matrix(mat(chain(IQFT(3), QFT(3)))) eye(1<<3)
Expand Down
2 changes: 2 additions & 0 deletions src/Blocks/IOSyntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ end
# Color Traits
color(::Type{T}) where {T <: Roller} = :cyan
color(::Type{T}) where {T <: KronBlock} = :cyan
color(::Type{T}) where {T <: RepeatedBlock} = :cyan
color(::Type{T}) where {T <: ChainBlock} = :blue
color(::Type{T}) where {T <: ControlBlock} = :red
color(::Type{T}) where {T <: Swap} = :magenta

# Default Charset
BlockTreeCharSet() = BlockTreeCharSet('','','','')
Expand Down
7 changes: 1 addition & 6 deletions src/Blocks/Primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@ method to enable key value cache.
"""
abstract type PrimitiveBlock{N, T} <: MatrixBlock{N, T} end

# include("ConstantGate.jl")
include("ConstGate.jl")
include("PhaseGate.jl")
include("RotationGate.jl")
# include("CachedBlock.jl")

# TODO:
# 1. new Primitive: SWAP gate
# include("SwapGate.jl")
include("SwapGate.jl")
15 changes: 13 additions & 2 deletions src/Blocks/Repeated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ mutable struct RepeatedBlock{N, T, GT<:MatrixBlock} <: CompositeBlock{N, T}
block::GT
lines::Vector{Int}

function RepeatedBlock{N, T}(block::GT) where {N, T, GT <: MatrixBlock}
function RepeatedBlock{N}(block::GT) where {N, M, T, GT <: MatrixBlock{M, T}}
new{N, T, GT}(block, Vector{Int}(1:N))
end

function RepeatedBlock{N, T}(block::GT, lines::Vector{Int}) where {N, T, GT <: MatrixBlock}
function RepeatedBlock{N}(block::GT, lines::Vector{Int}) where {N, M, T, GT <: MatrixBlock{M, T}}
new{N, T, GT}(block, lines)
end
end
Expand Down Expand Up @@ -40,3 +40,14 @@ end
function ==(lhs::RepeatedBlock{N, T, GT}, rhs::RepeatedBlock{N, T, GT}) where {N, T, GT}
(lhs.block == rhs.block) && (lhs.lines == rhs.lines)
end

function print_block(io::IO, rb::RepeatedBlock{N}) where N
printstyled(io, "repeat on ("; bold=true, color=color(RepeatedBlock))
for i in eachindex(rb.lines)
printstyled(io, rb.lines[i]; bold=true, color=color(RepeatedBlock))
if i != lastindex(rb.lines)
printstyled(io, ", "; bold=true, color=color(RepeatedBlock))
end
end
printstyled(io, ")"; bold=true, color=color(RepeatedBlock))
end
81 changes: 77 additions & 4 deletions src/Blocks/SwapGate.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,82 @@
export Swap

struct Swap{N, T} <: PrimitiveBlock{N, T}
line1::UInt
line2::UInt
end
line1::Int
line2::Int

Swap(::Type{T}, N::Int, addr1, addr2) where T = Swap{N, T}(addr1, addr2)
Swap{N, T}(line1::Int, line2::Int) where {N, T} = new{N, T}(line1, line2)
end

function mat(g::Swap{N, T}) where {N, T}
mask = bmask(g.line1, g.line2)
order = map(b->swapbits(b, mask) + 1, basis(N))
PermMatrix(order, ones(T, 1<<N))
end

function apply!(r::AbstractRegister{1}, rb::Swap)
if nremains(r) == 0
swapapply!(vec(state(r)), rb.line1, rb.line2)
else
swapapply!(state(r), rb.line1, rb.line2)
end
end

function apply!(r::AbstractRegister, rb::Swap)
swapapply!(state(r), rb.line1, rb.line2)
end

function swapapply!(state::Matrix{T}, b1::Int, b2::Int) where T
mask1 = bmask(b1)
mask2 = bmask(b2)
mask12 = mask1|mask2
M, N = size(state)

@simd for b = basis(state)
local temp::T
local i_::Int
if b&mask1==0 && b&mask2==mask2
i = b+1
i_ = b mask12 + 1
@simd for c = 1:N
@inbounds temp = state[i, c]
@inbounds state[i, c] = state[i_, c]
@inbounds state[i_, c] = temp
end
end
end
state
end

function swapapply!(state::Vector{T}, b1::Int, b2::Int) where T
mask1 = bmask(b1)
mask2 = bmask(b2)
mask12 = mask1|mask2
M = length(state)

@simd for b = basis(state)
local temp::T
local i_::Int
if b&mask1==0 && b&mask2==mask2
i = b+1
i_ = b mask12 + 1
@inbounds temp = state[i]
@inbounds state[i] = state[i_]
@inbounds state[i_] = temp
end
end
state
end

function hash(swap::Swap, h::UInt)
hashkey = hash(swap.line1, h)
hashkey = hash(swap.line2, hashkey)
hashkey
end

function ==(lhs::Swap, rhs::Swap)
(lhs.line1 == rhs.line1) && (lhs.line2 == rhs.line2)
end

function print_block(io::IO, swap::Swap)
printstyled(io, "swap(", swap.line1, ", ", swap.line2, ")"; bold=true, color=color(Swap))
end
2 changes: 2 additions & 0 deletions src/Boost/Boost.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ export xgate, ygate, zgate
export cxgate, cygate, czgate
export controlled_U1

export xapply!, yapply!, zapply!
export cxapply!, cyapply!, czapply!

include("gates.jl")
include("applys.jl")
Expand Down
2 changes: 0 additions & 2 deletions src/Boost/gates.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
####################### Gate Utilities ######################


###################### X, Y, Z Gates ######################
"""
xgate(::Type{MT}, num_bit::Int, bits::Ints) -> PermMatrix
Expand Down Expand Up @@ -166,7 +165,6 @@ function controlled_U1(num_bit::Int, gate::AbstractMatrix, cbits::Vector{Int}, c
general_controlled_gates(num_bit, [c==1 ? mat(P1) : mat(P0) for c in cvals], cbits, [gate], [b2])
end


# arbituary control PermMatrix gate: SparseMatrixCSC
# TODO: to interface
#toffoligate(num_bit::Int, b1::Int, b2::Int, b3::Int) = controlled_U1(num_bit, PAULI_X, [b1, b2], b3)
49 changes: 41 additions & 8 deletions src/Interfaces/Composite.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,40 @@
function parse_block(n::Int, x::Pair{Int, BT}) where {BT <: MatrixBlock}
kron(n, x)
end

function parse_block(n::Int, x::Pair{I, BT}) where {I, BT <: MatrixBlock}
kron(n, i=>x.second for i in x.first)
end

function parse_block(n::Int, x::Pair{Int, BT}) where {BT <: ConstantGate}
repeat(n, x.second, [x.first, ])
end

function parse_block(n::Int, x::Pair{I, BT}) where {I, BT <: ConstantGate}
repeat(n, x.second, collect(x.first))
end

function parse_block(n::Int, x::Function)
x(n)
end

function parse_block(n::Int, x::MatrixBlock{N}) where N
n == N || throw(ArgumentError("number of qubits does not match: $x"))
x
end

# 2. composite blocks
# 2.1 chain block
export chain

chain(n::Int) = ChainBlock(MatrixBlock{n}[])
chain() = n -> chain(n)

function chain(n, blocks)
_2block(x::Function) = x(n)
_2block(x::MatrixBlock) = x

if blocks isa Union{Function, MatrixBlock}
ChainBlock([_2block(blocks)])
function chain(n::Int, blocks)
if blocks isa Union{Function, MatrixBlock, Pair}
ChainBlock([parse_block(n, blocks)])
else
ChainBlock(MatrixBlock{n}[_2block(each) for each in blocks])
ChainBlock(MatrixBlock{n}[parse_block(n, each) for each in blocks])
end
end

Expand All @@ -22,8 +44,12 @@ function chain(blocks::Vector{MatrixBlock{N}}) where N
ChainBlock(Vector{MatrixBlock{N}}(blocks))
end

function chain(n, blocks...)
ChainBlock(MatrixBlock{n}[parse_block(n, each) for each in blocks])
end

function chain(blocks::MatrixBlock{N}...) where N
ChainBlock(blocks...)
ChainBlock(collect(MatrixBlock{N}, blocks))
end

Base.getindex(::typeof(chain), xs...) = ChainBlock(xs...)
Expand Down Expand Up @@ -102,3 +128,10 @@ function roll(N::Int, blocks::MatrixBlock...)
end

roll(blocks::MatrixBlock...) = n->roll(n, blocks...)

# 2.5 repeat

import Base: repeat
repeat(n::Int, x::MatrixBlock, lines) = RepeatedBlock{n}(x, lines)
repeat(n::Int, x::MatrixBlock) = RepeatedBlock{n}(x)
repeat(x::MatrixBlock, params...) = n->repeat(n, x, params...)
7 changes: 6 additions & 1 deletion src/Interfaces/Primitive.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export H, phase, shift, Rx, Ry, Rz, rot
export H, phase, shift, Rx, Ry, Rz, rot, swap

include("PauliGates.jl")

Expand Down Expand Up @@ -83,3 +83,8 @@ function rot end

rot(::Type{T}, U::GT, theta=0.0) where {T, GT} = RotationGate{real(T), GT}(U, theta)
rot(U::MatrixBlock, theta=0.0) = rot(DefaultType, U, theta)

swap(n::Int, ::Type{T}, line1::Int, line2::Int) where T = Swap{n, T}(line1, line2)
swap(::Type{T}, line1::Int, line2::Int) where T = n -> swap(n, T, line1, line2)
swap(n::Int, line1::Int, line2::Int) = Swap{n, DefaultType}(line1, line2)
swap(line1::Int, line2::Int) = n->swap(n, line1, line2)
15 changes: 9 additions & 6 deletions src/Intrinsics/Basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ const DInts = Union{Vector{DInt}, DInt, UnitRange{DInt}}
Returns the UnitRange for basis in Hilbert Space of num_bit qubits.
"""
basis(num_bit::Int) = UnitRange{DInt}(0, 1<<num_bit-1)
basis(state::AbstractArray)::UnitRange{DInt} = UnitRange{DInt}(0, size(state, 1)-1)


########## BitArray views ###################
import Base: BitArray
Expand Down Expand Up @@ -94,15 +96,16 @@ Return an integer with all bits flipped (with total number of bit `num_bit`).
neg(index::DInt, num_bit::Int)::DInt = bmask(1:num_bit) index

"""
swapbits(num::Int, i::Int, j::Int) -> Int
swapbits(num::Int, mask12::Int) -> Int
Return an integer with bits at `i` and `j` flipped.
"""
function swapbits(num::DInt, i::Int, j::Int)::DInt
i = i-1
j = j-1
k = (num >> j) & 1 - (num >> i) & 1
num + k*(1<<i) - k*(1<<j)
function swapbits(b::Int, mask12::Int)::Int
bm = b&mask12
if bm!=0 && bm!=mask12
b ⊻= mask12
end
b
end

"""
Expand Down
2 changes: 2 additions & 0 deletions src/Registers/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ zero_state(n::Int, nbatch::Int=1) = register(n, nbatch, :zero)
rand_state(n::Int, nbatch::Int=1) = register(n, nbatch, :rand)
randn_state(n::Int, nbatch::Int=1) = register(n, nbatch, :randn)

basis(r::AbstractRegister) = basis(nqubits(r))

# function ghz(num_bit::Int; x::DInt=zero(DInt))
# v = zeros(DefaultType, 1<<num_bit)
# v[x+1] = 1/sqrt(2)
Expand Down
3 changes: 2 additions & 1 deletion src/Registers/Registers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import Base: show

# import package APIs
import ..Yao: DefaultType, nqubits, address, state, focus!
import ..Intrinsics: basis

# APIs
export nqubits, nactive, nremain, nbatch, address, state, statevec, focus!
export nqubits, nactive, nremain, nbatch, address, state, statevec, focus!, basis
export AbstractRegister, Register

# factories
Expand Down
2 changes: 1 addition & 1 deletion test/Blocks/Control.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end
end

@testset "inverse control" begin
g = ControlBlock{2}([-1, ], X, 2)
g = ControlBlock{2}([-1, ], X, 2)

op = U mat(P0) + IMatrix(U) mat(P1)
@test mat(g) op
Expand Down
4 changes: 4 additions & 0 deletions test/Blocks/Primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ end
@testset "Rotation Gate" begin
include("RotationGate.jl")
end

@testset "Swap Gate" begin
include("SwapGate.jl")
end
16 changes: 16 additions & 0 deletions test/Blocks/SwapGate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using Compat
using Compat.Test

using Yao.Blocks
using Yao.LuxurySparse
import Yao.Blocks: swapapply!

apply2mat(applyfunc!::Function, num_bit::Int) = applyfunc!(Matrix{ComplexF64}(I, 1<<num_bit, 1<<num_bit))

@testset "matrix" begin
@test mat(Swap{2, ComplexF64}(1, 2)) PermMatrix([1, 3, 2, 4], ones(1<<2))
end

@testset "apply" begin
@test mat(Swap{4, ComplexF64}(1, 3)) apply2mat(s->swapapply!(s, 1,3), 4)
end
4 changes: 2 additions & 2 deletions test/Boost/Binding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ using Yao.Boost
@testset "Repeated" begin
for G in [:X, :Y, :Z]
@eval begin
rb = RepeatedBlock{2, Complex128}($G, [1,2])
rb = RepeatedBlock{2}($G, [1,2])
@test mat(rb) kron(mat($G), mat($G))
end
end
end

@testset "Multiple Control" begin
mcb = ControlBlock{3, Complex128}(X, [3, 2], 1)
mcb = ControlBlock{3}([3, 2], 1=>X)
@test mat(mcb) mat(Toffoli)
end
14 changes: 14 additions & 0 deletions test/Boost/Boost.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Compat
using Compat.Test

@testset "binding" begin
include("Binding.jl")
end

@testset "applys" begin
include("applys.jl")
end

@testset "gates" begin
include("Gates.jl")
end
Loading

0 comments on commit a5ce710

Please sign in to comment.