Skip to content

Commit

Permalink
Fix visualization and correct test
Browse files Browse the repository at this point in the history
  • Loading branch information
cpondoc committed Apr 9, 2024
1 parent 1bcfb30 commit 6f00e5d
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 106 deletions.
3 changes: 2 additions & 1 deletion examples/map_gen_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Run basic environment.
"""
# Set up parameters
num_rows, num_cols = 200, 100
num_rows, num_cols = 25, 50
num_populated_areas = 5

# example of generating map (other parameters are set to their default values)
Expand Down Expand Up @@ -58,6 +58,7 @@
"populated_areas": populated_areas,
"paths": paths,
"paths_to_pops": paths_to_pops,
"skip": True,
}
env = gymnasium.make("pyrorl/PyroRL-v0", **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions pyrorl/pyrorl/envs/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def __init__(
custom_fire_locations: Optional[np.ndarray] = None,
wind_speed: Optional[float] = None,
wind_angle: Optional[float] = None,
fuel_mean:float = 8.5,
fuel_stdev:float = 3
fuel_mean: float = 8.5,
fuel_stdev: float = 3,
):
"""
The constructor defines the state and action space, initializes the fires,
Expand Down
170 changes: 81 additions & 89 deletions pyrorl/pyrorl/envs/pyrorl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Optional, Any

