Skip to content

Commit

Permalink
update measure
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Apr 25, 2019
1 parent 4f64983 commit 8d9cb01
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 10 deletions.
10 changes: 5 additions & 5 deletions src/measure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ function _measure(pl::AbstractMatrix, nshots::Int)
return res
end

YaoBase.measure(reg::ArrayReg{1}; nshots::Int=1) = _measure(reg |> probs, nshots)
YaoBase.measure(::ComputationalBasis, reg::ArrayReg{1}, ::AllLocs; nshots::Int=1) = _measure(reg |> probs, nshots)

function YaoBase.measure(reg::ArrayReg{B}; nshots::Int=1) where B
function YaoBase.measure(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs; nshots::Int=1) where B
pl = dropdims(sum(reg |> rank3 .|> abs2, dims=2), dims=2)
return _measure(pl, nshots)
end

function YaoBase.measure_remove!(reg::ArrayReg{B}) where B
function YaoBase.measure_remove!(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs) where B
state = reg |> rank3
nstate = similar(reg.state, 1<<nremain(reg), B)
pl = dropdims(sum(state .|> abs2, dims=2), dims=2)
Expand All @@ -37,7 +37,7 @@ function YaoBase.measure_remove!(reg::ArrayReg{B}) where B
return res
end

function YaoBase.measure!(reg::ArrayReg{B}) where B
function YaoBase.measure!(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs) where B
state = reg |> rank3
nstate = zero(state)
res = measure_remove!(reg)
Expand All @@ -49,7 +49,7 @@ function YaoBase.measure!(reg::ArrayReg{B}) where B
return res
end

function YaoBase.measure_collapseto!(reg::ArrayReg{B}; config::Integer=0) where B
function YaoBase.measure_collapseto!(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs; config::Integer=0) where B
state = rank3(reg)
M, N, B1 = size(state)
nstate = zero(state)
Expand Down
19 changes: 16 additions & 3 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export ArrayReg,
nbatch,
viewbatch,
addbits!,
insert_qubits!,
datatype,
probs,
reorder!,
Expand Down Expand Up @@ -72,9 +73,13 @@ function ArrayReg{B}(raw::MT) where {B, T, MT <: AbstractMatrix{T}}
return ArrayReg{B, T, MT}(raw)
end

ArrayReg(raw::AbstractVector{<:Complex}) = ArrayReg(reshape(raw, :, 1))
ArrayReg(raw::AbstractMatrix{<:Complex}) = ArrayReg{size(raw, 2)}(raw)
ArrayReg(raw::AbstractArray{<:Complex, 3}) = ArrayReg{size(raw, 3)}(reshape(raw, size(raw, 1), :))
function _warn_type(raw)
eltype(raw) <: Complex || @warn "Input type of `ArrayReg` is not Complex, got $(eltype(raw))"
end

ArrayReg(raw::AbstractVector) = (_warn_type(raw); ArrayReg(reshape(raw, :, 1)))
ArrayReg(raw::AbstractMatrix) = (_warn_type(raw); ArrayReg{size(raw, 2)}(raw))
ArrayReg(raw::AbstractArray{<:Any, 3}) = (_warn_type(raw); ArrayReg{size(raw, 3)}(reshape(raw, size(raw, 1), :)))

# bit literal
# NOTE: batch size B and element type T are 1 and ComplexF64 by default
Expand Down Expand Up @@ -138,6 +143,14 @@ function YaoBase.addbits!(r::ArrayReg, n::Int)
return r
end

function YaoBase.insert_qubits!(reg::ArrayReg{B}, loc::Int; nbit::Int=1) where B
na = nactive(reg)
focus!(reg, 1:loc-1)
reg2 = join(zero_state(nbit, B), reg) |> relax! |> focus!((1:na+nbit)...)
reg.state = reg2.state
reg
end

function YaoBase.probs(r::ArrayReg{1})
if size(r.state, 2) == 1
return vec(r.state .|> abs2)
Expand Down
6 changes: 5 additions & 1 deletion test/measure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Test, YaoArrayRegister, YaoBase
@test r3 r2
end

@testset "measure and reset/remove" begin
@testset "measure and collapseto/remove" begin
reg = rand_state(4)
res = measure_collapseto!(reg, (4,))
@test isnormalized(reg)
Expand All @@ -24,4 +24,8 @@ end
res = measure_remove!(reg)
select(reg0, res)
@test select(reg0, res) |> normalize! reg

reg = rand_state(6, nbatch=5) |> focus!((1:5)...)
measure_collapseto!(reg, 1)
@test nactive(reg) == 5
end
2 changes: 1 addition & 1 deletion test/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Test, YaoArrayRegister, BitBasis, LinearAlgebra
@test ArrayReg{3}(rand(4, 6)) isa ArrayReg{3}
@test_throws DimensionMismatch ArrayReg{2}(rand(4, 3))
@test_throws DimensionMismatch ArrayReg{2}(rand(5, 2))
@test_throws MethodError ArrayReg(rand(4, 3))
@test_logs (:warn, "Input type of `ArrayReg` is not Complex, got Float64") ArrayReg(rand(4, 3))

@test ArrayReg(rand(ComplexF64, 4, 3)) isa ArrayReg{3}
@test ArrayReg(rand(ComplexF64, 4)) isa ArrayReg{1}
Expand Down

0 comments on commit 8d9cb01

Please sign in to comment.