Skip to content

Commit

Permalink
Add tests for checking the interface and that the GIF is generated
Browse files Browse the repository at this point in the history
  • Loading branch information
cpondoc committed Mar 5, 2024
1 parent f9e5982 commit fa38e4c
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 21 deletions.
Binary file modified examples/training.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyrorl/pyrorl/envs/pyrorl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyrorl.envs.environment.environment import *
import gymnasium as gym
from gymnasium import spaces
import imageio
import imageio.v2 as imageio
import numpy as np
import os
import pygame
Expand Down
2 changes: 0 additions & 2 deletions tests/environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,10 @@ def test_setup():
action = random.randint(0, len(all_actions) - 1)

test_world.set_action(all_actions[action])
print("Action: " + str(action))

# Advance the gridworld and get the reward
test_world.advance_to_next_timestep()
reward = test_world.get_state_utility()
print("Reward: " + str(reward) + "\n")


def test_remove_path_on_fire():
Expand Down
111 changes: 93 additions & 18 deletions tests/package_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@

import gymnasium
import numpy as np
import os
import pygame
import pytest
import pyrorl


def PGEvent():
def __init__(self):
self.type = pygame.QUIT


def test_constructor():
"""
Test the constructor to make sure all variables are accounted for.
Expand Down Expand Up @@ -54,15 +50,18 @@ def test_constructor():
env = gymnasium.make("pyrorl/PyroRL-v0", **kwargs)

# Make basic checks for the constructor
assert env.num_rows == num_rows
assert env.num_cols == num_cols
np.testing.assert_array_equal(env.populated_areas, populated_areas)
np.testing.assert_array_equal(env.paths, paths)
assert env.get_wrapper_attr("num_rows") == num_rows
assert env.get_wrapper_attr("num_cols") == num_cols
np.testing.assert_array_equal(
env.get_wrapper_attr("populated_areas"), populated_areas
)
np.testing.assert_array_equal(env.get_wrapper_attr("paths"), paths)

# Special check for paths to populated areas
for key in paths_to_pops:
np.testing.assert_array_equal(
np.array(env.paths_to_pops[key]), np.array(paths_to_pops[key])
np.array(env.get_wrapper_attr("paths_to_pops")[key]),
np.array(paths_to_pops[key]),
)


Expand Down Expand Up @@ -115,19 +114,22 @@ def test_reset():

# Check that reset makes it all the same
env.reset()
assert env.num_rows == num_rows
assert env.num_cols == num_cols
np.testing.assert_array_equal(env.populated_areas, populated_areas)
np.testing.assert_array_equal(env.paths, paths)
assert env.get_wrapper_attr("num_rows") == num_rows
assert env.get_wrapper_attr("num_cols") == num_cols
np.testing.assert_array_equal(
env.get_wrapper_attr("populated_areas"), populated_areas
)
np.testing.assert_array_equal(env.get_wrapper_attr("paths"), paths)

# Special check for paths to populated areas
for key in paths_to_pops:
np.testing.assert_array_equal(
np.array(env.paths_to_pops[key]), np.array(paths_to_pops[key])
np.array(env.get_wrapper_attr("paths_to_pops")[key]),
np.array(paths_to_pops[key]),
)


def test_reset(mocker):
def test_render(mocker):
"""
Test that basic rendering is working through mocking.
"""
Expand Down Expand Up @@ -166,13 +168,86 @@ def test_reset(mocker):
}
env = gymnasium.make("pyrorl/PyroRL-v0", **kwargs)

# Render and make simple test check
# Reset environment
env.reset()

# Mock all of the Pygame elements
display_mock = mocker.patch("pygame.display")
draw_mock = mocker.patch("pygame.draw")
image_mock = mocker.patch("pygame.image")
event_mock = mocker.patch("pygame.event.get")
event_mock.return_value = [pygame.event.Event(pygame.QUIT)]

# Render environment
env.render()
# pygame.quit()

# Check that render requirements are satisfied
pygame.display.set_mode.assert_called_once_with([600, 725])
pygame.display.set_caption.assert_called_once_with(
"Wildfire Evacuation RL Gym Environment"
)
pygame.display.flip.assert_called()
num_drawn_rects = pygame.draw.rect.call_count
assert num_drawn_rects == num_rows * num_cols + 5


def test_generate_gif(mocker):
"""
Tests that a GIF is correctly generated by the code.
"""
# Set up parameters
num_rows, num_cols = 10, 10
populated_areas = np.array([[1, 2], [4, 8], [6, 4], [8, 7]])
paths = np.array(
[
[[1, 0], [1, 1]],
[[2, 2], [3, 2], [4, 2], [4, 1], [4, 0]],
[[2, 9], [2, 8], [3, 8]],
[[5, 8], [6, 8], [6, 9]],
[[7, 7], [6, 7], [6, 8], [6, 9]],
[[8, 6], [8, 5], [9, 5]],
[[8, 5], [9, 5], [7, 5], [7, 4]],
],
dtype=object,
)
paths_to_pops = {
0: [[1, 2]],
1: [[1, 2]],
2: [[4, 8]],
3: [[4, 8]],
4: [[8, 7]],
5: [[8, 7]],
6: [[6, 4]],
}

# Create environment
kwargs = {
"num_rows": num_rows,
"num_cols": num_cols,
"populated_areas": populated_areas,
"paths": paths,
"paths_to_pops": paths_to_pops,
}
env = gymnasium.make("pyrorl/PyroRL-v0", **kwargs)

# Reset the environment
env.reset()

# Mock Pygame event component
event_mock = mocker.patch("pygame.event.get")
event_mock.return_value = [pygame.event.Event(pygame.QUIT)]

# Run a simple loop of the environment
for _ in range(10):

# Take action and observation
action = env.get_wrapper_attr("action_space").sample()
env.step(action)

# Render environment and print reward
env.render()

# Generate the gif, check that it exists, and then remove it
env.generate_gif()
assert os.path.exists("training.gif")
os.remove("training.gif")

0 comments on commit fa38e4c

Please sign in to comment.