Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

predictions from a bart classification model #64

Closed
jdberson opened this issue Jul 9, 2024 · 4 comments · Fixed by #65
Closed

predictions from a bart classification model #64

jdberson opened this issue Jul 9, 2024 · 4 comments · Fixed by #65
Labels
feature a feature request or enhancement

Comments

@jdberson
Copy link

jdberson commented Jul 9, 2024

Hello

I’m having trouble making probability predictions in a new R session from a bundled bart classification model object. Not sure if this is a bug or I am doing something silly.

Predicting from an unbundled bart classification model in the same R session works, but when I use saveRDS + readRDS, the new predictions are incorrect and tend to be clustered around 0.5.

I have included a reprex below to show the problem.

Thanks for all the great packages and for your help!

Warm regards

Jacob

# Example to show that bundling a bart classification model object and then 
# predicting the probabilities from this object in a new R session results in 
# different probability predictions that are clustered around 0.5.

## Setup

library(tidymodels)
library(bundle)
library(xgboost, quietly = TRUE, warn.conflicts = FALSE)
library(dbarts, quietly = TRUE, warn.conflicts = FALSE)
library(callr)
library(waldo)
library(stringr, quietly = TRUE, warn.conflicts = FALSE)

tidymodels_prefer()

# Data
data("two_class_dat", package = "modeldata")

set.seed(1)

## Fit model and make predictions

# fit the model
mod_bart <-
  parsnip::bart() |>
  set_mode("classification") |>
  set_engine("dbarts") |>
  fit(Class ~ ., data = two_class_dat[1:70, ])

# bundle the model
mod_bart_bundle <- bundle(mod_bart)

# Make predictions using the bundled model in the existing R session
mod_bart_unbundled <- unbundle(mod_bart_bundle)
bart_predictions_existing_session <- 
  predict(mod_bart_unbundled, two_class_dat[71:200, ], type = "prob")

# Make predictions using bundled model in a new R session
bart_predictions_new_session <-
  r(
    function(model_bundle, new_data) {
      library(bundle)
      library(parsnip)
      
      model_object <- unbundle(model_bundle)
      
      predict(model_object, new_data, type = "prob")
    },
    args = list(
      model_bundle = mod_bart_bundle,
      new_data = two_class_dat[71:200, ]
    )
  )

# Compare the predictions
compare(bart_predictions_existing_session, bart_predictions_new_session)
#> old vs new
#>              .pred_Class1 .pred_Class2
#> - old[1, ]          0.844        0.156
#> + new[1, ]          0.477        0.523
#> - old[2, ]          0.581        0.419
#> + new[2, ]          0.486        0.514
#> - old[3, ]          0.396        0.604
#> + new[3, ]          0.488        0.512
#> - old[4, ]          0.722        0.278
#> + new[4, ]          0.514        0.486
#> - old[5, ]          0.596        0.404
#> + new[5, ]          0.501        0.499
#> - old[6, ]          0.545        0.455
#> + new[6, ]          0.482        0.518
#> - old[7, ]          0.860        0.140
#> + new[7, ]          0.487        0.513
#> - old[8, ]          0.635        0.365
#> + new[8, ]          0.509        0.491
#> - old[9, ]          0.195        0.805
#> + new[9, ]          0.521        0.479
#> - old[10, ]         0.399        0.601
#> + new[10, ]         0.503        0.497
#> and 120 more ...
#> 
#>      old$.pred_Class1 | new$.pred_Class1                 
#>  [1] 0.844            - 0.477            [1]             
#>  [2] 0.581            - 0.486            [2]             
#>  [3] 0.396            - 0.488            [3]             
#>  [4] 0.722            - 0.514            [4]             
#>  [5] 0.596            - 0.501            [5]             
#>  [6] 0.545            - 0.482            [6]             
#>  [7] 0.860            - 0.487            [7]             
#>  [8] 0.635            - 0.509            [8]             
#>  [9] 0.195            - 0.521            [9]             
#> [10] 0.399            - 0.503            [10]            
#>  ... ...                ...              and 120 more ...
#> 
#>      old$.pred_Class2 | new$.pred_Class2                 
#>  [1] 0.156            - 0.523            [1]             
#>  [2] 0.419            - 0.514            [2]             
#>  [3] 0.604            - 0.512            [3]             
#>  [4] 0.278            - 0.486            [4]             
#>  [5] 0.404            - 0.499            [5]             
#>  [6] 0.455            - 0.518            [6]             
#>  [7] 0.140            - 0.513            [7]             
#>  [8] 0.365            - 0.491            [8]             
#>  [9] 0.805            - 0.479            [9]             
#> [10] 0.601            - 0.497            [10]            
#>  ... ...                ...              and 120 more ...

