Skip to content

Commit

Permalink
Merge pull request #330 from CrazyNicolas/master
Browse files Browse the repository at this point in the history
[Fix]Update KAN.py
  • Loading branch information
KindXiaoming authored Jul 12, 2024
2 parents 3029652 + a2811a7 commit e5abcb7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion kan/KAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, scale_base_mu=0.0,
# splines
#scale_base = 1 / np.sqrt(width[l]) + (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * noise_scale_base
scale_base = scale_base_mu * 1 / np.sqrt(width[l]) + \
scale_base_sigma * (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * 1/np.sqrt(width[l])
scale_base_sigma * (torch.randn(width[l] , width[l + 1], ) * 2 - 1) * 1/np.sqrt(width[l])
sp_batch = KANLayer(in_dim=width[l], out_dim=width[l + 1], num=grid, k=k, noise_scale=noise_scale, scale_base=scale_base, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable,
sb_trainable=sb_trainable, device=device)
self.act_fun.append(sp_batch)
Expand Down

0 comments on commit e5abcb7

Please sign in to comment.