Skip to content

Commit

Permalink
Make ndims and parent inferrable
Browse files Browse the repository at this point in the history
Previously, the `@isdefined` conditionals prevented julia from inferring the type of these
methods at compile time. With this change broadcast operations in CUDA.jl like `C .- a'`
can now be fully inferred which was previously impossible because the
BroadcastStyle (which uses ndims internally) could not be decided at compile time.
  • Loading branch information
martenlienen authored and maleadt committed Dec 8, 2020
1 parent 1edd29c commit 1b99b81
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
37 changes: 22 additions & 15 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,26 @@ WrappedArray{T,N,Src,Dst} = Union{
# https://github.com/JuliaLang/julia/pull/31563

# accessors for extracting information about the wrapper type
Base.ndims(W::Type{<:WrappedArray{<:Any,N}}) where {N} =
@isdefined(N) ? N : specialized_ndims(W)
Base.ndims(::Type{<:Base.LogicalIndex}) = 1
Base.ndims(::Type{<:LinearAlgebra.Adjoint}) = 2
Base.ndims(::Type{<:LinearAlgebra.Transpose}) = 2
Base.ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2
Base.ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2
Base.ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2
Base.ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2
Base.ndims(::Type{<:LinearAlgebra.Diagonal}) = 2
Base.ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2
Base.ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N

Base.eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar
Base.parent(::Type{<:WrappedArray{<:Any,<:Any,Src,Dst}}) where {Src,Dst} =
@isdefined(Dst) ? Dst.name.wrapper : Src.name.wrapper

# some wrappers don't have a N typevar because it is constant, but we can't extract that from <:WrappedArray
specialized_ndims(::Type{<:Base.LogicalIndex}) = 1
specialized_ndims(::Type{<:LinearAlgebra.Adjoint}) = 2
specialized_ndims(::Type{<:LinearAlgebra.Transpose}) = 2
specialized_ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2
specialized_ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2
specialized_ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2
specialized_ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2
specialized_ndims(::Type{<:LinearAlgebra.Diagonal}) = 2
specialized_ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2

for T in [:(Base.LogicalIndex{<:Any,<:Src}),
:(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:Src}),
:(WrappedReinterpretArray{<:Any,<:Any,<:Src}),
:(WrappedReshapedArray{<:Any,<:Any,<:Src}),
:(WrappedSubArray{<:Any,<:Any,<:Src})]
@eval begin
Base.parent(::Type{<:$T}) where {Src} = Src.name.wrapper
end
end
Base.parent(::Type{<:WrappedArray{<:Any,<:Any,<:Any,Dst}}) where {Dst} = Dst.name.wrapper
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,10 @@ const du = CustomArray{Float64,1}(rand(2))
const d = CustomArray{Float64,1}(rand(3))
@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomArray

@test ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2
@testset "Extracting type information" begin
@test ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2
@test ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == 3

@test parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array
@test parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array
end

0 comments on commit 1b99b81

Please sign in to comment.