From e4f3f8bf677ffced379a63d52525b51bc429ea2f Mon Sep 17 00:00:00 2001 From: Jim137 Date: Fri, 10 May 2024 12:47:05 +0800 Subject: [PATCH] Fix dtype error in `curve2coef` --- kan/spline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kan/spline.py b/kan/spline.py index 6f64d883..1992729d 100644 --- a/kan/spline.py +++ b/kan/spline.py @@ -133,6 +133,6 @@ def curve2coef(x_eval, y_eval, grid, k, device="cpu"): torch.Size([5, 13]) ''' # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar - mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1) + mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1).to(y_eval.dtype) coef = torch.linalg.lstsq(mat.to('cpu'), y_eval.unsqueeze(dim=2).to('cpu')).solution[:, :, 0] # sometimes 'cuda' version may diverge return coef.to(device)