Skip to content

Commit

Permalink
Merge pull request #15 from kinleyid/model-free-constants
Browse files Browse the repository at this point in the history
Make "model-free" discount function more robust
  • Loading branch information
kinleyid authored Nov 22, 2024
2 parents a72e6b3 + 6d12217 commit 86ff6f1
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 123 deletions.
16 changes: 16 additions & 0 deletions R/internal-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,19 @@ validate_td_data <- function(data, required_columns) {
}

}

initialize_discount_function <- function(disc_func, data) {
if ('init' %in% names(disc_func)) {
# Run init() and do some validation
disc_func <- disc_func$init(disc_func, data)
stopifnot(
is.list(disc_func$par_starts),
is.list(disc_func$par_lims),
all(names(disc_func$par_starts) == names(disc_func$par_lims)),
all(vapply(disc_func$par_lims, length, integer(1)) == 2),
is.function(disc_func$fn),
all(names(formals(disc_func$fn)) == c('data', 'p'))
)
}
return(disc_func)
}
13 changes: 4 additions & 9 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ ED50 <- function(mod, val_del = NULL) {
#' @param ... Further arguments passed to `integrate()`.
#' @return AUC value.
#' @note
#' Calculation of the area always begins from delay 0, where an indifference point of 1 is assumed.
#' An indifference point of 1 is assumed at delay 0.
#' @examples
#' \dontrun{
#' data("td_bc_single_ptpt")
Expand Down Expand Up @@ -107,10 +107,6 @@ AUC <- function(obj, min_del = 0, max_del = NULL, val_del = NULL, del_transform
# Model-based AUC
stopifnot(inherits(obj, 'td_um'))
max_del <- max_del %def% max(obj$data$del)
if (obj$config$discount_function$name == 'model-free') {
# Assume indiff = 1 at del = 0
obj$optim$par <- c(c('del_0' = 1), obj$optim$par)
}
if (is.null(val_del)) {
if ('val_del' %in% names(obj$data)) {
val_del <- mean(obj$data$val_del)
Expand Down Expand Up @@ -186,10 +182,9 @@ nonsys <- function(obj) {
if (obj$config$discount_function$name != 'model-free') {
stop('Discount function must be "model-free" to check for non-systematic discounting.')
} else {
cf <- coef(obj)
cf <- cf[grep('del_', names(cf))]
indiffs <- unname(cf)
delays <- as.numeric(gsub('del_', '', names(cf)))
data <- indiffs(obj)
indiffs <- data$indiff
delays <- data$del
}
} else {
stop('Input must be a data.frame or a model of class td_bcnm, td_ipm, or td_ddm.')
Expand Down
17 changes: 4 additions & 13 deletions R/td_bcnm.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ td_bcnm <- function(
choice_rule = c('logistic', 'probit', 'power'),
fixed_ends = F,
fit_err_rate = F,
# robust = F,
# param_ranges = NULL,
gamma_par_starts = c(0.01, 1, 100),
eps_par_starts = c(0.01, 0.25),
silent = T,
Expand Down Expand Up @@ -118,9 +116,10 @@ td_bcnm <- function(
cand_mod <- list(data = data)
class(cand_mod) <- c('td_bcnm', 'td_um')
for (cand_fn in cand_fns) {

cand_fn <- initialize_discount_function(cand_fn, data)
config <- args
config$discount_function <- cand_fn

cand_mod$config <- config

# Get prob. model with the given settings but parameter values unspecified
Expand All @@ -137,11 +136,7 @@ td_bcnm <- function(
}

# Get parameter starting values
if (is.function(config$discount_function$par_starts)) {
par_starts <- config$discount_function$par_starts(data)
} else {
par_starts <- config$discount_function$par_starts
}
par_starts <- config$discount_function$par_starts
# Add gamma start values
par_starts <- c(
par_starts,
Expand All @@ -160,11 +155,7 @@ td_bcnm <- function(
}

# Get parameter bounds
if (is.function(config$discount_function$par_lims)) {
par_lims <- config$discount_function$par_lims(data)
} else {
par_lims <- config$discount_function$par_lims
}
par_lims <- config$discount_function$par_lims
# Add gamma limits
par_lims <- c(
par_lims,
Expand Down
35 changes: 16 additions & 19 deletions R/td_ddm.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,30 +95,29 @@ td_ddm <- function(
# par_starts = list(max_abs_drift = c(0.1, 1, 10)))
# }
drift_trans$name <- drift_transform

# Get a list of candidate discount functions
disc_func_cands <- get_candidate_discount_functions(arg = discount_function)

# Run optimization for each candidate discount function
best_crit <- Inf
best_mod <- list()
for (disc_func in disc_func_cands) {


disc_func <- initialize_discount_function(disc_func, data)

# Candidate model
cand_mod <- list(data = data,
config = list(discount_function = disc_func,
drift_transform = drift_trans))
class(cand_mod) <- c('td_ddm', 'td_um')

# Get parameter bounds for discount function
if (is.function(disc_func$par_lims)) {
par_lims <- disc_func$par_lims(data)
} else {
par_lims <- disc_func$par_lims
}
# Add DDM parameters bounds and drift transform parameter bounds
# Get parameter bounds and starts for
# discount function
# DDM
# drift transform
par_lims <- c(
par_lims,
disc_func$par_lims,
list(
v = c(0, Inf),
beta = c(0, 1),
Expand All @@ -127,15 +126,8 @@ td_ddm <- function(
),
drift_trans$par_lims
)
# Get discount function parameter starting values
if (is.function(disc_func$par_starts)) {
par_starts <- disc_func$par_starts(data)
} else {
par_starts <- disc_func$par_starts
}
# Add DDM parameter starting values and drift transform parameter starting values
par_starts <- c(
par_starts,
disc_func$par_starts,
list(
v = v_par_starts,
beta = beta_par_starts,
Expand Down Expand Up @@ -213,7 +205,12 @@ get_linpred_func_ddm <- function(discount_function, drift_transform) {

linpred_func <- function(data, par) {
# Compute subjective value difference
svd <- data$val_imm - data$val_del*discount_function$fn(data, par)
tryCatch(
svd <- data$val_imm - data$val_del*discount_function$fn(data, par),
error = function(e) {
browser()
}
)
# Compute drift rate
drift <- svd*par['v']
# Transform drift rate
Expand Down
136 changes: 77 additions & 59 deletions R/td_fn.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#' @param fn Function that takes a data.frame and a vector of named parameters and returns a vector of values between 0 and 1.
#' @param par_starts A named list of vectors, each specifying possible starting values for a parameter to try when running optimization.
#' @param par_lims A named list of vectors, each specifying the bounds to impose of a parameters. Any parameter for which bounds are unspecified are assumed to be unbounded.
#' @param init A function to initialize the td_fn object. It should take 2 arguments: "self" (the td_fn object being initialized) and "data" (the data used for initialization).
#' @param ED50 A function which, given a named vector of parameters \code{p} and optionally a value of \code{del_val}, computes the ED50. If there is no closed-form solution, this should return the string "non-analytic". If the ED50 is not well-defined, this should return the string "none". As a shortcut for these latter 2 cases, the strings "non-analytic" and "none" can be directly supplied as arguments.
#' @param par_chk Optionally, this is a function that checks the parameters to ensure they meet some criteria. E.g., for the dual-systems-exponential discount function, we require k1 < k2.
#' @return An object of class `td_fn`.
Expand All @@ -20,7 +21,7 @@
#' fn = function(data, p) (1 - p['b'])*exp(-p['k']*data$del) + p['b'],
#' par_starts = list(k = c(0.001, 0.1), b = c(0.001, 0.1)),
#' par_lims = list(k = c(0, Inf), b = c(0, 1)),
#' ED50 = function(...) 'non-analytic'
#' ED50 = 'non-analytic'
#' )
#' mod <- td_bcnm(td_bc_single_ptpt, discount_function = custom_discount_function, fit_err_rate = T)
#' }
Expand All @@ -35,10 +36,11 @@ td_fn <- function(predefined = c('hyperbolic',
'nonlinear-time-exponential',
'model-free',
'constant'),
name = NULL,
name = 'unnamed',
fn = NULL,
par_starts = NULL,
par_lims = NULL,
init = NULL,
ED50 = NULL,
par_chk = NULL) {

Expand All @@ -47,47 +49,68 @@ td_fn <- function(predefined = c('hyperbolic',

if (missing(predefined)) {

if (missing(name) | missing(fn) | missing(par_starts)) {
stop('To create a custom discount funciton, "name", "fn", and "par_starts" must all be provided to td_fn')
}
stopifnot(
is.character(name),
length(name) == 1
)
out$name <- name
if (!all(names(formals(fn)) == c('data', 'p'))) {
stop('fn must take 2 arguments: data (a dataframe) and p (a vector of named parameters)')

if (missing(fn)) {
if (missing(init)) {
stop('fn must be supplied if it will not be created by init')
} else {
fn <- function(data, p) NA
}
} else {
out$fn <- fn
stopifnot(is.function(fn))
if (!all(names(formals(fn)) == c('data', 'p')))
stop('fn must take 2 arguments: data (a dataframe) and p (a vector of named parameters)')
}
out$fn <- fn

if (is.function(par_starts)) {
out$par_starts <- par_starts
} else {
if (is.null(names(par_starts))) {
stop('par_starts must be a named list')
if (missing(par_starts)) {
if (missing(init)) {
stop('par_starts must be supplied if it will not be created by init')
} else {
out$par_starts <- as.list(par_starts)
par_starts <- list(placeholder = NA)
}
} else {
stopifnot(
is.list(par_starts),
!is.null(names(par_starts)),
all(vapply(par_starts,
function(x) is.numeric(x) || is.integer(x),
logical(1))))
}
out$par_starts <- par_starts

if (is.null(par_lims)) {
if (missing(par_lims)) {
par_lims <- list()
} else {
if (is.function(par_lims)) {
out$par_lims <- par_lims
} else {
if (is.null(names(par_lims)) | any(names(par_lims) == '')) {
stop('Every element of par_lims must have a name corresponding to a different parameter')
} else if (!all(vapply(par_lims, length, numeric(1)) == 2)) {
stop('par_lims must be a list of 2-element vectors')
}

for (par_name in names(par_starts)) {
if (!(par_name %in% names(par_lims))) {
par_lims[[par_name]] <- c(-Inf, Inf)
}
}

out$par_lims <- par_lims
stopifnot(
is.list(par_lims),
!is.null(names(par_lims))
)
extra_names <- setdiff(names(par_lims), names(par_starts))
if (length(extra_names) > 0) {
stop(sprintf('parameter(s) %s exist in par_lims but not par_starts',
paste(extra_names, collapse = ' and ')))
}

}
for (par_name in names(par_starts)) {
if (!(par_name %in% names(par_lims))) {
par_lims[[par_name]] <- c(-Inf, Inf)
}
}
out$par_lims <- par_lims

if (!missing(init)) {
stopifnot(is.function(init))
if (!all(names(formals(init)) == c('self', 'data')))
stop('init must take 2 arguments: "self" (the td_fn object being initialized) and "data" (the data being used to initialize)')
out$init <- init
}

if (is.null(ED50)) {
out$ED50 <- function(...) 'non-analytic'
Expand Down Expand Up @@ -175,7 +198,7 @@ td_fn <- function(predefined = c('hyperbolic',
par_lims = list(
k = c(0, Inf),
s = c(0, Inf)),
ED50 = function(p, ...) function(p, ...) (1/p['k']) ^ (1/p['s']))
ED50 = function(p, ...) (1/p['k']) ^ (1/p['s']))

} else if (name == 'scaled-exponential') {

Expand All @@ -192,7 +215,7 @@ td_fn <- function(predefined = c('hyperbolic',
} else if (name == 'dual-systems-exponential') {

out <- td_fn(name = name,
fn = function(data, p) function(data, p) p['w']*exp(-p['k1']*data$del) + (1 - p['w'])*exp(-p['k2']*data$del),
fn = function(data, p) p['w']*exp(-p['k1']*data$del) + (1 - p['w'])*exp(-p['k2']*data$del),
par_starts = list(
w = c(0.1, 0.5, 0.9),
k1 = c(0.001, 0.01, 0.1),
Expand Down Expand Up @@ -230,33 +253,28 @@ td_fn <- function(predefined = c('hyperbolic',
} else if (name == 'model-free') {

out <- td_fn(name = name,
fn = function(data, p) {
p <- p[grep('del_', names(p))]
# Round parameters and delays to 10 decimal points for comparison
dels <- round(as.numeric(gsub('del_', '', names(p))), 10)
xout <- round(data$del, 10)
get_yout <- function(xout_value) {
if (xout_value %in% dels) {
return(p[which(dels == xout_value)])
} else {
interp_result <- approx(x = dels, y = p, xout = xout_value)
return(interp_result$y)
}
fn = function(data, p) 'placeholder',
par_starts = list(placeholder = 0),
par_lims = list(placeholder = c(0, 0)),
init = function(self, data) {
# Get unique delays
delays <- sort(unique(data$del))
# Get starts and limits for free parameters
par_names <- sprintf('indiff_%s', seq_along(delays))
par_starts <- rep(list(0.5), length(delays))
names(par_starts) <- par_names
par_lims <- rep(list(c(0, 1)), length(delays))
names(par_lims) <- par_names
# Add to self
self$par_starts <- par_starts
self$par_lims <- par_lims
# Get interpolation function
self$fn <- function(data, p) {
approx(x = c(0, delays),
y = c(1, p[sprintf('indiff_%s', seq_along(delays))]),
xout = data$del)[['y']]
}
yout <- vapply(xout, get_yout, numeric(1))
return(yout)},
par_starts = function(data) {
unique_delays <- unique(data$del)
out <- as.list(rep(0.5, length(unique_delays)))
# Round to 10 decimal points to be able to align delay values
names(out) <- sprintf('del_%.10f', unique_delays)
return(out)},
par_lims = function(data) {
unique_delays <- unique(data$del)
out <- rep(list(c(0, 1)), length(unique_delays))
# Round to 10 decimal points to be able to align delay values
names(out) <- sprintf('del_%.10f', unique_delays)
return(out)},
return(self)},
ED50 = 'none')


Expand Down
Loading

0 comments on commit 86ff6f1

Please sign in to comment.