Skip to content

Commit

Permalink
Deal with NA values by simply omitting
Browse files Browse the repository at this point in the history
  • Loading branch information
kinleyid committed Dec 7, 2024
1 parent 00eeba7 commit 6d7fef8
Show file tree
Hide file tree
Showing 16 changed files with 83 additions and 39 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ importFrom(stats,approx)
importFrom(stats,approxfun)
importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,complete.cases)
importFrom(stats,filter)
importFrom(stats,fitted)
importFrom(stats,glm)
importFrom(stats,integrate)
importFrom(stats,median)
importFrom(stats,na.omit)
importFrom(stats,optim)
importFrom(stats,plogis)
importFrom(stats,predict)
Expand Down
5 changes: 0 additions & 5 deletions R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,6 @@ predict.td_bcnm <- function(object, newdata = NULL, type = c('link', 'response',

return(predict_indiffs(object, newdata))

# indiff_func <- object$config$discount_function$fn
# indiffs <- indiff_func(newdata, coef(object))
# names(indiffs) <- NULL
# return(indiffs)

}
}

Expand Down
15 changes: 15 additions & 0 deletions R/internal-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,24 @@ validate_td_data <- function(data, required_columns) {

# Validate columns in dataframe

# Check that all required columns are present
missing_cols <- setdiff(required_columns, names(data))
if (length(missing_cols) > 0) {
stop(sprintf('Missing required data column(s): %s', paste(missing_cols, collapse = ', ')))
}

# Omit NA rows
pre_nrow <- nrow(data)
data <- data[complete.cases(data[required_columns]), ]
n_rm <- pre_nrow - nrow(data)
if (n_rm > 0) {
warning(sprintf('Removing %s rows containing missing values', n_rm))
}
if (nrow(data) == 0) {
stop('Dataframe empty after removing missing values')
}

# Check that each column is of the expected type
expectations <- list(
'del' = list(type = c('numeric', 'integer', 'factor'),
lims = c(0, Inf)),
Expand Down Expand Up @@ -95,6 +108,8 @@ validate_td_data <- function(data, required_columns) {
data$imm_chosen <- as.logical(data$imm_chosen)
}

return(data)

}

initialize_discount_function <- function(disc_func, data) {
Expand Down
6 changes: 4 additions & 2 deletions R/td_bclm.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ td_bclm <- function(data,
...) {

# Validate data
validate_td_data(data,
required_columns = c('val_imm', 'val_del', 'del', 'imm_chosen'))
data <- validate_td_data(
data,
required_columns = c('val_imm', 'val_del', 'del', 'imm_chosen')
)
attention_checks(data, warn = T)
invariance_checks(data, warn = T)
if (length(grep('\\.B', names(data))) > 0) {
Expand Down
13 changes: 6 additions & 7 deletions R/td_bcnm.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#' @param eps_par_starts A vector of starting values to try for the "eps" parameter (which controls the error rate) during optimization. Ignored if `fit_err_rate = FALSE`.
#' @param optim_args Additional arguments to pass to \code{optim()}. Default is \code{list(silent = T)}.
#' @param silent Boolean (true by default). The call to \code{optim()} occurs within a \code{try()} wrapper. The value of \code{silent} is passed along to \code{try()}.
#' @param na.action Action to take when data contains \code{NA} values. Default is \code{na.omit}.
#' @param ... Additional arguments to provide finer-grained control over the model configuration.
#' @family nonlinear binary choice model functions
#' @return An object of class \code{td_bcnm} with components \code{data} (containing the data used for fitting), \code{config} (containing the internal configuration of the model, including the \code{discount_function}), and \code{optim} (the output of \code{optim()}).
Expand Down Expand Up @@ -40,7 +39,6 @@ td_bcnm <- function(
eps_par_starts = c(0.01, 0.25),
silent = T,
optim_args = list(),
na.action = na.omit,
...) {

# From a user's POV, it's easier to specify `choice_rule` and `fixed_ends`
Expand Down Expand Up @@ -86,10 +84,11 @@ td_bcnm <- function(
}
}

# Required data columns
validate_td_data(data,
required_columns = c('val_imm', 'val_del', 'del', 'imm_chosen'))
data <- na.action(data)
# Data validation
data <- validate_td_data(
data,
required_columns = c('val_imm', 'val_del', 'del', 'imm_chosen')
)

# Ensure imm_chosen is logical
data$imm_chosen <- as.logical(data$imm_chosen)
Expand Down Expand Up @@ -172,7 +171,7 @@ td_bcnm <- function(
)
)
}
# Run optimization
# Run optimizationn
optimized <- run_optimization(nll_fn,
par_starts,
par_lims,
Expand Down
13 changes: 5 additions & 8 deletions R/td_ddm.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#' @param drift_transform A transform to apply to drift rates. Either \code{"none"} (no transform), \code{"sigmoid"} (sigmoidal transform described by Peters & D'Esposito, 2020, \doi{10.1371/journal.pcbi.1007615}, and Fontanesi et al., 2019, \doi{10.3758/s13423-018-1554-2}), or \code{"bias-correct"} (experimental; see note below).
#' @param optim_args Additional arguments to pass to \code{optim()}. Default is \code{list(silent = T)}.
#' @param silent Boolean (true by default). The call to \code{optim()} occurs within a \code{try()} wrapper. The value of \code{silent} is passed along to \code{try()}.
#' @param na.action Action to take when data contains \code{NA} values. Default is \code{na.omit}.
#' @family drift diffusion model functions
#' @return An object of class \code{td_bcnm} with components \code{data} (containing the data used for fitting), \code{config} (containing the internal configuration of the model, including the \code{discount_function}), and \code{optim} (the output of \code{optim()}).
#' @note
Expand All @@ -47,18 +46,16 @@ td_ddm <- function(
tau_par_starts = c(0.2, 0.8),
drift_transform = c('none', 'sigmoid', 'bias-correct'),
silent = T,
optim_args = list(),
na.action = na.omit) {
optim_args = list()) {
# Input validation--------------------------

# Required data columns
validate_td_data(data,
required_columns = c('val_imm', 'val_del', 'del', 'imm_chosen', 'rt'))
data <- na.action(data)
data <- validate_td_data(data,
required_columns = c('val_imm', 'val_del', 'del', 'imm_chosen', 'rt'))

# Check that RTs are in seconds vs milliseconds
if (median(data$rt) > 500) {
stop('Median RT is greater than 500, meaning RTs are likely in units of milliseconds (or smaller). They should be in units of seconds.')
if (median(data$rt) > 100) {
stop('Median RT is greater than 100, meaning RTs are likely in units of milliseconds (or smaller). They should be in units of seconds.')
}

# Attention checks
Expand Down
6 changes: 2 additions & 4 deletions R/td_ipm.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ get_rss_fn <- function(data, discount_function) {
#' Compute a model of a single subject's indifference points.
#' @param data A data frame with columns \code{indiff} for the pre-computed indifference points and \code{del} for the delay.
#' @param discount_function A vector of strings specifying the name of the discount functions to use, or an object of class \code{td_fn} (used for creating custom discount functions), or a list of objects of class \code{td_fn}.
#' @param na.action Action to take when data contains \code{NA} values. Default is \code{na.omit}.
#' @param optim_args A list of additional args to pass to \code{optim}.
#' @param silent A Boolean specifying whether the call to \code{optim} (which occurs in a \code{try} block) should be silent on error.
#' @family indifference point model functions.
Expand Down Expand Up @@ -47,13 +46,12 @@ td_ipm <- function(
'nonlinear-time-exponential',
'model-free',
'constant'),
na.action = na.omit,
optim_args = list(),
silent = T) {

# Required data columns
validate_td_data(data, required_columns = c('indiff', 'del'))
data <- na.action(data)

data <- validate_td_data(data, required_columns = c('indiff', 'del'))

cand_fns <- get_candidate_discount_functions(discount_function)

Expand Down
2 changes: 1 addition & 1 deletion R/tempodisco-package.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

#' @importFrom stats optim predict qlogis plogis residuals integrate coef BIC AIC glm binomial fitted approx approxfun predict.glm na.omit median filter
#' @importFrom stats optim predict qlogis plogis residuals integrate coef BIC AIC glm binomial fitted approx approxfun predict.glm median filter complete.cases
#' @importFrom graphics lines points title axis
#' @importFrom methods is
#' @importFrom grDevices rgb
Expand Down
17 changes: 17 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ geomean <- function(x, ...) {
exp(mean(log(x), ...))
}

# Harmonic mean
harmean <- function(x) {
length(x) / sum(1/x)
}

# Log-odds mean
logoddsmean <- function(x) {
plogis(mean(qlogis(laplace_smooth(x))))
}

# arcsin mean
asinmean <- function(x) {
sin(
mean(asin(sqrt(x)))
)**2
}

get_transform <- function(config, inverse = F) {

# From a string, get the transform applied to val_imm/val_del and to the value of the discount function
Expand Down
3 changes: 0 additions & 3 deletions man/td_bcnm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 1 addition & 4 deletions man/td_ddm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions man/td_ipm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions tests/testthat/test-td_bclm.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ while (model_idx <= length(models)) {
})
}

test_that('NAs', {
with_na <- df
with_na$imm_chosen[2] <- NA
with_na$irrelevant_column <- NA
expect_warning(td_bclm(with_na, model = 'hyperbolic.1'))
})

test_that('errors', {
expect_error(td_bclm(df, model = 'random'))
expect_error(td_bclm())
Expand Down
7 changes: 7 additions & 0 deletions tests/testthat/test-td_bcnm.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ while (arg_combo_idx <= nrow(arg_combos)) {
})
}

test_that('NAs', {
with_na <- df
with_na$imm_chosen[2] <- NA
with_na$irrelevant_column <- NA
expect_warning(td_bcnm(with_na, discount_function = 'hyperbolic'))
})

test_that('errors', {
expect_error(td_bcnm(df, choice_rule = 'random'))
expect_error(td_bcnm(df, noise_dist = 'norm'))
Expand Down
11 changes: 10 additions & 1 deletion tests/testthat/test-td_ddm.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ df <- td_bc_single_ptpt
# Test generics
# Test a few different discount functions

default_args <- list(td_bc_single_ptpt,
default_args <- list(data = td_bc_single_ptpt,
discount_function = 'exponential',
v_par_starts = 0.01,
beta_par_starts = 0.5,
Expand Down Expand Up @@ -81,6 +81,15 @@ test_that('drift transformations', {
expect_no_error(do.call(td_ddm, args))
})

test_that('NAs', {
with_na <- df
with_na$imm_chosen[2] <- NA
with_na$irrelevant_column <- NA
args <- default_args
args$data <- with_na
expect_warning(mod <- do.call(td_ddm, args))
})

test_that('errors', {
expect_error(td_ddm(df, discount_function = 'random'))
expect_error(td_ddm())
Expand Down
7 changes: 7 additions & 0 deletions tests/testthat/test-td_ipm.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ while (df_idx <= length(discount_functions)) {
})
}

test_that('NAs', {
with_na <- df
with_na$indiff[2] <- NA
with_na$irrelevant_column <- NA
expect_warning(td_ipm(with_na, discount_function = 'hyperbolic'))
})

test_that('errors', {
expect_error(td_ipm())
expect_error(td_ipm(data.frame(del = 1:10)))
Expand Down

0 comments on commit 6d7fef8

Please sign in to comment.