Skip to content

Commit

Permalink
fixes to environment.py
Browse files Browse the repository at this point in the history
  • Loading branch information
joey-obrien committed Dec 15, 2023
1 parent bd6d61a commit 2cb9681
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions wildfire_evac/wildfire_evac/envs/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import random
from scipy.stats import bernoulli
import torch

from .environment_constant import fire_mask

"""
Expand All @@ -30,14 +29,22 @@ class FireWorld:
"""

def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops, num_fire_cells = 2, custom_fire_locations = None):
# if they don't pass in populated areas and don't say they want auto-generated paths raise an exception, don't do an assert


"""
The constructor defines the state and action space, initializes the fires,
and sets the paths and populated areas.
"""
# Define the state and action space
self.reward = 0
self.state_space = np.zeros([5, num_rows, num_cols])
self.actions = list(np.arange(len(paths) + 1)) # extra action for doing nothing

num_actions = 0
for key in paths_to_pops:
for _ in range(len(paths_to_pops[key])):
num_actions += 1
self.actions = list(np.arange(num_actions + 1)) # extra action for doing nothing

# Associate paths with populated areas and actions
# Note: there seems to be an error that keeps popping up where this dictionary is not
Expand All @@ -47,9 +54,13 @@ def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops, nu
# We want to remember which action index corresponds to which population center
# and which path (because we just provide an array like [1,2,3,4,5,6,7]) which
# would each be mapped to a given population area taking a given path

self.action_to_pop_and_path = { self.actions[-1] : None}
action_val = 0
for key in self.paths_to_pops:
self.action_to_pop_and_path[key] = (paths_to_pops[key], key) # action index: list of pop x,y index and path index [[x,y],path_index]
for pop_area in paths_to_pops[key]:
self.action_to_pop_and_path[key] = (pop_area, action_val)
action_val += 1

# State for the evacuation of populated areas
self.evacuating_paths = {} # path_index : list of pop x,y indices that are evacuating [[x,y],[x,y],...]
Expand Down Expand Up @@ -77,7 +88,7 @@ def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops, nu
# Note: right now paths is different from self.paths but we can change this later if needed
self.paths = []
for path in paths:
path_array = np.array(path)
path_array = np.array(path).astype(int)
path_rows, path_cols = path_array[:,0], path_array[:,1]
self.state_space[PATHS_INDEX,path_rows,path_cols] += 1

Expand Down Expand Up @@ -128,7 +139,6 @@ def update_paths_and_evactuations(self):
2. Also stops evacuating any areas that were taking a burned down path
3. Also decrements the evacuation timestamps
"""
self.state_space[FIRE_INDEX][1,1] = 1
for i in range(len(self.paths)):
# Decrement path counts and remove path
if self.paths[i][1] and np.sum(np.logical_and(self.state_space[FIRE_INDEX], self.paths[i][0])) > 0:
Expand All @@ -137,7 +147,8 @@ def update_paths_and_evactuations(self):

# Stop evacuating an area if it was taking the removed path
if i in self.evacuating_paths:
pop_centers = np.array(self.evacuating_paths[i])[0]
# pop_centers = np.array(self.evacuating_paths[i])[0]
pop_centers = np.array(self.evacuating_paths[i])
pop_rows, pop_cols = pop_centers[:,0], pop_centers[:,1]

# Reset timestamp and evacuation index
Expand All @@ -150,7 +161,7 @@ def update_paths_and_evactuations(self):
# for the below, this code works for if multiple population centers are taking the same path and
# finish at the same time, but if we have it so that two population centers can't take the same
# path it could probably be simplified
pop_centers = np.array(self.evacuating_paths[i])[0]
pop_centers = np.array(self.evacuating_paths[i])
pop_rows, pop_cols = pop_centers[:,0], pop_centers[:,1]
self.evacuating_timestamps[pop_rows,pop_cols] -= 1
done_evacuating = np.where(self.evacuating_timestamps == 0)
Expand All @@ -165,7 +176,7 @@ def update_paths_and_evactuations(self):
done_evacuating = np.array([done_evacuating[0], done_evacuating[1]])
done_evacuating = np.transpose(done_evacuating)
for j in range(done_evacuating.shape[0]):
self.evacuating_paths[i][0].remove(list(done_evacuating[j]))
self.evacuating_paths[i].remove(list(done_evacuating[j]))

# this population center is done evacuating, so we can set its timestamp back to infinity
# (this is important so that we don't try to remove this from self.evacuating paths twice -
Expand All @@ -174,10 +185,10 @@ def update_paths_and_evactuations(self):
self.evacuating_timestamps[update_row, update_col] = np.inf

# no more population centers are using this path, so we delete it
if len(self.evacuating_paths[i][0]) == 0:
if len(self.evacuating_paths[i]) == 0:
del self.evacuating_paths[i]

def advance_to_next_timestep(self):
def advance_to_next_timestep(self, manual_fire = None):
"""
Take three steps:
1. Advance fire forward one timestep
Expand Down Expand Up @@ -261,8 +272,9 @@ def get_state(self):
returned_state[PATHS_INDEX] = np.clip(returned_state[PATHS_INDEX], 0, 1)
return returned_state


def get_terminated(self):
"""
Get the status of the simulation.
"""
return ( self.time_step >= 100 )
return ( self.time_step >= 100 )

0 comments on commit 2cb9681

Please sign in to comment.