Skip to content

Commit

Permalink
Improvements to mode estimation docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Jun 6, 2024
1 parent 9c74411 commit 303e248
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 54 deletions.
88 changes: 44 additions & 44 deletions tutorials/docs-17-mode-estimation/Manifest.toml
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ version = "0.6.1"

[[deps.AdvancedMH]]
deps = ["AbstractMCMC", "Distributions", "FillArrays", "LinearAlgebra", "LogDensityProblems", "Random", "Requires"]
git-tree-sha1 = "16589dbdd36c782ff01700908e962b303474f641"
git-tree-sha1 = "fa4e8d6f9bae913aaa40224cf9407163e693d829"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.8.1"
version = "0.8.2"
weakdeps = ["DiffResults", "ForwardDiff", "MCMCChains", "StructArrays"]

[deps.AdvancedMH.extensions]
Expand All @@ -116,9 +116,9 @@ weakdeps = ["Libtask"]

[[deps.AdvancedVI]]
deps = ["ADTypes", "Bijectors", "DiffResults", "Distributions", "DistributionsAD", "DocStringExtensions", "ForwardDiff", "LinearAlgebra", "ProgressMeter", "Random", "Requires", "StatsBase", "StatsFuns", "Tracker"]
git-tree-sha1 = "187f67ab998f25208651262fee9539d845016b26"
git-tree-sha1 = "3e97de1a2ccce08978cd80570d8cbb9ff3f08bd3"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.2.5"
version = "0.2.6"

[deps.AdvancedVI.extensions]
AdvancedVIEnzymeExt = ["Enzyme"]
Expand Down Expand Up @@ -149,9 +149,9 @@ version = "1.1.1"

