Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix num_warmups arg, added test for cmdstan cmdline args #78

Merged
merged 3 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
tmp
Manifest.toml
deps/data/bridgestan/bin
/.vscode
7 changes: 3 additions & 4 deletions src/stanrun/cmdline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ function cmdline(m::SampleModel, id; kwargs...)
cmd = ``
# Handle the model name field for unix and windows
cmd = `$(m.exec_path)`

if m.use_cpp_chains
cmd = :num_threads in keys(kwargs) ? `$cmd num_threads=$(m.num_threads)` : `$cmd`
cmd = `$cmd method=sample num_chains=$(m.num_cpp_chains)`
Expand All @@ -25,7 +24,7 @@ function cmdline(m::SampleModel, id; kwargs...)
end

cmd = :num_samples in keys(kwargs) ? `$cmd num_samples=$(m.num_samples)` : `$cmd`
cmd = :num_warmup in keys(kwargs) ? `$cmd num_warmup=$(m.num_warmups)` : `$cmd`
cmd = :num_warmups in keys(kwargs) ? `$cmd num_warmup=$(m.num_warmups)` : `$cmd`
cmd = :save_warmup in keys(kwargs) ? `$cmd save_warmup=$(m.save_warmup)` : `$cmd`
cmd = :save_warmup in keys(kwargs) ? `$cmd thin=$(m.thin)` : `$cmd`
cmd = `$cmd adapt engaged=$(m.engaged)`
Expand All @@ -38,8 +37,8 @@ function cmdline(m::SampleModel, id; kwargs...)
cmd = :window in keys(kwargs) ? `$cmd window=$(m.window)` : `$cmd`
cmd = :save_metric in keys(kwargs) ? `$cmd save_metric=$(m.save_metric)` : `$cmd`

# Algorithm section
cmd = :algorithm in keys(kwargs) ? `$cmd algorithm=$(string(m.algorithm))` : `$cmd`
# Algorithm section, algorithm can only be HMC
cmd = `$cmd algorithm=$(string(m.algorithm))`
if m.algorithm == :hmc
cmd = :engine in keys(kwargs) ? `$cmd engine=$(string(m.engine))` : `$cmd`
if m.engine == :nuts
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ if haskey(ENV, "CMDSTAN") || haskey(ENV, "JULIA_CMDSTAN_HOME")
"test_basic_runs/test_bernoulli_dict.jl",
"test_basic_runs/test_bernoulli_array_dict_1.jl",
"test_basic_runs/test_bernoulli_array_dict_2.jl",
"test_basic_runs/test_parse_interpolate.jl"
"test_basic_runs/test_parse_interpolate.jl",
"test_basic_runs/test_cmdstan_args.jl",
]

@testset "Bernoulli basic run tests" begin
Expand Down Expand Up @@ -242,4 +243,3 @@ if haskey(ENV, "CMDSTAN") || haskey(ENV, "JULIA_CMDSTAN_HOME")
else
println("\nCMDSTAN and JULIA_CMDSTAN_HOME not set. Skipping tests")
end

59 changes: 59 additions & 0 deletions test/test_basic_runs/test_cmdstan_args.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using StanSample, Test

ProjDir = @__DIR__
cd(ProjDir) # do

bernoulli_model = "
data {
int<lower=1> N;
array[N] int<lower=0,upper=1> y;
}
parameters {
real<lower=0,upper=1> theta;
}
model {
theta ~ beta(1,1);
y ~ bernoulli(theta);
}
";

sm = SampleModel("bernoulli", bernoulli_model)
observeddata = Dict("N" => 10, "y" => [0, 1, 0, 1, 0, 0, 0, 0, 0, 1])
rc = stan_sample(
sm;
data=observeddata,
num_samples=13,
num_warmups=17,
save_warmup=true,
num_chains=1,
sig_figs=2,
stepsize=0.7,
)

@test success(rc)
samples = read_samples(sm, :array)

shape = size(samples)
# number of samples, number of chains, number of parameters
@test shape == (30, 1, 1)

# read the log file
f = open(sm.log_file[1], "r")
# remove leading whitespace and chop off the "(default)" suffix
config = [chopsuffix(lstrip(x), r"\s+\(default\)$"i) for x in eachline(f) if length(x) > 0]
close(f)
# check that the config is as expected

required_entries = [
"method = sample",
"num_samples = 13",
"num_warmup = 17",
"save_warmup = true",
"num_chains = 1",
"sig_figs = 2",
"stepsize = 0.7",
]

for entry in required_entries
@test entry in config
end
Loading