diff --git a/kan/KAN.py b/kan/KAN.py index c91d2123..2c0a673b 100644 --- a/kan/KAN.py +++ b/kan/KAN.py @@ -865,7 +865,6 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor): batch_size_test = dataset['test_input'].shape[0] else: batch_size = batch - batch_size_test = batch global train_loss, reg_ @@ -890,7 +889,6 @@ def closure(): for _ in pbar: train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False) - test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False) if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid: self.update_grid_from_samples(dataset['train_input'][train_id].to(device)) @@ -911,7 +909,7 @@ def closure(): loss.backward() optimizer.step() - test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(device)), dataset['test_label'][test_id].to(device)) + test_loss = loss_fn_eval(self.forward(dataset['test_input'].to(device)), dataset['test_label'].to(device)) if _ % log == 0: pbar.set_description("train loss: %.2e | test loss: %.2e | reg: %.2e " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))