From 2c35dab352780ea19c479cd024b88dda70d60d71 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 16:04:55 -0400 Subject: [PATCH 1/4] Define resize --- src/array.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ test/array.jl | 21 +++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/src/array.jl b/src/array.jl index f2fb1a60..0679e79a 100644 --- a/src/array.jl +++ b/src/array.jl @@ -478,3 +478,45 @@ function Base.unsafe_wrap(::Type{Array}, arr::oneArray{T,N,oneL0.SharedBuffer}) ptr = reinterpret(Ptr{T}, pointer(arr)) unsafe_wrap(Array, ptr, size(arr)) end + +## resizing + +""" + resize!(a::oneVector, n::Integer) + +Resize `a` to contain `n` elements. If `n` is smaller than the current collection length, +the first `n` elements will be retained. If `n` is larger, the new elements are not +guaranteed to be initialized. +""" +function Base.resize!(a::oneVector{T}, n::Integer) where {T} + # TODO: add additional space to allow for quicker resizing + maxsize = n * sizeof(T) + bufsize = if isbitstype(T) + maxsize + else + # type tag array past the data + maxsize + n + end + + # replace the data with a new one. this 'unshares' the array. + # as a result, we can safely support resizing unowned buffers. + ctx = context() + dev = device(a) + buf = allocate(buftype(a), ctx, dev, bufsize, Base.datatype_alignment(T)) + ptr = ZePtr{T}(buf) + m = min(length(a), n) + if m > 0 + unsafe_copyto!(device(a), ptr, pointer(a), m) + end + new_data = DataRef(buf) do buf + free(buf) + end + unsafe_free!(a) + + a.data = new_data + a.dims = (n,) + a.maxsize = maxsize + a.offset = 0 + + a +end diff --git a/test/array.jl b/test/array.jl index b6b47dab..73121f86 100644 --- a/test/array.jl +++ b/test/array.jl @@ -74,3 +74,24 @@ end e = c .+ d @test oneAPI.buftype(e) == oneL0.SharedBuffer end + +@testset "resizing" begin + a = oneArray([1,2,3]) + + resize!(a, 3) + @test length(a) == 3 + @test Array(a) == [1,2,3] + + resize!(a, 5) + @test length(a) == 5 + @test Array(a)[1:3] == [1,2,3] + + resize!(a, 2) + @test length(a) == 2 + @test Array(a)[1:2] == [1,2] + + b = oneArray{Int}(undef, 0) + @test length(b) == 0 + resize!(b, 1) + @test length(b) == 1 +end From 9f0c88c26f1838bce28ddd2843bb8d7de8181b00 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 16:22:45 -0400 Subject: [PATCH 2/4] Try fixing pointer conversion --- src/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array.jl b/src/array.jl index 0679e79a..6c669793 100644 --- a/src/array.jl +++ b/src/array.jl @@ -503,7 +503,7 @@ function Base.resize!(a::oneVector{T}, n::Integer) where {T} ctx = context() dev = device(a) buf = allocate(buftype(a), ctx, dev, bufsize, Base.datatype_alignment(T)) - ptr = ZePtr{T}(buf) + ptr = pointer(buf) m = min(length(a), n) if m > 0 unsafe_copyto!(device(a), ptr, pointer(a), m) From 6c43e8b6c80d7346f0163727c2b52d292865111f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 16:28:32 -0400 Subject: [PATCH 3/4] Try fixing unsafe_copyto call --- src/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array.jl b/src/array.jl index 6c669793..c0aa6e0f 100644 --- a/src/array.jl +++ b/src/array.jl @@ -500,13 +500,13 @@ function Base.resize!(a::oneVector{T}, n::Integer) where {T} # replace the data with a new one. this 'unshares' the array. # as a result, we can safely support resizing unowned buffers. - ctx = context() + ctx = context(a) dev = device(a) buf = allocate(buftype(a), ctx, dev, bufsize, Base.datatype_alignment(T)) ptr = pointer(buf) m = min(length(a), n) if m > 0 - unsafe_copyto!(device(a), ptr, pointer(a), m) + unsafe_copyto!(ctx, dev, ptr, pointer(a), m) end new_data = DataRef(buf) do buf free(buf) From 4ee1d5c7048ae7b7d511b1ad1feca8b29d61fc36 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 16:38:13 -0400 Subject: [PATCH 4/4] Try improving pointer conversion --- src/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array.jl b/src/array.jl index c0aa6e0f..f765d2b3 100644 --- a/src/array.jl +++ b/src/array.jl @@ -503,7 +503,7 @@ function Base.resize!(a::oneVector{T}, n::Integer) where {T} ctx = context(a) dev = device(a) buf = allocate(buftype(a), ctx, dev, bufsize, Base.datatype_alignment(T)) - ptr = pointer(buf) + ptr = convert(ZePtr{T}, buf) m = min(length(a), n) if m > 0 unsafe_copyto!(ctx, dev, ptr, pointer(a), m)