From d8387264ada30845074fa89b936462de80cdaedc Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Mon, 24 Oct 2022 12:54:38 +0800 Subject: [PATCH] handle permutedims rrule with ZeroTangent --- src/rulesets/Base/array.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 2461c5561..d817f4cce 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -143,18 +143,21 @@ end function rrule(::typeof(permutedims), x::AbstractVector) project = ProjectTo(x) permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy)))) + permutedims_pullback_1(::ZeroTangent) = (NoTangent(), ZeroTangent()) return permutedims(x), permutedims_pullback_1 end function rrule(::typeof(permutedims), x::AbstractArray, perm) pr = ProjectTo(x) # projection restores e.g. transpose([1,2,3]) permutedims_back_2(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent()) + permutedims_back_2(::ZeroTangent) = (NoTangent(), ZeroTangent(), NoTangent()) return permutedims(x, perm), permutedims_back_2 end function rrule(::Type{<:PermutedDimsArray}, x::AbstractArray, perm) pr = ProjectTo(x) permutedims_back_3(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent()) + permutedims_back_3(::ZeroTangent) = (NoTangent(), ZeroTangent(), NoTangent()) return PermutedDimsArray(x, perm), permutedims_back_3 end