diff --git a/test/test_repr.py b/test/test_repr.py new file mode 100644 index 000000000..6fb5770a4 --- /dev/null +++ b/test/test_repr.py @@ -0,0 +1,37 @@ +import uxarray as ux +import os + +import pytest + +from pathlib import Path + +current_path = Path(os.path.dirname(os.path.realpath(__file__))) + + +grid_path = current_path / 'meshfiles' / "ugrid" / "quad-hexagon" / 'grid.nc' +data_path = current_path / 'meshfiles' / "ugrid" / "quad-hexagon" / 'data.nc' + + + +def test_grid_repr(): + uxgrid = ux.open_grid(grid_path) + + out = uxgrid._repr_html_() + + assert out is not None + + +def test_dataset_repr(): + uxds = ux.open_dataset(grid_path, data_path) + + out = uxds._repr_html_() + + assert out is not None + + +def test_dataarray_repr(): + uxds = ux.open_dataset(grid_path, data_path) + + out = uxds['t2m']._repr_html_() + + assert out is not None diff --git a/uxarray/conventions/descriptors.py b/uxarray/conventions/descriptors.py index 14cc2a727..2cd0eb2a7 100644 --- a/uxarray/conventions/descriptors.py +++ b/uxarray/conventions/descriptors.py @@ -1,9 +1,18 @@ -DESCRIPTOR_NAMES = ["face_areas", "edge_face_distances", "edge_node_distances"] +DESCRIPTOR_NAMES = [ + "face_areas", + "n_nodes_per_face", + "edge_face_distances", + "edge_node_distances", +] FACE_AREAS_DIMS = ["n_face"] -FACE_AREAS_ATTRS = {"cf_role": "face_areas"} +FACE_AREAS_ATTRS = {"cf_role": "face_areas", "long_name": "Area of each face."} + + +# TODO: add n_nodes_per_face + EDGE_FACE_DISTANCES_DIMS = ["n_edge"] EDGE_FACE_DISTANCES_ATTRS = { diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index f83f39e1a..e976c5816 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -3,8 +3,15 @@ import xarray as xr import numpy as np + from typing import TYPE_CHECKING, Optional, Union, Hashable, Literal +from uxarray.formatting_html import array_repr + +from html import escape + +from xarray.core.options import OPTIONS + from uxarray.grid import Grid import uxarray.core.dataset @@ -78,6 +85,11 @@ def __init__(self, *args, uxgrid: Grid = None, **kwargs): subset = UncachedAccessor(DataArraySubsetAccessor) remap = UncachedAccessor(UxDataArrayRemapAccessor) + def _repr_html_(self) -> str: + if OPTIONS["display_style"] == "text": + return f"
{escape(repr(self))}
" + return array_repr(self) + @classmethod def _construct_direct(cls, *args, **kwargs): """Override to make the result a ``uxarray.UxDataArray`` class.""" diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index dc2afed87..9a2f522a0 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -14,6 +14,12 @@ from xarray.core.utils import UncachedAccessor +from uxarray.formatting_html import dataset_repr + +from html import escape + +from xarray.core.options import OPTIONS + from uxarray.remap import UxDatasetRemapAccessor from warnings import warn @@ -76,6 +82,11 @@ def __init__( plot = UncachedAccessor(UxDatasetPlotAccessor) remap = UncachedAccessor(UxDatasetRemapAccessor) + def _repr_html_(self) -> str: + if OPTIONS["display_style"] == "text": + return f"
{escape(repr(self))}
" + return dataset_repr(self) + def __getitem__(self, key): """Override to make sure the result is an instance of ``uxarray.UxDataArray`` or ``uxarray.UxDataset``.""" diff --git a/uxarray/formatting.py b/uxarray/formatting.py new file mode 100644 index 000000000..2ae28399f --- /dev/null +++ b/uxarray/formatting.py @@ -0,0 +1 @@ +pass diff --git a/uxarray/formatting_html.py b/uxarray/formatting_html.py new file mode 100644 index 000000000..8f15978b5 --- /dev/null +++ b/uxarray/formatting_html.py @@ -0,0 +1,205 @@ +from html import escape +import xarray.core.formatting_html as xrfm + +from functools import partial + +from uxarray.conventions import ugrid, descriptors + +from collections import OrderedDict + + +def _grid_header(grid, header_name=None): + if header_name is None: + obj_type = f"uxarray.{type(grid).__name__}" + else: + obj_type = f"{header_name}" + + header_components = [f"
{escape(obj_type)}
"] + + return header_components + + +def _grid_sections(grid, max_items_collapse=15): + cartesian_coordinates = list( + [coord for coord in ugrid.CARTESIAN_COORDS if coord in grid._ds] + ) + spherical_coordinates = list( + [coord for coord in ugrid.SPHERICAL_COORDS if coord in grid._ds] + ) + descritor = list( + [desc for desc in descriptors.DESCRIPTOR_NAMES if desc in grid._ds] + ) + connectivity = grid.connectivity + + sections = [xrfm.dim_section(grid._ds)] + + sections.append( + grid_spherical_coordinates_section( + grid._ds[spherical_coordinates], + max_items_collapse=max_items_collapse, + name="Spherical Coordinates", + ) + ) + sections.append( + grid_cartesian_coordinates_section( + grid._ds[cartesian_coordinates], + max_items_collapse=max_items_collapse, + name="Cartesian Coordinates", + ) + ) + + sections.append( + grid_connectivity_section( + grid._ds[connectivity], + max_items_collapse=max_items_collapse, + name="Connectivity", + ) + ) + + sections.append( + grid_descriptor_section( + grid._ds[descritor], + max_items_collapse=max_items_collapse, + name="Descriptors", + ) + ) + + sections.append( + grid_attr_section( + grid._ds.attrs, max_items_collapse=max_items_collapse, name="Attributes" + ) + ) + + return sections + + +def grid_repr(grid, max_items_collapse=15, header_name=None) -> str: + """HTML repr for ``Grid`` class.""" + header_components = _grid_header(grid, header_name) + + sections = _grid_sections(grid, max_items_collapse) + + return xrfm._obj_repr(grid, header_components, sections) + + +grid_spherical_coordinates_section = partial( + xrfm._mapping_section, + details_func=xrfm.summarize_vars, + expand_option_name="display_expand_data_vars", +) + +grid_cartesian_coordinates_section = partial( + xrfm._mapping_section, + details_func=xrfm.summarize_vars, + expand_option_name="display_expand_data_vars", +) + +grid_connectivity_section = partial( + xrfm._mapping_section, + details_func=xrfm.summarize_vars, + expand_option_name="display_expand_data_vars", +) + +grid_descriptor_section = partial( + xrfm._mapping_section, + details_func=xrfm.summarize_vars, + expand_option_name="display_expand_data_vars", +) + +grid_attr_section = partial( + xrfm._mapping_section, + details_func=xrfm.summarize_attrs, + expand_option_name="display_expand_attrs", +) + + +def _obj_repr_with_grid(obj, header_components, sections): + """Return HTML repr of an uxarray object. + + If CSS is not injected (untrusted notebook), fallback to the plain + text repr. + """ + # Construct header and sections for the main object + header = f"
{''.join(h for h in header_components)}
" + sections = "".join(f"
  • {s}
  • " for s in sections) + + grid_html_repr = grid_repr( + obj.uxgrid, + max_items_collapse=0, + header_name=f"uxarray.{type(obj).__name__}.uxgrid", + ) + + icons_svg, css_style = xrfm._load_static_files() + obj_repr_html = ( + "
    " + f"{icons_svg}" + f"
    {escape(repr(obj))}
    " + "" + "
    " + ) + + return ( + "
    " + f"{icons_svg}" + f"
    {escape(repr(obj))}
    " + "" + "
    " + ) + + +def dataset_repr(ds) -> str: + """HTML repr for ``UxDataset`` class.""" + obj_type = f"uxarray.{type(ds).__name__}" + + header_components = [f"
    {escape(obj_type)}
    "] + + sections = [ + xrfm.dim_section(ds), + xrfm.coord_section(ds.coords), + xrfm.datavar_section(ds.data_vars), + xrfm.index_section(xrfm._get_indexes_dict(ds.xindexes)), + xrfm.attr_section(ds.attrs), + ] + + return _obj_repr_with_grid(ds, header_components, sections) + + +def array_repr(arr) -> str: + """HTML repr for ``UxDataArray`` class.""" + + dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape)) + if hasattr(arr, "xindexes"): + indexed_dims = arr.xindexes.dims + else: + indexed_dims = {} + + obj_type = f"uxarray.{type(arr).__name__}" + arr_name = f"'{arr.name}'" if getattr(arr, "name", None) else "" + + header_components = [ + f"
    {obj_type}
    ", + f"
    {arr_name}
    ", + xrfm.format_dims(dims, indexed_dims), + ] + + sections = [xrfm.array_section(arr)] + + if hasattr(arr, "coords"): + sections.append(xrfm.coord_section(arr.coords)) + + if hasattr(arr, "xindexes"): + indexes = xrfm._get_indexes_dict(arr.xindexes) + sections.append(xrfm.index_section(indexes)) + + sections.append(xrfm.attr_section(arr.attrs)) + + return _obj_repr_with_grid(arr, header_components, sections) diff --git a/uxarray/grid/grid.py b/uxarray/grid/grid.py index ae480b132..5b92c1026 100644 --- a/uxarray/grid/grid.py +++ b/uxarray/grid/grid.py @@ -3,6 +3,10 @@ import xarray as xr import numpy as np +from html import escape + +from xarray.core.options import OPTIONS + from typing import ( Optional, Union, @@ -23,6 +27,8 @@ from uxarray.io._topology import _read_topology from uxarray.io._geos import _read_geos_cs +from uxarray.formatting_html import grid_repr + from uxarray.io.utils import _parse_grid_type from uxarray.grid.area import get_all_face_area_from_coords from uxarray.grid.coordinates import ( @@ -396,6 +402,11 @@ def __repr__(self): + descriptors_str ) + def _repr_html_(self) -> str: + if OPTIONS["display_style"] == "text": + return f"
    {escape(repr(self))}
    " + return grid_repr(self) + def __getitem__(self, item): """Implementation of getitem operator for indexing a grid to obtain variables.