Skip to content

Commit

Permalink
Merge pull request #8 from 4estlaine/feature/forward_diff_rule
Browse files Browse the repository at this point in the history
Add rule for forard mode AD via ForwardDiff.jl
  • Loading branch information
Forrest Laine authored Apr 25, 2022
2 parents b1a1610 + 364a382 commit a2f2758
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 343 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test/Manifest.toml
86 changes: 86 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ git-tree-sha1 = "9950387274246d08af38f6eef8cb5480862a435f"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.14.0"

[[deps.ChangesOfVariables]]
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
git-tree-sha1 = "bf98fa45a0a4cee295de98d4c1462be26345b9a1"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.2"

[[deps.CodecBzip2]]
deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"]
git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7"
Expand All @@ -42,6 +48,12 @@ git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.0"

[[deps.CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.3.0"

[[deps.Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "96b0bc6c52df76506efc8a441c6cf1adcb1babc4"
Expand All @@ -60,14 +72,38 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[deps.DiffResults]]
deps = ["StaticArrays"]
git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "1.0.3"

[[deps.DiffRules]]
deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "dd933c4ef7b4c270aacd4eb88fa64c147492acf0"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.10.0"

[[deps.Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[deps.DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.6"

[[deps.Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[deps.ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "1bd6fc0c344fc0cbee1f42f8d2e7ec8253dda2d2"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.25"

[[deps.HTTP]]
deps = ["Base64", "Dates", "IniFile", "Logging", "MbedTLS", "NetworkOptions", "Sockets", "URIs"]
git-tree-sha1 = "0fa77022fe4b511826b39c894c90daf5fce3334a"
Expand All @@ -83,6 +119,17 @@ version = "0.5.1"
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[deps.InverseFunctions]]
deps = ["Test"]
git-tree-sha1 = "91b5dcf362c5add98049e6c29ee756910b03051d"
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
version = "0.1.3"

[[deps.IrrationalConstants]]
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.1.1"

[[deps.JLLWrappers]]
deps = ["Preferences"]
git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1"
Expand Down Expand Up @@ -124,9 +171,21 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
deps = ["Libdl", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.LogExpFunctions]]
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "a970d55c2ad8084ca317a4658ba6ce99b7523571"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.12"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.9"

[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
Expand Down Expand Up @@ -159,13 +218,28 @@ git-tree-sha1 = "8d9496b2339095901106961f44718920732616bb"
uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
version = "0.2.22"

[[deps.NaNMath]]
git-tree-sha1 = "b086b7ea07f8e38cf122f5016af580881ac914fe"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.7"

[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"

[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"

[[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.5+0"

[[deps.OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Expand Down Expand Up @@ -226,6 +300,18 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[deps.SpecialFunctions]]
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "5ba658aeecaaf96923dce0da9e703bd1fe7666f9"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.1.4"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "cd56bf18ed715e8b09f06ef8c6b781e6cdc49911"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.4.4"

[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.3.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PATHSolver = "f5f7c340-0bb3-5c69-969a-41884d311d1b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
1 change: 1 addition & 0 deletions src/TensorGames.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using PATHSolver
using SparseArrays
using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, @non_differentiable
using LinearAlgebra: qr, diag, I
using ForwardDiff: ForwardDiff

include("nash.jl")

Expand Down
42 changes: 37 additions & 5 deletions src/nash.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function ChainRulesCore.rrule(::typeof(expected_cost), x, cost_tensor)
grad!(∂x, cost_tensor, xx, n, tensor_indices, primal_indices)
end
∂x .*= ∂value
∂x = [∂x[primal_indices[n,1]:primal_indices[n,2]] for n 1:N]
∂x = [∂x[primal_indices[n, 1]:primal_indices[n, 2]] for n 1:N]

∂ct = zero(cost_tensor)
for ind tensor_indices
Expand All @@ -56,8 +56,8 @@ end

function grad!(f, CT, x, n, indices, primal_inds)
f[primal_inds[n, 1]:primal_inds[n, 2]] .= 0.0
for ind indices
prob = prob_prod(x,ind,primal_inds, n)
for ind indices
prob = prob_prod(x, ind, primal_inds, n)
for i 1:primal_inds[n, 2]+1-primal_inds[n, 1]
if ind[n] == i
f[primal_inds[n, 1]+i-1] += CT[ind] * prob
Expand Down Expand Up @@ -168,12 +168,44 @@ function (T::Wrapper)(n::Cint,
return Cint(0)
end

function compute_equilibrium(cost_tensors::AbstractVector{<:AbstractArray{<:ForwardDiff.Dual{T}}}; kwargs...) where {T}
# strip off the duals:
cost_tensors_v = [ForwardDiff.value.(c) for c in cost_tensors]
cost_tensors_p = [ForwardDiff.partials.(c) for c in cost_tensors]
# forward pass
res = compute_equilibrium(cost_tensors_v; kwargs...)
# backward pass
# 1. compute jacobian
_back = _compute_equilibrium_pullback(res)
# 2. project input-sensitivy through jacobian to yield output sensitivy
x_p = let
# output sensitivities stacked for all players
x_p_stacked = sum(zip(cost_tensors_p, _back)) do (cost_tensor_p, ∂cost_tensor)
inflated_cost_tensor_p = reshape(cost_tensor_p, 1, size(cost_tensor_p)...)
dims = 2:ndims(inflated_cost_tensor_p)
reshape(sum(∂cost_tensor .* inflated_cost_tensor_p; dims), size(∂cost_tensor, 1))
end
# unstacking the output sensitivities
n_actions_per_player = size(first(cost_tensors))
x_p_stacked_it = Iterators.Stateful(x_p_stacked)
map(n_actions_per_player) do n_actions
Iterators.take(x_p_stacked_it, n_actions) |> collect
end
end

# 3. glue primal and dual results together into a ForwardDiff.Dual-valued result
x_d = [ForwardDiff.Dual{T}.(xi_v, xi_p) for (xi_v, xi_p) in zip(res.x, x_p)]

(; x = x_d, res.λ, res._deriv_info)
end


function compute_equilibrium(cost_tensors;
initialization = nothing,
ϵ = 0.0,
silent = true,
convergence_tolerance = 1e-6)

N = Cint(length(cost_tensors))
m = Cint.(size(cost_tensors[1]))
@assert all(m == size(tensor) for tensor cost_tensors)
Expand Down Expand Up @@ -240,7 +272,7 @@ function ChainRulesCore.rrule(::typeof(compute_equilibrium),
full_sensitivities = map(∂res.x, res.x) do r, x
r isa ZeroTangent ? zeros(x) : r
end
derivs = vcat(full_sensitivities...)
derivs = reduce(vcat, full_sensitivities)

map(_back) do ∂cost_tensor
dropdims(sum(∂cost_tensor .* derivs; dims = 1); dims = 1)
Expand Down Expand Up @@ -291,7 +323,7 @@ function _compute_equilibrium_pullback(res; bound_tolerance = 1e-6, singularity_
# Artificially returning zero-derivatives if solution is non-isolated.
return ∂cost_tensors
end
nJi = factorization\(-I)
nJi = factorization \ (-I)

for ind res._deriv_info.tensor_indices
for n 1:N
Expand Down
Loading

0 comments on commit a2f2758

Please sign in to comment.