From 76d10b307df8a12af309d24742850e03ee2f5ddb Mon Sep 17 00:00:00 2001 From: dolf Date: Fri, 23 Aug 2024 00:01:23 +0200 Subject: [PATCH] Feature: `TableInfo` now allows either `rows` or (`n_rows` AND `get_cell`). The latter is useful to prevent copying data in memory unnecessarily. --- aa_py_openpyxl_util/_write_only.py | 83 ++++++++++++++++++++++++------ 1 file changed, 66 insertions(+), 17 deletions(-) diff --git a/aa_py_openpyxl_util/_write_only.py b/aa_py_openpyxl_util/_write_only.py index 110175d..1c4e636 100644 --- a/aa_py_openpyxl_util/_write_only.py +++ b/aa_py_openpyxl_util/_write_only.py @@ -6,9 +6,9 @@ import logging import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass from itertools import zip_longest -from typing import Optional, Any, Sequence, Generator, List, Iterable +from typing import Optional, Any, Sequence, Generator, List, Iterable, Callable from openpyxl import Workbook from openpyxl.cell import WriteOnlyCell, Cell @@ -123,8 +123,11 @@ def create_openpyxl_cell( ) -@dataclass class TableInfo: + """ + Info about a table which is to be written to a sheet using `openpyxl`. + """ + name: str """ The table name @@ -132,31 +135,68 @@ class TableInfo: column_names: Sequence[str] """ - The column names + The column names. """ - rows: Sequence[Sequence[FormattedCell]] + n_rows: int """ - The table rows. - Each row MUST have the same number of items as `column_names`. + The number of rows in the table. """ - pre_rows: Sequence[Sequence[FormattedCell]] = field(default_factory=list) + get_cell: Callable[[int, int], FormattedCell] + """ + A function that returns a cell for a given row and column index. The indices are 0-based. + """ + + pre_rows: Sequence[Sequence[FormattedCell]] """ Rows to write outside the table, above the header, but below the name and description. This may be wider or narrower as the table if required. """ - style: Optional[TableStyleInfo] = field(default=default_table_style) + style: Optional[TableStyleInfo] """ The table style """ - description: str = field(default="") + description: str """ A table description to write below the table name. """ + def __init__( + self, + *, + name: str, + column_names: Sequence[str], + n_rows: int | None = None, + get_cell: Callable[[int, int], FormattedCell] | None = None, + rows: Sequence[Sequence[FormattedCell]] | None = None, + pre_rows: Sequence[Sequence[FormattedCell]] | None = None, + style: Optional[TableStyleInfo] | None = None, + description: str | None = None, + ): + self.name = name + self.column_names = column_names + self.pre_rows = pre_rows or [] + self.style = style or default_table_style + self.description = description or "" + + if rows is None: + if n_rows is None or get_cell is None: + raise ValueError( + "Either `rows` or (`n_rows` and `get_cell`) must be provided." + ) + self.n_rows = n_rows + self.get_cell = get_cell + else: + if n_rows is not None or get_cell is not None: + raise ValueError( + "`rows` and (`n_rows` and `get_cell`) are mutually exclusive." + ) + self.n_rows = len(rows) + self.get_cell = lambda i_row, i_col: rows[i_row][i_col] # type: ignore[index] + @property def width(self) -> int: """ @@ -167,10 +207,16 @@ def width(self) -> int: [ len(self.column_names), *(len(r) for r in self.pre_rows), - *(len(r) for r in self.rows), ] ) + @property + def rows(self) -> Generator[Generator[FormattedCell, None, None], None, None]: + for i_row in range(self.n_rows): + yield ( + self.get_cell(i_row, i_col) for i_col in range(len(self.column_names)) + ) + def write_tables_side_by_side_over_multiple_sheets( *, @@ -316,7 +362,7 @@ def write_tables_side_by_side( first_row=first_row, name=t.name, column_names=t.column_names, - n_data_rows=len(t.rows), + n_data_rows=t.n_rows, style=t.style, ) results[t.name] = (coords, lo) @@ -427,7 +473,7 @@ def header_row() -> Generator[FormattedCell, None, None]: widths = [t.width for t in tables] def row( - data: Sequence[Optional[Sequence[FormattedCell]]], + data: Iterable[Optional[Iterable[FormattedCell]]], ) -> Generator[FormattedCell, None, None]: for w, d in zip(widths, data): yield from [FormattedCell(None)] * col_margin @@ -435,8 +481,11 @@ def row( if d is None: yield from [FormattedCell(None)] * w else: - yield from d - yield from [FormattedCell(None)] * (w - len(d)) + d_len = 0 + for cell in d: + yield cell + d_len += 1 + yield from [FormattedCell(None)] * (w - d_len) for _ in range(row_margin): yield [] @@ -446,8 +495,8 @@ def row( yield list(table_description_row()) if write_pre_rows: - for row_data in zip_longest(*(t.pre_rows for t in tables), fillvalue=None): - yield list(row(row_data)) + for pre_row_data in zip_longest(*(t.pre_rows for t in tables), fillvalue=None): + yield list(row(pre_row_data)) yield list(header_row())