Skip to content

Commit

Permalink
docs: Add example in Livebook format
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeguimaraes committed Sep 8, 2024
1 parent 402547a commit 6e32816
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ components = Soothsayer.predict_components(fitted_model, Series.from_list(Enum.t
#> %{comboned: ..., trend: ..., yearly_seasonality: ..., weekly_seasonality: ...}
```

### Livebook Example

You can check the `livebook` dir with some examples on how to use Soothsayer.

### Customizing the Model

You can customize various aspects of the model:
Expand Down
222 changes: 222 additions & 0 deletions livebook/soothsayer_demo_1_1_1.livemd
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Soothsayer Demo

```elixir
# Livebook Setup

Mix.install([
{:soothsayer, path: "~/code/soothsayer"},
{:explorer, "~> 0.9.1"},
{:vega_lite, ">= 0.0.0"},
{:kino_vega_lite, ">= 0.0.0"},
{:kino_explorer, "~> 0.1.20"},
{:req, "~> 0.5.6"}
])

alias Explorer.DataFrame
alias Explorer.Series
alias VegaLite, as: Vl

Nx.global_default_backend(EXLA.Backend)
```

## Section

```elixir
# Generate synthetic data

start_date = ~D[2020-01-01]
end_date = ~D[2023-12-31]
dates = Date.range(start_date, end_date)

y =
Enum.map(dates, fn date ->
days_since_start = Date.diff(date, start_date)
trend = 1000 + 0.5 * days_since_start
yearly_seasonality = 50 * :math.sin(2 * :math.pi() * days_since_start / 365.25)
weekly_seasonality = 20 * :math.cos(2 * :math.pi() * Date.day_of_week(date) / 7)
noise = :rand.normal(0, 200)
trend + yearly_seasonality + weekly_seasonality + noise
end)

df = DataFrame.new(%{"ds" => dates, "y" => y})
```

<!-- livebook:{"attrs":"eyJjaGFydF90aXRsZSI6bnVsbCwiaGVpZ2h0Ijo2MDAsImxheWVycyI6W3siYWN0aXZlIjp0cnVlLCJjaGFydF90eXBlIjoibGluZSIsImNvbG9yX2ZpZWxkIjpudWxsLCJjb2xvcl9maWVsZF9hZ2dyZWdhdGUiOm51bGwsImNvbG9yX2ZpZWxkX2JpbiI6bnVsbCwiY29sb3JfZmllbGRfc2NhbGVfc2NoZW1lIjpudWxsLCJjb2xvcl9maWVsZF90eXBlIjpudWxsLCJkYXRhX3ZhcmlhYmxlIjoiZGYiLCJnZW9kYXRhX2NvbG9yIjoiYmx1ZSIsImxhdGl0dWRlX2ZpZWxkIjpudWxsLCJsb25naXR1ZGVfZmllbGQiOm51bGwsInhfZmllbGQiOiJkcyIsInhfZmllbGRfYWdncmVnYXRlIjpudWxsLCJ4X2ZpZWxkX2JpbiI6bnVsbCwieF9maWVsZF9zY2FsZV90eXBlIjpudWxsLCJ4X2ZpZWxkX3R5cGUiOiJ0ZW1wb3JhbCIsInlfZmllbGQiOiJ5IiwieV9maWVsZF9hZ2dyZWdhdGUiOm51bGwsInlfZmllbGRfYmluIjpudWxsLCJ5X2ZpZWxkX3NjYWxlX3R5cGUiOm51bGwsInlfZmllbGRfdHlwZSI6InF1YW50aXRhdGl2ZSJ9XSwidmxfYWxpYXMiOiJFbGl4aXIuVmwiLCJ3aWR0aCI6ODAwfQ","chunks":null,"kind":"Elixir.KinoVegaLite.ChartCell","livebook_object":"smart_cell"} -->

```elixir
Vl.new(width: 800, height: 600)
|> Vl.data_from_values(df, only: ["ds", "y"])
|> Vl.mark(:line)
|> Vl.encode_field(:x, "ds", type: :temporal)
|> Vl.encode_field(:y, "y", type: :quantitative)
```

```elixir
# Create and fit Soothsayer model

model =
Soothsayer.new(%{
trend_config: %{enabled: true},
seasonality_config: %{
yearly: %{enabled: true},
weekly: %{enabled: true}
},
epochs: 10
})

fitted_model = Soothsayer.fit(model, df)
```

```elixir
# Make predictions and extract components
predictions = Soothsayer.predict(fitted_model, df["ds"])

df_with_predictions =
df
|> DataFrame.put("yhat", predictions)

df_with_predictions
```

```elixir
Vl.new(width: 800, height: 600)
|> Vl.data_from_values(df_with_predictions, only: ["ds", "y", "yhat"])
|> Vl.layers([
Vl.new()
|> Vl.mark(:line)
|> Vl.encode_field(:x, "ds", type: :temporal)
|> Vl.encode_field(:y, "y", type: :quantitative),
Vl.new()
|> Vl.mark(:line, color: "tomato")
|> Vl.encode_field(:x, "ds", type: :temporal)
|> Vl.encode_field(:y, "yhat", type: :quantitative)
])
```

```elixir
Vl.new(width: 800, height: 600)
|> Vl.data_from_values(df_with_predictions, only: ["ds", "y", "yhat"])
|> Vl.layers([
Vl.new()
|> Vl.mark(:line)
|> Vl.encode_field(:x, "ds",
type: :temporal,
scale: [domain: ["2023-05-01", "2023-06-30"]]
)
|> Vl.encode_field(:y, "y", type: :quantitative),
Vl.new()
|> Vl.mark(:line, color: "tomato")
|> Vl.encode_field(:x, "ds",
type: :temporal,
scale: [domain: ["2023-05-01", "2023-06-30"]]
)
|> Vl.encode_field(:y, "yhat", type: :quantitative)
])
```

```elixir
input = %{
"trend" => Nx.template({1, 1}, :f32),
"yearly" => Nx.template({1, 12}, :f32), # Assuming 6 Fourier terms for yearly seasonality
"weekly" => Nx.template({1, 6}, :f32) # Assuming 3 Fourier terms for weekly seasonality
}

# Display the graph
Axon.Display.as_graph(model.network, input)
```

```elixir
components = Soothsayer.predict_components(fitted_model, df["ds"])

df_with_components =
df
|> DataFrame.put("yhat", predictions)
|> DataFrame.put("trend", components.trend)
|> DataFrame.put("yearly_seasonality", components.yearly_seasonality)
|> DataFrame.put("weekly_seasonality", components.weekly_seasonality)
```

```elixir
Vl.new(width: 800, height: 600)
|> Vl.data_from_values(df_with_components)
|> Vl.layers([
Vl.new()
|> Vl.mark(:line)
|> Vl.encode_field(:x, "ds", type: :temporal)
|> Vl.encode_field(:y, "trend", type: :quantitative),
])
```

```elixir
Vl.new(width: 800, height: 600)
|> Vl.data_from_values(df_with_components)
|> Vl.layers([
Vl.new()
|> Vl.mark(:line)
|> Vl.encode_field(:x, "ds", type: :temporal)
|> Vl.encode_field(:y, "yearly_seasonality", type: :quantitative),
])
```

```elixir
Vl.new(width: 800, height: 600)
|> Vl.data_from_values(df_with_components)
|> Vl.layers([
Vl.new()
|> Vl.mark(:line)
|> Vl.encode_field(:x, "ds", type: :temporal, scale: [domain: ["2023-01-01", "2023-06-30"]])
|> Vl.encode_field(:y, "weekly_seasonality", type: :quantitative, scale: [domain: [1000, 1400]])
])
```

## Real data

Daily energy price data over the 4 years from Spain.

```elixir
real_df = DataFrame.from_csv!("https://raw.githubusercontent.com/ourownstory/neuralprophet-data/main/kaggle-energy/datasets/tutorial01.csv")

real_df =
real_df
|> DataFrame.put("ds", real_df["ds"] |> Explorer.Series.cast(:date))
```

<!-- livebook:{"attrs":"eyJjaGFydF90aXRsZSI6bnVsbCwiaGVpZ2h0Ijo2MDAsImxheWVycyI6W3siYWN0aXZlIjp0cnVlLCJjaGFydF90eXBlIjoicG9pbnQiLCJjb2xvcl9maWVsZCI6bnVsbCwiY29sb3JfZmllbGRfYWdncmVnYXRlIjpudWxsLCJjb2xvcl9maWVsZF9iaW4iOm51bGwsImNvbG9yX2ZpZWxkX3NjYWxlX3NjaGVtZSI6bnVsbCwiY29sb3JfZmllbGRfdHlwZSI6bnVsbCwiZGF0YV92YXJpYWJsZSI6InJlYWxfZGYiLCJnZW9kYXRhX2NvbG9yIjoiYmx1ZSIsImxhdGl0dWRlX2ZpZWxkIjpudWxsLCJsb25naXR1ZGVfZmllbGQiOm51bGwsInhfZmllbGQiOiJkcyIsInhfZmllbGRfYWdncmVnYXRlIjpudWxsLCJ4X2ZpZWxkX2JpbiI6bnVsbCwieF9maWVsZF9zY2FsZV90eXBlIjpudWxsLCJ4X2ZpZWxkX3R5cGUiOiJ0ZW1wb3JhbCIsInlfZmllbGQiOiJ5IiwieV9maWVsZF9hZ2dyZWdhdGUiOm51bGwsInlfZmllbGRfYmluIjpudWxsLCJ5X2ZpZWxkX3NjYWxlX3R5cGUiOm51bGwsInlfZmllbGRfdHlwZSI6InF1YW50aXRhdGl2ZSJ9XSwidmxfYWxpYXMiOiJFbGl4aXIuVmwiLCJ3aWR0aCI6ODAwfQ","chunks":null,"kind":"Elixir.KinoVegaLite.ChartCell","livebook_object":"smart_cell"} -->

```elixir
Vl.new(width: 800, height: 600)
|> Vl.data_from_values(real_df, only: ["ds", "y"])
|> Vl.mark(:point)
|> Vl.encode_field(:x, "ds", type: :temporal)
|> Vl.encode_field(:y, "y", type: :quantitative)
```

```elixir
model = Soothsayer.new(%{epochs: 100})
fitted_model = Soothsayer.fit(model, real_df)
```

```elixir
predictions = Soothsayer.predict(fitted_model, real_df["ds"])

real_df_with_predictions =
real_df
|> DataFrame.put("yhat", predictions)

real_df_with_predictions
```

```elixir
Vl.new(width: 800, height: 600)
|> Vl.data_from_values(real_df_with_predictions, only: ["ds", "y", "yhat"])
|> Vl.layers([
Vl.new()
|> Vl.mark(:line)
|> Vl.encode_field(:x, "ds", type: :temporal)
|> Vl.encode_field(:y, "y", type: :quantitative),
Vl.new()
|> Vl.mark(:line, color: "tomato")
|> Vl.encode_field(:x, "ds", type: :temporal)
|> Vl.encode_field(:y, "yhat", type: :quantitative)
])
```

0 comments on commit 6e32816

Please sign in to comment.