diff --git a/android_world/suite_utils.py b/android_world/suite_utils.py index 03c2f62..4c65efa 100644 --- a/android_world/suite_utils.py +++ b/android_world/suite_utils.py @@ -308,6 +308,7 @@ def _run_task_suite( checkpointer: checkpointer_lib.Checkpointer = checkpointer_lib.NullCheckpointer(), demo_mode: bool = False, agent_name: str = '', + return_full_episode_data: bool = False, ) -> list[dict[str, Any]]: """Runs e2e system on suite. @@ -318,6 +319,8 @@ def _run_task_suite( checkpointer: See docstring from `run`. demo_mode: Whether to display the scoreboard. agent_name: The name of the agent. + return_full_episode_data: Whether to return full episode data instead of + just metadata. Returns: Metadata for each episode, including the scripted reward. @@ -334,7 +337,12 @@ def _run_task_suite( completed_tasks, failed_tasks = _get_task_info( checkpointer.load(fields=metadata_fields) ) + if (completed_tasks or failed_tasks) and return_full_episode_data: + raise ValueError( + 'Cannot return full episode data when resuming from a checkpoint.' + ) episodes_metadata: list[dict[str, Any]] = [] + full_episode_data = [] correct, total = 0, 0 for name, instances in suite.items(): msg = 'Running task: ' + name @@ -364,6 +372,9 @@ def _run_task_suite( episode[constants.EpisodeConstants.INSTANCE_ID] = i checkpointer.save_episodes([episode], instance_name) + if return_full_episode_data: + full_episode_data.append(episode) + episodes_metadata.append({k: episode[k] for k in metadata_fields}) process_episodes(episodes_metadata, print_summary=True) @@ -376,7 +387,7 @@ def _run_task_suite( _update_scoreboard(correct, total, env.controller) print() - return episodes_metadata + return full_episode_data if return_full_episode_data else episodes_metadata def run( @@ -384,6 +395,7 @@ def run( agent: base_agent.EnvironmentInteractingAgent, checkpointer: checkpointer_lib.Checkpointer = checkpointer_lib.NullCheckpointer(), demo_mode: bool = False, + return_full_episode_data: bool = False, ) -> list[dict[str, Any]]: """Create suite and runs eval suite. @@ -396,6 +408,8 @@ def run( are executed. demo_mode: Whether to run in demo mode, which displays a scoreboard and the task instruction as a notification. + return_full_episode_data: Whether to return full episode data instead of + just metadata. Returns: Step-by-step data from each episode. @@ -431,6 +445,7 @@ def run_episode(task: task_eval.TaskEval) -> episode_runner.EpisodeResult: checkpointer=checkpointer, demo_mode=demo_mode, agent_name=agent.name, + return_full_episode_data=return_full_episode_data, ) return results