# Plot the  class 1 probability predictions from the new R session against the
# class 1 probability predictions from the existing R session
bind_cols(
  bart_predictions_existing_session |>
      rename_with(.fn = \(x) str_c(x, "_existing")),
    
    bart_predictions_new_session |>
      rename_with(.fn = \(x) str_c(x, "_new"))
  ) |>
  
  ggplot(aes(x = .pred_Class1_existing, .pred_Class1_new)) +
  geom_point() +
  theme_bw()

Created on 2024-07-09 with reprex v2.1.0.9000

@simonpcouch
Copy link
Collaborator

We currently don't supply any dbarts-specific functionality in bundle, but looks like dbart objects have an extptr that we lose access to in new sessions. Strange to me that the model will predict without it anyway.

library(tidymodels)
library(callr)
library(waldo)

# Data
data("two_class_dat", package = "modeldata")

mod_bart <- bart(mode = "classification", engine = "dbarts") 

# fit two models, each with same seed
set.seed(1)
bart_fit_1 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])

set.seed(1)
bart_fit_2 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])

# some differences, but fine as long as the predict the same
compare(bart_fit_1, bart_fit_2)
#> `old$fit$fit@.xData$.->pointer` is <pointer: 0x12be4c320>
#> `new$fit$fit@.xData$.->pointer` is <pointer: 0x11bee3570>
#> 
#> `old$fit$fit@.xData$pointer` is <pointer: 0x12be4c320>
#> `new$fit$fit@.xData$pointer` is <pointer: 0x11bee3570>
#> 
#> `attr(old$fit$fit@.xData$state, 'runningTime')`: 0.397
#> `attr(new$fit$fit@.xData$state, 'runningTime')`: 0.403
#> 
#> `old$elapsed$elapsed`: 0.44 0.01 0.45 0.00 0.00
#> `new$elapsed$elapsed`: 0.40 0.01 0.41 0.00 0.00
# ...which they do
set.seed(1)
p_fit_1 <- predict(bart_fit_1, two_class_dat[71:200, ], type = "prob")
set.seed(1)
p_fit_2 <- predict(bart_fit_2, two_class_dat[71:200, ], type = "prob")
compare(p_fit_1, p_fit_2)
#> ✔ No differences
# that pointer no longer exists in a new session
r(
  function(model_object) {
    model_object$fit$fit@.xData$pointer
  },
  args = list(
    model_object = bart_fit_1
  )
)
#> <pointer: 0x0>

Created on 2024-07-09 with reprex v2.1.0

Some other miscellaneous observations:

My first reflex was to see if the predictions from this model depend on the state of RNG (without bundling). Looks like they do:

library(tidymodels)
library(waldo)

# Data
data("two_class_dat", package = "modeldata")

set.seed(1)

mod_bart <-
  parsnip::bart() |>
  set_mode("classification") |>
  set_engine("dbarts") |>
  fit(Class ~ ., data = two_class_dat[1:70, ])

