Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
parent
3fde9361d7
commit
4d387cbfbf
@ -64,6 +64,7 @@ from xgboost.core import Booster, _check_distributed_params
|
|||||||
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
|
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
|
||||||
from xgboost.training import train as worker_train
|
from xgboost.training import train as worker_train
|
||||||
|
|
||||||
|
from .._typing import ArrayLike
|
||||||
from .data import (
|
from .data import (
|
||||||
_read_csr_matrix_from_unwrapped_spark_vec,
|
_read_csr_matrix_from_unwrapped_spark_vec,
|
||||||
alias,
|
alias,
|
||||||
@ -1117,12 +1118,86 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
)
|
)
|
||||||
return features_col, feature_col_names
|
return features_col, feature_col_names
|
||||||
|
|
||||||
|
def _get_pred_contrib_col_name(self) -> Optional[str]:
|
||||||
|
"""Return the pred_contrib_col col name"""
|
||||||
|
pred_contrib_col_name = None
|
||||||
|
if (
|
||||||
|
self.isDefined(self.pred_contrib_col)
|
||||||
|
and self.getOrDefault(self.pred_contrib_col) != ""
|
||||||
|
):
|
||||||
|
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
|
||||||
|
|
||||||
|
return pred_contrib_col_name
|
||||||
|
|
||||||
|
def _out_schema(self) -> Tuple[bool, str]:
|
||||||
|
"""Return the bool to indicate if it's a single prediction, true is single prediction,
|
||||||
|
and the returned type of the user-defined function. The value must
|
||||||
|
be a DDL-formatted type string."""
|
||||||
|
|
||||||
|
if self._get_pred_contrib_col_name() is not None:
|
||||||
|
return False, f"{pred.prediction} double, {pred.pred_contrib} array<double>"
|
||||||
|
|
||||||
|
return True, "double"
|
||||||
|
|
||||||
|
def _get_predict_func(self) -> Callable:
|
||||||
|
"""Return the true prediction function which will be running on the executor side"""
|
||||||
|
|
||||||
|
predict_params = self._gen_predict_params_dict()
|
||||||
|
pred_contrib_col_name = self._get_pred_contrib_col_name()
|
||||||
|
|
||||||
|
def _predict(
|
||||||
|
model: XGBModel, X: ArrayLike, base_margin: Optional[ArrayLike]
|
||||||
|
) -> Union[pd.DataFrame, pd.Series]:
|
||||||
|
data = {}
|
||||||
|
preds = model.predict(
|
||||||
|
X,
|
||||||
|
base_margin=base_margin,
|
||||||
|
validate_features=False,
|
||||||
|
**predict_params,
|
||||||
|
)
|
||||||
|
data[pred.prediction] = pd.Series(preds)
|
||||||
|
|
||||||
|
if pred_contrib_col_name is not None:
|
||||||
|
contribs = pred_contribs(model, X, base_margin)
|
||||||
|
data[pred.pred_contrib] = pd.Series(list(contribs))
|
||||||
|
return pd.DataFrame(data=data)
|
||||||
|
|
||||||
|
return data[pred.prediction]
|
||||||
|
|
||||||
|
return _predict
|
||||||
|
|
||||||
|
def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
|
||||||
|
"""Post process of transform"""
|
||||||
|
prediction_col_name = self.getOrDefault(self.predictionCol)
|
||||||
|
single_pred, _ = self._out_schema()
|
||||||
|
|
||||||
|
if single_pred:
|
||||||
|
if prediction_col_name:
|
||||||
|
dataset = dataset.withColumn(prediction_col_name, pred_col)
|
||||||
|
else:
|
||||||
|
pred_struct_col = "_prediction_struct"
|
||||||
|
dataset = dataset.withColumn(pred_struct_col, pred_col)
|
||||||
|
|
||||||
|
if prediction_col_name:
|
||||||
|
dataset = dataset.withColumn(
|
||||||
|
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_contrib_col_name = self._get_pred_contrib_col_name()
|
||||||
|
if pred_contrib_col_name is not None:
|
||||||
|
dataset = dataset.withColumn(
|
||||||
|
pred_contrib_col_name,
|
||||||
|
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = dataset.drop(pred_struct_col)
|
||||||
|
return dataset
|
||||||
|
|
||||||
def _transform(self, dataset: DataFrame) -> DataFrame:
|
def _transform(self, dataset: DataFrame) -> DataFrame:
|
||||||
# pylint: disable=too-many-statements, too-many-locals
|
# pylint: disable=too-many-statements, too-many-locals
|
||||||
# Save xgb_sklearn_model and predict_params to be local variable
|
# Save xgb_sklearn_model and predict_params to be local variable
|
||||||
# to avoid the `self` object to be pickled to remote.
|
# to avoid the `self` object to be pickled to remote.
|
||||||
xgb_sklearn_model = self._xgb_sklearn_model
|
xgb_sklearn_model = self._xgb_sklearn_model
|
||||||
predict_params = self._gen_predict_params_dict()
|
|
||||||
|
|
||||||
has_base_margin = False
|
has_base_margin = False
|
||||||
if (
|
if (
|
||||||
@ -1137,18 +1212,9 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||||
|
|
||||||
pred_contrib_col_name = None
|
predict_func = self._get_predict_func()
|
||||||
if (
|
|
||||||
self.isDefined(self.pred_contrib_col)
|
|
||||||
and self.getOrDefault(self.pred_contrib_col) != ""
|
|
||||||
):
|
|
||||||
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
|
|
||||||
|
|
||||||
single_pred = True
|
_, schema = self._out_schema()
|
||||||
schema = "double"
|
|
||||||
if pred_contrib_col_name:
|
|
||||||
single_pred = False
|
|
||||||
schema = f"{pred.prediction} double, {pred.pred_contrib} array<double>"
|
|
||||||
|
|
||||||
@pandas_udf(schema) # type: ignore
|
@pandas_udf(schema) # type: ignore
|
||||||
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
||||||
@ -1168,48 +1234,14 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
else:
|
else:
|
||||||
base_margin = None
|
base_margin = None
|
||||||
|
|
||||||
data = {}
|
yield predict_func(model, X, base_margin)
|
||||||
preds = model.predict(
|
|
||||||
X,
|
|
||||||
base_margin=base_margin,
|
|
||||||
validate_features=False,
|
|
||||||
**predict_params,
|
|
||||||
)
|
|
||||||
data[pred.prediction] = pd.Series(preds)
|
|
||||||
|
|
||||||
if pred_contrib_col_name:
|
|
||||||
contribs = pred_contribs(model, X, base_margin)
|
|
||||||
data[pred.pred_contrib] = pd.Series(list(contribs))
|
|
||||||
yield pd.DataFrame(data=data)
|
|
||||||
else:
|
|
||||||
yield data[pred.prediction]
|
|
||||||
|
|
||||||
if has_base_margin:
|
if has_base_margin:
|
||||||
pred_col = predict_udf(struct(*features_col, base_margin_col))
|
pred_col = predict_udf(struct(*features_col, base_margin_col))
|
||||||
else:
|
else:
|
||||||
pred_col = predict_udf(struct(*features_col))
|
pred_col = predict_udf(struct(*features_col))
|
||||||
|
|
||||||
prediction_col_name = self.getOrDefault(self.predictionCol)
|
return self._post_transform(dataset, pred_col)
|
||||||
|
|
||||||
if single_pred:
|
|
||||||
dataset = dataset.withColumn(prediction_col_name, pred_col)
|
|
||||||
else:
|
|
||||||
pred_struct_col = "_prediction_struct"
|
|
||||||
dataset = dataset.withColumn(pred_struct_col, pred_col)
|
|
||||||
|
|
||||||
dataset = dataset.withColumn(
|
|
||||||
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
|
|
||||||
)
|
|
||||||
|
|
||||||
if pred_contrib_col_name:
|
|
||||||
dataset = dataset.withColumn(
|
|
||||||
pred_contrib_col_name,
|
|
||||||
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = dataset.drop(pred_struct_col)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
class _ClassificationModel( # pylint: disable=abstract-method
|
class _ClassificationModel( # pylint: disable=abstract-method
|
||||||
@ -1221,22 +1253,21 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
|||||||
.. Note:: This API is experimental.
|
.. Note:: This API is experimental.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _transform(self, dataset: DataFrame) -> DataFrame:
|
def _out_schema(self) -> Tuple[bool, str]:
|
||||||
# pylint: disable=too-many-statements, too-many-locals
|
schema = (
|
||||||
# Save xgb_sklearn_model and predict_params to be local variable
|
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
|
||||||
# to avoid the `self` object to be pickled to remote.
|
f" {pred.probability} array<double>"
|
||||||
xgb_sklearn_model = self._xgb_sklearn_model
|
|
||||||
predict_params = self._gen_predict_params_dict()
|
|
||||||
|
|
||||||
has_base_margin = False
|
|
||||||
if (
|
|
||||||
self.isDefined(self.base_margin_col)
|
|
||||||
and self.getOrDefault(self.base_margin_col) != ""
|
|
||||||
):
|
|
||||||
has_base_margin = True
|
|
||||||
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
|
||||||
alias.margin
|
|
||||||
)
|
)
|
||||||
|
if self._get_pred_contrib_col_name() is not None:
|
||||||
|
# We will force setting strict_shape to True when predicting contribs,
|
||||||
|
# So, it will also output 3-D shape result.
|
||||||
|
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
|
||||||
|
|
||||||
|
return False, schema
|
||||||
|
|
||||||
|
def _get_predict_func(self) -> Callable:
|
||||||
|
predict_params = self._gen_predict_params_dict()
|
||||||
|
pred_contrib_col_name = self._get_pred_contrib_col_name()
|
||||||
|
|
||||||
def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
if margins.ndim == 1:
|
if margins.ndim == 1:
|
||||||
@ -1251,45 +1282,9 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
|||||||
class_probs = softmax(raw_preds, axis=1)
|
class_probs = softmax(raw_preds, axis=1)
|
||||||
return raw_preds, class_probs
|
return raw_preds, class_probs
|
||||||
|
|
||||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
def _predict(
|
||||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
model: XGBModel, X: ArrayLike, base_margin: Optional[np.ndarray]
|
||||||
|
) -> Union[pd.DataFrame, pd.Series]:
|
||||||
pred_contrib_col_name = None
|
|
||||||
if (
|
|
||||||
self.isDefined(self.pred_contrib_col)
|
|
||||||
and self.getOrDefault(self.pred_contrib_col) != ""
|
|
||||||
):
|
|
||||||
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
|
|
||||||
|
|
||||||
schema = (
|
|
||||||
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
|
|
||||||
f" {pred.probability} array<double>"
|
|
||||||
)
|
|
||||||
if pred_contrib_col_name:
|
|
||||||
# We will force setting strict_shape to True when predicting contribs,
|
|
||||||
# So, it will also output 3-D shape result.
|
|
||||||
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
|
|
||||||
|
|
||||||
@pandas_udf(schema) # type: ignore
|
|
||||||
def predict_udf(
|
|
||||||
iterator: Iterator[Tuple[pd.Series, ...]]
|
|
||||||
) -> Iterator[pd.DataFrame]:
|
|
||||||
assert xgb_sklearn_model is not None
|
|
||||||
model = xgb_sklearn_model
|
|
||||||
for data in iterator:
|
|
||||||
if enable_sparse_data_optim:
|
|
||||||
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
|
|
||||||
else:
|
|
||||||
if feature_col_names is not None:
|
|
||||||
X = data[feature_col_names] # type: ignore
|
|
||||||
else:
|
|
||||||
X = stack_series(data[alias.data])
|
|
||||||
|
|
||||||
if has_base_margin:
|
|
||||||
base_margin = stack_series(data[alias.margin])
|
|
||||||
else:
|
|
||||||
base_margin = None
|
|
||||||
|
|
||||||
margins = model.predict(
|
margins = model.predict(
|
||||||
X,
|
X,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
@ -1308,19 +1303,17 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
|||||||
pred.probability: pd.Series(list(class_probs)),
|
pred.probability: pd.Series(list(class_probs)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if pred_contrib_col_name:
|
if pred_contrib_col_name is not None:
|
||||||
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
|
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
|
||||||
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
|
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
|
||||||
|
|
||||||
yield pd.DataFrame(data=result)
|
return pd.DataFrame(data=result)
|
||||||
|
|
||||||
if has_base_margin:
|
return _predict
|
||||||
pred_struct = predict_udf(struct(*features_col, base_margin_col))
|
|
||||||
else:
|
|
||||||
pred_struct = predict_udf(struct(*features_col))
|
|
||||||
|
|
||||||
|
def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
|
||||||
pred_struct_col = "_prediction_struct"
|
pred_struct_col = "_prediction_struct"
|
||||||
dataset = dataset.withColumn(pred_struct_col, pred_struct)
|
dataset = dataset.withColumn(pred_struct_col, pred_col)
|
||||||
|
|
||||||
raw_prediction_col_name = self.getOrDefault(self.rawPredictionCol)
|
raw_prediction_col_name = self.getOrDefault(self.rawPredictionCol)
|
||||||
if raw_prediction_col_name:
|
if raw_prediction_col_name:
|
||||||
@ -1342,7 +1335,8 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
|||||||
array_to_vector(getattr(col(pred_struct_col), pred.probability)),
|
array_to_vector(getattr(col(pred_struct_col), pred.probability)),
|
||||||
)
|
)
|
||||||
|
|
||||||
if pred_contrib_col_name:
|
pred_contrib_col_name = self._get_pred_contrib_col_name()
|
||||||
|
if pred_contrib_col_name is not None:
|
||||||
dataset = dataset.withColumn(
|
dataset = dataset.withColumn(
|
||||||
pred_contrib_col_name,
|
pred_contrib_col_name,
|
||||||
getattr(col(pred_struct_col), pred.pred_contrib),
|
getattr(col(pred_struct_col), pred.pred_contrib),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user