Skip to content

Commit

Permalink
[WIP] Fix transpose copy (#39)
Browse files Browse the repository at this point in the history
* fix copy transpose

* keep transpose

* fix test
  • Loading branch information
GiggleLiu authored and Roger-luo committed Dec 1, 2019
1 parent 578babc commit 376ae24
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ os:
- osx
julia:
- 1.0
- 1.2
- 1.3
- nightly
matrix:
allow_failures:
Expand Down
53 changes: 52 additions & 1 deletion src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

using LinearAlgebra

export isnormalized, normalize!
export isnormalized, normalize!, regadd!, regsub!, regscale!

"""
isnormalized(r::ArrayReg) -> Bool
Expand Down Expand Up @@ -37,12 +37,46 @@ for op in [:+, :-]
return ArrayReg(($op)(state(lhs), state(rhs)))
end

@eval function Base.$op(lhs::ArrayReg{B,T1,<:Transpose}, rhs::ArrayReg{B,T2,<:Transpose}) where {B,T1,T2}
return ArrayReg(transpose(($op)(state(lhs).parent, state(rhs).parent)))
end

@eval function Base.$op(lhs::AdjointArrayReg{B}, rhs::AdjointArrayReg{B}) where {B}
r = $op(parent(lhs), parent(rhs))
return adjoint(r)
end
end

function regadd!(lhs::ArrayReg{B}, rhs::ArrayReg{B}) where {B}
lhs.state .+= rhs.state
lhs
end

function regsub!(lhs::ArrayReg{B}, rhs::ArrayReg{B}) where {B}
lhs.state .-= rhs.state
lhs
end

function regadd!(lhs::ArrayReg{B,T1,<:Transpose}, rhs::ArrayReg{B,T2,<:Transpose}) where {B,T1,T2}
lhs.state.parent .+= rhs.state.parent
lhs
end

function regsub!(lhs::ArrayReg{B,T1,<:Transpose}, rhs::ArrayReg{B,T2,<:Transpose}) where {B,T1,T2}
lhs.state.parent .-= rhs.state.parent
lhs
end

function regscale!(reg::ArrayReg{B,T1,<:Transpose}, x) where {B,T1}
reg.state.parent .*= x
reg
end

function regscale!(reg::ArrayReg{B}, x) where {B,T1}
reg.state .*= x
reg
end

# *, /
for op in [:*, :/]
@eval function Base.$op(lhs::RT, rhs::Number) where {B,RT<:ArrayReg{B}}
Expand Down Expand Up @@ -85,6 +119,23 @@ function Base.:*(bra::AdjointArrayReg{1}, ket::ArrayReg{1})
end

Base.:*(bra::AdjointArrayReg{B}, ket::ArrayReg{B}) where {B} = bra .* ket
function Base.:*(bra::AdjointArrayReg{B,T1,<:Transpose}, ket::ArrayReg{B,T2,<:Transpose}) where {B,T1,T2}
if nremain(bra) == nremain(ket) == 0 # all active
A, C = parent(state(parent(bra))), parent(state(ket))
res = zeros(eltype(promote_type(T1, T2)), B)
#return mapreduce((x, y) -> conj(x) * y, +, ; dims=2)
for j=1:size(A, 2)
for i=1:size(A, 1)
@inbounds res[i] += conj(A[i, j]) * C[i, j]
end
end
res
elseif nremain(bra) == 0 # <s|active> |remain>
bra .* ket
else
error("partially contract ⟨bra|ket⟩ is not supported, expect ⟨bra| to be fully actived. nactive(bra)/nqubits(bra)=$(nactive(bra))/$(nqubits(bra))")
end
end

# broadcast
broadcastable(r::ArrayRegOrAdjointArrayReg{1}) = Ref(r)
Expand Down
2 changes: 2 additions & 0 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ Initialize a new `ArrayReg` by an existing `ArrayReg`. This is equivalent
to `copy`.
"""
ArrayReg(r::ArrayReg{B}) where {B} = ArrayReg{B}(copy(r.state))
ArrayReg(r::ArrayReg{B,T,<:Transpose}) where {B,T} = ArrayReg{B}(Transpose(copy(r.state.parent)))

transpose_storage(reg::ArrayReg{B,T,<:Transpose}) where {B,T} = ArrayReg{B}(copy(reg.state))
transpose_storage(reg::ArrayReg{B,T}) where {B,T} = ArrayReg{B}(transpose(copy(transpose(reg.state))))
Expand Down Expand Up @@ -450,6 +451,7 @@ end
Returns an `ArrayReg` with `1:n` qubits activated.
"""
oneto(r::ArrayReg{B}, n::Int = nqubits(r)) where {B} = ArrayReg{B}(reshape(copy(r.state), 1 << n, :))
oneto(r::ArrayReg{B,T,<:Transpose}, n::Int = nqubits(r)) where {B,T} = transpose_storage(ArrayReg{B}(reshape(r.state, 1 << n, :)))

"""
oneto(n::Int) -> f(register)
Expand Down
27 changes: 27 additions & 0 deletions test/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,31 @@ end
focus!(ket, 1)
focus!(bra, 2)
@test_throws ErrorException bra' * ket

reg1 = rand_state(5; nbatch=10)
reg2 = rand_state(5; nbatch=10)
@test reg1' * reg2 reg1' .* reg2
reg1 = rand_state(2; nbatch=10)
reg2 = rand_state(5; nbatch=10)
focus!(reg2, 2:3)
@test all(reg1' * reg2 .≈ reg1' .* reg2)
end

@testset "inplace funcs" begin
for nbatch in [1, 10]
reg = rand_state(5; nbatch=nbatch)
reg0 = copy(reg)
@test regscale!(reg, 0.3) 0.3*reg0
reg1 = rand_state(5; nbatch=nbatch)
reg2 = rand_state(5; nbatch=nbatch)
reg10 = copy(reg1)
regsub!(reg1, reg2)
@test reg1 reg10 - reg2

reg1 = rand_state(5; nbatch=nbatch)
reg2 = rand_state(5; nbatch=nbatch)
reg10 = copy(reg1)
regadd!(reg1, reg2)
@test reg1 reg10 + reg2
end
end
15 changes: 15 additions & 0 deletions test/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,18 @@ end
copyto!(r1, r2)
@test r1 == r2
end

@testset "transpose copy" begin
reg = rand_state(5; nbatch=10)
reg1 = copy(reg)
reg2 = focus!(copy(reg), (3,5))
reg3 = relax!(copy(reg2), (3,5))
reg4 = oneto(reg1, 3)
@test reg1.state isa Transpose
@test reg2.state isa Transpose
@test reg3.state isa Transpose
@test reg4.state isa Transpose
@test reg4.state oneto(reg, 3).state
reg4.state[1] = 2.0
@test reg.state[1] != 2.0
end

0 comments on commit 376ae24

Please sign in to comment.