From 8d9cb017daa4e415a61f59410260d6d5116c909b Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Thu, 25 Apr 2019 21:49:59 +0800 Subject: [PATCH 1/2] update measure --- src/measure.jl | 10 +++++----- src/register.jl | 19 ++++++++++++++++--- test/measure.jl | 6 +++++- test/register.jl | 2 +- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/measure.jl b/src/measure.jl index 612f9427b..ef4eb122c 100644 --- a/src/measure.jl +++ b/src/measure.jl @@ -16,14 +16,14 @@ function _measure(pl::AbstractMatrix, nshots::Int) return res end -YaoBase.measure(reg::ArrayReg{1}; nshots::Int=1) = _measure(reg |> probs, nshots) +YaoBase.measure(::ComputationalBasis, reg::ArrayReg{1}, ::AllLocs; nshots::Int=1) = _measure(reg |> probs, nshots) -function YaoBase.measure(reg::ArrayReg{B}; nshots::Int=1) where B +function YaoBase.measure(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs; nshots::Int=1) where B pl = dropdims(sum(reg |> rank3 .|> abs2, dims=2), dims=2) return _measure(pl, nshots) end -function YaoBase.measure_remove!(reg::ArrayReg{B}) where B +function YaoBase.measure_remove!(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs) where B state = reg |> rank3 nstate = similar(reg.state, 1< abs2, dims=2), dims=2) @@ -37,7 +37,7 @@ function YaoBase.measure_remove!(reg::ArrayReg{B}) where B return res end -function YaoBase.measure!(reg::ArrayReg{B}) where B +function YaoBase.measure!(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs) where B state = reg |> rank3 nstate = zero(state) res = measure_remove!(reg) @@ -49,7 +49,7 @@ function YaoBase.measure!(reg::ArrayReg{B}) where B return res end -function YaoBase.measure_collapseto!(reg::ArrayReg{B}; config::Integer=0) where B +function YaoBase.measure_collapseto!(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs; config::Integer=0) where B state = rank3(reg) M, N, B1 = size(state) nstate = zero(state) diff --git a/src/register.jl b/src/register.jl index b9cd7c7e0..c4997828c 100644 --- a/src/register.jl +++ b/src/register.jl @@ -11,6 +11,7 @@ export ArrayReg, nbatch, viewbatch, addbits!, + insert_qubits!, datatype, probs, reorder!, @@ -72,9 +73,13 @@ function ArrayReg{B}(raw::MT) where {B, T, MT <: AbstractMatrix{T}} return ArrayReg{B, T, MT}(raw) end -ArrayReg(raw::AbstractVector{<:Complex}) = ArrayReg(reshape(raw, :, 1)) -ArrayReg(raw::AbstractMatrix{<:Complex}) = ArrayReg{size(raw, 2)}(raw) -ArrayReg(raw::AbstractArray{<:Complex, 3}) = ArrayReg{size(raw, 3)}(reshape(raw, size(raw, 1), :)) +function _warn_type(raw) + eltype(raw) <: Complex || @warn "Input type of `ArrayReg` is not Complex, got $(eltype(raw))" +end + +ArrayReg(raw::AbstractVector) = (_warn_type(raw); ArrayReg(reshape(raw, :, 1))) +ArrayReg(raw::AbstractMatrix) = (_warn_type(raw); ArrayReg{size(raw, 2)}(raw)) +ArrayReg(raw::AbstractArray{<:Any, 3}) = (_warn_type(raw); ArrayReg{size(raw, 3)}(reshape(raw, size(raw, 1), :))) # bit literal # NOTE: batch size B and element type T are 1 and ComplexF64 by default @@ -138,6 +143,14 @@ function YaoBase.addbits!(r::ArrayReg, n::Int) return r end +function YaoBase.insert_qubits!(reg::ArrayReg{B}, loc::Int; nbit::Int=1) where B + na = nactive(reg) + focus!(reg, 1:loc-1) + reg2 = join(zero_state(nbit, B), reg) |> relax! |> focus!((1:na+nbit)...) + reg.state = reg2.state + reg +end + function YaoBase.probs(r::ArrayReg{1}) if size(r.state, 2) == 1 return vec(r.state .|> abs2) diff --git a/test/measure.jl b/test/measure.jl index fa4d4751d..e425a5070 100644 --- a/test/measure.jl +++ b/test/measure.jl @@ -12,7 +12,7 @@ using Test, YaoArrayRegister, YaoBase @test r3 ≈ r2 end -@testset "measure and reset/remove" begin +@testset "measure and collapseto/remove" begin reg = rand_state(4) res = measure_collapseto!(reg, (4,)) @test isnormalized(reg) @@ -24,4 +24,8 @@ end res = measure_remove!(reg) select(reg0, res) @test select(reg0, res) |> normalize! ≈ reg + + reg = rand_state(6, nbatch=5) |> focus!((1:5)...) + measure_collapseto!(reg, 1) + @test nactive(reg) == 5 end diff --git a/test/register.jl b/test/register.jl index 8f87ffa86..d829c14c1 100644 --- a/test/register.jl +++ b/test/register.jl @@ -4,7 +4,7 @@ using Test, YaoArrayRegister, BitBasis, LinearAlgebra @test ArrayReg{3}(rand(4, 6)) isa ArrayReg{3} @test_throws DimensionMismatch ArrayReg{2}(rand(4, 3)) @test_throws DimensionMismatch ArrayReg{2}(rand(5, 2)) - @test_throws MethodError ArrayReg(rand(4, 3)) + @test_logs (:warn, "Input type of `ArrayReg` is not Complex, got Float64") ArrayReg(rand(4, 3)) @test ArrayReg(rand(ComplexF64, 4, 3)) isa ArrayReg{3} @test ArrayReg(rand(ComplexF64, 4)) isa ArrayReg{1} From adaab563a9d14367155ebb9bff84ba9dc102e948 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Thu, 25 Apr 2019 22:49:53 +0800 Subject: [PATCH 2/2] update measure --- src/register.jl | 4 ++-- test/focus.jl | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/register.jl b/src/register.jl index c4997828c..bf80c51f0 100644 --- a/src/register.jl +++ b/src/register.jl @@ -143,10 +143,10 @@ function YaoBase.addbits!(r::ArrayReg, n::Int) return r end -function YaoBase.insert_qubits!(reg::ArrayReg{B}, loc::Int; nbit::Int=1) where B +function YaoBase.insert_qubits!(reg::ArrayReg{B}, loc::Int; nqubits::Int=1) where B na = nactive(reg) focus!(reg, 1:loc-1) - reg2 = join(zero_state(nbit, B), reg) |> relax! |> focus!((1:na+nbit)...) + reg2 = join(zero_state(nqubits; nbatch=B), reg) |> relax! |> focus!((1:na+nqubits)...) reg.state = reg2.state reg end diff --git a/test/focus.jl b/test/focus.jl index cfa2dc5e7..efd9d28ee 100644 --- a/test/focus.jl +++ b/test/focus.jl @@ -35,6 +35,8 @@ end @test copy(reg) |> addbits!(2) |> nactive == 5 reg2 = copy(reg) |> addbits!(2) |> focus!(4,5) @test (reg2 |> measure_remove!; reg2) |> relax!(to_nactive=nqubits(reg2)) ≈ reg + + @test insert_qubits!(copy(reg), 2; nqubits=2) |> nactive == 5 end @testset "Focus 2" begin