From 04364eaff6f3a73cbe11dbefcc949fe86d684279 Mon Sep 17 00:00:00 2001 From: lassepe Date: Sat, 23 Apr 2022 16:35:08 +0200 Subject: [PATCH 1/8] [wip] rudimentary rule for ForwardDiff.jl --- Manifest.toml | 86 ++++++++++++++++++++++++++++++++++++++++++++++ Project.toml | 1 + src/TensorGames.jl | 1 + src/nash.jl | 43 ++++++++++++++++++++--- 4 files changed, 126 insertions(+), 5 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index df716d0..289f833 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -30,6 +30,12 @@ git-tree-sha1 = "9950387274246d08af38f6eef8cb5480862a435f" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" version = "1.14.0" +[[deps.ChangesOfVariables]] +deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +git-tree-sha1 = "bf98fa45a0a4cee295de98d4c1462be26345b9a1" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.2" + [[deps.CodecBzip2]] deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"] git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7" @@ -42,6 +48,12 @@ git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" version = "0.7.0" +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + [[deps.Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] git-tree-sha1 = "96b0bc6c52df76506efc8a441c6cf1adcb1babc4" @@ -60,14 +72,38 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +[[deps.DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "dd933c4ef7b4c270aacd4eb88fa64c147492acf0" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.10.0" + [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.6" + [[deps.Downloads]] deps = ["ArgTools", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "1bd6fc0c344fc0cbee1f42f8d2e7ec8253dda2d2" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.25" + [[deps.HTTP]] deps = ["Base64", "Dates", "IniFile", "Logging", "MbedTLS", "NetworkOptions", "Sockets", "URIs"] git-tree-sha1 = "0fa77022fe4b511826b39c894c90daf5fce3334a" @@ -83,6 +119,17 @@ version = "0.5.1" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "91b5dcf362c5add98049e6c29ee756910b03051d" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.3" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + [[deps.JLLWrappers]] deps = ["Preferences"] git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" @@ -124,9 +171,21 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" deps = ["Libdl", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[deps.LogExpFunctions]] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "a970d55c2ad8084ca317a4658ba6ce99b7523571" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.12" + [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.9" + [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -159,6 +218,11 @@ git-tree-sha1 = "8d9496b2339095901106961f44718920732616bb" uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" version = "0.2.22" +[[deps.NaNMath]] +git-tree-sha1 = "b086b7ea07f8e38cf122f5016af580881ac914fe" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.7" + [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -166,6 +230,16 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + [[deps.OrderedCollections]] git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -226,6 +300,18 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +[[deps.SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "5ba658aeecaaf96923dce0da9e703bd1fe7666f9" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.1.4" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "cd56bf18ed715e8b09f06ef8c6b781e6cdc49911" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.4.4" + [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/Project.toml b/Project.toml index 358f069..dae4390 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.3.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PATHSolver = "f5f7c340-0bb3-5c69-969a-41884d311d1b" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/src/TensorGames.jl b/src/TensorGames.jl index 27d3627..c8aab77 100644 --- a/src/TensorGames.jl +++ b/src/TensorGames.jl @@ -4,6 +4,7 @@ using PATHSolver using SparseArrays using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, @non_differentiable using LinearAlgebra: qr, diag, I +using ForwardDiff: ForwardDiff include("nash.jl") diff --git a/src/nash.jl b/src/nash.jl index fc509b1..1943e80 100644 --- a/src/nash.jl +++ b/src/nash.jl @@ -43,7 +43,7 @@ function ChainRulesCore.rrule(::typeof(expected_cost), x, cost_tensor) grad!(∂x, cost_tensor, xx, n, tensor_indices, primal_indices) end ∂x .*= ∂value - ∂x = [∂x[primal_indices[n,1]:primal_indices[n,2]] for n ∈ 1:N] + ∂x = [∂x[primal_indices[n, 1]:primal_indices[n, 2]] for n ∈ 1:N] ∂ct = zero(cost_tensor) for ind ∈ tensor_indices @@ -56,8 +56,8 @@ end function grad!(f, CT, x, n, indices, primal_inds) f[primal_inds[n, 1]:primal_inds[n, 2]] .= 0.0 - for ind ∈ indices - prob = prob_prod(x,ind,primal_inds, n) + for ind ∈ indices + prob = prob_prod(x, ind, primal_inds, n) for i ∈ 1:primal_inds[n, 2]+1-primal_inds[n, 1] if ind[n] == i f[primal_inds[n, 1]+i-1] += CT[ind] * prob @@ -168,12 +168,45 @@ function (T::Wrapper)(n::Cint, return Cint(0) end +function compute_equilibrium(cost_tensors::Vector{<:Array{<:ForwardDiff.Dual}}; kwargs...) + # strip off the duals: + cost_tensors_v = [ForwardDiff.value.(c) for c in cost_tensors] + cost_tensors_p = [ForwardDiff.partials.(c) for c in cost_tensors] + # forward pass + res = compute_equilibrium(cost_tensors_v; kwargs...) + # backward pass + # 1. compute jacobian + _back = _compute_equilibrium_pullback(res) + # 2. project input-sensitivy through jacobian to yield output sensitivy + x_p = let + # output sensitivities stacked for all players + x_p_stacked = sum(zip(cost_tensors_p, _back)) do (cost_tensor_p, ∂cost_tensor) + inflated_cost_tensor_p = reshape(cost_tensor_p, 1, size(cost_tensor_p)...) + dims = 2:ndims(inflated_cost_tensor_p) + reshape(sum(∂cost_tensor .* inflated_cost_tensor_p; dims), size(∂cost_tensor, 1)) + end + # unstacking the output sensitivities + n_actions_per_player = size(first(cost_tensors)) + x_p_stacked_it = Iterators.Stateful(x_p_stacked) + map(n_actions_per_player) do n_actions + Iterators.take(x_p_stacked_it, n_actions) |> collect + end + end + + # 3. glue primal and dual results together into a ForwardDiff.Dual-valued result + # TODO: think about tagging to avoid perturbation confusion + x_d = [ForwardDiff.Dual.(xi_v, xi_p) for (xi_v, xi_p) in zip(res.x, x_p)] + + (; x = x_d, res.λ, res._deriv_info) +end + function compute_equilibrium(cost_tensors; initialization = nothing, ϵ = 0.0, silent = true, convergence_tolerance = 1e-6) + N = Cint(length(cost_tensors)) m = Cint.(size(cost_tensors[1])) @assert all(m == size(tensor) for tensor ∈ cost_tensors) @@ -240,7 +273,7 @@ function ChainRulesCore.rrule(::typeof(compute_equilibrium), full_sensitivities = map(∂res.x, res.x) do r, x r isa ZeroTangent ? zeros(x) : r end - derivs = vcat(full_sensitivities...) + derivs = reduce(vcat, full_sensitivities) map(_back) do ∂cost_tensor dropdims(sum(∂cost_tensor .* derivs; dims = 1); dims = 1) @@ -291,7 +324,7 @@ function _compute_equilibrium_pullback(res; bound_tolerance = 1e-6, singularity_ # Artificially returning zero-derivatives if solution is non-isolated. return ∂cost_tensors end - nJi = factorization\(-I) + nJi = factorization \ (-I) for ind ∈ res._deriv_info.tensor_indices for n ∈ 1:N From d23dbfd73cecd860474043b3fe7af09caec70f63 Mon Sep 17 00:00:00 2001 From: lassepe Date: Mon, 25 Apr 2022 15:24:23 +0200 Subject: [PATCH 2/8] JuliaFormatter.jl --- test/test_derivatives.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_derivatives.jl b/test/test_derivatives.jl index 6654ce1..7304916 100644 --- a/test/test_derivatives.jl +++ b/test/test_derivatives.jl @@ -1,41 +1,41 @@ -function compute_value(cost_tensors, i=1; init=nothing, ϵ=0.0) - sol = compute_equilibrium(cost_tensors; ϵ=ϵ, initialization=init, convergence_tolerance=1e-10) +function compute_value(cost_tensors, i = 1; init = nothing, ϵ = 0.0) + sol = compute_equilibrium(cost_tensors; ϵ = ϵ, initialization = init, convergence_tolerance = 1e-10) value = expected_cost(sol.x, cost_tensors[i]) end @testset "derivative computation" begin - for N ∈ [2,3,4] + for N ∈ [2, 3, 4] for iteration ∈ 1:3 for min_prob ∈ [0.0, 0.05] ϵ = 1e-6 ϵ2 = 1e-3 d_max = 50.0 # ignore very large numerical derivatives, as they imply non-local solution sol_tol = 1e-10 - d = 3*ones(Int,N) + d = 3 * ones(Int, N) cost_tensors = [randn(d...) for n ∈ 1:N] - sol = compute_equilibrium(cost_tensors; ϵ=min_prob, convergence_tolerance=1e-10) + sol = compute_equilibrium(cost_tensors; ϵ = min_prob, convergence_tolerance = 1e-10) x = vcat(sol.x...) init = [x; sol.λ] - values = [compute_value(cost_tensors, i; ϵ=min_prob) for i ∈ 1:N] - compute_val_eps(CT, ind) = compute_value(CT, ind; ϵ=min_prob) + values = [compute_value(cost_tensors, i; ϵ = min_prob) for i ∈ 1:N] + compute_val_eps(CT, ind) = compute_value(CT, ind; ϵ = min_prob) value_grads = [gradient(compute_val_eps, cost_tensors, i) for i ∈ 1:N] - + for n = 1:N cinds = CartesianIndices(cost_tensors[n]) for i ∈ eachindex(cost_tensors[n]) ct2 = deepcopy(cost_tensors) ct2[n][i] += ϵ - values_perturbed = [compute_value(ct2, i; ϵ=min_prob, init=init) for i ∈ 1:N] + values_perturbed = [compute_value(ct2, i; ϵ = min_prob, init = init) for i ∈ 1:N] ii = cinds[i] - num_derivs = (values_perturbed-values)./ϵ + num_derivs = (values_perturbed - values) ./ ϵ if any(abs.(num_derivs) .≥ d_max) continue end for j ∈ 1:length(num_derivs) - @test ≈(num_derivs[j], value_grads[j][1][n][ii]; atol=ϵ2) + @test ≈(num_derivs[j], value_grads[j][1][n][ii]; atol = ϵ2) end end end From 9f50891b33d9551e36c966dd75290753e474008f Mon Sep 17 00:00:00 2001 From: lassepe Date: Mon, 25 Apr 2022 17:52:03 +0200 Subject: [PATCH 3/8] Don't check in Manifest.toml --- .gitignore | 1 + test/Manifest.toml | 314 --------------------------------------------- 2 files changed, 1 insertion(+), 314 deletions(-) create mode 100644 .gitignore delete mode 100644 test/Manifest.toml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c181d1f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +test/Manifest.toml diff --git a/test/Manifest.toml b/test/Manifest.toml deleted file mode 100644 index 59ab023..0000000 --- a/test/Manifest.toml +++ /dev/null @@ -1,314 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.7.2" -manifest_format = "2.0" - -[[deps.AbstractFFTs]] -deps = ["ChainRulesCore", "LinearAlgebra"] -git-tree-sha1 = "6f1d9bc1c08f9f4a8fa92e3ea3cb50153a1b40d4" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.1.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.ChainRules]] -deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"] -git-tree-sha1 = "8b887daa6af5daf705081061e36386190204ac87" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.28.1" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "9950387274246d08af38f6eef8cb5480862a435f" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.14.0" - -[[deps.ChangesOfVariables]] -deps = ["ChainRulesCore", "LinearAlgebra", "Test"] -git-tree-sha1 = "bf98fa45a0a4cee295de98d4c1462be26345b9a1" -uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.2" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "96b0bc6c52df76506efc8a441c6cf1adcb1babc4" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.42.0" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[deps.DiffResults]] -deps = ["StaticArrays"] -git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.0.3" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "dd933c4ef7b4c270aacd4eb88fa64c147492acf0" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.10.0" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.6" - -[[deps.Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" - -[[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.13.2" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "1bd6fc0c344fc0cbee1f42f8d2e7ec8253dda2d2" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.25" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "7f43342f8d5fd30ead0ba1b49ab1a3af3b787d24" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.5" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "91b5dcf362c5add98049e6c29ee756910b03051d" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.3" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.1.1" - -[[deps.JLLWrappers]] -deps = ["Preferences"] -git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.4.1" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" - -[[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "58f25e56b706f95125dcb796f39e1fb01d913a71" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.10" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.9" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" - -[[deps.NaNMath]] -git-tree-sha1 = "b086b7ea07f8e38cf122f5016af580881ac914fe" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.7" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "d3538e7f8a790dc8903519090857ef8e1283eecd" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.2.5" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA", "Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[deps.SpecialFunctions]] -deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "5ba658aeecaaf96923dce0da9e703bd1fe7666f9" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.1.4" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "4f6ec5d99a28e1a749559ef7dd518663c5eca3d5" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.4.3" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "52adc0a505b6421a8668f13dcdb0c4cb498bd72c" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.37" - -[[deps.ZygoteRules]] -deps = ["MacroTools"] -git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.2" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" From fe81e647290896af3d186993eb7abf3859339eaa Mon Sep 17 00:00:00 2001 From: lassepe Date: Mon, 25 Apr 2022 17:52:24 +0200 Subject: [PATCH 4/8] Widen dispatch bounds for ForwardDiff.Dual --- src/nash.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nash.jl b/src/nash.jl index 1943e80..8621007 100644 --- a/src/nash.jl +++ b/src/nash.jl @@ -168,7 +168,7 @@ function (T::Wrapper)(n::Cint, return Cint(0) end -function compute_equilibrium(cost_tensors::Vector{<:Array{<:ForwardDiff.Dual}}; kwargs...) +function compute_equilibrium(cost_tensors::AbstractVector{<:AbstractArray{<:ForwardDiff.Dual}}; kwargs...) # strip off the duals: cost_tensors_v = [ForwardDiff.value.(c) for c in cost_tensors] cost_tensors_p = [ForwardDiff.partials.(c) for c in cost_tensors] From 4e2f2ec22e1b0c42a5525851174f359caa7eb337 Mon Sep 17 00:00:00 2001 From: lassepe Date: Mon, 25 Apr 2022 17:52:42 +0200 Subject: [PATCH 5/8] Add tests for forward mode AD --- test/runtests.jl | 4 +++- test/test_derivatives.jl | 48 ++++++++++++++++++++++++++-------------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4dec2dc..4270cc9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,9 @@ using Test using TensorGames using LinearAlgebra: norm -using Zygote: gradient +using Zygote: gradient, forwarddiff +using Random: MersenneTwister +using FiniteDifferences: to_vec include("test_equilibrium.jl") include("test_derivatives.jl") diff --git a/test/test_derivatives.jl b/test/test_derivatives.jl index 7304916..e3eb7db 100644 --- a/test/test_derivatives.jl +++ b/test/test_derivatives.jl @@ -4,6 +4,7 @@ function compute_value(cost_tensors, i = 1; init = nothing, ϵ = 0.0) end @testset "derivative computation" begin + rng = MersenneTwister(1) for N ∈ [2, 3, 4] for iteration ∈ 1:3 for min_prob ∈ [0.0, 0.05] @@ -12,30 +13,45 @@ end d_max = 50.0 # ignore very large numerical derivatives, as they imply non-local solution sol_tol = 1e-10 d = 3 * ones(Int, N) - cost_tensors = [randn(d...) for n ∈ 1:N] + cost_tensors = [randn(rng, d...) for n ∈ 1:N] sol = compute_equilibrium(cost_tensors; ϵ = min_prob, convergence_tolerance = 1e-10) x = vcat(sol.x...) init = [x; sol.λ] values = [compute_value(cost_tensors, i; ϵ = min_prob) for i ∈ 1:N] - compute_val_eps(CT, ind) = compute_value(CT, ind; ϵ = min_prob) - value_grads = [gradient(compute_val_eps, cost_tensors, i) for i ∈ 1:N] + compute_val_eps_reverse_mode_ad(CT, ind) = compute_value(CT, ind; ϵ = min_prob) + compute_val_eps_forward_mode_ad(CT, ind) = begin + CT_stacked = cat(CT..., dims = ndims(eltype(CT)) + 1) + forwarddiff(CT_stacked) do CT_stacked + CT_unstacked = collect(eachslice(CT_stacked; dims = ndims(CT_stacked))) + compute_value(CT_unstacked, ind; ϵ = min_prob) + end + end - for n = 1:N - cinds = CartesianIndices(cost_tensors[n]) - for i ∈ eachindex(cost_tensors[n]) - ct2 = deepcopy(cost_tensors) - ct2[n][i] += ϵ + for (mode_name, f) in [ + ("reverse", compute_val_eps_reverse_mode_ad), + ("forward", compute_val_eps_forward_mode_ad), + ] + @testset "ndims: $N, ad-mode: $mode_name, \ + iteration: $iteration, min_prob: $min_prob" begin + value_grads_reverse_diff = [gradient(f, cost_tensors, i) for i ∈ 1:N] + for n = 1:N + cinds = CartesianIndices(cost_tensors[n]) + for i ∈ eachindex(cost_tensors[n]) + ct2 = deepcopy(cost_tensors) + ct2[n][i] += ϵ - values_perturbed = [compute_value(ct2, i; ϵ = min_prob, init = init) for i ∈ 1:N] + values_perturbed = [compute_value(ct2, i; ϵ = min_prob, init = init) for i ∈ 1:N] - ii = cinds[i] - num_derivs = (values_perturbed - values) ./ ϵ - if any(abs.(num_derivs) .≥ d_max) - continue - end - for j ∈ 1:length(num_derivs) - @test ≈(num_derivs[j], value_grads[j][1][n][ii]; atol = ϵ2) + ii = cinds[i] + num_derivs = (values_perturbed - values) ./ ϵ + if any(abs.(num_derivs) .≥ d_max) + continue + end + for j ∈ 1:length(num_derivs) + @test ≈(num_derivs[j], value_grads_reverse_diff[j][1][n][ii]; atol = ϵ2) + end + end end end end From b6559c098e4675ace18d25c0e5b6485f93d28072 Mon Sep 17 00:00:00 2001 From: lassepe Date: Mon, 25 Apr 2022 18:09:03 +0200 Subject: [PATCH 6/8] Add missing test dependencies --- test/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 52ba12d..343fce4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,6 @@ [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +TensorGames = "dad72d94-522a-402a-84d3-a630e8ffd4d6" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" From 930d205263ca05abe928bb3002ac7b25668c411c Mon Sep 17 00:00:00 2001 From: lassepe Date: Mon, 25 Apr 2022 21:45:46 +0200 Subject: [PATCH 7/8] Add dual tags to avoid perturbation confusion --- src/nash.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/nash.jl b/src/nash.jl index 8621007..452f94d 100644 --- a/src/nash.jl +++ b/src/nash.jl @@ -168,7 +168,7 @@ function (T::Wrapper)(n::Cint, return Cint(0) end -function compute_equilibrium(cost_tensors::AbstractVector{<:AbstractArray{<:ForwardDiff.Dual}}; kwargs...) +function compute_equilibrium(cost_tensors::AbstractVector{<:AbstractArray{<:ForwardDiff.Dual{T}}}; kwargs...) where {T} # strip off the duals: cost_tensors_v = [ForwardDiff.value.(c) for c in cost_tensors] cost_tensors_p = [ForwardDiff.partials.(c) for c in cost_tensors] @@ -194,8 +194,7 @@ function compute_equilibrium(cost_tensors::AbstractVector{<:AbstractArray{<:Forw end # 3. glue primal and dual results together into a ForwardDiff.Dual-valued result - # TODO: think about tagging to avoid perturbation confusion - x_d = [ForwardDiff.Dual.(xi_v, xi_p) for (xi_v, xi_p) in zip(res.x, x_p)] + x_d = [ForwardDiff.Dual{T}.(xi_v, xi_p) for (xi_v, xi_p) in zip(res.x, x_p)] (; x = x_d, res.λ, res._deriv_info) end From 364a382c44165f6758bf19ff887d141a03a1b349 Mon Sep 17 00:00:00 2001 From: lassepe Date: Mon, 25 Apr 2022 21:51:12 +0200 Subject: [PATCH 8/8] Fix tests --- test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4270cc9..f4c3a5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,6 @@ using TensorGames using LinearAlgebra: norm using Zygote: gradient, forwarddiff using Random: MersenneTwister -using FiniteDifferences: to_vec include("test_equilibrium.jl") include("test_derivatives.jl")