[pyspark] support pred_contribs (#8633)
This commit is contained in:
@@ -45,6 +45,7 @@ from .data import (
|
||||
_read_csr_matrix_from_unwrapped_spark_vec,
|
||||
alias,
|
||||
create_dmatrix_from_partitions,
|
||||
pred_contribs,
|
||||
stack_series,
|
||||
)
|
||||
from .model import (
|
||||
@@ -56,6 +57,7 @@ from .model import (
|
||||
from .params import (
|
||||
HasArbitraryParamsDict,
|
||||
HasBaseMarginCol,
|
||||
HasContribPredictionCol,
|
||||
HasEnableSparseDataOptim,
|
||||
HasFeaturesCols,
|
||||
HasQueryIdCol,
|
||||
@@ -92,6 +94,7 @@ _pyspark_specific_params = [
|
||||
"enable_sparse_data_optim",
|
||||
"qid_col",
|
||||
"repartition_random_shuffle",
|
||||
"pred_contrib_col",
|
||||
]
|
||||
|
||||
_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"]
|
||||
@@ -140,6 +143,12 @@ _unsupported_predict_params = {
|
||||
"base_margin", # Use pyspark base_margin_col param instead.
|
||||
}
|
||||
|
||||
# Global prediction names
|
||||
Pred = namedtuple(
|
||||
"Pred", ("prediction", "raw_prediction", "probability", "pred_contrib")
|
||||
)
|
||||
pred = Pred("prediction", "rawPrediction", "probability", "predContrib")
|
||||
|
||||
|
||||
class _SparkXGBParams(
|
||||
HasFeaturesCol,
|
||||
@@ -152,6 +161,7 @@ class _SparkXGBParams(
|
||||
HasFeaturesCols,
|
||||
HasEnableSparseDataOptim,
|
||||
HasQueryIdCol,
|
||||
HasContribPredictionCol,
|
||||
):
|
||||
num_workers = Param(
|
||||
Params._dummy(),
|
||||
@@ -993,6 +1003,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
return features_col, feature_col_names
|
||||
|
||||
def _transform(self, dataset):
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
# to avoid the `self` object to be pickled to remote.
|
||||
xgb_sklearn_model = self._xgb_sklearn_model
|
||||
@@ -1010,7 +1021,19 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||
|
||||
@pandas_udf("double")
|
||||
pred_contrib_col_name = None
|
||||
if self.isDefined(self.pred_contrib_col) and self.getOrDefault(
|
||||
self.pred_contrib_col
|
||||
):
|
||||
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
|
||||
|
||||
single_pred = True
|
||||
schema = "double"
|
||||
if pred_contrib_col_name:
|
||||
single_pred = False
|
||||
schema = f"{pred.prediction} double, {pred.pred_contrib} array<double>"
|
||||
|
||||
@pandas_udf(schema)
|
||||
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
||||
model = xgb_sklearn_model
|
||||
for data in iterator:
|
||||
@@ -1027,22 +1050,48 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
else:
|
||||
base_margin = None
|
||||
|
||||
data = {}
|
||||
preds = model.predict(
|
||||
X,
|
||||
base_margin=base_margin,
|
||||
validate_features=False,
|
||||
**predict_params,
|
||||
)
|
||||
yield pd.Series(preds)
|
||||
data[pred.prediction] = pd.Series(preds)
|
||||
|
||||
if pred_contrib_col_name:
|
||||
contribs = pred_contribs(model, X, base_margin)
|
||||
data[pred.pred_contrib] = pd.Series(list(contribs))
|
||||
yield pd.DataFrame(data=data)
|
||||
else:
|
||||
yield data[pred.prediction]
|
||||
|
||||
if has_base_margin:
|
||||
pred_col = predict_udf(struct(*features_col, base_margin_col))
|
||||
else:
|
||||
pred_col = predict_udf(struct(*features_col))
|
||||
|
||||
predictionColName = self.getOrDefault(self.predictionCol)
|
||||
prediction_col_name = self.getOrDefault(self.predictionCol)
|
||||
|
||||
return dataset.withColumn(predictionColName, pred_col)
|
||||
if single_pred:
|
||||
dataset = dataset.withColumn(prediction_col_name, pred_col)
|
||||
else:
|
||||
pred_struct_col = "_prediction_struct"
|
||||
dataset = dataset.withColumn(pred_struct_col, pred_col)
|
||||
|
||||
dataset = dataset.withColumn(
|
||||
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
|
||||
)
|
||||
|
||||
if pred_contrib_col_name:
|
||||
dataset = dataset.withColumn(
|
||||
pred_contrib_col_name,
|
||||
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
|
||||
)
|
||||
|
||||
dataset = dataset.drop(pred_struct_col)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class SparkXGBRegressorModel(_SparkXGBModel):
|
||||
@@ -1069,7 +1118,9 @@ class SparkXGBRankerModel(_SparkXGBModel):
|
||||
return XGBRanker
|
||||
|
||||
|
||||
class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol):
|
||||
class SparkXGBClassifierModel(
|
||||
_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol, HasContribPredictionCol
|
||||
):
|
||||
"""
|
||||
The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit`
|
||||
|
||||
@@ -1081,7 +1132,7 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
||||
return XGBClassifier
|
||||
|
||||
def _transform(self, dataset):
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
# to avoid the `self` object to be pickled to remote.
|
||||
xgb_sklearn_model = self._xgb_sklearn_model
|
||||
@@ -1112,9 +1163,22 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||
|
||||
@pandas_udf(
|
||||
"rawPrediction array<double>, prediction double, probability array<double>"
|
||||
pred_contrib_col_name = None
|
||||
if self.isDefined(self.pred_contrib_col) and self.getOrDefault(
|
||||
self.pred_contrib_col
|
||||
):
|
||||
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
|
||||
|
||||
schema = (
|
||||
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
|
||||
f" {pred.probability} array<double>"
|
||||
)
|
||||
if pred_contrib_col_name:
|
||||
# We will force setting strict_shape to True when predicting contribs,
|
||||
# So, it will also output 3-D shape result.
|
||||
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
|
||||
|
||||
@pandas_udf(schema)
|
||||
def predict_udf(
|
||||
iterator: Iterator[Tuple[pd.Series, ...]]
|
||||
) -> Iterator[pd.DataFrame]:
|
||||
@@ -1145,13 +1209,17 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
||||
# It seems that they use argmax of class probs,
|
||||
# not of margin to get the prediction (Note: scala implementation)
|
||||
preds = np.argmax(class_probs, axis=1)
|
||||
yield pd.DataFrame(
|
||||
data={
|
||||
"rawPrediction": pd.Series(list(raw_preds)),
|
||||
"prediction": pd.Series(preds),
|
||||
"probability": pd.Series(list(class_probs)),
|
||||
}
|
||||
)
|
||||
data = {
|
||||
pred.raw_prediction: pd.Series(list(raw_preds)),
|
||||
pred.prediction: pd.Series(preds),
|
||||
pred.probability: pd.Series(list(class_probs)),
|
||||
}
|
||||
|
||||
if pred_contrib_col_name:
|
||||
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
|
||||
data[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
|
||||
|
||||
yield pd.DataFrame(data=data)
|
||||
|
||||
if has_base_margin:
|
||||
pred_struct = predict_udf(struct(*features_col, base_margin_col))
|
||||
@@ -1159,23 +1227,32 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
||||
pred_struct = predict_udf(struct(*features_col))
|
||||
|
||||
pred_struct_col = "_prediction_struct"
|
||||
|
||||
rawPredictionColName = self.getOrDefault(self.rawPredictionCol)
|
||||
predictionColName = self.getOrDefault(self.predictionCol)
|
||||
probabilityColName = self.getOrDefault(self.probabilityCol)
|
||||
dataset = dataset.withColumn(pred_struct_col, pred_struct)
|
||||
if rawPredictionColName:
|
||||
|
||||
raw_prediction_col_name = self.getOrDefault(self.rawPredictionCol)
|
||||
if raw_prediction_col_name:
|
||||
dataset = dataset.withColumn(
|
||||
rawPredictionColName,
|
||||
array_to_vector(col(pred_struct_col).rawPrediction),
|
||||
raw_prediction_col_name,
|
||||
array_to_vector(getattr(col(pred_struct_col), pred.raw_prediction)),
|
||||
)
|
||||
if predictionColName:
|
||||
|
||||
prediction_col_name = self.getOrDefault(self.predictionCol)
|
||||
if prediction_col_name:
|
||||
dataset = dataset.withColumn(
|
||||
predictionColName, col(pred_struct_col).prediction
|
||||
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
|
||||
)
|
||||
if probabilityColName:
|
||||
|
||||
probability_col_name = self.getOrDefault(self.probabilityCol)
|
||||
if probability_col_name:
|
||||
dataset = dataset.withColumn(
|
||||
probabilityColName, array_to_vector(col(pred_struct_col).probability)
|
||||
probability_col_name,
|
||||
array_to_vector(getattr(col(pred_struct_col), pred.probability)),
|
||||
)
|
||||
|
||||
if pred_contrib_col_name:
|
||||
dataset = dataset.withColumn(
|
||||
pred_contrib_col_name,
|
||||
getattr(col(pred_struct_col), pred.pred_contrib),
|
||||
)
|
||||
|
||||
return dataset.drop(pred_struct_col)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# pylint: disable=protected-access
|
||||
"""Utilities for processing spark partitions."""
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
||||
@@ -7,8 +8,10 @@ import pandas as pd
|
||||
from scipy.sparse import csr_matrix
|
||||
from xgboost.compat import concat
|
||||
|
||||
from xgboost import DataIter, DMatrix, QuantileDMatrix
|
||||
from xgboost import DataIter, DMatrix, QuantileDMatrix, XGBModel
|
||||
|
||||
from .._typing import ArrayLike
|
||||
from ..core import _convert_ntree_limit
|
||||
from .utils import get_logger # type: ignore
|
||||
|
||||
|
||||
@@ -331,3 +334,29 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
|
||||
assert dvalid.num_col() == dtrain.num_col()
|
||||
|
||||
return dtrain, dvalid
|
||||
|
||||
|
||||
def pred_contribs(
|
||||
model: XGBModel,
|
||||
data: ArrayLike,
|
||||
base_margin: Optional[ArrayLike] = None,
|
||||
strict_shape: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""Predict contributions with data with the full model."""
|
||||
iteration_range = _convert_ntree_limit(model.get_booster(), None, None)
|
||||
iteration_range = model._get_iteration_range(iteration_range)
|
||||
data_dmatrix = DMatrix(
|
||||
data,
|
||||
base_margin=base_margin,
|
||||
missing=model.missing,
|
||||
nthread=model.n_jobs,
|
||||
feature_types=model.feature_types,
|
||||
enable_categorical=model.enable_categorical,
|
||||
)
|
||||
return model.get_booster().predict(
|
||||
data_dmatrix,
|
||||
pred_contribs=True,
|
||||
validate_features=False,
|
||||
iteration_range=iteration_range,
|
||||
strict_shape=strict_shape,
|
||||
)
|
||||
|
||||
@@ -85,3 +85,19 @@ class HasQueryIdCol(Params):
|
||||
"query id column name",
|
||||
typeConverter=TypeConverters.toString,
|
||||
)
|
||||
|
||||
|
||||
class HasContribPredictionCol(Params):
|
||||
"""
|
||||
Mixin for param pred_contrib_col: contribution prediction column name.
|
||||
|
||||
Output is a 3-dim array, with (rows, groups, columns + 1) for classification case.
|
||||
Else, it can be a 2 dimension for regression case.
|
||||
"""
|
||||
|
||||
pred_contrib_col: "Param[str]" = Param(
|
||||
Params._dummy(),
|
||||
"pred_contrib_col",
|
||||
"feature contributions to individual predictions.",
|
||||
typeConverter=TypeConverters.toString,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user