diff --git a/examples/map_gen_example.py b/examples/map_gen_example.py index ecd3e93..6d2ab8d 100644 --- a/examples/map_gen_example.py +++ b/examples/map_gen_example.py @@ -18,7 +18,7 @@ Run basic environment. """ # Set up parameters - num_rows, num_cols = 200, 100 + num_rows, num_cols = 25, 50 num_populated_areas = 5 # example of generating map (other parameters are set to their default values) @@ -58,6 +58,7 @@ "populated_areas": populated_areas, "paths": paths, "paths_to_pops": paths_to_pops, + "skip": True, } env = gymnasium.make("pyrorl/PyroRL-v0", **kwargs) diff --git a/pyrorl/pyrorl/envs/environment/environment.py b/pyrorl/pyrorl/envs/environment/environment.py index a1527de..86e41ae 100644 --- a/pyrorl/pyrorl/envs/environment/environment.py +++ b/pyrorl/pyrorl/envs/environment/environment.py @@ -39,8 +39,8 @@ def __init__( custom_fire_locations: Optional[np.ndarray] = None, wind_speed: Optional[float] = None, wind_angle: Optional[float] = None, - fuel_mean:float = 8.5, - fuel_stdev:float = 3 + fuel_mean: float = 8.5, + fuel_stdev: float = 3, ): """ The constructor defines the state and action space, initializes the fires, diff --git a/pyrorl/pyrorl/envs/pyrorl.py b/pyrorl/pyrorl/envs/pyrorl.py index f35eb15..4123726 100644 --- a/pyrorl/pyrorl/envs/pyrorl.py +++ b/pyrorl/pyrorl/envs/pyrorl.py @@ -14,7 +14,7 @@ from typing import Optional, Any # Constants for visualization -WIDTH, HEIGHT = 475, 475 +IMG_DIRECTORY = "grid_screenshots/" FIRE_COLOR = pygame.Color("#ef476f") POPULATED_COLOR = pygame.Color("#073b4c") EVACUATING_COLOR = pygame.Color("#118ab2") @@ -34,8 +34,9 @@ def __init__( custom_fire_locations: Optional[np.ndarray] = None, wind_speed: Optional[float] = None, wind_angle: Optional[float] = None, - fuel_mean:float = 8.5, - fuel_stdev:float = 3 + fuel_mean: float = 8.5, + fuel_stdev: float = 3, + skip: bool = False, ): """ Set up the basic environment and its parameters. @@ -51,6 +52,7 @@ def __init__( self.wind_angle = wind_angle self.fuel_mean = fuel_mean self.fuel_stdev = fuel_stdev + self.skip = skip self.fire_env = FireWorld( num_rows, num_cols, @@ -60,8 +62,8 @@ def __init__( custom_fire_locations=custom_fire_locations, wind_speed=wind_speed, wind_angle=wind_angle, - fuel_mean = fuel_mean, - fuel_stdev = fuel_stdev + fuel_mean=fuel_mean, + fuel_stdev=fuel_stdev, ) # Set up action space @@ -74,13 +76,9 @@ def __init__( low=0, high=200, shape=observations.shape, dtype=np.float64 ) - # Set up grid constants - self.grid_width = WIDTH // num_rows - self.grid_height = HEIGHT // num_cols - # Create directory to store screenshots - if os.path.exists("grid_screenshots") is False: - os.mkdir("grid_screenshots") + if os.path.exists(IMG_DIRECTORY) is False: + os.mkdir(IMG_DIRECTORY) def reset( self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None @@ -125,61 +123,42 @@ def render_hf( surface_width = screen.get_width() surface_height = screen.get_height() - # Starting locations - x = int(surface_width * 0.35) - y = int(surface_height * 0.8) + # Starting locations and timestep + x_offset, y_offset = 0.05, 0.05 + timestep = self.fire_env.get_timestep() # Set title of the screen - timestep = self.fire_env.get_timestep() text = font.render("Timestep #: " + str(timestep), True, (0, 0, 0)) - screen.blit(text, (50, 25)) - - # Grass component - text = font.render("Grass", True, (0, 0, 0)) - screen.blit(text, (x + 65, y + 20)) - pygame.draw.rect(screen, GRASS_COLOR, (x, y, 50, 50)) - - # Update y - y += 75 - - # Fire component - text = font.render("Fire", True, (0, 0, 0)) - screen.blit(text, (x + 65, y + 20)) - pygame.draw.rect(screen, FIRE_COLOR, (x, y, 50, 50)) - - # Update locations - x += 175 - y -= 75 - - # Populated component - text = font.render("Populated", True, (0, 0, 0)) - screen.blit(text, (x + 65, y + 20)) - pygame.draw.rect(screen, POPULATED_COLOR, (x, y, 50, 50)) - - # Update y - y += 75 - - # Evacuating component - text = font.render("Evacuating", True, (0, 0, 0)) - screen.blit(text, (x + 65, y + 20)) - pygame.draw.rect(screen, EVACUATING_COLOR, (x, y, 50, 50)) - - # Update locations - x += 175 - y -= 75 - - # Path component - text = font.render("Path", True, (0, 0, 0)) - screen.blit(text, (x + 65, y + 20)) - pygame.draw.rect(screen, PATH_COLOR, (x, y, 50, 50)) - - # Update y location - y += 75 - - # Evacuated component - text = font.render("Evacuated", True, (0, 0, 0)) - screen.blit(text, (x + 65, y + 20)) - pygame.draw.rect(screen, FINISHED_COLOR, (x, y, 50, 50)) + screen.blit(text, (surface_width * x_offset, surface_height * y_offset)) + + # Set initial grid squares and offsets + grid_squares = [ + (GRASS_COLOR, "Grass"), + (FIRE_COLOR, "Fire"), + (POPULATED_COLOR, "Populated"), + (EVACUATING_COLOR, "Evacuating"), + (PATH_COLOR, "Path"), + (FINISHED_COLOR, "Finished"), + ] + x_offset, y_offset = 0.2, 0.045 + + # Iterate through, create the grid squares + for i in range(len(grid_squares)): + + # Get the color and name, set in the screen + (color, name) = grid_squares[i] + pygame.draw.rect( + screen, + color, + (surface_width * x_offset, surface_height * y_offset, 25, 25), + ) + text = font.render(name, True, (0, 0, 0)) + screen.blit( + text, (surface_width * x_offset + 35, surface_height * y_offset + 5) + ) + + # Adjust appropriate offset + x_offset += 0.125 return screen @@ -198,17 +177,30 @@ def render(self): screen_width = screen_info.current_w screen_height = screen_info.current_h - # Set up pygame and font - screen = pygame.display.set_mode([screen_width, screen_height]) - self.grid_width = (0.8 * screen_width) // self.num_rows - self.grid_height = (0.6 * screen_height) // self.num_cols + # Set up screen and font + surface_width = screen_width * 0.8 + surface_height = screen_height * 0.8 + screen = pygame.display.set_mode([surface_width, surface_height]) font = pygame.font.Font(None, 25) # Set screen details screen.fill((255, 255, 255)) - pygame.display.set_caption("Wildfire Evacuation RL Gym Environment") + pygame.display.set_caption("PyroRL") screen = self.render_hf(screen, font) + # Calculation for square + total_width = 0.85 * surface_width - 2 * (cols - 1) + total_height = 0.85 * surface_height - 2 * (rows - 1) + square_dim = min(int(total_width / cols), int(total_height / rows)) + + # Calculate start x, start y + start_x = surface_width - 2 * (cols - 1) - square_dim * cols + start_y = ( + surface_height - 2 * (rows - 1) - square_dim * rows + 0.05 * surface_height + ) + start_x /= 2 + start_y /= 2 + # Running the loop! running = True while running: @@ -216,45 +208,45 @@ def render(self): for event in pygame.event.get(): if event.type == pygame.QUIT: timestep = self.fire_env.get_timestep() - pygame.image.save( - screen, "grid_screenshots/" + str(timestep) + ".png" - ) + pygame.image.save(screen, IMG_DIRECTORY + str(timestep) + ".png") running = False # Iterate through all of the squares # Note: try to vectorize? - for x in range(rows): - for y in range(cols): + for x in range(cols): + for y in range(rows): # Set color of the square color = GRASS_COLOR - if state_space[4][x][y] > 0: + if state_space[4][y][x] > 0: color = PATH_COLOR - if state_space[0][x][y] == 1: + if state_space[0][y][x] == 1: color = FIRE_COLOR - if state_space[2][x][y] == 1: + if state_space[2][y][x] == 1: color = POPULATED_COLOR - if state_space[3][x][y] > 0: + if state_space[3][y][x] > 0: color = EVACUATING_COLOR - if [x, y] in finished_evacuating: + if [y, x] in finished_evacuating: color = FINISHED_COLOR # Draw the square - self.grid_dim = min(self.grid_width, self.grid_height) + # self.grid_dim = min(self.grid_width, self.grid_height) square_rect = pygame.Rect( - 50 + x * (self.grid_dim + 2), - 50 + y * (self.grid_dim + 2), - self.grid_dim, - self.grid_dim, - # 50 + x * (self.grid_width + 2), - # 50 + y * (self.grid_height + 2), - # self.grid_width, - # self.grid_height, + start_x + x * (square_dim + 2), + start_y + y * (square_dim + 2), + square_dim, + square_dim, ) pygame.draw.rect(screen, color, square_rect) # Render and then quit outside pygame.display.flip() + + # If we skip, then we basically just render the canvas and then quit outside + if self.skip: + timestep = self.fire_env.get_timestep() + pygame.image.save(screen, IMG_DIRECTORY + str(timestep) + ".png") + running = False pygame.quit() def generate_gif(self): @@ -262,6 +254,6 @@ def generate_gif(self): Save run as a GIF. """ files = [str(i) for i in range(1, self.fire_env.get_timestep() + 1)] - images = [imageio.imread("grid_screenshots/" + f + ".png") for f in files] + images = [imageio.imread(IMG_DIRECTORY + f + ".png") for f in files] imageio.mimsave("training.gif", images) - shutil.rmtree("grid_screenshots") + shutil.rmtree(IMG_DIRECTORY) diff --git a/pyrorl/pyrorl/map_helpers/create_map_info.py b/pyrorl/pyrorl/map_helpers/create_map_info.py index 8cac9ec..a186a56 100644 --- a/pyrorl/pyrorl/map_helpers/create_map_info.py +++ b/pyrorl/pyrorl/map_helpers/create_map_info.py @@ -40,7 +40,7 @@ def generate_pop_locations(num_rows, num_cols, num_populated_areas): # map pop_row = random.randint(1, num_rows - 2) pop_col = random.randint(1, num_cols - 2) - # Continue generating populated areas until one that + # Continue generating populated areas until one that # has not already been created is made while (pop_row, pop_col) in populated_areas: pop_row = random.randint(1, num_rows - 2) @@ -49,17 +49,18 @@ def generate_pop_locations(num_rows, num_cols, num_populated_areas): populated_areas = np.array(list(populated_areas)) return populated_areas + def save_map_info( num_rows, num_cols, num_populated_areas, populated_areas, paths, paths_to_pops ): """ This function saves five files: - - map_info.txt: lets the user easily see the number of rows, + - map_info.txt: lets the user easily see the number of rows, the number of columns, and the number of populated areas - populated_areas_array.pkl: saves the populated areas array - paths_array.pkl: saves the paths array - paths_to_pops_array.pkl: saves the paths to pops array - - map_size_and_percent_populated_list.pkl: saves a list that contains + - map_size_and_percent_populated_list.pkl: saves a list that contains the number of rows, number of columns, and number of populated areas """ # the map information is saved in the user's current working directory @@ -89,9 +90,14 @@ def save_array_to_pickle(current_map_directory, array, name): array_filename = os.path.join(current_map_directory, name) with open(array_filename, "wb") as f: pkl.dump(array, f) - save_array_to_pickle(current_map_directory, populated_areas, "populated_areas_array.pkl") + + save_array_to_pickle( + current_map_directory, populated_areas, "populated_areas_array.pkl" + ) save_array_to_pickle(current_map_directory, paths, "paths_array.pkl") - save_array_to_pickle(current_map_directory, paths_to_pops, "paths_to_pops_array.pkl") + save_array_to_pickle( + current_map_directory, paths_to_pops, "paths_to_pops_array.pkl" + ) # save the number of rows, number of columns, and number of populated areas map_size_and_percent_populated_list = [num_rows, num_cols, num_populated_areas] @@ -101,6 +107,7 @@ def save_array_to_pickle(current_map_directory, array, name): with open(map_size_and_percent_populated_list_filename, "wb") as f: pkl.dump(map_size_and_percent_populated_list, f) + def load_map_info(map_directory_path): """ This function loads in six variables to initialize a wildfire environment: @@ -111,10 +118,9 @@ def load_map_info(map_directory_path): - paths to pops array - number of populated areas """ + def load_pickle_file(name): - array_filename = os.path.join( - map_directory_path, name - ) + array_filename = os.path.join(map_directory_path, name) with open(array_filename, "rb") as f: return pkl.load(f) @@ -124,7 +130,9 @@ def load_pickle_file(name): paths_to_pops = load_pickle_file("paths_to_pops_array.pkl") # load the number of rows, number of columns, and number of populated areas - map_size_and_percent_populated_list = load_pickle_file("map_size_and_percent_populated_list.pkl") + map_size_and_percent_populated_list = load_pickle_file( + "map_size_and_percent_populated_list.pkl" + ) num_rows = map_size_and_percent_populated_list[0] num_cols = map_size_and_percent_populated_list[1] num_populated_areas = map_size_and_percent_populated_list[2] diff --git a/tests/package_test.py b/tests/package_test.py index f8f041b..f2b282e 100644 --- a/tests/package_test.py +++ b/tests/package_test.py @@ -178,13 +178,11 @@ def test_render(mocker): env.render() # Check that render requirements are satisfied - pygame.display.set_mode.assert_called_once_with([600, 725]) - pygame.display.set_caption.assert_called_once_with( - "Wildfire Evacuation RL Gym Environment" - ) + pygame.display.set_mode.assert_called() + pygame.display.set_caption.assert_called_once_with("PyroRL") pygame.display.flip.assert_called() num_drawn_rects = pygame.draw.rect.call_count - assert num_drawn_rects == num_rows * num_cols + 5 + assert num_drawn_rects == num_rows * num_cols + 6 def test_generate_gif(mocker):