Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework wrapper type #23

Merged
merged 4 commits into from
Jun 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Adapt"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "1.1.0"
version = "2.0.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
42 changes: 1 addition & 41 deletions src/Adapt.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,5 @@
module Adapt

using LinearAlgebra


export WrappedArray

# database of array wrappers
#
# LHS entries are a symbolic type with AT for the array type
#
# RHS entries consist of a closure to reconstruct the wrapper, with as arguments
# a wrapper instance and mutator function to apply to the inner array
const wrappers = (
:(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))),
:(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,AT})=> (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)),
:(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)),
:(LinearAlgebra.Adjoint{<:Any,AT}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))),
:(LinearAlgebra.Transpose{<:Any,AT}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))),
:(LinearAlgebra.LowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))),
:(LinearAlgebra.UnitLowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitLowerTriangular(mut(parent(A))),
:(LinearAlgebra.UpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UpperTriangular(mut(parent(A))),
:(LinearAlgebra.UnitUpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))),
:(LinearAlgebra.Diagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))),
:(LinearAlgebra.Tridiagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Tridiagonal(mut(A.dl), mut(A.d), mut(A.du)),
)

"""
WrappedArray{AT}

Union-type that encodes all array wrappers known by Adapt.jl.

Only use this type for dispatch purposes. To convert instances of an array wrapper, use
[`adapt`](@ref).
"""
const WrappedArray{AT} = @eval Union{$([W for (W,ctor) in Adapt.wrappers]...)} where AT

# XXX: this Unions is a hack, and only works with one level of wrray wrappers. ideally, Base
# would have `Transpose <: WrappedArray <: AbstractArray` and we could define methods
# in terms of `Union{SomeArray, WrappedArray{<:Any, <:SomeArray}}`.
# https://github.com/JuliaLang/julia/pull/31563


export adapt

"""
Expand Down Expand Up @@ -84,5 +43,6 @@ adapt_structure(to, x) = adapt_storage(to, x)
adapt_storage(to, x) = x

include("base.jl")
include("wrappers.jl")

end # module
12 changes: 0 additions & 12 deletions src/base.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
# predefined adaptors for working with types from the Julia standard library

## Base

adapt_structure(to, xs::Union{Tuple,NamedTuple}) = map(x->adapt(to,x), xs)


## Array wrappers

permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm

for (W, ctor) in wrappers
mut = :(A -> adapt(to, A))
@eval adapt_structure(to, wrapper::$W where {AT <: Any}) = $ctor(wrapper, $mut)
end


## Broadcast

import Base.Broadcast: Broadcasted, Extruded
Expand Down
65 changes: 65 additions & 0 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# adaptors and type aliases for working with array wrappers

using LinearAlgebra

permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm


export WrappedArray

# database of array wrappers
const _wrappers = (
:(SubArray{T,N,<:Src}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))),
:(PermutedDimsArray{T,N,<:Any,<:Any,<:Src}) => (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)),
:(Base.ReshapedArray{T,N,<:Src}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)),
:(Base.ReinterpretArray{T,N,<:Src}) => (A,mut)->Base.reinterpret(eltype(A), mut(parent(A))),
:(LinearAlgebra.Adjoint{T,<:Dst}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))),
:(LinearAlgebra.Transpose{T,<:Dst}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))),
:(LinearAlgebra.LowerTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))),
:(LinearAlgebra.UnitLowerTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.UnitLowerTriangular(mut(parent(A))),
:(LinearAlgebra.UpperTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.UpperTriangular(mut(parent(A))),
:(LinearAlgebra.UnitUpperTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))),
:(LinearAlgebra.Diagonal{T,<:Dst}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))),
:(LinearAlgebra.Tridiagonal{T,<:Dst}) => (A,mut)->LinearAlgebra.Tridiagonal(mut(A.dl), mut(A.d), mut(A.du)),
)

for (W, ctor) in _wrappers
mut = :(A -> adapt(to, A))
@eval adapt_structure(to, wrapper::$W where {T,N,Src,Dst}) = $ctor(wrapper, $mut)
end

"""
WrappedArray{T,N,Src,Dst}

Union-type that encodes all array wrappers known by Adapt.jl. Typevars `T` and `N` encode
the type and dimensionality of the resulting container.

