Skip to content

Commit

Permalink
cat with multiple dims FluxML#156
Browse files Browse the repository at this point in the history
Co-authored-by: americast <sayan.sinha@iitkgp.ac.in>
  • Loading branch information
gustafsson and americast committed May 2, 2018
1 parent cb2f6e7 commit 377718c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
31 changes: 15 additions & 16 deletions src/tracker/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ for f in [:vcat, :hcat]
Base.$f(a::TrackedArray...) = track($f, a...)
Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...)

# assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
# assumes there is another function to match Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...)
end
end
Expand All @@ -125,22 +125,21 @@ function back(::typeof(hcat), Δ, xs...)
end
end

Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...)
Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...)
Base.cat(dims, a::TrackedArray...) = track(Base.cat, dims, a...)
Base.cat(dims, a::TrackedArray, b::Array...) = track(Base.cat, dims, a, b...)
Base.cat(dims, a::Array, b::TrackedArray...) = track(Base.cat, dims, a, b...)

function back(::typeof(cat), Δ, dim, xs...)
start = 0
for xsi in xs
if ndims(xsi) < dim
i = map(_ -> :, size(xsi))
j = ones(Int, dim-ndims(xsi)-1)
@back(xsi, Δ[i..., j..., start+1])
else
i = fill(:, dim-1)
j = fill(:, ndims(xsi)-dim)
@back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...])
end
start += size(xsi, dim)
function back(::typeof(cat), Δ, dims, Xs...)
start = ntuple(i -> 0, Val{ndims(Δ)})
for xs in Xs
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})

xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})

@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))

start = start .+ till_xs
end
end

Expand Down
2 changes: 2 additions & 0 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ end
@test gradtest(catdim, rand(5), rand(5))
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
end

@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
end

@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
Expand Down

0 comments on commit 377718c

Please sign in to comment.