diff --git a/Project.toml b/Project.toml index 4c91764..55a5711 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Strided" uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" authors = ["Lukas Devos ", "Maarten Van Damme ", "Jutho Haegeman "] -version = "2.0.2" +version = "2.0.3" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/mapreduce.jl b/src/mapreduce.jl index b0147eb..1c8ae48 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -162,7 +162,7 @@ function _mapreduce_block!(@nospecialize(f), @nospecialize(op), @nospecialize(in newarrays = (threadedout, Base.tail(arrays)...) _mapreduce_threaded!(f, op, nothing, dims, blocks, strides, offsets, costs, - newarrays, get_num_threads(), spacing) + newarrays, get_num_threads(), spacing, 1) for i in 1:get_num_threads() a = op(a, threadedout[(i - 1) * spacing + 1]) @@ -194,9 +194,9 @@ end # reduction function _mapreduce_threaded!(@nospecialize(f), @nospecialize(op), @nospecialize(initop), dims, blocks, strides, offsets, costs, arrays, nthreads, - spacing) + spacing, taskindex) if nthreads == 1 || prod(dims) <= MINTHREADLENGTH - offset1 = offsets[1] + spacing * (Threads.threadid() - 1) + offset1 = offsets[1] + spacing * (taskindex - 1) spacedoffsets = (offset1, Base.tail(offsets)...) _mapreduce_kernel!(f, op, initop, dims, blocks, arrays, strides, spacedoffsets) else @@ -213,13 +213,13 @@ function _mapreduce_threaded!(@nospecialize(f), @nospecialize(op), @nospecialize newoffsets = offsets t = Threads.@spawn _mapreduce_threaded!(f, op, initop, newdims, blocks, strides, newoffsets, costs, arrays, nnthreads, - spacing) + spacing, taskindex) stridesi = getindex.(strides, i) newoffsets2 = offsets .+ ndi .* stridesi newdims2 = setindex(dims, di - ndi, i) nnthreads2 = nthreads - nnthreads _mapreduce_threaded!(f, op, initop, newdims2, blocks, strides, newoffsets2, - costs, arrays, nnthreads2, spacing) + costs, arrays, nnthreads2, spacing, taskindex + nnthreads) wait(t) end end diff --git a/test/othertests.jl b/test/othertests.jl index 7a23ad6..2208779 100644 --- a/test/othertests.jl +++ b/test/othertests.jl @@ -115,7 +115,7 @@ end @test minimum(real, R1) ≈ minimum(real, StridedView(R1)) @test sum(x -> real(x) < 0, R1) == sum(x -> real(x) < 0, StridedView(R1)) - R1 = permutedims(R1, (randperm(6)...,)) + R1 = PermutedDimsArray(R1, (randperm(6)...,)) @test sum(R1) ≈ sum(StridedView(R1)) @test maximum(abs, R1) ≈ maximum(abs, StridedView(R1))