diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index 513d44206d..928d3f71a4 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -45,7 +45,7 @@ from awkward._regularize import is_non_string_like_iterable from awkward._typing import Any, TypeVar from awkward._util import STDOUT -from awkward.prettyprint import Formatter +from awkward.prettyprint import Formatter, highlevel_array_show_rows from awkward.prettyprint import valuestr as prettyprint_valuestr __all__ = ("Array", "ArrayBuilder", "Record") @@ -1398,10 +1398,13 @@ def show( self, limit_rows=20, limit_cols=80, + *, type=False, named_axis=False, + nbytes=False, + backend=False, + all=False, stream=STDOUT, - *, formatter=None, precision=3, ): @@ -1411,9 +1414,16 @@ def show( limit_cols (int): Maximum number of columns (characters wide). type (bool): If True, print the type as well. (Doesn't count toward number of rows/lines limit.) + named_axis (bool): If True, print the named axis as well. (Doesn't count toward number + of rows/lines limit.) + nbytes (bool): If True, print the number of bytes as well. (Doesn't count toward number + of rows/lines limit.) + backend (bool): If True, print the backend of the array as well. (Doesn't count toward number + of rows/lines limit.) + all (bool): If True, print the 'type', 'named axis', 'nbytes', and 'backend' of the array. (Doesn't count toward number + of rows/lines limit.) stream (object with a ``write(str)`` method or None): Stream to write the output to. If None, return a string instead of writing to a stream. - formatter (Mapping or None): Mapping of types/type-classes to string formatters. If None, use the default formatter. @@ -1426,64 +1436,70 @@ def show( key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting string values, falling back upon `"str_kind"`. """ - formatter_impl = Formatter(formatter, precision=precision) - - valuestr = prettyprint_valuestr( - self, limit_rows, limit_cols, formatter=formatter_impl + rows = highlevel_array_show_rows( + array=self, + limit_rows=limit_rows, + limit_cols=limit_cols, + type=type or all, + named_axis=named_axis or all, + nbytes=nbytes or all, + backend=backend or all, + formatter=formatter, + precision=precision, ) + array_line = rows.pop(0) out_io = io.StringIO() if type: - out_io.write("type: ") - self.type.show(stream=out_io) - if named_axis and self.named_axis: - out_io.write("axes: ") - out_io.write( - _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) - ) + # it's always the second row (after the array) + type_line = rows.pop(0) + out_io.write(type_line) + + # the rest of the rows we sort by the length of their ':' + # but we sort it from shortest to longest contrary to _repr_mimebundle_ + sorted_rows = sorted([r for r in rows if r], key=lambda x: len(x.split(":")[0])) + + if sorted_rows: + out_io.write("\n".join(sorted_rows)) out_io.write("\n") - out_io.write(valuestr) + out_io.write(array_line) if stream is None: - return out_io + return out_io.getvalue() else: if stream is STDOUT: stream = STDOUT.stream stream.write(out_io.getvalue() + "\n") def _repr_mimebundle_(self, include=None, exclude=None): - # order: 1. array, 2. named_axis, 3. type - value_buff = io.StringIO() - self.show(type=False, named_axis=False, stream=value_buff) - header_lines = value_buff.getvalue().splitlines() - - named_axis_line = "" - if self.named_axis: - named_axis_buff = io.StringIO() - named_axis_buff.write("axes: ") - named_axis_buff.write( - _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) - ) - named_axis_line = named_axis_buff.getvalue() + # order: + # first: array, + # last: type, + # middle: rest sorted by length of prefix (longest first) + + rows = highlevel_array_show_rows( + array=self, + type=True, + named_axis=True, + nbytes=True, + backend=True, + ) + header_lines = rows.pop(0).removesuffix("\n").splitlines() - type_buff = io.StringIO() - self.type.show(stream=type_buff) - footer_lines = type_buff.getvalue().splitlines() - # Prepend a `type: ` prefix to the type information - footer_lines[0] = f"type: {footer_lines[0]}" + # it's always the second row (after the array) + type_lines = [rows.pop(0).removesuffix("\n")] - if header_lines[-1] == "": - del header_lines[-1] + # the rest of the rows we sort by the length of their ':' + # but we sort it from longest to shortest for _repr_mimebundle_ + sorted_rows = sorted(rows, key=lambda x: -len(x.split(":")[0])) n_cols = max( - len(line) - for line in itertools.chain(header_lines, [named_axis_line], footer_lines) + len(line) for line in itertools.chain(header_lines, sorted_rows, type_lines) ) body_lines = header_lines body_lines.append("-" * n_cols) - if named_axis_line: - body_lines.append(named_axis_line) - body_lines.extend(footer_lines) + body_lines.extend(sorted_rows) + body_lines.extend(type_lines) body = "\n".join(body_lines) return { @@ -2317,10 +2333,13 @@ def show( self, limit_rows=20, limit_cols=80, + *, type=False, named_axis=False, + nbytes=False, + backend=False, + all=False, stream=STDOUT, - *, formatter=None, precision=3, ): @@ -2330,6 +2349,14 @@ def show( limit_cols (int): Maximum number of columns (characters wide). type (bool): If True, print the type as well. (Doesn't count toward number of rows/lines limit.) + named_axis (bool): If True, print the named axis as well. (Doesn't count toward number + of rows/lines limit.) + nbytes (bool): If True, print the number of bytes as well. (Doesn't count toward number + of rows/lines limit.) + backend (bool): If True, print the backend of the array as well. (Doesn't count toward number + of rows/lines limit.) + all (bool): If True, print the 'type', 'named axis', 'nbytes', and 'backend' of the array. (Doesn't count toward number + of rows/lines limit.) stream (object with a ``write(str)`` method or None): Stream to write the output to. If None, return a string instead of writing to a stream. formatter (Mapping or None): Mapping of types/type-classes to string formatters. @@ -2344,23 +2371,34 @@ def show( key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting string values, falling back upon `"str_kind"`. """ - formatter_impl = Formatter(formatter, precision=precision) - valuestr = prettyprint_valuestr( - self, limit_rows, limit_cols, formatter=formatter_impl + rows = highlevel_array_show_rows( + array=self, + limit_rows=limit_rows, + limit_cols=limit_cols, + type=type or all, + named_axis=named_axis or all, + nbytes=nbytes or all, + backend=backend or all, + formatter=formatter, + precision=precision, ) + array_line = rows.pop(0) out_io = io.StringIO() if type: - out_io.write("type: ") - self.type.show(stream=out_io) - if named_axis and self.named_axis: - out_io.write("axes: ") - out_io.write( - _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) - ) + # it's always the second row (after the array) + type_line = rows.pop(0) + out_io.write(type_line) + + # the rest of the rows we sort by the length of their ':' + # but we sort it from shortest to longest contrary to _repr_mimebundle_ + sorted_rows = sorted([r for r in rows if r], key=lambda x: len(x.split(":")[0])) + + if sorted_rows: + out_io.write("\n".join(sorted_rows)) out_io.write("\n") - out_io.write(valuestr) + out_io.write(array_line) if stream is None: return out_io.getvalue() else: @@ -2369,38 +2407,34 @@ def show( stream.write(out_io.getvalue() + "\n") def _repr_mimebundle_(self, include=None, exclude=None): - # order: 1. array, 2. named_axis, 3. type - value_buff = io.StringIO() - self.show(type=False, named_axis=False, stream=value_buff) - header_lines = value_buff.getvalue().splitlines() - - named_axis_line = "" - if self.named_axis: - named_axis_buff = io.StringIO() - named_axis_buff.write("axes: ") - named_axis_buff.write( - _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) - ) - named_axis_line = named_axis_buff.getvalue() + # order: + # first: array, + # last: type, + # middle: rest sorted by length of prefix (longest first) + + rows = highlevel_array_show_rows( + array=self, + type=True, + named_axis=True, + nbytes=True, + backend=True, + ) + header_lines = rows.pop(0).removesuffix("\n").splitlines() - type_buff = io.StringIO() - self.type.show(stream=type_buff) - footer_lines = type_buff.getvalue().splitlines() - # Prepend a `type: ` prefix to the type information - footer_lines[0] = f"type: {footer_lines[0]}" + # it's always the second row (after the array) + type_lines = [rows.pop(0).removesuffix("\n")] - if header_lines[-1] == "": - del header_lines[-1] + # the rest of the rows we sort by the length of their ':' + # but we sort it from longest to shortest for _repr_mimebundle_ + sorted_rows = sorted(rows, key=lambda x: -len(x.split(":")[0])) n_cols = max( - len(line) - for line in itertools.chain(header_lines, [named_axis_line], footer_lines) + len(line) for line in itertools.chain(header_lines, sorted_rows, type_lines) ) body_lines = header_lines body_lines.append("-" * n_cols) - if named_axis_line: - body_lines.append(named_axis_line) - body_lines.extend(footer_lines) + body_lines.extend(sorted_rows) + body_lines.extend(type_lines) body = "\n".join(body_lines) return { diff --git a/src/awkward/prettyprint.py b/src/awkward/prettyprint.py index b4075cf0d6..26cfa46012 100644 --- a/src/awkward/prettyprint.py +++ b/src/awkward/prettyprint.py @@ -2,12 +2,14 @@ from __future__ import annotations +import io import math import re from collections.abc import Callable import awkward as ak from awkward._layout import wrap_layout +from awkward._namedaxis import _prettify_named_axes from awkward._nplikes.numpy import Numpy, NumpyMetadata from awkward._typing import TYPE_CHECKING, Any, TypeAlias, TypedDict @@ -436,3 +438,62 @@ def valuestr( else: raise AssertionError(type(data)) + + +def bytes_repr(nbytes: int) -> str: + count, unit = ( + (f"{nbytes / 1e9 :,.1f}", "GB") + if nbytes > 1e9 + else (f"{nbytes / 1e6 :,.1f}", "MB") + if nbytes > 1e6 + else (f"{nbytes / 1e3 :,.1f}", "kB") + if nbytes > 1e3 + else (f"{nbytes:,}", "B") + ) + + return f"{count} {unit}" + + +def highlevel_array_show_rows( + array, + limit_rows=20, + limit_cols=80, + type=False, + named_axis=False, + nbytes=False, + backend=False, + *, + formatter=None, + precision=3, +) -> list[str]: + rows = [] + formatter_impl = Formatter(formatter, precision=precision) + + array_line = valuestr(array, limit_rows, limit_cols, formatter=formatter_impl) + rows.append(array_line) + + if type: + typeio = io.StringIO() + array.type.show(stream=typeio) + type_line = "type: " + type_line += typeio.getvalue().removesuffix("\n") + rows.append(type_line) + + # other info + if named_axis and array.named_axis: + named_axis_line = "named axis: " + named_axis_line += _prettify_named_axes( + array.named_axis, delimiter=", ", maxlen=None + ) + rows.append(named_axis_line) + if nbytes: + nbytes_line = f"nbytes: {bytes_repr(array.nbytes)}" + rows.append(nbytes_line) + if backend: + backend_line = f"backend: {array.layout.backend.name}" + rows.append(backend_line) + + # make sure the type is always the second row, don't move it + if type: + assert rows[1].startswith("type: ") + return rows