Skip to content

Commit

Permalink
Merge pull request #2 from google/master
Browse files Browse the repository at this point in the history
latest
  • Loading branch information
SwanseaLeo authored Jun 12, 2019
2 parents 897d3e4 + 4f9299d commit 3751065
Show file tree
Hide file tree
Showing 18 changed files with 1,088 additions and 24 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ which demonstrate how to use Dopamine.
This is not an official Google product.

## What's new
* **11/06/2019:** Visualization utilities added to generate videos and still
images of a trained agent interacting with its environment. See an example
colaboratory
[here](https://colab.research.google.com/github/google/dopamine/blob/master/dopamine/colab/agent_visualizer.ipynb).
* **30/01/2019:** Dopamine 2.0 now supports general discrete-domain gym
environments.
* **01/11/2018:** Download links for each individual checkpoint, to avoid
Expand Down
24 changes: 19 additions & 5 deletions dopamine/agents/dqn/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def __init__(self,
epsilon=0.00001,
centered=True),
summary_writer=None,
summary_writing_frequency=500):
summary_writing_frequency=500,
allow_partial_reload=False):
"""Initializes the agent and constructs the components of its graph.
Args:
Expand Down Expand Up @@ -142,6 +143,8 @@ def __init__(self,
Summary writing disabled if set to None.
summary_writing_frequency: int, frequency with which summaries will be
written. Lower values will result in slower training.
allow_partial_reload: bool, whether we allow reloading a partial agent
(for instance, only the network parameters).
"""
assert isinstance(observation_shape, tuple)
tf.logging.info('Creating %s agent with the following parameters:',
Expand All @@ -157,6 +160,8 @@ def __init__(self,
tf.logging.info('\t tf_device: %s', tf_device)
tf.logging.info('\t use_staging: %s', use_staging)
tf.logging.info('\t optimizer: %s', optimizer)
tf.logging.info('\t max_tf_checkpoints_to_keep: %d',
max_tf_checkpoints_to_keep)

self.num_actions = num_actions
self.observation_shape = tuple(observation_shape)
Expand All @@ -178,6 +183,7 @@ def __init__(self,
self.optimizer = optimizer
self.summary_writer = summary_writer
self.summary_writing_frequency = summary_writing_frequency
self.allow_partial_reload = allow_partial_reload

with tf.device(tf_device):
# Create a placeholder for the state input to the DQN network.
Expand Down Expand Up @@ -516,13 +522,21 @@ def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary):
"""
try:
# self._replay.load() will throw a NotFoundError if it does not find all
# the necessary files, in which case we abort the process & return False.
# the necessary files.
self._replay.load(checkpoint_dir, iteration_number)
except tf.errors.NotFoundError:
if not self.allow_partial_reload:
# If we don't allow partial reloads, we will return False.
return False
tf.logging.warning('Unable to reload replay buffer!')
if bundle_dictionary is not None:
for key in self.__dict__:
if key in bundle_dictionary:
self.__dict__[key] = bundle_dictionary[key]
elif not self.allow_partial_reload:
return False
for key in self.__dict__:
if key in bundle_dictionary:
self.__dict__[key] = bundle_dictionary[key]
else:
tf.logging.warning("Unable to reload the agent's parameters!")
# Restore the agent's TensorFlow graph.
self._saver.restore(self._sess,
os.path.join(checkpoint_dir,
Expand Down
6 changes: 6 additions & 0 deletions dopamine/colab/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ In this
[colab](https://colab.research.google.com/github/google/dopamine/blob/master/dopamine/colab/load_statistics.ipynb)
we illustrate how to load and visualize the logs data produced by Dopamine.

## Visualizing trained agents
In this
[colab](https://colab.research.google.com/github/google/dopamine/blob/master/dopamine/colab/agent_visualizer.ipynb)
we illustrate how to visualize a trained agent using the visualization utilities
provided with Dopamine.

## Visualizing with Tensorboard
In this
[colab](https://colab.research.google.com/github/google/dopamine/blob/master/dopamine/colab/tensorboard.ipynb)
Expand Down
199 changes: 199 additions & 0 deletions dopamine/colab/agent_visualizer.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dopamine/colab/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def read_experiment(log_path,
experiment_path, iteration_number=iteration_number, verbose=verbose)

summary = summarize_data(raw_data, summary_keys)
for iteration in range(last_iteration):
for iteration in range(last_iteration + 1):
# The row contains all the parameters, the iteration, and finally the
# requested values.
row_data = (list(parameter_tuple) + [iteration] +
Expand Down
2 changes: 1 addition & 1 deletion dopamine/discrete_domains/atari_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@


import atari_py
import gin
import gym
from gym.spaces.box import Box
import numpy as np
import tensorflow as tf

import gin.tf
import cv2

slim = tf.contrib.slim
Expand Down
10 changes: 9 additions & 1 deletion dopamine/discrete_domains/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,28 @@

import os
import pickle

import gin
import tensorflow as tf

CHECKPOINT_DURATION = 4


def get_latest_checkpoint_number(base_directory):
@gin.configurable
def get_latest_checkpoint_number(base_directory, override_number=None):
"""Returns the version number of the latest completed checkpoint.
Args:
base_directory: str, directory in which to look for checkpoint files.
override_number: None or int, allows the user to manually override
the checkpoint number via a gin-binding.
Returns:
int, the iteration number of the latest checkpoint, or -1 if none was found.
"""
if override_number is not None:
return override_number

glob = os.path.join(base_directory, 'sentinel_checkpoint_complete.*')
def extract_iteration(x):
return int(x[x.rfind('.') + 1:])
Expand Down
16 changes: 10 additions & 6 deletions dopamine/discrete_domains/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,12 @@ def __init__(self,
self._summary_writer = tf.summary.FileWriter(self._base_dir)

self._environment = create_environment_fn()
config = tf.ConfigProto(allow_soft_placement=True)
# Allocate only subset of the GPU memory as needed which allows for running
# multiple agents/workers on the same GPU.
config.gpu_options.allow_growth = True
# Set up a session and initialize variables.
self._sess = tf.Session('',
config=tf.ConfigProto(allow_soft_placement=True))
self._sess = tf.Session('', config=config)
self._agent = create_agent_fn(self._sess, self._environment,
summary_writer=self._summary_writer)
self._summary_writer.add_graph(graph=tf.get_default_graph())
Expand Down Expand Up @@ -231,10 +234,11 @@ def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix):
latest_checkpoint_version)
if self._agent.unbundle(
self._checkpoint_dir, latest_checkpoint_version, experiment_data):
assert 'logs' in experiment_data
assert 'current_iteration' in experiment_data
self._logger.data = experiment_data['logs']
self._start_iteration = experiment_data['current_iteration'] + 1
if experiment_data is not None:
assert 'logs' in experiment_data
assert 'current_iteration' in experiment_data
self._logger.data = experiment_data['logs']
self._start_iteration = experiment_data['current_iteration'] + 1
tf.logging.info('Reloaded checkpoint and will start from iteration %d',
self._start_iteration)

Expand Down
126 changes: 126 additions & 0 deletions dopamine/utils/agent_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# coding=utf-8
# Copyright 2018 The Dopamine Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to visualize different aspects of an agent's behaviour.
This file defines the class AgentVisualizer, which allows one to combine
a number of Plotter objects into a series of single images, generated during
agent interaction with the environment.
If requested, this class will combine the image files into a movie.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import subprocess


import gin
import numpy as np
import pygame
import scipy


@gin.configurable
class AgentVisualizer(object):
"""Code to visualize an agent's behaviour."""

def __init__(self,
record_path,
plotters,
screen_width=160,
screen_height=210,
render_rate=1,
file_types=('png', ''),
filename_format='frame_{:06d}'):
"""Constructor for the AgentVisualizer class.
This class generates a series of images built by a set of Plotters. These
images are then saved to disk.
It can optionally generate a video by concatenating all the images with
ffmpeg.
Args:
record_path: str, path where to save files.
plotters: list of `Plotter` objects to draw.
screen_width: int, width of generated images.
screen_height: int, height of generated images.
render_rate: int, frame frequency at which to generate files.
file_types: list of str, specifies the file types to generate.
filename_format: str, format to use for saving files.
"""
self.record_path = record_path
self.plotters = plotters
self.screen_width = screen_width
self.screen_height = screen_height
self.render_rate = render_rate
self.file_types = file_types
self.filename_format = filename_format
self.step = 0
self.record_frame = np.zeros((self.screen_height, self.screen_width, 3),
dtype=np.uint8)
# This is necessary to avoid a `pygame.error: No available video device`
# error.
os.environ['SDL_VIDEODRIVER'] = 'dummy'
pygame.init()
self.screen = pygame.display.set_mode((self.screen_width,
self.screen_height),
0, 32)

def visualize(self):
if self.step % self.render_rate == 0:
self.screen.fill((0, 0, 0))
for plotter in self.plotters:
self.screen = self.screen.copy() # To avoid locked Surfaces issue.
self.screen.blit(plotter.draw(), (plotter.x, plotter.y))
self.save_frame()
self.step += 1

def save_frame(self):
"""Save a frame to disk and generate a video, if enabled."""
screen_buffer = (
np.frombuffer(self.screen.get_buffer(), dtype=np.int32)
.reshape(self.screen_height, self.screen_width))
sb = screen_buffer[:, 0:self.screen_width]
self.record_frame[..., 2] = sb % 256
self.record_frame[..., 1] = (sb >> 8) % 256
self.record_frame[..., 0] = (sb >> 16) % 256
frame_number = self.step // self.render_rate
for file_type in self.file_types:
if not file_type:
continue
filename = (
self.filename_format.format(frame_number) + '.{}'.format(file_type))
scipy.misc.imsave(os.path.join(self.record_path, filename),
self.record_frame)

def generate_video(self, video_file='video.mp4'):
"""Generates a video, requires 'png' be in file_types.
Note that this will issue a `subprocess.call` to `ffmpeg`, so only use this
functionality with trusted paths.
Args:
video_file: str, name of video file to generate.
"""
if 'png' not in self.file_types:
return
os.chdir(self.record_path)
file_regex = self.filename_format.replace('{:', '%').replace('}', '')
file_regex += '.png'
subprocess.call(['ffmpeg', '-r', '30', '-f', 'image2', '-s', '1920x1080',
'-i', file_regex, '-vcodec', 'libx264', '-crf', '25',
'-pix_fmt', 'yuv420p', video_file])
68 changes: 68 additions & 0 deletions dopamine/utils/atari_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# coding=utf-8
# Copyright 2018 The Dopamine Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AtariPlotter used for rendering Atari 2600 frames.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


from dopamine.utils import plotter
import gin
import numpy as np
import pygame


@gin.configurable
class AtariPlotter(plotter.Plotter):
"""A Plotter for rendering Atari 2600 frames."""

_defaults = {
'x': 0,
'y': 0,
'width': 160,
'height': 210,
}

def __init__(self, parameter_dict=None):
"""Constructor for AtariPlotter.
Args:
parameter_dict: None or dict of parameter specifications for
visualization. If an expected parameter is present, its value will
be used, otherwise it will use defaults.
"""
super(AtariPlotter, self).__init__(parameter_dict)
assert 'environment' in self.parameters
self.game_surface = pygame.Surface((self.parameters['width'],
self.parameters['height']))

def draw(self):
"""Render the Atari 2600 frame.
Returns:
object to be rendered by AgentVisualizer.
"""
environment = self.parameters['environment']
numpy_surface = np.frombuffer(self.game_surface.get_buffer(),
dtype=np.int32)
obs = environment.render(mode='rgb_array').astype(np.int32)
obs = np.transpose(obs)
obs = np.swapaxes(obs, 1, 2)
obs = obs[2] | (obs[1] << 8) | (obs[0] << 16)
np.copyto(numpy_surface, obs.ravel())
return pygame.transform.scale(self.game_surface,
(self.parameters['width'],
self.parameters['height']))
Loading

0 comments on commit 3751065

Please sign in to comment.