Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move edt as part of track asset #117

Open
wants to merge 10 commits into
base: v1.0.0
Choose a base branch
from
41 changes: 41 additions & 0 deletions examples/create_edt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from f110_gym.envs.track import Track
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from f110_gym.envs.track import Track
from f1tenth_gym.envs.track import Track

from scipy.ndimage import distance_transform_edt as edt
import numpy as np

DEFAULT_MAP_NAMES = [
"Austin",
"BrandsHatch",
"Budapest",
"Catalunya",
"Hockenheim",
"IMS",
"Melbourne",
"MexicoCity",
"Montreal",
"Monza",
"MoscowRaceway",
"Nuerburgring",
"Oschersleben",
"Sakhir",
"SaoPaulo",
"Sepang",
"Shanghai",
"Silverstone",
"Sochi",
"Spa",
"Spielberg",
"YasMarina",
"Zandvoort",
]

for track_name in DEFAULT_MAP_NAMES:
track = Track.from_track_name(track_name)
occupancy_map = track.occupancy_map
resolution = track.spec.resolution

dt = resolution * edt(occupancy_map)

# saving
np.save(track.filepath, dt)

track_wedt = Track.from_track_name(track_name)
5 changes: 4 additions & 1 deletion f1tenth_gym/envs/laser_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ def set_map(self, map: str | Track):
self.orig_c = np.cos(self.origin[2])

# get the distance transform
self.dt = get_dt(self.map_img, self.map_resolution)
if self.track.edt is not None:
self.dt = self.track.edt
else:
self.dt = get_dt(self.map_img, self.map_resolution)

return True

Expand Down
19 changes: 19 additions & 0 deletions f1tenth_gym/envs/track/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
from dataclasses import dataclass
from typing import Tuple, Optional
import os

import numpy as np
import yaml
Expand Down Expand Up @@ -31,6 +32,7 @@ class Track:
filepath: str
ext: str
occupancy_map: np.ndarray
edt: np.ndarray
centerline: Raceline
raceline: Raceline

Expand All @@ -40,6 +42,7 @@ def __init__(
filepath: str,
ext: str,
occupancy_map: np.ndarray,
edt: Optional[np.ndarray] = None,
centerline: Optional[Raceline] = None,
raceline: Optional[Raceline] = None,
):
Expand All @@ -56,6 +59,8 @@ def __init__(
file extension of the track image file
occupancy_map : np.ndarray
occupancy grid map
edt : np.ndarray
distance transform of the map
centerline : Raceline, optional
centerline of the track, by default None
raceline : Raceline, optional
Expand All @@ -65,6 +70,7 @@ def __init__(
self.filepath = filepath
self.ext = ext
self.occupancy_map = occupancy_map
self.edt = edt
self.centerline = centerline
self.raceline = raceline

Expand Down Expand Up @@ -125,6 +131,18 @@ def from_track_name(track: str):
occupancy_map[occupancy_map <= 128] = 0.0
occupancy_map[occupancy_map > 128] = 255.0

# if exists and it has been created for the current map image, load edt
map_filepath = (track_dir / map_filename).absolute()
track_filepath = map_filepath.with_suffix("")
edt_filepath = track_dir / f"{track}_map.npy"
if edt_filepath.exists() and os.path.getmtime(edt_filepath) >= os.path.getmtime(map_filepath):
edt = np.load(track_dir / f"{track}_map.npy")
else:
from scipy.ndimage import distance_transform_edt as edt
resolution = track_spec.resolution
edt = resolution * edt(occupancy_map)
np.save(track_filepath, edt)

# if exists, load centerline
if (track_dir / f"{track}_centerline.csv").exists():
centerline = Raceline.from_centerline_file(
Expand All @@ -146,6 +164,7 @@ def from_track_name(track: str):
filepath=str((track_dir / map_filename.stem).absolute()),
ext=map_filename.suffix,
occupancy_map=occupancy_map,
edt=edt,
centerline=centerline,
raceline=raceline,
)
Expand Down
2 changes: 1 addition & 1 deletion f1tenth_gym/envs/track/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def find_track_dir(track_name: str) -> pathlib.Path:
FileNotFoundError
if no map directory matching the track name is found
"""
map_dir = pathlib.Path(__file__).parent.parent.parent.parent / "maps"
map_dir = pathlib.Path(__file__).parent.parent.parent.parent.parent / "maps"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this? it does not point to the right "maps" directory in this way

Suggested change
map_dir = pathlib.Path(__file__).parent.parent.parent.parent.parent / "maps"
map_dir = pathlib.Path(__file__).parent.parent.parent.parent / "maps"


if not (map_dir / track_name).exists():
print("Downloading Files for: " + track_name)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_track.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime
import os
import pathlib
import time
import unittest
Expand Down Expand Up @@ -138,3 +140,29 @@ def test_download_racetrack(self):
# rename the backup track dir to its original name
track_backup_dir = find_track_dir(tmp_dir.stem)
track_backup_dir.rename(track_dir)

def test_edt_update(self):
"""
Test the re-creation of the edt if the map modification time is more recent.
"""
track = Track.from_track_name("Spielberg")

# set the map image modification/access time to now
now = datetime.datetime.now()
dt_epoch = now.timestamp()
map_filepath = pathlib.Path(track.filepath).parent / track.spec.image
os.utime(map_filepath, (dt_epoch, dt_epoch))

# check the edt modification time is now < the map image time
edt_filepath = map_filepath.with_suffix(".npy")
self.assertTrue(os.path.getmtime(map_filepath) > os.path.getmtime(edt_filepath),
f"expected the map image modification time to be > the edt modification time")

# this should force the edt to be recomputed
# check the edt modification time is not > the map image time
track2 = Track.from_track_name("Spielberg")
self.assertTrue(os.path.getmtime(map_filepath) < os.path.getmtime(edt_filepath),
f"expected the map image modification time to be > the edt modification time")

# check consistency in the maps edts
self.assertTrue(np.allclose(track.edt, track2.edt), f"expected the same edt transform for {track.spec.name}")
Loading