[pyspark] Implement SparkXGBRanker estimator (#8172)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
WeichenXu 2022-08-23 02:35:19 +08:00 committed by GitHub
parent 35ef8abc27
commit f4628c22a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 235 additions and 27 deletions

View File

@ -10,6 +10,7 @@ except ImportError as e:
from .estimator import (
SparkXGBClassifier,
SparkXGBClassifierModel,
SparkXGBRanker,
SparkXGBRegressor,
SparkXGBRegressorModel,
)
@ -19,4 +20,5 @@ __all__ = [
"SparkXGBClassifierModel",
"SparkXGBRegressor",
"SparkXGBRegressorModel",
"SparkXGBRanker",
]

View File

@ -35,7 +35,7 @@ from xgboost.core import Booster
from xgboost.training import train as worker_train
import xgboost
from xgboost import XGBClassifier, XGBRegressor
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
from .data import (
_read_csr_matrix_from_unwrapped_spark_vec,
@ -54,6 +54,7 @@ from .params import (
HasBaseMarginCol,
HasEnableSparseDataOptim,
HasFeaturesCols,
HasQueryIdCol,
)
from .utils import (
RabitContext,
@ -86,6 +87,7 @@ _pyspark_specific_params = [
"feature_names",
"features_cols",
"enable_sparse_data_optim",
"qid_col",
]
_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"]
@ -116,6 +118,10 @@ _unsupported_fit_params = {
"eval_set",
"sample_weight_eval_set",
"base_margin", # Supported by spark param base_margin_col
"group", # Use spark param `qid_col` instead
"qid", # Use spark param `qid_col` instead
"eval_group", # Use spark param `qid_col` instead
"eval_qid", # Use spark param `qid_col` instead
}
_unsupported_predict_params = {
@ -136,6 +142,7 @@ class _SparkXGBParams(
HasBaseMarginCol,
HasFeaturesCols,
HasEnableSparseDataOptim,
HasQueryIdCol,
):
num_workers = Param(
Params._dummy(),
@ -572,13 +579,19 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
params["verbose_eval"] = verbose_eval
classification = self._xgb_cls() == XGBClassifier
num_classes = int(dataset.select(countDistinct(alias.label)).collect()[0][0])
if classification and num_classes == 2:
params["objective"] = "binary:logistic"
elif classification and num_classes > 2:
params["objective"] = "multi:softprob"
params["num_class"] = num_classes
if classification:
num_classes = int(
dataset.select(countDistinct(alias.label)).collect()[0][0]
)
if num_classes <= 2:
params["objective"] = "binary:logistic"
else:
params["objective"] = "multi:softprob"
params["num_class"] = num_classes
else:
params["objective"] = "reg:squarederror"
# use user specified objective or default objective.
# e.g., the default objective for Regressor is 'reg:squarederror'
params["objective"] = self.getOrDefault(self.objective)
# TODO: support "num_parallel_tree" for random forest
params["num_boost_round"] = self.getOrDefault(self.n_estimators)
@ -648,6 +661,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
col(self.getOrDefault(self.base_margin_col)).alias(alias.margin)
)
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))
dataset = dataset.select(*select_cols)
num_workers = self.getOrDefault(self.num_workers)
@ -782,6 +798,10 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
super().__init__()
self._xgb_sklearn_model = xgb_sklearn_model
@classmethod
def _xgb_cls(cls):
raise NotImplementedError()
def get_booster(self):
"""
Return the `xgboost.core.Booster` instance.
@ -818,9 +838,6 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
"""
return SparkXGBModelReader(cls)
def _transform(self, dataset):
raise NotImplementedError()
def _get_feature_col(self, dataset) -> (list, Optional[list]):
"""XGBoost model trained with features_cols parameter can also predict
vector or array feature type. But first we need to check features_cols
@ -855,18 +872,6 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
)
return features_col, feature_col_names
class SparkXGBRegressorModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls):
return XGBRegressor
def _transform(self, dataset):
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
@ -920,6 +925,30 @@ class SparkXGBRegressorModel(_SparkXGBModel):
return dataset.withColumn(predictionColName, pred_col)
class SparkXGBRegressorModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls):
return XGBRegressor
class SparkXGBRankerModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls):
return XGBRanker
class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol):
"""
The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit`

View File

@ -19,8 +19,8 @@ def stack_series(series: pd.Series) -> np.ndarray:
# Global constant for defining column alias shared between estimator and data
# processing procedures.
Alias = namedtuple("Alias", ("data", "label", "weight", "margin", "valid"))
alias = Alias("values", "label", "weight", "baseMargin", "validationIndicator")
Alias = namedtuple("Alias", ("data", "label", "weight", "margin", "valid", "qid"))
alias = Alias("values", "label", "weight", "baseMargin", "validationIndicator", "qid")
def concat_or_none(seq: Optional[Sequence[np.ndarray]]) -> Optional[np.ndarray]:
@ -41,6 +41,7 @@ def cache_partitions(
append(part, alias.label, is_valid)
append(part, alias.weight, is_valid)
append(part, alias.margin, is_valid)
append(part, alias.qid, is_valid)
has_validation: Optional[bool] = None
@ -94,6 +95,7 @@ class PartIter(DataIter):
label=self._fetch(self._data.get(alias.label, None)),
weight=self._fetch(self._data.get(alias.weight, None)),
base_margin=self._fetch(self._data.get(alias.margin, None)),
qid=self._fetch(self._data.get(alias.qid, None)),
)
self._iter += 1
return 1
@ -226,9 +228,9 @@ def create_dmatrix_from_partitions(
label = concat_or_none(values.get(alias.label, None))
weight = concat_or_none(values.get(alias.weight, None))
margin = concat_or_none(values.get(alias.margin, None))
qid = concat_or_none(values.get(alias.qid, None))
return DMatrix(
data=data, label=label, weight=weight, base_margin=margin, **kwargs
data=data, label=label, weight=weight, base_margin=margin, qid=qid, **kwargs
)
is_dmatrix = feature_cols is None

View File

@ -3,10 +3,11 @@
# pylint: disable=too-many-ancestors
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
from xgboost import XGBClassifier, XGBRegressor
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
from .core import (
SparkXGBClassifierModel,
SparkXGBRankerModel,
SparkXGBRegressorModel,
_set_pyspark_xgb_cls_param_attrs,
_SparkXGBEstimator,
@ -106,6 +107,13 @@ class SparkXGBRegressor(_SparkXGBEstimator):
def _pyspark_model_cls(cls):
return SparkXGBRegressorModel
def _validate_params(self):
super()._validate_params()
if self.isDefined(self.qid_col):
raise ValueError(
"Spark Xgboost regressor estimator does not support `qid_col` param."
)
_set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel)
@ -213,5 +221,126 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
def _pyspark_model_cls(cls):
return SparkXGBClassifierModel
def _validate_params(self):
super()._validate_params()
if self.isDefined(self.qid_col):
raise ValueError(
"Spark Xgboost classifier estimator does not support `qid_col` param."
)
_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)
class SparkXGBRanker(_SparkXGBEstimator):
"""SparkXGBRanker is a PySpark ML estimator. It implements the XGBoost
classification algorithm based on XGBoost python library, and it can be used in
PySpark Pipeline and PySpark ML meta algorithms like
:py:class:`~pyspark.ml.tuning.CrossValidator`/
:py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
:py:class:`~pyspark.ml.classification.OneVsRest`
SparkXGBRanker automatically supports most of the parameters in
`xgboost.XGBClassifier` constructor and most of the parameters used in
:py:class:`xgboost.XGBClassifier` fit and predict method.
SparkXGBRanker doesn't support setting `gpu_id` but support another param `use_gpu`,
see doc below for more details.
SparkXGBRanker doesn't support setting `base_margin` explicitly as well, but support
another param called `base_margin_col`. see doc below for more details.
SparkXGBRanker doesn't support setting `output_margin`, but we can get output margin
from the raw prediction column. See `raw_prediction_col` param doc below for more details.
SparkXGBRanker doesn't support `validate_features` and `output_margin` param.
SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the `nthread`
param for each xgboost worker will be set equal to `spark.task.cpus` config value.
Parameters
----------
callbacks:
The export and import of the callback functions are at best effort. For
details, see :py:attr:`xgboost.spark.SparkXGBClassifier.callbacks` param doc.
validation_indicator_col:
For params related to `xgboost.XGBClassifier` training with
evaluation dataset's supervision,
set :py:attr:`xgboost.spark.SparkXGBClassifier.validation_indicator_col`
parameter instead of setting the `eval_set` parameter in `xgboost.XGBClassifier`
fit method.
weight_col:
To specify the weight of the training and validation dataset, set
:py:attr:`xgboost.spark.SparkXGBClassifier.weight_col` parameter instead of setting
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier`
fit method.
xgb_model:
Set the value to be the instance returned by
:func:`xgboost.spark.SparkXGBClassifierModel.get_booster`.
num_workers:
Integer that specifies the number of XGBoost workers to use.
Each XGBoost worker corresponds to one spark task.
use_gpu:
Boolean that specifies whether the executors are running on GPU
instances.
base_margin_col:
To specify the base margins of the training and validation
dataset, set :py:attr:`xgboost.spark.SparkXGBRanker.base_margin_col` parameter
instead of setting `base_margin` and `base_margin_eval_set` in the
`xgboost.XGBRanker` fit method.
qid_col:
To specify the qid of the training and validation
dataset, set :py:attr:`xgboost.spark.SparkXGBRanker.qid_col` parameter
instead of setting `qid` / `group`, `eval_qid` / `eval_group` in the
`xgboost.XGBRanker` fit method.
.. Note:: The Parameters chart above contains parameters that need special handling.
For a full list of parameters, see entries with `Param(parent=...` below.
.. Note:: This API is experimental.
Examples
--------
>>> from xgboost.spark import SparkXGBClassifier
>>> from pyspark.ml.linalg import Vectors
>>> df_train = spark.createDataFrame([
... (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
... (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
... (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
... (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
... ], ["features", "label", "isVal", "weight"])
>>> df_test = spark.createDataFrame([
... (Vectors.dense(1.0, 2.0, 3.0), ),
... ], ["features"])
>>> xgb_classifier = SparkXGBClassifier(max_depth=5, missing=0.0,
... validation_indicator_col='isVal', weight_col='weight',
... early_stopping_rounds=1, eval_metric='logloss')
>>> xgb_clf_model = xgb_classifier.fit(df_train)
>>> xgb_clf_model.transform(df_test).show()
"""
def __init__(self, **kwargs):
super().__init__()
self.setParams(**kwargs)
@classmethod
def _xgb_cls(cls):
return XGBRanker
@classmethod
def _pyspark_model_cls(cls):
return SparkXGBRankerModel
def _validate_params(self):
super()._validate_params()
if not self.isDefined(self.qid_col):
raise ValueError(
"Spark Xgboost ranker estimator requires setting `qid_col` param."
)
_set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel)

View File

@ -72,3 +72,17 @@ class HasEnableSparseDataOptim(Params):
def __init__(self):
super().__init__()
self._setDefault(enable_sparse_data_optim=False)
class HasQueryIdCol(Params):
"""
Mixin for param featuresCols: a list of feature column names.
This parameter is taken effect only when use_gpu is enabled.
"""
qid_col = Param(
Params._dummy(),
"qid_col",
"query id column name",
typeConverter=TypeConverters.toString,
)

View File

@ -24,6 +24,7 @@ from pyspark.sql import functions as spark_sql_func
from xgboost.spark import (
SparkXGBClassifier,
SparkXGBClassifierModel,
SparkXGBRanker,
SparkXGBRegressor,
SparkXGBRegressorModel,
)
@ -380,6 +381,28 @@ class XgboostLocalTest(SparkTestCase):
"expected_prediction_with_base_margin",
],
)
self.ranker_df_train = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
],
["features", "label", "qid"],
)
self.ranker_df_test = self.session.createDataFrame(
[
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
],
["features", "qid", "expected_prediction"],
)
self.reg_df_sparse_train = self.session.createDataFrame(
[
@ -1024,3 +1047,12 @@ class XgboostLocalTest(SparkTestCase):
for row1, row2 in zip(pred_result, pred_result2):
self.assertTrue(np.allclose(row1.probability, row2.probability, rtol=1e-3))
def test_ranker(self):
ranker = SparkXGBRanker(qid_col="qid")
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
model = ranker.fit(self.ranker_df_train)
pred_result = model.transform(self.ranker_df_test).collect()
for row in pred_result:
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)