Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: miscellaneous RNTuple improvements #1250

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
34 changes: 26 additions & 8 deletions src/uproot/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
"""
from __future__ import annotations

import struct

import numpy

# determines when a file is "big"
Expand Down Expand Up @@ -118,8 +116,6 @@
kStreamedMemberWise = numpy.uint16(1 << 14)

############ RNTuple https://github.com/root-project/root/blob/master/tree/ntuple/v7/doc/specifications.md
_rntuple_frame_format = struct.Struct("<Q")
rntuple_env_header = _rntuple_frame_format.pack(0) # TODO: need to check this
rntuple_col_num_to_dtype_dict = {
1: "uint64",
2: "uint32",
Expand Down Expand Up @@ -212,7 +208,29 @@
"splitzigzagint16": 28,
}

rntuple_role_leaf = 0
rntuple_role_vector = 1
rntuple_role_struct = 2
rntuple_role_union = 3
rntuple_locator_type_uri = 0x01
rntuple_locator_type_daos = 0x02

rntuple_env_type_header = 0x01
rntuple_env_type_footer = 0x02
rntuple_env_type_pagelist = 0x03
rntuple_env_type_metadata = 0x04

rntuple_field_role_leaf = 0x00
rntuple_field_role_vector = 0x01
rntuple_field_role_struct = 0x02
rntuple_field_role_union = 0x03
rntuple_field_role_unsplit = 0x04

rntuple_field_flag_repetitive = 0x01
rntuple_field_flag_projected = 0x02
rntuple_field_flag_checksum = 0x04

rntuple_col_flag_deferred = 0x08

rntuple_extra_type_identifier_root = 0x00

rntuple_user_metadata_type_int = 0x01
rntuple_user_metadata_type_bool = 0x02
rntuple_user_metadata_type_double = 0x03
rntuple_user_metadata_type_string = 0x04
Comment on lines +211 to +236
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe these should be enums

152 changes: 112 additions & 40 deletions src/uproot/models/RNTuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,41 @@

import uproot

# https://github.com/root-project/root/blob/aa513463b0b512517370cb91cca025e53a8b13a2/tree/ntuple/v7/inc/ROOT/RNTupleAnchor.hxx#L69
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#anchor-schema
_rntuple_anchor_format = struct.Struct(">HHHHQQQQQQQ")

# https://github.com/root-project/root/blob/aa513463b0b512517370cb91cca025e53a8b13a2/tree/ntuple/v7/doc/specifications.md#envelopes
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#feature-flags
_rntuple_feature_flag_format = struct.Struct("<Q")
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#frames
_rntuple_frame_size_format = struct.Struct("<q")
_rntuple_frame_num_items_format = struct.Struct("<i")
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#locators-and-envelope-links
_rntuple_locator_format = struct.Struct("<iQ")
_rntuple_envlink_size_format = struct.Struct("<Q")
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#envelopes
_rntuple_env_header_format = struct.Struct("<Q")
_rntuple_checksum_format = struct.Struct("<Q")
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#field-description
_rntuple_field_description_format = struct.Struct("<IIIHH")
_rntuple_repetition_format = struct.Struct("<Q")
_rntuple_column_record_format = struct.Struct("<HHII")
_rntuple_source_field_id_format = struct.Struct("<I")
_rntuple_root_streamer_checksum_format = struct.Struct("<I")
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#column-description
_rntuple_column_record_format = struct.Struct("<HHIHH")
_rntuple_first_element_index_format = struct.Struct("<I")
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#alias-columns
_rntuple_alias_column_format = struct.Struct("<II")
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#extra-type-information
_rntuple_extra_type_info_format = struct.Struct("<III")
_rntuple_record_size_format = struct.Struct("<q")
_rntuple_frame_header_format = struct.Struct("<qi")
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#cluster-group-record-frame
_rntuple_cluster_group_format = struct.Struct("<qqi")
_rntuple_locator_format = struct.Struct("<iQ")
_rntuple_column_group_id_format = struct.Struct("<I")
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#cluster-summary-record-frame
_rntuple_cluster_summary_format = struct.Struct("<QQ")
_rntuple_checksum_format = struct.Struct("<Q")
_rntuple_envlink_size_format = struct.Struct("<Q")
# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#page-locations
_rntuple_page_num_elements_format = struct.Struct("<i")
_rntuple_column_group_id_format = struct.Struct("<I")
_rntuple_first_ele_index_format = struct.Struct("<I")


def from_zigzag(n):
def _from_zigzag(n):
return n >> 1 ^ -(n & 1)


Expand Down Expand Up @@ -97,6 +108,8 @@ def read_members(self, chunk, cursor, context, file):
self._alias_columns_dict_ = None
self._related_ids_ = None
self._column_records_dict_ = None
self._num_entries = None
self._length = None

self._page_list_envelopes = []

Expand Down Expand Up @@ -236,14 +249,45 @@ def footer(self):
def cluster_summaries(self):
return self.page_list_envelopes.cluster_summaries

# FIXME
@property
def _length(self):
return sum(x.num_entries for x in self.cluster_summaries)
def num_entries(self):
if self._num_entries is None:
self._num_entries = sum(x.num_entries for x in self.cluster_summaries)
return self._num_entries

def __len__(self):
if self._length is None:
self._length = len(self.keys())
return self._length

def __repr__(self):
if len(self) == 0:
return f"<RNTuple {self.name!r} at 0x{id(self):012x}>"
else:
return f"<RNTuple {self.name!r} ({len(self)} columns) at 0x{id(self):012x}>"

@property
def name(self):
"""
Name of the ``RNTuple``.
"""
return self.parent.fName

@property
def object_path(self):
"""
Object path of the ``RNTuple``.
"""
return self.parent.object_path

@property
def cache_key(self):
"""
String that uniquely specifies this ``RNTuple`` in its path, to use as
part of object and array cache keys.
"""
return f"{self.parent.cache_key}{self.name};{self.parent.fCycle}"

def read_locator(self, loc, uncomp_size, context):
cursor = uproot.source.cursor.Cursor(loc.offset)
chunk = self.file.source.chunk(loc.offset, loc.offset + loc.num_bytes)
Expand Down Expand Up @@ -278,9 +322,9 @@ def base_col_form(self, cr, col_id, parameters=None, cardinality=False):

form_key = f"column-{col_id}" + ("-cardinality" if cardinality else "")
dtype_byte = cr.type
if dtype_byte == uproot.const.rntuple_role_union:
if dtype_byte == uproot.const.rntuple_col_type_to_num_dict["switch"]:
return form_key
elif dtype_byte > uproot.const.rntuple_role_struct:
elif dtype_byte > uproot.const.rntuple_col_type_to_num_dict["switch"]:
dt_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte]
if dt_str == "bit":
dt_str = "bool"
Expand Down Expand Up @@ -335,7 +379,7 @@ def field_form(self, this_id, seen):
seen.add(this_id)
structural_role = this_record.struct_role
if (
structural_role == uproot.const.rntuple_role_leaf
structural_role == uproot.const.rntuple_field_role_leaf
and this_record.repetition == 0
):
# deal with std::atomic
Expand All @@ -350,7 +394,7 @@ def field_form(self, this_id, seen):
# base case of recursion
# n.b. the split may happen in column
return self.col_form(this_id)
elif structural_role == uproot.const.rntuple_role_leaf:
elif structural_role == uproot.const.rntuple_field_role_leaf:
if this_id in self._related_ids:
# std::array has only one subfield
child_id = self._related_ids[this_id][0]
Expand All @@ -360,7 +404,7 @@ def field_form(self, this_id, seen):
inner = self.col_form(this_id)
keyname = f"RegularForm-{this_id}"
return ak.forms.RegularForm(inner, this_record.repetition, form_key=keyname)
elif structural_role == uproot.const.rntuple_role_vector:
elif structural_role == uproot.const.rntuple_field_role_vector:
if this_id not in self._related_ids or len(self._related_ids[this_id]) != 1:
keyname = f"vector-{this_id}"
newids = self._related_ids.get(this_id, [])
Expand All @@ -384,15 +428,15 @@ def field_form(self, this_id, seen):
child_id = self._related_ids[this_id][0]
inner = self.field_form(child_id, seen)
return ak.forms.ListOffsetForm("i64", inner, form_key=keyname)
elif structural_role == uproot.const.rntuple_role_struct:
elif structural_role == uproot.const.rntuple_field_role_struct:
newids = []
if this_id in self._related_ids:
newids = self._related_ids[this_id]
# go find N in the rest, N is the # of fields in struct
recordlist = [self.field_form(i, seen) for i in newids]
namelist = [field_records[i].field_name for i in newids]
return ak.forms.RecordForm(recordlist, namelist, form_key="whatever")
elif structural_role == uproot.const.rntuple_role_union:
elif structural_role == uproot.const.rntuple_field_role_union:
keyname = self.col_form(this_id)
newids = []
if this_id in self._related_ids:
Expand All @@ -402,6 +446,10 @@ def field_form(self, this_id, seen):
"i8", "i64", recordlist, form_key=keyname + "-union"
)
return ak.forms.IndexedOptionForm("i64", inner, form_key=keyname)
elif structural_role == uproot.const.rntuple_field_role_unsplit:
raise NotImplementedError(
f"Unsplit fields are not supported. {this_record}"
)
else:
# everything should recurse above this branch
raise AssertionError("this should be unreachable")
Expand Down Expand Up @@ -480,13 +528,13 @@ def read_pagedesc(self, destination, desc, dtype_str, dtype, nbits, split):
# needed to chop off extra bits incase we used `unpackbits`
destination[:] = content[:num_elements]

def read_col_pages(self, ncol, cluster_range, pad_missing_ele=False):
def read_col_pages(self, ncol, cluster_range, pad_missing_element=False):
res = numpy.concatenate(
[self.read_col_page(ncol, i) for i in cluster_range], axis=0
)
if pad_missing_ele:
first_ele_index = self.column_records[ncol].first_ele_index
res = numpy.pad(res, (first_ele_index, 0))
if pad_missing_element:
first_element_index = self.column_records[ncol].first_element_index
res = numpy.pad(res, (first_element_index, 0))
return res

def read_col_page(self, ncol, cluster_i):
Expand Down Expand Up @@ -523,7 +571,7 @@ def read_col_page(self, ncol, cluster_i):
if index:
res = numpy.insert(res, 0, 0) # for offsets
if zigzag:
res = from_zigzag(res)
res = _from_zigzag(res)
elif delta:
res = numpy.cumsum(res)
return res
Expand All @@ -539,7 +587,7 @@ def arrays(
):
ak = uproot.extras.awkward()

entry_stop = entry_stop or self._length
entry_stop = entry_stop or self.num_entries

clusters = self.cluster_summaries
cluster_starts = numpy.array([c.num_first_entry for c in clusters])
Expand All @@ -566,7 +614,7 @@ def arrays(
content = self.read_col_pages(
key_nr,
range(start_cluster_idx, stop_cluster_idx),
pad_missing_ele=True,
pad_missing_element=True,
)
if "cardinality" in key:
content = numpy.diff(content)
Expand Down Expand Up @@ -639,6 +687,9 @@ def __init__(self):
def read(self, chunk, cursor, context):
out = MetaData(type(self).__name__)
out.env_header = _envelop_header(chunk, cursor, context)
assert (
out.env_header["env_type_id"] == uproot.const.rntuple_env_type_pagelist
), f"env_type_id={out.env_header['env_type_id']}"
out.header_checksum = cursor.field(chunk, _rntuple_checksum_format, context)
out.cluster_summaries = self.list_cluster_summaries.read(chunk, cursor, context)
out.pagelinklist = self.nested_page_locations.read(chunk, cursor, context)
Expand All @@ -652,6 +703,8 @@ def read(self, chunk, cursor, context):
out.num_bytes, out.offset = cursor.fields(
chunk, _rntuple_locator_format, context
)
if out.num_bytes < 0:
raise NotImplementedError("Non-disk locators are not supported.")
return out


Expand Down Expand Up @@ -692,7 +745,7 @@ def __init__(self, payload):

def read(self, chunk, cursor, context):
local_cursor = cursor.copy()
num_bytes = local_cursor.field(chunk, _rntuple_record_size_format, context)
num_bytes = local_cursor.field(chunk, _rntuple_frame_size_format, context)
assert num_bytes >= 0, f"num_bytes={num_bytes}"
cursor.skip(num_bytes)
return self.payload.read(chunk, local_cursor, context)
Expand All @@ -704,10 +757,9 @@ def __init__(self, payload):

def read(self, chunk, cursor, context):
local_cursor = cursor.copy()
num_bytes, num_items = local_cursor.fields(
chunk, _rntuple_frame_header_format, context
)
num_bytes = local_cursor.field(chunk, _rntuple_frame_size_format, context)
assert num_bytes < 0, f"num_bytes={num_bytes}"
num_items = local_cursor.field(chunk, _rntuple_frame_num_items_format, context)
cursor.skip(-num_bytes)
return [
self.payload.read(chunk, local_cursor, context) for _ in range(num_items)
Expand All @@ -725,10 +777,24 @@ def read(self, chunk, cursor, context):
out.struct_role,
out.flags,
) = cursor.fields(chunk, _rntuple_field_description_format, context)
if out.flags == 0x0001:
if out.flags == uproot.const.rntuple_field_flag_repetitive:
out.repetition = cursor.field(chunk, _rntuple_repetition_format, context)
out.source_field_id = None
out.checksum = None
elif out.flags == uproot.const.rntuple_field_flag_projected:
out.repetition = 0
out.source_field_id = cursor.field(
chunk, _rntuple_source_field_id_format, context
)
out.checksum = None
elif out.flags == uproot.const.rntuple_field_flag_checksum:
out.repetition = 0
out.source_field_id = None
out.checksum = cursor.field(chunk, _rntuple_checksum_format, context)
else:
out.repetition = 0
out.source_field_id = None
out.checksum = None
out.field_name, out.type_name, out.type_alias, out.field_desc = (
cursor.rntuple_string(chunk, context) for _ in range(4)
)
Expand All @@ -739,15 +805,15 @@ def read(self, chunk, cursor, context):
class ColumnRecordReader:
def read(self, chunk, cursor, context):
out = MetaData("ColumnRecordFrame")
out.type, out.nbits, out.field_id, out.flags = cursor.fields(
out.type, out.nbits, out.field_id, out.flags, out.repr_idx = cursor.fields(
chunk, _rntuple_column_record_format, context
)
if out.flags & 0x08:
out.first_ele_index = cursor.field(
chunk, _rntuple_first_ele_index_format, context
if out.flags & uproot.const.rntuple_col_flag_deferred:
out.first_element_index = cursor.field(
chunk, _rntuple_first_element_index_format, context
)
else:
out.first_ele_index = 0
out.first_element_index = 0
return out


Expand Down Expand Up @@ -790,6 +856,9 @@ def __init__(self):
def read(self, chunk, cursor, context):
out = MetaData(type(self).__name__)
out.env_header = _envelop_header(chunk, cursor, context)
assert (
out.env_header["env_type_id"] == uproot.const.rntuple_env_type_header
), f"env_type_id={out.env_header['env_type_id']}"
out.feature_flag = cursor.field(chunk, _rntuple_feature_flag_format, context)
out.name, out.ntuple_description, out.writer_identifier = (
cursor.rntuple_string(chunk, context) for _ in range(3)
Expand Down Expand Up @@ -846,7 +915,7 @@ def read(self, chunk, cursor, context):
class RNTupleSchemaExtension:
def read(self, chunk, cursor, context):
out = MetaData(type(self).__name__)
out.size = cursor.field(chunk, _rntuple_record_size_format, context)
out.size = cursor.field(chunk, _rntuple_frame_size_format, context)
assert out.size >= 0, f"size={out.size}"
out.field_records = ListFrameReader(
RecordFrameReader(FieldRecordReader())
Expand Down Expand Up @@ -880,6 +949,9 @@ def __init__(self):
def read(self, chunk, cursor, context):
out = MetaData("Footer")
out.env_header = _envelop_header(chunk, cursor, context)
assert (
out.env_header["env_type_id"] == uproot.const.rntuple_env_type_footer
), f"env_type_id={out.env_header['env_type_id']}"
out.feature_flag = cursor.field(chunk, _rntuple_feature_flag_format, context)
out.header_checksum = cursor.field(chunk, _rntuple_checksum_format, context)
out.extension_links = self.extension_header_links.read(chunk, cursor, context)
Expand Down
Loading
Loading