Skip to content

Commit

Permalink
Update types and propagate callback.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed May 23, 2024
1 parent d757552 commit a82d7e6
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions flybody/agents/ray_distributed_dmpo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Classes for DMPO agent distributed with Ray."""

from typing import Optional, Iterator, Callable
from typing import Iterator, Callable
import socket
import dataclasses
import copy
Expand Down Expand Up @@ -61,6 +61,7 @@ class DMPOConfig:
replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE
print_fn: Callable = logging.info
userdata: dict | None = None
actor_observation_callback: Callable | None = None


class ReplayServer():
Expand Down Expand Up @@ -241,10 +242,9 @@ def __init__(
dmpo_config,
actor_or_evaluator='actor',
label=None,
ray_head_node_ip: Optional[str] = None,
egl_device_id_head_node: Optional[list] = None, # ['1', '2', '3']
egl_device_id_worker_node: Optional[
list] = None, # ['0', '1', '2', '3']
ray_head_node_ip: str | None = None,
egl_device_id_head_node: list | None = None,
egl_device_id_worker_node: list | None = None,
):
"""The actor process."""

Expand All @@ -258,12 +258,8 @@ def __init__(
running_on_head_node = True
break
if running_on_head_node:
# egl_device_id = egl_device_id_head_node[
# actor_count % len(egl_device_id_head_node)]
egl_device_id = np.random.choice(egl_device_id_head_node)
else:
# egl_device_id = egl_device_id_worker_node[
# actor_count % len(egl_device_id_worker_node)]
egl_device_id = np.random.choice(egl_device_id_worker_node)
os.environ['MUJOCO_EGL_DEVICE_ID'] = str(egl_device_id)

Expand Down Expand Up @@ -314,9 +310,11 @@ def wrapped_network_factory(action_spec):
save_data = self._config.logger_save_csv_data

# Create the agent.
actor = self._make_actor(policy_network=policy_network,
adder=adder,
variable_source=variable_source)
actor = self._make_actor(
policy_network=policy_network,
adder=adder,
variable_source=variable_source,
observation_callback=self._config.actor_observation_callback)

# Create logger and counter; actors will not spam bigtable.
counter = counting.Counter(parent=counter, prefix=actor_or_evaluator)
Expand Down Expand Up @@ -347,8 +345,9 @@ def isready(self):
def _make_actor(
self,
policy_network: snt.Module,
adder: Optional[adders.Adder] = None,
variable_source: Optional[core.VariableSource] = None,
adder: adders.Adder | None = None,
variable_source: core.VariableSource | None = None,
observation_callback: Callable | None = None,
):
"""Create an actor instance."""
if variable_source:
Expand All @@ -369,7 +368,8 @@ def _make_actor(
return DelayedFeedForwardActor(policy_network=policy_network,
adder=adder,
variable_client=variable_client,
action_delay=None)
action_delay=None,
observation_callback=observation_callback)

def _make_adder(self, replay_client: reverb.Client) -> adders.Adder:
"""Create an adder which records data generated by the actor/environment."""
Expand Down

0 comments on commit a82d7e6

Please sign in to comment.