Skip to content

Commit

Permalink
feat: add pretty nbytes repr to .show and jupyter repr (#3348)
Browse files Browse the repository at this point in the history
* add pretty nbytes repr to .show and jupyter repr

* ak.Array.show(): add backend arg; fix sorting of rows; fix doc string; KB->kB

* address Jim's comments

* address Jim's comments
  • Loading branch information
pfackeldey authored Dec 18, 2024
1 parent c59a49c commit 55f1909
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 78 deletions.
190 changes: 112 additions & 78 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from awkward._regularize import is_non_string_like_iterable
from awkward._typing import Any, TypeVar
from awkward._util import STDOUT
from awkward.prettyprint import Formatter
from awkward.prettyprint import Formatter, highlevel_array_show_rows
from awkward.prettyprint import valuestr as prettyprint_valuestr

__all__ = ("Array", "ArrayBuilder", "Record")
Expand Down Expand Up @@ -1398,10 +1398,13 @@ def show(
self,
limit_rows=20,
limit_cols=80,
*,
type=False,
named_axis=False,
nbytes=False,
backend=False,
all=False,
stream=STDOUT,
*,
formatter=None,
precision=3,
):
Expand All @@ -1411,9 +1414,16 @@ def show(
limit_cols (int): Maximum number of columns (characters wide).
type (bool): If True, print the type as well. (Doesn't count toward number
of rows/lines limit.)
named_axis (bool): If True, print the named axis as well. (Doesn't count toward number
of rows/lines limit.)
nbytes (bool): If True, print the number of bytes as well. (Doesn't count toward number
of rows/lines limit.)
backend (bool): If True, print the backend of the array as well. (Doesn't count toward number
of rows/lines limit.)
all (bool): If True, print the 'type', 'named axis', 'nbytes', and 'backend' of the array. (Doesn't count toward number
of rows/lines limit.)
stream (object with a ``write(str)`` method or None): Stream to write the
output to. If None, return a string instead of writing to a stream.
formatter (Mapping or None): Mapping of types/type-classes to string formatters.
If None, use the default formatter.
Expand All @@ -1426,64 +1436,70 @@ def show(
key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting
string values, falling back upon `"str_kind"`.
"""
formatter_impl = Formatter(formatter, precision=precision)

valuestr = prettyprint_valuestr(
self, limit_rows, limit_cols, formatter=formatter_impl
rows = highlevel_array_show_rows(
array=self,
limit_rows=limit_rows,
limit_cols=limit_cols,
type=type or all,
named_axis=named_axis or all,
nbytes=nbytes or all,
backend=backend or all,
formatter=formatter,
precision=precision,
)
array_line = rows.pop(0)

out_io = io.StringIO()
if type:
out_io.write("type: ")
self.type.show(stream=out_io)
if named_axis and self.named_axis:
out_io.write("axes: ")
out_io.write(
_prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
)
# it's always the second row (after the array)
type_line = rows.pop(0)
out_io.write(type_line)

# the rest of the rows we sort by the length of their '<prefix>:'
# but we sort it from shortest to longest contrary to _repr_mimebundle_
sorted_rows = sorted([r for r in rows if r], key=lambda x: len(x.split(":")[0]))

if sorted_rows:
out_io.write("\n".join(sorted_rows))
out_io.write("\n")
out_io.write(valuestr)

out_io.write(array_line)
if stream is None:
return out_io
return out_io.getvalue()
else:
if stream is STDOUT:
stream = STDOUT.stream
stream.write(out_io.getvalue() + "\n")

def _repr_mimebundle_(self, include=None, exclude=None):
# order: 1. array, 2. named_axis, 3. type
value_buff = io.StringIO()
self.show(type=False, named_axis=False, stream=value_buff)
header_lines = value_buff.getvalue().splitlines()

named_axis_line = ""
if self.named_axis:
named_axis_buff = io.StringIO()
named_axis_buff.write("axes: ")
named_axis_buff.write(
_prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
)
named_axis_line = named_axis_buff.getvalue()
# order:
# first: array,
# last: type,
# middle: rest sorted by length of prefix (longest first)

rows = highlevel_array_show_rows(
array=self,
type=True,
named_axis=True,
nbytes=True,
backend=True,
)
header_lines = rows.pop(0).removesuffix("\n").splitlines()

type_buff = io.StringIO()
self.type.show(stream=type_buff)
footer_lines = type_buff.getvalue().splitlines()
# Prepend a `type: ` prefix to the type information
footer_lines[0] = f"type: {footer_lines[0]}"
# it's always the second row (after the array)
type_lines = [rows.pop(0).removesuffix("\n")]

if header_lines[-1] == "":
del header_lines[-1]
# the rest of the rows we sort by the length of their '<prefix>:'
# but we sort it from longest to shortest for _repr_mimebundle_
sorted_rows = sorted(rows, key=lambda x: -len(x.split(":")[0]))

n_cols = max(
len(line)
for line in itertools.chain(header_lines, [named_axis_line], footer_lines)
len(line) for line in itertools.chain(header_lines, sorted_rows, type_lines)
)
body_lines = header_lines
body_lines.append("-" * n_cols)
if named_axis_line:
body_lines.append(named_axis_line)
body_lines.extend(footer_lines)
body_lines.extend(sorted_rows)
body_lines.extend(type_lines)
body = "\n".join(body_lines)

return {
Expand Down Expand Up @@ -2317,10 +2333,13 @@ def show(
self,
limit_rows=20,
limit_cols=80,
*,
type=False,
named_axis=False,
nbytes=False,
backend=False,
all=False,
stream=STDOUT,
*,
formatter=None,
precision=3,
):
Expand All @@ -2330,6 +2349,14 @@ def show(
limit_cols (int): Maximum number of columns (characters wide).
type (bool): If True, print the type as well. (Doesn't count toward number
of rows/lines limit.)
named_axis (bool): If True, print the named axis as well. (Doesn't count toward number
of rows/lines limit.)
nbytes (bool): If True, print the number of bytes as well. (Doesn't count toward number
of rows/lines limit.)
backend (bool): If True, print the backend of the array as well. (Doesn't count toward number
of rows/lines limit.)
all (bool): If True, print the 'type', 'named axis', 'nbytes', and 'backend' of the array. (Doesn't count toward number
of rows/lines limit.)
stream (object with a ``write(str)`` method or None): Stream to write the
output to. If None, return a string instead of writing to a stream.
formatter (Mapping or None): Mapping of types/type-classes to string formatters.
Expand All @@ -2344,23 +2371,34 @@ def show(
key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting
string values, falling back upon `"str_kind"`.
"""
formatter_impl = Formatter(formatter, precision=precision)
valuestr = prettyprint_valuestr(
self, limit_rows, limit_cols, formatter=formatter_impl
rows = highlevel_array_show_rows(
array=self,
limit_rows=limit_rows,
limit_cols=limit_cols,
type=type or all,
named_axis=named_axis or all,
nbytes=nbytes or all,
backend=backend or all,
formatter=formatter,
precision=precision,
)
array_line = rows.pop(0)

out_io = io.StringIO()
if type:
out_io.write("type: ")
self.type.show(stream=out_io)
if named_axis and self.named_axis:
out_io.write("axes: ")
out_io.write(
_prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
)
# it's always the second row (after the array)
type_line = rows.pop(0)
out_io.write(type_line)

# the rest of the rows we sort by the length of their '<prefix>:'
# but we sort it from shortest to longest contrary to _repr_mimebundle_
sorted_rows = sorted([r for r in rows if r], key=lambda x: len(x.split(":")[0]))

if sorted_rows:
out_io.write("\n".join(sorted_rows))
out_io.write("\n")
out_io.write(valuestr)

out_io.write(array_line)
if stream is None:
return out_io.getvalue()
else:
Expand All @@ -2369,38 +2407,34 @@ def show(
stream.write(out_io.getvalue() + "\n")

def _repr_mimebundle_(self, include=None, exclude=None):
# order: 1. array, 2. named_axis, 3. type
value_buff = io.StringIO()
self.show(type=False, named_axis=False, stream=value_buff)
header_lines = value_buff.getvalue().splitlines()

named_axis_line = ""
if self.named_axis:
named_axis_buff = io.StringIO()
named_axis_buff.write("axes: ")
named_axis_buff.write(
_prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
)
named_axis_line = named_axis_buff.getvalue()
# order:
# first: array,
# last: type,
# middle: rest sorted by length of prefix (longest first)

rows = highlevel_array_show_rows(
array=self,
type=True,
named_axis=True,
nbytes=True,
backend=True,
)
header_lines = rows.pop(0).removesuffix("\n").splitlines()

type_buff = io.StringIO()
self.type.show(stream=type_buff)
footer_lines = type_buff.getvalue().splitlines()
# Prepend a `type: ` prefix to the type information
footer_lines[0] = f"type: {footer_lines[0]}"
# it's always the second row (after the array)
type_lines = [rows.pop(0).removesuffix("\n")]

if header_lines[-1] == "":
del header_lines[-1]
# the rest of the rows we sort by the length of their '<prefix>:'
# but we sort it from longest to shortest for _repr_mimebundle_
sorted_rows = sorted(rows, key=lambda x: -len(x.split(":")[0]))

n_cols = max(
len(line)
for line in itertools.chain(header_lines, [named_axis_line], footer_lines)
len(line) for line in itertools.chain(header_lines, sorted_rows, type_lines)
)
body_lines = header_lines
body_lines.append("-" * n_cols)
if named_axis_line:
body_lines.append(named_axis_line)
body_lines.extend(footer_lines)
body_lines.extend(sorted_rows)
body_lines.extend(type_lines)
body = "\n".join(body_lines)

return {
Expand Down
61 changes: 61 additions & 0 deletions src/awkward/prettyprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from __future__ import annotations

import io
import math
import re
from collections.abc import Callable

import awkward as ak
from awkward._layout import wrap_layout
from awkward._namedaxis import _prettify_named_axes
from awkward._nplikes.numpy import Numpy, NumpyMetadata
from awkward._typing import TYPE_CHECKING, Any, TypeAlias, TypedDict

Expand Down Expand Up @@ -436,3 +438,62 @@ def valuestr(

else:
raise AssertionError(type(data))


def bytes_repr(nbytes: int) -> str:
count, unit = (
(f"{nbytes / 1e9 :,.1f}", "GB")
if nbytes > 1e9
else (f"{nbytes / 1e6 :,.1f}", "MB")
if nbytes > 1e6
else (f"{nbytes / 1e3 :,.1f}", "kB")
if nbytes > 1e3
else (f"{nbytes:,}", "B")
)

return f"{count} {unit}"


def highlevel_array_show_rows(
array,
limit_rows=20,
limit_cols=80,
type=False,
named_axis=False,
nbytes=False,
backend=False,
*,
formatter=None,
precision=3,
) -> list[str]:
rows = []
formatter_impl = Formatter(formatter, precision=precision)

array_line = valuestr(array, limit_rows, limit_cols, formatter=formatter_impl)
rows.append(array_line)

if type:
typeio = io.StringIO()
array.type.show(stream=typeio)
type_line = "type: "
type_line += typeio.getvalue().removesuffix("\n")
rows.append(type_line)

# other info
if named_axis and array.named_axis:
named_axis_line = "named axis: "
named_axis_line += _prettify_named_axes(
array.named_axis, delimiter=", ", maxlen=None
)
rows.append(named_axis_line)
if nbytes:
nbytes_line = f"nbytes: {bytes_repr(array.nbytes)}"
rows.append(nbytes_line)
if backend:
backend_line = f"backend: {array.layout.backend.name}"
rows.append(backend_line)

# make sure the type is always the second row, don't move it
if type:
assert rows[1].startswith("type: ")
return rows

0 comments on commit 55f1909

Please sign in to comment.