diff --git a/src/bmi/estimators/neural/_training_log.py b/src/bmi/estimators/neural/_training_log.py index 0e7348f4..69a46a1b 100644 --- a/src/bmi/estimators/neural/_training_log.py +++ b/src/bmi/estimators/neural/_training_log.py @@ -22,7 +22,10 @@ def __init__( Args: max_n_steps: maximum number of training steps allowed early_stopping: whether early stopping is turned on - train_smooth_factor: TODO(Frederic, Pawel): Add description. + train_smooth_factor: fraction of the training history length used + as the smoothing window for convergence checks. E.g. when + `train_smooth_factor=0.1` and the current training history + has length 1000, averages over 100 steps will be used. Max 0.5. verbose: whether to print information during the training enable_tqdm: whether to use tqdm's progress bar during training history_in_additional_information: whether the generated additional @@ -30,9 +33,11 @@ def __init__( training and test populations). We recommend keeping this flag turned on. """ + assert train_smooth_factor <= 0.5, "train_smooth_factor can be at most 0.5" + self.max_n_steps = max_n_steps self.early_stopping = early_stopping - self.train_smooth_window = int(max_n_steps * train_smooth_factor) + self.train_smooth_factor = train_smooth_factor self.verbose = verbose self._train_history_in_additional_information = train_history_in_additional_information @@ -114,40 +119,46 @@ def detect_warnings(self): # noqa: C901 if self.verbose: print("WARNING: Early stopping enabled but max_n_steps reached.") - # analyze training + # get train MI history train_mi = jnp.array([mi for _step, mi in self._mi_train_history]) - w = self.train_smooth_window - cs = jnp.cumsum(train_mi) - # TODO(Pawel, Frederic): If training smooth window is too - # long we will have an error that subtraction between (n,) - # and (0,) arrays cannot be performed. + w = int(self.train_smooth_factor * len(train_mi)) + + # check if training long enough to compute diagnostics + if w < 1: + self._additional_information["training_too_short_for_diagnostics"] = True + if self.verbose: + print("WARNING: Training too short to compute diagnostics.") + return + + # compute smoothed mi + cs = jnp.cumsum(jnp.concatenate([jnp.zeros(1), train_mi])) train_mi_smooth = (cs[w:] - cs[:-w]) / w - if len(train_mi_smooth) > 0: - train_mi_smooth_max = float(train_mi_smooth.max()) - train_mi_smooth_fin = float(train_mi_smooth[-1]) - if train_mi_smooth_max > 1.05 * train_mi_smooth_fin: - self._additional_information["max_training_mi_decreased"] = True - if self.verbose: - print( - f"WARNING: Smoothed training MI fell compared to highest value: " - f"max={train_mi_smooth_max:.3f} vs " - f"final={train_mi_smooth_fin:.3f}" - ) - - w = self.train_smooth_window - if len(train_mi_smooth) >= w: - train_mi_smooth_fin = float(train_mi_smooth[-1]) - train_mi_smooth_prv = float(train_mi_smooth[-w]) - if train_mi_smooth_fin > 1.05 * train_mi_smooth_prv: - self._additional_information["training_mi_still_increasing"] = True - if self.verbose: - print( - f"WARNING: Smoothed raining MI was still " - f"increasing when training stopped: " - f"final={train_mi_smooth_fin:.3f} vs " - f"{w} step(s) ago={train_mi_smooth_prv:.3f}" - ) + # n + 1 - w >= w + 1 since w <= int(0.5 * n) + assert len(train_mi_smooth) >= w + 1 + + train_mi_smooth_max = float(train_mi_smooth.max()) + train_mi_smooth_fin = float(train_mi_smooth[-1]) + if train_mi_smooth_max > 1.05 * train_mi_smooth_fin: + self._additional_information["max_training_mi_decreased"] = True + if self.verbose: + print( + f"WARNING: Smoothed training MI fell compared to highest value: " + f"max={train_mi_smooth_max:.3f} vs " + f"final={train_mi_smooth_fin:.3f}" + ) + + train_mi_smooth_fin = float(train_mi_smooth[-1]) + train_mi_smooth_prv = float(train_mi_smooth[-w]) + if train_mi_smooth_fin > 1.05 * train_mi_smooth_prv: + self._additional_information["training_mi_still_increasing"] = True + if self.verbose: + print( + f"WARNING: Smoothed raining MI was still " + f"increasing when training stopped: " + f"final={train_mi_smooth_fin:.3f} vs " + f"{w} step(s) ago={train_mi_smooth_prv:.3f}" + ) def _tqdm_init(self): self._tqdm = tqdm.tqdm(