Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
KindXiaoming committed Apr 29, 2024
1 parent 0cc83b4 commit bfe4e84
Show file tree
Hide file tree
Showing 706 changed files with 13,032 additions and 209,745 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ __pycache__/
docs/_build/
docs/_static/
docs/_templates
folder
68 changes: 54 additions & 14 deletions docs/API_demo/.ipynb_checkpoints/API_9_video-checkpoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,35 @@
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 6.39e-03 | test loss: 6.40e-03 | reg: 7.91e+00 : 100%|██| 50/50 [01:30<00:00, 1.81s/it]\n"
"train loss: 5.89e-03 | test loss: 5.99e-03 | reg: 7.89e+00 : 100%|██| 50/50 [01:36<00:00, 1.92s/it]\n"
]
},
}
],
"source": [
"from kan import KAN, create_dataset\n",
"import torch\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[4,2,1,1], grid=3, k=3, seed=0)\n",
"f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n",
"dataset = create_dataset(f, n_var=4, train_num=3000)\n",
"\n",
"image_folder = 'video_img'\n",
"\n",
"# train the model\n",
"#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n",
"model.train(dataset, opt=\"LBFGS\", steps=50, lamb=5e-5, lamb_entropy=2., save_fig=True, beta=10, \n",
" in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],\n",
" out_vars=[r'${\\rm exp}({\\rm sin}(x_1^2+x_2^2)+{\\rm sin}(x_3^2+x_4^2))$'],\n",
" img_folder=image_folder);\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c18245a3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
Expand All @@ -49,21 +75,35 @@
}
],
"source": [
"from kan import KAN, create_dataset\n",
"import torch\n",
"import os\n",
"import numpy as np\n",
"import moviepy.video.io.ImageSequenceClip # moviepy == 1.0.3\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[4,2,1,1], grid=3, k=3, seed=0)\n",
"f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n",
"dataset = create_dataset(f, n_var=4, train_num=3000)\n",
"video_name='video'\n",
"fps=5\n",
"\n",
"# train the model\n",
"#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n",
"model.train(dataset, opt=\"LBFGS\", steps=50, lamb=5e-5, lamb_entropy=2., save_video=True, beta=10, \n",
" in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],\n",
" out_vars=[r'${\\rm exp}({\\rm sin}(x_1^2+x_2^2)+{\\rm sin}(x_3^2+x_4^2))$'],\n",
" video_name='video', fps=5);"
"fps = fps\n",
"files = os.listdir(image_folder)\n",
"train_index = []\n",
"for file in files:\n",
" if file[0].isdigit() and file.endswith('.jpg'):\n",
" train_index.append(int(file[:-4]))\n",
"\n",
"train_index = np.sort(train_index)\n",
"\n",
"image_files = [image_folder+'/'+str(train_index[index])+'.jpg' for index in train_index]\n",
"\n",
"clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps)\n",
"clip.write_videofile(video_name+'.mp4')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88d0d737",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
68 changes: 54 additions & 14 deletions docs/API_demo/API_9_video.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,35 @@
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 6.39e-03 | test loss: 6.40e-03 | reg: 7.91e+00 : 100%|██| 50/50 [01:30<00:00, 1.81s/it]\n"
"train loss: 5.89e-03 | test loss: 5.99e-03 | reg: 7.89e+00 : 100%|██| 50/50 [01:36<00:00, 1.92s/it]\n"
]
},
}
],
"source": [
"from kan import KAN, create_dataset\n",
"import torch\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[4,2,1,1], grid=3, k=3, seed=0)\n",
"f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n",
"dataset = create_dataset(f, n_var=4, train_num=3000)\n",
"\n",
"image_folder = 'video_img'\n",
"\n",
"# train the model\n",
"#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n",
"model.train(dataset, opt=\"LBFGS\", steps=50, lamb=5e-5, lamb_entropy=2., save_fig=True, beta=10, \n",
" in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],\n",
" out_vars=[r'${\\rm exp}({\\rm sin}(x_1^2+x_2^2)+{\\rm sin}(x_3^2+x_4^2))$'],\n",
" img_folder=image_folder);\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c18245a3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
Expand All @@ -49,21 +75,35 @@
}
],
"source": [
"from kan import KAN, create_dataset\n",
"import torch\n",
"import os\n",
"import numpy as np\n",
"import moviepy.video.io.ImageSequenceClip # moviepy == 1.0.3\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[4,2,1,1], grid=3, k=3, seed=0)\n",
"f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n",
"dataset = create_dataset(f, n_var=4, train_num=3000)\n",
"video_name='video'\n",
"fps=5\n",
"\n",
"# train the model\n",
"#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n",
"model.train(dataset, opt=\"LBFGS\", steps=50, lamb=5e-5, lamb_entropy=2., save_video=True, beta=10, \n",
" in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],\n",
" out_vars=[r'${\\rm exp}({\\rm sin}(x_1^2+x_2^2)+{\\rm sin}(x_3^2+x_4^2))$'],\n",
" video_name='video', fps=5);"
"fps = fps\n",
"files = os.listdir(image_folder)\n",
"train_index = []\n",
"for file in files:\n",
" if file[0].isdigit() and file.endswith('.jpg'):\n",
" train_index.append(int(file[:-4]))\n",
"\n",
"train_index = np.sort(train_index)\n",
"\n",
"image_files = [image_folder+'/'+str(train_index[index])+'.jpg' for index in train_index]\n",
"\n",
"clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps)\n",
"clip.write_videofile(video_name+'.mp4')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88d0d737",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Binary file removed docs/API_demo/figures/sp_0_0_0.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_0_0_1.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_0_0_2.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_0_1_0.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_0_1_1.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_0_1_2.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_0_2_0.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_0_2_1.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_0_3_0.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_0_3_1.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_1_0_0.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_1_0_1.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_1_1_0.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_1_1_1.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_1_2_0.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_1_2_1.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_2_0_0.png
Binary file not shown.
Binary file removed docs/API_demo/figures/sp_2_1_0.png
Binary file not shown.
File renamed without changes.
File renamed without changes.
Binary file removed docs/example_files/example_1_1.png
Binary file not shown.
Binary file removed docs/model_ckpt/ckpt1
Binary file not shown.
Binary file removed docs/model_ckpt/ckpt2
Binary file not shown.
Binary file removed docs/model_ckpt/ckpt3
Binary file not shown.
Binary file removed docs/video.mp4
Binary file not shown.
Binary file removed docs/video/0.jpg
Binary file not shown.
Binary file removed docs/video/1.jpg
Binary file not shown.
Binary file removed docs/video/10.jpg
Binary file not shown.
Binary file removed docs/video/11.jpg
Binary file not shown.
Binary file removed docs/video/12.jpg
Binary file not shown.
Binary file removed docs/video/13.jpg
Binary file not shown.
Binary file removed docs/video/14.jpg
Diff not rendered.
Binary file removed docs/video/15.jpg
Diff not rendered.
Binary file removed docs/video/16.jpg
Diff not rendered.
Binary file removed docs/video/17.jpg
Diff not rendered.
Binary file removed docs/video/18.jpg
Diff not rendered.
Binary file removed docs/video/19.jpg
Diff not rendered.
Binary file removed docs/video/2.jpg
Diff not rendered.
Binary file removed docs/video/20.jpg
Diff not rendered.
Binary file removed docs/video/21.jpg
Diff not rendered.
Binary file removed docs/video/22.jpg
Diff not rendered.
Binary file removed docs/video/23.jpg
Diff not rendered.
Binary file removed docs/video/24.jpg
Diff not rendered.
Binary file removed docs/video/25.jpg
Diff not rendered.
Binary file removed docs/video/26.jpg
Diff not rendered.
Binary file removed docs/video/27.jpg
Diff not rendered.
Binary file removed docs/video/28.jpg
Diff not rendered.
Binary file removed docs/video/29.jpg
Diff not rendered.
Binary file removed docs/video/3.jpg
Diff not rendered.
Binary file removed docs/video/30.jpg
Diff not rendered.
Binary file removed docs/video/31.jpg
Diff not rendered.
Binary file removed docs/video/32.jpg
Diff not rendered.
Binary file removed docs/video/33.jpg
Diff not rendered.
Binary file removed docs/video/34.jpg
Diff not rendered.
Binary file removed docs/video/35.jpg
Diff not rendered.
Binary file removed docs/video/36.jpg
Diff not rendered.
Binary file removed docs/video/37.jpg
Diff not rendered.
Binary file removed docs/video/38.jpg
Diff not rendered.
Binary file removed docs/video/39.jpg
Diff not rendered.
Binary file removed docs/video/4.jpg
Diff not rendered.
Binary file removed docs/video/40.jpg
Diff not rendered.
Binary file removed docs/video/41.jpg
Diff not rendered.
Binary file removed docs/video/42.jpg
Diff not rendered.
Binary file removed docs/video/43.jpg
Diff not rendered.
Binary file removed docs/video/44.jpg
Diff not rendered.
Binary file removed docs/video/45.jpg
Diff not rendered.
Binary file removed docs/video/46.jpg
Diff not rendered.
Binary file removed docs/video/47.jpg
Diff not rendered.
Binary file removed docs/video/48.jpg
Diff not rendered.
Binary file removed docs/video/49.jpg
Diff not rendered.
Binary file removed docs/video/5.jpg
Diff not rendered.
Binary file removed docs/video/6.jpg
Diff not rendered.
Binary file removed docs/video/7.jpg
Diff not rendered.
Binary file removed docs/video/8.jpg
Diff not rendered.
Binary file removed docs/video/9.jpg
Diff not rendered.
Binary file removed docs/video/sp_0_0_0.png
Diff not rendered.
Binary file removed docs/video/sp_0_0_1.png
Diff not rendered.
Binary file removed docs/video/sp_0_1_0.png
Diff not rendered.
Binary file removed docs/video/sp_0_1_1.png
Diff not rendered.
Binary file removed docs/video/sp_0_2_0.png
Diff not rendered.
Binary file removed docs/video/sp_0_2_1.png
Diff not rendered.
Binary file removed docs/video/sp_0_3_0.png
Diff not rendered.
Binary file removed docs/video/sp_0_3_1.png
Diff not rendered.
Binary file removed docs/video/sp_1_0_0.png
Diff not rendered.
Binary file removed docs/video/sp_1_1_0.png
Diff not rendered.
Binary file removed docs/video/sp_2_0_0.png
Diff not rendered.
Binary file removed figures/sp_0_0_0.png
Diff not rendered.
Binary file removed figures/sp_0_0_1.png
Diff not rendered.
Binary file removed figures/sp_0_0_2.png
Diff not rendered.
Binary file removed figures/sp_0_0_3.png
Diff not rendered.
Binary file removed figures/sp_0_0_4.png
Diff not rendered.
Binary file removed figures/sp_0_1_0.png
Diff not rendered.
Binary file removed figures/sp_0_1_1.png
Diff not rendered.
Binary file removed figures/sp_0_1_2.png
Diff not rendered.
Binary file removed figures/sp_0_1_3.png
Diff not rendered.
Binary file removed figures/sp_0_1_4.png
Diff not rendered.
Binary file removed figures/sp_1_0_0.png
Diff not rendered.
Binary file removed figures/sp_1_1_0.png
Diff not rendered.
Binary file removed figures/sp_1_2_0.png
Diff not rendered.
Binary file removed figures/sp_1_3_0.png
Diff not rendered.
Binary file removed figures/sp_1_4_0.png
Diff not rendered.
37 changes: 7 additions & 30 deletions kan/KAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from tqdm import tqdm
import random
import copy
import moviepy.video.io.ImageSequenceClip


class KAN(nn.Module):
Expand Down Expand Up @@ -773,7 +772,7 @@ def score2alpha(score):



def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1 = 1., lamb_entropy = 2., lamb_coef = 0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn = None, lr=1., stop_grid_update_step=50, batch=-1, small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_video=False, in_vars=None, out_vars=None, video_folder='./video', beta=3, save_fig_freq=1, video_name='video', fps=1, device='cpu'):
def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1 = 1., lamb_entropy = 2., lamb_coef = 0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn = None, lr=1., stop_grid_update_step=50, batch=-1, small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu'):
'''
training
Expand Down Expand Up @@ -813,10 +812,6 @@ def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1 = 1., l
device
save_fig_freq : int
save figure every (save_fig_freq) step
video_name : str
the name of the video, default 'video'
fps : int
frame per second, default fps=1
Returns:
--------
Expand Down Expand Up @@ -905,9 +900,9 @@ def closure():
objective.backward()
return objective

if save_video:
if not os.path.exists(video_folder):
os.makedirs(video_folder)
if save_fig:
if not os.path.exists(img_folder):
os.makedirs(img_folder)

for _ in pbar:

Expand Down Expand Up @@ -950,31 +945,13 @@ def closure():
results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
results['reg'].append(reg_.cpu().detach().numpy())

if save_video and _ % save_fig_freq == 0:
if save_fig and _ % save_fig_freq == 0:

self.plot(folder=video_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
plt.savefig(video_folder+'/'+str(_)+'.jpg', bbox_inches='tight', dpi=200)
self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
plt.savefig(img_folder+'/'+str(_)+'.jpg', bbox_inches='tight', dpi=200)
plt.close()


if save_video:
image_folder = video_folder
#video_name = 'haha'
fps = fps
files = os.listdir(image_folder)
train_index = []
for file in files:
if file[0].isdigit() and file.endswith('.jpg'):
train_index.append(int(file[:-4]))

train_index = np.sort(train_index)

image_files = [image_folder+'/'+str(train_index[index])+'.jpg' for index in train_index]

clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps)
clip.write_videofile(video_name+'.mp4')
#print('saving video to',video_name+'.mp4')

return results


Expand Down
49 changes: 49 additions & 0 deletions pykan.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,52 @@ Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE

<img width="600" alt="kan_plot" src="https://github.com/KindXiaoming/pykan/assets/23551623/a2d2d225-b4d2-4c1e-823e-bc45c7ea96f9">

# Kolmogorov-Arnold Newtworks (KANs)

This the github repo for the paper "KAN: Kolmogorov-Arnold Networks" [link]. The documentation can be found here [link].

Kolmogorov-Arnold Networks (KANs) are promising alternatives of Multi-Layer Perceptrons (MLPs). KANs have strong mathematical foundations just like MLPs: MLPs are based on the [universal approximation theorem](https://en.wikipedia.org/wiki/Universal_approximation_theorem), while KANs are based on [Kolmogorov-Arnold representation theorem](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Arnold_representation_theorem). KANs and MLPs are dual: KANs have activation functions on edges, while MLPs have activation functions on nodes. This simple change makes KANs better (sometimes much better!) than MLPs in terms of both model accuracy and interpretability.

<img width="1163" alt="mlp_kan_compare" src="https://github.com/KindXiaoming/pykan/assets/23551623/695adc2d-0d0b-4e4b-bcff-db2c8070f841">

## Installation
There are two ways to install pykan, through pypi or github.

**Installation via github**

```python
git clone https://github.com/KindXiaoming/pykan.git
cd pykan
pip install -e .
```

**Installation via pypi (soon)**

```python
pip install pykan
```


To install requirements:
```python
pip install -r requirements.txt
```

## Documentation
The documenation can be found here [].

## Tutorials

**Quickstart**

Get started with [hellokan.ipynb](./hellokan.ipynb) notebook

**More demos**

Jupyter Notebooks in [docs/Examples](./docs/Examples) and [docs/API_demo](./docs/API\_demo) are ready to play. You may also find these examples in documentation.


1 change: 1 addition & 0 deletions pykan.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
LICENSE
README.md
setup.py
kan/KAN.py
Expand Down
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
matplotlib==3.6.2
moviepy==1.0.3
numpy==1.24.4
scikit_learn==1.1.3
setuptools==65.5.0
sphinx_rtd_theme==2.0.0
sympy==1.11.1
torch==2.2.2
tqdm==4.66.2
Loading

0 comments on commit bfe4e84

Please sign in to comment.