Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Dec 6, 2024
1 parent 49b804e commit 919b1ab
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 85 deletions.
25 changes: 19 additions & 6 deletions src/anemoi/inference/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,20 @@ def _state_to_numpy(self, state, variables, dates) -> np.ndarray:
And that the fields values are sorted by dates.
"""
fields = state["fields"]
result = np.stack([fields[v] for v in variables], axis=0)
assert result.shape[:2] == (
len(variables),
len(dates),
), (result.shape, variables, dates)

if len(dates) == 1:
result = np.stack([fields[v] for v in variables], axis=0)
if len(result.shape) == 2:
result = result[:, np.newaxis]
else:
result = np.stack([fields[v] for v in variables], axis=0)

assert len(result.shape) == 3 and result.shape[0] == len(variables) and result.shape[1] == len(dates), (
result.shape,
variables,
dates,
)

return result


Expand Down Expand Up @@ -107,7 +116,11 @@ def __repr__(self):

def load_forcings_array(self, dates, current_state):
return self._state_to_numpy(
self.input.load_forcings_state(variables=self.variables, dates=dates, current_state=current_state),
self.input.load_forcings_state(
variables=self.variables,
dates=dates,
current_state=current_state,
),
self.variables,
dates,
)
Expand Down
4 changes: 0 additions & 4 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,6 @@ def add_boundary_forcings_to_input_tensor(self, input_tensor_torch, state, date,
# TO DO: add some consistency checks as above
return input_tensor_torch

def exchange_tensors(self, state):
# To be overriden by coupled models
pass

def validate_input_state(self, input_state):

if not isinstance(input_state, dict):
Expand Down
20 changes: 15 additions & 5 deletions src/anemoi/inference/tasks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,27 @@ def __init__(self, task, transport, couplings):
self.task = task
self.transport = transport
self.couplings = couplings
self.constants = {}

def load_forcings_state(self, *, variables, dates, current_state):
LOG.info("Adding dynamic forcings %s %s", variables, dates)
state = dict(variables=variables, dates=dates)
state = dict(variables=variables, date=dates)

for c in self.couplings:
c.apply(self.task, self.transport, input_state=current_state, output_state=state)
c.apply(self.task, self.transport, input_state=current_state, output_state=state, constants=self.constants)

if len(dates) == 1:
state["date"] = dates[0]
for f, v in state["fields"].items():
assert len(v.shape) == 1, (f, v.shape)

assert state["date"] == dates, (state["date"], dates)

return state

def initial_state(self, state):
# We want to copy the constants that may be requested by the other tasks
# For now, we keep it simple and just copy the whole state
self.constants = state["fields"].copy()


@task_registry.register("runner")
class RunnerTask(Task):
Expand All @@ -74,11 +82,13 @@ def run(self, transport):
LOG.info("Running task %s", self.name)
couplings = transport.couplings(self)

runner = CoupledRunner(self.config, CoupledInput(self, transport, couplings))
coupler = CoupledInput(self, transport, couplings)
runner = CoupledRunner(self.config, coupler)
input = runner.create_input()
output = runner.create_output()

input_state = input.create_input_state(date=self.config.date)
coupler.initial_state(output.reduce(input_state))

if self.config.write_initial_state:
output.write_initial_state(input_state)
Expand Down
15 changes: 11 additions & 4 deletions src/anemoi/inference/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,25 @@ def __str__(self):
class CouplingSend(Coupling):
"""_summary_"""

def apply(self, task, transport, *, input_state, output_state):
def apply(self, task, transport, *, input_state, output_state, constants):
transport.send_state(
task, self.target, input_state=input_state, output_state=output_state, variables=self.variables
task,
self.target,
input_state=input_state,
variables=self.variables,
constants=constants,
)


class CouplingRecv(Coupling):
"""_summary_"""

def apply(self, task, transport, *, input_state, output_state):
def apply(self, task, transport, *, input_state, output_state, constants):
transport.receive_state(
task, self.source, input_state=input_state, output_state=output_state, variables=self.variables
task,
self.source,
output_state=output_state,
variables=self.variables,
)


Expand Down
82 changes: 16 additions & 66 deletions src/anemoi/inference/transports/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import logging
import os
import pickle
import select
import struct

from anemoi.utils.logs import enable_logging_name
Expand Down Expand Up @@ -83,7 +82,7 @@ def wait(self):
for pid in self.children:
os.kill(pid, 15)

def send_state(self, sender, target, *, input_state, output_state, variables):
def send_state(self, sender, target, *, input_state, variables, constants):

assert isinstance(input_state, dict)

Expand All @@ -92,10 +91,22 @@ def send_state(self, sender, target, *, input_state, output_state, variables):

fields = input_state["fields"]

LOG.info(f"{sender}: sending to {target} {variables}")
LOG.info(f"{sender}: sending to {target} {variables} {input_state['date']}")

fields = {v: fields[v] for v in variables if v in fields}

for v in variables:
if v not in fields:
# Check in the constants
if v in constants:
LOG.warning(f"{sender}: {v} not in fields, using the value from constants")
fields[v] = constants[v]
else:
raise ValueError(f"{sender}: Variable {v} not in fields or constants")

for f, v in fields.items():
assert len(v.shape) == 1, f"Expected got {v.shape}"

state = input_state.copy()
state["fields"] = fields

Expand All @@ -113,7 +124,7 @@ def send_state(self, sender, target, *, input_state, output_state, variables):
os.write(write_fd, struct.pack("!Q", len(pickle_data)))
os.write(write_fd, pickle_data)

def receive_state(self, receiver, source, *, input_state, output_state, variables):
def receive_state(self, receiver, source, *, output_state, variables):

assert receiver.name != source.name, f"Cannot receive from self {receiver}"

Expand All @@ -126,7 +137,6 @@ def receive_state(self, receiver, source, *, input_state, output_state, variable
raise state

assert isinstance(state, dict)
assert input_state["date"] == state["date"]
assert "fields" in state
assert isinstance(state["fields"], dict), f"Expected dict got {type(state['fields'])}"

Expand All @@ -144,64 +154,4 @@ def receive_state(self, receiver, source, *, input_state, output_state, variable

fields_out[v] = fields_in[v]

def rpc(self, sender, proc, *args, **kwargs):

target = self.rpcs[proc]

assert sender.name != target, f"Cannot send to self {sender}"
_, write_fd = self.pipes[(sender.name, target)]
read_fd, _ = self.pipes[(target, sender.name)]

LOG.info(f"{sender}: sending rpc {proc} to {target} {read_fd} {write_fd}")

os.write(write_fd, "r".encode())
data = pickle.dumps((proc, args, kwargs))
os.write(write_fd, struct.pack("!I", len(data)))
os.write(write_fd, data)

code = os.read(read_fd, 1).decode()
assert code == "r", f"Expected array got {code}"
size = struct.unpack("!I", os.read(read_fd, 4))[0]
data = os.read(read_fd, size)
result = pickle.loads(data)
if isinstance(result, Exception):
raise result
return result

def dispatch(self, task, dispatcher):
LOG.info(f"{task}: waiting for messages {self.pipes} {task.name}")
while True:
fds = [fd[0] for (peers, fd) in self.pipes.items() if task.name == peers[1]]
remotes = {fd[0]: peers[0] for (peers, fd) in self.pipes.items() if task.name == peers[1]}

if not fds:
LOG.info(f"{task}: no more messages")
break

LOG.info(f"{task}: waiting on {fds}")
read_fds, _, _ = select.select(fds, [], [])
LOG.info(f"{task}: got message {read_fds}")

for read_fd in read_fds:

LOG.info(f"{task}: reading from {read_fd}, remote is {remotes[read_fd]}")

code = os.read(read_fd, 1).decode()
assert code == "r", f"Expected array got {code}"
size = struct.unpack("!I", os.read(read_fd, 4))[0]
data = os.read(read_fd, size)
(proc, args, kwargs) = pickle.loads(data)

LOG.info(f"{task}: received rpc {proc} {args} {kwargs}")

try:
result = dispatcher[proc](*args, **kwargs)
except Exception as e:
LOG.exception(e)
result = e

_, write_fd = self.pipes[(task.name, remotes[read_fd])]
os.write(write_fd, "r".encode())
data = pickle.dumps(result)
os.write(write_fd, struct.pack("!I", len(data)))
os.write(write_fd, data)
assert len(fields_out[v].shape) == 1, f"Expected got {fields_out[v].shape}"

0 comments on commit 919b1ab

Please sign in to comment.