Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added export constructors (I used the wrong local branch name). #73

Merged
merged 5 commits into from
Mar 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ SpecialFunctions 0.7.0
Distributions 0.16.2
AxisArrays 0.3.0
KernelDensity 0.5.1
DataFrames 0.17.1
6 changes: 5 additions & 1 deletion src/MCMCChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import StatsBase: autocor, autocov, countmap, counts, describe, predict,
quantile, sample, sem, summarystats, sample, AbstractWeights
import LinearAlgebra: diag
import Serialization: serialize, deserialize
import Base: sort, range, names, get, hash
import Base: sort, range, names, get, hash, convert
import Statistics: cor
import Core.Array
import DataFrames: DataFrame

using RecipesBase
import RecipesBase: plot
Expand All @@ -21,6 +23,7 @@ const axes = Base.axes
export Chains, getindex, setindex!, chains, setinfo, chainscat
export describe, set_section, get_params, sections
export sample, AbstractWeights
export Array, DataFrame, sort_sections, convert

# export diagnostics functions
export discretediag, gelmandiag, gewekediag, heideldiag, rafterydiag
Expand Down Expand Up @@ -50,6 +53,7 @@ include("utils.jl")

include("chains.jl")
include("chainsummary.jl")
include("constructors.jl")
include("discretediag.jl")
include("fileio.jl")
include("gelmandiag.jl")
Expand Down
2 changes: 2 additions & 0 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ Base.first(c::AbstractChains) = first(c.value[Axis{:iter}].val)
Base.step(c::AbstractChains) = step(c.value[Axis{:iter}].val)
Base.last(c::AbstractChains) = last(c.value[Axis{:iter}].val)

Base.convert(::Type{Array}, chn::MCMCChains.Chains) = convert(Array, chn.value)

#################### Auxilliary Functions ####################

function Base.hash(c::Chains)
Expand Down
206 changes: 206 additions & 0 deletions src/constructors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
function sort_sections(chn::MCMCChains.AbstractChains)
smap = keys(chn.name_map)
section_list = Vector{Symbol}(undef, length(smap))
indx = 1
if :parameters in smap
section_list[1] = :parameters
indx += 1
end
if :internals in smap
section_list[end] = :internals
end
for par in smap
if !(par in [:parameters, :internals])
section_list[indx] = par
indx += 1
end
end
section_list
end

"""

# Array

Array constructor from an MCMCChains.Chains object. Returns 3 dimensionsal
array or an Array of 2 dimensional Arrays. If only a single parameter is selected for
inclusion, a dimension is dropped in both cases, as is e.g. required by cde(), etc.

### Method
```julia
Array(
chn::MCMCChains.AbstractChains,
sections::Vector{Symbol};
append_chains::Bool,
remove_missing_union::Bool
)
```

### Required arguments
```julia
* `chn` : Chains object to convert to an Array
* `sections = Symbol[]` : Sections from the Chains object to be included
```

### Optional arguments
```julia
* `append_chains=true` : Append chains into a single column
* ` remove_missing_union=true` : Convert Union{Missing, Real} to Float64
```

### Examples
```julia
* `Array(chns)` : Array with chain values are appended
* `Array(chns[:par])` : Array with selected parameter chain values are appended
* `Array(chns, [:parameters])` : Array with only :parameter section
* `Array(chns, [:parameters, :internals])` : Array also includes :internals section
* `Array(chns, append_chains=false)` : Array of Arrays, each chain in its own array
* `Array(chns, remove_missing_union=false)` : No conversion to remove missing values
```

"""
function Array(chn::MCMCChains.AbstractChains,
sections::Vector{Symbol}=Symbol[];
append_chains=true, remove_missing_union=true)

section_list = length(sections) == 0 ? sort_sections(chn) : sections
d, p, c = size(chn.value.data)

local b
if append_chains
first_parameter = true
for section in section_list
for par in chn.name_map[section]
x = get(chn, Symbol(par))
d, c = size(x[Symbol(par)])
if first_parameter
if remove_missing_union
b = reshape(convert(Array{Float64}, x[Symbol(par)]), d*c)[:, 1]
else
b = reshape(x[Symbol(par)], d*c)[:, 1]
end
p == 1 && (b = reshape(b, size(b, 1)))
first_parameter = false
else
if remove_missing_union
b = hcat(b, reshape(convert(Array{Float64}, x[Symbol(par)]), d*c)[:, 1])
else
b = hcat(b, reshape(x[Symbol(par)], d*c)[:, 1])
end
end
end
end
else
b=Vector(undef, c)
for i in 1:c
first_parameter = true
for section in section_list
for par in chn.name_map[section]
x = get(chn, Symbol(par))
d, c = size(x[Symbol(par)])
if first_parameter
if remove_missing_union
b[i] = convert(Array{Real}, x[Symbol(par)][:, i])
else
b[i] = x[Symbol(par)][:, i]
end
p == 1 && (b[i] = reshape(b[i], size(b[i], 1)))
first_parameter = false
else
if remove_missing_union
b[i] = hcat(b[i], convert(Array{Real}, x[Symbol(par)][:, i]))
else
b[i] = hcat(b[i], x[Symbol(par)][:, i])
end
end
end
end
end
end
b
end

Base.convert(::Type{Array}, chn::MCMCChains.Chains) = convert(Array, chn.value)

