Skip to content

mcabbott/SliceMap.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

64 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SliceMap.jl

Build Status

This package provides some mapslices-like functions, with gradients defined for Tracker and Zygote:

mapcols(f, M)  mapreduce(f, hcat, eachcol(M))
MapCols{d}(f, M)         # where d=size(M,1), for SVector slices
ThreadMapCols{d}(f, M)   # using Threads.@threads

maprows(f, M)  mapslices(f, M, dims=2)

slicemap(f, A; dims)  mapslices(f, A, dims=dims) # only Zygote

The capitalised functions differ both in using StaticArrays slices, and using ForwardDiff for the gradient of each slice, instead of the same reverse-mode Tracker/Zygote. For small slices, this will often be much faster, with or without gradients.

The package also defines Zygote gradients for the Slice/Align functions in JuliennedArrays, which is a good way to roll-your-own mapslices-like thing (and is exactly how slicemap(f, A; dims) works). Similar gradients are also available in TensorCast, and in LazyStack.

There are more details & examples at docs/intro.md.