Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle permutedims rrule with ZeroTangent #683

Merged
merged 1 commit into from
Dec 3, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,21 @@ end
function rrule(::typeof(permutedims), x::AbstractVector)
project = ProjectTo(x)
permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy))))
permutedims_pullback_1(::ZeroTangent) = (NoTangent(), ZeroTangent())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK why there are several Zeros, but perhaps if doing this it may as well accept all of them:

Suggested change
permutedims_pullback_1(::ZeroTangent) = (NoTangent(), ZeroTangent())
permutedims_pullback_1(z::AbstractZero) = (NoTangent(), z)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK either, but I follow how CR.jl previously do.

norm_pullback_p(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instances of the wider pattern, which I didn't write, include:

isa AbstractZero && return (NoTangent(), V̄)

isa AbstractZero && return (NoTangent(), V̄)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, then should I update this PR, or just leave as it is?

return permutedims(x), permutedims_pullback_1
end

function rrule(::typeof(permutedims), x::AbstractArray, perm)
pr = ProjectTo(x) # projection restores e.g. transpose([1,2,3])
permutedims_back_2(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent())
permutedims_back_2(::ZeroTangent) = (NoTangent(), ZeroTangent(), NoTangent())
return permutedims(x, perm), permutedims_back_2
end

function rrule(::Type{<:PermutedDimsArray}, x::AbstractArray, perm)
pr = ProjectTo(x)
permutedims_back_3(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent())
permutedims_back_3(::ZeroTangent) = (NoTangent(), ZeroTangent(), NoTangent())
return PermutedDimsArray(x, perm), permutedims_back_3
end

Expand Down