[[deps.ArrayInterface]]
deps = ["Adapt", "LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "133a240faec6e074e07c31ee75619c90544179cf"
git-tree-sha1 = "ed2ec3c9b483842ae59cd273834e5b46206d6dda"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.10.0"
version = "7.11.0"

[deps.ArrayInterface.extensions]
ArrayInterfaceBandedMatricesExt = "BandedMatrices"
Expand Down Expand Up @@ -228,9 +228,9 @@ version = "0.1.1"

[[deps.Bijectors]]
deps = ["ArgCheck", "ChainRules", "ChainRulesCore", "ChangesOfVariables", "Compat", "Distributions", "Functors", "InverseFunctions", "IrrationalConstants", "LinearAlgebra", "LogExpFunctions", "MappedArrays", "Random", "Reexport", "Requires", "Roots", "SparseArrays", "Statistics"]
git-tree-sha1 = "49491db48b1c70eefa5115e626100dbd6c0ff4c0"
git-tree-sha1 = "2173b2974d6afb2dbc72002c51c84803c08e8aa0"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.13.12"
version = "0.13.13"

[deps.Bijectors.extensions]
BijectorsDistributionsADExt = "DistributionsAD"
Expand Down Expand Up @@ -261,15 +261,15 @@ version = "0.5.1"

[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"]
git-tree-sha1 = "291821c1251486504f6bae435227907d734e94d2"
git-tree-sha1 = "5ec157747036038ec70b250f578362268f0472f1"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.66.0"
version = "1.68.0"

[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra"]
git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b"
git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.23.0"
version = "1.24.0"
weakdeps = ["SparseArrays"]

[deps.ChainRulesCore.extensions]
Expand Down Expand Up @@ -402,9 +402,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[deps.Distributions]]
deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
git-tree-sha1 = "22c595ca4146c07b16bcf9c8bea86f731f7109d2"
git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.108"
version = "0.25.109"
weakdeps = ["ChainRulesCore", "DensityInterface", "Test"]

[deps.Distributions.extensions]
Expand Down Expand Up @@ -551,9 +551,9 @@ version = "0.1.3"

[[deps.Functors]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "d3e63d9fa13f8eaa2f06f64949e2afc593ff52c2"
git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.4.10"
version = "0.4.11"

[[deps.Future]]
deps = ["Random"]
Expand Down Expand Up @@ -677,9 +677,9 @@ version = "0.4.1"

[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"]
git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a"
git-tree-sha1 = "065c36f95709dd4a676dc6839a35d6fa6f192f24"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "6.6.3"
version = "7.1.0"

[deps.LLVM.extensions]
BFloat16sExt = "BFloat16s"
Expand Down Expand Up @@ -800,9 +800,9 @@ version = "1.9.0"

[[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37"
git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.27"
version = "0.3.28"
weakdeps = ["ChainRulesCore", "ChangesOfVariables", "InverseFunctions"]

[deps.LogExpFunctions.extensions]
Expand Down Expand Up @@ -839,9 +839,9 @@ version = "2024.1.0+0"

[[deps.MLJModelInterface]]
deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
git-tree-sha1 = "d2a45e1b5998ba3fdfb6cfe0c81096d4c7fb40e7"
git-tree-sha1 = "88ef480f46e0506143681b3fb14d86742f3cecb1"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
version = "1.9.6"
version = "1.10.0"

[[deps.MacroTools]]
deps = ["Markdown", "Random"]
Expand Down Expand Up @@ -908,9 +908,9 @@ version = "2.7.1+0"

[[deps.NNlib]]
deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
git-tree-sha1 = "e0cea7ec219ada9ac80ec2e82e374ab2f154ae05"
git-tree-sha1 = "3d4617f943afe6410206a5294a95948c8d1b35bd"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.9.16"
version = "0.9.17"

[deps.NNlib.extensions]
NNlibAMDGPUExt = "AMDGPU"
Expand All @@ -932,9 +932,9 @@ version = "1.0.2"

[[deps.NamedArrays]]
deps = ["Combinatorics", "DataStructures", "DelimitedFiles", "InvertedIndices", "LinearAlgebra", "Random", "Requires", "SparseArrays", "Statistics"]
git-tree-sha1 = "0ae91efac93c3859f5c812a24c9468bb9e50b028"
git-tree-sha1 = "c7aab3836df3f31591a2b4167fcd87b741dacfc9"
uuid = "86f7a689-2022-50b4-a561-43c23ac3c673"
version = "0.10.1"
version = "0.10.2"

[[deps.NaturalSort]]
git-tree-sha1 = "eda490d06b9f7c00752ee81cfa451efe55521e21"
Expand Down Expand Up @@ -1075,9 +1075,9 @@ version = "1.4.3"

[[deps.PrettyTables]]
deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"]
git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660"
git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7"
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
version = "2.3.1"
version = "2.3.2"

[[deps.Printf]]
deps = ["Unicode"]
Expand All @@ -1096,9 +1096,9 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.10.0"

[[deps.PtrArrays]]
git-tree-sha1 = "077664975d750757f30e739c870fbbdc01db7913"
git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759"
uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
version = "1.1.0"
version = "1.2.0"

[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
Expand Down Expand Up @@ -1157,9 +1157,9 @@ version = "1.3.4"

[[deps.RecursiveArrayTools]]
deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "SparseArrays", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
git-tree-sha1 = "758bc86b90e9fee2edc4af2a750b0d3f2d5c02c5"
git-tree-sha1 = "2cea01606a852c2431ded77293eb533b511b19e6"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
version = "3.19.0"
version = "3.22.0"

[deps.RecursiveArrayTools.extensions]
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
Expand Down Expand Up @@ -1243,10 +1243,10 @@ uuid = "26aad666-b158-4e64-9d35-0e672562fa48"
version = "0.1.1"

[[deps.SciMLBase]]
deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
git-tree-sha1 = "265f1a7a804d8093fa0b17e33e45373a77e56ca5"
deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
git-tree-sha1 = "1d1d1ff37d2917cad263fa186cbc19ce4b587ccf"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
version = "2.38.0"
version = "2.40.0"

[deps.SciMLBase.extensions]
SciMLBaseChainRulesCoreExt = "ChainRulesCore"
Expand Down Expand Up @@ -1355,9 +1355,9 @@ version = "1.4.2"

[[deps.StatisticalTraits]]
deps = ["ScientificTypesBase"]
git-tree-sha1 = "30b9236691858e13f167ce829490a68e1a597782"
git-tree-sha1 = "983c41a0ddd6c19f5607ca87271d7c7620ab5d50"
uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9"
version = "3.2.0"
version = "3.3.0"

[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand Down Expand Up @@ -1417,9 +1417,9 @@ version = "7.2.1+1"

[[deps.SymbolicIndexingInterface]]
deps = ["Accessors", "ArrayInterface", "RuntimeGeneratedFunctions", "StaticArraysCore"]
git-tree-sha1 = "b479c7a16803f08779ac5b7f9844a42621baeeda"
git-tree-sha1 = "a5f6f138b740c9d93d76f0feddd3092e6ef002b7"
uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
version = "0.3.21"
version = "0.3.22"

[[deps.TOML]]
deps = ["Dates"]
Expand Down Expand Up @@ -1484,10 +1484,10 @@ version = "0.4.82"
Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e"

[[deps.Turing]]
deps = ["ADTypes", "AbstractMCMC", "Accessors", "AdvancedHMC", "AdvancedMH", "AdvancedPS", "AdvancedVI", "BangBang", "Bijectors", "DataStructures", "Distributions", "DistributionsAD", "DocStringExtensions", "DynamicPPL", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "LogDensityProblemsAD", "MCMCChains", "NamedArrays", "Printf", "Random", "Reexport", "Requires", "SciMLBase", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
git-tree-sha1 = "cfb3b446a5e52e1da4cc71b77a9350c309c581f0"
deps = ["ADTypes", "AbstractMCMC", "Accessors", "AdvancedHMC", "AdvancedMH", "AdvancedPS", "AdvancedVI", "BangBang", "Bijectors", "Compat", "DataStructures", "Distributions", "DistributionsAD", "DocStringExtensions", "DynamicPPL", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "LogDensityProblemsAD", "MCMCChains", "NamedArrays", "Optimization", "OptimizationOptimJL", "OrderedCollections", "Printf", "Random", "Reexport", "Requires", "SciMLBase", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
git-tree-sha1 = "6ea505cb1829868b333f9615e4049d7c83e97ce7"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.32.0"
version = "0.33.0"

[deps.Turing.extensions]
TuringDynamicHMCExt = "DynamicHMC"
Expand Down Expand Up @@ -1516,9 +1516,9 @@ version = "0.2.1"

[[deps.UnsafeAtomicsLLVM]]
deps = ["LLVM", "UnsafeAtomics"]
git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e"
git-tree-sha1 = "d9f5962fecd5ccece07db1ff006fb0b5271bdfdd"
uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
version = "0.1.3"
version = "0.1.4"

[[deps.WoodburyMatrices]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand Down
24 changes: 14 additions & 10 deletions tutorials/docs-17-mode-estimation/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Pkg;
Pkg.instantiate();
```

In addition to sampling from the distributions of a statistical model, one may be interested in finding the most likely parameter values that maximise for instance the posterior distribution density function or the likelihood. This is called mode estimation. Turing provides support for two mode estimation techniques, [maximum likelihood estimation](https://en.wikipedia.org/wiki/Maximum_likelihood_estimation) (MLE) and [maximum a posterior](https://en.wikipedia.org/wiki/Maximum_a_posteriori_estimation) (MAP) estimation.
After defining a statistical model, in addition to sampling from its distributions, one may be interested in finding the parameter values that maximise for instance the posterior distribution density function or the likelihood. This is called mode estimation. Turing provides support for two mode estimation techniques, [maximum likelihood estimation](https://en.wikipedia.org/wiki/Maximum_likelihood_estimation) (MLE) and [maximum a posterior](https://en.wikipedia.org/wiki/Maximum_a_posteriori_estimation) (MAP) estimation.

To demonstrate mode estimation, let us load Turing and declare a model:

Expand Down Expand Up @@ -45,23 +45,23 @@ mle_estimate = maximum_likelihood(model)
map_estimate = maximum_a_posteriori(model)
```

The estimates are returned as instances of the `ModeResult` type. It has the fields `values` for the parameter values found and `lp` for the log probability at the optimum, as well as `f` for the objective function and `optim_result` for more detailed results of the optimisation procedure, such as convergence information.
The estimates are returned as instances of the `ModeResult` type. It has the fields `values` for the parameter values found and `lp` for the log probability at the optimum, as well as `f` for the objective function and `optim_result` for more detailed results of the optimisation procedure.

```{julia}
@show mle_estimate.values
@show mle_estimate.lp
@show mle_estimate.lp;
```

## Controlling the optimisation process

Under the hood `maximum_likelihood` and `maximum_a_posteriori` use the [Optimization.jl](https://github.com/SciML/Optimization.jl) package, which provides a unified interface to many other optimisation packages. By default Turing typically uses the [LBFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) method from [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), but we can easily change that,
Under the hood `maximum_likelihood` and `maximum_a_posteriori` use the [Optimization.jl](https://github.com/SciML/Optimization.jl) package, which provides a unified interface to many other optimisation packages. By default Turing typically uses the [LBFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) method from [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl) to find the mode estimate, but we can easily change that:

```{julia}
using OptimizationOptimJL: NelderMead
maximum_likelihood(model, NelderMead())
@show maximum_likelihood(model, NelderMead())
using OptimizationNLopt: NLopt.LD_TNEWTON_PRECOND_RESTART
maximum_likelihood(model, NLopt.LD_TNEWTON_PRECOND_RESTART())
@show maximum_likelihood(model, LD_TNEWTON_PRECOND_RESTART());
```

The above are just two examples, Optimization.jl supports [many more](https://docs.sciml.ai/Optimization/stable/).
Expand All @@ -71,7 +71,9 @@ We can also help the optimisation by giving it a starting point we know is close
```{julia}
using ADTypes: AutoReverseDiff
import ReverseDiff
maximum_likelihood(model, NelderMead(); initial_params=[0.1, 2], adtype=AutoReverseDiff())
maximum_likelihood(
model, NelderMead(); initial_params=[0.1, 2], adtype=AutoReverseDiff()
)
```

When providing values to arguments like `initial_params` the parameters are typically specified in the order in which they appear in the code of the model, so in this case first `` then `m`. More precisely it's the order returned by `Turing.Inference.getparams(model, Turing.VarInfo(model))`.
Expand All @@ -82,16 +84,18 @@ We can also do constrained optimisation, by providing either intervals within wh
maximum_likelihood(model; lb=[0.0, -1.0], ub=[0.01, 1.0])
```

The arguments for lower (`lb`) and upper (`ub`) bounds follow the arguments of `Optimization.OptimizationProblem`, as do other parameters for providing bounds such as `cons`. Any extraneous keyword arguments given to `maximum_likelihood` or `maximum_a_posteriori` are passed to `Optimization.solve`. Some often useful ones are `maxiters` for controlling the maximum number of iterations and `abstol` and `reltol` for the absolute and relative convergence tolerances:
The arguments for lower (`lb`) and upper (`ub`) bounds follow the arguments of `Optimization.OptimizationProblem`, as do other parameters for providing [constraints](https://docs.sciml.ai/Optimization/stable/tutorials/constraints/), such as `cons`. Any extraneous keyword arguments given to `maximum_likelihood` or `maximum_a_posteriori` are passed to `Optimization.solve`. Some often useful ones are `maxiters` for controlling the maximum number of iterations and `abstol` and `reltol` for the absolute and relative convergence tolerances:

```{julia}
badly_converged_mle = maximum_likelihood(model, NelderMead(); maxiters=10, reltol=1e-9)
badly_converged_mle = maximum_likelihood(
model, NelderMead(); maxiters=10, reltol=1e-9
)
```

We can check whether the optimisation converged using the `optim_result` field of the result:

```{julia}
@show badly_converged_mle.optim_result
@show badly_converged_mle.optim_result;
```

For more details, such as a full list of possible arguments, we encourage the reader to read the docstring of the function `Turing.Optimisation.estimate_mode`, which is what `maximum_likelihood` and `maximum_a_posteriori` call, and the documentation of [Optimization.jl](https://docs.sciml.ai/Optimization/stable/).
Expand Down

0 comments on commit 303e248

Please sign in to comment.