[backport] [pyspark] rework transform to reuse same code (#9292) (#9558)

Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
Jiaming Yuan 2023-09-07 15:26:24 +08:00 committed by GitHub
parent 3fde9361d7
commit 4d387cbfbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -64,6 +64,7 @@ from xgboost.core import Booster, _check_distributed_params
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
from xgboost.training import train as worker_train
from .._typing import ArrayLike
from .data import (
_read_csr_matrix_from_unwrapped_spark_vec,
alias,
@ -1117,12 +1118,86 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
)
return features_col, feature_col_names
def _get_pred_contrib_col_name(self) -> Optional[str]:
"""Return the pred_contrib_col col name"""
pred_contrib_col_name = None
if (
self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
return pred_contrib_col_name
def _out_schema(self) -> Tuple[bool, str]:
"""Return the bool to indicate if it's a single prediction, true is single prediction,
and the returned type of the user-defined function. The value must
be a DDL-formatted type string."""
if self._get_pred_contrib_col_name() is not None:
return False, f"{pred.prediction} double, {pred.pred_contrib} array<double>"
return True, "double"
def _get_predict_func(self) -> Callable:
"""Return the true prediction function which will be running on the executor side"""
predict_params = self._gen_predict_params_dict()
pred_contrib_col_name = self._get_pred_contrib_col_name()
def _predict(
model: XGBModel, X: ArrayLike, base_margin: Optional[ArrayLike]
) -> Union[pd.DataFrame, pd.Series]:
data = {}
preds = model.predict(
X,
base_margin=base_margin,
validate_features=False,
**predict_params,
)
data[pred.prediction] = pd.Series(preds)
if pred_contrib_col_name is not None:
contribs = pred_contribs(model, X, base_margin)
data[pred.pred_contrib] = pd.Series(list(contribs))
return pd.DataFrame(data=data)
return data[pred.prediction]
return _predict
def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
"""Post process of transform"""
prediction_col_name = self.getOrDefault(self.predictionCol)
single_pred, _ = self._out_schema()
if single_pred:
if prediction_col_name:
dataset = dataset.withColumn(prediction_col_name, pred_col)
else:
pred_struct_col = "_prediction_struct"
dataset = dataset.withColumn(pred_struct_col, pred_col)
if prediction_col_name:
dataset = dataset.withColumn(
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
)
pred_contrib_col_name = self._get_pred_contrib_col_name()
if pred_contrib_col_name is not None:
dataset = dataset.withColumn(
pred_contrib_col_name,
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
)
dataset = dataset.drop(pred_struct_col)
return dataset
def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model
predict_params = self._gen_predict_params_dict()
has_base_margin = False
if (
@ -1137,18 +1212,9 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
features_col, feature_col_names = self._get_feature_col(dataset)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
pred_contrib_col_name = None
if (
self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
predict_func = self._get_predict_func()
single_pred = True
schema = "double"
if pred_contrib_col_name:
single_pred = False
schema = f"{pred.prediction} double, {pred.pred_contrib} array<double>"
_, schema = self._out_schema()
@pandas_udf(schema) # type: ignore
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
@ -1168,48 +1234,14 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
else:
base_margin = None
data = {}
preds = model.predict(
X,
base_margin=base_margin,
validate_features=False,
**predict_params,
)
data[pred.prediction] = pd.Series(preds)
if pred_contrib_col_name:
contribs = pred_contribs(model, X, base_margin)
data[pred.pred_contrib] = pd.Series(list(contribs))
yield pd.DataFrame(data=data)
else:
yield data[pred.prediction]
yield predict_func(model, X, base_margin)
if has_base_margin:
pred_col = predict_udf(struct(*features_col, base_margin_col))
else:
pred_col = predict_udf(struct(*features_col))
prediction_col_name = self.getOrDefault(self.predictionCol)
if single_pred:
dataset = dataset.withColumn(prediction_col_name, pred_col)
else:
pred_struct_col = "_prediction_struct"
dataset = dataset.withColumn(pred_struct_col, pred_col)
dataset = dataset.withColumn(
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
)
if pred_contrib_col_name:
dataset = dataset.withColumn(
pred_contrib_col_name,
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
)
dataset = dataset.drop(pred_struct_col)
return dataset
return self._post_transform(dataset, pred_col)
class _ClassificationModel( # pylint: disable=abstract-method
@ -1221,22 +1253,21 @@ class _ClassificationModel( # pylint: disable=abstract-method
.. Note:: This API is experimental.
"""
def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model
predict_params = self._gen_predict_params_dict()
has_base_margin = False
if (
self.isDefined(self.base_margin_col)
and self.getOrDefault(self.base_margin_col) != ""
):
has_base_margin = True
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
alias.margin
def _out_schema(self) -> Tuple[bool, str]:
schema = (
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
f" {pred.probability} array<double>"
)
if self._get_pred_contrib_col_name() is not None:
# We will force setting strict_shape to True when predicting contribs,
# So, it will also output 3-D shape result.
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
return False, schema
def _get_predict_func(self) -> Callable:
predict_params = self._gen_predict_params_dict()
pred_contrib_col_name = self._get_pred_contrib_col_name()
def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if margins.ndim == 1:
@ -1251,45 +1282,9 @@ class _ClassificationModel( # pylint: disable=abstract-method
class_probs = softmax(raw_preds, axis=1)
return raw_preds, class_probs
features_col, feature_col_names = self._get_feature_col(dataset)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
pred_contrib_col_name = None
if (
self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
schema = (
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
f" {pred.probability} array<double>"
)
if pred_contrib_col_name:
# We will force setting strict_shape to True when predicting contribs,
# So, it will also output 3-D shape result.
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
@pandas_udf(schema) # type: ignore
def predict_udf(
iterator: Iterator[Tuple[pd.Series, ...]]
) -> Iterator[pd.DataFrame]:
assert xgb_sklearn_model is not None
model = xgb_sklearn_model
for data in iterator:
if enable_sparse_data_optim:
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
else:
if feature_col_names is not None:
X = data[feature_col_names] # type: ignore
else:
X = stack_series(data[alias.data])
if has_base_margin:
base_margin = stack_series(data[alias.margin])
else:
base_margin = None
def _predict(
model: XGBModel, X: ArrayLike, base_margin: Optional[np.ndarray]
) -> Union[pd.DataFrame, pd.Series]:
margins = model.predict(
X,
base_margin=base_margin,
@ -1308,19 +1303,17 @@ class _ClassificationModel( # pylint: disable=abstract-method
pred.probability: pd.Series(list(class_probs)),
}
if pred_contrib_col_name:
if pred_contrib_col_name is not None:
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
yield pd.DataFrame(data=result)
return pd.DataFrame(data=result)
if has_base_margin:
pred_struct = predict_udf(struct(*features_col, base_margin_col))
else:
pred_struct = predict_udf(struct(*features_col))
return _predict
def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
pred_struct_col = "_prediction_struct"
dataset = dataset.withColumn(pred_struct_col, pred_struct)
dataset = dataset.withColumn(pred_struct_col, pred_col)
raw_prediction_col_name = self.getOrDefault(self.rawPredictionCol)
if raw_prediction_col_name:
@ -1342,7 +1335,8 @@ class _ClassificationModel( # pylint: disable=abstract-method
array_to_vector(getattr(col(pred_struct_col), pred.probability)),
)
if pred_contrib_col_name:
pred_contrib_col_name = self._get_pred_contrib_col_name()
if pred_contrib_col_name is not None:
dataset = dataset.withColumn(
pred_contrib_col_name,
getattr(col(pred_struct_col), pred.pred_contrib),