Skip to content

Commit

Permalink
Feature: TableInfo now allows either rows or (n_rows AND `get_c…
Browse files Browse the repository at this point in the history
…ell`). The latter is useful to prevent copying data in memory unnecessarily.
  • Loading branch information
rudolfbyker committed Aug 22, 2024
1 parent 7049ef5 commit 76d10b3
Showing 1 changed file with 66 additions and 17 deletions.
83 changes: 66 additions & 17 deletions aa_py_openpyxl_util/_write_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import logging
import warnings
from dataclasses import dataclass, field
from dataclasses import dataclass
from itertools import zip_longest
from typing import Optional, Any, Sequence, Generator, List, Iterable
from typing import Optional, Any, Sequence, Generator, List, Iterable, Callable

from openpyxl import Workbook
from openpyxl.cell import WriteOnlyCell, Cell
Expand Down Expand Up @@ -123,40 +123,80 @@ def create_openpyxl_cell(
)


@dataclass
class TableInfo:
"""
Info about a table which is to be written to a sheet using `openpyxl`.
"""

name: str
"""
The table name
"""

column_names: Sequence[str]
"""
The column names
The column names.
"""

rows: Sequence[Sequence[FormattedCell]]
n_rows: int
"""
The table rows.
Each row MUST have the same number of items as `column_names`.
The number of rows in the table.
"""

pre_rows: Sequence[Sequence[FormattedCell]] = field(default_factory=list)
get_cell: Callable[[int, int], FormattedCell]
"""
A function that returns a cell for a given row and column index. The indices are 0-based.
"""

pre_rows: Sequence[Sequence[FormattedCell]]
"""
Rows to write outside the table, above the header, but below the name and description.
This may be wider or narrower as the table if required.
"""

style: Optional[TableStyleInfo] = field(default=default_table_style)
style: Optional[TableStyleInfo]
"""
The table style
"""

description: str = field(default="")
description: str
"""
A table description to write below the table name.
"""

def __init__(
self,
*,
name: str,
column_names: Sequence[str],
n_rows: int | None = None,
get_cell: Callable[[int, int], FormattedCell] | None = None,
rows: Sequence[Sequence[FormattedCell]] | None = None,
pre_rows: Sequence[Sequence[FormattedCell]] | None = None,
style: Optional[TableStyleInfo] | None = None,
description: str | None = None,
):
self.name = name
self.column_names = column_names
self.pre_rows = pre_rows or []
self.style = style or default_table_style
self.description = description or ""

if rows is None:
if n_rows is None or get_cell is None:
raise ValueError(
"Either `rows` or (`n_rows` and `get_cell`) must be provided."
)
self.n_rows = n_rows
self.get_cell = get_cell
else:
if n_rows is not None or get_cell is not None:
raise ValueError(
"`rows` and (`n_rows` and `get_cell`) are mutually exclusive."
)
self.n_rows = len(rows)
self.get_cell = lambda i_row, i_col: rows[i_row][i_col] # type: ignore[index]

@property
def width(self) -> int:
"""
Expand All @@ -167,10 +207,16 @@ def width(self) -> int:
[
len(self.column_names),
*(len(r) for r in self.pre_rows),
*(len(r) for r in self.rows),
]
)

@property
def rows(self) -> Generator[Generator[FormattedCell, None, None], None, None]:
for i_row in range(self.n_rows):
yield (
self.get_cell(i_row, i_col) for i_col in range(len(self.column_names))
)


def write_tables_side_by_side_over_multiple_sheets(
*,
Expand Down Expand Up @@ -316,7 +362,7 @@ def write_tables_side_by_side(
first_row=first_row,
name=t.name,
column_names=t.column_names,
n_data_rows=len(t.rows),
n_data_rows=t.n_rows,
style=t.style,
)
results[t.name] = (coords, lo)
Expand Down Expand Up @@ -427,16 +473,19 @@ def header_row() -> Generator[FormattedCell, None, None]:
widths = [t.width for t in tables]

def row(
data: Sequence[Optional[Sequence[FormattedCell]]],
data: Iterable[Optional[Iterable[FormattedCell]]],
) -> Generator[FormattedCell, None, None]:
for w, d in zip(widths, data):
yield from [FormattedCell(None)] * col_margin

if d is None:
yield from [FormattedCell(None)] * w
else:
yield from d
yield from [FormattedCell(None)] * (w - len(d))
d_len = 0
for cell in d:
yield cell
d_len += 1
yield from [FormattedCell(None)] * (w - d_len)

for _ in range(row_margin):
yield []
Expand All @@ -446,8 +495,8 @@ def row(
yield list(table_description_row())

if write_pre_rows:
for row_data in zip_longest(*(t.pre_rows for t in tables), fillvalue=None):
yield list(row(row_data))
for pre_row_data in zip_longest(*(t.pre_rows for t in tables), fillvalue=None):
yield list(row(pre_row_data))

yield list(header_row())

Expand Down

0 comments on commit 76d10b3

Please sign in to comment.