Skip to content

Commit

Permalink
update grid extension
Browse files Browse the repository at this point in the history
  • Loading branch information
KindXiaoming committed Nov 14, 2024
1 parent c5ebd60 commit b120446
Show file tree
Hide file tree
Showing 47 changed files with 759 additions and 155 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ expressiveness
figures
molecule
applications
experiments
34 changes: 31 additions & 3 deletions kan/.ipynb_checkpoints/KANLayer-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_

self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable)
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * self.mask).requires_grad_(sp_trainable) # make scale trainable
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask).requires_grad_(sp_trainable) # make scale trainable
self.base_fun = base_fun


Expand Down Expand Up @@ -197,11 +197,13 @@ def update_grid_from_samples(self, x, mode='sample'):
def get_grid(num_interval):
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
margin = 0.00
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]] + 2 * margin)/num_interval
grid_uniform = grid_adaptive[:,[0]] - margin + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid


grid = get_grid(num_interval)

if mode == 'grid':
Expand All @@ -210,6 +212,8 @@ def get_grid(num_interval):
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)

self.grid.data = extend_grid(grid, k_extend=self.k)
#print('x_pos 2', x_pos.shape)
#print('y_eval 2', y_eval.shape)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

def initialize_grid_from_parent(self, parent, x, mode='sample'):
Expand Down Expand Up @@ -240,16 +244,40 @@ def initialize_grid_from_parent(self, parent, x, mode='sample'):

batch = x.shape[0]

# shrink grid
x_pos = torch.sort(x, dim=0)[0]
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
num_interval = self.grid.shape[1] - 1 - 2*self.k


'''
# based on samples
def get_grid(num_interval):
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid'''

#print('p', parent.grid)
# based on interpolating parent grid
def get_grid(num_interval):
x_pos = parent.grid[:,parent.k:-parent.k]
#print('x_pos', x_pos)
sp2 = KANLayer(in_dim=1, out_dim=self.in_dim,k=1,num=x_pos.shape[1]-1,scale_base_mu=0.0, scale_base_sigma=0.0).to(x.device)

#print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim))
#print('sp2_coef_shape', sp2.coef.shape)
sp2_coef = curve2coef(sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim), x_pos.permute(1,0).unsqueeze(dim=2), sp2.grid[:,:], k=1).permute(1,0,2)
shp = sp2_coef.shape
#sp2_coef = torch.cat([torch.zeros(shp[0], shp[1], 1), sp2_coef, torch.zeros(shp[0], shp[1], 1)], dim=2)
#print('sp2_coef',sp2_coef)
#print(sp2.coef.shape)
sp2.coef.data = sp2_coef
percentile = torch.linspace(-1,1,self.num+1).to(self.device)
grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0)
#print('c', grid)
return grid

grid = get_grid(num_interval)
Expand Down
51 changes: 38 additions & 13 deletions kan/.ipynb_checkpoints/MultKAN-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, sca
self.act_fun = []
self.depth = len(width) - 1

#print('haha1', width)
for i in range(len(width)):
if type(width[i]) == int:
#print(type(width[i]), type(width[i]) == int)
if type(width[i]) == int or type(width[i]) == np.int64:
width[i] = [width[i],0]

#print('haha2', width)

self.width = width

Expand Down Expand Up @@ -196,7 +200,18 @@ def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, sca

for l in range(self.depth):
# splines
sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid, k=k, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
if isinstance(grid, list):
grid_l = grid[l]
else:
grid_l = grid

if isinstance(k, list):
k_l = k[l]
else:
k_l = k


sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid_l, k=k_l, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
self.act_fun.append(sp_batch)

self.node_bias = []
Expand Down Expand Up @@ -951,14 +966,14 @@ def unfix_symbolic(self, l, i, j, log_history=True):
if log_history:
self.log_history('unfix_symbolic')

def unfix_symbolic_all(self):
def unfix_symbolic_all(self, log_history=True):
'''
unfix all activation functions.
'''
for l in range(len(self.width) - 1):
for i in range(self.width[l]):
for j in range(self.width[l + 1]):
self.unfix_symbolic(l, i, j)
for i in range(self.width_in[l]):
for j in range(self.width_out[l + 1]):
self.unfix_symbolic(l, i, j, log_history)

