Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 692978188
  • Loading branch information
Chris Rawles authored and The android_world Authors committed Nov 4, 2024
1 parent 8deb3ab commit 70f826d
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion android_world/suite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def _run_task_suite(
checkpointer: checkpointer_lib.Checkpointer = checkpointer_lib.NullCheckpointer(),
demo_mode: bool = False,
agent_name: str = '',
return_full_episode_data: bool = False,
) -> list[dict[str, Any]]:
"""Runs e2e system on suite.
Expand All @@ -318,6 +319,8 @@ def _run_task_suite(
checkpointer: See docstring from `run`.
demo_mode: Whether to display the scoreboard.
agent_name: The name of the agent.
return_full_episode_data: Whether to return full episode data instead of
just metadata.
Returns:
Metadata for each episode, including the scripted reward.
Expand All @@ -334,7 +337,12 @@ def _run_task_suite(
completed_tasks, failed_tasks = _get_task_info(
checkpointer.load(fields=metadata_fields)
)
if (completed_tasks or failed_tasks) and return_full_episode_data:
raise ValueError(
'Cannot return full episode data when resuming from a checkpoint.'
)
episodes_metadata: list[dict[str, Any]] = []
full_episode_data = []
correct, total = 0, 0
for name, instances in suite.items():
msg = 'Running task: ' + name
Expand Down Expand Up @@ -364,6 +372,9 @@ def _run_task_suite(
episode[constants.EpisodeConstants.INSTANCE_ID] = i
checkpointer.save_episodes([episode], instance_name)

if return_full_episode_data:
full_episode_data.append(episode)

episodes_metadata.append({k: episode[k] for k in metadata_fields})
process_episodes(episodes_metadata, print_summary=True)

Expand All @@ -376,14 +387,15 @@ def _run_task_suite(
_update_scoreboard(correct, total, env.controller)
print()

return episodes_metadata
return full_episode_data if return_full_episode_data else episodes_metadata


def run(
suite: Suite,
agent: base_agent.EnvironmentInteractingAgent,
checkpointer: checkpointer_lib.Checkpointer = checkpointer_lib.NullCheckpointer(),
demo_mode: bool = False,
return_full_episode_data: bool = False,
) -> list[dict[str, Any]]:
"""Create suite and runs eval suite.
Expand All @@ -396,6 +408,8 @@ def run(
are executed.
demo_mode: Whether to run in demo mode, which displays a scoreboard and the
task instruction as a notification.
return_full_episode_data: Whether to return full episode data instead of
just metadata.
Returns:
Step-by-step data from each episode.
Expand Down Expand Up @@ -431,6 +445,7 @@ def run_episode(task: task_eval.TaskEval) -> episode_runner.EpisodeResult:
checkpointer=checkpointer,
demo_mode=demo_mode,
agent_name=agent.name,
return_full_episode_data=return_full_episode_data,
)

return results
Expand Down

0 comments on commit 70f826d

Please sign in to comment.