From 024704ed5a583b5ee5271525af6e8c98e3924131 Mon Sep 17 00:00:00 2001 From: Peter Date: Sun, 4 Dec 2022 05:51:58 +0800 Subject: [PATCH] handle permutedims rrule with ZeroTangent (#683) --- src/rulesets/Base/array.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 4ae424151..cea023cda 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -143,18 +143,21 @@ end function rrule(::typeof(permutedims), x::AbstractVector) project = ProjectTo(x) permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy)))) + permutedims_pullback_1(::ZeroTangent) = (NoTangent(), ZeroTangent()) return permutedims(x), permutedims_pullback_1 end function rrule(::typeof(permutedims), x::AbstractArray, perm) pr = ProjectTo(x) # projection restores e.g. transpose([1,2,3]) permutedims_back_2(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent()) + permutedims_back_2(::ZeroTangent) = (NoTangent(), ZeroTangent(), NoTangent()) return permutedims(x, perm), permutedims_back_2 end function rrule(::Type{<:PermutedDimsArray}, x::AbstractArray, perm) pr = ProjectTo(x) permutedims_back_3(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent()) + permutedims_back_3(::ZeroTangent) = (NoTangent(), ZeroTangent(), NoTangent()) return PermutedDimsArray(x, perm), permutedims_back_3 end