Skip to content

Commit

Permalink
return forcings as states
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Dec 6, 2024
1 parent cc795e4 commit 49b804e
Show file tree
Hide file tree
Showing 19 changed files with 227 additions and 150 deletions.
74 changes: 32 additions & 42 deletions src/anemoi/inference/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,26 @@ def __init__(self, context):
self.kinds = dict(unknown=True) # Used for debugging

@abstractmethod
def load_forcings(self, state, date):
def load_forcings_array(self, dates, current_state) -> np.ndarray:
"""Load the forcings for the given dates."""
pass

def __repr__(self):
return f"{self.__class__.__name__}"

def _state_to_numpy(self, state, variables, dates) -> np.ndarray:
"""Convert the state dictionary to a numpy array.
This assumes that the state dictionary contains the fields for the given variables.
And that the fields values are sorted by dates.
"""
fields = state["fields"]
result = np.stack([fields[v] for v in variables], axis=0)
assert result.shape[:2] == (
len(variables),
len(dates),
), (result.shape, variables, dates)
return result


class ComputedForcings(Forcings):
"""Compute forcings like `cos_julian_day` or `insolation`."""
Expand All @@ -50,16 +64,16 @@ def __init__(self, context, variables, mask):
def __repr__(self):
return f"{self.__class__.__name__}({self.variables})"

def load_forcings(self, state, dates):
def load_forcings_array(self, dates, current_state):

LOG.debug("Adding dynamic forcings %s", self.variables)

if not isinstance(dates, (list, tuple)):
dates = [dates]

source = UnstructuredGridFieldList.from_values(
latitudes=state["latitudes"],
longitudes=state["longitudes"],
latitudes=current_state["latitudes"],
longitudes=current_state["longitudes"],
)

ds = ekd.from_source("forcings", source, date=dates, param=self.variables)
Expand All @@ -73,7 +87,7 @@ def rename(f, _, metadata):

forcing = ds.order_by(name=self.variables, valid_datetime="ascending")

# Forcing are sorted by `compute_forcings` in the order (varaible, date)
# Forcing are sorted by `compute_forcings` in the order (variable, date)

return forcing.to_numpy(dtype=np.float32, flatten=True).reshape(len(self.variables), len(dates), -1)

Expand All @@ -86,45 +100,17 @@ def __init__(self, context, input, variables, mask):
self.variables = variables
self.mask = mask
self.input = input
# self.grid = context.checkpoint.grid
# self.area = context.checkpoint.area
# self.use_grib_paramid = True # TODO: find a way to `use_grib_paramid``
self.kinds = dict(retrieved=True) # Used for debugging

def __repr__(self):
return f"{self.__class__.__name__}({self.variables})"

def load_forcings(self, state, dates):
data = self.input.load_forcings(variables=self.variables, dates=dates)

# expected_shape = (len(self.variables), len(dates), state["latitudes"].size)
# assert data.shape == expected_shape, (data.shape, expected_shape)

return data

# assert False, "Not implemented yet"
# from .inputs.mars import retrieve

# requests = self.context.checkpoint.mars_requests(
# variables=self.variables,
# dates=dates,
# use_grib_paramid=self.use_grib_paramid,
# )

# if not requests:
# raise ValueError("No requests for %s (%s)" % (self.variables, dates))

# for r in requests:
# LOG.info("Request: %s", r)

# fields = retrieve(requests=requests, grid=self.grid, area=self.area, expver=1)

# if not fields:
# raise ValueError("No fields retrieved for {self.variables} ({dates})")

# fields = self.checkpoint.name_fields(fields).order_by(name=self.variables, valid_datetime="ascending")

# return fields.to_numpy(dtype=np.float32, flatten=True).reshape(len(self.variables), len(dates), -1)
def load_forcings_array(self, dates, current_state):
return self._state_to_numpy(
self.input.load_forcings_state(variables=self.variables, dates=dates, current_state=current_state),
self.variables,
dates,
)


