Skip to content

Commit

Permalink
Pass BaseSimulator and TaskManager directly to AndroidEnv.
Browse files Browse the repository at this point in the history
This is the first of a series of changes to keep the `AndroidEnv` and
`TaskManager` classes focused on RL interactions (e.g. rewards, begin/end of
episodes, resetting etc), while making the rest (e.g. `BaseSimulator`,
`Coordinator`) more independent and easier to use in other domains such as
LLMs.

For now, only `AndroidEnv.stats()` has been changed to minimize diffs, but more
will slowly come.

PiperOrigin-RevId: 683192156
  • Loading branch information
kenjitoyama authored and copybara-github committed Oct 7, 2024
1 parent 30d925c commit 7482480
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 18 deletions.
5 changes: 2 additions & 3 deletions android_env/components/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
Args:
simulator: A BaseSimulator instance.
task_manager: The TaskManager, responsible for coordinating RL tasks.
config: Settings to customize this Coordinator.
"""
self._simulator = simulator
self._task_manager = task_manager
Expand Down Expand Up @@ -453,9 +454,7 @@ def _get_time_since_last_observation(self) -> float:
def stats(self) -> dict[str, Any]:
"""Returns various statistics."""

output = copy.deepcopy(self._stats)
output.update(self._task_manager.stats())
return output
return copy.deepcopy(self._stats)

def load_state(
self, request: state_pb2.LoadStateRequest
Expand Down
16 changes: 13 additions & 3 deletions android_env/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,28 @@
from absl import logging
from android_env import env_interface
from android_env.components import coordinator as coordinator_lib
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators import base_simulator
from android_env.proto import adb_pb2
from android_env.proto import state_pb2
from android_env.proto import task_pb2
import dm_env
import numpy as np


class AndroidEnv(env_interface.AndroidEnvInterface):
"""An RL environment that interacts with Android apps."""

def __init__(self, coordinator: coordinator_lib.Coordinator):
def __init__(
self,
simulator: base_simulator.BaseSimulator,
coordinator: coordinator_lib.Coordinator,
task_manager: task_manager_lib.TaskManager,
):
"""Initializes the state of this AndroidEnv object."""

self._simulator = simulator
self._coordinator = coordinator
self._task_manager = task_manager
self._latest_action = {}
self._latest_observation = {}
self._latest_extras = {}
Expand Down Expand Up @@ -133,7 +141,9 @@ def raw_observation(self):
return self._latest_observation.copy()

def stats(self) -> dict[str, Any]:
return self._coordinator.stats()
coordinator_stats = self._coordinator.stats()
task_manager_stats = self._task_manager.stats()
return coordinator_stats | task_manager_stats

def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
return self._coordinator.execute_adb_call(call)
Expand Down
60 changes: 49 additions & 11 deletions android_env/environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

from absl.testing import absltest
from android_env import environment
from android_env.components import config_classes
from android_env.components import coordinator as coordinator_lib
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators.fake import fake_simulator
from android_env.proto import adb_pb2
from android_env.proto import state_pb2
from android_env.proto import task_pb2
import dm_env
import numpy as np

Expand All @@ -47,7 +49,14 @@ def _create_mock_coordinator() -> coordinator_lib.Coordinator:
class AndroidEnvTest(absltest.TestCase):

def test_specs(self):
env = environment.AndroidEnv(_create_mock_coordinator())
simulator = fake_simulator.FakeSimulator(
config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456))
)
coordinator = _create_mock_coordinator()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)

# Check action spec.
self.assertNotEmpty(env.action_spec())
Expand Down Expand Up @@ -77,7 +86,11 @@ def test_specs(self):
self.assertEqual(env.observation_spec()['orientation'].shape, (4,))

def test_reset_and_step(self):
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
simulator = fake_simulator.FakeSimulator(
config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456))
)
coordinator = _create_mock_coordinator()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
coordinator.action_spec.return_value = {
'action_type':
dm_env.specs.DiscreteArray(num_values=3),
Expand All @@ -90,7 +103,9 @@ def test_reset_and_step(self):
'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64),
'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8),
}
env = environment.AndroidEnv(coordinator)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
coordinator.rl_reset.return_value = dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=0.0,
Expand Down Expand Up @@ -125,9 +140,8 @@ def test_reset_and_step(self):
self.assertIn('click', extras)
self.assertEqual(extras['click'], np.array([246], dtype=np.int64))

coordinator.stats.return_value = {
'my_measurement': 135,
}
coordinator.stats.return_value = {'my_measurement': 135}
task_manager.stats.return_value = {'another_measurement': 79}

# Step again in the environment and check expectations again.
pixels = np.random.rand(987, 654, 3)
Expand Down Expand Up @@ -189,8 +203,14 @@ def test_reset_and_step(self):
np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0))

def test_adb_call(self):
simulator = fake_simulator.FakeSimulator(
config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456))
)
coordinator = _create_mock_coordinator()
env = environment.AndroidEnv(coordinator)
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
call = adb_pb2.AdbRequest(
force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah'))
expected_response = adb_pb2.AdbResponse(
Expand All @@ -203,8 +223,14 @@ def test_adb_call(self):
coordinator.execute_adb_call.assert_called_once_with(call)

def test_load_state(self):
simulator = fake_simulator.FakeSimulator(
config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456))
)
coordinator = _create_mock_coordinator()
env = environment.AndroidEnv(coordinator)
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
expected_response = state_pb2.LoadStateResponse(
status=state_pb2.LoadStateResponse.Status.OK
)
Expand All @@ -215,8 +241,14 @@ def test_load_state(self):
coordinator.load_state.assert_called_once_with(request)

def test_save_state(self):
simulator = fake_simulator.FakeSimulator(
config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456))
)
coordinator = _create_mock_coordinator()
env = environment.AndroidEnv(coordinator)
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
expected_response = state_pb2.SaveStateResponse(
status=state_pb2.SaveStateResponse.Status.OK
)
Expand All @@ -227,8 +259,14 @@ def test_save_state(self):
coordinator.save_state.assert_called_once_with(request)

def test_double_close(self):
simulator = fake_simulator.FakeSimulator(
config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456))
)
coordinator = _create_mock_coordinator()
env = environment.AndroidEnv(coordinator)
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
env.close()
env.close()
coordinator.close.assert_called_once()
Expand Down
4 changes: 3 additions & 1 deletion android_env/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def load(config: config_classes.AndroidEnvConfig) -> environment.AndroidEnv:
raise ValueError('Unsupported simulator config: {config.simulator}')

coordinator = coordinator_lib.Coordinator(simulator, task_manager)
return environment.AndroidEnv(coordinator=coordinator)
return environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)


def _process_emulator_launcher_config(
Expand Down

0 comments on commit 7482480

Please sign in to comment.