Skip to content

Commit

Permalink
Merge pull request #216 from mr4msm/make_load_compatible_with_v02
Browse files Browse the repository at this point in the history
make ReplayBuffer.load() compatible with v0.2.0.
  • Loading branch information
toslunar authored Jan 18, 2018
2 parents 99da574 + 5d154cb commit a51847f
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion chainerrl/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from abc import ABCMeta
from abc import abstractmethod
from abc import abstractproperty
import collections

import numpy as np
import six.moves.cPickle as pickle
Expand Down Expand Up @@ -153,6 +154,10 @@ def save(self, filename):
def load(self, filename):
with open(filename, 'rb') as f:
self.memory = pickle.load(f)
if isinstance(self.memory, collections.deque):
# Load v0.2
self.memory = RandomAccessQueue(
self.memory, maxlen=self.memory.maxlen)

def stop_current_episode(self):
pass
Expand Down Expand Up @@ -281,7 +286,23 @@ def save(self, filename):

def load(self, filename):
with open(filename, 'rb') as f:
self.memory, self.episodic_memory = pickle.load(f)
memory = pickle.load(f)
if isinstance(memory, tuple):
self.memory, self.episodic_memory = memory
else:
# Load v0.2
# FIXME: The code works with EpisodicReplayBuffer
# but not with PrioritizedEpisodicReplayBuffer
self.memory = RandomAccessQueue(memory)
self.episodic_memory = RandomAccessQueue()

# Recover episodic_memory with best effort.
episode = []
for item in self.memory:
episode.append(item)
if item['is_state_terminal']:
self.episodic_memory.append(episode)
episode = []

def stop_current_episode(self):
if self.current_episode:
Expand Down

0 comments on commit a51847f

Please sign in to comment.