Skip to content

Commit

Permalink
Update KAN.py
Browse files Browse the repository at this point in the history
In the line 131 of the KANLayer.py, the sclae_base holds shape of (input_dim ,output_dim), while here in KAN.py, the scale_base hold shape of input_dim * output_dim, which would make the construction of a KAN fails.
  • Loading branch information
CrazyNicolas authored Jul 12, 2024
1 parent 61ab4ac commit a2811a7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion kan/KAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, scale_base_mu=0.0,
# splines
#scale_base = 1 / np.sqrt(width[l]) + (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * noise_scale_base
scale_base = scale_base_mu * 1 / np.sqrt(width[l]) + \
scale_base_sigma * (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * 1/np.sqrt(width[l])
scale_base_sigma * (torch.randn(width[l] , width[l + 1], ) * 2 - 1) * 1/np.sqrt(width[l])
sp_batch = KANLayer(in_dim=width[l], out_dim=width[l + 1], num=grid, k=k, noise_scale=noise_scale, scale_base=scale_base, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable,
sb_trainable=sb_trainable, device=device)
self.act_fun.append(sp_batch)
Expand Down

0 comments on commit a2811a7

Please sign in to comment.