-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
predictions from a bart classification model #64
Comments
We currently don't supply any dbarts-specific functionality in bundle, but looks like dbart objects have an extptr that we lose access to in new sessions. Strange to me that the model will predict without it anyway. library(tidymodels)
library(callr)
library(waldo)
# Data
data("two_class_dat", package = "modeldata")
mod_bart <- bart(mode = "classification", engine = "dbarts")
# fit two models, each with same seed
set.seed(1)
bart_fit_1 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])
set.seed(1)
bart_fit_2 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])
# some differences, but fine as long as the predict the same
compare(bart_fit_1, bart_fit_2)
#> `old$fit$fit@.xData$.->pointer` is <pointer: 0x12be4c320>
#> `new$fit$fit@.xData$.->pointer` is <pointer: 0x11bee3570>
#>
#> `old$fit$fit@.xData$pointer` is <pointer: 0x12be4c320>
#> `new$fit$fit@.xData$pointer` is <pointer: 0x11bee3570>
#>
#> `attr(old$fit$fit@.xData$state, 'runningTime')`: 0.397
#> `attr(new$fit$fit@.xData$state, 'runningTime')`: 0.403
#>
#> `old$elapsed$elapsed`: 0.44 0.01 0.45 0.00 0.00
#> `new$elapsed$elapsed`: 0.40 0.01 0.41 0.00 0.00 # ...which they do
set.seed(1)
p_fit_1 <- predict(bart_fit_1, two_class_dat[71:200, ], type = "prob")
set.seed(1)
p_fit_2 <- predict(bart_fit_2, two_class_dat[71:200, ], type = "prob")
compare(p_fit_1, p_fit_2)
#> ✔ No differences # that pointer no longer exists in a new session
r(
function(model_object) {
model_object$fit$fit@.xData$pointer
},
args = list(
model_object = bart_fit_1
)
)
#> <pointer: 0x0> Created on 2024-07-09 with reprex v2.1.0 Some other miscellaneous observations:My first reflex was to see if the predictions from this model depend on the state of RNG (without bundling). Looks like they do: library(tidymodels)
library(waldo)
# Data
data("two_class_dat", package = "modeldata")
set.seed(1)
mod_bart <-
parsnip::bart() |>
set_mode("classification") |>
set_engine("dbarts") |>
fit(Class ~ ., data = two_class_dat[1:70, ])
compare(
predict(mod_bart, two_class_dat[71:200, ], type = "prob"),
predict(mod_bart, two_class_dat[71:200, ], type = "prob")
)
#> old vs new
#> .pred_Class1 .pred_Class2
#> - old[1, ] 0.844 0.156
#> + new[1, ] 0.830 0.170
#> - old[2, ] 0.581 0.419
#> + new[2, ] 0.570 0.430
#> - old[3, ] 0.396 0.604
#> + new[3, ] 0.377 0.623
#> - old[4, ] 0.722 0.278
#> + new[4, ] 0.733 0.267
#> - old[5, ] 0.596 0.404
#> + new[5, ] 0.605 0.395
#> - old[6, ] 0.545 0.455
#> + new[6, ] 0.522 0.478
#> old[7, ] 0.860 0.140
#> - old[8, ] 0.635 0.365
#> + new[8, ] 0.630 0.370
#> - old[9, ] 0.195 0.805
#> + new[9, ] 0.194 0.806
#> - old[10, ] 0.399 0.601
#> + new[10, ] 0.407 0.593
#> and 120 more ...
#>
#> old$.pred_Class1 | new$.pred_Class1
#> [1] 0.844 - 0.830 [1]
#> [2] 0.581 - 0.570 [2]
#> [3] 0.396 - 0.377 [3]
#> [4] 0.722 - 0.733 [4]
#> [5] 0.596 - 0.605 [5]
#> [6] 0.545 - 0.522 [6]
#> [7] 0.860 | 0.860 [7]
#> [8] 0.635 - 0.630 [8]
#> [9] 0.195 - 0.194 [9]
#> [10] 0.399 - 0.407 [10]
#> ... ... ... and 120 more ...
#>
#> old$.pred_Class2 | new$.pred_Class2
#> [1] 0.156 - 0.170 [1]
#> [2] 0.419 - 0.430 [2]
#> [3] 0.604 - 0.623 [3]
#> [4] 0.278 - 0.267 [4]
#> [5] 0.404 - 0.395 [5]
#> [6] 0.455 - 0.478 [6]
#> [7] 0.140 | 0.140 [7]
#> [8] 0.365 - 0.370 [8]
#> [9] 0.805 - 0.806 [9]
#> [10] 0.601 - 0.593 [10]
#> ... ... ... and 120 more ... set.seed(1)
p_1 <- predict(mod_bart, two_class_dat[71:200, ], type = "prob")
set.seed(1)
p_2 <- predict(mod_bart, two_class_dat[71:200, ], type = "prob")
compare(p_1, p_2)
#> ✔ No differences Created on 2024-07-09 with reprex v2.1.0 We don't supply any dbarts-specific functionality in bundle, so I also wondered if this was some side effect of bundling a library(tidymodels)
library(bundle)
library(callr)
library(waldo)
tidymodels_prefer()
# Data
data("two_class_dat", package = "modeldata")
set.seed(1)
# fit the model
mod_bart <-
parsnip::bart() |>
set_mode("classification") |>
set_engine("dbarts") |>
fit(Class ~ ., data = two_class_dat[1:70, ])
set.seed(1)
bart_predictions_existing_session <-
predict(mod_bart, two_class_dat[71:200, ], type = "prob")
# Make predictions using bundled model in a new R session
bart_predictions_new_session <-
r(
function(model_object, new_data) {
library(parsnip)
set.seed(1)
predict(model_object, new_data, type = "prob")
},
args = list(
model_object = mod_bart,
new_data = two_class_dat[71:200, ]
)
)
compare(bart_predictions_existing_session, bart_predictions_new_session)
#> old vs new
#> .pred_Class1 .pred_Class2
#> - old[1, ] 0.823 0.177
#> + new[1, ] 0.520 0.480
#> - old[2, ] 0.568 0.432
#> + new[2, ] 0.519 0.481
#> - old[3, ] 0.361 0.639
#> + new[3, ] 0.519 0.481
#> - old[4, ] 0.721 0.279
#> + new[4, ] 0.508 0.492
#> - old[5, ] 0.599 0.401
#> + new[5, ] 0.498 0.502
#> - old[6, ] 0.549 0.451
#> + new[6, ] 0.478 0.522
#> - old[7, ] 0.863 0.137
#> + new[7, ] 0.474 0.526
#> - old[8, ] 0.647 0.353
#> + new[8, ] 0.522 0.478
#> - old[9, ] 0.200 0.800
#> + new[9, ] 0.502 0.498
#> - old[10, ] 0.394 0.606
#> + new[10, ] 0.507 0.493
#> and 120 more ...
#>
#> old$.pred_Class1 | new$.pred_Class1
#> [1] 0.823 - 0.520 [1]
#> [2] 0.568 - 0.519 [2]
#> [3] 0.361 - 0.519 [3]
#> [4] 0.721 - 0.508 [4]
#> [5] 0.599 - 0.498 [5]
#> [6] 0.549 - 0.478 [6]
#> [7] 0.863 - 0.474 [7]
#> [8] 0.647 - 0.522 [8]
#> [9] 0.200 - 0.502 [9]
#> [10] 0.394 - 0.507 [10]
#> ... ... ... and 120 more ...
#>
#> old$.pred_Class2 | new$.pred_Class2
#> [1] 0.177 - 0.480 [1]
#> [2] 0.432 - 0.481 [2]
#> [3] 0.639 - 0.481 [3]
#> [4] 0.279 - 0.492 [4]
#> [5] 0.401 - 0.502 [5]
#> [6] 0.451 - 0.522 [6]
#> [7] 0.137 - 0.526 [7]
#> [8] 0.353 - 0.478 [8]
#> [9] 0.800 - 0.498 [9]
#> [10] 0.606 - 0.493 [10]
#> ... ... ... and 120 more ... Created on 2024-07-09 with reprex v2.1.0 |
Okay, from dbarts' docs:
"Touch"ing the library(tidymodels)
library(callr)
library(waldo)
# Data
data("two_class_dat", package = "modeldata")
mod_bart <- parsnip::bart(mode = "classification", engine = "dbarts")
# fit the model
set.seed(1)
bart_fit_1 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])
set.seed(1)
p_orig <- predict(bart_fit_1, two_class_dat[71:200, ], type = "prob")
invisible(bart_fit_1$fit$fit$state)
p_new_session <- r(
function(model_object, two_class_dat) {
library(parsnip)
set.seed(1)
predict(model_object, two_class_dat[71:200, ], type = "prob")
},
args = list(
model_object = bart_fit_1,
two_class_dat = two_class_dat
)
)
compare(p_orig, p_new_session)
#> ✔ No differences Created on 2024-07-09 with reprex v2.1.0 |
Thanks very much @simonpcouch for looking into this and responding so quickly! It will be great if bundle could include dbarts functionality. In the meantime touching the $fit$state slot solves my problem, apologies I wasn't aware of this. Cheers Jacob |
No worries at all! It's bundle's job to iron out these oddities for users, so we definitely ought to support this. I appreciate you pointing this one out. |
Hello
I’m having trouble making probability predictions in a new R session from a bundled bart classification model object. Not sure if this is a bug or I am doing something silly.
Predicting from an unbundled bart classification model in the same R session works, but when I use saveRDS + readRDS, the new predictions are incorrect and tend to be clustered around 0.5.
I have included a reprex below to show the problem.
Thanks for all the great packages and for your help!
Warm regards
Jacob
Created on 2024-07-09 with reprex v2.1.0.9000
The text was updated successfully, but these errors were encountered: