-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Help wanted on creating something similar to StrideArray
#73
Comments
It sounds like you've done most of the work -- what's left? It's trivial to go from strides to indices, and not much harder to go from size. If you didn't have all the information as part of the type, I'd encourage balancing sparseness of runtime representation with overspecialization. But that's a non issue when all info is known at compile time. You can do
Then define all your usual As an aside, I prefer having slicing with |
Thank you for your reply. Please allow me to provide more information before I carefully read your code. Ultimately, I want to create a CuTeArray that looks like this: struct CuTeArray{T, N, A<:Engine{T}, L<:Layout{N}} <: AbstractArray{T, N}
engine::E
layout::L
end Here, the engine represents a one-dimensional vector for the underlying data. I would like to create two types of engines, one owning and one non-owning. Below is some draft code: abstract type Engine{T} <: DenseVector{T} end
struct ViewEngine{T, P<:Ref{T}} <: Engine{T} # non-owning
ptr::P
len::Int
end
mutable struct ArrayEngine{T, L} <: Engine{T} # owning, must have a static size
data::NTuple{L, T}
end I am not sure how to proceed. Upon generally reading the source code of StrideArrayCore, it seems I should write the following code: @inline Base.unsafe_convert(::Type{Ptr{T}}, A::ArrayEngine{T}) where {T} = Base.unsafe_convert(Ptr{T}, pointer_from_objref(A))
@inline Base.pointer(A::ArrayEngine{T}) where {T} = Base.unsafe_convert(Ptr{T}, pointer_from_objref(A)) |
Yes. You can implement The non-owning of course does not need |
This is my first attempt, and I'm not sure if I made any mistakes or did something unnecessary. abstract type Engine{T} <: DenseVector{T} end
@inline Base.IndexStyle(::Type{<:Engine}) = IndexLinear()
@inline Base.elsize(::Engine{T}) where {T} = sizeof(T)
struct ViewEngine{T, P} <: Engine{T} # non-owning
ptr::P
len::Int
end
@inline function ViewEngine(ptr::Ptr{T}, len::Int) where {T}
return ViewEngine{T, typeof(ptr)}(ptr, len)
end
@inline function ViewEngine(A::AbstractArray)
p = LayoutPointers.memory_reference(A)[1] # not sure what this does
return ViewEngine(p, length(A))
end
@inline Base.pointer(A::ViewEngine) = getfield(A, :ptr)
@inline function Base.unsafe_convert(p::Type{<:Ref{T}}, A::ViewEngine{T}) where {T}
return Base.unsafe_convert(p, pointer(A))
end
@inline Base.size(A::ViewEngine) = tuple(getfield(A, :len))
@inline Base.length(A::ViewEngine) = getfield(A, :len)
@inline function Base.getindex(A::ViewEngine{T}, i::Integer) where {T}
@boundscheck checkbounds(A, i)
return unsafe_load(pointer(A), i)
end
@inline function Base.setindex!(A::ViewEngine{T}, val, i::Integer) where {T}
@boundscheck checkbounds(A, i)
unsafe_store!(pointer(A), val, i)
return val
end
@inline ManualMemory.preserve_buffer(::ViewEngine) = nothing
mutable struct ArrayEngine{T, L} <: DenseVector{T} # owning
data::NTuple{L, T}
@inline ArrayEngine{T, L}(::UndefInitializer) where {T, L} = new{T, L}()
@inline function ArrayEngine{T}(::UndefInitializer, ::StaticInt{L}) where {T, L}
return ArrayEngine{T, L}(undef)
end
end
@inline function Base.unsafe_convert(::Type{Ptr{T}}, A::ArrayEngine{T}) where {T}
return Base.unsafe_convert(Ptr{T}, pointer_from_objref(A))
end
@inline function Base.pointer(A::ArrayEngine{T}) where {T}
return Base.unsafe_convert(Ptr{T}, pointer_from_objref(A))
end
@inline Base.size(::ArrayEngine{T, L}) where {T, L} = (L,)
@inline Base.length(::ArrayEngine{T, L}) where {T, L} = L
@inline Base.similar(::ArrayEngine{T, L}) where {T, L} = ArrayEngine{T, L}(undef)
@inline function Base.similar(A::ArrayEngine, ::Type{T}) where {T}
return ArrayEngine{T}(undef, static(length(A)))
end
@inline function ArrayEngine{T}(f::Function, ::StaticInt{L}) where {T, L}
A = ArrayEngine{T, L}(undef)
@inbounds for i in eachindex(A)
A[i] = f(i)
end
return A
end
@inline function ManualMemory.preserve_buffer(A::ArrayEngine)
return ManualMemory.preserve_buffer(getfield(A, :data))
end
Base.@propagate_inbounds function Base.getindex(A::ArrayEngine,
i::Union{Integer, StaticInt})
b = ManualMemory.preserve_buffer(A)
GC.@preserve b begin ViewEngine(A)[i] end
end
Base.@propagate_inbounds function Base.setindex!(A::ArrayEngine, val,
i::Union{Integer, StaticInt})
b = ManualMemory.preserve_buffer(A)
GC.@preserve b begin ViewEngine(A)[i] = val end
end |
I'm not quite sure how to properly implement the pointer for the CuTeArray that packs struct CuTeArray{T, N, E <: DenseVector{T}, L <: Layout{N}} <: AbstractArray{T, N}
engine::E
layout::L
function CuTeArray(engine::DenseVector{T}, layout::Layout{N}) where {T, N}
return new{T, N, typeof(engine), typeof(layout)}(engine, layout)
end
end
@inline function Base.unsafe_convert(::Type{Ptr{T}}, A::CuTeArray{T,N,<:ArrayEngine}) where {T,N}
return Base.unsafe_convert(Ptr{T}, pointer_from_objref(engine(A)))
end
@inline function Base.pointer(A::CuTeArray{T,N,<:ArrayEngine}) where {N,T}
return Base.unsafe_convert(Ptr{T}, pointer_from_objref(engine(A)))
end Is this even correct? |
Hi I think that I have got the hang of it so I'm closing this |
Thank you for creating this package, and I hope to receive some guidance from you. Recently, I have developed CuTe.jl, which abstracts the shape and stride(s) of an array into a structure called Layout. This abstraction allows us to perform various algebraic operations on the Layout. My goal is to implement something akin to StrideArray, with the distinction that the shape and stride are already parameterized by the Layout, and thus, the data only needs to be a one-dimensional vector. Moreover, the length of the vector should also be part of the type parameters. Do you have any suggestions or advice?
Also, you are welcome the have a glance at the source code of CuTe.jl for any apparent mistakes.
The text was updated successfully, but these errors were encountered: