Skip to content

Commit

Permalink
fix issue Yao/#299
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Aug 14, 2021
1 parent 156d8d4 commit e5d0004
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
10 changes: 5 additions & 5 deletions examples/PortZygote/chainrules_patch.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import ChainRulesCore: rrule, @non_differentiable, NoTangent
using Yao, Yao.AD

function rrule(::typeof(apply!), reg::ArrayReg, block::AbstractBlock)
out = apply!(reg, block)
function rrule(::typeof(apply), reg::ArrayReg, block::AbstractBlock)
out = apply(reg, block)
out, function (outδ)
(in, inδ), paramsδ = apply_back((out, outδ), block)
(in, inδ), paramsδ = apply_back((copy(out), outδ), block)
return (NoTangent(), inδ, paramsδ)
end
end

function rrule(::typeof(dispatch!), block::AbstractBlock, params)
out = dispatch!(block, params)
function rrule(::typeof(dispatch), block::AbstractBlock, params)
out = dispatch(block, params)
out, function (outδ)
(NoTangent(), NoTangent(), outδ)
end
Expand Down
4 changes: 2 additions & 2 deletions examples/PortZygote/gate_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ Learn a general U4 gate. The optimizer is LBFGS.
function learn_u4(u::AbstractMatrix; niter=100)
ansatz = general_U4() * put(2, 1=>phase(0.0)) # initial values are 0, here, we attach a global phase.
params = parameters(ansatz)
g!(G, x) = (dispatch!(ansatz, x); G .= Zygote.gradient(ansatz->loss(u, ansatz), ansatz)[1])
optimize(x->(dispatch!(ansatz, x); loss(u, ansatz)), g!, parameters(ansatz),
g!(G, x) = (ansatz=dispatch(ansatz, x); G .= Zygote.gradient(ansatz->loss(u, ansatz), ansatz)[1])
optimize(x->(ansatz=dispatch(ansatz, x); loss(u, ansatz)), g!, parameters(ansatz),
LBFGS(), Optim.Options(iterations=niter))
println("final loss = $(loss(u,ansatz))")
return ansatz
Expand Down
8 changes: 4 additions & 4 deletions examples/PortZygote/shared_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ h = YaoExtensions.heisenberg(5)

function loss(h, c, θ) where N
# the assign is nessesary!
c = dispatch!(c, fill(θ, nparameters(c)))
reg = apply!(zero_state(nqubits(c)), c)
c = dispatch(c, fill(θ, nparameters(c)))
reg = apply(zero_state(nqubits(c)), c)
real(expect(h, reg))
end

Expand All @@ -28,9 +28,9 @@ true_grad = sum(gparams)
# the batched version
function loss2(h, c, θ) where N
# the assign is nessesary!
c = dispatch!(c, fill(θ, nparameters(c)))
c = dispatch(c, fill(θ, nparameters(c)))
reg = zero_state(nqubits(c),nbatch=2)
reg = apply!(reg, c)
reg = apply(reg, c)
sum(real(expect(h, reg)))
end

Expand Down
2 changes: 1 addition & 1 deletion examples/PortZygote/simple_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dispatch!(c, :random)

function loss(reg::AbstractRegister, circuit::AbstractBlock{N}) where N
#copy(reg) |> circuit
reg = apply!(copy(reg), circuit)
reg = apply(copy(reg), circuit)
st = state(reg)
sum(real(st.*st))
end
Expand Down

0 comments on commit e5d0004

Please sign in to comment.