diff --git a/torchdrivesim/simulator.py b/torchdrivesim/simulator.py index 753a3bf..1b29cf3 100644 --- a/torchdrivesim/simulator.py +++ b/torchdrivesim/simulator.py @@ -2139,11 +2139,13 @@ def update_exposed_agents(self): self.proximal_timesteps = self.across_agent_types( lambda x, y: x.mul(y), self.proximal_timesteps, self.previous_present_mask ) - self.update_present_mask(self.get_present_mask()) + present_mask = self.get_present_mask() + self.inner_simulator.update_present_mask(present_mask) + self.previous_present_mask = self.agent_functor(torch.logical_and, present_mask, self.previous_present_mask) def update_present_mask(self, present_mask): self.inner_simulator.update_present_mask(present_mask) - self.previous_present_mask = self.agent_functor(torch.logical_and, present_mask, self.previous_present_mask) + self.previous_present_mask = present_mask def get_present_mask(self): present_mask = super().get_present_mask()