Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Metal arrays #31

Closed
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- macOS-latest
- windows-latest
arch:
- x64
- aarch64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[compat]
Adapt = "3, 4"
AMDGPU = "0.3.7, 0.4, 0.5, 0.6, 0.7, 0.8"
Adapt = "3, 4"
CUDA = "3.12, 4, 5"
julia = "1.9" # Minimum required Julia version (supporting extensions and weak dependencies)
Metal = "1"
StaticArrays = "1"
julia = "1.9" # Minimum required Julia version (supporting extensions and weak dependencies)

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "AMDGPU", "CUDA"]
test = ["Test", "AMDGPU", "CUDA", "Metal"]
39 changes: 38 additions & 1 deletion src/CellArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,44 @@ function define_ROCCellArray()
end
end

"""
@define_MtlCellArray

Define the following type alias and constructors in the caller module:

********************************************************************************
MtlCellArray{T<:Cell,N,B,T_elem} <: AbstractArray{T,N} where Cell <: Union{Number, SArray, FieldArray}

`N`-dimensional CellArray with cells of type `T`, blocklength `B`, and `T_array` being a `MtlArray` of element type `T_elem`: alias for `CellArray{T,N,B,MtlArray{T_elem,CellArrays._N}}`.

--------------------------------------------------------------------------------

MtlCellArray{T,B}(undef, dims)
MtlCellArray{T}(undef, dims)

Construct an uninitialized `N`-dimensional `CellArray` containing `Cells` of type `T` which are stored in an array of kind `MtlArray`.

See also: [`CellArray`](@ref), [`CPUCellArray`](@ref), [`ROCCellArray`](@ref)
********************************************************************************

!!! note "Avoiding unneeded dependencies"
The type aliases and constructors for GPU `CellArray`s are provided via macros to avoid unneeded dependencies on the GPU packages in CellArrays.

See also: [`@define_MtlCellArray`](@ref)
"""
macro define_MtlCellArray() esc(define_MtlCellArray()) end

function define_MtlCellArray()
quote
const MtlCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,Metal.MtlArray{T_elem,CellArrays._N}}

MtlCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = (CellArrays.check_T(T); MtlCellArray{T,N,B,eltype(T)}(undef, dims))
MtlCellArray{T,B}(::UndefInitializer, dims::Int...) where {T<:CellArrays.Cell,B} = MtlCellArray{T,B}(undef, dims)
MtlCellArray{T}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N} = MtlCellArray{T,0}(undef, dims)
MtlCellArray{T}(::UndefInitializer, dims::Int...) where {T<:CellArrays.Cell} = MtlCellArray{T}(undef, dims)
end
end


## AbstractArray methods

Expand All @@ -203,7 +241,6 @@ end
CellArray{T,N,B}(T_arraykind{eltype(T),_N}, undef, dims)
end


@inline function Base.fill!(A::CellArray{T,N,B,T_array}, x) where {T<:Number,N,B,T_array}
cell = convert(T, x)
A.data[:, 1, :] .= cell
Expand Down
2 changes: 1 addition & 1 deletion src/CellArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ using .Exceptions
include("CellArray.jl")

## Exports (need to be after include of submodules if re-exports from them)
export CellArray, CPUCellArray, @define_CuCellArray, @define_ROCCellArray, cellsize, blocklength, field
export CellArray, CPUCellArray, @define_CuCellArray, @define_ROCCellArray, @define_MtlCellArray, cellsize, blocklength, field
end
Loading
Loading