Skip to content

Commit

Permalink
Feature/constants (#60)
Browse files Browse the repository at this point in the history
* support swapped dimensions

* add support for `match_all_dates`
  • Loading branch information
b8raoult authored Oct 1, 2024
1 parent 69fcd94 commit 3130d73
Show file tree
Hide file tree
Showing 13 changed files with 324 additions and 78 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ _dev/
test.ipynb
*tmp_data/
tempCodeRunnerFile.python
Untitled-*.py
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Keep it human-readable, your future self will thank you!
### Added

- New `rescale` keyword in `open_dataset` to change units of variables #36
- Add support for constant fields when creating datasets
- Simplify imports

### Changed
Expand Down
14 changes: 12 additions & 2 deletions src/anemoi/datasets/create/functions/sources/xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import logging

from earthkit.data.core.fieldlist import MultiFieldList
from earthkit.data.indexing.fieldlist import FieldArray

from anemoi.datasets.data.stores import name_to_zarr_store
from anemoi.datasets.utils.fields import NewMetadataField

from .. import iterate_patterns
from .fieldlist import XarrayFieldList
Expand All @@ -29,7 +31,7 @@ def check(what, ds, paths, **kwargs):
raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})")


def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs):
def load_one(emoji, context, dates, dataset, options={}, match_all_dates=False, flavour=None, **kwargs):
import xarray as xr

