diff --git a/src/anemoi/inference/forcings.py b/src/anemoi/inference/forcings.py index 1939e69..159f363 100644 --- a/src/anemoi/inference/forcings.py +++ b/src/anemoi/inference/forcings.py @@ -31,12 +31,26 @@ def __init__(self, context): self.kinds = dict(unknown=True) # Used for debugging @abstractmethod - def load_forcings(self, state, date): + def load_forcings_array(self, dates, current_state) -> np.ndarray: + """Load the forcings for the given dates.""" pass def __repr__(self): return f"{self.__class__.__name__}" + def _state_to_numpy(self, state, variables, dates) -> np.ndarray: + """Convert the state dictionary to a numpy array. + This assumes that the state dictionary contains the fields for the given variables. + And that the fields values are sorted by dates. + """ + fields = state["fields"] + result = np.stack([fields[v] for v in variables], axis=0) + assert result.shape[:2] == ( + len(variables), + len(dates), + ), (result.shape, variables, dates) + return result + class ComputedForcings(Forcings): """Compute forcings like `cos_julian_day` or `insolation`.""" @@ -50,7 +64,7 @@ def __init__(self, context, variables, mask): def __repr__(self): return f"{self.__class__.__name__}({self.variables})" - def load_forcings(self, state, dates): + def load_forcings_array(self, dates, current_state): LOG.debug("Adding dynamic forcings %s", self.variables) @@ -58,8 +72,8 @@ def load_forcings(self, state, dates): dates = [dates] source = UnstructuredGridFieldList.from_values( - latitudes=state["latitudes"], - longitudes=state["longitudes"], + latitudes=current_state["latitudes"], + longitudes=current_state["longitudes"], ) ds = ekd.from_source("forcings", source, date=dates, param=self.variables) @@ -73,7 +87,7 @@ def rename(f, _, metadata): forcing = ds.order_by(name=self.variables, valid_datetime="ascending") - # Forcing are sorted by `compute_forcings` in the order (varaible, date) + # Forcing are sorted by `compute_forcings` in the order (variable, date) return forcing.to_numpy(dtype=np.float32, flatten=True).reshape(len(self.variables), len(dates), -1) @@ -86,45 +100,17 @@ def __init__(self, context, input, variables, mask): self.variables = variables self.mask = mask self.input = input - # self.grid = context.checkpoint.grid - # self.area = context.checkpoint.area - # self.use_grib_paramid = True # TODO: find a way to `use_grib_paramid`` self.kinds = dict(retrieved=True) # Used for debugging def __repr__(self): return f"{self.__class__.__name__}({self.variables})" - def load_forcings(self, state, dates): - data = self.input.load_forcings(variables=self.variables, dates=dates) - - # expected_shape = (len(self.variables), len(dates), state["latitudes"].size) - # assert data.shape == expected_shape, (data.shape, expected_shape) - - return data - - # assert False, "Not implemented yet" - # from .inputs.mars import retrieve - - # requests = self.context.checkpoint.mars_requests( - # variables=self.variables, - # dates=dates, - # use_grib_paramid=self.use_grib_paramid, - # ) - - # if not requests: - # raise ValueError("No requests for %s (%s)" % (self.variables, dates)) - - # for r in requests: - # LOG.info("Request: %s", r) - - # fields = retrieve(requests=requests, grid=self.grid, area=self.area, expver=1) - - # if not fields: - # raise ValueError("No fields retrieved for {self.variables} ({dates})") - - # fields = self.checkpoint.name_fields(fields).order_by(name=self.variables, valid_datetime="ascending") - - # return fields.to_numpy(dtype=np.float32, flatten=True).reshape(len(self.variables), len(dates), -1) + def load_forcings_array(self, dates, current_state): + return self._state_to_numpy( + self.input.load_forcings_state(variables=self.variables, dates=dates, current_state=current_state), + self.variables, + dates, + ) class BoundaryForcings(Forcings): @@ -143,11 +129,15 @@ def __init__(self, context, input, variables, variables_mask): def __repr__(self): return f"{self.__class__.__name__}({self.variables})" - def load_forcings(self, state, dates): - data = self.input.load_forcings(variables=self.variables, dates=dates) + def load_forcings_array(self, dates, current_state): + data = self._state_to_numpy( + self.input.load_forcings_state(variables=self.variables, dates=dates, current_state=current_state), + self.variables, + dates, + ) data = data[..., self.spatial_mask] - expected_shape = (len(self.variables), len(dates), state["latitudes"][self.spatial_mask].size) + expected_shape = (len(self.variables), len(dates), current_state["latitudes"][self.spatial_mask].size) assert data.shape == expected_shape, (data.shape, expected_shape) return data diff --git a/src/anemoi/inference/input.py b/src/anemoi/inference/input.py index 9331e16..9e73a56 100644 --- a/src/anemoi/inference/input.py +++ b/src/anemoi/inference/input.py @@ -33,7 +33,7 @@ def create_input_state(self, *, date=None): pass @abstractmethod - def load_forcings(self, *, variables, dates): + def load_forcings_state(self, *, variables, dates, current_state): """Load forcings (constant and dynamic).""" pass diff --git a/src/anemoi/inference/inputs/dataset.py b/src/anemoi/inference/inputs/dataset.py index c7122e3..8aa7811 100644 --- a/src/anemoi/inference/inputs/dataset.py +++ b/src/anemoi/inference/inputs/dataset.py @@ -40,6 +40,14 @@ def ds(self): return open_dataset(*self.args, **self.kwargs) + @cached_property + def latitudes(self): + return self.ds.latitudes + + @cached_property + def longitudes(self): + return self.ds.longitudes + def __repr__(self): return f"DatasetInput({self.args}, {self.kwargs})" @@ -51,8 +59,8 @@ def create_input_state(self, *, date=None): input_state = dict( date=date, - latitudes=self.ds.latitudes, - longitudes=self.ds.longitudes, + latitudes=self.latitudes, + longitudes=self.longitudes, fields=dict(), ) @@ -73,7 +81,7 @@ def create_input_state(self, *, date=None): return input_state - def load_forcings(self, *, variables, dates): + def load_forcings_state(self, *, variables, dates, current_state): data = self._load_dates(dates) # (date, variables, ensemble, values) requested_variables = np.array([self.ds.name_to_index[v] for v in variables]) @@ -82,7 +90,14 @@ def load_forcings(self, *, variables, dates): data = np.squeeze(data, axis=2) # Reorder the dimensions to (variable, date, values) data = np.swapaxes(data, 0, 1) - return data + fields = {v: data[i] for i, v in enumerate(variables)} + + return dict( + fields=fields, + dates=dates, + latitudes=self.latitudes, + longitudes=self.longitudes, + ) def _load_dates(self, dates): diff --git a/src/anemoi/inference/inputs/ekd.py b/src/anemoi/inference/inputs/ekd.py index 1db8947..fc26123 100644 --- a/src/anemoi/inference/inputs/ekd.py +++ b/src/anemoi/inference/inputs/ekd.py @@ -79,7 +79,7 @@ def __init__(self, context, *, namer=None): self._namer = namer if namer is not None else self.checkpoint.default_namer() assert callable(self._namer), type(self._namer) - def _create_input_state( + def _create_state( self, input_fields, *, @@ -208,6 +208,13 @@ def _name(field, _, original_metadata): data = FieldArray([f.clone(name=_name) for f in data]) return data.sel(name=name, **kwargs) - def _load_forcings(self, fields, variables, dates): - data = self._filter_and_sort(fields, variables=variables, dates=dates, title="Load forcings") - return data.to_numpy(dtype=np.float32, flatten=True).reshape(len(variables), len(dates), -1) + def _load_forcings_state(self, fields, variables, dates, current_state): + return self._create_state( + fields, + variables=variables, + date=dates, + latitudes=current_state["latitudes"], + longitudes=current_state["longitudes"], + dtype=np.float32, + flatten=True, + ) diff --git a/src/anemoi/inference/inputs/gribfile.py b/src/anemoi/inference/inputs/gribfile.py index 0720aef..e6636b7 100644 --- a/src/anemoi/inference/inputs/gribfile.py +++ b/src/anemoi/inference/inputs/gribfile.py @@ -31,10 +31,12 @@ def __init__(self, context, path, *, namer=None, **kwargs): self.path = path def create_input_state(self, *, date): - return self._create_input_state(ekd.from_source("file", self.path), variables=None, date=date) + return self._create_state(ekd.from_source("file", self.path), variables=None, date=date) - def load_forcings(self, *, variables, dates): - return self._load_forcings(ekd.from_source("file", self.path), variables=variables, dates=dates) + def load_forcings_state(self, *, variables, dates, current_state): + return self._load_forcings_state( + ekd.from_source("file", self.path), variables=variables, dates=dates, current_state=current_state + ) def template(self, variable, date, **kwargs): fields = ekd.from_source("file", self.path) diff --git a/src/anemoi/inference/inputs/icon.py b/src/anemoi/inference/inputs/icon.py index 87a29d8..a775429 100644 --- a/src/anemoi/inference/inputs/icon.py +++ b/src/anemoi/inference/inputs/icon.py @@ -45,7 +45,7 @@ def create_input_state(self, *, date): LOG.info("Done") - return self._create_input_state( + return self._create_state( ekd.from_source("file", self.path), variables=None, date=date, @@ -53,5 +53,7 @@ def create_input_state(self, *, date): longitudes=longitudes, ) - def load_forcings(self, *, variables, dates): - return self._load_forcings(ekd.from_source("file", self.path), variables=variables, dates=dates) + def load_forcings_state(self, *, variables, dates, current_state): + return self._load_forcings_state( + ekd.from_source("file", self.path), variables=variables, dates=dates, current_state=current_state + ) diff --git a/src/anemoi/inference/inputs/mars.py b/src/anemoi/inference/inputs/mars.py index c3182d0..9703eb4 100644 --- a/src/anemoi/inference/inputs/mars.py +++ b/src/anemoi/inference/inputs/mars.py @@ -126,7 +126,7 @@ def create_input_state(self, *, date): date = to_datetime(date) - return self._create_input_state( + return self._create_state( self.retrieve( self.variables, [date + h for h in self.checkpoint.lagged], @@ -151,5 +151,7 @@ def retrieve(self, variables, dates): def template(self, variable, date, **kwargs): return self.retrieve([variable], [date])[0] - def load_forcings(self, variables, dates): - return self._load_forcings(self.retrieve(variables, dates), variables, dates) + def load_forcings_state(self, *, variables, dates, current_state): + return self._load_forcings_state( + self.retrieve(variables, dates), variables=variables, dates=dates, current_state=current_state + ) diff --git a/src/anemoi/inference/inputs/multi.py b/src/anemoi/inference/inputs/multi.py index 42fe159..b37ad6f 100644 --- a/src/anemoi/inference/inputs/multi.py +++ b/src/anemoi/inference/inputs/multi.py @@ -44,7 +44,8 @@ def __repr__(self): def create_input_state(self, *, date): raise NotImplementedError("MultiInput.create_input_state() not implemented") - def load_forcings(self, variables, dates): + def load_forcings_state(self, *, variables, dates, current_state): + raise NotImplementedError("MultiInput.load_forcings_state() not implemented") inputs = defaultdict(list) for v in variables: if v not in self.inputs: @@ -54,7 +55,7 @@ def load_forcings(self, variables, dates): rows = {} for input, vs in inputs.items(): - array = self._input_per_id[input].load_forcings(variables=vs, dates=dates) + array = self._input_per_id[input].load_forcings_state(variables=vs, dates=dates) for i, v in enumerate(vs): rows[v] = array[i] diff --git a/src/anemoi/inference/inputs/templates.py b/src/anemoi/inference/inputs/templates.py index 3f6ec75..8c6ea45 100644 --- a/src/anemoi/inference/inputs/templates.py +++ b/src/anemoi/inference/inputs/templates.py @@ -83,5 +83,5 @@ def template(self, variable, date, **kwargs): return ekd.from_source("memory", template)[0] - def load_forcings(self, variables, dates): - raise NotImplementedError("TemplatesInput.load_forcings() not implemented") + def load_forcings_state(self, *, variables, dates, current_state): + raise NotImplementedError("TemplatesInput.load_forcings_state() not implemented") diff --git a/src/anemoi/inference/outputs/grib.py b/src/anemoi/inference/outputs/grib.py index 63a3099..096fc4f 100644 --- a/src/anemoi/inference/outputs/grib.py +++ b/src/anemoi/inference/outputs/grib.py @@ -109,7 +109,7 @@ def write_state(self, state): self.quiet.add(name) variable_keys = variable.grib_keys.copy() - for key in ("class", "type", "stream", "expver", "date", "time", "step"): + for key in ("class", "type", "stream", "expver", "date", "time", "step", "domain"): variable_keys.pop(key, None) keys.update(variable_keys) diff --git a/src/anemoi/inference/plugin.py b/src/anemoi/inference/plugin.py index 7aad9ec..b7a2e92 100644 --- a/src/anemoi/inference/plugin.py +++ b/src/anemoi/inference/plugin.py @@ -35,8 +35,10 @@ def __init__(self, context, *, input_fields): def create_input_state(self, *, date): return self._create_input_state(self.input_fields, variables=None, date=date) - def load_forcings(self, *, variables, dates): - return self._load_forcings(self.input_fields, variables=variables, dates=dates) + def load_forcings_state(self, *, variables, dates, current_state): + return self._load_forcings_state( + self.input_fields, variables=variables, dates=dates, current_state=current_state + ) def set_private_attributes(self, state, input_fields): input_fields = input_fields.order_by("valid_datetime") diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 9ff0d2a..dacf0c2 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -109,6 +109,9 @@ def checkpoint(self): def run(self, *, input_state, lead_time): + input_state = input_state.copy() + input_state["step"] = 0 + self.constant_forcings_inputs = self.checkpoint.constant_forcings_inputs(self, input_state) self.dynamic_forcings_inputs = self.checkpoint.dynamic_forcings_inputs(self, input_state) self.boundary_forcings_inputs = self.checkpoint.boundary_forcings_inputs(self, input_state) @@ -145,22 +148,34 @@ def add_initial_forcings_to_input_state(self, input_state): # TODO: Check for user provided forcings - for source in self.constant_forcings_inputs: + # We may need different forcings initial conditions + initial_constant_forcings_inputs = self.initial_constant_forcings_inputs(self.constant_forcings_inputs) + initial_dynamic_forcings_inputs = self.initial_dynamic_forcings_inputs(self.dynamic_forcings_inputs) + + for source in initial_constant_forcings_inputs: LOG.info("Constant forcings input: %s %s (%s)", source, source.variables, dates) - arrays = source.load_forcings(input_state, dates) + arrays = source.load_forcings_array(dates, input_state) for name, forcing in zip(source.variables, arrays): assert isinstance(forcing, np.ndarray), (name, forcing) fields[name] = forcing self._input_kinds[name] = Kind(forcing=True, constant=True, **source.kinds) - for source in self.dynamic_forcings_inputs: + for source in initial_dynamic_forcings_inputs: LOG.info("Dynamic forcings input: %s %s (%s)", source, source.variables, dates) - arrays = source.load_forcings(input_state, dates) + arrays = source.load_forcings_array(dates, input_state) for name, forcing in zip(source.variables, arrays): assert isinstance(forcing, np.ndarray), (name, forcing) fields[name] = forcing self._input_kinds[name] = Kind(forcing=True, constant=True, **source.kinds) + def initial_constant_forcings_inputs(self, constant_forcings_inputs): + # Give an opportunity to modify the forcings for the first step + return constant_forcings_inputs + + def initial_dynamic_forcings_inputs(self, dynamic_forcings_inputs): + # Give an opportunity to modify the forcings for the first step + return dynamic_forcings_inputs + def prepare_input_tensor(self, input_state, dtype=np.float32): if "latitudes" not in input_state: @@ -253,8 +268,8 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps) - result = input_state.copy() # We should not modify the input state - result["fields"] = dict() + new_state = input_state.copy() # We should not modify the input state + new_state["fields"] = dict() start = input_state["date"] @@ -279,7 +294,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): date = start + step LOG.info("Forecasting step %s (%s)", step, date) - result["date"] = date + new_state["date"] = date # Predict next state of atmosphere with torch.autocast(device_type=self.device, dtype=self.autocast): @@ -290,12 +305,13 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): # Update state for i in range(output.shape[1]): - result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i] + new_state["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i] if (s == 0 and self.verbosity > 0) or self.verbosity > 1: self._print_output_tensor("Output tensor", output) - yield result + new_state["step"] = s + 1 + yield new_state # Update tensor for next iteration @@ -305,10 +321,8 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): del y_pred # Recover memory - input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(input_tensor_torch, input_state, date, check) - input_tensor_torch = self.add_boundary_forcings_to_input_tensor( - input_tensor_torch, input_state, date, check - ) + input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(input_tensor_torch, new_state, date, check) + input_tensor_torch = self.add_boundary_forcings_to_input_tensor(input_tensor_torch, new_state, date, check) if not check.all(): # Not all variables have been updated @@ -361,7 +375,7 @@ def add_dynamic_forcings_to_input_tensor(self, input_tensor_torch, state, date, for source in self.dynamic_forcings_inputs: - forcings = source.load_forcings(state, [date]) # shape: (variables, dates, values) + forcings = source.load_forcings_array([date], state) # shape: (variables, dates, values) forcings = np.squeeze(forcings, axis=1) # Drop the dates dimension @@ -385,7 +399,7 @@ def add_boundary_forcings_to_input_tensor(self, input_tensor_torch, state, date, # batch is always 1 sources = self.boundary_forcings_inputs for source in sources: - forcings = source.load_forcings(state, [date]) # shape: (variables, dates, values) + forcings = source.load_forcings_array([date], state) # shape: (variables, dates, values) forcings = np.squeeze(forcings, axis=1) # Drop the dates dimension @@ -397,6 +411,10 @@ def add_boundary_forcings_to_input_tensor(self, input_tensor_torch, state, date, # TO DO: add some consistency checks as above return input_tensor_torch + def exchange_tensors(self, state): + # To be overriden by coupled models + pass + def validate_input_state(self, input_state): if not isinstance(input_state, dict): diff --git a/src/anemoi/inference/runners/simple.py b/src/anemoi/inference/runners/simple.py index 5910fac..6d68582 100644 --- a/src/anemoi/inference/runners/simple.py +++ b/src/anemoi/inference/runners/simple.py @@ -30,7 +30,7 @@ def __init__(self, context, variables, mask): self.mask = mask self.kinds = dict(unknown=True) - def load_forcings(self, state, date): + def load_forcings_state(self, state, date): pass diff --git a/src/anemoi/inference/tasks/input.py b/src/anemoi/inference/tasks/input.py index 4fc136f..b2f00dd 100644 --- a/src/anemoi/inference/tasks/input.py +++ b/src/anemoi/inference/tasks/input.py @@ -29,17 +29,6 @@ def __init__(self, name, config): super().__init__(name) self.config = load_config(config, []) - # def run(self, transport): - # transport.dispatch( - # self, - # { - # "load_forcings": self.load_forcings, - # }, - # ) - - # def load_forcings(self, variables, dates): - # assert False, (variables, dates) - def run(self, transport): LOG.info("Running task %s", self.name) couplings = transport.couplings(self) @@ -58,7 +47,7 @@ def run(self, transport): while date <= last: LOG.info(f"=============== Loading: {dates}") for input in inputs: - tensor = input.load_forcings({}, dates) + tensor = input.load_forcings_state({}, dates) LOG.info(f"Sending matrix: {tensor.shape} {tensor.size * tensor.itemsize}") for c in couplings: c.apply(self, transport, tensor, tag=tag) diff --git a/src/anemoi/inference/tasks/runner.py b/src/anemoi/inference/tasks/runner.py index 3854d9c..5123455 100644 --- a/src/anemoi/inference/tasks/runner.py +++ b/src/anemoi/inference/tasks/runner.py @@ -10,8 +10,6 @@ import logging -import numpy as np - from anemoi.inference.config import load_config from anemoi.inference.forcings import CoupledForcings from anemoi.inference.runners.default import DefaultRunner @@ -33,6 +31,14 @@ def create_dynamic_coupled_forcings(self, variables, mask): result = CoupledForcings(self, self.input, variables, mask) return [result] + def initial_dynamic_forcings_inputs(self, dynamic_forcings_inputs): + # For the initial state we need to load the forcings + # from the default input. + result = [] + for c in dynamic_forcings_inputs: + result.extend(super().create_dynamic_coupled_forcings(c.variables, c.mask)) + return result + class CoupledInput: """_summary_""" @@ -41,19 +47,18 @@ def __init__(self, task, transport, couplings): self.task = task self.transport = transport self.couplings = couplings - self.tag = 0 - # def load_forcings(self, variables, dates): - # return self.transport.rpc(self.task, "load_forcings", variables, dates) + def load_forcings_state(self, *, variables, dates, current_state): + LOG.info("Adding dynamic forcings %s %s", variables, dates) + state = dict(variables=variables, dates=dates) - def load_forcings(self, variables, dates): - LOG.info("Adding dynamic forcings %s %s", len(variables), len(dates)) - tensor = self.tensor = np.zeros(shape=(11, len(dates), 40320), dtype=np.float32) for c in self.couplings: - c.apply(self.task, self.transport, tensor, tag=self.tag) + c.apply(self.task, self.transport, input_state=current_state, output_state=state) + + if len(dates) == 1: + state["date"] = dates[0] - self.tag += 1 - return tensor + return state @task_registry.register("runner") diff --git a/src/anemoi/inference/transport.py b/src/anemoi/inference/transport.py index 2292dd3..860d7d6 100644 --- a/src/anemoi/inference/transport.py +++ b/src/anemoi/inference/transport.py @@ -15,28 +15,31 @@ class Coupling: """_summary_""" - def __init__(self, source, sidx, target, tidx): + def __init__(self, source, target, variables): self.source = source - self.sidx = sidx self.target = target - self.tidx = tidx + self.variables = variables def __str__(self): - return f"{self.source}:{self.sidx}->{self.target}:{self.tidx}" + return f"{self.source}->{self.target}" class CouplingSend(Coupling): """_summary_""" - def apply(self, task, transport, tensor, tag): - transport.send_array(task, tensor[self.sidx], self.target, tag) + def apply(self, task, transport, *, input_state, output_state): + transport.send_state( + task, self.target, input_state=input_state, output_state=output_state, variables=self.variables + ) class CouplingRecv(Coupling): """_summary_""" - def apply(self, task, transport, tensor, tag): - transport.receive_array(task, tensor[self.tidx], self.source, tag) + def apply(self, task, transport, *, input_state, output_state): + transport.receive_state( + task, self.source, input_state=input_state, output_state=output_state, variables=self.variables + ) class Transport(ABC): @@ -54,37 +57,38 @@ def couplings(self, task): couplings = [] for coupling in self._couplings: - source, target = coupling.split("->") - source, sidx = source.strip().split(":") - target, tidx = target.strip().split(":") + assert isinstance(coupling, dict) + assert len(coupling) == 1 + k, variables = list(coupling.items())[0] + source, target = k.split("->") + source = source.strip() + target = target.strip() if task.name == source: couplings.append( CouplingSend( self.tasks[source], - int(sidx), self.tasks[target], - int(tidx), + variables, ) ) if task.name == target: couplings.append( CouplingRecv( self.tasks[source], - int(sidx), self.tasks[target], - int(tidx), + variables, ) ) return couplings # @abstractmethod - # def send_array(self, data, destination): + # def send_state(self, data, destination): # """_summary_""" # pass # @abstractmethod - # def receive_array(self, source): + # def receive_state(self, source): # """_summary_""" # pass diff --git a/src/anemoi/inference/transports/mpi.py b/src/anemoi/inference/transports/mpi.py index 806000d..802979f 100644 --- a/src/anemoi/inference/transports/mpi.py +++ b/src/anemoi/inference/transports/mpi.py @@ -44,13 +44,13 @@ def start(self): def wait(self): self.comm.barrier() - def send_array(self, sender, tensor, target, tag): + def send_state(self, sender, tensor, target, tag): assert sender.name != target.name, f"Cannot send to self {sender}" LOG.info(f"{sender}: sending to {target} {tag}") self.comm.Send(tensor, dest=self.ranks[target.name], tag=tag) LOG.info(f"{sender}: sent to {target} {tag}") - def receive_array(self, receiver, tensor, source, tag): + def receive_state(self, receiver, tensor, source, tag): assert receiver.name != source.name, f"Cannot receive from self {receiver}" LOG.info(f"{receiver}: receiving from {source} {tag}") self.comm.Recv(tensor, source=self.ranks[source.name], tag=tag) diff --git a/src/anemoi/inference/transports/processes.py b/src/anemoi/inference/transports/processes.py index 8c676ba..4fba573 100644 --- a/src/anemoi/inference/transports/processes.py +++ b/src/anemoi/inference/transports/processes.py @@ -14,7 +14,8 @@ import select import struct -import numpy as np +from anemoi.utils.logs import enable_logging_name +from anemoi.utils.logs import set_logging_name from ..transport import Transport from . import transport_registry @@ -29,8 +30,10 @@ class ProcessesTransport(Transport): def __init__(self, couplings, rpcs, tasks, *args, **kwargs): super().__init__(couplings, rpcs, tasks) self.children = {} + enable_logging_name("main") def child_process(self, task): + set_logging_name(task.name) # Close all the pipes that are not needed for (task1, task2), (read_fd, write_fd) in self.pipes.items(): @@ -63,7 +66,7 @@ def start(self): if pid == 0: os._exit(self.child_process(task)) else: - self.children[name] = pid + self.children[pid] = name # We need to close the pipes in the parent process for read_fd, write_fd in self.pipes.values(): @@ -71,38 +74,75 @@ def start(self): os.close(write_fd) def wait(self): - for name, pid in self.children.items(): - if os.waitpid(pid, 0) != (pid, 0): - raise RuntimeError("Child process failed %s %s" % (name, pid)) + while self.children: + (pid, status) = os.wait() + LOG.info(f"Child process {pid} ({self.children[pid]}) exited with status {status}") + del self.children[pid] + + if status != 0: + for pid in self.children: + os.kill(pid, 15) + + def send_state(self, sender, target, *, input_state, output_state, variables): + + assert isinstance(input_state, dict) - def send_array(self, sender, tensor, target, tag): assert sender.name != target.name, f"Cannot send to self {sender}" _, write_fd = self.pipes[(sender.name, target.name)] - os.write(write_fd, "a".encode()) # a for array + fields = input_state["fields"] + + LOG.info(f"{sender}: sending to {target} {variables}") + + fields = {v: fields[v] for v in variables if v in fields} + + state = input_state.copy() + state["fields"] = fields + + # Don't send unnecessary data + state["latitudes"] = None + state["longitudes"] = None + for s in list(state.keys()): + if s.startswith("_"): + del state[s] + + # TODO: something more efficient than pickle + + pickle_data = pickle.dumps(state) - header = np.array([tag, tensor.size * tensor.itemsize], dtype=np.uint64) + os.write(write_fd, struct.pack("!Q", len(pickle_data))) + os.write(write_fd, pickle_data) - os.write(write_fd, header.tobytes()) - os.write(write_fd, tensor.tobytes()) + def receive_state(self, receiver, source, *, input_state, output_state, variables): - def receive_array(self, receiver, tensor, source, tag): assert receiver.name != source.name, f"Cannot receive from self {receiver}" - tag = np.uint64(tag) read_fd, _ = self.pipes[(source.name, receiver.name)] - code = os.read(read_fd, 1).decode() - assert code == "a", f"Expected array got {code}" - - # Read the data - header = os.read(read_fd, np.dtype(np.uint64).itemsize * 2) - header = np.frombuffer(header, dtype=np.uint64) - assert tag == header[0] - size = header[1] - LOG.info(f"{receiver}: receiving from {source} {tag} {size} {tensor.size * tensor.itemsize} {tensor.shape}") - tensor[:] = np.ndarray(buffer=os.read(read_fd, size), dtype=tensor.dtype, shape=tensor.shape) - LOG.info(f"{receiver}: received from {source} {tag}") + size = struct.unpack("!Q", os.read(read_fd, 8))[0] + data = os.read(read_fd, size) + state = pickle.loads(data) + if isinstance(state, Exception): + raise state + + assert isinstance(state, dict) + assert input_state["date"] == state["date"] + assert "fields" in state + assert isinstance(state["fields"], dict), f"Expected dict got {type(state['fields'])}" + + output_state.setdefault("fields", {}) + + fields_in = state["fields"] + fields_out = output_state["fields"] + + for v in variables: + if v in fields_out: + raise ValueError(f"Variable {v} already in output state") + + if v not in fields_in: + raise ValueError(f"Variable {v} not in input state") + + fields_out[v] = fields_in[v] def rpc(self, sender, proc, *args, **kwargs): diff --git a/src/anemoi/inference/transports/threads.py b/src/anemoi/inference/transports/threads.py index e02ce5d..14bca99 100644 --- a/src/anemoi/inference/transports/threads.py +++ b/src/anemoi/inference/transports/threads.py @@ -64,13 +64,13 @@ def wait(self): if wrapped_task.error: raise wrapped_task.error - def send_array(self, sender, tensor, target, tag): + def send_state(self, sender, tensor, target, tag): assert sender.name != target.name, f"Cannot send to self {sender}" LOG.info(f"{sender}: sending to {target} {tag}") self.wrapped_tasks[target.name].queue.put((sender.name, tensor, tag)) LOG.info(f"{sender}: sent to {target} {tag}") - def receive_array(self, receiver, tensor, source, tag): + def receive_state(self, receiver, tensor, source, tag): assert receiver.name != source.name, f"Cannot receive from self {receiver}" LOG.info(f"{receiver}: receiving from {source} {tag} (backlog: {len(self.backlogs[receiver.name])})")