Two additional typevars are used to encode the parent array type: `Src` when the wrapper
uses the parent array as a source, but changes its properties (e.g.
`SubArray{T,1,Array{T,2}` changes `N`), and `Dst` when those properties are copied and thus
are identical to the destination wrapper's properties (e.g. `Transpose{T,Array{T,N}}` has
the same dimensionality as the inner array). When creating an alias for this type, e.g.
`WrappedSomeArray{T,N} = WrappedArray{T,N,...}` the `Dst` typevar should typically be set to
`SomeArray{T,N}` while `Src` should be more lenient, e.g., `SomeArray`.

Only use this type for dispatch purposes. To convert instances of an array wrapper, use
[`adapt`](@ref).
"""
const WrappedArray{T,N,Src,Dst} = @eval Union{$([W for (W,ctor) in Adapt._wrappers]...)} where {T,N,Src,Dst}

# XXX: this Union is a hack:
# - only works with one level of wrappi ng
# - duplication of Src and Dst typevars (without it, we get `WrappedGPUArray{T,N,AT{T,N}}`
# not matching `SubArray{T,1,AT{T,2}}`, and leaving out `{T,N}` makes it impossible to
# match e.g. `Diagonal{T,AT}` and get `N` out of that). alternatively, computed types
# would make it possible to do `SubArray{T,N,<:AT.name.wrapper}` or `Diagonal{T,AT{T,N}}`.
#
# ideally, Base would have, e.g., `Transpose <: WrappedArray`, and we could use
# `Union{SomeArray, WrappedArray{<:Any, <:SomeArray}}` for dispatch.
# https://github.com/JuliaLang/julia/pull/31563

# accessors for extracting information about the wrapper type
Base.ndims(::Type{<:WrappedArray{T,N,Src,Dst}}) where {T,N,Src,Dst} = @isdefined(N) ? N : ndims(Dst)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type-piracy is unfortunate. @jakebolewski and I just hunted this down as the source of breaking GPU inference, because Julia doesn't inline this function (might be the unused Src type parameter, e.g. dangling type parameter).

why is the @defined needed here?

x-ref: CliMA/ClimateMachine.jl#1260 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> f(a) = eltype(a)
f (generic function with 1 method)

julia> @code_typed f(B)
CodeInfo(
1 ─     return Float64
) => Type{Float64}

julia> using Adapt
[ Info: Precompiling Adapt [79e6a3ab-5dfb-504d-930d-738a2a938a0e]

julia> @code_typed f(B)
CodeInfo(
1 ─ %1 = invoke Main.eltype(_2::SubArray{Float64,1,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},Int64},true})::Core.Compiler.Const(Float64, false)
└──      return %1
) => Type{Float64}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the @defined needed here?

julia> x(::Base.ReshapedArray{Float32,2,Array{Float32,2}})
ERROR: syntax: invalid "::" syntax around REPL[3]:1
Stacktrace:
 [1] top-level scope at REPL[3]:1

julia> x(Base.ReshapedArray{Float32,2,Array{Float32,2}})
(Float32, 2)

julia> x(LinearAlgebra.Adjoint{Float32,Array{Float32,2}})
ERROR: UndefVarError: N not defined
Stacktrace:
 [1] x(::Type{Adjoint{Float32,Array{Float32,2}}}) at ./REPL[2]:1
 [2] top-level scope at REPL[8]:1

For types that 'fix' the N typevar I can't get a hold of it from the WrappedArray union...
It's all a bit of a kludge, admittedly, but it brings us closer to trying out something that might end up in Base some day.

Suggested fix? Adapt.eltype and Adapt.ndims?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had not seen #26

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think making these Adapt.eltype would make sense

Base.eltype(::Type{<:WrappedArray{T,N,Src,Dst}}) where {T,N,Src,Dst} = @isdefined(T) ? T : ndims(Dst)
Base.parent(W::Type{<:WrappedArray{T,N,Src,Dst}}) where {T,N,Src,Dst} = @isdefined(Dst) ? Dst.name.wrapper : Src.name.wrapper

58 changes: 34 additions & 24 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,67 +5,77 @@ using Test
# custom array type

struct CustomArray{T,N} <: AbstractArray{T,N}
arr::AbstractArray
arr::Array
end

CustomArray(x::AbstractArray{T,N}) where {T,N} = CustomArray{T,N}(x)
Adapt.adapt_storage(::Type{<:CustomArray}, xs::AbstractArray) = CustomArray(xs)
CustomArray(x::Array{T,N}) where {T,N} = CustomArray{T,N}(x)
Adapt.adapt_storage(::Type{<:CustomArray}, xs::Array) = CustomArray(xs)

Base.size(x::CustomArray, y...) = size(x.arr, y...)
Base.getindex(x::CustomArray, y...) = getindex(x.arr, y...)


const val = CustomArray{Float64,2}(rand(2,2))
const mat = CustomArray{Float64,2}(rand(2,2))
const vec = CustomArray{Float64,1}(rand(2))

macro test_adapt(to, src, dst)
quote
@test adapt($to, $src) == $dst
@test typeof(adapt($to, $src)) == typeof($dst)
end
end


# basic adaption
@test adapt(CustomArray, val.arr) == val
@test adapt(CustomArray, val.arr) isa CustomArray
@test_adapt CustomArray mat.arr mat

# idempotency
@test adapt(CustomArray, val) == val
@test adapt(CustomArray, val) isa CustomArray
@test_adapt CustomArray mat mat

# custom wrapper
struct Wrapper{T}
arr::T
end
Wrapper(x::T) where T = Wrapper{T}(x)
Adapt.adapt_structure(to, xs::Wrapper) = Wrapper(adapt(to, xs.arr))
@test adapt(CustomArray, Wrapper(val.arr)) == Wrapper(val)
@test adapt(CustomArray, Wrapper(val.arr)) isa Wrapper{<:CustomArray}
@test_adapt CustomArray Wrapper(mat.arr) Wrapper(mat)


## base wrappers

@test @inferred(adapt(nothing, NamedTuple())) == NamedTuple()
@test adapt(CustomArray, (val.arr,)) == (val,)
@test_adapt CustomArray (mat.arr,) (mat,)
@test @allocated(adapt(nothing, ())) == 0
@test @allocated(adapt(nothing, (1,))) == 0
@test @allocated(adapt(nothing, (1,2,3,4,5,6,7,8,9,10))) == 0

@test adapt(CustomArray, (a=val.arr,)) == (a=val,)
@test_adapt CustomArray (a=mat.arr,) (a=mat,)

@test adapt(CustomArray, view(val.arr,:,:)) == view(val,:,:)
@test_adapt CustomArray view(mat.arr,:,:) view(mat,:,:)
const inds = CustomArray{Int,1}([1,2])
@test adapt(CustomArray, view(val.arr,inds.arr,:)) == view(val,inds,:)
@test_adapt CustomArray view(mat.arr,inds.arr,:) view(mat,inds,:)

# NOTE: manual creation of PermutedDimsArray because permutedims collects
@test adapt(CustomArray, PermutedDimsArray(val.arr,(2,1))) == PermutedDimsArray(val,(2,1))
@test_adapt CustomArray PermutedDimsArray(mat.arr,(2,1)) PermutedDimsArray(mat,(2,1))

# NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape`
@test adapt(CustomArray, Base.ReshapedArray(val.arr,(2,2),())) == reshape(val,(2,2))
@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2))


using LinearAlgebra

@test adapt(CustomArray, val.arr') == val'
@test_adapt CustomArray mat.arr' mat'

@test_adapt CustomArray transpose(mat.arr) transpose(mat)

@test adapt(CustomArray, transpose(val.arr)) == transpose(val)
@test_adapt CustomArray LowerTriangular(mat.arr) LowerTriangular(mat)
@test_adapt CustomArray UnitLowerTriangular(mat.arr) UnitLowerTriangular(mat)
@test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat)
@test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat)

@test adapt(CustomArray, LowerTriangular(val.arr)) == LowerTriangular(val)
@test adapt(CustomArray, UnitLowerTriangular(val.arr)) == UnitLowerTriangular(val)
@test adapt(CustomArray, UpperTriangular(val.arr)) == UpperTriangular(val)
@test adapt(CustomArray, UnitUpperTriangular(val.arr)) == UnitUpperTriangular(val)
@test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec)

@test adapt(CustomArray, Diagonal(val.arr)) == Diagonal(val)
@test adapt(CustomArray, Tridiagonal(val.arr)) == Tridiagonal(val)
const dl = CustomArray{Float64,1}(rand(2))
const du = CustomArray{Float64,1}(rand(2))
const d = CustomArray{Float64,1}(rand(3))
@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du)