Skip to content

Commit

Permalink
refactor inference progress
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Sep 3, 2024
1 parent f8f027a commit ed8d14b
Showing 1 changed file with 122 additions and 90 deletions.
212 changes: 122 additions & 90 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,122 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
def _initialize_inference_model(self):
pass

def _process_batch(self, ex: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""Run prediction model on batch.
This method handles running inference on a batch and postprocessing.
Args:
ex: a dictionary holding the input for inference.
Returns:
The input dictionary updated with the predictions.
"""
# Skip inference if model is not loaded
if self.inference_model is None:
return ex

# Run inference on current batch.
preds = self.inference_model.predict_on_batch(ex, numpy=True)

# Add model outputs to the input data example.
ex.update(preds)

# Convert to numpy arrays if not already.
if isinstance(ex["video_ind"], tf.Tensor):
ex["video_ind"] = ex["video_ind"].numpy().flatten()
if isinstance(ex["frame_ind"], tf.Tensor):
ex["frame_ind"] = ex["frame_ind"].numpy().flatten()

# Adjust for potential SizeMatcher scaling.
offset_x = ex.get("offset_x", 0)
offset_y = ex.get("offset_y", 0)
ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2])
ex["instance_peaks"] /= np.expand_dims(
np.expand_dims(ex["scale"], axis=1), axis=1
)

return ex

def _run_batch_json(
self,
examples: List[Dict[str, np.ndarray]],
n_total: int,
max_length: int = 30,
) -> Iterator[Dict[str, np.ndarray]]:
n_processed = 0
n_recent = deque(maxlen=max_length)
elapsed_recent = deque(maxlen=max_length)
last_report = time()
t0_all = time()
t0_batch = time()
for ex in examples:
# Process batch of examples.
ex = self._process_batch(ex)

# Track timing and progress.
elapsed_batch = time() - t0_batch
t0_batch = time()
n_batch = len(ex["frame_ind"])
n_processed += n_batch
elapsed_all = time() - t0_all

# Compute recent rate.
n_recent.append(n_batch)
elapsed_recent.append(elapsed_batch)
rate = sum(n_recent) / sum(elapsed_recent)
eta = (n_total - n_processed) / rate

# Report.
if time() > last_report + self.report_period:
print(
json.dumps(
{
"n_processed": n_processed,
"n_total": n_total,
"elapsed": elapsed_all,
"rate": rate,
"eta": eta,
}
),
flush=True,
)
last_report = time()

# Return results.
yield ex

def _run_batch_rich(
self,
examples: List[Dict[str, np.ndarray]],
n_total: int,
) -> Iterator[Dict[str, np.ndarray]]:
with rich.progress.Progress(
"{task.description}",
rich.progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
"ETA:",
rich.progress.TimeRemainingColumn(),
RateColumn(),
auto_refresh=False,
refresh_per_second=self.report_rate,
speed_estimate_period=5,
) as progress:
task = progress.add_task("Predicting...", total=n_total)
last_report = time()
for ex in examples:
ex = self._process_batch(ex)

progress.update(task, advance=len(ex["frame_ind"]))

# Handle refreshing manually to support notebooks.
if time() > last_report + self.report_period:
progress.refresh()
last_report = time()

# Return results.
yield ex

def _predict_generator(
self, data_provider: Provider
) -> Iterator[Dict[str, np.ndarray]]:
Expand All @@ -389,106 +505,22 @@ def _predict_generator(
if self.inference_model is None:
self._initialize_inference_model()

def process_batch(ex):
# Run inference on current batch.
preds = self.inference_model.predict_on_batch(ex, numpy=True)

# Add model outputs to the input data example.
ex.update(preds)

# Convert to numpy arrays if not already.
if isinstance(ex["video_ind"], tf.Tensor):
ex["video_ind"] = ex["video_ind"].numpy().flatten()
if isinstance(ex["frame_ind"], tf.Tensor):
ex["frame_ind"] = ex["frame_ind"].numpy().flatten()

# Adjust for potential SizeMatcher scaling.
offset_x = ex.get("offset_x", 0)
offset_y = ex.get("offset_y", 0)
ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2])
ex["instance_peaks"] /= np.expand_dims(
np.expand_dims(ex["scale"], axis=1), axis=1
)

return ex

# Compile loop examples before starting time to improve ETA
n_total=len(data_provider)
examples = self.pipeline.make_dataset()

# Loop over data batches with optional progress reporting.
if self.verbosity == "rich":
with rich.progress.Progress(
"{task.description}",
rich.progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
"ETA:",
rich.progress.TimeRemainingColumn(),
RateColumn(),
auto_refresh=False,
refresh_per_second=self.report_rate,
speed_estimate_period=5,
) as progress:
task = progress.add_task("Predicting...", total=len(data_provider))
last_report = time()
for ex in examples:
ex = process_batch(ex)
progress.update(task, advance=len(ex["frame_ind"]))

# Handle refreshing manually to support notebooks.
elapsed_since_last_report = time() - last_report
if elapsed_since_last_report > self.report_period:
progress.refresh()

# Return results.
yield ex
for ex in self._run_batch_rich(examples, n_total=n_total):
yield ex

elif self.verbosity == "json":
n_processed = 0
n_total = len(data_provider)
n_recent = deque(maxlen=30)
elapsed_recent = deque(maxlen=30)
last_report = time()
t0_all = time()
t0_batch = time()
for ex in examples:
# Process batch of examples.
ex = process_batch(ex)

# Track timing and progress.
elapsed_batch = time() - t0_batch
t0_batch = time()
n_batch = len(ex["frame_ind"])
n_processed += n_batch
elapsed_all = time() - t0_all

# Compute recent rate.
n_recent.append(n_batch)
elapsed_recent.append(elapsed_batch)
rate = sum(n_recent) / sum(elapsed_recent)
eta = (n_total - n_processed) / rate

# Report.
elapsed_since_last_report = time() - last_report
if elapsed_since_last_report > self.report_period:
print(
json.dumps(
{
"n_processed": n_processed,
"n_total": n_total,
"elapsed": elapsed_all,
"rate": rate,
"eta": eta,
}
),
flush=True,
)
last_report = time()

# Return results.
for ex in self._run_batch_json(examples, n_total=n_total):
yield ex

else:
for ex in examples:
yield process_batch(ex)
yield self._process_batch(ex)

def predict(
self, data: Union[Provider, sleap.Labels, sleap.Video], make_labels: bool = True
Expand Down

0 comments on commit ed8d14b

Please sign in to comment.