Skip to content

Commit

Permalink
fix: Add docs and typespecs (#4)
Browse files Browse the repository at this point in the history
* fix: Add documentation for public functions in Soothsayer library

* fix: Add typespecs to public functions in lib/

feat: Add typespecs notation to each public function in lib/

* fix: Add type definition for Soothsayer.Model.t()

* docs: Update Soothsayer.predict return type

* fix: Address Dialyzer warnings in add_fourier_terms/4 function

* refactor: Replace :float64 with {:f, 64} in Soothsayer.Preprocessor

* fix: Update doctest for Soothsayer.new/1

* fix: update doctest in lib/soothsayer.ex

* fix: Update Preprocessor module to use {:f, 64} instead of :float64

* feat: Add dialyxir dependency for static code analysis
  • Loading branch information
georgeguimaraes authored Sep 10, 2024
1 parent b2cf5f6 commit e88ae7f
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 5 deletions.
105 changes: 105 additions & 0 deletions lib/soothsayer.ex
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
defmodule Soothsayer do
@moduledoc """
The main module for the Soothsayer library, providing functions for creating, fitting, and using time series forecasting models.
"""

alias Explorer.DataFrame
alias Explorer.Series
alias Soothsayer.Model
alias Soothsayer.Preprocessor

@doc """
Creates a new Soothsayer model with the given configuration.
## Parameters
* `config` - A map containing the model configuration. Defaults to an empty map.
## Returns
A new `Soothsayer.Model` struct.
## Examples
iex> Soothsayer.new()
%Soothsayer.Model{config: %{trend: %{enabled: true}, seasonality: %{yearly: %{enabled: true, fourier_terms: 6}, weekly: %{enabled: true, fourier_terms: 3}}, epochs: 100, learning_rate: 0.01}, network: %Axon.Node{}, params: nil}
iex> Soothsayer.new(%{epochs: 200, learning_rate: 0.005})
%Soothsayer.Model{config: %{trend: %{enabled: true}, seasonality: %{yearly: %{enabled: true, fourier_terms: 6}, weekly: %{enabled: true, fourier_terms: 3}}, epochs: 200, learning_rate: 0.005}, network: %Axon.Node{}, params: nil}
"""
@spec new(map()) :: Soothsayer.Model.t()
def new(config \\ %{}) do
default_config = %{
trend: %{enabled: true},
Expand All @@ -19,6 +44,27 @@ defmodule Soothsayer do
Model.new(merged_config)
end

@doc """
Fits the Soothsayer model to the provided data.
## Parameters
* `model` - A `Soothsayer.Model` struct.
* `data` - An `Explorer.DataFrame` containing the training data.
## Returns
An updated `Soothsayer.Model` struct with fitted parameters.
## Examples
iex> model = Soothsayer.new()
iex> data = Explorer.DataFrame.new(%{"ds" => [...], "y" => [...]})
iex> fitted_model = Soothsayer.fit(model, data)
%Soothsayer.Model{config: %{}, network: %Axon.Node{}, params: %{}}
"""
@spec fit(Soothsayer.Model.t(), Explorer.DataFrame.t()) :: Soothsayer.Model.t()
def fit(%Model{} = model, %DataFrame{} = data) do
processed_data = Preprocessor.prepare_data(data, "y", "ds", model.config.seasonality)

Expand All @@ -43,11 +89,70 @@ defmodule Soothsayer do
}
end

@doc """
Makes predictions using a fitted Soothsayer model.
## Parameters
* `model` - A fitted `Soothsayer.Model` struct.
* `x` - An `Explorer.Series` containing the dates for which to make predictions.
## Returns
An `Nx.Tensor` containing the predicted values.
## Examples
iex> fitted_model = Soothsayer.fit(model, training_data)
iex> future_dates = Explorer.Series.from_list([~D[2023-01-01], ~D[2023-01-02], ~D[2023-01-03]])
iex> predictions = Soothsayer.predict(fitted_model, future_dates)
#Nx.Tensor<
f32[3][1]
[
[1.5],
[2.3],
[3.1]
]
>
"""
@spec predict(Soothsayer.Model.t(), Explorer.Series.t()) :: Nx.Tensor.t()
def predict(%Model{} = model, %Series{} = x) do
%{combined: combined} = predict_components(model, x)
combined
end

@doc """
Makes predictions and returns the individual components (trend, seasonality) using a fitted Soothsayer model.
## Parameters
* `model` - A fitted `Soothsayer.Model` struct.
* `x` - An `Explorer.Series` containing the dates for which to make predictions.
## Returns
A map containing the predicted values for each component (trend, yearly seasonality, weekly seasonality) and the combined prediction.
## Examples
iex> fitted_model = Soothsayer.fit(model, training_data)
iex> future_dates = Explorer.Series.from_list([~D[2023-01-01], ~D[2023-01-02], ~D[2023-01-03]])
iex> predictions = Soothsayer.predict_components(fitted_model, future_dates)
%{
combined: #Nx.Tensor<...>,
trend: #Nx.Tensor<...>,
yearly_seasonality: #Nx.Tensor<...>,
weekly_seasonality: #Nx.Tensor<...>
}
"""
@spec predict_components(Soothsayer.Model.t(), Explorer.Series.t()) :: %{
combined: Nx.Tensor.t(),
trend: Nx.Tensor.t(),
yearly_seasonality: Nx.Tensor.t(),
weekly_seasonality: Nx.Tensor.t()
}
def predict_components(%Model{} = model, %Series{} = x) do
processed_x =
Preprocessor.prepare_data(DataFrame.new(%{"ds" => x}), nil, "ds", model.config.seasonality)
Expand Down
103 changes: 103 additions & 0 deletions lib/soothsayer/model.ex
Original file line number Diff line number Diff line change
@@ -1,13 +1,61 @@
defmodule Soothsayer.Model do
@moduledoc """
Defines the structure and operations for the Soothsayer forecasting model.
"""

defstruct [:network, :params, :config]

@type t :: %__MODULE__{
network: Axon.t(),
params: term() | nil,
config: map()
}

@doc """
Creates a new Soothsayer.Model struct with the given configuration.
## Parameters
* `config` - A map containing the model configuration.
## Returns
A new `Soothsayer.Model` struct.
## Examples
iex> config = %{trend: %{enabled: true}, seasonality: %{yearly: %{enabled: true, fourier_terms: 6}}}
iex> Soothsayer.Model.new(config)
%Soothsayer.Model{network: ..., params: nil, config: ^config}
"""
@spec new(map()) :: t()
def new(config) do
%__MODULE__{
network: build_network(config),
config: config
}
end

@doc """
Builds the neural network for the Soothsayer model based on the given configuration.
## Parameters
* `config` - A map containing the model configuration.
## Returns
An Axon neural network structure.
## Examples
iex> config = %{trend: %{enabled: true}, seasonality: %{yearly: %{enabled: true, fourier_terms: 6}}}
iex> network = Soothsayer.Model.build_network(config)
#Axon.Node<...>
"""
@spec build_network(map()) :: Axon.t()
def build_network(config) do
trend_input = Axon.input("trend", shape: {nil, 1})
yearly_input = Axon.input("yearly", shape: {nil, 2 * config.seasonality.yearly.fourier_terms})
Expand Down Expand Up @@ -44,6 +92,30 @@ defmodule Soothsayer.Model do
})
end

