Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add row filtering to nd2.index, as well as binary/roi data #151

Merged
merged 3 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 87 additions & 25 deletions src/nd2/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Iterable, Iterator, Sequence, cast, no_type_check
from typing import Any, Iterable, Iterator, Sequence, cast, no_type_check

from typing_extensions import TypedDict

Expand All @@ -33,6 +33,8 @@
dtype: str
shape: list[int]
axes: str
binary: bool
rois: bool
software_name: str
software_version: str
grabber: str
Expand All @@ -48,9 +50,11 @@
if nd.is_legacy:
software: dict = {}
acquired: str | None = ""
binary = False
else:
software = nd._rdr._app_info() # type: ignore
acquired = nd._rdr._acquisition_date() # type: ignore
binary = nd.binary_data is not None

stat = path.stat()
exp = [(x.type, x.count) for x in nd.experiment]
Expand All @@ -69,6 +73,8 @@
"dtype": str(nd.dtype),
"shape": list(shape),
"axes": "".join(axes),
"binary": binary,
"rois": bool(nd.rois),
"software_name": software.get("SWNameString", ""),
"software_version": software.get("VersionString", ""),
"grabber": software.get("GrabberString", ""),
Expand Down Expand Up @@ -96,29 +102,45 @@
return results


def _pretty_print_table(data: list[Record]) -> None:
def _pretty_print_table(data: list[Record], sort_column: str | None = None) -> None:
try:
from rich.console import Console
from rich.table import Table

except ImportError:
raise sys.exit(
"rich is required to print a pretty table. "
"Install it with `pip install rich`."
) from None

table = Table(show_header=True, header_style="bold")
headers = list(data[0])

# add headers, and highlight any sorted columns
sort_col = ""
if sort_column:
sort_col = (sort_column or "").rstrip("-")
direction = " ↓" if sort_column.endswith("-") else " ↑"
for header in headers:
if header == sort_col:
table.add_column(header + direction, style="green")
else:
table.add_column(header)

for header in data[0]:
table.add_column(header)
for row in data:
table.add_row(*[str(value) for value in row.values()])
table.add_row(*[_strify(value) for value in row.values()])

Console().print(table)


def _strify(val: Any) -> str:
if isinstance(val, bool):
return "✅" if val else ""
return str(val)


def _print_csv(records: list[Record], skip_header: bool = False) -> None:
import csv
import sys

writer = csv.DictWriter(sys.stdout, fieldnames=records[0].keys())
if not skip_header:
Expand Down Expand Up @@ -191,36 +213,79 @@
action="store_true",
help="Don't write the CSV header",
)
parser.add_argument(
"--filter",
"-F",
type=str,
action="append",
help="Filter the output. Each filter "
"should be a python expression (string)\nthat evaluates to True or False. "
"It will be evaluated in the context\nof each row. You can use any of the "
"column names as variables.\ne.g.: \"acquired > '2020' and kb < 500\". (May "
"be used multiple times).",
)

return parser.parse_args(argv or sys.argv[1:])


@no_type_check
def _filter_data(
data: list[Record],
to_include: Sequence[str] = (),
sort_by: str | None = None,
include: str | None = None,
exclude: str | None = None,
filters: Sequence[str] = (),
) -> list[Record]:
unrecognized = set(to_include) - set(HEADERS)
"""Filter and sort the data.

Parameters
----------
data : list[Record]
the data to filter
sort_by : str | None, optional
Name of column to sort by, by default None
include : str | None, optional
Comma-separated list of columns to include, by default None
exclude : str | None, optional
Comma-separated list of columns to exclude, by default None
filters : Sequence[str], optional
Sequence of python expression strings to filter the data, by default ()

Returns
-------
list[Record]
_description_
"""
includes = include.split(",") if include else []
unrecognized = set(includes) - set(HEADERS)
if unrecognized: # pragma: no cover
print(f"Unrecognized columns: {', '.join(unrecognized)}", file=sys.stderr)
to_include = [x for x in to_include if x not in unrecognized]
includes = [x for x in includes if x not in unrecognized]

if sort_by:
if sort_by.endswith("-"):
data.sort(key=lambda x: x[sort_by[:-1]], reverse=True)
else:
data.sort(key=lambda x: x[sort_by])

