-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from JuliaGPU/tb/wrappers
Rework wrapper type
- Loading branch information
Showing
5 changed files
with
101 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# adaptors and type aliases for working with array wrappers | ||
|
||
using LinearAlgebra | ||
|
||
permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm | ||
|
||
|
||
export WrappedArray | ||
|
||
# database of array wrappers | ||
const _wrappers = ( | ||
:(SubArray{T,N,<:Src}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))), | ||
:(PermutedDimsArray{T,N,<:Any,<:Any,<:Src}) => (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)), | ||
:(Base.ReshapedArray{T,N,<:Src}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)), | ||
:(Base.ReinterpretArray{T,N,<:Src}) => (A,mut)->Base.reinterpret(eltype(A), mut(parent(A))), | ||
:(LinearAlgebra.Adjoint{T,<:Dst}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))), | ||
:(LinearAlgebra.Transpose{T,<:Dst}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))), | ||
:(LinearAlgebra.LowerTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))), | ||
:(LinearAlgebra.UnitLowerTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.UnitLowerTriangular(mut(parent(A))), | ||
:(LinearAlgebra.UpperTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.UpperTriangular(mut(parent(A))), | ||
:(LinearAlgebra.UnitUpperTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))), | ||
:(LinearAlgebra.Diagonal{T,<:Dst}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))), | ||
:(LinearAlgebra.Tridiagonal{T,<:Dst}) => (A,mut)->LinearAlgebra.Tridiagonal(mut(A.dl), mut(A.d), mut(A.du)), | ||
) | ||
|
||
for (W, ctor) in _wrappers | ||
mut = :(A -> adapt(to, A)) | ||
@eval adapt_structure(to, wrapper::$W where {T,N,Src,Dst}) = $ctor(wrapper, $mut) | ||
end | ||
|
||
""" | ||
WrappedArray{T,N,Src,Dst} | ||
Union-type that encodes all array wrappers known by Adapt.jl. Typevars `T` and `N` encode | ||
the type and dimensionality of the resulting container. | ||
Two additional typevars are used to encode the parent array type: `Src` when the wrapper | ||
uses the parent array as a source, but changes its properties (e.g. | ||
`SubArray{T,1,Array{T,2}` changes `N`), and `Dst` when those properties are copied and thus | ||
are identical to the destination wrapper's properties (e.g. `Transpose{T,Array{T,N}}` has | ||
the same dimensionality as the inner array). When creating an alias for this type, e.g. | ||
`WrappedSomeArray{T,N} = WrappedArray{T,N,...}` the `Dst` typevar should typically be set to | ||
`SomeArray{T,N}` while `Src` should be more lenient, e.g., `SomeArray`. | ||
Only use this type for dispatch purposes. To convert instances of an array wrapper, use | ||
[`adapt`](@ref). | ||
""" | ||
const WrappedArray{T,N,Src,Dst} = @eval Union{$([W for (W,ctor) in Adapt._wrappers]...)} where {T,N,Src,Dst} | ||
|
||
# XXX: this Union is a hack: | ||
# - only works with one level of wrappi ng | ||
# - duplication of Src and Dst typevars (without it, we get `WrappedGPUArray{T,N,AT{T,N}}` | ||
# not matching `SubArray{T,1,AT{T,2}}`, and leaving out `{T,N}` makes it impossible to | ||
# match e.g. `Diagonal{T,AT}` and get `N` out of that). alternatively, computed types | ||
# would make it possible to do `SubArray{T,N,<:AT.name.wrapper}` or `Diagonal{T,AT{T,N}}`. | ||
# | ||
# ideally, Base would have, e.g., `Transpose <: WrappedArray`, and we could use | ||
# `Union{SomeArray, WrappedArray{<:Any, <:SomeArray}}` for dispatch. | ||
# https://github.com/JuliaLang/julia/pull/31563 | ||
|
||
# accessors for extracting information about the wrapper type | ||
Base.ndims(::Type{<:WrappedArray{T,N,Src,Dst}}) where {T,N,Src,Dst} = @isdefined(N) ? N : ndims(Dst) | ||
Base.eltype(::Type{<:WrappedArray{T,N,Src,Dst}}) where {T,N,Src,Dst} = @isdefined(T) ? T : ndims(Dst) | ||
Base.parent(W::Type{<:WrappedArray{T,N,Src,Dst}}) where {T,N,Src,Dst} = @isdefined(Dst) ? Dst.name.wrapper : Src.name.wrapper | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
a62a256
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
a62a256
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/16140
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: