Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorganization of internal utilities and adding type hints for internal utils #341

Merged
merged 12 commits into from
Nov 8, 2023
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_requirements(requirements_filename):
"peter.marinescu@colostate.edu",
],
license="BSD-3-Clause License",
packages=[PACKAGE_NAME, PACKAGE_NAME + ".utils"],
packages=[PACKAGE_NAME, PACKAGE_NAME + ".utils", PACKAGE_NAME + ".utils.internal"],
install_requires=get_requirements("requirements.txt"),
test_requires=["pytest"],
zip_safe=False,
Expand Down
1 change: 1 addition & 0 deletions tobac/utils/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .basic import *
203 changes: 70 additions & 133 deletions tobac/utils/internal.py → tobac/utils/internal/basic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
"""Internal tobac utilities
"""
from __future__ import annotations

import numpy as np
import skimage.measure
import xarray as xr
import iris
import iris.cube
import pandas as pd
import warnings
from . import iris_utils
from . import xarray_utils as xr_utils
from typing import Union, Callable

# list of common vertical coordinates to search for in various functions
COMMON_VERT_COORDS: list[str] = [
"z",
"model_level_number",
"altitude",
"geopotential_height",
]


def _warn_auto_coordinate():
Expand All @@ -17,7 +32,7 @@
)


def get_label_props_in_dict(labels):
def get_label_props_in_dict(labels: np.array) -> dict:
"""Function to get the label properties into a dictionary format.

Parameters
Expand All @@ -40,7 +55,7 @@
return region_properties_dict


def get_indices_of_labels_from_reg_prop_dict(region_property_dict):
def get_indices_of_labels_from_reg_prop_dict(region_property_dict: dict) -> tuple[dict]:
"""Function to get the x, y, and z indices (as well as point count) of all labeled regions.
Parameters
----------
Expand Down Expand Up @@ -94,7 +109,7 @@
return [curr_loc_indices, y_indices, x_indices]


def iris_to_xarray(func):
def iris_to_xarray(func: Callable) -> Callable:
"""Decorator that converts all input of a function that is in the form of
Iris cubes into xarray DataArrays and converts all outputs with type
xarray DataArrays back into Iris cubes.
Expand Down Expand Up @@ -164,7 +179,7 @@
return wrapper


def xarray_to_iris(func):
def xarray_to_iris(func: Callable) -> Callable:
"""Decorator that converts all input of a function that is in the form of
xarray DataArrays into Iris cubes and converts all outputs with type
Iris cubes back into xarray DataArrays.
Expand Down Expand Up @@ -248,7 +263,7 @@
return wrapper


def irispandas_to_xarray(func):
def irispandas_to_xarray(func: Callable) -> Callable:
"""Decorator that converts all input of a function that is in the form of
Iris cubes/pandas Dataframes into xarray DataArrays/xarray Datasets and
converts all outputs with the type xarray DataArray/xarray Dataset
Expand Down Expand Up @@ -328,7 +343,7 @@
return wrapper


def xarray_to_irispandas(func):
def xarray_to_irispandas(func: Callable) -> Callable:
"""Decorator that converts all input of a function that is in the form of
DataArrays/xarray Datasets into xarray Iris cubes/pandas Dataframes and
converts all outputs with the type Iris cubes/pandas Dataframes back into
Expand Down Expand Up @@ -431,7 +446,7 @@
return wrapper


def njit_if_available(func, **kwargs):
def njit_if_available(func: Callable, **kwargs) -> Callable:
"""Decorator to wrap a function with numba.njit if available.
If numba isn't available, it just returns the function.

Expand All @@ -456,12 +471,15 @@
return func


def find_vertical_axis_from_coord(variable_cube, vertical_coord=None):
def find_vertical_axis_from_coord(
variable_cube: Union[iris.cube.Cube, xr.DataArray],
vertical_coord: Union[str, None] = None,
) -> str:
"""Function to find the vertical coordinate in the iris cube

Parameters
----------
variable_cube: iris.cube
variable_cube: iris.cube.Cube or xarray.DataArray
Input variable cube, containing a vertical coordinate.
vertical_coord: str
Vertical coordinate name. If None, this function tries to auto-detect.
Expand All @@ -476,69 +494,21 @@
ValueError
Raised if the vertical coordinate isn't found in the cube.
"""
list_vertical = [
"z",
"model_level_number",
"altitude",
"geopotential_height",
]

if vertical_coord == "auto":
_warn_auto_coordinate()

if isinstance(variable_cube, iris.cube.Cube):
list_coord_names = [coord.name() for coord in variable_cube.coords()]
elif isinstance(variable_cube, xr.Dataset) or isinstance(
variable_cube, xr.DataArray
):
list_coord_names = variable_cube.coords

if vertical_coord is None or vertical_coord == "auto":
# find the intersection
all_vertical_axes = list(set(list_coord_names) & set(list_vertical))
if len(all_vertical_axes) >= 1:
return all_vertical_axes[0]
else:
raise ValueError(
"Cube lacks suitable automatic vertical coordinate (z, model_level_number, altitude, or geopotential_height)"
)
elif vertical_coord in list_coord_names:
return vertical_coord
else:
raise ValueError("Please specify vertical coordinate found in cube")


def find_axis_from_coord(variable_cube, coord_name):
"""Finds the axis number in an iris cube given a coordinate name.

Parameters
----------
variable_cube: iris.cube
Input variable cube
coord_name: str
coordinate to look for
return iris_utils.find_vertical_axis_from_coord(variable_cube, vertical_coord)
if isinstance(variable_cube, xr.Dataset) or isinstance(variable_cube, xr.DataArray):
return xr_utils.find_vertical_axis_from_coord(variable_cube, vertical_coord)

Returns
-------
axis_number: int
the number of the axis of the given coordinate, or None if the coordinate
is not found in the cube or not a dimensional coordinate
"""

list_coord_names = [coord.name() for coord in variable_cube.coords()]
all_matching_axes = list(set(list_coord_names) & set((coord_name,)))
if (
len(all_matching_axes) == 1
and len(variable_cube.coord_dims(all_matching_axes[0])) > 0
):
return variable_cube.coord_dims(all_matching_axes[0])[0]
elif len(all_matching_axes) > 1:
raise ValueError("Too many axes matched.")
else:
return None
raise ValueError("variable_cube must be xr.DataArray or iris.cube.Cube")

Check warning on line 506 in tobac/utils/internal/basic.py

View check run for this annotation

Codecov / codecov/patch

tobac/utils/internal/basic.py#L506

Added line #L506 was not covered by tests


def find_dataframe_vertical_coord(variable_dataframe, vertical_coord=None):
def find_dataframe_vertical_coord(
variable_dataframe: pd.DataFrame, vertical_coord: Union[str, None] = None
) -> str:
"""Function to find the vertical coordinate in the iris cube

Parameters
Expand All @@ -563,8 +533,9 @@
_warn_auto_coordinate()

if vertical_coord is None or vertical_coord == "auto":
list_vertical = ["z", "model_level_number", "altitude", "geopotential_height"]
all_vertical_axes = list(set(variable_dataframe.columns) & set(list_vertical))
all_vertical_axes = list(
set(variable_dataframe.columns) & set(COMMON_VERT_COORDS)
)
if len(all_vertical_axes) == 1:
return all_vertical_axes[0]
else:
Expand All @@ -578,7 +549,7 @@


@njit_if_available
def calc_distance_coords(coords_1, coords_2):
def calc_distance_coords(coords_1: np.array, coords_2: np.array) -> float:
"""Function to calculate the distance between cartesian
coordinate set 1 and coordinate set 2.
Parameters
Expand All @@ -605,13 +576,17 @@
return np.sqrt(np.sum(deltas**2))


def find_hdim_axes_3D(field_in, vertical_coord=None, vertical_axis=None):
def find_hdim_axes_3D(
field_in: Union[iris.cube.Cube, xr.DataArray],
vertical_coord: Union[str, None] = None,
vertical_axis: Union[int, None] = None,
) -> tuple[int]:
"""Finds what the hdim axes are given a 3D (including z) or
4D (including z and time) dataset.

Parameters
----------
field_in: iris cube or xarray dataset
field_in: iris cube or xarray dataarray
Input field, can be 3D or 4D
vertical_coord: str
The name of the vertical coord, or None, which will attempt to find
Expand All @@ -626,7 +601,6 @@
The axes for hdim_1 and hdim_2

"""
from iris import cube as iris_cube

if vertical_coord == "auto":
_warn_auto_coordinate()
Expand All @@ -635,91 +609,54 @@
if vertical_coord != "auto":
raise ValueError("Cannot set both vertical_coord and vertical_axis.")

if type(field_in) is iris_cube.Cube:
return find_hdim_axes_3D_iris(field_in, vertical_coord, vertical_axis)
if type(field_in) is iris.cube.Cube:
return iris_utils.find_hdim_axes_3d(field_in, vertical_coord, vertical_axis)
elif type(field_in) is xr.DataArray:
raise NotImplementedError("Xarray find_hdim_axes_3D not implemented")
else:
raise ValueError("Unknown data type: " + type(field_in).__name__)


def find_hdim_axes_3D_iris(field_in, vertical_coord=None, vertical_axis=None):
"""Finds what the hdim axes are given a 3D (including z) or
4D (including z and time) dataset.
def find_axis_from_coord(
variable_arr: Union[iris.cube.Cube, xr.DataArray], coord_name: str
) -> int:
"""Finds the axis number in an xarray or iris cube given a coordinate or dimension name.

Parameters
----------
field_in: iris cube
Input field, can be 3D or 4D
vertical_coord: str or None
The name of the vertical coord, or None, which will attempt to find
the vertical coordinate name
vertical_axis: int or None
The axis number of the vertical coordinate, or None. Note
that only one of vertical_axis or vertical_coord can be set.
variable_arr: iris.cube.Cube or xarray.DataArray
Input variable cube
coord_name: str
coordinate or dimension to look for

Returns
-------
(hdim_1_axis, hdim_2_axis): (int, int)
The axes for hdim_1 and hdim_2
axis_number: int
the number of the axis of the given coordinate, or None if the coordinate
is not found in the variable or not a dimensional coordinate
"""

if vertical_coord == "auto":
_warn_auto_coordinate()

if vertical_coord is not None and vertical_axis is not None:
if vertical_coord != "auto":
raise ValueError("Cannot set both vertical_coord and vertical_axis.")

time_axis = find_axis_from_coord(field_in, "time")
if vertical_axis is not None:
vertical_coord_axis = vertical_axis
vert_coord_found = True
if isinstance(variable_arr, iris.cube.Cube):
return iris_utils.find_axis_from_coord(variable_arr, coord_name)
elif isinstance(variable_arr, xr.DataArray):
raise NotImplementedError(

Check warning on line 642 in tobac/utils/internal/basic.py

View check run for this annotation

Codecov / codecov/patch

tobac/utils/internal/basic.py#L641-L642

Added lines #L641 - L642 were not covered by tests
"xarray version of find_axis_from_coord not implemented."
)
else:
try:
vertical_axis = find_vertical_axis_from_coord(
field_in, vertical_coord=vertical_coord
)
except ValueError:
vert_coord_found = False
else:
vert_coord_found = True
ndim_vertical = field_in.coord_dims(vertical_axis)
if len(ndim_vertical) > 1:
raise ValueError(
"please specify 1 dimensional vertical coordinate."
" Current vertical coordinates: {0}".format(ndim_vertical)
)
if len(ndim_vertical) != 0:
vertical_coord_axis = ndim_vertical[0]
else:
# this means the vertical coordinate is an auxiliary coordinate of some kind.
vert_coord_found = False

if not vert_coord_found:
# if we don't have a vertical coordinate, and we are 3D or lower
# that is okay.
if (field_in.ndim == 3 and time_axis is not None) or field_in.ndim < 3:
vertical_coord_axis = None
else:
raise ValueError("No suitable vertical coordinate found")
# Once we know the vertical coordinate, we can resolve the
# horizontal coordinates

all_axes = np.arange(0, field_in.ndim)
output_vals = tuple(
all_axes[np.logical_not(np.isin(all_axes, [time_axis, vertical_coord_axis]))]
)
return output_vals
raise ValueError("variable_arr must be Iris Cube or Xarray DataArray")

Check warning on line 646 in tobac/utils/internal/basic.py

View check run for this annotation

Codecov / codecov/patch

tobac/utils/internal/basic.py#L646

Added line #L646 was not covered by tests


@irispandas_to_xarray
def detect_latlon_coord_name(in_dataset, latitude_name=None, longitude_name=None):
def detect_latlon_coord_name(
in_dataset: Union[xr.DataArray, iris.cube.Cube],
latitude_name: Union[str, None] = None,
longitude_name: Union[str, None] = None,
) -> tuple[str]:
"""Function to detect the name of latitude/longitude coordinates

Parameters
----------
in_dataset: iris.cube.Cube, xarray.Dataset, or xarray.Dataarray
in_dataset: iris.cube.Cube or xarray.DataArray
Input dataset to detect names from
latitude_name: str
The name of the latitude coordinate. If None, tries to auto-detect.
Expand Down
Loading
Loading