diff --git a/src/measure.jl b/src/measure.jl index 612f9427b..ef4eb122c 100644 --- a/src/measure.jl +++ b/src/measure.jl @@ -16,14 +16,14 @@ function _measure(pl::AbstractMatrix, nshots::Int) return res end -YaoBase.measure(reg::ArrayReg{1}; nshots::Int=1) = _measure(reg |> probs, nshots) +YaoBase.measure(::ComputationalBasis, reg::ArrayReg{1}, ::AllLocs; nshots::Int=1) = _measure(reg |> probs, nshots) -function YaoBase.measure(reg::ArrayReg{B}; nshots::Int=1) where B +function YaoBase.measure(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs; nshots::Int=1) where B pl = dropdims(sum(reg |> rank3 .|> abs2, dims=2), dims=2) return _measure(pl, nshots) end -function YaoBase.measure_remove!(reg::ArrayReg{B}) where B +function YaoBase.measure_remove!(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs) where B state = reg |> rank3 nstate = similar(reg.state, 1< abs2, dims=2), dims=2) @@ -37,7 +37,7 @@ function YaoBase.measure_remove!(reg::ArrayReg{B}) where B return res end -function YaoBase.measure!(reg::ArrayReg{B}) where B +function YaoBase.measure!(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs) where B state = reg |> rank3 nstate = zero(state) res = measure_remove!(reg) @@ -49,7 +49,7 @@ function YaoBase.measure!(reg::ArrayReg{B}) where B return res end -function YaoBase.measure_collapseto!(reg::ArrayReg{B}; config::Integer=0) where B +function YaoBase.measure_collapseto!(::ComputationalBasis, reg::ArrayReg{B}, ::AllLocs; config::Integer=0) where B state = rank3(reg) M, N, B1 = size(state) nstate = zero(state) diff --git a/src/register.jl b/src/register.jl index b9cd7c7e0..bf80c51f0 100644 --- a/src/register.jl +++ b/src/register.jl @@ -11,6 +11,7 @@ export ArrayReg, nbatch, viewbatch, addbits!, + insert_qubits!, datatype, probs, reorder!, @@ -72,9 +73,13 @@ function ArrayReg{B}(raw::MT) where {B, T, MT <: AbstractMatrix{T}} return ArrayReg{B, T, MT}(raw) end -ArrayReg(raw::AbstractVector{<:Complex}) = ArrayReg(reshape(raw, :, 1)) -ArrayReg(raw::AbstractMatrix{<:Complex}) = ArrayReg{size(raw, 2)}(raw) -ArrayReg(raw::AbstractArray{<:Complex, 3}) = ArrayReg{size(raw, 3)}(reshape(raw, size(raw, 1), :)) +function _warn_type(raw) + eltype(raw) <: Complex || @warn "Input type of `ArrayReg` is not Complex, got $(eltype(raw))" +end + +ArrayReg(raw::AbstractVector) = (_warn_type(raw); ArrayReg(reshape(raw, :, 1))) +ArrayReg(raw::AbstractMatrix) = (_warn_type(raw); ArrayReg{size(raw, 2)}(raw)) +ArrayReg(raw::AbstractArray{<:Any, 3}) = (_warn_type(raw); ArrayReg{size(raw, 3)}(reshape(raw, size(raw, 1), :))) # bit literal # NOTE: batch size B and element type T are 1 and ComplexF64 by default @@ -138,6 +143,14 @@ function YaoBase.addbits!(r::ArrayReg, n::Int) return r end +function YaoBase.insert_qubits!(reg::ArrayReg{B}, loc::Int; nqubits::Int=1) where B + na = nactive(reg) + focus!(reg, 1:loc-1) + reg2 = join(zero_state(nqubits; nbatch=B), reg) |> relax! |> focus!((1:na+nqubits)...) + reg.state = reg2.state + reg +end + function YaoBase.probs(r::ArrayReg{1}) if size(r.state, 2) == 1 return vec(r.state .|> abs2) diff --git a/test/focus.jl b/test/focus.jl index cfa2dc5e7..efd9d28ee 100644 --- a/test/focus.jl +++ b/test/focus.jl @@ -35,6 +35,8 @@ end @test copy(reg) |> addbits!(2) |> nactive == 5 reg2 = copy(reg) |> addbits!(2) |> focus!(4,5) @test (reg2 |> measure_remove!; reg2) |> relax!(to_nactive=nqubits(reg2)) ≈ reg + + @test insert_qubits!(copy(reg), 2; nqubits=2) |> nactive == 5 end @testset "Focus 2" begin diff --git a/test/measure.jl b/test/measure.jl index fa4d4751d..e425a5070 100644 --- a/test/measure.jl +++ b/test/measure.jl @@ -12,7 +12,7 @@ using Test, YaoArrayRegister, YaoBase @test r3 ≈ r2 end -@testset "measure and reset/remove" begin +@testset "measure and collapseto/remove" begin reg = rand_state(4) res = measure_collapseto!(reg, (4,)) @test isnormalized(reg) @@ -24,4 +24,8 @@ end res = measure_remove!(reg) select(reg0, res) @test select(reg0, res) |> normalize! ≈ reg + + reg = rand_state(6, nbatch=5) |> focus!((1:5)...) + measure_collapseto!(reg, 1) + @test nactive(reg) == 5 end diff --git a/test/register.jl b/test/register.jl index 8f87ffa86..d829c14c1 100644 --- a/test/register.jl +++ b/test/register.jl @@ -4,7 +4,7 @@ using Test, YaoArrayRegister, BitBasis, LinearAlgebra @test ArrayReg{3}(rand(4, 6)) isa ArrayReg{3} @test_throws DimensionMismatch ArrayReg{2}(rand(4, 3)) @test_throws DimensionMismatch ArrayReg{2}(rand(5, 2)) - @test_throws MethodError ArrayReg(rand(4, 3)) + @test_logs (:warn, "Input type of `ArrayReg` is not Complex, got Float64") ArrayReg(rand(4, 3)) @test ArrayReg(rand(ComplexF64, 4, 3)) isa ArrayReg{3} @test ArrayReg(rand(ComplexF64, 4)) isa ArrayReg{1}