Skip to content

Commit

Permalink
merge pr #89: accommodate outcome levels with special characters
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Sep 10, 2021
2 parents 5de108c + b573fce commit 0985737
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

To released as v0.2.2.

* Fixed bug arising from outcome levels that are not valid column
names in the multinomial classification setting.

# v0.2.1

* Updates for importing workflow sets that use the `add_variables()`
Expand Down
7 changes: 5 additions & 2 deletions R/add_candidates.R
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ add_candidates.default <- function(data_stack, candidates, name, ...) {

pred_class_idx <- grepl(pattern = ".pred_class", x = colnames(candidate_cols))

candidate_cols <- candidate_cols[,!pred_class_idx]
candidate_cols <- candidate_cols[,!pred_class_idx] %>%
setNames(., make.names(names(.)))

if (nrow(stack) == 0) {
stack <-
Expand Down Expand Up @@ -473,7 +474,9 @@ process_.config <- function(.config, df, name) {

# For racing, we only want to keep the candidates with complete resamples.
collate_predictions <- function(x) {
res <- tune::collect_predictions(x, summarize = TRUE)
res <- tune::collect_predictions(x, summarize = TRUE) %>%
dplyr::rename_with(make.names, .cols = dplyr::starts_with(".pred"))

if (inherits(x, "tune_race")) {
config_counts <-
tune::collect_metrics(x, summarize = FALSE) %>%
Expand Down
3 changes: 2 additions & 1 deletion R/fit_members.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ sanitize_classification_names <- function(model_stack, member_names) {
as.character() %>%
unique()

pred_strings <- paste0(".pred_", outcome_levels, "_")
pred_strings <- paste0(".pred_", outcome_levels, "_") %>%
make.names()

new_member_names <-
gsub(
Expand Down
3 changes: 2 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ predict.model_stack <- function(object, new_data, type = NULL, members = FALSE,
opts = opts,
type = member_type
) %>%
rlang::eval_tidy()
rlang::eval_tidy() %>%
setNames(., make.names(names(.)))

res <- stack_predict(object$equations[[type]], member_preds)

Expand Down

0 comments on commit 0985737

Please sign in to comment.