Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow train_test_pairs to work on generic vector y? #847

Closed
SimonEnsemble opened this issue Oct 15, 2022 · 2 comments · Fixed by #848
Closed

allow train_test_pairs to work on generic vector y? #847

SimonEnsemble opened this issue Oct 15, 2022 · 2 comments · Fixed by #848
Assignees

Comments

@SimonEnsemble
Copy link

after a long time, I still cannot figure out how to get

	class_labels = vcat(["spam" for _ = 1:3], ["not spam" for i = 1:9])
	ttp = train_test_pairs(kf, 1:length(class_labels), class_labels)
	for (k, (ids_train, ids_test)) in enumerate(ttp)
		# train on ids_train
		# test on ids_test
	end

to work. would be nice to allow train_test_pairs to accept a generic vector of class labels so that this can be used outside the context of ML workflows in this package. I'm looking for something generic like scikitlearn's stratified K-folds. thanks!

the error I get is:

ArgumentError: Supplied target has scitype AbstractVector{ScientificTypesBase.Textual} but stratified cross-validation applies only to classification problems.

    train_test_pairs(::MLJBase.StratifiedCV, ::UnitRange{Int64}, ::Vector{String})@resampling.jl:407
    top-level scope@[Local: 3](http://localhost:1235/edit?id=c4c85f16-4c12-11ed-3b74-134639249bb1#)[inlined]
@SimonEnsemble
Copy link
Author

I needed:

train_test_pairs(kf, 1:length(y_train), y_train |> categorical)

which I discovered from #276. not intuitive or necessary to me, as a user, just importing a few functions from MLJBase...

@ablaom
Copy link
Member

ablaom commented Oct 16, 2022

@SimonEnsemble Thanks for reporting.

In MLJ more generally, care is taken to track all levels of categorical data, not just those seen in a give sample. However, in this particular case, the full class pool is not used (or needed) so I agree there's a good argument to make this more generic. I think the test triggering the warning can just be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants