diff --git a/pyrorl/pyrorl/envs/environment/environment.py b/pyrorl/pyrorl/envs/environment/environment.py index 724b6d4..02e1caf 100644 --- a/pyrorl/pyrorl/envs/environment/environment.py +++ b/pyrorl/pyrorl/envs/environment/environment.py @@ -12,7 +12,7 @@ from scipy.stats import bernoulli import torch -from .environment_constant import fire_mask +from .environment_constant import fire_mask, linear_wind_transform """ Indices corresponding to each layer of state @@ -29,7 +29,18 @@ class FireWorld: while the 5 represents each of the following: [fire, fuel, populated_areas, evacuating, paths] """ - def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops, num_fire_cells = 2, custom_fire_locations = None): + def __init__( + self, + num_rows, + num_cols, + populated_areas, + paths, + paths_to_pops, + num_fire_cells = 2, + custom_fire_locations = None, + wind_speed = None, + wind_angle = None, + ): """ The constructor defines the state and action space, initializes the fires, and sets the paths and populated areas. @@ -93,6 +104,13 @@ def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops, nu # Set the timestep self.time_step = 0 + #Factor in wind speeds + if wind_speed is not None or wind_angle is not None: + if wind_speed is None or wind_angle is None: + raise TypeError("When setting wind details, wind speed and wind angle must both be provided") + global fire_mask + fire_mask = linear_wind_transform(wind_speed, wind_angle) + def sample_fire_propogation(self): """ Sample the next state of the wildfire model. diff --git a/pyrorl/pyrorl/envs/environment/environment_constant.py b/pyrorl/pyrorl/envs/environment/environment_constant.py index 90ba4a2..ee5fe74 100644 --- a/pyrorl/pyrorl/envs/environment/environment_constant.py +++ b/pyrorl/pyrorl/envs/environment/environment_constant.py @@ -34,4 +34,4 @@ def linear_wind_transform(wind_speed : float, wind_angle : float): wind_vector = np.array([[np.cos(wind_angle)], [np.sin(wind_angle)]]) scaling_term = -(neighbor_vectors @ wind_vector) * speed_to_percent_ratio * wind_speed + 1 - return np.clip(scaling_term * base_fire_mask, a_min=0, a_max=1) \ No newline at end of file + return np.clip(torch.from_numpy(scaling_term) * base_fire_mask, a_min=0, a_max=1) \ No newline at end of file diff --git a/pyrorl/pyrorl/envs/pyrorl.py b/pyrorl/pyrorl/envs/pyrorl.py index 63eda19..5793954 100644 --- a/pyrorl/pyrorl/envs/pyrorl.py +++ b/pyrorl/pyrorl/envs/pyrorl.py @@ -19,7 +19,16 @@ GRASS_COLOR = pygame.Color("#06d6a0") class WildfireEvacuationEnv(gym.Env): - def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops): + def __init__( + self, + num_rows, + num_cols, + populated_areas, + paths, + paths_to_pops, + wind_speed = None, + wind_angle = None, + ): """ Set up the basic environment and its parameters. """ @@ -29,7 +38,15 @@ def __init__(self, num_rows, num_cols, populated_areas, paths, paths_to_pops): self.populated_areas = populated_areas self.paths = paths self.paths_to_pops = paths_to_pops - self.fire_env = FireWorld(num_rows, num_cols, populated_areas, paths, paths_to_pops) + self.fire_env = FireWorld( + num_rows, + num_cols, + populated_areas, + paths, + paths_to_pops, + wind_speed=wind_speed, + wind_angle=wind_angle + ) # Set up action space actions = self.fire_env.get_actions()