From 1b99b8131cc9c2791ccc22783628bef2e02b448b Mon Sep 17 00:00:00 2001 From: Marten Lienen Date: Mon, 7 Dec 2020 21:52:11 +0100 Subject: [PATCH] Make ndims and parent inferrable Previously, the `@isdefined` conditionals prevented julia from inferring the type of these methods at compile time. With this change broadcast operations in CUDA.jl like `C .- a'` can now be fully inferred which was previously impossible because the BroadcastStyle (which uses ndims internally) could not be decided at compile time. --- src/wrappers.jl | 37 ++++++++++++++++++++++--------------- test/runtests.jl | 8 +++++++- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/wrappers.jl b/src/wrappers.jl index 482ea7b..7cb53f0 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -103,19 +103,26 @@ WrappedArray{T,N,Src,Dst} = Union{ # https://github.com/JuliaLang/julia/pull/31563 # accessors for extracting information about the wrapper type -Base.ndims(W::Type{<:WrappedArray{<:Any,N}}) where {N} = - @isdefined(N) ? N : specialized_ndims(W) +Base.ndims(::Type{<:Base.LogicalIndex}) = 1 +Base.ndims(::Type{<:LinearAlgebra.Adjoint}) = 2 +Base.ndims(::Type{<:LinearAlgebra.Transpose}) = 2 +Base.ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2 +Base.ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2 +Base.ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2 +Base.ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2 +Base.ndims(::Type{<:LinearAlgebra.Diagonal}) = 2 +Base.ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2 +Base.ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N + Base.eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar -Base.parent(::Type{<:WrappedArray{<:Any,<:Any,Src,Dst}}) where {Src,Dst} = - @isdefined(Dst) ? Dst.name.wrapper : Src.name.wrapper - -# some wrappers don't have a N typevar because it is constant, but we can't extract that from <:WrappedArray -specialized_ndims(::Type{<:Base.LogicalIndex}) = 1 -specialized_ndims(::Type{<:LinearAlgebra.Adjoint}) = 2 -specialized_ndims(::Type{<:LinearAlgebra.Transpose}) = 2 -specialized_ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2 -specialized_ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2 -specialized_ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2 -specialized_ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2 -specialized_ndims(::Type{<:LinearAlgebra.Diagonal}) = 2 -specialized_ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2 + +for T in [:(Base.LogicalIndex{<:Any,<:Src}), + :(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:Src}), + :(WrappedReinterpretArray{<:Any,<:Any,<:Src}), + :(WrappedReshapedArray{<:Any,<:Any,<:Src}), + :(WrappedSubArray{<:Any,<:Any,<:Src})] + @eval begin + Base.parent(::Type{<:$T}) where {Src} = Src.name.wrapper + end +end +Base.parent(::Type{<:WrappedArray{<:Any,<:Any,<:Any,Dst}}) where {Dst} = Dst.name.wrapper diff --git a/test/runtests.jl b/test/runtests.jl index 23cdd44..5c417d6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -107,4 +107,10 @@ const du = CustomArray{Float64,1}(rand(2)) const d = CustomArray{Float64,1}(rand(3)) @test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomArray -@test ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2 +@testset "Extracting type information" begin + @test ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2 + @test ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == 3 + + @test parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array + @test parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array +end