compare(
  predict(mod_bart, two_class_dat[71:200, ], type = "prob"),
  predict(mod_bart, two_class_dat[71:200, ], type = "prob")
)
#> old vs new
#>              .pred_Class1 .pred_Class2
#> - old[1, ]          0.844        0.156
#> + new[1, ]          0.830        0.170
#> - old[2, ]          0.581        0.419
#> + new[2, ]          0.570        0.430
#> - old[3, ]          0.396        0.604
#> + new[3, ]          0.377        0.623
#> - old[4, ]          0.722        0.278
#> + new[4, ]          0.733        0.267
#> - old[5, ]          0.596        0.404
#> + new[5, ]          0.605        0.395
#> - old[6, ]          0.545        0.455
#> + new[6, ]          0.522        0.478
#>   old[7, ]          0.860        0.140
#> - old[8, ]          0.635        0.365
#> + new[8, ]          0.630        0.370
#> - old[9, ]          0.195        0.805
#> + new[9, ]          0.194        0.806
#> - old[10, ]         0.399        0.601
#> + new[10, ]         0.407        0.593
#> and 120 more ...
#> 
#>      old$.pred_Class1 | new$.pred_Class1                 
#>  [1] 0.844            - 0.830            [1]             
#>  [2] 0.581            - 0.570            [2]             
#>  [3] 0.396            - 0.377            [3]             
#>  [4] 0.722            - 0.733            [4]             
#>  [5] 0.596            - 0.605            [5]             
#>  [6] 0.545            - 0.522            [6]             
#>  [7] 0.860            | 0.860            [7]             
#>  [8] 0.635            - 0.630            [8]             
#>  [9] 0.195            - 0.194            [9]             
#> [10] 0.399            - 0.407            [10]            
#>  ... ...                ...              and 120 more ...
#> 
#>      old$.pred_Class2 | new$.pred_Class2                 
#>  [1] 0.156            - 0.170            [1]             
#>  [2] 0.419            - 0.430            [2]             
#>  [3] 0.604            - 0.623            [3]             
#>  [4] 0.278            - 0.267            [4]             
#>  [5] 0.404            - 0.395            [5]             
#>  [6] 0.455            - 0.478            [6]             
#>  [7] 0.140            | 0.140            [7]             
#>  [8] 0.365            - 0.370            [8]             
#>  [9] 0.805            - 0.806            [9]             
#> [10] 0.601            - 0.593            [10]            
#>  ... ...                ...              and 120 more ...
set.seed(1)
p_1 <- predict(mod_bart, two_class_dat[71:200, ], type = "prob")

set.seed(1)
p_2 <- predict(mod_bart, two_class_dat[71:200, ], type = "prob")

compare(p_1, p_2)
#> ✔ No differences

Created on 2024-07-09 with reprex v2.1.0

We don't supply any dbarts-specific functionality in bundle, so I also wondered if this was some side effect of bundling a model_fit, but this behavior is replicable without bundling. It is:

library(tidymodels)
library(bundle)
library(callr)
library(waldo)

tidymodels_prefer()

# Data
data("two_class_dat", package = "modeldata")

set.seed(1)

# fit the model
mod_bart <-
  parsnip::bart() |>
  set_mode("classification") |>
  set_engine("dbarts") |>
  fit(Class ~ ., data = two_class_dat[1:70, ])

set.seed(1)
bart_predictions_existing_session <- 
  predict(mod_bart, two_class_dat[71:200, ], type = "prob")

# Make predictions using bundled model in a new R session
bart_predictions_new_session <-
  r(
    function(model_object, new_data) {
      library(parsnip)

      set.seed(1)
      predict(model_object, new_data, type = "prob")
    },
    args = list(
      model_object = mod_bart,
      new_data = two_class_dat[71:200, ]
    )
  )

