Skip to content

Commit

Permalink
try to use view.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 2, 2023
1 parent 5a16d49 commit 6030aa9
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,15 +1783,31 @@ def _get_qid(
X: ArrayLike, qid: Optional[ArrayLike]
) -> Tuple[ArrayLike, Optional[ArrayLike]]:
"""Get the special qid column from X if exists."""
if (_is_pandas_df(X) or _is_cudf_df(X)) and hasattr(X, "qid"):
has_qid = hasattr(X, "qid")
if (_is_pandas_df(X) or _is_cudf_df(X)) and has_qid:
if qid is not None:
raise ValueError(
"Found both the special column `qid` in `X` and the `qid` from the"
"`fit` method. Please remove one of them."
)
if _is_cudf_df(X) and has_qid:
q_x = X.qid
X = X.drop("qid", axis=1)
return X, q_x
if _is_pandas_df(X) and has_qid:
import pandas as pd

q_x = X.qid
series = []
columns = X.columns.difference(["qid"])
for c in columns:
if c == "qid":
continue

s_view = X[c].view(X[c].dtype)
series.append(s_view)
X = pd.DataFrame(series, columns=columns)
return X, q_x
return X, qid


Expand Down

0 comments on commit 6030aa9

Please sign in to comment.