@doc """
Fits the Soothsayer model to the provided data.
## Parameters
* `model` - A `Soothsayer.Model` struct.
* `x` - A map of input tensors.
* `y` - A tensor of target values.
* `epochs` - The number of training epochs.
## Returns
An updated `Soothsayer.Model` struct with fitted parameters.
## Examples
iex> model = Soothsayer.Model.new(config)
iex> x = %{"trend" => trend_tensor, "yearly" => yearly_tensor, "weekly" => weekly_tensor}
iex> y = target_tensor
iex> fitted_model = Soothsayer.Model.fit(model, x, y, 100)
%Soothsayer.Model{...}
"""
@spec fit(t(), %{String.t() => Nx.Tensor.t()}, Nx.Tensor.t(), non_neg_integer()) :: t()
def fit(model, x, y, epochs) do
{init_fn, _predict_fn} = Axon.build(model.network)
initial_params = init_fn.(x, %{})
Expand All @@ -63,6 +135,37 @@ defmodule Soothsayer.Model do
%{model | params: trained_params}
end

@doc """
Makes predictions using a fitted Soothsayer model.
## Parameters
* `model` - A fitted `Soothsayer.Model` struct.
* `x` - A map of input tensors.
## Returns
A map containing the predicted values for each component and the combined prediction.
## Examples
iex> fitted_model = Soothsayer.Model.fit(model, training_x, training_y, 100)
iex> x = %{"trend" => future_trend_tensor, "yearly" => future_yearly_tensor, "weekly" => future_weekly_tensor}
iex> predictions = Soothsayer.Model.predict(fitted_model, x)
%{
combined: #Nx.Tensor<...>,
trend: #Nx.Tensor<...>,
yearly_seasonality: #Nx.Tensor<...>,
weekly_seasonality: #Nx.Tensor<...>
}
"""
@spec predict(t(), %{String.t() => Nx.Tensor.t()}) :: %{
combined: Nx.Tensor.t(),
trend: Nx.Tensor.t(),
yearly_seasonality: Nx.Tensor.t(),
weekly_seasonality: Nx.Tensor.t()
}
def predict(model, x) do
{_init_fn, predict_fn} = Axon.build(model.network)
predict_fn.(model.params, x)
Expand Down
35 changes: 32 additions & 3 deletions lib/soothsayer/preprocessor.ex
Original file line number Diff line number Diff line change
@@ -1,7 +1,34 @@
defmodule Soothsayer.Preprocessor do
@moduledoc """
Provides data preprocessing functionality for the Soothsayer forecasting model.
"""

alias Explorer.DataFrame
alias Explorer.Series

@doc """
Prepares the input data by adding Fourier terms for yearly and weekly seasonality based on the provided configuration.
## Parameters
* `df` - An `Explorer.DataFrame` containing the input data.
* `y_column` - The name of the target variable column.
* `ds_column` - The name of the date column.
* `seasonality_config` - A map containing the seasonality configuration.
## Returns
An `Explorer.DataFrame` with additional columns for Fourier terms based on the seasonality configuration.
## Examples
iex> df = Explorer.DataFrame.new(%{"ds" => [...], "y" => [...]})
iex> seasonality_config = %{yearly: %{enabled: true, fourier_terms: 6}, weekly: %{enabled: true, fourier_terms: 3}}
iex> prepared_df = Soothsayer.Preprocessor.prepare_data(df, "y", "ds", seasonality_config)
#Explorer.DataFrame<...>
"""
@spec prepare_data(Explorer.DataFrame.t(), String.t() | nil, String.t(), map()) :: Explorer.DataFrame.t()
def prepare_data(df, y_column, ds_column, seasonality_config) do
df =
if seasonality_config.yearly.enabled do
Expand Down Expand Up @@ -39,16 +66,16 @@ defmodule Soothsayer.Preprocessor do
|> Series.from_list()

Series.day_of_year(date_series)
|> Series.cast(:float)
|> Series.cast({:f, 64})
|> Series.divide(days_in_year)

:weekly ->
Series.day_of_week(date_series)
|> Series.cast(:float)
|> Series.cast({:f, 64})
|> Series.divide(Series.from_list(List.duplicate(7.0, Series.size(date_series))))
end

Enum.reduce(1..fourier_terms, df, fn i, acc_df ->
result_df = Enum.reduce(1..fourier_terms, df, fn i, acc_df ->
acc_df
|> DataFrame.put(
"#{period_type}_sin_#{i}",
Expand All @@ -59,5 +86,7 @@ defmodule Soothsayer.Preprocessor do
Series.cos(t |> Series.multiply(2 * :math.pi() * i))
)
end)

result_df
end
end
3 changes: 2 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ defmodule Soothsayer.MixProject do
{:nx, "~> 0.7.3"},
{:axon, "~> 0.6.1"},
{:exla, "~> 0.7.3"},
{:ex_doc, ">= 0.0.0", only: :docs}
{:ex_doc, ">= 0.0.0", only: :docs},
{:dialyxir, "~> 1.0", only: [:dev], runtime: false}
]
end

Expand Down
2 changes: 2 additions & 0 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
"axon": {:hex, :axon, "0.6.1", "1d042fdba1c1b4413a3d65800524feebd1bc8ed218f8cdefe7a97510c3f427f3", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.6.0 or ~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}, {:polaris, "~> 0.1", [hex: :polaris, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "d6b0ae2f0dd284f6bf702edcab71e790d6c01ca502dd06c4070836554f5a48e1"},
"castore": {:hex, :castore, "1.0.8", "dedcf20ea746694647f883590b82d9e96014057aff1d44d03ec90f36a5c0dc6e", [:mix], [], "hexpm", "0b2b66d2ee742cb1d9cb8c8be3b43c3a70ee8651f37b75a8b982e036752983f1"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"dialyxir": {:hex, :dialyxir, "1.4.3", "edd0124f358f0b9e95bfe53a9fcf806d615d8f838e2202a9f430d59566b6b53b", [:mix], [{:erlex, ">= 0.2.6", [hex: :erlex, repo: "hexpm", optional: false]}], "hexpm", "bf2cfb75cd5c5006bec30141b131663299c661a864ec7fbbc72dfa557487a986"},
"earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"},
"elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"},
"erlex": {:hex, :erlex, "0.2.7", "810e8725f96ab74d17aac676e748627a07bc87eb950d2b83acd29dc047a30595", [:mix], [], "hexpm", "3ed95f79d1a844c3f6bf0cea61e0d5612a42ce56da9c03f01df538685365efb0"},
"ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"},
"exla": {:hex, :exla, "0.7.3", "51310270a0976974fc758f7b28ebd6ca8e099b3d6fc78b0d484c808e977cb914", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "5b3d5741a24aada21d3b0feb4b99d1fc3c8457f995a63ea16684d8d5678b96ff"},
"explorer": {:hex, :explorer, "0.9.1", "9c6f175dfd2fa2f432d5fe9a86b81875438a9a1110af5b952c284842bee434e4", [:mix], [{:adbc, "~> 0.1", [hex: :adbc, repo: "hexpm", optional: true]}, {:aws_signature, "~> 0.3", [hex: :aws_signature, repo: "hexpm", optional: false]}, {:castore, "~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:flame, "~> 0.3", [hex: :flame, repo: "hexpm", optional: true]}, {:fss, "~> 0.1", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}, {:rustler, "~> 0.34.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.7", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1 or ~> 4.0.0", [hex: :table_rex, repo: "hexpm", optional: false]}], "hexpm", "d88ec0e78f904c5eaf0b37c4a0ce4632de133515f3740a29fbddd2c0d0a78e77"},
Expand Down
1 change: 0 additions & 1 deletion test/soothsayer_test.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
defmodule SoothsayerTest do
use ExUnit.Case
doctest Soothsayer

alias Explorer.DataFrame
alias Explorer.Series
Expand Down

0 comments on commit e88ae7f

Please sign in to comment.