"""
Expand All @@ -49,7 +51,15 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs)
data = xr.open_dataset(dataset, **options)

fs = XarrayFieldList.from_xarray(data, flavour)
result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])

if match_all_dates:
match = fs.sel(**kwargs)
result = []
for date in dates:
result.append(FieldArray([NewMetadataField(f, valid_datetime=date) for f in match]))
result = MultiFieldList(result)
else:
result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])

if len(result) == 0:
LOG.warning(f"No data found for {dataset} and dates {dates} and {kwargs}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class Coordinate:
is_step = False
is_date = False
is_member = False
is_x = False
is_y = False

def __init__(self, variable):
self.variable = variable
Expand All @@ -66,10 +68,11 @@ def __len__(self):
return 1 if self.scalar else len(self.variable)

def __repr__(self):
return "%s[name=%s,values=%s]" % (
return "%s[name=%s,values=%s,shape=%s]" % (
self.__class__.__name__,
self.variable.name,
self.variable.values if self.scalar else len(self),
self.variable.shape,
)

def reduced(self, i):
Expand Down Expand Up @@ -225,11 +228,13 @@ class LatitudeCoordinate(Coordinate):

class XCoordinate(Coordinate):
is_grid = True
is_x = True
mars_names = ("x",)


class YCoordinate(Coordinate):
is_grid = True
is_y = True
mars_names = ("y",)


Expand Down
17 changes: 13 additions & 4 deletions src/anemoi/datasets/create/functions/sources/xarray/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,18 @@ def __init__(self, owner, selection):
def shape(self):
return self._shape

def to_numpy(self, flatten=False, dtype=None):
values = self.selection.values
def to_numpy(self, flatten=False, dtype=None, index=None):
if index is not None:
values = self.selection[index]
else:
values = self.selection

assert dtype is None

if flatten:
return values.flatten()
return values.reshape(self.shape)
return values.values.flatten()

return values # .reshape(self.shape)

def _make_metadata(self):
return XArrayMetadata(self)
Expand Down Expand Up @@ -113,3 +118,7 @@ def forecast_reference_time(self):

def __repr__(self):
return repr(self._metadata)

def _values(self):
# we don't use .values as this will download the data
return self.selection
32 changes: 16 additions & 16 deletions src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,25 @@ def _skip_attr(v, attr_name):
skip.update(attr_val.split(" "))

for name in ds.data_vars:
v = ds[name]
_skip_attr(v, "coordinates")
_skip_attr(v, "bounds")
_skip_attr(v, "grid_mapping")
variable = ds[name]
_skip_attr(variable, "coordinates")
_skip_attr(variable, "bounds")
_skip_attr(variable, "grid_mapping")

# Select only geographical variables
for name in ds.data_vars:

if name in skip:
continue

v = ds[name]
variable = ds[name]
coordinates = []

for coord in v.coords:
for coord in variable.coords:

c = guess.guess(ds[coord], coord)
assert c, f"Could not guess coordinate for {coord}"
if coord not in v.dims:
if coord not in variable.dims:
c.is_dim = False
coordinates.append(c)

Expand All @@ -98,17 +98,17 @@ def _skip_attr(v, attr_name):
if grid_coords < 2:
continue

variables.append(
Variable(
ds=ds,
var=v,
coordinates=coordinates,
grid=guess.grid(coordinates),
time=Time.from_coordinates(coordinates),
metadata={},
)
v = Variable(
ds=ds,
variable=variable,
coordinates=coordinates,
grid=guess.grid(coordinates, variable),
time=Time.from_coordinates(coordinates),
metadata={},
)

variables.append(v)

return cls(ds, variables)

def sel(self, **kwargs):
Expand Down
121 changes: 111 additions & 10 deletions src/anemoi/datasets/create/functions/sources/xarray/flavour.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#


import logging

from .coordinates import DateCoordinate
from .coordinates import EnsembleCoordinate
from .coordinates import LatitudeCoordinate
Expand All @@ -18,8 +20,13 @@
from .coordinates import TimeCoordinate
from .coordinates import XCoordinate
from .coordinates import YCoordinate
from .coordinates import is_scalar
from .grid import MeshedGrid
from .grid import MeshProjectionGrid
from .grid import UnstructuredGrid
from .grid import UnstructuredProjectionGrid

LOG = logging.getLogger(__name__)


class CoordinateGuesser:
Expand Down Expand Up @@ -155,31 +162,125 @@ def _guess(self, c, coord):
f" {long_name=}, {standard_name=}, units\n\n{c}\n\n{type(c.values)} {c.shape}"
)

def grid(self, coordinates):
def grid(self, coordinates, variable):
lat = [c for c in coordinates if c.is_lat]
lon = [c for c in coordinates if c.is_lon]

if len(lat) != 1:
raise NotImplementedError(f"Expected 1 latitude coordinate, got {len(lat)}")
if len(lat) == 1 and len(lon) == 1:
return self._lat_lon_provided(lat, lon, variable)

x = [c for c in coordinates if c.is_x]
y = [c for c in coordinates if c.is_y]

if len(x) == 1 and len(y) == 1:
return self._x_y_provided(x, y, variable)

if len(lon) != 1:
raise NotImplementedError(f"Expected 1 longitude coordinate, got {len(lon)}")
raise NotImplementedError(f"Cannot establish grid {coordinates}")

def _lat_lon_provided(self, lat, lon, variable):
lat = lat[0]
lon = lon[0]

if (lat.name, lon.name) in self._cache:
return self._cache[(lat.name, lon.name)]
if lat.variable.dims != lon.variable.dims:
raise ValueError(f"Dimensions do not match {lat.name}{lat.variable.dims} != {lon.name}{lon.variable.dims}")

dim_vars = variable.dims[-len(lat.variable.dims) :]

if set(lat.variable.dims) != set(dim_vars):
raise ValueError(
f"Dimensions do not match {variable.name}{variable.dims} != {lat.name}{lat.variable.dims} and {lon.name}{lon.variable.dims}"
)

if (lat.name, lon.name, dim_vars) in self._cache:
return self._cache[(lat.name, lon.name, dim_vars)]

assert len(lat.variable.shape) == len(lon.variable.shape), (lat.variable.shape, lon.variable.shape)
if len(lat.variable.shape) == 1:
grid = MeshedGrid(lat, lon)
grid = MeshedGrid(lat, lon, dim_vars)
else:
grid = UnstructuredGrid(lat, lon)
grid = UnstructuredGrid(lat, lon, dim_vars)

self._cache[(lat.name, lon.name)] = grid
self._cache[(lat.name, lon.name, dim_vars)] = grid
return grid

def _x_y_provided(self, x, y, variable):
x = x[0]
y = y[0]

if x.variable.dims != y.variable.dims:
raise ValueError(f"Dimensions do not match {x.name}{x.variable.dims} != {y.name}{y.variable.dims}")

dim_vars = variable.dims[-len(x.variable.dims) :]

if x.variable.dims != dim_vars:
raise ValueError(
f"Dimensions do not match {variable.name}{variable.dims} != {x.name}{x.variable.dims} and {y.name}{y.variable.dims}"
)

if (x.name, y.name) in self._cache:
return self._cache[(x.name, y.name)]

if (x.name, y.name) in self._cache:
return self._cache[(x.name, y.name)]

assert len(x.variable.shape) == len(x.variable.shape), (x.variable.shape, y.variable.shape)

grid_mapping = variable.attrs.get("grid_mapping", None)

if grid_mapping is None:
LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'")
LOG.warning("Trying to guess...")

PROBE = {
"prime_meridian_name",
"reference_ellipsoid_name",
"crs_wkt",
"horizontal_datum_name",
"semi_major_axis",
"spatial_ref",
"inverse_flattening",
"semi_minor_axis",
"geographic_crs_name",
"GeoTransform",
"grid_mapping_name",
"longitude_of_prime_meridian",
}
candidate = None
for v in self.ds.variables:
var = self.ds[v]
if not is_scalar(var):
continue

if PROBE.intersection(var.attrs.keys()):
if candidate:
raise ValueError(f"Multiple candidates for 'grid_mapping': {candidate} and {v}")
candidate = v

if candidate:
LOG.warning(f"Using '{candidate}' as 'grid_mapping'")
grid_mapping = candidate
else:
LOG.warning("Could not fine a candidate for 'grid_mapping'")

if grid_mapping is None:
if "crs" in self.ds[variable].attrs:
grid_mapping = self.ds[variable].attrs["crs"]
LOG.warning(f"Using CRS {grid_mapping} from variable '{variable.name}' attributes")

if grid_mapping is None:
if "crs" in self.ds.attrs:
grid_mapping = self.ds.attrs["crs"]
LOG.warning(f"Using CRS {grid_mapping} from global attributes")

if grid_mapping is not None:
if len(x.variable.shape) == 1:
return MeshProjectionGrid(x, y, grid_mapping)
else:
return UnstructuredProjectionGrid(x, y, grid_mapping)

LOG.error("Could not fine a candidate for 'grid_mapping'")
raise NotImplementedError(f"Unstructured grid {x.name} {y.name}")


class DefaultCoordinateGuesser(CoordinateGuesser):
def __init__(self, ds):
Expand Down
Loading

0 comments on commit 3130d73

Please sign in to comment.