From 919b1ab138a8f981c5674771224d82e6d1c54949 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 6 Dec 2024 11:48:29 +0000 Subject: [PATCH] update --- src/anemoi/inference/forcings.py | 25 ++++-- src/anemoi/inference/runner.py | 4 - src/anemoi/inference/tasks/runner.py | 20 +++-- src/anemoi/inference/transport.py | 15 +++- src/anemoi/inference/transports/processes.py | 82 ++++---------------- 5 files changed, 61 insertions(+), 85 deletions(-) diff --git a/src/anemoi/inference/forcings.py b/src/anemoi/inference/forcings.py index 159f363..9699fb1 100644 --- a/src/anemoi/inference/forcings.py +++ b/src/anemoi/inference/forcings.py @@ -44,11 +44,20 @@ def _state_to_numpy(self, state, variables, dates) -> np.ndarray: And that the fields values are sorted by dates. """ fields = state["fields"] - result = np.stack([fields[v] for v in variables], axis=0) - assert result.shape[:2] == ( - len(variables), - len(dates), - ), (result.shape, variables, dates) + + if len(dates) == 1: + result = np.stack([fields[v] for v in variables], axis=0) + if len(result.shape) == 2: + result = result[:, np.newaxis] + else: + result = np.stack([fields[v] for v in variables], axis=0) + + assert len(result.shape) == 3 and result.shape[0] == len(variables) and result.shape[1] == len(dates), ( + result.shape, + variables, + dates, + ) + return result @@ -107,7 +116,11 @@ def __repr__(self): def load_forcings_array(self, dates, current_state): return self._state_to_numpy( - self.input.load_forcings_state(variables=self.variables, dates=dates, current_state=current_state), + self.input.load_forcings_state( + variables=self.variables, + dates=dates, + current_state=current_state, + ), self.variables, dates, ) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index dacf0c2..e3ac693 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -411,10 +411,6 @@ def add_boundary_forcings_to_input_tensor(self, input_tensor_torch, state, date, # TO DO: add some consistency checks as above return input_tensor_torch - def exchange_tensors(self, state): - # To be overriden by coupled models - pass - def validate_input_state(self, input_state): if not isinstance(input_state, dict): diff --git a/src/anemoi/inference/tasks/runner.py b/src/anemoi/inference/tasks/runner.py index 5123455..99995ca 100644 --- a/src/anemoi/inference/tasks/runner.py +++ b/src/anemoi/inference/tasks/runner.py @@ -47,19 +47,27 @@ def __init__(self, task, transport, couplings): self.task = task self.transport = transport self.couplings = couplings + self.constants = {} def load_forcings_state(self, *, variables, dates, current_state): LOG.info("Adding dynamic forcings %s %s", variables, dates) - state = dict(variables=variables, dates=dates) + state = dict(variables=variables, date=dates) for c in self.couplings: - c.apply(self.task, self.transport, input_state=current_state, output_state=state) + c.apply(self.task, self.transport, input_state=current_state, output_state=state, constants=self.constants) - if len(dates) == 1: - state["date"] = dates[0] + for f, v in state["fields"].items(): + assert len(v.shape) == 1, (f, v.shape) + + assert state["date"] == dates, (state["date"], dates) return state + def initial_state(self, state): + # We want to copy the constants that may be requested by the other tasks + # For now, we keep it simple and just copy the whole state + self.constants = state["fields"].copy() + @task_registry.register("runner") class RunnerTask(Task): @@ -74,11 +82,13 @@ def run(self, transport): LOG.info("Running task %s", self.name) couplings = transport.couplings(self) - runner = CoupledRunner(self.config, CoupledInput(self, transport, couplings)) + coupler = CoupledInput(self, transport, couplings) + runner = CoupledRunner(self.config, coupler) input = runner.create_input() output = runner.create_output() input_state = input.create_input_state(date=self.config.date) + coupler.initial_state(output.reduce(input_state)) if self.config.write_initial_state: output.write_initial_state(input_state) diff --git a/src/anemoi/inference/transport.py b/src/anemoi/inference/transport.py index 860d7d6..1af55d1 100644 --- a/src/anemoi/inference/transport.py +++ b/src/anemoi/inference/transport.py @@ -27,18 +27,25 @@ def __str__(self): class CouplingSend(Coupling): """_summary_""" - def apply(self, task, transport, *, input_state, output_state): + def apply(self, task, transport, *, input_state, output_state, constants): transport.send_state( - task, self.target, input_state=input_state, output_state=output_state, variables=self.variables + task, + self.target, + input_state=input_state, + variables=self.variables, + constants=constants, ) class CouplingRecv(Coupling): """_summary_""" - def apply(self, task, transport, *, input_state, output_state): + def apply(self, task, transport, *, input_state, output_state, constants): transport.receive_state( - task, self.source, input_state=input_state, output_state=output_state, variables=self.variables + task, + self.source, + output_state=output_state, + variables=self.variables, ) diff --git a/src/anemoi/inference/transports/processes.py b/src/anemoi/inference/transports/processes.py index 4fba573..ce30104 100644 --- a/src/anemoi/inference/transports/processes.py +++ b/src/anemoi/inference/transports/processes.py @@ -11,7 +11,6 @@ import logging import os import pickle -import select import struct from anemoi.utils.logs import enable_logging_name @@ -83,7 +82,7 @@ def wait(self): for pid in self.children: os.kill(pid, 15) - def send_state(self, sender, target, *, input_state, output_state, variables): + def send_state(self, sender, target, *, input_state, variables, constants): assert isinstance(input_state, dict) @@ -92,10 +91,22 @@ def send_state(self, sender, target, *, input_state, output_state, variables): fields = input_state["fields"] - LOG.info(f"{sender}: sending to {target} {variables}") + LOG.info(f"{sender}: sending to {target} {variables} {input_state['date']}") fields = {v: fields[v] for v in variables if v in fields} + for v in variables: + if v not in fields: + # Check in the constants + if v in constants: + LOG.warning(f"{sender}: {v} not in fields, using the value from constants") + fields[v] = constants[v] + else: + raise ValueError(f"{sender}: Variable {v} not in fields or constants") + + for f, v in fields.items(): + assert len(v.shape) == 1, f"Expected got {v.shape}" + state = input_state.copy() state["fields"] = fields @@ -113,7 +124,7 @@ def send_state(self, sender, target, *, input_state, output_state, variables): os.write(write_fd, struct.pack("!Q", len(pickle_data))) os.write(write_fd, pickle_data) - def receive_state(self, receiver, source, *, input_state, output_state, variables): + def receive_state(self, receiver, source, *, output_state, variables): assert receiver.name != source.name, f"Cannot receive from self {receiver}" @@ -126,7 +137,6 @@ def receive_state(self, receiver, source, *, input_state, output_state, variable raise state assert isinstance(state, dict) - assert input_state["date"] == state["date"] assert "fields" in state assert isinstance(state["fields"], dict), f"Expected dict got {type(state['fields'])}" @@ -144,64 +154,4 @@ def receive_state(self, receiver, source, *, input_state, output_state, variable fields_out[v] = fields_in[v] - def rpc(self, sender, proc, *args, **kwargs): - - target = self.rpcs[proc] - - assert sender.name != target, f"Cannot send to self {sender}" - _, write_fd = self.pipes[(sender.name, target)] - read_fd, _ = self.pipes[(target, sender.name)] - - LOG.info(f"{sender}: sending rpc {proc} to {target} {read_fd} {write_fd}") - - os.write(write_fd, "r".encode()) - data = pickle.dumps((proc, args, kwargs)) - os.write(write_fd, struct.pack("!I", len(data))) - os.write(write_fd, data) - - code = os.read(read_fd, 1).decode() - assert code == "r", f"Expected array got {code}" - size = struct.unpack("!I", os.read(read_fd, 4))[0] - data = os.read(read_fd, size) - result = pickle.loads(data) - if isinstance(result, Exception): - raise result - return result - - def dispatch(self, task, dispatcher): - LOG.info(f"{task}: waiting for messages {self.pipes} {task.name}") - while True: - fds = [fd[0] for (peers, fd) in self.pipes.items() if task.name == peers[1]] - remotes = {fd[0]: peers[0] for (peers, fd) in self.pipes.items() if task.name == peers[1]} - - if not fds: - LOG.info(f"{task}: no more messages") - break - - LOG.info(f"{task}: waiting on {fds}") - read_fds, _, _ = select.select(fds, [], []) - LOG.info(f"{task}: got message {read_fds}") - - for read_fd in read_fds: - - LOG.info(f"{task}: reading from {read_fd}, remote is {remotes[read_fd]}") - - code = os.read(read_fd, 1).decode() - assert code == "r", f"Expected array got {code}" - size = struct.unpack("!I", os.read(read_fd, 4))[0] - data = os.read(read_fd, size) - (proc, args, kwargs) = pickle.loads(data) - - LOG.info(f"{task}: received rpc {proc} {args} {kwargs}") - - try: - result = dispatcher[proc](*args, **kwargs) - except Exception as e: - LOG.exception(e) - result = e - - _, write_fd = self.pipes[(task.name, remotes[read_fd])] - os.write(write_fd, "r".encode()) - data = pickle.dumps(result) - os.write(write_fd, struct.pack("!I", len(data))) - os.write(write_fd, data) + assert len(fields_out[v].shape) == 1, f"Expected got {fields_out[v].shape}"