diff --git a/CHANGELOG.md b/CHANGELOG.md index f8839e0f25..a598f2c4a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Fixed issue with `Tabs` where disabled tabs could still be activated by clicking the underline https://github.com/Textualize/textual/issues/4701 - Fixed scroll_visible with margin https://github.com/Textualize/textual/pull/4719 - Fixed programmatically disabling button stuck in hover state https://github.com/Textualize/textual/pull/4724 +- Fixed `DataTable` poor performance on startup and focus change when rows contain multi-line content https://github.com/Textualize/textual/pull/4748 - Fixed `Tree` and `DirectoryTree` horizontal scrolling off-by-2 https://github.com/Textualize/textual/pull/4744 - Fixed text-opacity in component styles https://github.com/Textualize/textual/pull/4747 diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index ee77bddbe1..107635e9e9 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from itertools import chain, zip_longest from operator import itemgetter -from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast +from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar import rich.repr from rich.console import RenderableType @@ -49,6 +49,8 @@ _DEFAULT_CELL_X_PADDING = 1 """Default padding to use on each side of a column in the data table.""" +_EMPTY_TEXT = Text(no_wrap=True, end="") + class CellDoesNotExist(Exception): """The cell key/index was invalid. @@ -152,22 +154,68 @@ def __rich_repr__(self): yield "column_key", self.column_key -def default_cell_formatter(obj: object) -> RenderableType: +def _find_newline(string: str, number: int) -> int: + """Find newline number n (the nth newline) in a string. + + Args: + string: The string to search. + number: The nth newline character to find. + + Returns: + The index of the nth newline character, or -1 if not found. + """ + if not string or number < 1: + return -1 + + pos = -1 + for _ in range(number): + pos = string.find("\n", pos + 1) + if pos == -1: + break + return pos + + +def default_cell_formatter( + obj: object, wrap: bool = True, height: int = 0 +) -> RenderableType: """Convert a cell into a Rich renderable for display. Args: obj: Data for a cell. + wrap: Enable or disable wrapping inside the cell. + height: The height of the cell, or `None` to render the entire cell. + This can be used to short-circuit rendering. e.g. If we know the cell + has a height of 1, we can render the cell as a single line of text + without any wrapping. Returns: A renderable to be displayed which represents the data. """ + # Get the string which will be displayed in the cell. + possible_markup = False if isinstance(obj, str): - return Text.from_markup(obj) - if isinstance(obj, float): - return f"{obj:.2f}" - if not is_renderable(obj): - return str(obj) - return cast(RenderableType, obj) + possible_markup = True + content = obj + elif isinstance(obj, float): + content = f"{obj:.2f}" + elif not is_renderable(obj): + content = str(obj) + else: + return obj + + if height: + # Let's throw away lines which definitely won't appear in the cell + # after wrapping using the height constraint. A cell can only grow + # vertically after wrapping occurs, so this is a safe operation. + trim_position = _find_newline(content, height) + if trim_position != -1 and trim_position != len(content) - 1: + content = content[:trim_position] + + if possible_markup: + text = Text.from_markup(content, end="") + text.no_wrap = not wrap + return text + return Text(content, no_wrap=not wrap, end="") @dataclass @@ -1032,7 +1080,11 @@ def get_row_height(self, row_key: RowKey) -> int: return self.rows[row_key].height def notify_style_update(self) -> None: - self._clear_caches() + self._row_render_cache.clear() + self._cell_render_cache.clear() + self._line_cache.clear() + self._styles_cache.clear() + self._get_styles_to_render_cell.cache_clear() self.refresh() def _on_resize(self, _: events.Resize) -> None: @@ -1268,19 +1320,37 @@ def _update_column_widths(self, updated_cells: set[CellKey]) -> None: """Update the widths of the columns based on the newly updated cell widths.""" for row_key, column_key in updated_cells: column = self.columns.get(column_key) - if column is None: + row = self.rows.get(row_key) + if column is None or row is None: continue console = self.app.console label_width = measure(console, column.label, 1) content_width = column.content_width cell_value = self._data[row_key][column_key] - new_content_width = measure(console, default_cell_formatter(cell_value), 1) + render_height = row.height + new_content_width = measure( + console, + default_cell_formatter( + cell_value, + wrap=row.height != 1, + height=render_height, + ), + 1, + ) if new_content_width < content_width: cells_in_column = self.get_column(column_key) cell_widths = [ - measure(console, default_cell_formatter(cell), 1) + measure( + console, + default_cell_formatter( + cell, + wrap=row.height != 1, + height=render_height, + ), + 1, + ) for cell in cells_in_column ] column.content_width = max([*cell_widths, label_width]) @@ -1586,7 +1656,9 @@ def add_row( column.key: cell for column, cell in zip_longest(self.ordered_columns, cells) } - label = Text.from_markup(label) if isinstance(label, str) else label + + label = Text.from_markup(label, end="") if isinstance(label, str) else label + # Rows with auto-height get a height of 0 because 1) we need an integer height # to do some intermediate computations and 2) because 0 doesn't impact the data # table while we don't figure out how tall this row is. @@ -1896,17 +1968,35 @@ def _get_row_renderables(self, row_index: int) -> RowRenderables: return RowRenderables(None, header_row) ordered_row = self.get_row_at(row_index) - empty = Text() - - formatted_row_cells = [ - Text() if datum is None else default_cell_formatter(datum) or empty + row_key = self._row_locations.get_key(row_index) + if row_key is None: + return RowRenderables(None, []) + row_metadata = self.rows.get(row_key) + if row_metadata is None: + return RowRenderables(None, []) + + formatted_row_cells: list[RenderableType] = [ + ( + _EMPTY_TEXT + if datum is None + else default_cell_formatter( + datum, + wrap=row_metadata.height != 1, + height=row_metadata.height, + ) + or _EMPTY_TEXT + ) for datum, _ in zip_longest(ordered_row, range(len(self.columns))) ] + label = None if self._should_render_row_labels: - row_metadata = self.rows.get(self._row_locations.get_key(row_index)) label = ( - default_cell_formatter(row_metadata.label) + default_cell_formatter( + row_metadata.label, + wrap=row_metadata.height != 1, + height=row_metadata.height, + ) if row_metadata.label else None ) @@ -1982,19 +2072,25 @@ def _render_cell( ) if is_header_cell: - options = self.app.console.options.update_dimensions( - width, self.header_height - ) + row_height = self.header_height + options = self.app.console.options.update_dimensions(width, row_height) else: - row = self.rows[row_key] # If an auto-height row hasn't had its height calculated, we don't fix # the value for `height` so that we can measure the height of the cell. + row = self.rows[row_key] if row.auto_height and row.height == 0: + row_height = 0 options = self.app.console.options.update_width(width) else: + row_height = row.height options = self.app.console.options.update_dimensions( - width, row.height + width, row_height ) + + # If the row height is explicitly set to 1, then we don't wrap. + if row_height == 1: + options = options.update(no_wrap=True) + lines = self.app.console.render_lines( Styled( Padding(cell, (0, self.cell_padding)), diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 63736e5194..018a3939a3 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -254,6 +254,7 @@ async def test_add_row_duplicate_key(): with pytest.raises(DuplicateKey): table.add_row("2", key="1") # Duplicate row key + async def test_add_row_too_many_values(): app = DataTableApp() async with app.run_test(): @@ -263,6 +264,7 @@ async def test_add_row_too_many_values(): with pytest.raises(ValueError): table.add_row("1", "2") + async def test_add_column_duplicate_key(): app = DataTableApp() async with app.run_test():