diff --git a/src/cmdstan/arguments/arg_laplace.hpp b/src/cmdstan/arguments/arg_laplace.hpp index d0c0e27761..b312e7b9b0 100644 --- a/src/cmdstan/arguments/arg_laplace.hpp +++ b/src/cmdstan/arguments/arg_laplace.hpp @@ -21,11 +21,15 @@ class arg_laplace : public categorical_argument { "")); _subarguments.push_back( new arg_single_bool("jacobian", - "When true, include change-of-variables adjustment" - " for constraining parameter transforms", + "When true, include change-of-variables adjustment " + "for constraining parameter transforms.", true)); _subarguments.push_back(new arg_single_int_nonneg( "draws", "Number of draws from the laplace approximation", 1000)); + _subarguments.push_back(new arg_single_bool( + "calculate_lp", + "If true, calculate the log probability of the model at each draw.", + true)); } }; diff --git a/src/cmdstan/command.hpp b/src/cmdstan/command.hpp index 84d0befb9c..64478e7826 100644 --- a/src/cmdstan/command.hpp +++ b/src/cmdstan/command.hpp @@ -232,12 +232,20 @@ int command(int argc, const char *argv[]) { init_filestream_writers(sample_writers, num_chains, id, output_file, "", ".csv", sig_figs, "# "); if (!diagnostic_file.empty()) { - init_filestream_writers(diagnostic_csv_writers, num_chains, id, - diagnostic_file, "", ".csv", sig_figs, "# "); + if (user_method->arg("laplace")) { + init_filestream_writers(diagnostic_json_writers, num_chains, id, + diagnostic_file, "", ".json", sig_figs); + init_null_writers(diagnostic_csv_writers, num_chains); + + } else { + init_filestream_writers(diagnostic_csv_writers, num_chains, id, + diagnostic_file, "", ".csv", sig_figs, "# "); + init_null_writers(diagnostic_json_writers, num_chains); + } } else { init_null_writers(diagnostic_csv_writers, num_chains); + init_null_writers(diagnostic_json_writers, num_chains); } - init_null_writers(diagnostic_json_writers, num_chains); } if (user_method->arg("sample") && get_arg_val(parser, "method", "sample", "adapt", @@ -383,15 +391,17 @@ int command(int argc, const char *argv[]) { } Eigen::VectorXd theta_hat = get_laplace_mode(fname, model); bool jacobian = get_arg_val(*laplace_arg, "jacobian"); + bool calculate_lp + = get_arg_val(*laplace_arg, "calculate_lp"); int draws = get_arg_val(*laplace_arg, "draws"); if (jacobian) { return_code = stan::services::laplace_sample( - model, theta_hat, draws, random_seed, refresh, interrupt, logger, - sample_writers[0]); + model, theta_hat, draws, calculate_lp, random_seed, refresh, + interrupt, logger, sample_writers[0], diagnostic_json_writers[0]); } else { return_code = stan::services::laplace_sample( - model, theta_hat, draws, random_seed, refresh, interrupt, logger, - sample_writers[0]); + model, theta_hat, draws, calculate_lp, random_seed, refresh, + interrupt, logger, sample_writers[0], diagnostic_json_writers[0]); } // ---- laplace end ---- // } else if (user_method->arg("log_prob")) {