class BoundaryForcings(Forcings):
Expand All @@ -143,11 +129,15 @@ def __init__(self, context, input, variables, variables_mask):
def __repr__(self):
return f"{self.__class__.__name__}({self.variables})"

def load_forcings(self, state, dates):
data = self.input.load_forcings(variables=self.variables, dates=dates)
def load_forcings_array(self, dates, current_state):
data = self._state_to_numpy(
self.input.load_forcings_state(variables=self.variables, dates=dates, current_state=current_state),
self.variables,
dates,
)
data = data[..., self.spatial_mask]

expected_shape = (len(self.variables), len(dates), state["latitudes"][self.spatial_mask].size)
expected_shape = (len(self.variables), len(dates), current_state["latitudes"][self.spatial_mask].size)
assert data.shape == expected_shape, (data.shape, expected_shape)

return data
2 changes: 1 addition & 1 deletion src/anemoi/inference/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create_input_state(self, *, date=None):
pass

@abstractmethod
def load_forcings(self, *, variables, dates):
def load_forcings_state(self, *, variables, dates, current_state):
"""Load forcings (constant and dynamic)."""
pass

Expand Down
23 changes: 19 additions & 4 deletions src/anemoi/inference/inputs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def ds(self):

return open_dataset(*self.args, **self.kwargs)

@cached_property
def latitudes(self):
return self.ds.latitudes

@cached_property
def longitudes(self):
return self.ds.longitudes

def __repr__(self):
return f"DatasetInput({self.args}, {self.kwargs})"

Expand All @@ -51,8 +59,8 @@ def create_input_state(self, *, date=None):

input_state = dict(
date=date,
latitudes=self.ds.latitudes,
longitudes=self.ds.longitudes,
latitudes=self.latitudes,
longitudes=self.longitudes,
fields=dict(),
)

Expand All @@ -73,7 +81,7 @@ def create_input_state(self, *, date=None):

return input_state

def load_forcings(self, *, variables, dates):
def load_forcings_state(self, *, variables, dates, current_state):
data = self._load_dates(dates) # (date, variables, ensemble, values)

requested_variables = np.array([self.ds.name_to_index[v] for v in variables])
Expand All @@ -82,7 +90,14 @@ def load_forcings(self, *, variables, dates):
data = np.squeeze(data, axis=2)
# Reorder the dimensions to (variable, date, values)
data = np.swapaxes(data, 0, 1)
return data
fields = {v: data[i] for i, v in enumerate(variables)}

return dict(
fields=fields,
dates=dates,
latitudes=self.latitudes,
longitudes=self.longitudes,
)

def _load_dates(self, dates):

Expand Down
15 changes: 11 additions & 4 deletions src/anemoi/inference/inputs/ekd.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, context, *, namer=None):
self._namer = namer if namer is not None else self.checkpoint.default_namer()
assert callable(self._namer), type(self._namer)

