-
Notifications
You must be signed in to change notification settings - Fork 24
/
state_generator.py
34 lines (30 loc) · 1.21 KB
/
state_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
class KarelStateGenerator(object):
def __init__(self, seed=None):
self.rng = np.random.RandomState(seed)
# generate an initial env
def generate_single_state(self, h=8, w=8, wall_prob=0.1):
s = np.zeros([h, w, 16]) > 0
# Wall
s[:, :, 4] = self.rng.rand(h, w) > 1 - wall_prob
s[0, :, 4] = True
s[h-1, :, 4] = True
s[:, 0, 4] = True
s[:, w-1, 4] = True
# Karel initial location
valid_loc = False
while(not valid_loc):
y = self.rng.randint(0, h)
x = self.rng.randint(0, w)
if not s[y, x, 4]:
valid_loc = True
s[y, x, self.rng.randint(0, 4)] = True
# Marker: num of max marker == 1 for now
s[:, :, 6] = (self.rng.rand(h, w) > 0.9) * (s[:, :, 4] == False) > 0
s[:, :, 5] = 1 - (np.sum(s[:, :, 6:], axis=-1) > 0) > 0
assert np.sum(s[:, :, 5:]) == h*w, np.sum(s[:, :, :5])
marker_weight = np.reshape(np.array(range(11)), (1, 1, 11))
return s, y, x, np.sum(s[:, :, 4]), np.sum(marker_weight*s[:, :, 5:])