Skip to content

Commit

Permalink
fix maxthreadid() vs. nthreads() related issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Moelf committed Oct 10, 2023
1 parent d86e5c0 commit cba1360
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 19 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ AbstractTrees = "^0.3.0, 0.4"
ArraysOfArrays = "^0.5.3, ^0.6"
Arrow = "2 - 2.5"
BitIntegers = "^0.2.6, ^0.3"
CodecLz4 = "^0.3.0, ^0.4.0"
CodecXz = "^0.6.0, ^0.7.0"
CodecZstd = "^0.6.0, ^0.7.0"
CodecLz4 = "^0.3, ^0.4"
CodecXz = "^0.6, ^0.7"
CodecZstd = "^0.6, ^0.7, ^0.8"
HTTP = "^0.9.7, 1"
IterTools = "^1"
LRUCache = "^1.3.0"
Expand Down
7 changes: 4 additions & 3 deletions src/RNTuple/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ struct RNTupleField{R, F, O, E} <: AbstractVector{E}
function RNTupleField(rn::R, field::F) where {R, F}
O = _field_output_type(F)
E = eltype(O)
buffers = Vector{O}(undef, Threads.nthreads())
thread_locks = [ReentrantLock() for _ in 1:Threads.nthreads()]
buffer_ranges = [0:-1 for _ in 1:Threads.nthreads()]
Nthreads = _maxthreadid()
buffers = Vector{O}(undef, Nthreads)
thread_locks = [ReentrantLock() for _ in 1:Nthreads]
buffer_ranges = [0:-1 for _ in 1:Nthreads]

Check warning on line 25 in src/RNTuple/highlevel.jl

View check run for this annotation

Codecov / codecov/patch

src/RNTuple/highlevel.jl#L22-L25

Added lines #L22 - L25 were not covered by tests
new{R, F, O, E}(rn, field, buffers, thread_locks, buffer_ranges)
end
end
Expand Down
7 changes: 7 additions & 0 deletions src/UnROOT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ include("RNTuple/displays.jl")
# show(devnull, df)
# show(devnull, df[1])
# end
#

_maxthreadid() = @static if VERSION < v"1.9"
Threads.nthreads()
else
Threads.maxthreadid()
end

if VERSION >= v"1.9"
let
Expand Down
7 changes: 4 additions & 3 deletions src/iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,12 @@ mutable struct LazyBranch{T,J,B} <: AbstractVector{T}
_buffer = VectorOfVectors(T(), Int32[1])
T = SubArray{eltype(T), 1, T, Tuple{UnitRange{Int64}}, true}
end
Nthreads = _maxthreadid()
return new{T,J,typeof(_buffer)}(f, b, length(b),
b.fBasketEntry,
[_buffer for _ in 1:Threads.nthreads()],
[ReentrantLock() for _ in 1:Threads.nthreads()],
[0:-1 for _ in 1:Threads.nthreads()])
[_buffer for _ in 1:Nthreads],
[ReentrantLock() for _ in 1:Nthreads],
[0:-1 for _ in 1:Nthreads])
end
end
LazyBranch(f::ROOTFile, s::AbstractString) = LazyBranch(f, f[s])
Expand Down
11 changes: 10 additions & 1 deletion test/rntuple_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,17 @@ end
Threads.@threads for i in eachindex(field)
@inbounds accumulator[Threads.threadid()] += field[i]
end

# test we've hit each thread's buffer
@test all(!isempty, field.buffers)
@test all(
map(eachindex(field.buffers)) do b
if !isassigned(field.buffers, b)
return true
else
return !isempty(field.buffers[b])
end

end)
@test sum(accumulator) == sum(1:5e4)

accumulator .= 0
Expand Down
12 changes: 3 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using StaticArrays
using InteractiveUtils
using MD5

const nthreads = Threads.nthreads()
const nthreads = UnROOT._maxthreadid()
nthreads == 1 && @warn "Running on a single thread. Please re-run the test suite with at least two threads (`julia --threads 2 ...`)"

const SAMPLES_DIR = joinpath(@__DIR__, "samples")
Expand Down Expand Up @@ -766,18 +766,12 @@ end


if get(ENV, "CI", "false") == "true"
if nthreads >= 1
@test Threads.nthreads()>1
else
if nthreads == 1
@warn "CI wasn't run with multiple threads"
end
end

nmus = if isdefined(Threads, :maxthreadid)
zeros(Int, Threads.maxthreadid())
else
zeros(Int, Threads.nthreads())
end
nmus = zeros(Int, nthreads)

Threads.@threads for i in 1:length(t)
nmus[Threads.threadid()] += length(t.Muon_pt[i])
Expand Down

0 comments on commit cba1360

Please sign in to comment.