# Constants for visualization
WIDTH, HEIGHT = 475, 475
IMG_DIRECTORY = "grid_screenshots/"
FIRE_COLOR = pygame.Color("#ef476f")
POPULATED_COLOR = pygame.Color("#073b4c")
EVACUATING_COLOR = pygame.Color("#118ab2")
Expand All @@ -34,8 +34,9 @@ def __init__(
custom_fire_locations: Optional[np.ndarray] = None,
wind_speed: Optional[float] = None,
wind_angle: Optional[float] = None,
fuel_mean:float = 8.5,
fuel_stdev:float = 3
fuel_mean: float = 8.5,
fuel_stdev: float = 3,
skip: bool = False,
):
"""
Set up the basic environment and its parameters.
Expand All @@ -51,6 +52,7 @@ def __init__(
self.wind_angle = wind_angle
self.fuel_mean = fuel_mean
self.fuel_stdev = fuel_stdev
self.skip = skip
self.fire_env = FireWorld(
num_rows,
num_cols,
Expand All @@ -60,8 +62,8 @@ def __init__(
custom_fire_locations=custom_fire_locations,
wind_speed=wind_speed,
wind_angle=wind_angle,
fuel_mean = fuel_mean,
fuel_stdev = fuel_stdev
fuel_mean=fuel_mean,
fuel_stdev=fuel_stdev,
)

# Set up action space
Expand All @@ -74,13 +76,9 @@ def __init__(
low=0, high=200, shape=observations.shape, dtype=np.float64
)

# Set up grid constants
self.grid_width = WIDTH // num_rows
self.grid_height = HEIGHT // num_cols

# Create directory to store screenshots
if os.path.exists("grid_screenshots") is False:
os.mkdir("grid_screenshots")
if os.path.exists(IMG_DIRECTORY) is False:
os.mkdir(IMG_DIRECTORY)

def reset(
self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None
Expand Down Expand Up @@ -125,61 +123,42 @@ def render_hf(
surface_width = screen.get_width()
surface_height = screen.get_height()

# Starting locations
x = int(surface_width * 0.35)
y = int(surface_height * 0.8)
# Starting locations and timestep
x_offset, y_offset = 0.05, 0.05
timestep = self.fire_env.get_timestep()

# Set title of the screen
timestep = self.fire_env.get_timestep()
text = font.render("Timestep #: " + str(timestep), True, (0, 0, 0))
screen.blit(text, (50, 25))

# Grass component
text = font.render("Grass", True, (0, 0, 0))
screen.blit(text, (x + 65, y + 20))
pygame.draw.rect(screen, GRASS_COLOR, (x, y, 50, 50))

# Update y
y += 75

# Fire component
text = font.render("Fire", True, (0, 0, 0))
screen.blit(text, (x + 65, y + 20))
pygame.draw.rect(screen, FIRE_COLOR, (x, y, 50, 50))

# Update locations
x += 175
y -= 75

# Populated component
text = font.render("Populated", True, (0, 0, 0))
screen.blit(text, (x + 65, y + 20))
pygame.draw.rect(screen, POPULATED_COLOR, (x, y, 50, 50))

# Update y
y += 75

# Evacuating component
text = font.render("Evacuating", True, (0, 0, 0))
screen.blit(text, (x + 65, y + 20))
pygame.draw.rect(screen, EVACUATING_COLOR, (x, y, 50, 50))

# Update locations
x += 175
y -= 75

# Path component
text = font.render("Path", True, (0, 0, 0))
screen.blit(text, (x + 65, y + 20))
pygame.draw.rect(screen, PATH_COLOR, (x, y, 50, 50))

# Update y location
y += 75

# Evacuated component
text = font.render("Evacuated", True, (0, 0, 0))
screen.blit(text, (x + 65, y + 20))
pygame.draw.rect(screen, FINISHED_COLOR, (x, y, 50, 50))
screen.blit(text, (surface_width * x_offset, surface_height * y_offset))

# Set initial grid squares and offsets
grid_squares = [
(GRASS_COLOR, "Grass"),
(FIRE_COLOR, "Fire"),
(POPULATED_COLOR, "Populated"),
(EVACUATING_COLOR, "Evacuating"),
(PATH_COLOR, "Path"),
(FINISHED_COLOR, "Finished"),
]
x_offset, y_offset = 0.2, 0.045

# Iterate through, create the grid squares
for i in range(len(grid_squares)):

# Get the color and name, set in the screen
(color, name) = grid_squares[i]
pygame.draw.rect(
screen,
color,
(surface_width * x_offset, surface_height * y_offset, 25, 25),
)
text = font.render(name, True, (0, 0, 0))
screen.blit(
text, (surface_width * x_offset + 35, surface_height * y_offset + 5)
)

# Adjust appropriate offset
x_offset += 0.125

return screen

Expand All @@ -198,70 +177,83 @@ def render(self):
screen_width = screen_info.current_w
screen_height = screen_info.current_h

# Set up pygame and font
screen = pygame.display.set_mode([screen_width, screen_height])
self.grid_width = (0.8 * screen_width) // self.num_rows
self.grid_height = (0.6 * screen_height) // self.num_cols
# Set up screen and font
surface_width = screen_width * 0.8
surface_height = screen_height * 0.8
screen = pygame.display.set_mode([surface_width, surface_height])
font = pygame.font.Font(None, 25)

# Set screen details
screen.fill((255, 255, 255))
pygame.display.set_caption("Wildfire Evacuation RL Gym Environment")
pygame.display.set_caption("PyroRL")
screen = self.render_hf(screen, font)

# Calculation for square
total_width = 0.85 * surface_width - 2 * (cols - 1)
total_height = 0.85 * surface_height - 2 * (rows - 1)
square_dim = min(int(total_width / cols), int(total_height / rows))

# Calculate start x, start y
start_x = surface_width - 2 * (cols - 1) - square_dim * cols
start_y = (
surface_height - 2 * (rows - 1) - square_dim * rows + 0.05 * surface_height
)
start_x /= 2
start_y /= 2

# Running the loop!
running = True
while running:
# Did the user click the window close button?
for event in pygame.event.get():
if event.type == pygame.QUIT:
timestep = self.fire_env.get_timestep()
pygame.image.save(
screen, "grid_screenshots/" + str(timestep) + ".png"
)
pygame.image.save(screen, IMG_DIRECTORY + str(timestep) + ".png")
running = False

# Iterate through all of the squares
# Note: try to vectorize?
for x in range(rows):
for y in range(cols):
for x in range(cols):
for y in range(rows):

# Set color of the square
color = GRASS_COLOR
if state_space[4][x][y] > 0:
if state_space[4][y][x] > 0:
color = PATH_COLOR
if state_space[0][x][y] == 1:
if state_space[0][y][x] == 1:
color = FIRE_COLOR
if state_space[2][x][y] == 1:
if state_space[2][y][x] == 1:
color = POPULATED_COLOR
if state_space[3][x][y] > 0:
if state_space[3][y][x] > 0:
color = EVACUATING_COLOR
if [x, y] in finished_evacuating:
if [y, x] in finished_evacuating:
color = FINISHED_COLOR

# Draw the square
self.grid_dim = min(self.grid_width, self.grid_height)
# self.grid_dim = min(self.grid_width, self.grid_height)
square_rect = pygame.Rect(
50 + x * (self.grid_dim + 2),
50 + y * (self.grid_dim + 2),
self.grid_dim,
self.grid_dim,
# 50 + x * (self.grid_width + 2),
# 50 + y * (self.grid_height + 2),
# self.grid_width,
# self.grid_height,
start_x + x * (square_dim + 2),
start_y + y * (square_dim + 2),
square_dim,
square_dim,
)
pygame.draw.rect(screen, color, square_rect)

# Render and then quit outside
pygame.display.flip()

# If we skip, then we basically just render the canvas and then quit outside
if self.skip:
timestep = self.fire_env.get_timestep()
pygame.image.save(screen, IMG_DIRECTORY + str(timestep) + ".png")
running = False
pygame.quit()

def generate_gif(self):
"""
Save run as a GIF.
"""
files = [str(i) for i in range(1, self.fire_env.get_timestep() + 1)]
images = [imageio.imread("grid_screenshots/" + f + ".png") for f in files]
images = [imageio.imread(IMG_DIRECTORY + f + ".png") for f in files]
imageio.mimsave("training.gif", images)
shutil.rmtree("grid_screenshots")
shutil.rmtree(IMG_DIRECTORY)
26 changes: 17 additions & 9 deletions pyrorl/pyrorl/map_helpers/create_map_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def generate_pop_locations(num_rows, num_cols, num_populated_areas):
# map
pop_row = random.randint(1, num_rows - 2)
pop_col = random.randint(1, num_cols - 2)
# Continue generating populated areas until one that
# Continue generating populated areas until one that
# has not already been created is made
while (pop_row, pop_col) in populated_areas:
pop_row = random.randint(1, num_rows - 2)
Expand All @@ -49,17 +49,18 @@ def generate_pop_locations(num_rows, num_cols, num_populated_areas):
populated_areas = np.array(list(populated_areas))
return populated_areas


def save_map_info(
num_rows, num_cols, num_populated_areas, populated_areas, paths, paths_to_pops
):
"""
This function saves five files:
- map_info.txt: lets the user easily see the number of rows,
- map_info.txt: lets the user easily see the number of rows,
the number of columns, and the number of populated areas
- populated_areas_array.pkl: saves the populated areas array
- paths_array.pkl: saves the paths array
- paths_to_pops_array.pkl: saves the paths to pops array
- map_size_and_percent_populated_list.pkl: saves a list that contains
- map_size_and_percent_populated_list.pkl: saves a list that contains
the number of rows, number of columns, and number of populated areas
"""
# the map information is saved in the user's current working directory
Expand Down Expand Up @@ -89,9 +90,14 @@ def save_array_to_pickle(current_map_directory, array, name):
array_filename = os.path.join(current_map_directory, name)
with open(array_filename, "wb") as f:
pkl.dump(array, f)
save_array_to_pickle(current_map_directory, populated_areas, "populated_areas_array.pkl")

save_array_to_pickle(
current_map_directory, populated_areas, "populated_areas_array.pkl"
)
save_array_to_pickle(current_map_directory, paths, "paths_array.pkl")
save_array_to_pickle(current_map_directory, paths_to_pops, "paths_to_pops_array.pkl")
save_array_to_pickle(
current_map_directory, paths_to_pops, "paths_to_pops_array.pkl"
)

# save the number of rows, number of columns, and number of populated areas
map_size_and_percent_populated_list = [num_rows, num_cols, num_populated_areas]
Expand All @@ -101,6 +107,7 @@ def save_array_to_pickle(current_map_directory, array, name):
with open(map_size_and_percent_populated_list_filename, "wb") as f:
pkl.dump(map_size_and_percent_populated_list, f)


def load_map_info(map_directory_path):
"""
This function loads in six variables to initialize a wildfire environment:
Expand All @@ -111,10 +118,9 @@ def load_map_info(map_directory_path):
- paths to pops array
- number of populated areas
"""

def load_pickle_file(name):
array_filename = os.path.join(
map_directory_path, name
)
array_filename = os.path.join(map_directory_path, name)
with open(array_filename, "rb") as f:
return pkl.load(f)

Expand All @@ -124,7 +130,9 @@ def load_pickle_file(name):
paths_to_pops = load_pickle_file("paths_to_pops_array.pkl")

# load the number of rows, number of columns, and number of populated areas
map_size_and_percent_populated_list = load_pickle_file("map_size_and_percent_populated_list.pkl")
map_size_and_percent_populated_list = load_pickle_file(
"map_size_and_percent_populated_list.pkl"
)
num_rows = map_size_and_percent_populated_list[0]
num_cols = map_size_and_percent_populated_list[1]
num_populated_areas = map_size_and_percent_populated_list[2]
Expand Down
8 changes: 3 additions & 5 deletions tests/package_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,11 @@ def test_render(mocker):
env.render()

# Check that render requirements are satisfied
pygame.display.set_mode.assert_called_once_with([600, 725])
pygame.display.set_caption.assert_called_once_with(
"Wildfire Evacuation RL Gym Environment"
)
pygame.display.set_mode.assert_called()
pygame.display.set_caption.assert_called_once_with("PyroRL")
pygame.display.flip.assert_called()
num_drawn_rects = pygame.draw.rect.call_count
assert num_drawn_rects == num_rows * num_cols + 5
assert num_drawn_rects == num_rows * num_cols + 6


def test_generate_gif(mocker):
Expand Down

0 comments on commit 6f00e5d

Please sign in to comment.