Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

latest #2

Merged
merged 11 commits into from
Jun 12, 2019
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ which demonstrate how to use Dopamine.
This is not an official Google product.

## What's new
* **11/06/2019:** Visualization utilities added to generate videos and still
images of a trained agent interacting with its environment. See an example
colaboratory
[here](https://colab.research.google.com/github/google/dopamine/blob/master/dopamine/colab/agent_visualizer.ipynb).
* **30/01/2019:** Dopamine 2.0 now supports general discrete-domain gym
environments.
* **01/11/2018:** Download links for each individual checkpoint, to avoid
Expand Down
24 changes: 19 additions & 5 deletions dopamine/agents/dqn/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def __init__(self,
epsilon=0.00001,
centered=True),
summary_writer=None,
summary_writing_frequency=500):
summary_writing_frequency=500,
allow_partial_reload=False):
"""Initializes the agent and constructs the components of its graph.

Args:
Expand Down Expand Up @@ -142,6 +143,8 @@ def __init__(self,
Summary writing disabled if set to None.
summary_writing_frequency: int, frequency with which summaries will be
written. Lower values will result in slower training.
allow_partial_reload: bool, whether we allow reloading a partial agent
(for instance, only the network parameters).
"""
assert isinstance(observation_shape, tuple)
tf.logging.info('Creating %s agent with the following parameters:',
Expand All @@ -157,6 +160,8 @@ def __init__(self,
tf.logging.info('\t tf_device: %s', tf_device)
tf.logging.info('\t use_staging: %s', use_staging)
tf.logging.info('\t optimizer: %s', optimizer)
tf.logging.info('\t max_tf_checkpoints_to_keep: %d',
max_tf_checkpoints_to_keep)

self.num_actions = num_actions
self.observation_shape = tuple(observation_shape)
Expand All @@ -178,6 +183,7 @@ def __init__(self,
self.optimizer = optimizer
self.summary_writer = summary_writer
self.summary_writing_frequency = summary_writing_frequency
self.allow_partial_reload = allow_partial_reload

with tf.device(tf_device):
# Create a placeholder for the state input to the DQN network.
Expand Down Expand Up @@ -516,13 +522,21 @@ def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary):
"""
try:
# self._replay.load() will throw a NotFoundError if it does not find all
# the necessary files, in which case we abort the process & return False.
# the necessary files.
self._replay.load(checkpoint_dir, iteration_number)
except tf.errors.NotFoundError:
if not self.allow_partial_reload:
# If we don't allow partial reloads, we will return False.
return False
tf.logging.warning('Unable to reload replay buffer!')
if bundle_dictionary is not None:
for key in self.__dict__:
if key in bundle_dictionary:
self.__dict__[key] = bundle_dictionary[key]
elif not self.allow_partial_reload:
return False
for key in self.__dict__:
if key in bundle_dictionary:
self.__dict__[key] = bundle_dictionary[key]
else:
tf.logging.warning("Unable to reload the agent's parameters!")
# Restore the agent's TensorFlow graph.
self._saver.restore(self._sess,
os.path.join(checkpoint_dir,
Expand Down
6 changes: 6 additions & 0 deletions dopamine/colab/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ In this
[colab](https://colab.research.google.com/github/google/dopamine/blob/master/dopamine/colab/load_statistics.ipynb)
we illustrate how to load and visualize the logs data produced by Dopamine.

## Visualizing trained agents
In this
[colab](https://colab.research.google.com/github/google/dopamine/blob/master/dopamine/colab/agent_visualizer.ipynb)
we illustrate how to visualize a trained agent using the visualization utilities
provided with Dopamine.

## Visualizing with Tensorboard
In this
[colab](https://colab.research.google.com/github/google/dopamine/blob/master/dopamine/colab/tensorboard.ipynb)
Expand Down
199 changes: 199 additions & 0 deletions dopamine/colab/agent_visualizer.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dopamine/colab/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def read_experiment(log_path,
experiment_path, iteration_number=iteration_number, verbose=verbose)

summary = summarize_data(raw_data, summary_keys)
for iteration in range(last_iteration):
for iteration in range(last_iteration + 1):
# The row contains all the parameters, the iteration, and finally the
# requested values.
row_data = (list(parameter_tuple) + [iteration] +
Expand Down
2 changes: 1 addition & 1 deletion dopamine/discrete_domains/atari_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@


import atari_py
import gin
import gym
from gym.spaces.box import Box
import numpy as np
import tensorflow as tf

import gin.tf
import cv2

slim = tf.contrib.slim
Expand Down
10 changes: 9 additions & 1 deletion dopamine/discrete_domains/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,28 @@

import os
import pickle

import gin
import tensorflow as tf

CHECKPOINT_DURATION = 4


def get_latest_checkpoint_number(base_directory):
@gin.configurable
def get_latest_checkpoint_number(base_directory, override_number=None):
"""Returns the version number of the latest completed checkpoint.

Args:
base_directory: str, directory in which to look for checkpoint files.
override_number: None or int, allows the user to manually override
the checkpoint number via a gin-binding.

Returns:
int, the iteration number of the latest checkpoint, or -1 if none was found.
"""
if override_number is not None:
return override_number

glob = os.path.join(base_directory, 'sentinel_checkpoint_complete.*')
def extract_iteration(x):
return int(x[x.rfind('.') + 1:])
Expand Down
16 changes: 10 additions & 6 deletions dopamine/discrete_domains/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,12 @@ def __init__(self,
self._summary_writer = tf.summary.FileWriter(self._base_dir)

self._environment = create_environment_fn()
config = tf.ConfigProto(allow_soft_placement=True)
# Allocate only subset of the GPU memory as needed which allows for running
# multiple agents/workers on the same GPU.
config.gpu_options.allow_growth = True
# Set up a session and initialize variables.
self._sess = tf.Session('',
config=tf.ConfigProto(allow_soft_placement=True))
self._sess = tf.Session('', config=config)
self._agent = create_agent_fn(self._sess, self._environment,
summary_writer=self._summary_writer)
self._summary_writer.add_graph(graph=tf.get_default_graph())
Expand Down Expand Up @@ -231,10 +234,11 @@ def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix):
latest_checkpoint_version)
if self._agent.unbundle(
self._checkpoint_dir, latest_checkpoint_version, experiment_data):
assert 'logs' in experiment_data
assert 'current_iteration' in experiment_data
self._logger.data = experiment_data['logs']
self._start_iteration = experiment_data['current_iteration'] + 1
if experiment_data is not None:
assert 'logs' in experiment_data
assert 'current_iteration' in experiment_data
self._logger.data = experiment_data['logs']
self._start_iteration = experiment_data['current_iteration'] + 1
tf.logging.info('Reloaded checkpoint and will start from iteration %d',
self._start_iteration)

Expand Down
126 changes: 126 additions & 0 deletions dopamine/utils/agent_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# coding=utf-8
# Copyright 2018 The Dopamine Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to visualize different aspects of an agent's behaviour.

This file defines the class AgentVisualizer, which allows one to combine
a number of Plotter objects into a series of single images, generated during
agent interaction with the environment.
If requested, this class will combine the image files into a movie.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import subprocess


import gin
import numpy as np
import pygame
import scipy


@gin.configurable
class AgentVisualizer(object):
"""Code to visualize an agent's behaviour."""

def __init__(self,
record_path,
plotters,
screen_width=160,
screen_height=210,
render_rate=1,
file_types=('png', ''),
filename_format='frame_{:06d}'):
"""Constructor for the AgentVisualizer class.

This class generates a series of images built by a set of Plotters. These
images are then saved to disk.

It can optionally generate a video by concatenating all the images with
ffmpeg.

Args:
record_path: str, path where to save files.
plotters: list of `Plotter` objects to draw.
screen_width: int, width of generated images.
screen_height: int, height of generated images.
render_rate: int, frame frequency at which to generate files.
file_types: list of str, specifies the file types to generate.
filename_format: str, format to use for saving files.
"""
self.record_path = record_path
self.plotters = plotters
self.screen_width = screen_width
self.screen_height = screen_height
self.render_rate = render_rate
self.file_types = file_types
self.filename_format = filename_format
self.step = 0
self.record_frame = np.zeros((self.screen_height, self.screen_width, 3),
dtype=np.uint8)
# This is necessary to avoid a `pygame.error: No available video device`
# error.
os.environ['SDL_VIDEODRIVER'] = 'dummy'
pygame.init()
self.screen = pygame.display.set_mode((self.screen_width,
self.screen_height),
0, 32)

def visualize(self):
if self.step % self.render_rate == 0:
self.screen.fill((0, 0, 0))
for plotter in self.plotters:
self.screen = self.screen.copy() # To avoid locked Surfaces issue.
self.screen.blit(plotter.draw(), (plotter.x, plotter.y))
self.save_frame()
self.step += 1

def save_frame(self):
"""Save a frame to disk and generate a video, if enabled."""
screen_buffer = (
np.frombuffer(self.screen.get_buffer(), dtype=np.int32)
.reshape(self.screen_height, self.screen_width))
sb = screen_buffer[:, 0:self.screen_width]
self.record_frame[..., 2] = sb % 256
self.record_frame[..., 1] = (sb >> 8) % 256
self.record_frame[..., 0] = (sb >> 16) % 256
frame_number = self.step // self.render_rate
for file_type in self.file_types:
if not file_type:
continue
filename = (
self.filename_format.format(frame_number) + '.{}'.format(file_type))
scipy.misc.imsave(os.path.join(self.record_path, filename),
self.record_frame)

def generate_video(self, video_file='video.mp4'):
"""Generates a video, requires 'png' be in file_types.

Note that this will issue a `subprocess.call` to `ffmpeg`, so only use this
functionality with trusted paths.

Args:
video_file: str, name of video file to generate.
"""
if 'png' not in self.file_types:
return
os.chdir(self.record_path)
file_regex = self.filename_format.replace('{:', '%').replace('}', '')
file_regex += '.png'
subprocess.call(['ffmpeg', '-r', '30', '-f', 'image2', '-s', '1920x1080',
'-i', file_regex, '-vcodec', 'libx264', '-crf', '25',
'-pix_fmt', 'yuv420p', video_file])
68 changes: 68 additions & 0 deletions dopamine/utils/atari_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# coding=utf-8
# Copyright 2018 The Dopamine Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AtariPlotter used for rendering Atari 2600 frames.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


from dopamine.utils import plotter
import gin
import numpy as np
import pygame


@gin.configurable
class AtariPlotter(plotter.Plotter):
"""A Plotter for rendering Atari 2600 frames."""

_defaults = {
'x': 0,
'y': 0,
'width': 160,
'height': 210,
}

def __init__(self, parameter_dict=None):
"""Constructor for AtariPlotter.

Args:
parameter_dict: None or dict of parameter specifications for
visualization. If an expected parameter is present, its value will
be used, otherwise it will use defaults.
"""
super(AtariPlotter, self).__init__(parameter_dict)
assert 'environment' in self.parameters
self.game_surface = pygame.Surface((self.parameters['width'],
self.parameters['height']))

def draw(self):
"""Render the Atari 2600 frame.

Returns:
object to be rendered by AgentVisualizer.
"""
environment = self.parameters['environment']
numpy_surface = np.frombuffer(self.game_surface.get_buffer(),
dtype=np.int32)
obs = environment.render(mode='rgb_array').astype(np.int32)
obs = np.transpose(obs)
obs = np.swapaxes(obs, 1, 2)
obs = obs[2] | (obs[1] << 8) | (obs[0] << 16)
np.copyto(numpy_surface, obs.ravel())
return pygame.transform.scale(self.game_surface,
(self.parameters['width'],
self.parameters['height']))
Loading