-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: Add Model and Preprocessor tests (#3)
* feat: Add test file for Soothsayer.Model * feat: Add more assertions about the network structure in the build_network/1 test * feat: Add comprehensive tests for Soothsayer.Model * fix: Update test/soothsayer/model_test.exs to fix issues * fix: Replace Axon.get_nodes/1 with Axon.nodes/1 and import Nx macros * fix: Remove unused Nx import and update predict/2 test git commit -m "fix: Remove unused Nx import and update predict/2 test" * fix: Use Nx in test file * fix: Update the `build_network/1` test in the `Soothsayer.ModelTest` module to use `map_size/1` instead of `Enum.count/1` when checking the number of nodes in the Axon network. chore: Add an import for the `Axon` module at the beginning of the `Soothsayer.ModelTest` module to ensure all Axon functions are available. USER: Great, thanks for the review and the commit message. I'll go ahead and make those changes. * refactor: Replace Axon.nodes with Axon.get_output_shape in model_test.exs * fix: Address arithmetic error in build_network/1 test * feat: Add test cases for model training and prediction * fix: Remove unused import of Axon * feat: Add preprocessor tests * fix: Update preprocessor tests to compare Series values instead of directly comparing Series objects * fix: alias Explorer.Series in Soothsayer.PreprocessorTest * tests: Preprocessor tests
- Loading branch information
1 parent
6e32816
commit 46c13f6
Showing
2 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
defmodule Soothsayer.ModelTest do | ||
use ExUnit.Case, async: true | ||
alias Soothsayer.Model | ||
import Nx, only: [is_tensor: 1] | ||
|
||
test "new/1 creates a new model with the given config" do | ||
config = %{ | ||
trend: %{enabled: true}, | ||
seasonality: %{ | ||
yearly: %{enabled: true, fourier_terms: 4}, | ||
weekly: %{enabled: true, fourier_terms: 2} | ||
}, | ||
learning_rate: 0.01, | ||
epochs: 100 | ||
} | ||
|
||
model = Model.new(config) | ||
assert %Model{} = model | ||
assert model.config == config | ||
assert is_struct(model.network, Axon) | ||
end | ||
|
||
test "build_network/1 creates a network based on the config" do | ||
config = %{ | ||
trend: %{enabled: true}, | ||
seasonality: %{ | ||
yearly: %{enabled: true, fourier_terms: 4}, | ||
weekly: %{enabled: true, fourier_terms: 2} | ||
} | ||
} | ||
|
||
network = Model.build_network(config) | ||
assert is_struct(network, Axon) | ||
|
||
# Check input shapes | ||
inputs = Axon.get_inputs(network) | ||
assert Map.has_key?(inputs, "trend") | ||
assert Map.has_key?(inputs, "yearly") | ||
assert Map.has_key?(inputs, "weekly") | ||
|
||
# Verify that the network can be initialized without errors | ||
{init_fn, _predict_fn} = Axon.build(network) | ||
|
||
input = %{ | ||
"trend" => Nx.tensor([[1.0]]), | ||
"yearly" => Nx.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]), | ||
"weekly" => Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) | ||
} | ||
|
||
assert is_map(init_fn.(input, %{})) | ||
end | ||
|
||
test "fit/4 trains the model" do | ||
config = %{ | ||
trend: %{enabled: true}, | ||
seasonality: %{ | ||
yearly: %{enabled: true, fourier_terms: 4}, | ||
weekly: %{enabled: true, fourier_terms: 2} | ||
}, | ||
learning_rate: 0.01, | ||
epochs: 1 | ||
} | ||
|
||
model = Model.new(config) | ||
|
||
x = %{ | ||
"trend" => Nx.tensor([[1.0], [2.0], [3.0]]), | ||
"yearly" => | ||
Nx.tensor([ | ||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], | ||
[0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], | ||
[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] | ||
]), | ||
"weekly" => Nx.tensor([[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.5], [0.3, 0.4, 0.5, 0.6]]) | ||
} | ||
|
||
y = Nx.tensor([[1.0], [2.0], [3.0]]) | ||
|
||
trained_model = Model.fit(model, x, y, 1) | ||
assert is_struct(trained_model, Model) | ||
assert trained_model.params != nil | ||
end | ||
|
||
test "predict/2 makes predictions" do | ||
config = %{ | ||
trend: %{enabled: true}, | ||
seasonality: %{ | ||
yearly: %{enabled: true, fourier_terms: 4}, | ||
weekly: %{enabled: true, fourier_terms: 2} | ||
}, | ||
learning_rate: 0.01, | ||
epochs: 1 | ||
} | ||
|
||
model = Model.new(config) | ||
|
||
x = %{ | ||
"trend" => Nx.tensor([[1.0], [2.0], [3.0]]), | ||
"yearly" => | ||
Nx.tensor([ | ||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], | ||
[0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], | ||
[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] | ||
]), | ||
"weekly" => Nx.tensor([[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.5], [0.3, 0.4, 0.5, 0.6]]) | ||
} | ||
|
||
y = Nx.tensor([[1.0], [2.0], [3.0]]) | ||
|
||
trained_model = Model.fit(model, x, y, 1) | ||
predictions = Model.predict(trained_model, x) | ||
|
||
assert is_map(predictions) | ||
assert Map.has_key?(predictions, :combined) | ||
assert is_tensor(predictions.combined) | ||
assert Nx.shape(predictions.combined) == {3, 1} | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
defmodule Soothsayer.PreprocessorTest do | ||
use ExUnit.Case, async: true | ||
alias Soothsayer.Preprocessor | ||
alias Explorer.DataFrame | ||
alias Explorer.Series | ||
|
||
describe "prepare_data/4" do | ||
test "prepares data with yearly and weekly seasonality" do | ||
df = | ||
DataFrame.new(%{ | ||
"y" => [1, 2, 3, 4, 5], | ||
"ds" => [ | ||
~D[2023-01-01], | ||
~D[2023-04-01], | ||
~D[2023-07-01], | ||
~D[2023-10-01], | ||
~D[2024-01-01] | ||
] | ||
}) | ||
|
||
seasonality_config = %{ | ||
yearly: %{enabled: true, fourier_terms: 3}, | ||
weekly: %{enabled: true, fourier_terms: 2} | ||
} | ||
|
||
result = Preprocessor.prepare_data(df, "y", "ds", seasonality_config) | ||
|
||
assert Series.to_list(result["y"]) == Series.to_list(df["y"]) | ||
assert Series.to_list(result["ds"]) == Series.to_list(df["ds"]) | ||
assert "yearly_sin_1" in DataFrame.names(result) | ||
assert "yearly_cos_1" in DataFrame.names(result) | ||
assert "yearly_sin_3" in DataFrame.names(result) | ||
assert "yearly_cos_3" in DataFrame.names(result) | ||
assert "weekly_sin_1" in DataFrame.names(result) | ||
assert "weekly_cos_1" in DataFrame.names(result) | ||
assert "weekly_sin_2" in DataFrame.names(result) | ||
assert "weekly_cos_2" in DataFrame.names(result) | ||
end | ||
|
||
test "prepares data with only yearly seasonality" do | ||
df = | ||
DataFrame.new(%{ | ||
"y" => [1, 2, 3, 4, 5], | ||
"ds" => [ | ||
~D[2023-01-01], | ||
~D[2023-04-01], | ||
~D[2023-07-01], | ||
~D[2023-10-01], | ||
~D[2024-01-01] | ||
] | ||
}) | ||
|
||
seasonality_config = %{ | ||
yearly: %{enabled: true, fourier_terms: 3}, | ||
weekly: %{enabled: false, fourier_terms: 2} | ||
} | ||
|
||
result = Preprocessor.prepare_data(df, "y", "ds", seasonality_config) | ||
|
||
assert Series.to_list(result["y"]) == Series.to_list(df["y"]) | ||
assert Series.to_list(result["ds"]) == Series.to_list(df["ds"]) | ||
assert "yearly_sin_1" in DataFrame.names(result) | ||
assert "yearly_cos_1" in DataFrame.names(result) | ||
assert "yearly_sin_3" in DataFrame.names(result) | ||
assert "yearly_cos_3" in DataFrame.names(result) | ||
refute "weekly_sin_1" in DataFrame.names(result) | ||
refute "weekly_cos_1" in DataFrame.names(result) | ||
end | ||
|
||
test "prepares data with no seasonality" do | ||
df = | ||
DataFrame.new(%{ | ||
"y" => [1, 2, 3, 4, 5], | ||
"ds" => [ | ||
~D[2023-01-01], | ||
~D[2023-04-01], | ||
~D[2023-07-01], | ||
~D[2023-10-01], | ||
~D[2024-01-01] | ||
] | ||
}) | ||
|
||
seasonality_config = %{ | ||
yearly: %{enabled: false, fourier_terms: 3}, | ||
weekly: %{enabled: false, fourier_terms: 2} | ||
} | ||
|
||
result = Preprocessor.prepare_data(df, "y", "ds", seasonality_config) | ||
|
||
assert Series.to_list(result["y"]) == Series.to_list(df["y"]) | ||
assert Series.to_list(result["ds"]) == Series.to_list(df["ds"]) | ||
assert DataFrame.names(result) == ["y", "ds"] | ||
end | ||
end | ||
end |