if to_include:
if includes:
# preserve order of to_include
data = [{h: row[h] for h in to_include} for row in data]
data = [{h: row[h] for h in includes} for row in data]

to_exclude = cast("list[str]", exclude.split(",") if exclude else [])

if to_exclude:
data = [{h: row[h] for h in HEADERS if h not in to_exclude} for row in data]

if sort_by:
if sort_by.endswith("-"):
data.sort(key=lambda x: x[sort_by[:-1]], reverse=True)
else:
data.sort(key=lambda x: x[sort_by])
if filters:
# filters are in the form of a string expression, to be evaluated
# against each row. For example, "'TimeLoop' in experiment"
for f in filters:
try:
data = [row for row in data if bool(eval(f, None, row))]
except Exception as e: # pragma: no cover
print(f"Error evaluating filter {f!r}: {e}", file=sys.stderr)
sys.exit(1)

return data

Expand All @@ -229,20 +294,17 @@
"""Index ND2 files and print the results as a table."""
args = _parse_args(argv)

to_include = cast("list[str]", args.include.split(",") if args.include else [])
if args.sort_by and to_include and args.sort_by not in to_include:
raise sys.exit( # pragma: no cover
f"The sort column {args.sort_by!r} must be in the "
f"included columns: {to_include!r}."
)

data = _index_files(paths=args.paths, recurse=args.recurse, glob=args.glob_pattern)
data = _filter_data(
data, to_include=to_include, sort_by=args.sort_by, exclude=args.exclude
data,
sort_by=args.sort_by,
include=args.include,
exclude=args.exclude,
filters=args.filter,
)

if args.format == "table":
_pretty_print_table(data)
_pretty_print_table(data, args.sort_by)

Check warning on line 307 in src/nd2/index.py

View check run for this annotation

Codecov / codecov/patch

src/nd2/index.py#L307

Added line #L307 was not covered by tests
elif args.format == "csv":
_print_csv(data, args.no_header)
elif args.format == "json":
Expand Down
1 change: 1 addition & 0 deletions src/nd2/nd2file.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def rois(self) -> dict[int, ROI]:
try:
_rois = [ROI._from_meta_dict(d) for d in dicts]
except Exception as e: # pragma: no cover
return {}
raise ValueError(f"Could not parse ROI metadata: {e}") from e
return {r.id: r for r in _rois}

Expand Down
8 changes: 8 additions & 0 deletions src/nd2/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,14 @@ def _from_meta_dict(cls, val: dict) -> ROI:
)


class T(TypedDict):
Id: int
Info: dict
GUID: str
AnimParams_Size: int
# AnimParams_{i}: dict


@dataclass
class AnimParam:
"""Parameters of ROI position/shape."""
Expand Down
16 changes: 11 additions & 5 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,31 @@ def test_format(records, fmt, capsys):
filtered = nd2.index._filter_data(records)

if fmt == "table":
nd2.index._pretty_print_table(filtered)
nd2.index._pretty_print_table(filtered, sort_column="name")
elif fmt == "csv":
nd2.index._print_csv(filtered)
elif fmt == "json":
nd2.index._print_json(filtered)
captured = capsys.readouterr()
assert "path" in captured.out
assert captured.out
assert not captured.err


@pytest.mark.parametrize(
"filters",
[
{},
{"to_include": ["path", "name", "version"]},
{"include": "path,name,version"},
{"sort_by": "version"},
{"sort_by": "version-"},
{"exclude": "path"},
{"filters": ("'TimeLoop' in experiment",)},
{"filters": ["acquired > '2020' and kb < 500"], "sort_by": "kb-"},
],
)
def test_filter_data(records, filters: dict):
def test_filter_data(records, filters: dict) -> None:
filtered = nd2.index._filter_data(records, **filters)
assert isinstance(filtered, list)
assert len(filtered) == len(records)
if filters.get("to_include"):
assert len(filtered[0]) == len(filters["to_include"])
sb = filters.get("sort_by")
Expand All @@ -48,6 +50,10 @@ def test_filter_data(records, filters: dict):
assert first_version == "3.0" if sb.endswith("-") else "1.0"
if filters.get("exclude"):
assert "path" not in filtered[0]
if filters.get("filters"):
assert len(filtered) < len(records)
else:
assert len(filtered) == len(records)


def test_index(capsys):
Expand Down