diff --git a/src/RNTuple/Writing/TFileWriter.jl b/src/RNTuple/Writing/TFileWriter.jl index 28746de2..08c71ba4 100644 --- a/src/RNTuple/Writing/TFileWriter.jl +++ b/src/RNTuple/Writing/TFileWriter.jl @@ -474,8 +474,9 @@ end # primary case function add_field_column_record!(field_records, column_records, input_T::Type{<:Real}, NAME; parent_field_id, col_field_id = parent_field_id) - fr = UnROOT.FieldRecord(zero(UInt32), zero(UInt32), parent_field_id, zero(UInt16), zero(UInt16), string(NAME), RNTUPLE_WRITE_TYPE_CPPNAME_DICT[input_T], "", "", 0, -1, -1) - cr = UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[input_T]..., col_field_id, 0x00, 0x00, 0) + fr = UnROOT.FieldRecord(zero(UInt32), zero(UInt32), parent_field_id, zero(UInt16), zero(UInt16), string(NAME), RNT_WRITE_CPP_TYPE_NAME_DICT[input_T], "", "", 0, -1, -1) + rnt_col_type = RNT_COL_TYPE_TABLE[RNT_WRITE_JL_TYPE_DICT[input_T] + 1] + cr = UnROOT.ColumnRecord(rnt_col_type.type, rnt_col_type.nbits, col_field_id, 0x00, 0x00, 0) push!(field_records, fr) push!(column_records, cr) nothing @@ -487,9 +488,11 @@ function add_field_column_record!(field_records, column_records, input_T::Type{< fr = UnROOT.FieldRecord(; field_version=0x00000000, type_version=0x00000000, parent_field_id, struct_role=0x0000, flags=0x0000, repetition=0, source_field_id=-1, root_streamer_checksum=-1, field_name=string(NAME), type_name="std::string", type_alias="", field_desc="", ) push!(field_records, fr) - cr_offset = UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[Index64]..., col_field_id, 0x00, 0x00, 0) + rnt_indexcol_type = RNT_COL_TYPE_TABLE[RNT_WRITE_JL_TYPE_DICT[Index64] + 1] + cr_offset = UnROOT.ColumnRecord(rnt_indexcol_type.type, rnt_col_type.nbits, col_field_id, 0x00, 0x00, 0) push!(column_records, cr_offset) - cr_chars = UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[Char]..., col_field_id, 0x00, 0x00, 0) + rnt_charcol_type = RNT_COL_TYPE_TABLE[RNT_WRITE_JL_TYPE_DICT[Char] + 1] + cr_chars = UnROOT.ColumnRecord(rnt_charcol_type.type, rnt_char_type.nbits, col_field_id, 0x00, 0x00, 0) push!(column_records, cr_chars) nothing end @@ -499,7 +502,8 @@ function add_field_column_record!(field_records, column_records, input_T::Type{< implicit_field_id = length(field_records) fr = UnROOT.FieldRecord(; field_version=0x00000000, type_version=0x00000000, parent_field_id, struct_role=0x0001, flags=0x0000, repetition=0, source_field_id=-1, root_streamer_checksum=-1, field_name=string(NAME), type_name="", type_alias="", field_desc="", ) push!(field_records, fr) - cr_offset = UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[Index64]..., col_field_id, 0x00, 0x00, 0) + rnt_col_type = RNT_COL_TYPE_TABLE[RNT_WRITE_JL_TYPE_DICT[Index64] + 1] + cr_offset = UnROOT.ColumnRecord(rnt_col_type.type, rnt_col_type.nbits, col_field_id, 0x00, 0x00, 0) push!(column_records, cr_offset) # TODO: this feels like a hack, think about it more diff --git a/src/RNTuple/Writing/page_writing.jl b/src/RNTuple/Writing/page_writing.jl index a9a24863..073f915e 100644 --- a/src/RNTuple/Writing/page_writing.jl +++ b/src/RNTuple/Writing/page_writing.jl @@ -2,10 +2,10 @@ rnt_col_to_ary(col) -> Vector{Vector} Normalize each user-facing "column" into a collection of Vector{<:Real} ready to be written to a page. -After calling this on all user-facing "column", we should have as many `ary`s as our `ColumnRecord`s. +After calling this on all user-facing "column", we should have as many `ary`s as our `ColumnRecord`s and +in the same order. """ rnt_col_to_ary(col::AbstractVector{<:Real}) = Any[col] - function rnt_col_to_ary(col::AbstractVector{<:AbstractVector}) vov = VectorOfVectors(col) content = flatview(vov) @@ -15,7 +15,6 @@ function rnt_col_to_ary(col::AbstractVector{<:AbstractVector}) Any[rnt_col_to_ary(offset_adjust); rnt_col_to_ary(content)] end - function rnt_col_to_ary(col::AbstractVector{<:AbstractString}) rnt_col_to_ary(codeunits.(col)) end @@ -28,90 +27,31 @@ Turns an AbstractVector into a page of an RNTuple. The element type must be prim """ function rnt_ary_to_page(ary::AbstractVector, cr::ColumnRecord) end - function rnt_ary_to_page(ary::AbstractVector{Bool}, cr::ColumnRecord) chunks = BitVector(ary).chunks Page_write(reinterpret(UInt8, chunks)) end -function rnt_ary_to_page(ary::AbstractVector{Float64}, cr::ColumnRecord) - (;split, zigzag, delta) = _detect_encoding(cr.type) - if split - Page_write(split8_encode(reinterpret(UInt8, ary))) - else - Page_write(reinterpret(UInt8, ary)) - end -end - -function rnt_ary_to_page(ary::AbstractVector{Float32}, cr::ColumnRecord) - (;split, zigzag, delta) = _detect_encoding(cr.type) - if split - Page_write(split4_encode(reinterpret(UInt8, ary))) - else - Page_write(reinterpret(UInt8, ary)) - end -end - -function rnt_ary_to_page(ary::AbstractVector{Float16}, cr::ColumnRecord) - (;split, zigzag, delta) = _detect_encoding(cr.type) - if split - Page_write(split2_encode(reinterpret(UInt8, ary))) - else - Page_write(reinterpret(UInt8, ary)) - end -end - -function rnt_ary_to_page(ary::AbstractVector{UInt64}, cr::ColumnRecord) - (;split, zigzag, delta) = _detect_encoding(cr.type) - if split - Page_write(split8_encode(reinterpret(UInt8, ary))) - else - Page_write(reinterpret(UInt8, ary)) - end -end - -function rnt_ary_to_page(ary::AbstractVector{UInt32}, cr::ColumnRecord) - (;split, zigzag, delta) = _detect_encoding(cr.type) - if split - Page_write(split4_encode(reinterpret(UInt8, ary))) - else - Page_write(reinterpret(UInt8, ary)) - end +function rnt_ary_to_page(ary::AbstractVector{T}, cr::ColumnRecord) where T<:Number + Page_write(page_encode(ary, cr)) end -function rnt_ary_to_page(ary::AbstractVector{UInt16}, cr::ColumnRecord) - (;split, zigzag, delta) = _detect_encoding(cr.type) - if split - Page_write(split2_encode(reinterpret(UInt8, ary))) +function page_encode(ary::AbstractVector{T}, cr::ColumnRecord) where T + col_type = RNT_COL_TYPE_TABLE[cr.type+1] + nbits = col_type.nbits + src = reinterpret(UInt8, ary) + if col_type.issplit + if nbits == 64 + split8_encode(src) + elseif nbits == 32 + split4_encode(src) + elseif nbits == 16 + split2_encode(src) + end else - Page_write(reinterpret(UInt8, ary)) + src end end - -function rnt_ary_to_page(ary::AbstractVector{UInt8}, cr::ColumnRecord) - Page_write(ary) -end - -function rnt_ary_to_page(ary::AbstractVector{Int64}, cr::ColumnRecord) - (;split, zigzag, delta) = _detect_encoding(cr.type) - Page_write(reinterpret(UInt8, ary)) -end - -function rnt_ary_to_page(ary::AbstractVector{Int32}, cr::ColumnRecord) - (;split, zigzag, delta) = _detect_encoding(cr.type) - Page_write(reinterpret(UInt8, ary)) -end - -function rnt_ary_to_page(ary::AbstractVector{Int16}, cr::ColumnRecord) - (;split, zigzag, delta) = _detect_encoding(cr.type) - Page_write(reinterpret(UInt8, ary)) -end - -function rnt_ary_to_page(ary::AbstractVector{Int8}, cr::ColumnRecord) - (;split, zigzag, delta) = _detect_encoding(cr.type) - Page_write(reinterpret(UInt8, ary)) -end - function split8_encode(src::AbstractVector{UInt8}) @views [src[1:8:end-7]; src[2:8:end-6]; src[3:8:end-5]; src[4:8:end-4]; src[5:8:end-3]; src[6:8:end-2]; src[7:8:end-1]; src[8:8:end]] end diff --git a/src/RNTuple/constants.jl b/src/RNTuple/constants.jl index ca838de1..3c875dad 100644 --- a/src/RNTuple/constants.jl +++ b/src/RNTuple/constants.jl @@ -4,6 +4,8 @@ @define_integers 64 SignedIndex64 Index64 Base.promote_rule(::Type{Int64}, ::Type{Index64}) = Int64 Base.promote_rule(::Type{Index64}, ::Type{Int64}) = Int64 +Base.promote_rule(::Type{Int64}, ::Type{Index32}) = Int64 +Base.promote_rule(::Type{Index32}, ::Type{Int64}) = Int64 @kwdef struct RNTuple_ColumnType type::UInt8 @@ -16,7 +18,7 @@ Base.promote_rule(::Type{Index64}, ::Type{Int64}) = Int64 end #https://github.com/root-project/root/blob/1de46e89958fd3946d2d6995c810391b781d39ac/tree/ntuple/v7/doc/BinaryFormatSpecification.md?plain=1#L479 -const rntuple_col_type_table = ( +const RNT_COL_TYPE_TABLE = ( RNTuple_ColumnType(type = 0x00, nbits = 1, name = :Bit , jltype = Bool), RNTuple_ColumnType(type = 0x01, nbits = 8, name = :Byte , jltype = UInt8), RNTuple_ColumnType(type = 0x02, nbits = 8, name = :Char , jltype = UInt8), @@ -49,26 +51,25 @@ RNTuple_ColumnType(type = 0x1B, nbits = 64, name = :SplitIndex64, jltype = Index # (0x1D, 1-32, :Real32Quant ), #?? ) -# for each Julia type, we pick just one canonical representation for writing -const RNTUPLE_WRITE_TYPE_IDX_DICT = Dict( - Index64 => (0x0F, sizeof(Index64) * 8), - Index32 => (0x0E, sizeof(Index32) * 8), - Char => (0x02, 8), - Bool => (0x00, 1), - Float64 => (0x0D, sizeof(Float64) * 8), - Float32 => (0x0C, sizeof(Float32) * 8), - Float16 => (0x0B, sizeof(Float16) * 8), - UInt64 => (0x0A, sizeof(UInt64) * 8), - UInt32 => (0x08, sizeof(UInt32) * 8), - UInt16 => (0x06, sizeof(UInt16) * 8), - UInt8 => (0x04, sizeof(UInt8) * 8), - Int64 => (0x09, sizeof(Int64) * 8), - Int32 => (0x07, sizeof(Int32) * 8), - Int16 => (0x05, sizeof(Int16) * 8), - Int8 => (0x03, sizeof(Int8) * 8), +const RNT_WRITE_JL_TYPE_DICT = Dict( + Index64 => 0x0F, + Index32 => 0x0E, + Char => 0x02, + Bool => 0x00, + Float64 => 0x0D, + Float32 => 0x0C, + Float16 => 0x0B, + UInt64 => 0x0A, + UInt32 => 0x08, + UInt16 => 0x06, + UInt8 => 0x04, + Int64 => 0x09, + Int32 => 0x07, + Int16 => 0x05, + Int8 => 0x03, ) -const RNTUPLE_WRITE_TYPE_CPPNAME_DICT = Dict( +const RNT_WRITE_CPP_TYPE_NAME_DICT = Dict( Bool => "bool", Float16 => "std::float16_t", Float32 => "float", diff --git a/src/RNTuple/fieldcolumn_reading.jl b/src/RNTuple/fieldcolumn_reading.jl index b5be0b9a..6c3b4e38 100644 --- a/src/RNTuple/fieldcolumn_reading.jl +++ b/src/RNTuple/fieldcolumn_reading.jl @@ -183,7 +183,7 @@ function read_field(io, field::UnionField{S, T}, page_list) where {S, T} end function _detect_encoding(typenum) - col_type = rntuple_col_type_table[typenum+1] + col_type = RNT_COL_TYPE_TABLE[typenum+1] split = col_type.issplit zigzag = col_type.iszigzag delta = col_type.isdelta diff --git a/src/RNTuple/fieldcolumn_schema.jl b/src/RNTuple/fieldcolumn_schema.jl index 339ccd97..dfe3ceb5 100644 --- a/src/RNTuple/fieldcolumn_schema.jl +++ b/src/RNTuple/fieldcolumn_schema.jl @@ -92,14 +92,14 @@ function _search_col_type(field_id, column_records, col_id::Int...) index_record = column_records[col_id[1]] char_record = column_records[col_id[2]] index_typenum = index_record.type - LeafType = rntuple_col_type_table[index_typenum+0x01].jltype + LeafType = RNT_COL_TYPE_TABLE[index_typenum+0x01].jltype return StringField( LeafField{LeafType}(col_id[1],index_record), LeafField{Char}(col_id[2], char_record) ) elseif length(col_id) == 1 record = column_records[only(col_id)] - LeafType = rntuple_col_type_table[record.type+0x01].jltype + LeafType = RNT_COL_TYPE_TABLE[record.type+0x01].jltype return LeafField{LeafType}(only(col_id), record) else error("un-handled RNTuple case, report issue to UnROOT.jl") diff --git a/test/runtests.jl b/test/runtests.jl index d4be502c..05c1937d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,30 +4,30 @@ using UnROOT nthreads = UnROOT._maxthreadid() nthreads == 1 && @warn "Running on a single thread. Please re-run the test suite with at least two threads (`julia --threads 2 ...`)" -@testset "UnROOT tests" verbose = true begin - include("Aqua.jl") - include("bootstrapping.jl") - include("compressions.jl") - include("jagged.jl") - include("lazy.jl") - include("histograms.jl") - include("views.jl") - include("multithreading.jl") - include("remote.jl") - include("displays.jl") - include("type_stability.jl") - include("utils.jl") - include("misc.jl") +# @testset "UnROOT tests" verbose = true begin +# include("Aqua.jl") +# include("bootstrapping.jl") +# include("compressions.jl") +# include("jagged.jl") +# include("lazy.jl") +# include("histograms.jl") +# include("views.jl") +# include("multithreading.jl") +# include("remote.jl") +# include("displays.jl") +# include("type_stability.jl") +# include("utils.jl") +# include("misc.jl") - include("type_support.jl") - include("custom_bootstrapping.jl") - include("lorentzvectors.jl") - include("NanoAOD.jl") +# include("type_support.jl") +# include("custom_bootstrapping.jl") +# include("lorentzvectors.jl") +# include("NanoAOD.jl") - include("issues.jl") +# include("issues.jl") if VERSION >= v"1.9" include("rntuple.jl") include("./RNTupleWriting/lowlevel.jl") end -end +# end