def get_range(self, l, i, j, verbose=True):
'''
Expand Down Expand Up @@ -1522,6 +1537,10 @@ def closure():

if _ == steps-1 and old_save_act:
self.save_act = True

if save_fig and _ % save_fig_freq == 0:
save_act = self.save_act
self.save_act = True

train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
Expand Down Expand Up @@ -1579,6 +1598,7 @@ def closure():
self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=200)
plt.close()
self.save_act = save_act

self.log_history('fit')
# revert back to original state
Expand Down Expand Up @@ -2160,7 +2180,7 @@ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=No

return best_name, best_fun, best_r2, best_c;

def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1):
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple = 0.8, r2_threshold=0.0):
'''
automatic symbolic regression for all edges
Expand All @@ -2174,7 +2194,10 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
library of candidate symbolic functions
verbose : int
larger verbosity => more verbosity
weight_simple : float
a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
r2_threshold : float
If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
Returns:
--------
None
Expand All @@ -2191,17 +2214,19 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
for l in range(len(self.width_in) - 1):
for i in range(self.width_in[l]):
for j in range(self.width_out[l + 1]):
#if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
print(f'skipping ({l},{i},{j}) since already symbolic')
elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.:
self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False)
print(f'fixing ({l},{i},{j}) with 0')
else:
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False)
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
if verbose >= 1:
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False, weight_simple=weight_simple)
if r2 >= r2_threshold:
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
if verbose >= 1:
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
else:
print(f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.')

self.log_history('auto_symbolic')

Expand Down
26 changes: 17 additions & 9 deletions kan/.ipynb_checkpoints/spline-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"):
Returns:
--------
y_eval : 3D torch.tensor
shape (number of samples, in_dim, out_dim)
shape (batch, in_dim, out_dim)
'''

Expand All @@ -78,16 +78,16 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"):
return y_eval


def curve2coef(x_eval, y_eval, grid, k, lamb=1e-8):
def curve2coef(x_eval, y_eval, grid, k):
'''
converting B-spline curves to B-spline coefficients using least squares.
Args:
-----
x_eval : 2D torch.tensor
shape (in_dim, out_dim, number of samples)
y_eval : 2D torch.tensor
shape (in_dim, out_dim, number of samples)
shape (batch, in_dim)
y_eval : 3D torch.tensor
shape (batch, in_dim, out_dim)
grid : 2D torch.tensor
shape (in_dim, grid+2*k)
k : int
Expand All @@ -100,25 +100,33 @@ def curve2coef(x_eval, y_eval, grid, k, lamb=1e-8):
coef : 3D torch.tensor
shape (in_dim, out_dim, G+k)
'''
#print('haha', x_eval.shape, y_eval.shape, grid.shape)
batch = x_eval.shape[0]
in_dim = x_eval.shape[1]
out_dim = y_eval.shape[2]
n_coef = grid.shape[1] - k - 1
mat = B_batch(x_eval, grid, k)
mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef)
#print('mat', mat.shape)
y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3)
#print('y_eval', y_eval.shape)
device = mat.device

#coef = torch.linalg.lstsq(mat, y_eval,
#driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]

#coef = torch.linalg.lstsq(mat, y_eval, driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
try:
coef = torch.linalg.lstsq(mat, y_eval).solution[:,:,:,0]
except:
print('lstsq failed')

# manual psuedo-inverse
'''lamb=1e-8
XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
A = XtX + lamb * identity
B = Xty
coef = (A.pinverse() @ B)[:,:,:,0]
coef = (A.pinverse() @ B)[:,:,:,0]'''

return coef

Expand Down
9 changes: 6 additions & 3 deletions kan/.ipynb_checkpoints/utils-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def augment_input(orig_vars, aux_vars, x):
return x


def batch_jacobian(func, x, create_graph=False):
def batch_jacobian(func, x, create_graph=False, mode='scalar'):
'''
jacobian
Expand All @@ -408,7 +408,10 @@ def batch_jacobian(func, x, create_graph=False):
# x in shape (Batch, Length)
def _func_sum(x):
return func(x).sum(dim=0)
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
if mode == 'scalar':
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
elif mode == 'vector':
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)

def batch_hessian(model, x, create_graph=False):
'''
Expand Down Expand Up @@ -588,4 +591,4 @@ def model2param(model):
p = torch.tensor([]).to(model.device)
for params in model.parameters():
p = torch.cat([p, params.reshape(-1,)], dim=0)
return p
return p
34 changes: 31 additions & 3 deletions kan/KANLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_

self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable)
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * self.mask).requires_grad_(sp_trainable) # make scale trainable
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask).requires_grad_(sp_trainable) # make scale trainable
self.base_fun = base_fun


Expand Down Expand Up @@ -197,11 +197,13 @@ def update_grid_from_samples(self, x, mode='sample'):
def get_grid(num_interval):
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
margin = 0.00
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]] + 2 * margin)/num_interval
grid_uniform = grid_adaptive[:,[0]] - margin + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid


grid = get_grid(num_interval)

if mode == 'grid':
Expand All @@ -210,6 +212,8 @@ def get_grid(num_interval):
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)

self.grid.data = extend_grid(grid, k_extend=self.k)
#print('x_pos 2', x_pos.shape)
#print('y_eval 2', y_eval.shape)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

def initialize_grid_from_parent(self, parent, x, mode='sample'):
Expand Down Expand Up @@ -240,16 +244,40 @@ def initialize_grid_from_parent(self, parent, x, mode='sample'):

batch = x.shape[0]

# shrink grid
x_pos = torch.sort(x, dim=0)[0]
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
num_interval = self.grid.shape[1] - 1 - 2*self.k


'''
# based on samples
def get_grid(num_interval):
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid'''

#print('p', parent.grid)
# based on interpolating parent grid
def get_grid(num_interval):
x_pos = parent.grid[:,parent.k:-parent.k]
#print('x_pos', x_pos)
sp2 = KANLayer(in_dim=1, out_dim=self.in_dim,k=1,num=x_pos.shape[1]-1,scale_base_mu=0.0, scale_base_sigma=0.0).to(x.device)

#print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim))
#print('sp2_coef_shape', sp2.coef.shape)
sp2_coef = curve2coef(sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim), x_pos.permute(1,0).unsqueeze(dim=2), sp2.grid[:,:], k=1).permute(1,0,2)
shp = sp2_coef.shape
#sp2_coef = torch.cat([torch.zeros(shp[0], shp[1], 1), sp2_coef, torch.zeros(shp[0], shp[1], 1)], dim=2)
#print('sp2_coef',sp2_coef)
#print(sp2.coef.shape)
sp2.coef.data = sp2_coef
percentile = torch.linspace(-1,1,self.num+1).to(self.device)
grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0)
#print('c', grid)
return grid

grid = get_grid(num_interval)
Expand Down
Loading

0 comments on commit b120446

Please sign in to comment.