diff --git a/Project.toml b/Project.toml index 03707e8d..8fb60638 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "5.7.0" +version = "5.7.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/chains.jl b/src/chains.jl index aa652592..b6852b67 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -137,6 +137,12 @@ julia> namesingroup(chn, :A) 2-element Vector{Symbol}: Symbol("A[1]") Symbol("A[2]") + +julia> # Also works for specific elements. + namesingroup(chn, Symbol("A[1]")) +1-element Vector{Symbol}: + Symbol("A[1]") + ``` ```jldoctest julia> chn = Chains(rand(100, 3, 2), ["A.1", "A.2", "B"]); @@ -155,7 +161,7 @@ function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket) idx_str = index_type == :bracket ? "[" : "." # Start by looking up the symbols in the list of parameter names. names_of_params = names(chains) - regex = Regex("^$sym\$|^$sym\\$idx_str") + regex = Regex("^\\Q$sym\\E\$|^\\Q$sym$idx_str\\E") indices = findall(x -> match(regex, string(x)) !== nothing, names(chains)) return names_of_params[indices] end