-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scale block not supporting chainrules/Zygote diff yet #323
Comments
Thanks for the issue, this is not because the Scale block is not supported, but the julia> N=2;
julia> psi_0 = zero_state(N);
julia> U0 = chain(N, put(1=>Rx(0.0)), put(2=>Ry(0.0)));
julia> C = 2.1*sum([chain(N, put(k=>Z)) for k=1:N]);
julia> function loss(theta)
U = dispatch(U0, theta)
psi0 = copy(psi_0)
psi1 = apply(psi0, U)
psi2 = Zygote.@ignore apply(psi1, C)
result = real(sum(conj(state(psi1)) .* state(psi2)))
return result
end
julia> theta = [1.7,2.5];
julia> println(expect'(C, copy(psi_0) => dispatch(U0, theta))[2])
[-2.0824961019501838, -1.2567915026183087]
julia> grad = Zygote.gradient(theta->loss(theta), theta)[1];
julia> println(grad)
[-1.0412480509750919, -0.6283957513091544] |
thanks that could work for most cases! Also, why is your gradient returned a factor 2 difference in both cases? |
If you are asking why expect’ returns the correct gradient, the it is because you are using Yao‘s built in AD engine. Yao ignores it automatically. The reason why the gradients are different by two is probably related to the macro also ignores half of psi‘s gradient at the same time. |
Scale blocks appear to be unsupported by the chainrules in YaoBlocks
A minimal example setting:
the above loss function computes effectively an expectation value equivalent to expect(C, psi_0 => U). Computing expect' is no problem, but when instead we use Zygote we find the following error:
However, if we instead put the scale factor in front of each Z instead of in front of the whole sum([chain[][) block, so
, expect' and zygote.gradient yield the same result [-2.0824961019501838, -1.2567915026183087], as expected.
The two methods are mathematically equivalent, but support for the former would be useful/clean!
The text was updated successfully, but these errors were encountered: