Skip to content

Commit

Permalink
Add step_hooks argument to train_agent with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
muupan committed May 2, 2017
1 parent beab149 commit 16f3494
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
20 changes: 15 additions & 5 deletions chainerrl/experiments/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=''):

def train_agent(agent, env, steps, outdir, max_episode_len=None,
step_offset=0, evaluator=None, successful_score=None,
logger=None):
step_hooks=[], logger=None):

logger = logger or logging.getLogger(__name__)

Expand All @@ -56,6 +56,9 @@ def train_agent(agent, env, steps, outdir, max_episode_len=None,
episode_r += r
episode_len += 1

for hook in step_hooks:
hook(env, agent, t)

if done or episode_len == max_episode_len or t == steps:
agent.stop_episode_and_train(obs, r, done=done)
logger.info('outdir:%s step:%s episode:%s R:%s',
Expand Down Expand Up @@ -90,7 +93,7 @@ def train_agent_with_evaluation(
agent, env, steps, eval_n_runs, eval_interval,
outdir, max_episode_len=None, step_offset=0, eval_explorer=None,
eval_max_episode_len=None, eval_env=None, successful_score=None,
render=False, logger=None):
render=False, step_hooks=[], logger=None):
"""Run a DQN-like agent.
Args:
Expand All @@ -106,6 +109,9 @@ def train_agent_with_evaluation(
eval_env: Environment used for evaluation.
successful_score (float): Finish training if the mean score is greater
or equal to this value if not None
step_hooks (list): List of callable objects that accepts
(env, agent, step) as arguments. They are called every step.
See chainerrl.experiments.hooks.
"""

logger = logger or logging.getLogger(__name__)
Expand All @@ -128,6 +134,10 @@ def train_agent_with_evaluation(
logger=logger)

train_agent(
agent, env, steps, outdir, max_episode_len=max_episode_len,
step_offset=step_offset, evaluator=evaluator,
successful_score=successful_score, logger=logger)
agent, env, steps, outdir,
max_episode_len=max_episode_len,
step_offset=step_offset,
evaluator=evaluator,
successful_score=successful_score,
step_hooks=step_hooks,
logger=logger)
55 changes: 55 additions & 0 deletions tests/experiments_tests/test_train_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases()
import tempfile
import unittest

import mock

import chainerrl


class TestTrainAgent(unittest.TestCase):

def test(self):

outdir = tempfile.mkdtemp()

agent = mock.Mock()
env = mock.Mock()
# Reaches the terminal state after five actions
env.reset.side_effect = [('state', 0)]
env.step.side_effect = [
(('state', 1), 0, False, {}),
(('state', 2), 0, False, {}),
(('state', 3), -0.5, False, {}),
(('state', 4), 0, False, {}),
(('state', 5), 1, True, {}),
]
hook = mock.Mock()

chainerrl.experiments.train_agent(
agent=agent,
env=env,
steps=5,
outdir=outdir,
step_hooks=[hook])

self.assertEqual(agent.act_and_train.call_count, 5)
self.assertEqual(agent.stop_episode_and_train.call_count, 1)

self.assertEqual(env.reset.call_count, 1)
self.assertEqual(env.step.call_count, 5)

self.assertEqual(hook.call_count, 5)
# A hook receives (env, agent, step)
for i, call in enumerate(hook.call_args_list):
args, kwargs = call
self.assertEqual(args[0], env)
self.assertEqual(args[1], agent)
# step starts with 1
self.assertEqual(args[2], i + 1)

0 comments on commit 16f3494

Please sign in to comment.