Skip to content

Commit

Permalink
Rename LargeList.dtype to LargeList.feature (#7106)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova authored Aug 26, 2024
1 parent 3813ce8 commit 88f646c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 28 deletions.
38 changes: 18 additions & 20 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,11 +1175,11 @@ class LargeList:
It is backed by `pyarrow.LargeListType`, which is like `pyarrow.ListType` but with 64-bit rather than 32-bit offsets.
Args:
dtype ([`FeatureType`]):
feature ([`FeatureType`]):
Child feature data type of each item within the large list.
"""

dtype: Any
feature: Any
id: Optional[str] = None
# Automatically constructed
pa_type: ClassVar[Any] = None
Expand Down Expand Up @@ -1218,8 +1218,6 @@ def _check_non_null_non_empty_recursive(obj, schema: Optional[FeatureType] = Non
pass
elif isinstance(schema, (list, tuple)):
schema = schema[0]
elif isinstance(schema, LargeList):
schema = schema.dtype
else:
schema = schema.feature
return _check_non_null_non_empty_recursive(obj[0], schema)
Expand Down Expand Up @@ -1252,7 +1250,7 @@ def get_nested_type(schema: FeatureType) -> pa.DataType:
value_type = get_nested_type(schema[0])
return pa.list_(value_type)
elif isinstance(schema, LargeList):
value_type = get_nested_type(schema.dtype)
value_type = get_nested_type(schema.feature)
return pa.large_list(value_type)
elif isinstance(schema, Sequence):
value_type = get_nested_type(schema.feature)
Expand Down Expand Up @@ -1303,7 +1301,7 @@ def encode_nested_example(schema, obj, level=0):
return None
else:
if len(obj) > 0:
sub_schema = schema.dtype
sub_schema = schema.feature
for first_elmt in obj:
if _check_non_null_non_empty_recursive(first_elmt, sub_schema):
break
Expand Down Expand Up @@ -1384,7 +1382,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
if obj is None:
return None
else:
sub_schema = schema.dtype
sub_schema = schema.feature
if len(obj) > 0:
for first_elmt in obj:
if _check_non_null_non_empty_recursive(first_elmt, sub_schema):
Expand Down Expand Up @@ -1463,8 +1461,8 @@ def generate_from_dict(obj: Any):
raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}")

if class_type == LargeList:
dtype = obj.pop("dtype")
return LargeList(generate_from_dict(dtype), **obj)
feature = obj.pop("feature")
return LargeList(feature=generate_from_dict(feature), **obj)
if class_type == Sequence:
feature = obj.pop("feature")
return Sequence(feature=generate_from_dict(feature), **obj)
Expand Down Expand Up @@ -1493,8 +1491,8 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType:
return [feature]
return Sequence(feature=feature)
elif isinstance(pa_type, pa.LargeListType):
dtype = generate_from_arrow_type(pa_type.value_type)
return LargeList(dtype)
feature = generate_from_arrow_type(pa_type.value_type)
return LargeList(feature=feature)
elif isinstance(pa_type, _ArrayXDExtensionType):
array_feature = [None, None, Array2D, Array3D, Array4D, Array5D][pa_type.ndims]
return array_feature(shape=pa_type.shape, dtype=pa_type.value_type)
Expand Down Expand Up @@ -1601,7 +1599,7 @@ def _visit(feature: FeatureType, func: Callable[[FeatureType], Optional[FeatureT
elif isinstance(feature, (list, tuple)):
out = func([_visit(feature[0], func)])
elif isinstance(feature, LargeList):
out = func(LargeList(_visit(feature.dtype, func)))
out = func(LargeList(_visit(feature.feature, func)))
elif isinstance(feature, Sequence):
out = func(Sequence(_visit(feature.feature, func), length=feature.length))
else:
Expand All @@ -1624,7 +1622,7 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False
elif isinstance(feature, (list, tuple)):
return require_decoding(feature[0])
elif isinstance(feature, LargeList):
return require_decoding(feature.dtype)
return require_decoding(feature.feature)
elif isinstance(feature, Sequence):
return require_decoding(feature.feature)
else:
Expand All @@ -1644,7 +1642,7 @@ def require_storage_cast(feature: FeatureType) -> bool:
elif isinstance(feature, (list, tuple)):
return require_storage_cast(feature[0])
elif isinstance(feature, LargeList):
return require_storage_cast(feature.dtype)
return require_storage_cast(feature.feature)
elif isinstance(feature, Sequence):
return require_storage_cast(feature.feature)
else:
Expand All @@ -1664,7 +1662,7 @@ def require_storage_embed(feature: FeatureType) -> bool:
elif isinstance(feature, (list, tuple)):
return require_storage_cast(feature[0])
elif isinstance(feature, LargeList):
return require_storage_cast(feature.dtype)
return require_storage_cast(feature.feature)
elif isinstance(feature, Sequence):
return require_storage_cast(feature.feature)
else:
Expand Down Expand Up @@ -1876,8 +1874,8 @@ def to_yaml_inner(obj: Union[dict, list]) -> dict:
if isinstance(obj, dict):
_type = obj.pop("_type", None)
if _type == "LargeList":
value_type = obj.pop("dtype")
return simplify({"large_list": to_yaml_inner(value_type), **obj})
_feature = obj.pop("feature")
return simplify({"large_list": to_yaml_inner(_feature), **obj})
elif _type == "Sequence":
_feature = obj.pop("feature")
return simplify({"sequence": to_yaml_inner(_feature), **obj})
Expand Down Expand Up @@ -1947,8 +1945,8 @@ def from_yaml_inner(obj: Union[dict, list]) -> Union[dict, list]:
return {}
_type = next(iter(obj))
if _type == "large_list":
_dtype = unsimplify(obj).pop(_type)
return {"dtype": from_yaml_inner(_dtype), **obj, "_type": "LargeList"}
_feature = unsimplify(obj).pop(_type)
return {"feature": from_yaml_inner(_feature), **obj, "_type": "LargeList"}
if _type == "sequence":
_feature = unsimplify(obj).pop(_type)
return {"feature": from_yaml_inner(_feature), **obj, "_type": "Sequence"}
Expand Down Expand Up @@ -2180,7 +2178,7 @@ def recursive_reorder(source, target, stack=""):
elif isinstance(source, LargeList):
if not isinstance(target, LargeList):
raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position)
return LargeList(recursive_reorder(source.dtype, target.dtype, stack))
return LargeList(recursive_reorder(source.feature, target.feature, stack))
else:
return source

Expand Down
8 changes: 5 additions & 3 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,7 +2017,7 @@ def cast_array_to_feature(
array_offsets = _combine_list_array_offsets_with_mask(array)
return pa.ListArray.from_arrays(array_offsets, casted_array_values)
elif isinstance(feature, LargeList):
casted_array_values = _c(array.values, feature.dtype)
casted_array_values = _c(array.values, feature.feature)
if pa.types.is_large_list(array.type) and casted_array_values.type == array.values.type:
# Both array and feature have equal large_list type and values (within the list) type
return array
Expand Down Expand Up @@ -2075,7 +2075,9 @@ def cast_array_to_feature(
return pa.ListArray.from_arrays(array_offsets, _c(array.values, feature[0]), mask=array.is_null())
elif isinstance(feature, LargeList):
array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size
return pa.LargeListArray.from_arrays(array_offsets, _c(array.values, feature.dtype), mask=array.is_null())
return pa.LargeListArray.from_arrays(
array_offsets, _c(array.values, feature.feature), mask=array.is_null()
)
elif isinstance(feature, Sequence):
if feature.length > -1:
if feature.length == array.type.list_size:
Expand Down Expand Up @@ -2155,7 +2157,7 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"):
# feature must be LargeList(subfeature)
# Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError
array_offsets = _combine_list_array_offsets_with_mask(array)
return pa.LargeListArray.from_arrays(array_offsets, _e(array.values, feature.dtype))
return pa.LargeListArray.from_arrays(array_offsets, _e(array.values, feature.feature))
elif pa.types.is_fixed_size_list(array.type):
# feature must be Sequence(subfeature)
if isinstance(feature, Sequence) and feature.length > -1:
Expand Down
8 changes: 4 additions & 4 deletions tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def test_features_flatten_with_list_types(features_dict, expected_features_dict)
{"col": [Value("int32")]},
),
(
{"col": {"dtype": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"}},
{"col": {"feature": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"}},
{"col": LargeList(Value("int32"))},
),
(
Expand All @@ -738,7 +738,7 @@ def test_features_flatten_with_list_types(features_dict, expected_features_dict)
{"col": [{"sub_col": Value("int32")}]},
),
(
{"col": {"dtype": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"}},
{"col": {"feature": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"}},
{"col": LargeList({"sub_col": Value("int32")})},
),
(
Expand All @@ -760,7 +760,7 @@ def test_features_from_dict_with_list_types(deserialized_features_dict, expected
[Value("int32")],
),
(
{"dtype": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"},
{"feature": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"},
LargeList(Value("int32")),
),
(
Expand All @@ -772,7 +772,7 @@ def test_features_from_dict_with_list_types(deserialized_features_dict, expected
[{"sub_col": Value("int32")}],
),
(
{"dtype": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"},
{"feature": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"},
LargeList({"sub_col": Value("int32")}),
),
(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_dataset_info_from_dict_with_large_list():
dataset_info_dict = {
"citation": "",
"description": "",
"features": {"col_1": {"dtype": {"dtype": "int64", "_type": "Value"}, "_type": "LargeList"}},
"features": {"col_1": {"feature": {"dtype": "int64", "_type": "Value"}, "_type": "LargeList"}},
"homepage": "",
"license": "",
}
Expand Down

0 comments on commit 88f646c

Please sign in to comment.