compare(bart_predictions_existing_session, bart_predictions_new_session)
#> old vs new
#>              .pred_Class1 .pred_Class2
#> - old[1, ]          0.823        0.177
#> + new[1, ]          0.520        0.480
#> - old[2, ]          0.568        0.432
#> + new[2, ]          0.519        0.481
#> - old[3, ]          0.361        0.639
#> + new[3, ]          0.519        0.481
#> - old[4, ]          0.721        0.279
#> + new[4, ]          0.508        0.492
#> - old[5, ]          0.599        0.401
#> + new[5, ]          0.498        0.502
#> - old[6, ]          0.549        0.451
#> + new[6, ]          0.478        0.522
#> - old[7, ]          0.863        0.137
#> + new[7, ]          0.474        0.526
#> - old[8, ]          0.647        0.353
#> + new[8, ]          0.522        0.478
#> - old[9, ]          0.200        0.800
#> + new[9, ]          0.502        0.498
#> - old[10, ]         0.394        0.606
#> + new[10, ]         0.507        0.493
#> and 120 more ...
#> 
#>      old$.pred_Class1 | new$.pred_Class1                 
#>  [1] 0.823            - 0.520            [1]             
#>  [2] 0.568            - 0.519            [2]             
#>  [3] 0.361            - 0.519            [3]             
#>  [4] 0.721            - 0.508            [4]             
#>  [5] 0.599            - 0.498            [5]             
#>  [6] 0.549            - 0.478            [6]             
#>  [7] 0.863            - 0.474            [7]             
#>  [8] 0.647            - 0.522            [8]             
#>  [9] 0.200            - 0.502            [9]             
#> [10] 0.394            - 0.507            [10]            
#>  ... ...                ...              and 120 more ...
#> 
#>      old$.pred_Class2 | new$.pred_Class2                 
#>  [1] 0.177            - 0.480            [1]             
#>  [2] 0.432            - 0.481            [2]             
#>  [3] 0.639            - 0.481            [3]             
#>  [4] 0.279            - 0.492            [4]             
#>  [5] 0.401            - 0.502            [5]             
#>  [6] 0.451            - 0.522            [6]             
#>  [7] 0.137            - 0.526            [7]             
#>  [8] 0.353            - 0.478            [8]             
#>  [9] 0.800            - 0.498            [9]             
#> [10] 0.606            - 0.493            [10]            
#>  ... ...                ...              and 120 more ...

Created on 2024-07-09 with reprex v2.1.0

@juliasilge juliasilge added the feature a feature request or enhancement label Jul 9, 2024
@simonpcouch
Copy link
Collaborator

Okay, from dbarts' docs:

Saving: saveing and loading fitted BART objects for use with predict requires that R’s
serialization mechanism be able to access the underlying trees, in addition to being fit with
keeptrees/keepTrees as TRUE. For memory purposes, the trees are not stored as R objects unless
specifically requested. To do this, one must “touch” the sampler’s state object before saving, e.g.
for a fitted object bartFit, execute invisible(bartFit$fit$state).

"Touch"ing the $fit$state slot is easy for bundle to do. Setting keeptrees = TRUE in the model fit is outside of bundle's scope, though parsnip sets keeptrees = TRUE by default. Maybe the bundle() method in this case would error if keeptrees = FALSE?

library(tidymodels)
library(callr)
library(waldo)

# Data
data("two_class_dat", package = "modeldata")

mod_bart <- parsnip::bart(mode = "classification", engine = "dbarts")

# fit the model
set.seed(1)
bart_fit_1 <- fit(mod_bart, Class ~ ., data = two_class_dat[1:70, ])

set.seed(1)
p_orig <- predict(bart_fit_1, two_class_dat[71:200, ], type = "prob")

invisible(bart_fit_1$fit$fit$state)

p_new_session <- r(
  function(model_object, two_class_dat) {
    library(parsnip)
    
    set.seed(1)
    predict(model_object, two_class_dat[71:200, ], type = "prob")
  },
  args = list(
    model_object = bart_fit_1,
    two_class_dat = two_class_dat
  )
)

compare(p_orig, p_new_session)
#> ✔ No differences

Created on 2024-07-09 with reprex v2.1.0

@jdberson
Copy link
Author

Thanks very much @simonpcouch for looking into this and responding so quickly! It will be great if bundle could include dbarts functionality. In the meantime touching the $fit$state slot solves my problem, apologies I wasn't aware of this.

Cheers

Jacob

@simonpcouch
Copy link
Collaborator

apologies I wasn't aware of this.

No worries at all! It's bundle's job to iron out these oddities for users, so we definitely ought to support this. I appreciate you pointing this one out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature a feature request or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants