Skip to content

Commit

Permalink
Make "model-free" discount function more robust
Browse files Browse the repository at this point in the history
Previously, the correspondence between delays and indifference points was encoded in the names of the parameter vector. This wasn't great because it required storing floating point data in strings. Now the correspondence is established in an enclosing environment when an init() function is called.
  • Loading branch information
kinleyid committed Nov 22, 2024
1 parent a72e6b3 commit 6d12217
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 123 deletions.
16 changes: 16 additions & 0 deletions R/internal-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,19 @@ validate_td_data <- function(data, required_columns) {
}

}

initialize_discount_function <- function(disc_func, data) {
if ('init' %in% names(disc_func)) {
# Run init() and do some validation
disc_func <- disc_func$init(disc_func, data)
stopifnot(
is.list(disc_func$par_starts),
is.list(disc_func$par_lims),
all(names(disc_func$par_starts) == names(disc_func$par_lims)),
all(vapply(disc_func$par_lims, length, integer(1)) == 2),
is.function(disc_func$fn),
all(names(formals(disc_func$fn)) == c('data', 'p'))
)
}
return(disc_func)
}
13 changes: 4 additions & 9 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ ED50 <- function(mod, val_del = NULL) {
#' @param ... Further arguments passed to `integrate()`.
#' @return AUC value.
#' @note
#' Calculation of the area always begins from delay 0, where an indifference point of 1 is assumed.
#' An indifference point of 1 is assumed at delay 0.
#' @examples
#' \dontrun{
#' data("td_bc_single_ptpt")
Expand Down Expand Up @@ -107,10 +107,6 @@ AUC <- function(obj, min_del = 0, max_del = NULL, val_del = NULL, del_transform
# Model-based AUC
stopifnot(inherits(obj, 'td_um'))
max_del <- max_del %def% max(obj$data$del)
if (obj$config$discount_function$name == 'model-free') {
# Assume indiff = 1 at del = 0
obj$optim$par <- c(c('del_0' = 1), obj$optim$par)
}
if (is.null(val_del)) {
if ('val_del' %in% names(obj$data)) {
val_del <- mean(obj$data$val_del)
Expand Down Expand Up @@ -186,10 +182,9 @@ nonsys <- function(obj) {
if (obj$config$discount_function$name != 'model-free') {
stop('Discount function must be "model-free" to check for non-systematic discounting.')
} else {
cf <- coef(obj)
cf <- cf[grep('del_', names(cf))]
indiffs <- unname(cf)
delays <- as.numeric(gsub('del_', '', names(cf)))
data <- indiffs(obj)
indiffs <- data$indiff
delays <- data$del
}
} else {
stop('Input must be a data.frame or a model of class td_bcnm, td_ipm, or td_ddm.')
Expand Down
17 changes: 4 additions & 13 deletions R/td_bcnm.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ td_bcnm <- function(
choice_rule = c('logistic', 'probit', 'power'),
fixed_ends = F,
fit_err_rate = F,
# robust = F,
# param_ranges = NULL,
gamma_par_starts = c(0.01, 1, 100),
eps_par_starts = c(0.01, 0.25),
silent = T,
Expand Down Expand Up @@ -118,9 +116,10 @@ td_bcnm <- function(
cand_mod <- list(data = data)
class(cand_mod) <- c('td_bcnm', 'td_um')
for (cand_fn in cand_fns) {

cand_fn <- initialize_discount_function(cand_fn, data)
config <- args
config$discount_function <- cand_fn

cand_mod$config <- config

# Get prob. model with the given settings but parameter values unspecified
Expand All @@ -137,11 +136,7 @@ td_bcnm <- function(
}

# Get parameter starting values
if (is.function(config$discount_function$par_starts)) {
par_starts <- config$discount_function$par_starts(data)
} else {
par_starts <- config$discount_function$par_starts
}
par_starts <- config$discount_function$par_starts
# Add gamma start values
par_starts <- c(
par_starts,
Expand All @@ -160,11 +155,7 @@ td_bcnm <- function(
}

# Get parameter bounds
if (is.function(config$discount_function$par_lims)) {
par_lims <- config$discount_function$par_lims(data)
} else {
par_lims <- config$discount_function$par_lims
}
par_lims <- config$discount_function$par_lims
# Add gamma limits
par_lims <- c(
par_lims,
Expand Down
35 changes: 16 additions & 19 deletions R/td_ddm.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,30 +95,29 @@ td_ddm <- function(
# par_starts = list(max_abs_drift = c(0.1, 1, 10)))
# }
drift_trans$name <- drift_transform

# Get a list of candidate discount functions
disc_func_cands <- get_candidate_discount_functions(arg = discount_function)

# Run optimization for each candidate discount function
best_crit <- Inf
best_mod <- list()
for (disc_func in disc_func_cands) {


disc_func <- initialize_discount_function(disc_func, data)

# Candidate model
cand_mod <- list(data = data,
config = list(discount_function = disc_func,
drift_transform = drift_trans))
class(cand_mod) <- c('td_ddm', 'td_um')

# Get parameter bounds for discount function
if (is.function(disc_func$par_lims)) {
par_lims <- disc_func$par_lims(data)
} else {
par_lims <- disc_func$par_lims
}
# Add DDM parameters bounds and drift transform parameter bounds
# Get parameter bounds and starts for
# discount function
# DDM
# drift transform
par_lims <- c(
par_lims,
disc_func$par_lims,
list(
v = c(0, Inf),
beta = c(0, 1),
Expand All @@ -127,15 +126,8 @@ td_ddm <- function(
),
drift_trans$par_lims
)
# Get discount function parameter starting values
if (is.function(disc_func$par_starts)) {
par_starts <- disc_func$par_starts(data)
} else {
par_starts <- disc_func$par_starts
}
# Add DDM parameter starting values and drift transform parameter starting values
par_starts <- c(
par_starts,
disc_func$par_starts,
list(
v = v_par_starts,
beta = beta_par_starts,
Expand Down Expand Up @@ -213,7 +205,12 @@ get_linpred_func_ddm <- function(discount_function, drift_transform) {

linpred_func <- function(data, par) {
# Compute subjective value difference
svd <- data$val_imm - data$val_del*discount_function$fn(data, par)
tryCatch(
svd <- data$val_imm - data$val_del*discount_function$fn(data, par),
error = function(e) {
browser()
}
)
# Compute drift rate
drift <- svd*par['v']
# Transform drift rate
Expand Down
136 changes: 77 additions & 59 deletions R/td_fn.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#' @param fn Function that takes a data.frame and a vector of named parameters and returns a vector of values between 0 and 1.
#' @param par_starts A named list of vectors, each specifying possible starting values for a parameter to try when running optimization.
#' @param par_lims A named list of vectors, each specifying the bounds to impose of a parameters. Any parameter for which bounds are unspecified are assumed to be unbounded.
#' @param init A function to initialize the td_fn object. It should take 2 arguments: "self" (the td_fn object being initialized) and "data" (the data used for initialization).
#' @param ED50 A function which, given a named vector of parameters \code{p} and optionally a value of \code{del_val}, computes the ED50. If there is no closed-form solution, this should return the string "non-analytic". If the ED50 is not well-defined, this should return the string "none". As a shortcut for these latter 2 cases, the strings "non-analytic" and "none" can be directly supplied as arguments.
#' @param par_chk Optionally, this is a function that checks the parameters to ensure they meet some criteria. E.g., for the dual-systems-exponential discount function, we require k1 < k2.
#' @return An object of class `td_fn`.
Expand All @@ -20,7 +21,7 @@
#' fn = function(data, p) (1 - p['b'])*exp(-p['k']*data$del) + p['b'],
#' par_starts = list(k = c(0.001, 0.1), b = c(0.001, 0.1)),
#' par_lims = list(k = c(0, Inf), b = c(0, 1)),
#' ED50 = function(...) 'non-analytic'
#' ED50 = 'non-analytic'
#' )
#' mod <- td_bcnm(td_bc_single_ptpt, discount_function = custom_discount_function, fit_err_rate = T)
#' }
Expand All @@ -35,10 +36,11 @@ td_fn <- function(predefined = c('hyperbolic',
'nonlinear-time-exponential',
'model-free',
'constant'),
name = NULL,
name = 'unnamed',
fn = NULL,
par_starts = NULL,
par_lims = NULL,
init = NULL,
ED50 = NULL,
par_chk = NULL) {

Expand All @@ -47,47 +49,68 @@ td_fn <- function(predefined = c('hyperbolic',

if (missing(predefined)) {

if (missing(name) | missing(fn) | missing(par_starts)) {
stop('To create a custom discount funciton, "name", "fn", and "par_starts" must all be provided to td_fn')
}
stopifnot(
is.character(name),
length(name) == 1
)
out$name <- name
if (!all(names(formals(fn)) == c('data', 'p'))) {
stop('fn must take 2 arguments: data (a dataframe) and p (a vector of named parameters)')

if (missing(fn)) {
if (missing(init)) {
stop('fn must be supplied if it will not be created by init')
} else {
fn <- function(data, p) NA
}
} else {
out$fn <- fn
stopifnot(is.function(fn))
if (!all(names(formals(fn)) == c('data', 'p')))
stop('fn must take 2 arguments: data (a dataframe) and p (a vector of named parameters)')
}
out$fn <- fn

if (is.function(par_starts)) {
out$par_starts <- par_starts
} else {
if (is.null(names(par_starts))) {
stop('par_starts must be a named list')
if (missing(par_starts)) {
if (missing(init)) {
stop('par_starts must be supplied if it will not be created by init')
} else {
out$par_starts <- as.list(par_starts)
par_starts <- list(placeholder = NA)
}
} else {
stopifnot(
is.list(par_starts),
!is.null(names(par_starts)),
all(vapply(par_starts,
function(x) is.numeric(x) || is.integer(x),
logical(1))))
}
out$par_starts <- par_starts

if (is.null(par_lims)) {
if (missing(par_lims)) {
par_lims <- list()
} else {
if (is.function(par_lims)) {
out$par_lims <- par_lims
} else {
if (is.null(names(par_lims)) | any(names(par_lims) == '')) {
stop('Every element of par_lims must have a name corresponding to a different parameter')
} else if (!all(vapply(par_lims, length, numeric(1)) == 2)) {
stop('par_lims must be a list of 2-element vectors')
}

for (par_name in names(par_starts)) {
if (!(par_name %in% names(par_lims))) {
par_lims[[par_name]] <- c(-Inf, Inf)
}
}

out$par_lims <- par_lims
stopifnot(
is.list(par_lims),
!is.null(names(par_lims))
)
extra_names <- setdiff(names(par_lims), names(par_starts))
if (length(extra_names) > 0) {
stop(sprintf('parameter(s) %s exist in par_lims but not par_starts',
paste(extra_names, collapse = ' and ')))
}

}
for (par_name in names(par_starts)) {
if (!(par_name %in% names(par_lims))) {
par_lims[[par_name]] <- c(-Inf, Inf)
}
}
out$par_lims <- par_lims

if (!missing(init)) {
stopifnot(is.function(init))
if (!all(names(formals(init)) == c('self', 'data')))
stop('init must take 2 arguments: "self" (the td_fn object being initialized) and "data" (the data being used to initialize)')
out$init <- init
}

if (is.null(ED50)) {
out$ED50 <- function(...) 'non-analytic'
Expand Down Expand Up @@ -175,7 +198,7 @@ td_fn <- function(predefined = c('hyperbolic',
par_lims = list(
k = c(0, Inf),
s = c(0, Inf)),
ED50 = function(p, ...) function(p, ...) (1/p['k']) ^ (1/p['s']))
ED50 = function(p, ...) (1/p['k']) ^ (1/p['s']))

} else if (name == 'scaled-exponential') {

Expand All @@ -192,7 +215,7 @@ td_fn <- function(predefined = c('hyperbolic',
} else if (name == 'dual-systems-exponential') {

out <- td_fn(name = name,
fn = function(data, p) function(data, p) p['w']*exp(-p['k1']*data$del) + (1 - p['w'])*exp(-p['k2']*data$del),
fn = function(data, p) p['w']*exp(-p['k1']*data$del) + (1 - p['w'])*exp(-p['k2']*data$del),
par_starts = list(
w = c(0.1, 0.5, 0.9),
k1 = c(0.001, 0.01, 0.1),
Expand Down Expand Up @@ -230,33 +253,28 @@ td_fn <- function(predefined = c('hyperbolic',
} else if (name == 'model-free') {

out <- td_fn(name = name,
fn = function(data, p) {
p <- p[grep('del_', names(p))]
# Round parameters and delays to 10 decimal points for comparison
dels <- round(as.numeric(gsub('del_', '', names(p))), 10)
xout <- round(data$del, 10)
get_yout <- function(xout_value) {
if (xout_value %in% dels) {
return(p[which(dels == xout_value)])
} else {
interp_result <- approx(x = dels, y = p, xout = xout_value)
return(interp_result$y)
}
fn = function(data, p) 'placeholder',
par_starts = list(placeholder = 0),
par_lims = list(placeholder = c(0, 0)),
init = function(self, data) {
# Get unique delays
delays <- sort(unique(data$del))
# Get starts and limits for free parameters
par_names <- sprintf('indiff_%s', seq_along(delays))
par_starts <- rep(list(0.5), length(delays))
names(par_starts) <- par_names
par_lims <- rep(list(c(0, 1)), length(delays))
names(par_lims) <- par_names
# Add to self
self$par_starts <- par_starts
self$par_lims <- par_lims
# Get interpolation function
self$fn <- function(data, p) {
approx(x = c(0, delays),
y = c(1, p[sprintf('indiff_%s', seq_along(delays))]),
xout = data$del)[['y']]
}
yout <- vapply(xout, get_yout, numeric(1))
return(yout)},
par_starts = function(data) {
unique_delays <- unique(data$del)
out <- as.list(rep(0.5, length(unique_delays)))
# Round to 10 decimal points to be able to align delay values
names(out) <- sprintf('del_%.10f', unique_delays)
return(out)},
par_lims = function(data) {
unique_delays <- unique(data$del)
out <- rep(list(c(0, 1)), length(unique_delays))
# Round to 10 decimal points to be able to align delay values
names(out) <- sprintf('del_%.10f', unique_delays)
return(out)},
return(self)},
ED50 = 'none')


Expand Down
Loading

0 comments on commit 6d12217

Please sign in to comment.