Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for AbstractString #397

Merged
merged 3 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "5.6.1"
version = "5.7.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
22 changes: 11 additions & 11 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function Chains(
end

"""
Chains(c::Chains, section::Union{Symbol,String})
Chains(c::Chains, section::Union{Symbol,AbstractString})
Chains(c::Chains, sections)

Return a new chain with only a specific `section` or multiple `sections` pulled out.
Expand All @@ -101,7 +101,7 @@ julia> names(chn2)
:a
```
"""
Chains(c::Chains, section::Union{Symbol,String}) = Chains(c, (section,))
Chains(c::Chains, section::Union{Symbol,AbstractString}) = Chains(c, (section,))
function Chains(chn::Chains, sections)
# Make sure the sections exist first.
all(haskey(chn.name_map, Symbol(x)) for x in sections) ||
Expand All @@ -121,7 +121,7 @@ Chains(chain::Chains, ::Nothing) = chain
# Groups of parameters

"""
namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket)
namesingroup(chains::Chains, sym::Union{AbstractString,Symbol}; index_type::Symbol=:bracket)

Return the parameters with the same name `sym`, but have a different index. Bracket indexing format
in the form of `:sym[index]` is assumed by default. Use `index_type=:dot` for parameters with dot
Expand All @@ -147,7 +147,7 @@ julia> namesingroup(chn, :A; index_type=:dot)
Symbol("A.2")
```
"""
namesingroup(chains::Chains, sym::String; kwargs...) = namesingroup(chains, Symbol(sym); kwargs...)
namesingroup(chains::Chains, sym::AbstractString; kwargs...) = namesingroup(chains, Symbol(sym); kwargs...)
function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket)
if index_type !== :bracket && index_type !== :dot
error("index_type must be :bracket or :dot")
Expand All @@ -161,14 +161,14 @@ function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket)
end

"""
group(chains::Chains, name::Union{String,Symbol}; index_type::Symbol=:bracket)
group(chains::Chains, name::Union{AbstractString,Symbol}; index_type::Symbol=:bracket)

Return a subset of the chain containing parameters with the same `name`, but a different index.

Bracket indexing format in the form of `:name[index]` is assumed by default. Use `index_type=:dot` for parameters with dot
indexing, i.e. `:sym.index`.
"""
function group(chains::Chains, name::Union{String,Symbol}; kwargs...)
function group(chains::Chains, name::Union{AbstractString,Symbol}; kwargs...)
return chains[:, namesingroup(chains, name; kwargs...), :]
end

Expand All @@ -177,8 +177,8 @@ end
Base.getindex(c::Chains, i::Integer) = c[i, :, :]
Base.getindex(c::Chains, i::AbstractVector{<:Integer}) = c[i, :, :]

Base.getindex(c::Chains, v::String) = c[:, Symbol(v), :]
Base.getindex(c::Chains, v::AbstractVector{String}) = c[:, Symbol.(v), :]
Base.getindex(c::Chains, v::AbstractString) = c[:, Symbol(v), :]
Base.getindex(c::Chains, v::AbstractVector{<:AbstractString}) = c[:, Symbol.(v), :]

Base.getindex(c::Chains, v::Symbol) = c[:, v, :]
Base.getindex(c::Chains, v::AbstractVector{Symbol}) = c[:, v, :]
Expand All @@ -199,7 +199,7 @@ _toindex(i, j, k::Integer) = (i, string2symbol(j), k:k)
_toindex(i::Integer, j, k::Integer) = (i:i, string2symbol(j), k:k)

# return an array or a number if a single parameter is specified
const SingleIndex = Union{Symbol,String,Integer}
const SingleIndex = Union{Symbol,AbstractString,Integer}
_toindex(i, j::SingleIndex, k) = (i, string2symbol(j), k)
_toindex(i::Integer, j::SingleIndex, k) = (i, string2symbol(j), k)
_toindex(i, j::SingleIndex, k::Integer) = (i, string2symbol(j), k)
Expand Down Expand Up @@ -542,7 +542,7 @@ Return multiple `Chains` objects, each containing only a single section.
function get_sections(chains::Chains, sections = keys(chains.name_map))
return [Chains(chains, section) for section in sections]
end
get_sections(chains::Chains, section::Union{Symbol, String}) = Chains(chains, section)
get_sections(chains::Chains, section::Union{Symbol, AbstractString}) = Chains(chains, section)

