Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
parent
3fde9361d7
commit
4d387cbfbf
@ -64,6 +64,7 @@ from xgboost.core import Booster, _check_distributed_params
|
||||
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
from .._typing import ArrayLike
|
||||
from .data import (
|
||||
_read_csr_matrix_from_unwrapped_spark_vec,
|
||||
alias,
|
||||
@ -1117,12 +1118,86 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
)
|
||||
return features_col, feature_col_names
|
||||
|
||||
def _get_pred_contrib_col_name(self) -> Optional[str]:
|
||||
"""Return the pred_contrib_col col name"""
|
||||
pred_contrib_col_name = None
|
||||
if (
|
||||
self.isDefined(self.pred_contrib_col)
|
||||
and self.getOrDefault(self.pred_contrib_col) != ""
|
||||
):
|
||||
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
|
||||
|
||||
return pred_contrib_col_name
|
||||
|
||||
def _out_schema(self) -> Tuple[bool, str]:
|
||||
"""Return the bool to indicate if it's a single prediction, true is single prediction,
|
||||
and the returned type of the user-defined function. The value must
|
||||
be a DDL-formatted type string."""
|
||||
|
||||
if self._get_pred_contrib_col_name() is not None:
|
||||
return False, f"{pred.prediction} double, {pred.pred_contrib} array<double>"
|
||||
|
||||
return True, "double"
|
||||
|
||||
def _get_predict_func(self) -> Callable:
|
||||
"""Return the true prediction function which will be running on the executor side"""
|
||||
|
||||
predict_params = self._gen_predict_params_dict()
|
||||
pred_contrib_col_name = self._get_pred_contrib_col_name()
|
||||
|
||||
def _predict(
|
||||
model: XGBModel, X: ArrayLike, base_margin: Optional[ArrayLike]
|
||||
) -> Union[pd.DataFrame, pd.Series]:
|
||||
data = {}
|
||||
preds = model.predict(
|
||||
X,
|
||||
base_margin=base_margin,
|
||||
validate_features=False,
|
||||
**predict_params,
|
||||
)
|
||||
data[pred.prediction] = pd.Series(preds)
|
||||
|
||||
if pred_contrib_col_name is not None:
|
||||
contribs = pred_contribs(model, X, base_margin)
|
||||
data[pred.pred_contrib] = pd.Series(list(contribs))
|
||||
return pd.DataFrame(data=data)
|
||||
|
||||
return data[pred.prediction]
|
||||
|
||||
return _predict
|
||||
|
||||
def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
|
||||
"""Post process of transform"""
|
||||
prediction_col_name = self.getOrDefault(self.predictionCol)
|
||||
single_pred, _ = self._out_schema()
|
||||
|
||||
if single_pred:
|
||||
if prediction_col_name:
|
||||
dataset = dataset.withColumn(prediction_col_name, pred_col)
|
||||
else:
|
||||
pred_struct_col = "_prediction_struct"
|
||||
dataset = dataset.withColumn(pred_struct_col, pred_col)
|
||||
|
||||
if prediction_col_name:
|
||||
dataset = dataset.withColumn(
|
||||
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
|
||||
)
|
||||
|
||||
pred_contrib_col_name = self._get_pred_contrib_col_name()
|
||||
if pred_contrib_col_name is not None:
|
||||
dataset = dataset.withColumn(
|
||||
pred_contrib_col_name,
|
||||
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
|
||||
)
|
||||
|
||||
dataset = dataset.drop(pred_struct_col)
|
||||
return dataset
|
||||
|
||||
def _transform(self, dataset: DataFrame) -> DataFrame:
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
# to avoid the `self` object to be pickled to remote.
|
||||
xgb_sklearn_model = self._xgb_sklearn_model
|
||||
predict_params = self._gen_predict_params_dict()
|
||||
|
||||
has_base_margin = False
|
||||
if (
|
||||
@ -1137,18 +1212,9 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||
|
||||
pred_contrib_col_name = None
|
||||
if (
|
||||
self.isDefined(self.pred_contrib_col)
|
||||
and self.getOrDefault(self.pred_contrib_col) != ""
|
||||
):
|
||||
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
|
||||
predict_func = self._get_predict_func()
|
||||
|
||||
single_pred = True
|
||||
schema = "double"
|
||||
if pred_contrib_col_name:
|
||||
single_pred = False
|
||||
schema = f"{pred.prediction} double, {pred.pred_contrib} array<double>"
|
||||
_, schema = self._out_schema()
|
||||
|
||||
@pandas_udf(schema) # type: ignore
|
||||
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
||||
@ -1168,48 +1234,14 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
else:
|
||||
base_margin = None
|
||||
|
||||
data = {}
|
||||
preds = model.predict(
|
||||
X,
|
||||
base_margin=base_margin,
|
||||
validate_features=False,
|
||||
**predict_params,
|
||||
)
|
||||
data[pred.prediction] = pd.Series(preds)
|
||||
|
||||
if pred_contrib_col_name:
|
||||
contribs = pred_contribs(model, X, base_margin)
|
||||
data[pred.pred_contrib] = pd.Series(list(contribs))
|
||||
yield pd.DataFrame(data=data)
|
||||
else:
|
||||
yield data[pred.prediction]
|
||||
yield predict_func(model, X, base_margin)
|
||||
|
||||
if has_base_margin:
|
||||
pred_col = predict_udf(struct(*features_col, base_margin_col))
|
||||
else:
|
||||
pred_col = predict_udf(struct(*features_col))
|
||||
|
||||
prediction_col_name = self.getOrDefault(self.predictionCol)
|
||||
|
||||
if single_pred:
|
||||
dataset = dataset.withColumn(prediction_col_name, pred_col)
|
||||
else:
|
||||
pred_struct_col = "_prediction_struct"
|
||||
dataset = dataset.withColumn(pred_struct_col, pred_col)
|
||||
|
||||
dataset = dataset.withColumn(
|
||||
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
|
||||
)
|
||||
|
||||
if pred_contrib_col_name:
|
||||
dataset = dataset.withColumn(
|
||||
pred_contrib_col_name,
|
||||
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
|
||||
)
|
||||
|
||||
dataset = dataset.drop(pred_struct_col)
|
||||
|
||||
return dataset
|
||||
return self._post_transform(dataset, pred_col)
|
||||
|
||||
|
||||
class _ClassificationModel( # pylint: disable=abstract-method
|
||||
@ -1221,22 +1253,21 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
||||
.. Note:: This API is experimental.
|
||||
"""
|
||||
|
||||
def _transform(self, dataset: DataFrame) -> DataFrame:
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
# to avoid the `self` object to be pickled to remote.
|
||||
xgb_sklearn_model = self._xgb_sklearn_model
|
||||
predict_params = self._gen_predict_params_dict()
|
||||
def _out_schema(self) -> Tuple[bool, str]:
|
||||
schema = (
|
||||
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
|
||||
f" {pred.probability} array<double>"
|
||||
)
|
||||
if self._get_pred_contrib_col_name() is not None:
|
||||
# We will force setting strict_shape to True when predicting contribs,
|
||||
# So, it will also output 3-D shape result.
|
||||
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
|
||||
|
||||
has_base_margin = False
|
||||
if (
|
||||
self.isDefined(self.base_margin_col)
|
||||
and self.getOrDefault(self.base_margin_col) != ""
|
||||
):
|
||||
has_base_margin = True
|
||||
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
||||
alias.margin
|
||||
)
|
||||
return False, schema
|
||||
|
||||
def _get_predict_func(self) -> Callable:
|
||||
predict_params = self._gen_predict_params_dict()
|
||||
pred_contrib_col_name = self._get_pred_contrib_col_name()
|
||||
|
||||
def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if margins.ndim == 1:
|
||||
@ -1251,76 +1282,38 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
||||
class_probs = softmax(raw_preds, axis=1)
|
||||
return raw_preds, class_probs
|
||||
|
||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||
def _predict(
|
||||
model: XGBModel, X: ArrayLike, base_margin: Optional[np.ndarray]
|
||||
) -> Union[pd.DataFrame, pd.Series]:
|
||||
margins = model.predict(
|
||||
X,
|
||||
base_margin=base_margin,
|
||||
output_margin=True,
|
||||
validate_features=False,
|
||||
**predict_params,
|
||||
)
|
||||
raw_preds, class_probs = transform_margin(margins)
|
||||
|
||||
pred_contrib_col_name = None
|
||||
if (
|
||||
self.isDefined(self.pred_contrib_col)
|
||||
and self.getOrDefault(self.pred_contrib_col) != ""
|
||||
):
|
||||
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
|
||||
# It seems that they use argmax of class probs,
|
||||
# not of margin to get the prediction (Note: scala implementation)
|
||||
preds = np.argmax(class_probs, axis=1)
|
||||
result: Dict[str, pd.Series] = {
|
||||
pred.raw_prediction: pd.Series(list(raw_preds)),
|
||||
pred.prediction: pd.Series(preds),
|
||||
pred.probability: pd.Series(list(class_probs)),
|
||||
}
|
||||
|
||||
schema = (
|
||||
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
|
||||
f" {pred.probability} array<double>"
|
||||
)
|
||||
if pred_contrib_col_name:
|
||||
# We will force setting strict_shape to True when predicting contribs,
|
||||
# So, it will also output 3-D shape result.
|
||||
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
|
||||
if pred_contrib_col_name is not None:
|
||||
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
|
||||
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
|
||||
|
||||
@pandas_udf(schema) # type: ignore
|
||||
def predict_udf(
|
||||
iterator: Iterator[Tuple[pd.Series, ...]]
|
||||
) -> Iterator[pd.DataFrame]:
|
||||
assert xgb_sklearn_model is not None
|
||||
model = xgb_sklearn_model
|
||||
for data in iterator:
|
||||
if enable_sparse_data_optim:
|
||||
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
|
||||
else:
|
||||
if feature_col_names is not None:
|
||||
X = data[feature_col_names] # type: ignore
|
||||
else:
|
||||
X = stack_series(data[alias.data])
|
||||
return pd.DataFrame(data=result)
|
||||
|
||||
if has_base_margin:
|
||||
base_margin = stack_series(data[alias.margin])
|
||||
else:
|
||||
base_margin = None
|
||||
|
||||
margins = model.predict(
|
||||
X,
|
||||
base_margin=base_margin,
|
||||
output_margin=True,
|
||||
validate_features=False,
|
||||
**predict_params,
|
||||
)
|
||||
raw_preds, class_probs = transform_margin(margins)
|
||||
|
||||
# It seems that they use argmax of class probs,
|
||||
# not of margin to get the prediction (Note: scala implementation)
|
||||
preds = np.argmax(class_probs, axis=1)
|
||||
result: Dict[str, pd.Series] = {
|
||||
pred.raw_prediction: pd.Series(list(raw_preds)),
|
||||
pred.prediction: pd.Series(preds),
|
||||
pred.probability: pd.Series(list(class_probs)),
|
||||
}
|
||||
|
||||
if pred_contrib_col_name:
|
||||
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
|
||||
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
|
||||
|
||||
yield pd.DataFrame(data=result)
|
||||
|
||||
if has_base_margin:
|
||||
pred_struct = predict_udf(struct(*features_col, base_margin_col))
|
||||
else:
|
||||
pred_struct = predict_udf(struct(*features_col))
|
||||
return _predict
|
||||
|
||||
def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
|
||||
pred_struct_col = "_prediction_struct"
|
||||
dataset = dataset.withColumn(pred_struct_col, pred_struct)
|
||||
dataset = dataset.withColumn(pred_struct_col, pred_col)
|
||||
|
||||
raw_prediction_col_name = self.getOrDefault(self.rawPredictionCol)
|
||||
if raw_prediction_col_name:
|
||||
@ -1342,7 +1335,8 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
||||
array_to_vector(getattr(col(pred_struct_col), pred.probability)),
|
||||
)
|
||||
|
||||
if pred_contrib_col_name:
|
||||
pred_contrib_col_name = self._get_pred_contrib_col_name()
|
||||
if pred_contrib_col_name is not None:
|
||||
dataset = dataset.withColumn(
|
||||
pred_contrib_col_name,
|
||||
getattr(col(pred_struct_col), pred.pred_contrib),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user