diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d7ec12c78cd0..e88587983581 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -141,7 +141,7 @@ jobs: - name: Install Python packages run: | python -m pip install wheel setuptools - python -m pip install pylint cpplint numpy scipy scikit-learn + python -m pip install pylint cpplint numpy scipy scikit-learn pyspark dask[all] - name: Run lint run: | make lint diff --git a/demo/pyspark/iris.py b/demo/pyspark/iris.py new file mode 100644 index 000000000000..b7a283a6b8b7 --- /dev/null +++ b/demo/pyspark/iris.py @@ -0,0 +1,80 @@ +""" +Example of training with PySpark on CPU +======================================= + +.. versionadded:: 1.6.0 + +""" +from pyspark.sql import SparkSession +from pyspark.sql.types import * +from pyspark.ml.feature import StringIndexer +from pyspark.ml.feature import VectorAssembler +from xgboost.spark import XGBoostClassifier +import xgboost + +version = "1.5.2" + +spark = ( + SparkSession.builder.master("local[1]") + .config( + "spark.jars.packages", + f"ml.dmlc:xgboost4j_2.12:{version},ml.dmlc:xgboost4j-spark_2.12:{version}", + ) + .appName("xgboost-pyspark iris") + .getOrCreate() +) + +schema = StructType( + [ + StructField("sepal length", DoubleType(), nullable=True), + StructField("sepal width", DoubleType(), nullable=True), + StructField("petal length", DoubleType(), nullable=True), + StructField("petal width", DoubleType(), nullable=True), + StructField("class", StringType(), nullable=True), + ] +) +raw_input = spark.read.schema(schema).csv("iris.data") + +stringIndexer = StringIndexer(inputCol="class", outputCol="classIndex").fit(raw_input) +labeled_input = stringIndexer.transform(raw_input).drop("class") + +vector_assembler = ( + VectorAssembler() + .setInputCols(("sepal length", "sepal width", "petal length", "petal width")) + .setOutputCol("features") +) +xgb_input = vector_assembler.transform(labeled_input).select("features", "classIndex") + + +params = { + "objective": "multi:softprob", + "treeMethod": "hist", + "numWorkers": 1, + "numRound": 100, + "numClass": 3, + "labelCol": "classIndex", + "featuresCol": "features", +} + +classifier = XGBoostClassifier(**params) +classifier.write().overwrite().save("/tmp/xgboost_classifier") +classifier1 = XGBoostClassifier.load("/tmp/xgboost_classifier") + + +classifier = ( + XGBoostClassifier() + .setLabelCol("classIndex") + .setFeaturesCol("features") + .setTreeMethod("hist") + .setNumClass(3) + .setNumRound(100) + .setObjective("multi:softprob") +) +classifier.setNumWorkers(1) + +model = classifier.fit(xgb_input) + + +model = classifier.fit(xgb_input) +results = model.transform(xgb_input) +results.show() diff --git a/doc/python/python_api.rst b/doc/python/python_api.rst index 9f077edbc0df..8ad26ca3906a 100644 --- a/doc/python/python_api.rst +++ b/doc/python/python_api.rst @@ -147,3 +147,29 @@ Dask API :members: :inherited-members: :show-inheritance: + + +PySpark API +----------- + +.. automodule:: xgboost.spark + +.. autoclass:: xgboost.spark.XGBoostClassifier + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: xgboost.spark.XGBoostClassificationModel + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: xgboost.spark.XGBoostRegressor + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: xgboost.spark.XGBoostRegressionModel + :members: + :inherited-members: + :show-inheritance: diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index d2cf979e39f3..31f0efae04b6 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -16,6 +16,7 @@ See `Awesome XGBoost `_ for mo Distributed XGBoost with XGBoost4J-Spark dask ray + pyspark dart monotonic rf diff --git a/doc/tutorials/pyspark.rst b/doc/tutorials/pyspark.rst new file mode 100644 index 000000000000..c68460643140 --- /dev/null +++ b/doc/tutorials/pyspark.rst @@ -0,0 +1,222 @@ +################################ +Distributed XGBoost with PySpark +################################ + +.. versionadded:: 1.6.0 + +.. note:: + + The feature is highly experimental and has limited features. + +**XGBoost PySpark** is a project allowing XGBoost running on PySpark environment. Alougth the code +of **XGBoost PySpark** is shipping in the **XGBoost Python package**, it is the wrapper of XGBoost4j-Spark +and XGBoost4j-Spark-Gpu, which means all the data preparation and training or infering will be routed +to the logical of xgboost4j-spark or xgboost4j-spark-gpu. + +.. contents:: + :backlinks: none + :local: + +******************************************** +Build an ML Application with XGBoost PySpark +******************************************** + +Installation +=================================== + +Let's create a new Conda environment to manage all the dependencies there. You can use Python Virtual +Environment if you prefer or not have any enviroment. + + +.. code-block:: shell + + conda create -n xgboost python=3.8 -y + conda activate xgboost + pip install xgboost==1.6.0 pyspark==3.1.2 + +Data Preparation +================ + +In this section, we use `Iris `_ dataset as an example to +showcase how we use Spark to transform raw dataset and make it fit to the data interface of XGBoost PySpark. + +Iris dataset is shipped in CSV format. Each instance contains 4 features, "sepal length", "sepal width", +"petal length" and "petal width". In addition, it contains the "class" column, which is essentially the label +with three possible values: "Iris Setosa", "Iris Versicolour" and "Iris Virginica". + + +Start SparkSession +------------------ + +.. code-block:: python + + from pyspark.sql import SparkSession + + spark = SparkSession.builder\ + .master("local[1]")\ + .config("spark.jars.packages", "ml.dmlc:xgboost4j_2.12:1.6.0,ml.dmlc:xgboost4j-spark_2.12:1.6.0")\ + .appName("xgboost-pyspark iris").getOrCreate() + +As aforementioned, XGBoost-PySpark is based on XGBoost4j-Spark or XGBoost4j-Spark-Gpu, we need to specify `spark.jars.packages` +with maven coordinates of XGBoost4j-Spark or XGBoost4j-Spark-Gpu jars. + +If you would like to submit your xgboost application (eg, iris.py) to the Spark cluster, you need to manually specify +the packages by + +.. code-block:: shell + + spark-submit \ + --master local[1] \ + --packages ml.dmlc:xgboost4j_2.12:1.6.0,ml.dmlc:xgboost4j-spark_2.12:1.6.0 \ + iris.py + +Read Dataset with Spark's Built-In Reader +----------------------------------------- + +The first thing in data transformation is to load the dataset as Spark's structured data abstraction, DataFrame. + +.. code-block:: python + + + from pyspark.sql.types import * + + schema = StructType([ + StructField("sepal length", DoubleType(), nullable=True), + StructField("sepal width", DoubleType(), nullable=True), + StructField("petal length", DoubleType(), nullable=True), + StructField("petal width", DoubleType(), nullable=True), + StructField("class", StringType(), nullable=True), + ]) + raw_input = spark.read.schema(schema).csv("input_path") + + +Transform Raw Iris Dataset +-------------------------- + +To make Iris dataset be recognizable to XGBoost, we need to + +1. Transform String-typed label, i.e. "class", to Double-typed label. +2. Assemble the feature columns as a vector to fit to the data interface of Spark ML framework. + +To convert String-typed label to Double, we can use PySpark's built-in feature transformer `StringIndexer `_. + +.. code-block:: python + + from pyspark.ml.feature import StringIndexer + + stringIndexer = StringIndexer(inputCol="class", outputCol="classIndex").fit(raw_input) + labeled_input = stringIndexer.transform(raw_input).drop("class") + +With a newly created StringIndexer instance: + +1. we set input column, i.e. the column containing String-typed label +2. we set output column, i.e. the column to contain the Double-typed label. +3. Then we ``fit`` StringIndex with our input DataFrame ``raw_input``, so that Spark internals can get information like total number of distinct values, etc. + +Now we have a StringIndexer which is ready to be applied to our input DataFrame. To execute the transformation logic of StringIndexer, we ``transform`` the input DataFrame ``raw_input`` and to keep a concise DataFrame, +we drop the column "class" and only keeps the feature columns and the transformed Double-typed label column (in the last line of the above code snippet). + +The ``fit`` and ``transform`` are two key operations in MLLIB. Basically, ``fit`` produces a "transformer", e.g. StringIndexer, and each transformer applies ``transform`` method on DataFrame to add new column(s) containing transformed features/labels or prediction results, etc. To understand more about ``fit`` and ``transform``, You can find more details in `here `_. + +Similarly, we can use another transformer, `VectorAssembler `_, to assemble feature columns "sepal length", "sepal width", "petal length" and "petal width" as a vector. + +.. code-block:: python + + from pyspark.ml.feature import VectorAssembler + vector_assembler = VectorAssembler()\ + .setInputCols(("sepal length", "sepal width", "petal length", "petal width"))\ + .setOutputCol("features") + xgb_input = vector_assembler.transform(labeled_input).select("features", "classIndex") + + +Now, we have a DataFrame containing only two columns, "features" which contains vector-represented +"sepal length", "sepal width", "petal length" and "petal width" and "classIndex" which has Double-typed +labels. A DataFrame like this (containing vector-represented features and numeric labels) can be fed to training engine directly. + +Training +======== + +XGBoost supports both regression and classification. While we use Iris dataset in this tutorial to show how we use xgboost-pyspark to resolve a multi-classes classification problem, the usage in Regression is very similar to classification. + +To train a XGBoost model for classification, we need to claim a XGBoostClassifier first: + +.. code-block:: python + + from xgboost.spark import XGBoostClassifier + + params = { + 'objective': 'multi:softprob', + 'treeMethod': 'hist', + 'numWorkers': 1, + 'numRound': 100, + 'numClass': 3, + 'labelCol': 'classIndex', + 'featuresCol': 'features' + } + + classifier = XGBoostClassifier(**params) + classifier.write().overwrite().save("/tmp/xgboost_classifier") + classifier1 = XGBoostClassifier.load("/tmp/xgboost_classifier") + +Equivalently, we can call the corresponding **setXXX** API to set the parameter, + +.. code-block:: python + + classifier = XGBoostClassifier()\ + .setLabelCol("classIndex")\ + .setFeaturesCol("features")\ + .setTreeMethod("hist")\ + .setNumClass(3)\ + .setNumRound(100)\ + .setObjective("multi:softprob") + classifier.setNumWorkers(1) + + +After we set XGBoostClassifier parameters and feature/label column, we can build a transformer, XGBoostClassificationModel by fitting XGBoostClassifier with the input DataFrame. This ``fit`` operation is essentially the training process and the generated model can then be used in prediction. + +.. code-block:: python + + model = classifier.fit(xgb_input) + +Prediction +========== + +When we get a model, either XGBoostClassificationModel or XGBoostRegressionModel, it takes a DataFrame, read the column containing feature vectors, predict for each feature vector, and output a new DataFrame with the following columns by default: + +* XGBoostClassificationModel will output margins (``rawPredictionCol``), probabilities(``probabilityCol``) and the eventual prediction labels (``predictionCol``) for each possible label. +* XGBoostRegressionModel will output prediction label(``predictionCol``). + +.. code-block:: python + + model = classifier.fit(xgb_input) + results = model.transform(xgb_input) + results.show() + +With the above code snippet, we get a result DataFrame, result containing margin, probability for each class and the prediction for each instance + +.. code-block:: none + + +-----------------+----------+--------------------+--------------------+----------+ + | features|classIndex| rawPrediction| probability|prediction| + +-----------------+----------+--------------------+--------------------+----------+ + |[5.1,3.5,1.4,0.2]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + |[4.9,3.0,1.4,0.2]| 0.0|[3.08765506744384...|[0.99636262655258...| 0.0| + |[4.7,3.2,1.3,0.2]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + |[4.6,3.1,1.5,0.2]| 0.0|[3.08765506744384...|[0.99679487943649...| 0.0| + |[5.0,3.6,1.4,0.2]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + |[5.4,3.9,1.7,0.4]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + |[4.6,3.4,1.4,0.3]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + |[5.0,3.4,1.5,0.2]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + |[4.4,2.9,1.4,0.2]| 0.0|[3.08765506744384...|[0.99636262655258...| 0.0| + |[4.9,3.1,1.5,0.1]| 0.0|[3.08765506744384...|[0.99679487943649...| 0.0| + |[5.4,3.7,1.5,0.2]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + |[4.8,3.4,1.6,0.2]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + |[4.8,3.0,1.4,0.1]| 0.0|[3.08765506744384...|[0.99636262655258...| 0.0| + |[4.3,3.0,1.1,0.1]| 0.0|[3.08765506744384...|[0.99636262655258...| 0.0| + |[5.8,4.0,1.2,0.2]| 0.0|[3.08765506744384...|[0.99072486162185...| 0.0| + |[5.7,4.4,1.5,0.4]| 0.0|[3.08765506744384...|[0.99072486162185...| 0.0| + |[5.4,3.9,1.3,0.4]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + |[5.1,3.5,1.4,0.3]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + |[5.7,3.8,1.7,0.3]| 0.0|[3.08765506744384...|[0.99072486162185...| 0.0| + |[5.1,3.8,1.5,0.3]| 0.0|[3.08765506744384...|[0.99680268764495...| 0.0| + +-----------------+----------+--------------------+--------------------+----------+ diff --git a/jvm-packages/integration-tests/README.md b/jvm-packages/integration-tests/README.md new file mode 100644 index 000000000000..31c96a4be54b --- /dev/null +++ b/jvm-packages/integration-tests/README.md @@ -0,0 +1,39 @@ +# XGBoost4j Pyspark API Integration Tests + +This integration tests framework refers to [Nvidia/spark-rapids/integration_tests](https://github.com/NVIDIA/spark-rapids/tree/branch-22.04/integration_tests). + +## Setting Up the Environment + +The tests are based off of `pyspark` and `pytest` running on Python 3. There really are +only a small number of Python dependencies that you need to install for the tests. The +dependencies also only need to be on the driver. You can install them on all nodes +in the cluster but it is not required. + +- install python dependencies + +``` bash +pip install pytest numpy scipy +``` + +- install xgboost python package + +XGBoost4j pyspark APIs are in xgboost python package, so we need to install it first + +``` bash +cd xgboost/python-packages +python setup.py install +``` + +- compile xgboost jvm packages + +``` bash +cd xgboost/jvm-packages +mvn -Dmaven.test.skip=true -DskipTests clean package +``` + +- run integration tests + +```bash +cd xgboost/jvm-packages/integration-tests +./run_pyspark_from_build.sh +``` diff --git a/jvm-packages/integration-tests/conftest.py b/jvm-packages/integration-tests/conftest.py new file mode 100644 index 000000000000..38e9e7a573d0 --- /dev/null +++ b/jvm-packages/integration-tests/conftest.py @@ -0,0 +1,19 @@ +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def pytest_addoption(parser): + """Pytest hook to define command line options for pytest""" + parser.addoption( + "--platform", action="store", default="cpu", help="optional values [ cpu, gpu ]" + ) diff --git a/jvm-packages/integration-tests/pytest.ini b/jvm-packages/integration-tests/pytest.ini new file mode 100644 index 000000000000..9969adf43de5 --- /dev/null +++ b/jvm-packages/integration-tests/pytest.ini @@ -0,0 +1,17 @@ +; Copyright (c) 2022 by Contributors +; +; Licensed under the Apache License, Version 2.0 (the "License"); +; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; http://www.apache.org/licenses/LICENSE-2.0 +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, +; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +; See the License for the specific language governing permissions and +; limitations under the License. + +[pytest] +markers = + skip_by_platform(platform): skip test for the given platform diff --git a/jvm-packages/integration-tests/python/conftest.py b/jvm-packages/integration-tests/python/conftest.py new file mode 100644 index 000000000000..8c3036342181 --- /dev/null +++ b/jvm-packages/integration-tests/python/conftest.py @@ -0,0 +1,46 @@ +# +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from python.spark_init_internal import get_spark + + +@pytest.fixture +def platform(request): + return request.config.getoption('platform') + + +# https://stackoverflow.com/questions/28179026/how-to-skip-a-pytest-using-an-external-fixture +@pytest.fixture(autouse=True) +def skip_by_platform(request, platform): + if request.node.get_closest_marker('skip_platform'): + if request.node.get_closest_marker('skip_platform').args[0] == platform: + pytest.skip('skipped on this platform: {}'.format(platform)) + + + +@pytest.fixture +def xgboost_tmp_path(request): + ret = '/tmp/xgboost-integration-tests/' + # Make sure it is there and accessible + sc = get_spark().sparkContext + config = sc._jsc.hadoopConfiguration() + path = sc._jvm.org.apache.hadoop.fs.Path(ret) + fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config) + fs.mkdirs(path) + yield ret + fs.delete(path) diff --git a/jvm-packages/integration-tests/python/parameter_test.py b/jvm-packages/integration-tests/python/parameter_test.py new file mode 100644 index 000000000000..1ab655562ee3 --- /dev/null +++ b/jvm-packages/integration-tests/python/parameter_test.py @@ -0,0 +1,53 @@ +# +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from xgboost.spark import XGBoostClassifier + + +def test_xgboost_parameters_from_dictionary(): + xgb_params = {'objective': 'multi:softprob', + 'treeMethod': 'hist', + 'numWorkers': 1, + 'labelCol': 'classIndex', + 'featuresCol': 'features', + 'numRound': 100, + 'numClass': 3} + xgb = XGBoostClassifier(**xgb_params) + assert xgb.getObjective() == 'multi:softprob' + assert xgb.getTreeMethod() == 'hist' + assert xgb.getNumWorkers() == 1 + assert xgb.getLabelCol() == 'classIndex' + assert xgb.getFeaturesCol() == 'features' + assert xgb.getNumRound() == 100 + assert xgb.getNumClass() == 3 + + +def test_xgboost_set_parameter(): + xgb = XGBoostClassifier() + xgb.setObjective('multi:softprob') + xgb.setTreeMethod('hist') + xgb.setNumWorkers(1) + xgb.setLabelCol('classIndex') + xgb.setFeaturesCol('features') + xgb.setNumRound(100) + xgb.setNumClass(3) + assert xgb.getObjective() == 'multi:softprob' + assert xgb.getTreeMethod() == 'hist' + assert xgb.getNumWorkers() == 1 + assert xgb.getLabelCol() == 'classIndex' + assert xgb.getFeaturesCol() == 'features' + assert xgb.getNumRound() == 100 + assert xgb.getNumClass() == 3 diff --git a/jvm-packages/integration-tests/python/spark_init_internal.py b/jvm-packages/integration-tests/python/spark_init_internal.py new file mode 100644 index 000000000000..4a525825d3e9 --- /dev/null +++ b/jvm-packages/integration-tests/python/spark_init_internal.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +try: + import pyspark +except ImportError as error: + import findspark + findspark.init() + import pyspark + +_DRIVER_ENV = 'PYSP_TEST_spark_driver_extraJavaOptions' + +def _spark__init(): + # Force the RapidsPlugin to be enabled, so it blows up if the classpath is not set properly + # DO NOT SET ANY OTHER CONFIGS HERE!!! + # due to bugs in pyspark/pytest it looks like any configs set here + # can be reset in the middle of a test if specific operations are done (some types of cast etc) + _sb = pyspark.sql.SparkSession.builder + + for key, value in os.environ.items(): + if key.startswith('PYSP_TEST_') and key != _DRIVER_ENV: + _sb.config(key[10:].replace('_', '.'), value) + + driver_opts = os.environ.get(_DRIVER_ENV, "") + + _sb.config('spark.driver.extraJavaOptions', driver_opts) + _handle_event_log_dir(_sb, 'gw0') + + _s = _sb.appName('xgboost4j pyspark integration tests').getOrCreate() + # TODO catch the ClassNotFound error that happens if the classpath is not set up properly and + # make it a better error message + _s.sparkContext.setLogLevel("WARN") + return _s + + +def _handle_event_log_dir(sb, wid): + if os.environ.get('SPARK_EVENTLOG_ENABLED', str(True)).lower() in [ + str(False).lower(), 'off', '0' + ]: + print('Automatic configuration for spark event log disabled') + return + + spark_conf = pyspark.SparkConf() + master_url = os.environ.get('PYSP_TEST_spark_master', + spark_conf.get("spark.master", 'local')) + event_log_config = os.environ.get('PYSP_TEST_spark_eventLog_enabled', + spark_conf.get('spark.eventLog.enabled', str(False).lower())) + event_log_codec = os.environ.get('PYSP_TEST_spark_eventLog_compression_codec', 'zstd') + + if not master_url.startswith('local') or event_log_config != str(False).lower(): + print("SPARK_EVENTLOG_ENABLED is ignored for non-local Spark master and when " + "it's pre-configured by the user") + return + d = "./eventlog_{}".format(wid) + if not os.path.exists(d): + os.makedirs(d) + + print('Spark event logs will appear under {}. Set the environmnet variable ' + 'SPARK_EVENTLOG_ENABLED=false if you want to disable it'.format(d)) + + sb\ + .config('spark.eventLog.dir', "file://{}".format(os.path.abspath(d))) \ + .config('spark.eventLog.compress', True) \ + .config('spark.eventLog.enabled', True) \ + .config('spark.eventLog.compression.codec', event_log_codec) + + +_spark = _spark__init() + + +def get_spark(): + """ + Get the current SparkSession. + """ + return _spark + + +def spark_version(): + return _spark.version diff --git a/jvm-packages/integration-tests/python/xgboost_classifier_test.py b/jvm-packages/integration-tests/python/xgboost_classifier_test.py new file mode 100644 index 000000000000..05c5720d4eca --- /dev/null +++ b/jvm-packages/integration-tests/python/xgboost_classifier_test.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.feature import StringIndexer +from pyspark.ml.linalg import Vectors +from xgboost.spark import XGBoostClassifier, XGBoostClassificationModel + +from spark_init_internal import get_spark + + +def test_save_xgboost_classifier(xgboost_tmp_path): + params = { + 'objective': 'binary:logistic', + 'numRound': 5, + 'numWorkers': 2, + 'treeMethod': 'hist' + } + classifier = XGBoostClassifier(**params) + + classifier_path = xgboost_tmp_path + "xgboost-classifier" + + classifier.write().overwrite().save(classifier_path) + classifier1 = XGBoostClassifier.load(classifier_path) + assert classifier1.getObjective() == 'binary:logistic' + assert classifier1.getNumRound() == 5 + assert classifier1.getNumWorkers() == 2 + assert classifier1.getTreeMethod() == 'hist' + + +def test_xgboost_regressor_training_without_error(xgboost_tmp_path): + spark = get_spark() + df = spark.createDataFrame([ + ("a", Vectors.dense([1.0, 2.0, 3.0, 4.0, 5.0])), + ("b", Vectors.dense([5.0, 6.0, 7.0, 8.0, 9.0]))], + ["label", "features"]) + label_name = 'label_indexed' + string_indexer = StringIndexer(inputCol="label", outputCol=label_name).fit(df) + indexed_df = string_indexer.transform(df).select(label_name, 'features') + params = { + 'objective': 'binary:logistic', + 'numRound': 5, + 'numWorkers': 1, + 'treeMethod': 'hist' + } + classifier = XGBoostClassifier(**params) \ + .setLabelCol(label_name) \ + .setFeaturesCol('features') + + classifier_path = xgboost_tmp_path + "xgboost-classifier" + classifier.write().overwrite().save(classifier_path) + classifier1 = XGBoostClassifier.load(classifier_path) + + model_path = xgboost_tmp_path + "xgboost-classifier-model" + model = classifier1.fit(indexed_df) + model.write().overwrite().save(model_path) + model1 = XGBoostClassificationModel.load(model_path) + model1.transform(df).show() diff --git a/jvm-packages/integration-tests/python/xgboost_regressor_test.py b/jvm-packages/integration-tests/python/xgboost_regressor_test.py new file mode 100644 index 000000000000..70626da0f186 --- /dev/null +++ b/jvm-packages/integration-tests/python/xgboost_regressor_test.py @@ -0,0 +1,64 @@ +# +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.linalg import Vectors +from xgboost.spark import XGBoostRegressor, XGBoostRegressionModel + +from python.spark_init_internal import get_spark + + +def test_save_xgboost_regressor(xgboost_tmp_path): + params = { + 'objective': 'reg:squarederror', + 'numRound': 5, + 'numWorkers': 2, + 'treeMethod': 'hist' + } + + regressor_path = xgboost_tmp_path + "xgboost-regressor" + + classifier = XGBoostRegressor(**params) + classifier.write().overwrite().save(regressor_path) + classifier1 = XGBoostRegressor.load(regressor_path) + assert classifier1.getObjective() == 'reg:squarederror' + assert classifier1.getNumRound() == 5 + assert classifier1.getNumWorkers() == 2 + assert classifier1.getTreeMethod() == 'hist' + + +def test_xgboost_regressor_training_without_error(xgboost_tmp_path): + spark = get_spark() + df = spark.createDataFrame([ + (1.0, Vectors.dense(1.0)), + (0.0, Vectors.dense(2.0))], ["label", "features"]) + params = { + 'objective': 'reg:squarederror', + 'numRound': 5, + 'numWorkers': 1, + 'treeMethod': 'hist' + } + regressor = XGBoostRegressor(**params) \ + .setLabelCol('label') \ + .setFeaturesCol('features') + regressor_path = xgboost_tmp_path + "xgboost-regressor" + regressor.write().overwrite().save(regressor_path) + regressor1 = XGBoostRegressor.load(regressor_path) + model = regressor1.fit(df) + + model_path = xgboost_tmp_path + "xgboost-regressor-model" + model.write().overwrite().save(model_path) + model1 = XGBoostRegressionModel.load(model_path) + model1.transform(df).show() diff --git a/jvm-packages/integration-tests/run_pyspark_from_build.sh b/jvm-packages/integration-tests/run_pyspark_from_build.sh new file mode 100755 index 000000000000..d74b2c8b5945 --- /dev/null +++ b/jvm-packages/integration-tests/run_pyspark_from_build.sh @@ -0,0 +1,119 @@ +#!/bin/bash +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -ex + +SCRIPTPATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" +cd "$SCRIPTPATH" + +if [[ $( echo ${SKIP_TESTS} | tr [:upper:] [:lower:] ) == "true" ]]; +then + echo "PYTHON INTEGRATION TESTS SKIPPED..." +elif [[ -z "$SPARK_HOME" ]]; +then + >&2 echo "SPARK_HOME IS NOT SET CANNOT RUN PYTHON INTEGRATION TESTS..." +else + echo "WILL RUN TESTS WITH SPARK_HOME: ${SPARK_HOME}" + + # support alternate local jars NOT building from the source code + if [ -d "$LOCAL_JAR_PATH" ]; then + XGBOOST_4J_JAR=$(echo "$LOCAL_JAR_PATH"/xgboost4j_2.12-*.jar) + XGBOOST_4J_SPARK_JAR=$(echo "$LOCAL_JAR_PATH"/xgboost4j-spark_2.12-*.jar) + else + XGBOOST_4J_JAR=$(echo "$SCRIPTPATH"/../xgboost4j/target/xgboost4j_2.12-*.jar) + XGBOOST_4J_SPARK_JAR=$(echo "$SCRIPTPATH"/../xgboost4j-spark/target/xgboost4j-spark_2.12-*.jar) + fi + if [ ! -e $XGBOOST_4J_JAR ]; then + echo "$XGBOOST_4J_JAR does not exist" + exit 2 + fi + if [ ! -e $XGBOOST_4J_SPARK_JAR ]; then + echo "$XGBOOST_4J_SPARK_JAR does not exist" + exit 2 + fi + ALL_JARS="$XGBOOST_4J_JAR,$XGBOOST_4J_SPARK_JAR" + echo "AND XGBoost JARS: $ALL_JARS" + + if [[ "${TEST}" != "" ]]; + then + TEST_ARGS="-k $TEST" + fi + if [[ "${TEST_TAGS}" != "" ]]; + then + TEST_TAGS="-m $TEST_TAGS" + fi + + TEST_TYPE_PARAM="" + if [[ "${TEST_TYPE}" != "" ]]; + then + TEST_TYPE_PARAM="--test_type $TEST_TYPE" + fi + + RUN_DIR=${RUN_DIR-"$SCRIPTPATH"/target/run_dir} + mkdir -p "$RUN_DIR" + cd "$RUN_DIR" + + TEST_COMMON_OPTS=(-v + -rfExXs + "$TEST_TAGS" + --color=yes + --platform='gpu' + $TEST_TYPE_PARAM + "$TEST_ARGS" + $RUN_TEST_PARAMS + --junitxml=TEST-pytest-`date +%s%N`.xml + "$@") + + NUM_LOCAL_EXECS=${NUM_LOCAL_EXECS:-0} + MB_PER_EXEC=${MB_PER_EXEC:-1024} + CORES_PER_EXEC=${CORES_PER_EXEC:-1} + + SPARK_TASK_MAXFAILURES=1 + + export PYSP_TEST_spark_driver_extraClassPath="${ALL_JARS// /:}" + export PYSP_TEST_spark_executor_extraClassPath="${ALL_JARS// /:}" + export PYSP_TEST_spark_driver_extraJavaOptions="-ea -Duser.timezone=UTC $COVERAGE_SUBMIT_FLAGS" + export PYSP_TEST_spark_executor_extraJavaOptions='-ea -Duser.timezone=UTC' + export PYSP_TEST_spark_ui_showConsoleProgress='false' + export PYSP_TEST_spark_sql_session_timeZone='UTC' + # prevent cluster shape to change + export PYSP_TEST_spark_dynamicAllocation_enabled='false' + + # Set spark.task.maxFailures for most schedulers. + # + # Local (non-cluster) mode is the exception and does not work with `spark.task.maxFailures`. + # It requires two arguments to the master specification "local[N, K]" where + # N is the number of threads, and K is the maxFailures (otherwise this is hardcoded to 1, + # see https://issues.apache.org/jira/browse/SPARK-2083). + export PYSP_TEST_spark_task_maxFailures="1" + + if ((NUM_LOCAL_EXECS > 0)); then + export PYSP_TEST_spark_master="local-cluster[$NUM_LOCAL_EXECS,$CORES_PER_EXEC,$MB_PER_EXEC]" + else + # If a master is not specified, use "local[*, $SPARK_TASK_MAXFAILURES]" + if [ -z "${PYSP_TEST_spark_master}" ] && [[ "$SPARK_SUBMIT_FLAGS" != *"--master"* ]]; then + export PYSP_TEST_spark_master="local[*,$SPARK_TASK_MAXFAILURES]" + fi + fi + + LOCAL_ROOTDIR=${LOCAL_ROOTDIR:-"$SCRIPTPATH"} + RUN_TESTS_COMMAND=("$SCRIPTPATH"/runtests.py + --rootdir + "$LOCAL_ROOTDIR" + "$LOCAL_ROOTDIR"/python) + + exec "$SPARK_HOME"/bin/spark-submit --jars "${ALL_JARS// /,}" \ + --driver-java-options "$PYSP_TEST_spark_driver_extraJavaOptions" \ + $SPARK_SUBMIT_FLAGS "${RUN_TESTS_COMMAND[@]}" "${TEST_COMMON_OPTS[@]}" +fi diff --git a/jvm-packages/integration-tests/runtests.py b/jvm-packages/integration-tests/runtests.py new file mode 100644 index 000000000000..2cbcbff9671a --- /dev/null +++ b/jvm-packages/integration-tests/runtests.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022 Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from pytest import main + +#import cProfile + +if __name__ == '__main__': + #cProfile.run('main(sys.argv[1:])', 'test_profile') + # arguments are the same as for pytest https://docs.pytest.org/en/latest/usage.html + # or run pytest -h + sys.exit(main(sys.argv[1:])) diff --git a/python-package/xgboost/spark/__init__.py b/python-package/xgboost/spark/__init__.py new file mode 100644 index 000000000000..0492565b61a7 --- /dev/null +++ b/python-package/xgboost/spark/__init__.py @@ -0,0 +1,34 @@ +# +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +PySpark extenstions for distributed training +------------------------------------------- + +See :doc:`Distributed XGBoost with PySpark ` for a quick start. +""" +from .estimator import ( + XGBoostClassifier, + XGBoostClassificationModel, + XGBoostRegressor, + XGBoostRegressionModel, +) + +__all__ = [ + "XGBoostClassifier", + "XGBoostClassificationModel", + "XGBoostRegressor", + "XGBoostRegressionModel", +] diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py new file mode 100644 index 000000000000..8af0efa5f60d --- /dev/null +++ b/python-package/xgboost/spark/estimator.py @@ -0,0 +1,275 @@ +# +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pylint: disable=too-many-ancestors, invalid-name +""" +Estimators for PySpark interface. +""" + +import sys +import importlib +import types +from typing import Optional + +import py4j +from pyspark import keyword_only +from pyspark.ml.common import inherit_doc + +from .internal import _XGBoostClassifierBase, _XGBoostClassificationModelBase +from .internal import _XGBoostRegressorBase, _XGBoostRegressionModelBase + + +def _init_module() -> None: + """Allows Pipeline()/PipelineModel() with XGBoost stages to be loaded from disk. + Needed because they try to import Python objects from their Java location. + + """ + if "ml" not in sys.modules: + sys.modules["ml"] = importlib.util.module_from_spec( + importlib.machinery.ModuleSpec(name="ml", loader=None) + ) + + def dummy_module(parent: types.ModuleType, name: str) -> None: + if not hasattr(parent, name): + setattr( + parent, name, importlib.machinery.ModuleSpec(name=name, loader=None) + ) + + dummy_module(sys.modules["ml"], "dmlc") + dummy_module(sys.modules["ml"].dmlc, "xgboost4j") + dummy_module(sys.modules["ml"].dmlc.xgboost4j, "scala") + + setattr(sys.modules["ml"].dmlc.xgboost4j.scala, "spark", sys.modules[__name__]) + sys.modules["ml.dmlc.xgboost4j.scala.spark"] = sys.modules[__name__] + + +_init_module() + + +@inherit_doc +class XGBoostClassifier(_XGBoostClassifierBase): + """XGBoostClassifier is a PySpark ML estimator. It implements the XGBoost + classification algorithm based on + `ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier` in XGBoost jvm packages, and it + can be used in PySpark Pipeline and PySpark ML meta algorithms like CrossValidator. + + .. versionadded:: 1.6.0 + + Examples + -------- + + >>> from pyspark.ml.feature import StringIndexer, VectorAssembler + >>> from pyspark.sql import SparkSession + >>> from pyspark.sql.types import * + >>> from xgboost.spark import XGBoostClassifier, XGBoostClassificationModel + >>> iris_data_path = 'iris.csv' + >>> schema = StructType([ + ... StructField("sepal length", DoubleType(), nullable=True), + ... StructField("sepal width", DoubleType(), nullable=True), + ... StructField("petal length", DoubleType(), nullable=True), + ... StructField("petal width", DoubleType(), nullable=True), + ... StructField("class", StringType(), nullable=True), + ... ]) + >>> raw_df = spark.read.schema(schema).csv(iris_data_path) + >>> stringIndexer = StringIndexer(inputCol="class", outputCol="classIndex").fit(raw_df) + >>> labeled_input = stringIndexer.transform(raw_df).drop("class") + >>> vector_assembler = VectorAssembler()\ + ... .setInputCols(("sepal length", "sepal width", "petal length", "petal width"))\ + ... .setOutputCol("features") + >>> xgb_input = vector_assembler.transform(labeled_input).select("features", "classIndex") + >>> params = { + ... 'objective': 'multi:softprob', + ... 'treeMethod': 'hist', + ... 'numWorkers': 1, + ... 'numRound': 5, + ... 'numClass': 3, + ... 'labelCol': 'classIndex', + ... 'featuresCol': 'features' + ... } + >>> classifier = XGBoostClassifier(**params) + >>> classifier.write().overwrite().save("/tmp/xgboost_classifier") + >>> classifier1 = XGBoostClassifier.load("/tmp/xgboost_classifier") + >>> model = classifier1.fit(xgb_input) + >>> model.write().overwrite().save("/tmp/xgboost_classifier_model") + >>> model1 = XGBoostClassificationModel.load("/tmp/xgboost_classifier_model") + >>> df = model1.transform(xgb_input) + >>> df.show(2) + +-----------------+----------+--------------------+--------------------+----------+ + | features|classIndex| rawPrediction| probability|prediction| + +-----------------+----------+--------------------+--------------------+----------+ + |[5.1,3.5,1.4,0.2]| 0.0|[1.84931623935699...|[0.82763016223907...| 0.0| + |[4.9,3.0,1.4,0.2]| 0.0|[1.84931623935699...|[0.82763016223907...| 0.0| + +-----------------+----------+--------------------+--------------------+----------+ + only showing top 2 rows + + Besides passing dictionary parameters to XGBoostClassifier, users can call set APIs + to set the parameters, + + xgb_classifier = XGBoostClassifier() \ + .setFeaturesCol("features") \ + .setLabelCol("classIndex") \ + .setNumRound(100) \ + .setNumClass(3) \ + .setObjective('multi:softprob') \ + .setTreeMethod('hist') + + """ + + # _java_class_name will be used when loading pipeline. + _java_class_name = "ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier" + + # pylint: disable=unused-argument + @keyword_only + def __init__( + self, + *, + featuresCol: Optional[str] = None, + labelCol: Optional[str] = None, + treeMethod: Optional[str] = None, + objective: Optional[str] = None, + numClass: Optional[int] = None, + numRound: Optional[int] = None, + numWorkers: Optional[int] = None + ): + super().__init__() + self._java_obj = self._new_java_obj(self.__class__._java_class_name, self.uid) + kwargs = self._input_kwargs # pylint: disable=no-member + self._set(**kwargs) + + def _create_model( + self, java_model: py4j.java_gateway.JavaObject + ) -> "XGBoostClassificationModel": + return XGBoostClassificationModel(java_model) + + +class XGBoostClassificationModel(_XGBoostClassificationModelBase): + """ + The model returned by :func:`xgboost.spark.XGBoostClassifier.fit()` + + """ + + # _java_class_name will be used when loading pipeline. + _java_class_name = "ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel" + + def __init__( + self, java_model: Optional[py4j.java_gateway.JavaObject] = None + ) -> None: + super().__init__(java_model=java_model) # type:ignore + if not java_model: + self._java_obj = self._new_java_obj( + self.__class__._java_class_name, self.uid + ) + # transfer jvm default values to python + self._transfer_params_from_java() + + +class XGBoostRegressor(_XGBoostRegressorBase): + """XGBoostRegressor is a PySpark ML estimator. It implements the XGBoost regression + algorithm based on `ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor` in XGBoost jvm + packages, and it can be used in PySpark Pipeline and PySpark ML meta algorithms like + CrossValidator. + + .. versionadded:: 1.6.0 + + Examples + -------- + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.sql import SparkSession + >>> from xgboost.spark import XGBoostRegressor, XGBoostRegressionModel + >>> input = spark.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.dense(2.0))], ["label", "features"]) + >>> params = { + ... 'objective': 'reg:squarederror', + ... 'treeMethod': 'hist', + ... 'numWorkers': 1, + ... 'numRound': 100, + ... 'labelCol': 'label', + ... 'featuresCol': 'features' + ... } + >>> regressor = XGBoostRegressor(**params) + >>> regressor.write().overwrite().save("/tmp/xgboost_regressor") + >>> regressor1 = XGBoostRegressor.load("/tmp/xgboost_regressor") + >>> model = regressor1.fit(input) + >>> model.write().overwrite().save("/tmp/xgboost_regressor_model") + >>> model1 = XGBoostRegressionModel.load("/tmp/xgboost_regressor_model") + >>> df = model1.transform(input) + >>> df.show() + +-----+--------+--------------------+ + |label|features| prediction| + +-----+--------+--------------------+ + | 1.0| [1.0]| 0.9991162419319153| + | 0.0| [2.0]|8.837578352540731E-4| + +-----+--------+--------------------+ + + Besides passing dictionary parameters to XGBoostClassifier, users can call set APIs + to set the parameters, + + xgb_classifier = XGBoostRegressor() \ + .setFeaturesCol("features") \ + .setLabelCol("label") \ + .setNumRound(100) \ + .setObjective('reg:squarederror') \ + .setTreeMethod('hist') + + """ + + # _java_class_name will be used when loading pipeline. + _java_class_name = "ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor" + + # pylint: disable=unused-argument + @keyword_only + def __init__( + self, + *, + featuresCol: Optional[str] = None, + labelCol: Optional[str] = None, + treeMethod: Optional[str] = None, + objective: Optional[str] = None, + numRound: Optional[int] = None, + numWorkers: Optional[int] = None + ): + super().__init__() + self._java_obj = self._new_java_obj(self.__class__._java_class_name, self.uid) + kwargs = self._input_kwargs # pylint: disable=no-member + self._set(**kwargs) + + def _create_model( + self, java_model: py4j.java_gateway.JavaObject + ) -> "XGBoostRegressionModel": + return XGBoostRegressionModel(java_model) + + +class XGBoostRegressionModel(_XGBoostRegressionModelBase): + """ + The model returned by :func:`xgboost.spark.XGBoostRegressor.fit()` + + """ + + # _java_class_name will be used when loading pipeline. + _java_class_name = "ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel" + + def __init__( + self, java_model: Optional[py4j.java_gateway.JavaObject] = None + ) -> None: + super().__init__(java_model=java_model) + if not java_model: + self._java_obj = self._new_java_obj( + self.__class__._java_class_name, self.uid + ) + # transfer jvm default values to python + self._transfer_params_from_java() diff --git a/python-package/xgboost/spark/internal.py b/python-package/xgboost/spark/internal.py new file mode 100644 index 000000000000..8999908fa626 --- /dev/null +++ b/python-package/xgboost/spark/internal.py @@ -0,0 +1,179 @@ +# +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pylint: disable=invalid-name, too-many-ancestors, protected-access +""" +Interal mixins for pyspark module. +""" +from abc import ABCMeta + +from pyspark.ml.classification import ( + Classifier, + ProbabilisticClassifier, + ProbabilisticClassificationModel, + ClassificationModel, +) +from pyspark.ml.common import inherit_doc +from pyspark.ml.regression import Regressor, RegressionModel +from pyspark.ml.util import JavaMLReadable, JavaMLReader, JavaMLWritable +from pyspark.ml.wrapper import JavaPredictor, JavaPredictionModel + +from .shared import _XGBoostRegressorParams, _XGBoostClassifierParams + + +class _XGBJavaMLReadable(JavaMLReadable): + """ + Mixin a version of xgboost JavaMLReadable + """ + + @classmethod + def read(cls): + return _XGBJavaMLReader(cls) + + +class _XGBJavaMLReader(JavaMLReader): + """ + Mixin a version of xgboost JavaMLReader + """ + + @classmethod + def _java_loader_class(cls, clazz): + if hasattr(clazz, "_java_class_name") and clazz._java_class_name is not None: + return clazz._java_class_name + return JavaMLReader._java_loader_class(clazz) + + +@inherit_doc +class _XGBJavaClassifier(Classifier, JavaPredictor, metaclass=ABCMeta): + """ + Java Classifier for classification tasks. + Classes are indexed {0, 1, ..., numClasses - 1}. + """ + + # Copied from _JavaClassifier + def setRawPredictionCol(self, value): + """ + Sets the value of :py:attr:`rawPredictionCol`. + """ + return self._set(rawPredictionCol=value) + + +@inherit_doc +class _XGBJavaProbabilisticClassifier( + ProbabilisticClassifier, _XGBJavaClassifier, metaclass=ABCMeta +): + """ + Java Probabilistic Classifier for classification tasks. + """ + + +class _XGBoostClassifierBase( + _XGBJavaProbabilisticClassifier, + _XGBoostClassifierParams, + JavaMLWritable, + _XGBJavaMLReadable, + metaclass=ABCMeta, +): + """ + The base class of XGBoostClassifier + """ + + +@inherit_doc +class _XGBJavaClassificationModel(ClassificationModel, JavaPredictionModel): + """ + Java Model produced by a ``Classifier``. + Classes are indexed {0, 1, ..., numClasses - 1}. + To be mixed in with :class:`pyspark.ml.JavaModel` + """ + + @property + def numClasses(self): + """ + Number of classes (values which the label can take). + """ + return self._call_java("numClasses") + + def predictRaw(self, value): + """ + Raw prediction for each possible label. + """ + return self._call_java("predictRaw", value) + + +@inherit_doc +class _XGBJavaProbabilisticClassificationModel( + ProbabilisticClassificationModel, _XGBJavaClassificationModel +): + """ + Java Model produced by a ``ProbabilisticClassifier``. + """ + + def predictProbability(self, value): + """ + Predict the probability of each class given the features. + """ + return self._call_java("predictProbability", value) + + +class _XGBoostClassificationModelBase( + _XGBJavaProbabilisticClassificationModel, + _XGBoostClassifierParams, + JavaMLWritable, + _XGBJavaMLReadable, +): + """ + The base class of XGBoostClassificationModel + """ + + +@inherit_doc +class _XGBJavaRegressor(Regressor, JavaPredictor, metaclass=ABCMeta): + """ + Java Regressor for regression tasks. + + .. versionadded:: 3.0.0 + """ + + +class _XGBoostRegressorBase( + _XGBJavaRegressor, + _XGBoostRegressorParams, + JavaMLWritable, + _XGBJavaMLReadable, + metaclass=ABCMeta, +): + """ + The base class of XGBoostRegressor + """ + + +@inherit_doc +class _XGBJavaRegressionModel(RegressionModel, JavaPredictionModel, metaclass=ABCMeta): + """ + Java Model produced by a ``_JavaRegressor``. + To be mixed in with :class:`pyspark.ml.JavaModel` + + .. versionadded:: 3.0.0 + """ + + +class _XGBoostRegressionModelBase( + _XGBJavaRegressionModel, JavaMLWritable, _XGBJavaMLReadable +): + """ + The base class of XGBoostRegressionModel + """ diff --git a/python-package/xgboost/spark/shared.py b/python-package/xgboost/spark/shared.py new file mode 100644 index 000000000000..55a892fa8a3f --- /dev/null +++ b/python-package/xgboost/spark/shared.py @@ -0,0 +1,1145 @@ +# +# Copyright (c) 2022 by Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=invalid-name, too-many-ancestors +""" +Shared parameters between pyspark components. +""" + +from pyspark.ml.param import Params, Param, TypeConverters +from pyspark.ml.param.shared import HasWeightCol + + +class HasNumClass(Params): + """ + Mixin for param numClass: the number of classes for classifier. + """ + + numClass = Param( + Params._dummy(), + "numClass", + "number of classes.", + typeConverter=TypeConverters.toInt, + ) + + def getNumClass(self): + """ + Gets the value of numClass or its default value. + """ + return self.getOrDefault(self.numClass) + + +class _BoosterParams(Params): + """ + Booster parameters + """ + + eta = Param( + Params._dummy(), + "eta", + "The step size shrinkage used in update to prevents overfitting. " + "After each boosting step, we can directly get the weights of new features. " + "and eta actually shrinks the feature weights to make the boosting process " + "more conservative.", + typeConverter=TypeConverters.toFloat, + ) + + gamma = Param( + Params._dummy(), + "gamma", + "minimum loss reduction required to make a further partition on a leaf node " + "of the tree. the larger, the more conservative the algorithm will be.", + typeConverter=TypeConverters.toFloat, + ) + + maxDepth = Param( + Params._dummy(), + "maxDepth", + "maximum depth of a tree, increase this value will " + "make model more complex/likely to be overfitting.", + typeConverter=TypeConverters.toInt, + ) + + maxLeaves = Param( + Params._dummy(), + "maxLeaves", + "Maximum number of nodes to be added. Only relevant when " + "grow_policy=lossguide is set.", + typeConverter=TypeConverters.toInt, + ) + + minChildWeight = Param( + Params._dummy(), + "minChildWeight", + "minimum sum of instance weight(hessian) needed in a child. If the " + "tree partition step results in a leaf node with the sum of instance " + "weight less than min_child_weight, then the building process will " + "give up further partitioning. In linear regression mode, this " + "simply corresponds to minimum number of instances needed to be " + "in each node. The larger, the more conservative the algorithm " + "will be.", + typeConverter=TypeConverters.toFloat, + ) + + maxDeltaStep = Param( + Params._dummy(), + "maxDeltaStep", + "Maximum delta step we allow each tree's weight estimation to be. " + "If the value is set to 0, it means there is no constraint. If " + "it is set to a positive value, it can help making the update " + "step more conservative. Usually this parameter is not needed, " + "but it might help in logistic regression when class is extremely " + "imbalanced. Set it to value of 1-10 might help control the update", + typeConverter=TypeConverters.toFloat, + ) + + subsample = Param( + Params._dummy(), + "subsample", + "subsample ratio of the training instance. Setting it to 0.5 means " + "that XGBoost randomly collected half of the data instances to grow " + "trees and this will prevent overfitting.", + typeConverter=TypeConverters.toFloat, + ) + + colsampleBytree = Param( + Params._dummy(), + "colsampleBytree", + "subsample ratio of columns when constructing each tree.", + typeConverter=TypeConverters.toFloat, + ) + + colsampleBylevel = Param( + Params._dummy(), + "colsampleBylevel", + "subsample ratio of columns for each split, in each level.", + typeConverter=TypeConverters.toFloat, + ) + + alpha = Param( + Params._dummy(), + "alpha", + "L1 regularization term on weights, increase this value will make model " + "more conservative.", + typeConverter=TypeConverters.toFloat, + ) + + treeMethod = Param( + Params._dummy(), + "treeMethod", + "The tree construction algorithm used in XGBoost. " + "Options: {'auto', 'exact', 'approx','gpu_hist'} [default='auto']", + typeConverter=TypeConverters.toString, + ) + + growPolicy = Param( + Params._dummy(), + "growPolicy", + "Controls a way new nodes are added to the tree. Currently supported " + "only if tree_method is set to hist. Choices: depthwise, lossguide. " + "depthwise: split at nodes closest to the root. lossguide: split " + "at nodes with highest loss change.", + typeConverter=TypeConverters.toString, + ) + + maxBins = Param( + Params._dummy(), + "maxBins", + "maximum number of bins in histogram.", + typeConverter=TypeConverters.toInt, + ) + + singlePrecisionHistogram = Param( + Params._dummy(), + "singlePrecisionHistogram", + "whether to use single precision to build histograms.", + typeConverter=TypeConverters.toBoolean, + ) + + sketchEps = Param( + Params._dummy(), + "sketchEps", + "This is only used for approximate greedy algorithm." + "This roughly translated into O(1 / sketch_eps) number of bins. " + "Compared to directly select number of bins, this comes with " + "theoretical guarantee with sketch accuracy, [default=0.03] range: (0, 1)", + typeConverter=TypeConverters.toFloat, + ) + + scalePosWeight = Param( + Params._dummy(), + "scalePosWeight", + "Control the balance of positive and negative weights, useful for unbalanced classes." + "A typical value to consider: sum(negative cases) / sum(positive cases)", + typeConverter=TypeConverters.toFloat, + ) + + sampleType = Param( + Params._dummy(), + "sampleType", + "type of sampling algorithm, options: {'uniform', 'weighted'}", + typeConverter=TypeConverters.toString, + ) + + normalizeType = Param( + Params._dummy(), + "normalizeType", + "type of normalization algorithm, options: {'tree', 'forest'}", + typeConverter=TypeConverters.toString, + ) + + rateDrop = Param( + Params._dummy(), + "rateDrop", + "dropout rate", + typeConverter=TypeConverters.toFloat, + ) + + skipDrop = Param( + Params._dummy(), + "skipDrop", + "probability of skip dropout. If a dropout is skipped, new trees " + "are added in the same manner as gbtree.", + typeConverter=TypeConverters.toFloat, + ) + + lambdaBias = Param( + Params._dummy(), + "lambdaBias", + "L2 regularization term on bias, default 0 (no L1 reg on bias " + "because it is not important) ", + typeConverter=TypeConverters.toFloat, + ) + + treeLimit = Param( + Params._dummy(), + "treeLimit", + "number of trees used in the prediction; defaults to 0 (use all trees).", + typeConverter=TypeConverters.toInt, + ) + + monotoneConstraints = Param( + Params._dummy(), + "monotoneConstraints", + "a list in length of number of features, 1 indicate monotonic increasing, " + "-1 means decreasing, 0 means no constraint. If it is shorter than number " + "of features, 0 will be padded.", + typeConverter=TypeConverters.toString, + ) + + interactionConstraints = Param( + Params._dummy(), + "interactionConstraints", + "Constraints for interaction representing permitted interactions. The constraints " + "must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]], " + "where each inner list is a group of indices of features that are allowed to " + "interact with each other. See tutorial for more information", + typeConverter=TypeConverters.toString, + ) + + def getEta(self): + """ + Gets the value of eta or its default value. + """ + return self.getOrDefault(self.eta) + + def getGamma(self): + """ + Gets the value of gamma or its default value. + """ + return self.getOrDefault(self.gamma) + + def getMaxDepth(self): + """ + Gets the value of maxDepth or its default value. + """ + return self.getOrDefault(self.maxDepth) + + def getMaxLeaves(self): + """ + Gets the value of maxLeaves or its default value. + """ + return self.getOrDefault(self.maxLeaves) + + def getMinChildWeight(self): + """ + Gets the value of minChildWeight or its default value. + """ + return self.getOrDefault(self.minChildWeight) + + def getMaxDeltaStep(self): + """ + Gets the value of minChildWeight or its default value. + """ + return self.getOrDefault(self.maxDeltaStep) + + def getAlpha(self): + """ + Gets the value of alpha or its default value. + """ + return self.getOrDefault(self.alpha) + + def getSubsample(self): + """ + Gets the value of subsample or its default value. + """ + return self.getOrDefault(self.subsample) + + def getColsampleBytree(self): + """ + Gets the value of colsampleBytree or its default value. + """ + return self.getOrDefault(self.colsampleBytree) + + def getColsampleBylevel(self): + """ + Gets the value of colsampleBylevel or its default value. + """ + return self.getOrDefault(self.colsampleBylevel) + + def getAlpha(self): + """ + Gets the value of alpha or its default value. + """ + return self.getOrDefault(self.alpha) + + def getTreeMethod(self): + """ + Gets the value of treeMethod or its default value. + """ + return self.getOrDefault(self.treeMethod) + + def getGrowPolicy(self): + """ + Gets the value of growPolicy or its default value. + """ + return self.getOrDefault(self.growPolicy) + + def getMaxBins(self): + """ + Gets the value of maxBins or its default value. + """ + return self.getOrDefault(self.maxBins) + + def getSinglePrecisionHistogram(self): + """ + Gets the value of singlePrecisionHistogram or its default value. + """ + return self.getOrDefault(self.singlePrecisionHistogram) + + def getSketchEps(self): + """ + Gets the value of sketchEps or its default value. + """ + return self.getOrDefault(self.sketchEps) + + def getScalePosWeight(self): + """ + Gets the value of scalePosWeight or its default value. + """ + return self.getOrDefault(self.scalePosWeight) + + def getSampleType(self): + """ + Gets the value of sampleType or its default value. + """ + return self.getOrDefault(self.sampleType) + + def getNormalizeType(self): + """ + Gets the value of normalizeType or its default value. + """ + return self.getOrDefault(self.normalizeType) + + def getRateDrop(self): + """ + Gets the value of rateDrop or its default value. + """ + return self.getOrDefault(self.rateDrop) + + def getSkipDrop(self): + """ + Gets the value of skipDrop or its default value. + """ + return self.getOrDefault(self.skipDrop) + + def getLambdaBias(self): + """ + Gets the value of lambdaBias or its default value. + """ + return self.getOrDefault(self.lambdaBias) + + def getTreeLimit(self): + """ + Gets the value of treeLimit or its default value. + """ + return self.getOrDefault(self.treeLimit) + + def getMonotoneConstraints(self): + """ + Gets the value of monotoneConstraints or its default value. + """ + return self.getOrDefault(self.monotoneConstraints) + + def getInteractionConstraints(self): + """ + Gets the value of interactionConstraints or its default value. + """ + return self.getOrDefault(self.interactionConstraints) + + +class _LearningTaskParams(Params): + """Parameters for learning """ + + objective = Param( + Params._dummy(), + "objective", + "Specify the learning task and the corresponding learning objective. " + "options: reg:squarederror, reg:squaredlogerror, reg:logistic, binary:logistic, binary:logitraw" + "count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma", + typeConverter=TypeConverters.toString, + ) + + objectiveType = Param( + Params._dummy(), + "objectiveType", + "The learning objective type of the specified custom objective and eval. " + "Corresponding type will be assigned if custom objective is defined " + "options: regression, classification.", + typeConverter=TypeConverters.toString, + ) + + baseScore = Param( + Params._dummy(), + "baseScore", + "the initial prediction score of all instances, global bias. default=0.5. ", + typeConverter=TypeConverters.toFloat, + ) + + evalMetric = Param( + Params._dummy(), + "evalMetric", + "evaluation metrics for validation data, a default metric will be assigned " + "according to objective (rmse for regression, and error for classification, " + "mean average precision for ranking).", + typeConverter=TypeConverters.toString, + ) + + trainTestRatio = Param( + Params._dummy(), + "trainTestRatio", + "fraction of training points to use for testing.", + typeConverter=TypeConverters.toFloat, + ) + + cacheTrainingSet = Param( + Params._dummy(), + "cacheTrainingSet", + "whether caching training data.", + typeConverter=TypeConverters.toBoolean, + ) + + skipCleanCheckpoint = Param( + Params._dummy(), + "skipCleanCheckpoint", + "whether cleaning checkpoint data", + typeConverter=TypeConverters.toBoolean, + ) + + numEarlyStoppingRounds = Param( + Params._dummy(), + "numEarlyStoppingRounds", + "number of rounds of decreasing eval metric to tolerate before stopping the training", + typeConverter=TypeConverters.toInt, + ) + + maximizeEvaluationMetrics = Param( + Params._dummy(), + "maximizeEvaluationMetrics", + "define the expected optimization to the evaluation metrics, true to maximize otherwise " + "minimize it", + typeConverter=TypeConverters.toBoolean, + ) + + killSparkContextOnWorkerFailure = Param( + Params._dummy(), + "killSparkContextOnWorkerFailure", + "whether to kill SparkContext when training task fails.", + typeConverter=TypeConverters.toBoolean, + ) + + def getObjective(self): + """ + Gets the value of objective or its default value. + """ + return self.getOrDefault(self.objective) + + def getObjectiveType(self): + """ + Gets the value of objectiveType or its default value. + """ + return self.getOrDefault(self.objectiveType) + + def getBaseScore(self): + """ + Gets the value of baseScore or its default value. + """ + return self.getOrDefault(self.baseScore) + + def getEvalMetric(self): + """ + Gets the value of evalMetric or its default value. + """ + return self.getOrDefault(self.evalMetric) + + def getTrainTestRatio(self): + """ + Gets the value of trainTestRatio or its default value. + """ + return self.getOrDefault(self.trainTestRatio) + + def getCacheTrainingSet(self): + """ + Gets the value of cacheTrainingSet or its default value. + """ + return self.getOrDefault(self.cacheTrainingSet) + + def getSkipCleanCheckpoint(self): + """ + Gets the value of skipCleanCheckpoint or its default value. + """ + return self.getOrDefault(self.skipCleanCheckpoint) + + def getNumEarlyStoppingRounds(self): + """ + Gets the value of numEarlyStoppingRounds or its default value. + """ + return self.getOrDefault(self.numEarlyStoppingRounds) + + def getMaximizeEvaluationMetrics(self): + """ + Gets the value of maximizeEvaluationMetrics or its default value. + """ + return self.getOrDefault(self.maximizeEvaluationMetrics) + + def getKillSparkContextOnWorkerFailure(self): + """ + Gets the value of killSparkContextOnWorkerFailure or its default value. + """ + return self.getOrDefault(self.killSparkContextOnWorkerFailure) + + +class _GeneralParams(Params): + """ + The general parameters. + """ + + numRound = Param( + Params._dummy(), + "numRound", + "The number of rounds for boosting.", + typeConverter=TypeConverters.toInt, + ) + + numWorkers = Param( + Params._dummy(), + "numWorkers", + "The number of workers used to run xgboost.", + typeConverter=TypeConverters.toInt, + ) + + nthread = Param( + Params._dummy(), + "nthread", + "The number of threads used by per worker.", + typeConverter=TypeConverters.toInt, + ) + + useExternalMemory = Param( + Params._dummy(), + "useExternalMemory", + "Whether to use external memory as cache.", + typeConverter=TypeConverters.toBoolean, + ) + + verbosity = Param( + Params._dummy(), + "verbosity", + "Verbosity of printing messages. Valid values are 0 (silent), 1 (warning), " + "2 (info), 3 (debug).", + typeConverter=TypeConverters.toInt, + ) + + missing = Param( + Params._dummy(), + "missing", + "The value treated as missing.", + typeConverter=TypeConverters.toFloat, + ) + + allowNonZeroForMissing = Param( + Params._dummy(), + "allowNonZeroForMissing", + "Allow to have a non-zero value for missing when training or " + "predicting on a Sparse or Empty vector. Should only be used " + "if did not use Spark's VectorAssembler class to construct " + "the feature vector but instead used a method that preserves " + "zeros in your vector.", + typeConverter=TypeConverters.toBoolean, + ) + + timeoutRequestWorkers = Param( + Params._dummy(), + "timeoutRequestWorkers", + "the maximum time to request new Workers if numCores are insufficient. " + "The timeout will be disabled if this value is set smaller than or equal to 0.", + typeConverter=TypeConverters.toInt, + ) + + checkpointPath = Param( + Params._dummy(), + "checkpointPath", + "the hdfs folder to load and save checkpoints. If there are existing checkpoints " + "in checkpoint_path. The job will load the checkpoint with highest version as the " + "starting point for training. If checkpoint_interval is also set, the job will " + "save a checkpoint every a few rounds.", + typeConverter=TypeConverters.toString, + ) + + checkpointInterval = Param( + Params._dummy(), + "checkpointInterval", + "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that " + "the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` " + "must also be set if the checkpoint interval is greater than 0.", + typeConverter=TypeConverters.toInt, + ) + + seed = Param( + Params._dummy(), "seed", "Random seed", typeConverter=TypeConverters.toInt + ) + + def getNumRound(self): + """ + Gets the value of numRound or its default value. + """ + return self.getOrDefault(self.numRound) + + def getNumWorkers(self): + """ + Gets the value of numWorkers or its default value. + """ + return self.getOrDefault(self.numWorkers) + + def getNthread(self): + """ + Gets the value of nthread or its default value. + """ + return self.getOrDefault(self.nthread) + + def getUseExternalMemory(self): + """ + Gets the value of nthread or its default value. + """ + return self.getOrDefault(self.useExternalMemory) + + def getVerbosity(self): + """ + Gets the value of verbosity or its default value. + """ + return self.getOrDefault(self.verbosity) + + def getMissing(self): + """ + Gets the value of missing or its default value. + """ + return self.getOrDefault(self.missing) + + def getAllowNonZeroForMissingMissing(self): + """ + Gets the value of allowNonZeroForMissing or its default value. + """ + return self.getOrDefault(self.allowNonZeroForMissing) + + def getTimeoutRequestWorkers(self): + """ + Gets the value of timeoutRequestWorkers or its default value. + """ + return self.getOrDefault(self.timeoutRequestWorkers) + + def getCheckpointPath(self): + """ + Gets the value of checkpointPath or its default value. + """ + return self.getOrDefault(self.checkpointPath) + + def getCheckpointInterval(self): + """ + Gets the value of checkpointInterval or its default value. + """ + return self.getOrDefault(self.checkpointInterval) + + def getSeed(self): + """ + Gets the value of seed or its default value. + """ + return self.getOrDefault(self.seed) + + +class HasBaseMarginCol(Params): + """ + Mixin for param baseMarginCol: baseMargin (aka base margin) column name. + """ + + baseMarginCol = Param( + Params._dummy(), + "baseMarginCol", + "base margin column name.", + typeConverter=TypeConverters.toString, + ) + + def getBaseMarginCol(self): + """ + Gets the value of baseMarginCol or its default value. + """ + return self.getOrDefault(self.baseMarginCol) + + +class HasLeafPredictionCol(Params): + """ + Mixin for param leafPredictionCol: leaf prediction column name. + """ + + leafPredictionCol = Param( + Params._dummy(), + "leafPredictionCol", + "leaf prediction column name.", + typeConverter=TypeConverters.toString, + ) + + def getLeafPredictionCol(self): + """ + Gets the value of leafPredictionCol or its default value. + """ + return self.getOrDefault(self.leafPredictionCol) + + +class HasContribPredictionCol(Params): + """ + Mixin for param contribPredictionCol: contribution column name. + """ + + contribPredictionCol = Param( + Params._dummy(), + "contribPredictionCol", + "The contribution column name.", + typeConverter=TypeConverters.toString, + ) + + def getContribPredictionCol(self): + """ + Gets the value of contribPredictionCol or its default value. + """ + return self.getOrDefault(self.contribPredictionCol) + + +class _RabitParams(Params): + """Rabit parameters passed through Rabit.Init into native layer""" + + rabitRingReduceThreshold = Param( + Params._dummy(), + "rabitRingReduceThreshold", + "threshold count to enable allreduce/broadcast with ring based topology.", + typeConverter=TypeConverters.toInt, + ) + + rabitTimeout = Param( + Params._dummy(), + "rabitTimeout", + "timeout threshold after rabit observed failures.", + typeConverter=TypeConverters.toInt, + ) + + rabitConnectRetry = Param( + Params._dummy(), + "rabitConnectRetry", + "number of retry worker do before fail.", + typeConverter=TypeConverters.toInt, + ) + + def getRabitRingReduceThreshold(self): + """ + Gets the value of rabitRingReduceThreshold or its default value. + """ + return self.getOrDefault(self.rabitRingReduceThreshold) + + def getRabitTimeout(self): + """ + Gets the value of rabitTimeout or its default value. + """ + return self.getOrDefault(self.rabitTimeout) + + def getRabitConnectRetry(self): + """ + Gets the value of rabitConnectRetry or its default value. + """ + return self.getOrDefault(self.rabitConnectRetry) + + +class _XGBoostCommonParams( + _GeneralParams, + _LearningTaskParams, + _BoosterParams, + _RabitParams, + HasWeightCol, + HasBaseMarginCol, + HasLeafPredictionCol, + HasContribPredictionCol, +): + """ + XGBoost common parameters for both XGBoostClassifier and XGBoostRegressor + """ + + def setNumRound(self, value): + """ + Sets the value of :py:attr:`numRound`. + """ + self._set(numRound=value) + return self + + def setNumWorkers(self, value): + """ + Sets the value of :py:attr:`numWorkers`. + """ + self._set(numWorkers=value) + return self + + def setNthread(self, value): + """ + Sets the value of :py:attr:`nthread`. + """ + self._set(nthread=value) + return self + + def setUseExternalMemory(self, value): + """ + Sets the value of :py:attr:`useExternalMemory`. + """ + self._set(useExternalMemory=value) + return self + + def setVerbosity(self, value): + """ + Sets the value of :py:attr:`verbosity`. + """ + self._set(verbosity=value) + return self + + def setMissing(self, value): + """ + Sets the value of :py:attr:`missing`. + """ + self._set(missing=value) + return self + + def setAllowNonZeroForMissingMissing(self, value): + """ + Sets the value of :py:attr:`allowNonZeroForMissing`. + """ + self._set(allowNonZeroForMissing=value) + return self + + def setTimeoutRequestWorkers(self, value): + """ + Sets the value of :py:attr:`timeoutRequestWorkers`. + """ + self._set(timeoutRequestWorkers=value) + return self + + def setCheckpointPath(self, value): + """ + Sets the value of :py:attr:`checkpointPath`. + """ + self._set(checkpointPath=value) + return self + + def setCheckpointInterval(self, value): + """ + Sets the value of :py:attr:`checkpointInterval`. + """ + self._set(checkpointInterval=value) + return self + + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + self._set(seed=value) + return self + + def setObjective(self, value): + """ + Sets the value of :py:attr:`objective`. + """ + return self._set(objective=value) + + def setObjectiveType(self, value): + """ + Sets the value of :py:attr:`objectiveType`. + """ + return self._set(objectiveType=value) + + def setBaseScore(self, value): + """ + Sets the value of :py:attr:`objectiveType`. + """ + return self._set(baseScore=value) + + def setEvalMetric(self, value): + """ + Sets the value of :py:attr:`evalMetric`. + """ + return self._set(evalMetric=value) + + def setTrainTestRatio(self, value): + """ + Sets the value of :py:attr:`trainTestRatio`. + """ + return self._set(trainTestRatio=value) + + def setCacheTrainingSet(self, value): + """ + Sets the value of :py:attr:`cacheTrainingSet`. + """ + return self._set(cacheTrainingSet=value) + + def setSkipCleanCheckpoint(self, value): + """ + Sets the value of :py:attr:`skipCleanCheckpoint`. + """ + return self._set(skipCleanCheckpoint=value) + + def setNumEarlyStoppingRounds(self, value): + """ + Sets the value of :py:attr:`numEarlyStoppingRounds`. + """ + return self._set(numEarlyStoppingRounds=value) + + def setMaximizeEvaluationMetrics(self, value): + """ + Sets the value of :py:attr:`maximizeEvaluationMetrics`. + """ + return self._set(maximizeEvaluationMetrics=value) + + def setKillSparkContextOnWorkerFailure(self, value): + """ + Sets the value of :py:attr:`killSparkContextOnWorkerFailure`. + """ + return self._set(killSparkContextOnWorkerFailure=value) + + def setEta(self, value): + """ + Sets the value of :py:attr:`eta`. + """ + return self._set(eta=value) + + def setGamma(self, value): + """ + Sets the value of :py:attr:`gamma`. + """ + return self._set(gamma=value) + + def setMaxDepth(self, value): + """ + Sets the value of :py:attr:`maxDepth`. + """ + return self._set(maxDepth=value) + + def setMaxLeaves(self, value): + """ + Sets the value of :py:attr:`maxLeaves`. + """ + return self._set(maxLeaves=value) + + def setMinChildWeight(self, value): + """ + Sets the value of :py:attr:`minChildWeight`. + """ + return self._set(minChildWeight=value) + + def setMaxDeltaStep(self, value): + """ + Sets the value of :py:attr:`maxDeltaStep`. + """ + return self._set(maxDeltaStep=value) + + def setAlpha(self, value): + """ + Sets the value of :py:attr:`alpha`. + """ + return self._set(alpha=value) + + def setSubsample(self, value): + """ + Sets the value of :py:attr:`subsample`. + """ + return self._set(subsample=value) + + def setColsampleBytree(self, value): + """ + Sets the value of :py:attr:`colsampleBytree`. + """ + return self._set(colsampleBytree=value) + + def setColsampleBylevel(self, value): + """ + Sets the value of :py:attr:`colsampleBylevel`. + """ + return self._set(colsampleBylevel=value) + + def setAlpha(self, value): + """ + Sets the value of :py:attr:`alpha`. + """ + return self._set(alpha=value) + + def setTreeMethod(self, value): + """ + Sets the value of :py:attr:`treeMethod`. + """ + return self._set(treeMethod=value) + + def setGrowPolicy(self, value): + """ + Sets the value of :py:attr:`growPolicy`. + """ + return self._set(growPolicy=value) + + def setMaxBins(self, value): + """ + Sets the value of :py:attr:`maxBins`. + """ + return self._set(maxBins=value) + + def setSinglePrecisionHistogram(self, value): + """ + Sets the value of :py:attr:`singlePrecisionHistogram`. + """ + return self._set(singlePrecisionHistogram=value) + + def setSketchEps(self, value): + """ + Sets the value of :py:attr:`sketchEps`. + """ + return self._set(sketchEps=value) + + def setScalePosWeight(self, value): + """ + Sets the value of :py:attr:`scalePosWeight`. + """ + return self._set(scalePosWeight=value) + + def setSampleType(self, value): + """ + Sets the value of :py:attr:`sampleType`. + """ + return self._set(sampleType=value) + + def setNormalizeType(self, value): + """ + Sets the value of :py:attr:`normalizeType`. + """ + return self._set(normalizeType=value) + + def setRateDrop(self, value): + """ + Sets the value of :py:attr:`rateDrop`. + """ + return self._set(rateDrop=value) + + def setSkipDrop(self, value): + """ + Sets the value of :py:attr:`skipDrop`. + """ + return self._set(skipDrop=value) + + def setLambdaBias(self, value): + """ + Sets the value of :py:attr:`lambdaBias`. + """ + return self._set(lambdaBias=value) + + def setTreeLimit(self, value): + """ + Sets the value of :py:attr:`treeLimit`. + """ + return self._set(treeLimit=value) + + def setMonotoneConstraints(self, value): + """ + Sets the value of :py:attr:`monotoneConstraints`. + """ + return self._set(monotoneConstraints=value) + + def setInteractionConstraints(self, value): + """ + Sets the value of :py:attr:`interactionConstraints`. + """ + return self._set(interactionConstraints=value) + + +class HasGroupCol(Params): + """ + Mixin for param groupCol: group column name for regressor. + """ + + groupCol = Param( + Params._dummy(), + "groupCol", + "The group column name.", + typeConverter=TypeConverters.toString, + ) + + def getGroupCol(self): + """ + Gets the value of groupCol or its default value. + """ + return self.getOrDefault(self.groupCol) + + +class _XGBoostClassifierParams(_XGBoostCommonParams, HasNumClass): + """ + XGBoostClassifier parameters + """ + + def setNumClass(self, value): + """ + Sets the value of :py:attr:`numClass`. + """ + self._set(numClass=value) + return self + + +class _XGBoostRegressorParams(_XGBoostCommonParams, HasGroupCol): + """ + XGBoostRegressor parameters + """ + + def setGroupCol(self, value): + """ + Sets the value of :py:attr:`numClass`. + """ + self._set(groupCol=value) + return self diff --git a/tests/ci_build/build_jvm_packages.sh b/tests/ci_build/build_jvm_packages.sh index 241fc445f640..e6714fc38b9f 100755 --- a/tests/ci_build/build_jvm_packages.sh +++ b/tests/ci_build/build_jvm_packages.sh @@ -24,5 +24,13 @@ if [ "x$gpu_arch" != "x" ]; then fi mvn --no-transfer-progress package -Dspark.version=${spark_version} $gpu_options + +pushd ../python-package +python setup.py install +popd + +cd integration-tests +./run_pyspark_from_build.sh + set +x set +e