Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Mar 23, 2023
1 parent 35ffae1 commit d21bd75
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 56 deletions.
7 changes: 0 additions & 7 deletions tests/testthat/_snaps/ipcw.md

This file was deleted.

98 changes: 49 additions & 49 deletions tests/testthat/test-ipcw.R
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@

test_that('calculate weight time', {
skip_if_not_installed("parsnip", minimum_version = "1.0.4.9003")
skip_if_not_installed("parsnip", minimum_version = "1.0.4.9006")
skip_if_not_installed("censored", minimum_version = "0.1.1.9002")

library(survival)
library(tidymodels)
library(censored)

times <- 1:10
cens <- rep(0:1, times = 5)

surv_obj <- Surv(times, cens)
n <- length(surv_obj)

eval_0 <- parsnip:::graf_weight_time(surv_obj, eval_time = 0)
eval_05 <- parsnip:::graf_weight_time(surv_obj, eval_time = 5, eps = 1)
eval_11 <- parsnip:::graf_weight_time(surv_obj, eval_time = 11, rows = 11:20, eps = 0)
eval_0 <- parsnip:::graf_weight_time_vec(surv_obj, eval_time = rep(0, n))
eval_05 <- parsnip:::graf_weight_time_vec(surv_obj, eval_time = rep(5, n), eps = 1)
eval_11 <- parsnip:::graf_weight_time_vec(surv_obj, eval_time = rep(11, n), eps = 0)

na_05 <- is.na(eval_05$weight_time)
na_11 <- is.na(eval_11$weight_time)
na_05 <- is.na(eval_05)
na_11 <- is.na(eval_11)

expect_equal(eval_0$weight_time, rep(0, 10))
expect_equal(eval_0$.row, 1:10)
expect_equal(eval_0, rep(0, 10))

expect_equal(
which(na_05),
which(times <= 5 & cens == 0)
)
expect_equal(
eval_05$weight_time[!na_05],
eval_05[!na_05],
ifelse(times[!na_05] - 1 < 5, times[!na_05] - 1, 4)
)

Expand All @@ -33,62 +35,60 @@ test_that('calculate weight time', {
which(cens == 0)
)
expect_equal(
eval_11$weight_time[!na_11],
(1:5) * 2
eval_11[!na_11],
seq(2, 10, by = 2)
)
expect_equal(eval_11$.row, 11:20)

})

test_that('compute Graf weights', {
skip_if_not_installed("parsnip", minimum_version = "1.0.4.9003")
skip_if_not_installed("parsnip", minimum_version = "1.0.4.9006")
skip_if_not_installed("censored", minimum_version = "0.1.1.9002")

library(parsnip)
library(survival)
library(tidymodels)
library(censored)
library(workflows)
library(dplyr)

times <- 1:10
cens <- c(0, rep(1, 9))
times <- c(9, 1:9)
cens <- rep(0:1, 5)
surv_obj <- Surv(times, cens)
n <- length(surv_obj)
df <- data.frame(surv = surv_obj, x = -1:8)
fit <- survival_reg() %>% fit(surv ~ x, data = df)
wflow_fit <-
workflow() %>%
add_model(survival_reg(), formula = surv ~ x) %>%
add_variables(surv, x) %>%
fit(data = df)
mod_fit <- extract_fit_parsnip(wflow_fit)

eval_times <- c(5, 1:4)

pred_surv <-
predict(mod_fit, df, type = "survival", eval_time = eval_times) %>%
bind_cols(
predict(mod_fit, df, type = "time"),
df
) %>%
slice(5)

wt_times <-
parsnip:::graf_weight_time_vec(pred_surv$surv,
eval_time = pred_surv$.pred[[1]]$.eval_time)
expect_equal(wt_times, c(NA, 0.9999999999, 1.9999999999, 2.9999999999, NA), tolerance = 0.01)

cens_probs <- predict(fit$censor_probs, time = wt_times, as_vector = TRUE)

wts <- .censoring_weights_graf(fit, pred_surv)
expect_equal(names(wts), names(pred_surv))
expect_equal(nrow(wts), nrow(pred_surv))
expect_equal(dim(wts$.pred[[1]]), c(length(eval_times), 5))
expect_equal(wts$.pred[[1]]$.eval_time, eval_times)
expect_equal(
names(wts$.pred[[1]]),
c(".eval_time", ".pred_survival", ".weight_time", ".pred_censored", ".weight_censored"))

eval_0 <- parsnip:::graf_weight_time(surv_obj, eval_time = 0)
eval_05 <- parsnip:::graf_weight_time(surv_obj, eval_time = 5, eps = 1)
eval_11 <- parsnip:::graf_weight_time(surv_obj, eval_time = 11, rows = 11:20, eps = 0)

cens_prob_00 <- predict(fit$censor_probs, time = eval_0$weight_time, as_vector = TRUE)
cens_prob_05 <- predict(fit$censor_probs, time = eval_05$weight_time, as_vector = TRUE)
cens_prob_11 <- predict(fit$censor_probs, time = eval_11$weight_time, as_vector = TRUE)

wts_00 <- .censoring_weights_graf(fit, df, 0)
wts_05 <- .censoring_weights_graf(fit, df, 5)
wts_11 <- .censoring_weights_graf(fit, df, 11)

wflow_wts_00 <- .censoring_weights_graf(wflow_fit, df, 0)
wflow_wts_05 <- .censoring_weights_graf(wflow_fit, df, 5)
wflow_wts_11 <- .censoring_weights_graf(wflow_fit, df, 11)

expect_equal(wts_00$.weight_cens, 1 / cens_prob_00)
expect_equal(wts_05$.weight_cens, 1 / cens_prob_05)
expect_equal(wts_11$.weight_cens, 1 / cens_prob_11)

expect_equal(wflow_wts_00$.weight_cens, 1 / cens_prob_00)
expect_equal(wflow_wts_05$.weight_cens, 1 / cens_prob_05)
expect_equal(wflow_wts_11$.weight_cens, 1 / cens_prob_11)

expect_true(inherits(wts_00, "data.frame"))
expect_equal(names(wts_00), c(".row", "eval_time", ".prob_cens", ".weight_cens"))
expect_equal(nrow(wts_00), nrow(df))

expect_snapshot(.censoring_weights_graf(2, df, 0), error = TRUE)
wts2 <- wts %>% unnest(.pred)
expect_equal(wts2$.weight_censored, 1 / cens_probs)

})

0 comments on commit d21bd75

Please sign in to comment.