"""

# DataFrame

DataFrame constructor from an MCMCChains.Chains object.
Returns either a DataFrame or an Array{DataFrame}

### Method
```julia
DataFrame(
chn::MCMCChains.AbstractChains,
sections::Vector{Symbo);
append_chains::Bool,
remove_missing_union::Bool
)
```

### Required arguments
```julia
* `chn` : Chains object to convert to an DataFrame
* `sections = Symbol[]` : Sections form the Chains object to be included
```

### Optional arguments
```julia
* `append_chains=true` : Append chains into a single column
* ` remove_missing_union=true` : Remove Union{Missing, Real} and AxisArray stuff
```

### Examples
```julia
* `DataFrame(chns)` : DataFrame with chain values are appended
* `DataFrame(chns[:par])` : DataFrame with single parameter (chain values are appended)
* `DataFrame(chns, [:parameters])` : DataFrame with only :parameter section
* `DataFrame(chns, [:parameters, :internals])` : DataFrame also includes :internals section
* `DataFrame(chns, append_chains=false)` : Array of DataFrame, each chain in its own array
* `DataFrame(chns, remove_missing_union=false)` : No conversion to remove missing values
```

"""
function DataFrame(chn::MCMCChains.AbstractChains,
sections::Vector{Symbol}=Symbol[];
append_chains=true, remove_missing_union=true)

section_list = length(sections) == 0 ? sort_sections(chn) : sections
d, p, c = size(chn.value.data)

local b
if append_chains
b = DataFrame()
for section in section_list
for par in chn.name_map[section]
x = get(chn, Symbol(par))
d, c = size(x[Symbol(par)])
if remove_missing_union
b = hcat(b, DataFrame(Symbol(par) => reshape(convert(Array{Float64},
x[Symbol(par)]), d*c)[:, 1]))
else
b = hcat(b, DataFrame(Symbol(par) => reshape(x[Symbol(par)], d*c)[:, 1]))
end
end
end
else
b = Vector{DataFrame}(undef, c)
for i in 1:c
b[i] = DataFrame()
for section in section_list
for par in chn.name_map[section]
x = get(chn, Symbol(par))
d, c = size(x[Symbol(par)])
if remove_missing_union
b[i] = hcat(b[i], DataFrame(Symbol(par) => convert(Array{Float64},
x[Symbol(par)])[:, 1]))
else
b[i] = hcat(b[i], DataFrame(Symbol(par) => x[Symbol(par)][:,1]))
end
end
end
end
end
b
end
38 changes: 38 additions & 0 deletions test/arrayconstructor_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using Turing, MCMCChains, Test

@testset "Array constructor tests" begin

@model gdemo(x) = begin
m ~ Normal(1, 0.01)
s ~ Normal(5, 0.01)
end

model = gdemo([1.5, 2.0])
sampler = HMC(1000, 0.01, 5)

chns = [sample(model, sampler, save_state=true) for i in 1:4]
chns = chainscat(chns...)

d, p, c = size(chns.value.data)

@test size(Array(chns)) == (d*c, p)
@test size(Array(chns, [:parameters])) == (d*c, 2)
@test size(Array(chns, [:parameters, :internals])) == (d*c, p)
@test size(Array(chns, [:internals])) == (d*c, 6)
@test size(Array(chns, append_chains=true)) == (d*c, p)
@test size(Array(chns, append_chains=false)) == (4,)
@test size(Array(chns, append_chains=false)[1]) == (d, p)
@test typeof(Array(chns, append_chains=true)) == Array{Float64, 2}
@test size(Array(chns, remove_missing_union=false)) == (d*c, p)
@test size(Array(chns, append_chains=false, remove_missing_union=false)) == (4,)
@test size(Array(chns, append_chains=false, remove_missing_union=false)[1]) == (d, p)
@test typeof(Array(chns, append_chains=true, remove_missing_union=false)) ==
Array{Union{Missing, Float64}, 2}
@test size(Array(chns[:m])) == (d*c,)

Array(chns)
Array(chns[:s])
Array(chns, [:parameters])
Array(chns, [:parameters, :internals])

end
47 changes: 47 additions & 0 deletions test/dfconstructor_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using Turing, MCMCChains, Test

@testset "DataFrame constructor tests" begin

@model gdemo(x) = begin
m ~ Normal(1, 0.01)
s ~ Normal(5, 0.01)
end

model = gdemo([1.5, 2.0])
sampler = HMC(1000, 0.01, 5)

chns = [sample(model, sampler, save_state=true) for i in 1:4]
chn = chainscat(chns...)

df = DataFrame(chn)
@test size(df) == (4000, 8)

df1 = DataFrame(chn, [:parameters])
@test size(df1) == (4000, 2)

df2 = DataFrame(chn, [:internals, :parameters])
@test size(df2) == (4000, 8)

df3 = DataFrame(chn[:s])
@test size(df3) == (4000, 1)

df4 = DataFrame(chn[:lp])
@test size(df4) == (4000, 1)

df5 = DataFrame(chn, [:parameters], append_chains=false)
@test size(df5) == (4, )
@test size(df5[1]) == (1000, 2)

df6 = DataFrame(chn, [:parameters, :internals], append_chains=false)
@test size(df6) == (4, )
@test size(df6[1]) == (1000, 8)

df7 = DataFrame(chn, [:parameters, :internals], remove_missing_union=false)
@test size(df7) == (4000, 8)

df8 = DataFrame(chn, [:parameters, :internals], remove_missing_union=false,
append_chains=false)
@test size(df8) == (4, )
@test size(df8[1]) == (1000, 8)

end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ include("serialization_tests.jl")

# run tests for sampoling api
include("sampling_tests.jl")

# run tests for array constructor
include("arrayconstructor_tests.jl")

# run tests for dataframe constructor
include("dfconstructor_tests.jl")