Skip to content

Commit

Permalink
add KAN back
Browse files Browse the repository at this point in the history
  • Loading branch information
KindXiaoming committed Jul 8, 2024
1 parent 50529bb commit 8075a4e
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 10 deletions.
6 changes: 3 additions & 3 deletions kan/MultKAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def get_range(self, l, i, j, verbose=True):
print('y range: [' + '%.2f' % y_min, ',', '%.2f' % y_max, ']')
return x_min, x_max, y_min, y_max

def plot(self, folder="./figures", beta=3, mask=False, metric='fa', scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None):
def plot(self, folder="./figures", beta=3, mask=False, metric='fa', scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None, varscale=1.0):

global Symbol

Expand Down Expand Up @@ -694,7 +694,7 @@ def score2alpha(score):
n = self.width_in[0]
for i in range(n):
if isinstance(in_vars[i], sympy.Expr):
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, f'${latex(in_vars[i])}$', fontsize=40 * scale, horizontalalignment='center', verticalalignment='center')
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, f'${latex(in_vars[i])}$', fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center')
else:
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i], fontsize=40 * scale, horizontalalignment='center', verticalalignment='center')

Expand All @@ -706,7 +706,7 @@ def score2alpha(score):
if isinstance(out_vars[i], sympy.Expr):
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0+z0) * (len(self.width) - 1) + 0.15, f'${latex(out_vars[i])}$', fontsize=40 * scale, horizontalalignment='center', verticalalignment='center')
else:
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0+z0) * (len(self.width) - 1) + 0.15, out_vars[i], fontsize=40 * scale, horizontalalignment='center', verticalalignment='center')
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0+z0) * (len(self.width) - 1) + 0.15, out_vars[i], fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center')

if title != None:
plt.gcf().get_axes()[0].text(0.5, (y0+z0) * (len(self.width) - 1) + 0.3, title, fontsize=40 * scale, horizontalalignment='center', verticalalignment='center')
Expand Down
1 change: 1 addition & 0 deletions kan/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .MultKAN import *
from .KAN import *
20 changes: 19 additions & 1 deletion kan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sklearn.linear_model import LinearRegression
import sympy
import yaml

from sympy.utilities.lambdify import lambdify

# sigmoid = sympy.Function('sigmoid')
# name: (torch implementation, sympy implementation)
Expand Down Expand Up @@ -303,3 +303,21 @@ def ex_round(ex1, n_digit):
if isinstance(a, sympy.Float):
ex2 = ex2.subs(a, round(a, n_digit))
return ex2


def augment_input(orig_vars, aux_vars, x):

# if x is a tensor
if isinstance(x, torch.Tensor):

for aux_var in aux_vars:
func = lambdify(orig_vars, aux_var,'numpy') # returns a numpy-ready function
aux_value = torch.from_numpy(func(*[x[:,[i]].numpy() for i in range(len(orig_vars))]))
x = torch.cat([x, aux_value], dim=1)

# if x is a dataset
elif isinstance(x, dict):
x['train_input'] = augment_input(orig_vars, aux_vars, x['train_input'])
x['test_input'] = augment_input(orig_vars, aux_vars, x['test_input'])

return x
Binary file removed pde/Saloniki_June24.jpeg
Binary file not shown.
Binary file modified sr/figures/sp_0_0_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified sr/figures/sp_0_0_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified sr/figures/sp_0_1_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified sr/figures/sp_0_1_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "5a8b091f",
"id": "f6900184",
"metadata": {},
"outputs": [
{
Expand All @@ -106,7 +106,7 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "0ba4fcef",
"id": "c3a785ae",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -119,7 +119,7 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "60c879b7",
"id": "10d710ec",
"metadata": {},
"outputs": [
{
Expand Down
6 changes: 3 additions & 3 deletions tutorials/MultKAN_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "5a8b091f",
"id": "f6900184",
"metadata": {},
"outputs": [
{
Expand All @@ -106,7 +106,7 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "0ba4fcef",
"id": "c3a785ae",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -119,7 +119,7 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "60c879b7",
"id": "10d710ec",
"metadata": {},
"outputs": [
{
Expand Down

0 comments on commit 8075a4e

Please sign in to comment.