def _create_input_state(
def _create_state(
self,
input_fields,
*,
Expand Down Expand Up @@ -208,6 +208,13 @@ def _name(field, _, original_metadata):
data = FieldArray([f.clone(name=_name) for f in data])
return data.sel(name=name, **kwargs)

def _load_forcings(self, fields, variables, dates):
data = self._filter_and_sort(fields, variables=variables, dates=dates, title="Load forcings")
return data.to_numpy(dtype=np.float32, flatten=True).reshape(len(variables), len(dates), -1)
def _load_forcings_state(self, fields, variables, dates, current_state):
return self._create_state(
fields,
variables=variables,
date=dates,
latitudes=current_state["latitudes"],
longitudes=current_state["longitudes"],
dtype=np.float32,
flatten=True,
)
8 changes: 5 additions & 3 deletions src/anemoi/inference/inputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ def __init__(self, context, path, *, namer=None, **kwargs):
self.path = path

def create_input_state(self, *, date):
return self._create_input_state(ekd.from_source("file", self.path), variables=None, date=date)
return self._create_state(ekd.from_source("file", self.path), variables=None, date=date)

def load_forcings(self, *, variables, dates):
return self._load_forcings(ekd.from_source("file", self.path), variables=variables, dates=dates)
def load_forcings_state(self, *, variables, dates, current_state):
return self._load_forcings_state(
ekd.from_source("file", self.path), variables=variables, dates=dates, current_state=current_state
)

def template(self, variable, date, **kwargs):
fields = ekd.from_source("file", self.path)
Expand Down
8 changes: 5 additions & 3 deletions src/anemoi/inference/inputs/icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ def create_input_state(self, *, date):

LOG.info("Done")

return self._create_input_state(
return self._create_state(
ekd.from_source("file", self.path),
variables=None,
date=date,
latitudes=latitudes,
longitudes=longitudes,
)

def load_forcings(self, *, variables, dates):
return self._load_forcings(ekd.from_source("file", self.path), variables=variables, dates=dates)
def load_forcings_state(self, *, variables, dates, current_state):
return self._load_forcings_state(
ekd.from_source("file", self.path), variables=variables, dates=dates, current_state=current_state
)
8 changes: 5 additions & 3 deletions src/anemoi/inference/inputs/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def create_input_state(self, *, date):

date = to_datetime(date)

return self._create_input_state(
return self._create_state(
self.retrieve(
self.variables,
[date + h for h in self.checkpoint.lagged],
Expand All @@ -151,5 +151,7 @@ def retrieve(self, variables, dates):
def template(self, variable, date, **kwargs):
return self.retrieve([variable], [date])[0]

def load_forcings(self, variables, dates):
return self._load_forcings(self.retrieve(variables, dates), variables, dates)
def load_forcings_state(self, *, variables, dates, current_state):
return self._load_forcings_state(
self.retrieve(variables, dates), variables=variables, dates=dates, current_state=current_state
)
5 changes: 3 additions & 2 deletions src/anemoi/inference/inputs/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __repr__(self):
def create_input_state(self, *, date):
raise NotImplementedError("MultiInput.create_input_state() not implemented")

def load_forcings(self, variables, dates):
def load_forcings_state(self, *, variables, dates, current_state):
raise NotImplementedError("MultiInput.load_forcings_state() not implemented")
inputs = defaultdict(list)
for v in variables:
if v not in self.inputs:
Expand All @@ -54,7 +55,7 @@ def load_forcings(self, variables, dates):
rows = {}

for input, vs in inputs.items():
array = self._input_per_id[input].load_forcings(variables=vs, dates=dates)
array = self._input_per_id[input].load_forcings_state(variables=vs, dates=dates)
for i, v in enumerate(vs):
rows[v] = array[i]

Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/inference/inputs/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ def template(self, variable, date, **kwargs):

return ekd.from_source("memory", template)[0]

def load_forcings(self, variables, dates):
raise NotImplementedError("TemplatesInput.load_forcings() not implemented")
def load_forcings_state(self, *, variables, dates, current_state):
raise NotImplementedError("TemplatesInput.load_forcings_state() not implemented")
2 changes: 1 addition & 1 deletion src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def write_state(self, state):
self.quiet.add(name)

variable_keys = variable.grib_keys.copy()
for key in ("class", "type", "stream", "expver", "date", "time", "step"):
for key in ("class", "type", "stream", "expver", "date", "time", "step", "domain"):
variable_keys.pop(key, None)

keys.update(variable_keys)
Expand Down
6 changes: 4 additions & 2 deletions src/anemoi/inference/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def __init__(self, context, *, input_fields):
def create_input_state(self, *, date):
return self._create_input_state(self.input_fields, variables=None, date=date)

def load_forcings(self, *, variables, dates):
return self._load_forcings(self.input_fields, variables=variables, dates=dates)
def load_forcings_state(self, *, variables, dates, current_state):
return self._load_forcings_state(
self.input_fields, variables=variables, dates=dates, current_state=current_state
)

def set_private_attributes(self, state, input_fields):
input_fields = input_fields.order_by("valid_datetime")
Expand Down
Loading

0 comments on commit 49b804e

Please sign in to comment.