From 0f8fcdcff386d027e89b5a8c7fb8a71c12720f1c Mon Sep 17 00:00:00 2001 From: Adam Scibior Date: Thu, 11 Jul 2024 18:22:37 -0700 Subject: [PATCH 1/2] Let updated present mask overwrite no reentry mask --- torchdrivesim/simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdrivesim/simulator.py b/torchdrivesim/simulator.py index 753a3bf..14751b6 100644 --- a/torchdrivesim/simulator.py +++ b/torchdrivesim/simulator.py @@ -2143,7 +2143,7 @@ def update_exposed_agents(self): def update_present_mask(self, present_mask): self.inner_simulator.update_present_mask(present_mask) - self.previous_present_mask = self.agent_functor(torch.logical_and, present_mask, self.previous_present_mask) + self.previous_present_mask = present_mask def get_present_mask(self): present_mask = super().get_present_mask() From 419f5b95b0633a6f7f67b84645296e4b0c9c7ea9 Mon Sep 17 00:00:00 2001 From: Adam Scibior Date: Thu, 11 Jul 2024 18:40:39 -0700 Subject: [PATCH 2/2] Fix previous_present_mask update to ensure no reentry --- torchdrivesim/simulator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchdrivesim/simulator.py b/torchdrivesim/simulator.py index 14751b6..1b29cf3 100644 --- a/torchdrivesim/simulator.py +++ b/torchdrivesim/simulator.py @@ -2139,7 +2139,9 @@ def update_exposed_agents(self): self.proximal_timesteps = self.across_agent_types( lambda x, y: x.mul(y), self.proximal_timesteps, self.previous_present_mask ) - self.update_present_mask(self.get_present_mask()) + present_mask = self.get_present_mask() + self.inner_simulator.update_present_mask(present_mask) + self.previous_present_mask = self.agent_functor(torch.logical_and, present_mask, self.previous_present_mask) def update_present_mask(self, present_mask): self.inner_simulator.update_present_mask(present_mask)