From 2cb968190f5a00bc68469a93c92ee339a018ff71 Mon Sep 17 00:00:00 2001 From: joey-obrien Date: Fri, 15 Dec 2023 15:10:56 -0800 Subject: [PATCH] fixes to environment.py --- .../envs/environment/environment.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/wildfire_evac/wildfire_evac/envs/environment/environment.py b/wildfire_evac/wildfire_evac/envs/environment/environment.py index b6936b6..4ba9fa3 100644 --- a/wildfire_evac/wildfire_evac/envs/environment/environment.py +++ b/wildfire_evac/wildfire_evac/envs/environment/environment.py @@ -11,7 +11,6 @@ import random from scipy.stats import bernoulli import torch - from .environment_constant import fire_mask """ @@ -30,6 +29,9 @@ class FireWorld: """ def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops, num_fire_cells = 2, custom_fire_locations = None): + # if they don't pass in populated areas and don't say they want auto-generated paths raise an exception, don't do an assert + + """ The constructor defines the state and action space, initializes the fires, and sets the paths and populated areas. @@ -37,7 +39,12 @@ def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops, nu # Define the state and action space self.reward = 0 self.state_space = np.zeros([5, num_rows, num_cols]) - self.actions = list(np.arange(len(paths) + 1)) # extra action for doing nothing + + num_actions = 0 + for key in paths_to_pops: + for _ in range(len(paths_to_pops[key])): + num_actions += 1 + self.actions = list(np.arange(num_actions + 1)) # extra action for doing nothing # Associate paths with populated areas and actions # Note: there seems to be an error that keeps popping up where this dictionary is not @@ -47,9 +54,13 @@ def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops, nu # We want to remember which action index corresponds to which population center # and which path (because we just provide an array like [1,2,3,4,5,6,7]) which # would each be mapped to a given population area taking a given path + self.action_to_pop_and_path = { self.actions[-1] : None} + action_val = 0 for key in self.paths_to_pops: - self.action_to_pop_and_path[key] = (paths_to_pops[key], key) # action index: list of pop x,y index and path index [[x,y],path_index] + for pop_area in paths_to_pops[key]: + self.action_to_pop_and_path[key] = (pop_area, action_val) + action_val += 1 # State for the evacuation of populated areas self.evacuating_paths = {} # path_index : list of pop x,y indices that are evacuating [[x,y],[x,y],...] @@ -77,7 +88,7 @@ def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops, nu # Note: right now paths is different from self.paths but we can change this later if needed self.paths = [] for path in paths: - path_array = np.array(path) + path_array = np.array(path).astype(int) path_rows, path_cols = path_array[:,0], path_array[:,1] self.state_space[PATHS_INDEX,path_rows,path_cols] += 1 @@ -128,7 +139,6 @@ def update_paths_and_evactuations(self): 2. Also stops evacuating any areas that were taking a burned down path 3. Also decrements the evacuation timestamps """ - self.state_space[FIRE_INDEX][1,1] = 1 for i in range(len(self.paths)): # Decrement path counts and remove path if self.paths[i][1] and np.sum(np.logical_and(self.state_space[FIRE_INDEX], self.paths[i][0])) > 0: @@ -137,7 +147,8 @@ def update_paths_and_evactuations(self): # Stop evacuating an area if it was taking the removed path if i in self.evacuating_paths: - pop_centers = np.array(self.evacuating_paths[i])[0] + # pop_centers = np.array(self.evacuating_paths[i])[0] + pop_centers = np.array(self.evacuating_paths[i]) pop_rows, pop_cols = pop_centers[:,0], pop_centers[:,1] # Reset timestamp and evacuation index @@ -150,7 +161,7 @@ def update_paths_and_evactuations(self): # for the below, this code works for if multiple population centers are taking the same path and # finish at the same time, but if we have it so that two population centers can't take the same # path it could probably be simplified - pop_centers = np.array(self.evacuating_paths[i])[0] + pop_centers = np.array(self.evacuating_paths[i]) pop_rows, pop_cols = pop_centers[:,0], pop_centers[:,1] self.evacuating_timestamps[pop_rows,pop_cols] -= 1 done_evacuating = np.where(self.evacuating_timestamps == 0) @@ -165,7 +176,7 @@ def update_paths_and_evactuations(self): done_evacuating = np.array([done_evacuating[0], done_evacuating[1]]) done_evacuating = np.transpose(done_evacuating) for j in range(done_evacuating.shape[0]): - self.evacuating_paths[i][0].remove(list(done_evacuating[j])) + self.evacuating_paths[i].remove(list(done_evacuating[j])) # this population center is done evacuating, so we can set its timestamp back to infinity # (this is important so that we don't try to remove this from self.evacuating paths twice - @@ -174,10 +185,10 @@ def update_paths_and_evactuations(self): self.evacuating_timestamps[update_row, update_col] = np.inf # no more population centers are using this path, so we delete it - if len(self.evacuating_paths[i][0]) == 0: + if len(self.evacuating_paths[i]) == 0: del self.evacuating_paths[i] - def advance_to_next_timestep(self): + def advance_to_next_timestep(self, manual_fire = None): """ Take three steps: 1. Advance fire forward one timestep @@ -261,8 +272,9 @@ def get_state(self): returned_state[PATHS_INDEX] = np.clip(returned_state[PATHS_INDEX], 0, 1) return returned_state + def get_terminated(self): """ Get the status of the simulation. """ - return ( self.time_step >= 100 ) + return ( self.time_step >= 100 ) \ No newline at end of file