"""
sections(c::Chains)
Expand Down Expand Up @@ -727,7 +727,7 @@ function _clean_sections(chains::Chains, sections)
haskey(chains.name_map, Symbol(section))
end
end
function _clean_sections(chains::Chains, section::Union{String,Symbol})
function _clean_sections(chains::Chains, section::Union{AbstractString,Symbol})
return haskey(chains.name_map, Symbol(section)) ? section : ()
end
_clean_sections(::Chains, ::Nothing) = nothing
Expand Down
10 changes: 5 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ end

Convert strings to symbols.

If `x isa String`, the corresponding `Symbol` is returned. Likewise, if
`x isa AbstractVector{String}`, the corresponding vector of `Symbol`s is returned. In all
other cases, input `x` is returned.
If `x isa AbstractString`, the corresponding `Symbol` is returned.
Likewise, if `x isa AbstractVector{<:AbstractString}`, the corresponding vector of `Symbol`s is returned.
In all other cases, input `x` is returned.
"""
string2symbol(x) = x
string2symbol(x::String) = Symbol(x)
string2symbol(x::AbstractVector{String}) = Symbol.(x)
string2symbol(x::AbstractString) = Symbol(x)
string2symbol(x::AbstractVector{<:AbstractString}) = Symbol.(x)

#################### Mathematical Operators ####################
function cummean(x::AbstractArray)
Expand Down
66 changes: 50 additions & 16 deletions test/diagnostic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,42 @@ end
end

@testset "indexing tests" begin
@test chn[:,1,:] isa AbstractMatrix
@test chn[200:300, "param_1", :] isa AbstractMatrix
@test chn[200:300, ["param_1", "param_3"], :] isa Chains
@test chn[200:300, "param_1", 1] isa AbstractVector
@test size(chn[:,1,:]) == (niter, nchains)
@test chn[:,1,1] == val[:,1,1]
@test chn[:,1,2] == val[:,1,2]
c = chn[:, 1, :]
@test c isa AbstractMatrix
@test size(c) == (niter, nchains)
@test c == val[:, 1, :]

for i in 1:2
c = chn[:, 1, i]
@test c isa AbstractVector
@test length(c) == niter
@test c == val[:, 1, i]
end

for p in (:param_1, "param_1", SubString("param_1", 1))
c = chn[200:300, p, :]
@test c isa AbstractMatrix
@test size(c) == (101, size(chn, 3))
@test c == val[200:300, 1, :]

c = chn[200:300, p, 1]
@test c isa AbstractVector
@test length(c) == 101
@test c == val[200:300, 1, 1]
end

for ps in (
[:param_1, :param_3],
["param_1", "param_3"],
[SubString("param_1", 1), "param_3"],
["param_1", SubString("param_3", 1)],
[SubString("param_1", 1), SubString("param_3", 1)],
)
c = chn[200:300, ps, :]
@test c isa Chains
@test size(c) == (101, 2, nchains)
@test c.value.data == val[200:300, [1, 3], :]
end
end

@testset "names and groups tests" begin
Expand All @@ -116,18 +145,23 @@ end
(@inferred replacenames(chn, Dict("param_2" => "param[2]",
"param_3" => "param[3]"))).value
@test names(chn2) == [:param_1, Symbol("param[2]"), Symbol("param[3]"), :param_4]
@test namesingroup(chn2, "param") == Symbol.(["param[2]", "param[3]"])
for p in (:param, "param", SubString("param", 1))
@test namesingroup(chn2, p) == Symbol.(["param[2]", "param[3]"])
end

chn3 = group(chn2, "param")
@test names(chn3) == Symbol.(["param[2]", "param[3]"])
@test chn3.value == chn[:, [:param_2, :param_3], :].value
for p in (:param, "param", SubString("param", 1))
chn3 = group(chn2, p)
@test names(chn3) == Symbol.(["param[2]", "param[3]"])
@test chn3.value == chn[:, [:param_2, :param_3], :].value
end

stan_chn = Chains(rand(100, 3, 1), ["a.1", "a[2]", "b"])
@test namesingroup(stan_chn, "a"; index_type=:dot) == [Symbol("a.1")]
@test namesingroup(stan_chn, :a; index_type=:dot) == [Symbol("a.1")]
@test names(group(stan_chn, :a; index_type=:dot)) == [Symbol("a.1")]
@test_throws Exception namesingroup(stan_chn, :a; index_type=:x)
@test_throws Exception group(stan_chn, :a; index_type=:x)
for p in (:a, "a", SubString("a", 1))
@test namesingroup(stan_chn, p; index_type=:dot) == [Symbol("a.1")]
@test names(group(stan_chn, p; index_type=:dot)) == [Symbol("a.1")]
@test_throws Exception namesingroup(stan_chn, p; index_type=:x)
@test_throws Exception group(stan_chn, p; index_type=:x)
end
end

@testset "function tests" begin
Expand Down