Skip to content

Commit

Permalink
support for frames model
Browse files Browse the repository at this point in the history
  • Loading branch information
gferraro committed Oct 2, 2024
1 parent 34f2f21 commit a34b205
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
7 changes: 7 additions & 0 deletions src/ml_tools/hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ def square_width(self):
def frame_size(self):
return self.get("frame_size", 32)

def set_use_segments(self, use_segments):
self["use_segments"] = use_segments
if use_segments:
self["square_width"] = 5
else:
self["square_width"] = 1

#
# @property
# def red_type(self):
Expand Down
3 changes: 3 additions & 0 deletions src/ml_tools/kerasmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def load_training_meta(self, base_dir):
self.ds_by_label = meta.get("by_label", True)
self.excluded_labels = meta.get("excluded_labels")
self.remapped_labels = meta.get("remapped_labels")
self.params.set_use_segments(
meta.get("config").get("build", {}).get("use_segments", True)
)

def shape(self):
if self.model is None:
Expand Down
7 changes: 5 additions & 2 deletions src/ml_tools/thermaldataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def load_dataset(filenames, remap_lookup, labels, args):
extra_label_map=extra_label_map,
include_track=args.get("include_track", False),
num_frames=args.get("num_frames", 25),
channels=args.get("channels", [TrackChannels.thermal.name]),
channels=args.get(
"channels", [TrackChannels.thermal.name, TrackChannels.filtered.name]
),
),
num_parallel_calls=AUTOTUNE,
deterministic=deterministic,
Expand Down Expand Up @@ -183,7 +185,7 @@ def read_tfrecord(
channels=[TrackChannels.thermal.name, TrackChannels.filtered.name],
):
logging.info(
"Read tf record with image %s lbls %s labeld %s aug %s prepr %s only features %s one hot %s include fetures %s",
"Read tf record with image %s lbls %s labeld %s aug %s prepr %s only features %s one hot %s include fetures %s num frames %s",
image_size,
num_labels,
labeled,
Expand All @@ -192,6 +194,7 @@ def read_tfrecord(
only_features,
one_hot,
include_features,
num_frames,
)
load_images = not only_features
tfrecord_format = {
Expand Down

0 comments on commit a34b205

Please sign in to comment.