Skip to content

Commit

Permalink
Function for sequential learning (#391)
Browse files Browse the repository at this point in the history
* incremented dev version

* got rid of unused variable

* turned new_data into a list

* main part ready and passing tests

* styling and fixing docs issue

* Updating docs

* small steps

* adding user ids

* added the consistency vector to the particles

* updating examples

* updating more

* updating examples

* passing rcmdchck

* added deprecation notice for old burnin setting

* skipping the deprecation notice

* removed unnecessary brackets

* updated particle handling

* removed unused code

* refactoring complete

* ready for sequential updating

* we have a loop

* small fix

* added learning examples

* added test

* added tests

* styling and updating docs

* updated expected test output

* have to get rid of a test for now due to platform dependence
  • Loading branch information
osorensen authored Feb 27, 2024
1 parent 8e47867 commit b244710
Show file tree
Hide file tree
Showing 50 changed files with 701 additions and 299 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: BayesMallows
Type: Package
Title: Bayesian Preference Learning with the Mallows Rank Model
Version: 2.0.1.9004
Version: 2.0.1.9005
Authors@R: c(person("Oystein", "Sorensen",
email = "oystein.sorensen.1985@gmail.com",
role = c("aut", "cre"),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export(compute_consensus)
export(compute_expected_distance)
export(compute_mallows)
export(compute_mallows_mixtures)
export(compute_mallows_sequentially)
export(compute_observation_frequency)
export(compute_posterior_intervals)
export(compute_rank_distance)
Expand Down
35 changes: 20 additions & 15 deletions R/burnin.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,47 @@
#' @param model An object of class `BayesMallows` returned from
#' [compute_mallows()] or an object of class `BayesMallowsMixtures` returned
#' from [compute_mallows_mixtures()].
#' @param ... Optional arguments passed on to other methods. Currently not used.
#' @param value An integer specifying the burnin. If `model` is of class
#' `BayesMallowsMixtures`, a single value will be assumed to be the burnin
#' for each model element. Alternatively, `value` can be specified as an
#' integer vector of the same length as `model`, and hence a separate burnin
#' can be set for each number of mixture components.
#' @param ... Optional arguments passed on to other methods. Currently not used.
#'
#' @export
#' @return An object of class `BayesMallows` with burnin set.
#'
#' @family modeling
#'
#' @examples /inst/examples/burnin_example.R
`burnin<-` <- function(model, ...) UseMethod("burnin<-")
#' @example /inst/examples/burnin_example.R
#'
`burnin<-` <- function(model, ..., value) UseMethod("burnin<-")

#' @export
#' @rdname burnin-set
`burnin<-.BayesMallows` <- function(model, value) {
if(inherits(model, "SMCMallows")) {
`burnin<-.BayesMallows` <- function(model, ..., value) {
if (inherits(model, "SMCMallows")) {
stop("Cannot set burnin for SMC model.")
}
validate_integer(value)
if(value >= model$compute_options$nmc) {
if (value >= model$compute_options$nmc) {
stop("Burnin cannot be larger than the number of Monte Carlo samples.")
}
# Workaround as long as we have the deprecation notice for `$<-`
class(model) <- "list"
model$compute_options$burnin <- value
class(model) <- "BayesMallows"
model
}

#' @export
#' @rdname burnin-set
`burnin<-.BayesMallowsMixtures` <- function(model, value) {
for(v in value) validate_integer(v)
if(length(value) == 1) value <- rep(value, length(model))
if(length(value) != length(model)) stop("Wrong number of entries in value.")
`burnin<-.BayesMallowsMixtures` <- function(model, ..., value) {
for (v in value) validate_integer(v)
if (length(value) == 1) value <- rep(value, length(model))
if (length(value) != length(model)) stop("Wrong number of entries in value.")

for(i in seq_along(model)) burnin(model[[i]]) <- value[[i]]
for (i in seq_along(model)) burnin(model[[i]]) <- value[[i]]
model
}

Expand All @@ -57,21 +61,22 @@
#'
#' @family modeling
#'
#' @examples /inst/examples/burnin_example.R
#' @example /inst/examples/burnin_example.R
#'
burnin <- function(model, ...) UseMethod("burnin")

#' @rdname burnin
#' @export
burnin.BayesMallows <- function(model) {
burnin.BayesMallows <- function(model, ...) {
model$compute_options$burnin
}

#' @rdname burnin
#' @export
burnin.BayesMallowsMixtures <- function(model) {
burnin.BayesMallowsMixtures <- function(model, ...) {
lapply(model, burnin)
}

#' @rdname burnin
#' @export
burnin.SMCMallows <- function(model) 0
burnin.SMCMallows <- function(model, ...) 0
4 changes: 3 additions & 1 deletion R/compute_mallows.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ compute_mallows <- function(
validate_rankings(data)
validate_initial_values(initial_values, data)

pfun_values <- extract_pfun_values(model_options, data, pfun_estimate)
pfun_values <- extract_pfun_values(
model_options$metric, data$n_items, pfun_estimate
)

if (is.null(cl)) {
lapplyfun <- lapply
Expand Down
71 changes: 71 additions & 0 deletions R/compute_mallows_sequentially.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#' @title Estimate the Bayesian Mallows Model Sequentially
#'
#' @description Compute the posterior distributions of the parameters of the
#' Bayesian Mallows model using sequential Monte Carlo. This is based on the
#' algorithms developed in
#' \insertCite{steinSequentialInferenceMallows2023;textual}{BayesMallows}.
#' This function differs from [update_mallows()] in that it takes all the data
#' at once, and uses SMC to fit the model step-by-step. Used in this way, SMC
#' is an alternative to Metropolis-Hastings, which may work better in some
#' settings. In addition, it allows visualization of the learning process.
#'
#' @param data A list of objects of class "BayesMallowsData" returned from
#' [setup_rank_data()]. Each list element is interpreted as the data belonging
#' to a given timepoint.
#' @param initial_values An object of class "BayesMallowsPriorSamples" returned
#' from [sample_prior()].
#' @param model_options An object of class "BayesMallowsModelOptions" returned
#' from [set_model_options()].
#' @param smc_options An object of class "SMCOptions" returned from
#' [set_smc_options()].
#' @param compute_options An object of class "BayesMallowsComputeOptions"
#' returned from [set_compute_options()].
#' @param priors An object of class "BayesMallowsPriors" returned from
#' [set_priors()].
#'
#' @param pfun_estimate Object returned from [estimate_partition_function()].
#' Defaults to \code{NULL}, and will only be used for footrule, Spearman, or
#' Ulam distances when the cardinalities are not available, cf.
#' [get_cardinalities()].
#'
#' @return An object of class BayesMallowsSequential.
#'
#' @details This function is very new, and plotting functions and other tools
#' for visualizing the posterior distribution do not yet work. See the examples
#' for some workarounds.
#'
#'
#' @references \insertAllCited{}
#' @export
#'
#' @family modeling
#'
#' @example /inst/examples/compute_mallows_sequentially_example.R
#'
compute_mallows_sequentially <- function(
data,
initial_values,
model_options = set_model_options(),
smc_options = set_smc_options(),
compute_options = set_compute_options(),
priors = set_priors(),
pfun_estimate = NULL) {
pfun_values <- extract_pfun_values(model_options$metric, data[[1]]$n_items, pfun_estimate)
validate_class(initial_values, "BayesMallowsPriorSamples")
alpha_init <- sample(initial_values$alpha, smc_options$n_particles, replace = TRUE)
rho_init <- initial_values$rho[, sample(ncol(initial_values$rho), smc_options$n_particles, replace = TRUE)]

ret <- run_smc(
data = flush(data[[1]]),
new_data = data,
model_options = model_options,
smc_options = smc_options,
compute_options = compute_options,
priors = priors,
initial_values = list(alpha_init = alpha_init, rho_init = rho_init, aug_init = NULL),
pfun_values = pfun_values,
pfun_estimate = pfun_estimate
)
class(ret) <- "SMCMallows"
ret
}
4 changes: 2 additions & 2 deletions R/estimate_partition_function.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ estimate_partition_function <- function(
matrix(c(power, stats::lm(form, data = estimate)$coefficients), ncol = 2)
}

extract_pfun_values <- function(model_options, data, pfun_estimate) {
extract_pfun_values <- function(metric, n_items, pfun_estimate) {
tryCatch(
prepare_partition_function(model_options$metric, data$n_items),
prepare_partition_function(metric, n_items),
error = function(e) {
if (is.null(pfun_estimate)) {
stop(
Expand Down
4 changes: 0 additions & 4 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
#' @param x An object of type `BayesMallows`, returned from
#' [compute_mallows()].
#'
#' @param burnin A numeric value specifying the number of iterations
#' to discard as burn-in. Defaults to `burnin(x)`, and must be
#' provided if `burnin(x)` does not exist. See [assess_convergence()].
#'
#' @param parameter Character string defining the parameter to plot. Available
#' options are `"alpha"`, `"rho"`, `"cluster_probs"`,
#' `"cluster_assignment"`, and `"theta"`.
Expand Down
2 changes: 1 addition & 1 deletion R/predict_top_k.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#'
#' @example /inst/examples/plot_top_k_example.R
#' @family posterior quantities
predict_top_k <- function(model_fit,k = 3) {
predict_top_k <- function(model_fit, k = 3) {
validate_top_k(model_fit)
.predict_top_k(model_fit, k)
}
Expand Down
10 changes: 5 additions & 5 deletions R/setup_rank_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
#' 2 \tab 5 \tab 3\cr
#' }
#'
#' @param user_ids Optional vector of user IDs. Defaults to `NULL`, and only
#' used by [update_mallows()]. If provided, new data can consist of updated
#' partial rankings from users already in the dataset, as described in Section
#' 6 of
#' @param user_ids Optional vector of user IDs. Defaults to `character()`, and
#' only used by [update_mallows()]. If provided, new data can consist of
#' updated partial rankings from users already in the dataset, as described in
#' Section 6 of
#' \insertCite{steinSequentialInferenceMallows2023;textual}{BayesMallows}.
#'
#' @param observation_frequency A vector of observation frequencies (weights) to
Expand Down Expand Up @@ -129,7 +129,7 @@
setup_rank_data <- function(
rankings = NULL,
preferences = NULL,
user_ids = NULL,
user_ids = character(),
observation_frequency = NULL,
validate_rankings = TRUE,
na_action = c("augment", "fail", "omit"),
Expand Down
61 changes: 15 additions & 46 deletions R/smc_misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,57 +29,12 @@ extract_rho_init <- function(model, n_particles) {
model$rho_samples[, 1, thinned_inds, drop = TRUE]
}

prepare_new_data <- function(model, new_data) {
if (!is.null(new_data$user_ids) && !is.null(model$data$user_ids)) {
old_users <- setdiff(model$data$user_ids, new_data$user_ids)
updated_users <- intersect(model$data$user_ids, new_data$user_ids)
new_users <- setdiff(new_data$user_ids, model$data$user_ids)

rankings <- rbind(
model$data$rankings[old_users, , drop = FALSE],
new_data$rankings[c(updated_users, new_users), , drop = FALSE]
)

user_ids <- c(old_users, updated_users, new_users)

data <- setup_rank_data(rankings = rankings, user_ids = user_ids)
new_data <- setup_rank_data(
rankings = rankings[new_users, , drop = FALSE],
user_ids = new_users
)

if (!is.null(model$augmented_rankings)) {
consistent <- matrix(
TRUE,
nrow = nrow(rankings), ncol = model$smc_options$n_particles
)
for (uu in updated_users) {
index <- which(rownames(rankings) == uu)
to_compare <- as.numeric(stats::na.omit(rankings[index, ]))

consistent[index, ] <- apply(model$augmented_rankings[, index, ], 2, function(ar) {
all(ar[ar %in% to_compare] == to_compare)
})
}
data$consistent <- consistent * 1L
}
} else {
rankings <- rbind(model$data$rankings, new_data$rankings)
data <- setup_rank_data(
rankings = rankings,
user_ids = seq_len(nrow(rankings)),
timepoint = c(model$data$timepoint, new_data$timepoint)
)
}
list(data = data, new_data = new_data)
}

run_common_part <- function(
data, new_data, model_options, smc_options, compute_options, priors,
initial_values, pfun_list, model) {
ret <- run_smc(
data = data,
new_data = new_data,
new_data = list(new_data),
model_options = model_options,
smc_options = smc_options,
compute_options = compute_options,
Expand All @@ -89,18 +44,32 @@ run_common_part <- function(
pfun_estimate = pfun_list$pfun_estimate
)

ret$alpha_samples <- ret$alpha_samples[, 1]
ret$rho_samples <- ret$rho_samples[, , 1]
ret <- c(ret, tidy_smc(ret, data$items))
ret$model_options <- model_options
ret$smc_options <- smc_options
ret$compute_options <- compute_options
class(ret$compute_options) <- "list"
ret$priors <- priors
ret$n_items <- model$n_items
ret$n_clusters <- 1
ret$data <- new_data
ret$pfun_values <- pfun_list$pfun_values
ret$pfun_estimate <- pfun_list$pfun_estimate
ret$model_options$metric <- model_options$metric
if (prod(dim(ret$augmented_rankings)) == 0) ret$augmented_rankings <- NULL
ret$items <- data$items
class(ret) <- c("SMCMallows", "BayesMallows")
ret
}

flush <- function(data) {
data$rankings <- data$rankings[integer(), , drop = FALSE]
data$n_assessors <- 0
data$observation_frequency <- data$observation_frequency[integer()]
data$consistent <- data$consistent[integer(), , drop = FALSE]
data$user_ids <- data$user_ids[integer()]
data$timepoint <- data$timepoint[integer()]
data
}
28 changes: 17 additions & 11 deletions R/update_mallows.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ update_mallows.BayesMallowsPriorSamples <- function(
...) {
alpha_init <- sample(model$alpha, smc_options$n_particles, replace = TRUE)
rho_init <- model$rho[, sample(ncol(model$rho), smc_options$n_particles, replace = TRUE)]
pfun_values <- extract_pfun_values(model_options, new_data, pfun_estimate)
pfun_values <- extract_pfun_values(model_options$metric, new_data$n_items, pfun_estimate)

run_common_part(
data = new_data, new_data = new_data, model_options = model_options,
data = flush(new_data), new_data = new_data, model_options = model_options,
smc_options = smc_options, compute_options = compute_options,
priors = priors,
initial_values = list(
Expand Down Expand Up @@ -81,7 +81,7 @@ update_mallows.BayesMallows <- function(
rho_init <- extract_rho_init(model, smc_options$n_particles)

run_common_part(
data = new_data, new_data = new_data, model_options = model_options,
data = flush(new_data), new_data = new_data, model_options = model_options,
smc_options = smc_options, compute_options = compute_options,
priors = priors,
initial_values = list(
Expand All @@ -99,11 +99,9 @@ update_mallows.BayesMallows <- function(
#' @export
#' @rdname update_mallows
update_mallows.SMCMallows <- function(model, new_data, ...) {
datlist <- prepare_new_data(model, new_data)

ret <- run_smc(
data = datlist$data,
new_data = datlist$new_data,
data = model$data,
new_data = list(new_data),
model_options = model$model_options,
smc_options = model$smc_options,
compute_options = model$compute_options,
Expand All @@ -116,14 +114,22 @@ update_mallows.SMCMallows <- function(model, new_data, ...) {
pfun_values = model$pfun_values,
pfun_estimate = model$pfun_estimate
)
model$alpha_samples <- ret$alpha_samples
model$rho_samples <- ret$rho_samples
model$augmented_rankings <- ret$augmented_rankings
model$alpha_samples <- ret$alpha_samples[, 1]
model$rho_samples <- ret$rho_samples[, , 1]
model$augmented_rankings <-
if (prod(dim(ret$augmented_rankings)) == 0) {
NULL
} else {
ret$augmented_rankings
}

tidy_parameters <- tidy_smc(ret, model$items)
model$alpha <- tidy_parameters$alpha
model$rho <- tidy_parameters$rho
model$augmented_rankings <- ret$augmented_rankings
model$data <- datlist$data
items <- model$data$items
model$data <- ret$data
model$data$items <- items

class(model) <- c("SMCMallows", "BayesMallows")
model
Expand Down
Loading

0 comments on commit b244710

Please sign in to comment.