Skip to content

Commit

Permalink
add support for hindcasts
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Dec 19, 2024
1 parent 1d8c550 commit c3531cc
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class Config:
development_hacks: dict = {}
"""A dictionary of development hacks to apply to the runner. This is used to test new features or to work around"""

debugging_info: dict = {}
"""A dictionary to store debug information. This is ignored."""


def load_config(path, overrides, defaults=None, Configuration=Configuration):

Expand Down
61 changes: 56 additions & 5 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,58 @@
LOG = logging.getLogger(__name__)


class HindcastOutput:

def __init__(self, reference_year):
self.reference_year = reference_year

def __call__(self, values, template, keys):

if "date" not in keys:
assert template.metadata("hdate", default=None) is None, template
date = template.metadata("date")
else:
date = keys.pop("date")

for k in ("date", "hdate"):
keys.pop(k, None)

keys["edition"] = 1
keys["localDefinitionNumber"] = 30
keys["dataDate"] = int(to_datetime(date).strftime("%Y%m%d"))
keys["referenceDate"] = int(to_datetime(date).replace(year=self.reference_year).strftime("%Y%m%d"))

return values, template, keys


MODIFIERS = dict(hindcast=HindcastOutput)


def modifier_factory(modifiers):

if modifiers is None:
return []

if not isinstance(modifiers, list):
modifiers = [modifiers]

result = []
for modifier in modifiers:
assert isinstance(modifier, dict), modifier
assert len(modifier) == 1, modifier

klass = list(modifier.keys())[0]
result.append(MODIFIERS[klass](**modifier[klass]))

return result


class GribOutput(Output):
"""
Handles grib
"""

def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None):
def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None, modifiers=None):
super().__init__(context)
self._first = True
self.typed_variables = self.checkpoint.typed_variables
Expand All @@ -40,6 +86,7 @@ def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, g
self._template_date = None
self._template_reuse = None
self.use_closest_template = False # Off for now
self.modifiers = modifier_factory(modifiers)

def write_initial_state(self, state):
# We trust the GribInput class to provide the templates
Expand Down Expand Up @@ -76,7 +123,8 @@ def write_initial_state(self, state):
quiet=self.quiet,
)

# LOG.info("Step 0 GRIB %s\n%s", template, json.dumps(keys, indent=4))
for modifier in self.modifiers:
values, template, keys = modifier(values, template, keys)

self.write_message(values, template=template, **keys)

Expand All @@ -95,7 +143,7 @@ def write_state(self, state):
self.quiet.add("_grib_templates_for_output")
LOG.warning("Input is not GRIB.")

for name, value in state["fields"].items():
for name, values in state["fields"].items():
keys = {}

variable = self.typed_variables[name]
Expand All @@ -118,7 +166,7 @@ def write_state(self, state):
keys.update(self.encoding)

keys = grib_keys(
values=value,
values=values,
template=template,
date=reference_date.strftime("%Y-%m-%d"),
time=reference_date.hour,
Expand All @@ -131,11 +179,14 @@ def write_state(self, state):
quiet=self.quiet,
)

for modifier in self.modifiers:
values, template, keys = modifier(values, template, keys)

if LOG.isEnabledFor(logging.DEBUG):
LOG.info("Encoding GRIB %s\n%s", template, json.dumps(keys, indent=4))

try:
self.write_message(value, template=template, **keys)
self.write_message(values, template=template, **keys)
except Exception:
LOG.error("Error writing field %s", name)
LOG.error("Template: %s", template)
Expand Down
10 changes: 9 additions & 1 deletion src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,17 @@ def __init__(
templates=None,
grib1_keys=None,
grib2_keys=None,
modifiers=None,
**kwargs,
):
super().__init__(context, encoding=encoding, templates=templates, grib1_keys=grib1_keys, grib2_keys=grib2_keys)
super().__init__(
context,
encoding=encoding,
templates=templates,
grib1_keys=grib1_keys,
grib2_keys=grib2_keys,
modifiers=modifiers,
)
self.path = path
self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs)
self.archiving = defaultdict(ArchiveCollector)
Expand Down

0 comments on commit c3531cc

Please sign in to comment.