Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need an adjoint for constructor Base.OneTo{Int64} #483

Closed
AzamatB opened this issue Jan 28, 2020 · 7 comments · Fixed by #485
Closed

Need an adjoint for constructor Base.OneTo{Int64} #483

AzamatB opened this issue Jan 28, 2020 · 7 comments · Fixed by #485

Comments

@AzamatB
Copy link
Contributor

AzamatB commented Jan 28, 2020

Consider

using Zygote, LinearAlgebra
v = Ref(rand(2))
A = rand(2,3)
gradient(x -> sum(v .⋅ eachcol(x)), A)

which throws

ERROR: Need an adjoint for constructor Base.OneTo{Int64}. Gradient is of type Array{Nothing,1}

Noting that eachcol is defined as

eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2))

and replacing in that definition axes(A, 2) with 1:size(A, 2) makes the error go away.

I'm happy to submit a PR fixing this if someone can point me to where the adjoint for UnitRange{Int64} is defined.

@mcabbott
Copy link
Member

I think you just want Zygote.@nograd axes.

The error you had is still produced by things like gradient(x -> (sum(Base.OneTo(3) .+ x)), 4), so perhaps that should be nograd-ed too.

@jamblejoe
Copy link
Contributor

I do not see why Base.OneTo should be nograded? This example here

gradient(x-> sum(Base.OneTo(x)), 2)

clearly has a non-trivial gradient, but errors with

Need an adjoint for constructor Base.OneTo{Int64}. Gradient is of type FillArrays.Fill{Int64,1,Tuple{Base.OneTo{Int64}}}

I think this issue should be reopened and fixed by implementing the corresponding adjoint.

@mcabbott
Copy link
Member

Wouldn't gradient would want to be a limit of this?

(sum(Base.OneTo(2 + 0.001)) - sum(Base.OneTo(2))) / 0.001

I don't think this has a sensible answer, so I guess the answer should be nothing (or one of ChainRules's zeros).

@jamblejoe
Copy link
Contributor

I thought too mathematically here. For real x one knows that $\sum_{n=1}^x n = x(x+1)$ , which of course has a non-zero derivative, if one extends the right hand side to the real values. That is of course not what Base.OneTo is doing. Sorry for that!
gradient(x-> sum(1:x), 2)
results in
(nothing, )
As Base.OneTo is just the type stable version of it, I think it should have the same behaviour.

I am not familiar with Zygote, but maybe this is a good first issue. Would this be easy to implement @mcabbott ?

@mcabbott
Copy link
Member

But 1:x still wants integer x, it's still not continuously variable. It truncates, e.g. 1:3.9 == 1:3.

Zygote allows you to use integers in things like gradient(x -> x^3, 5), but they have to be real numbers which just happen to have integer value, for the gradient to make sense.

Stepranges might make sense, e.g. perhaps gradient(x -> sum(range(1,x,length=20)), 3) should be made to work, not sure how hard it would be. But have not run into a need for such things!

@jamblejoe
Copy link
Contributor

@mcabbott I agree that 1:x should have gradient nothing, as should Base.OneTo.

I get why sum(range(1,x,length=20)) is different, at least when I write it out in math terms.

I was asking, because I had some code involving getting the number of vertices of a graph from SimpleGraphs.jl. This ended up in a Base.OneTo call, which Zygote.jl did not know what to do with it.

Why though, do I get the above error for

gradient(x-> sum(Base.OneTo(x)), 2)

but this works fine

gradient(x -> (s=zero(x); for i in Base.OneTo(x); s+=i*x; end; s), 3)
(6,)

? I am confused now.

@mcabbott
Copy link
Member

Oh right, it needs Zygote.@nograd Base.OneTo I think. You could copy the above PR to add this permanently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants