Skip to content

Commit

Permalink
Remove batching on testing dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
kamil2oster committed Jul 8, 2024
1 parent 6e6d483 commit 8fa025e
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions kan/KAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,6 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):
batch_size_test = dataset['test_input'].shape[0]
else:
batch_size = batch
batch_size_test = batch

global train_loss, reg_

Expand All @@ -890,7 +889,6 @@ def closure():
for _ in pbar:

train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)

if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
Expand All @@ -911,7 +909,7 @@ def closure():
loss.backward()
optimizer.step()

test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(device)), dataset['test_label'][test_id].to(device))
test_loss = loss_fn_eval(self.forward(dataset['test_input'].to(device)), dataset['test_label'].to(device))

if _ % log == 0:
pbar.set_description("train loss: %.2e | test loss: %.2e | reg: %.2e " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
Expand Down

0 comments on commit 8fa025e

Please sign in to comment.