Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

__sklearn_tags__ in TabPFNClassifier #95

Open
anshulg954 opened this issue Dec 12, 2024 · 0 comments
Open

__sklearn_tags__ in TabPFNClassifier #95

anshulg954 opened this issue Dec 12, 2024 · 0 comments

Comments

@anshulg954
Copy link

File "/home/runner/work/tabpfn-server/tabpfn-server/app/services/tabpfn_prediction_service.py", line 118, in predict
res = model.predict_proba(x_test).tolist()
File "/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/tabpfn/scripts/estimator/base.py", line 3117, in predict_proba
return self.predict_full(X, additional_y=additional_y)["proba"]
File "/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/tabpfn/scripts/estimator/base.py", line 3066, in predict_full
X_full, y_full, additional_y_full, eval_pos = self.predict_common_setup(
File "/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/tabpfn/scripts/estimator/base.py", line 1442, in predict_common_setup
check_is_fitted(self)
File "/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/utils/validation.py", line 1751, in check_is_fitted
tags = get_tags(estimator)
File "/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/utils/_tags.py", line 405, in get_tags
sklearn_tags_provider[klass] = klass.sklearn_tags(estimator) # type: ignore[attr-defined]
File "/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/sklearn/base.py", line 540, in sklearn_tags
tags = super().sklearn_tags()
AttributeError: 'super' object has no attribute 'sklearn_tags'

The error arises because BaseEstimator in scikit-learn now relies on a method called sklearn_tags for metadata about the estimator, which is missing from TabPFNClassifier.

A possible fix:
Adding the following code to TabPFNClassifier?

def __sklearn_tags__(self):
    """
    Metadata tags required by scikit-learn.
    """
    tags = super().__sklearn_tags__() if hasattr(super(), "__sklearn_tags__") else {}
    tags.update({
        "requires_y": True,  # This estimator requires a target `y`
        "X_types": ["2darray"],  # Input should be a 2D array
        "multioutput": False,  # Not a multi-output classifier
        "allow_nan": True,  # Allows NaN values in `X`
        "binary_only": False,  # Handles multiclass and binary classification
    })
    return tags
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant