[pyspark] Cleanup data processing. (#8088)
- Use numpy stack for handling list of arrays. - Reuse concat function from dask. - Prepare for `QuantileDMatrix`. - Remove unused code. - Use iterator for prediction to avoid initializing xgboost model
This commit is contained in:
parent
3970e4e6bb
commit
546de5efd2
@ -1,13 +1,14 @@
|
|||||||
# coding: utf-8
|
|
||||||
# pylint: disable= invalid-name, unused-import
|
# pylint: disable= invalid-name, unused-import
|
||||||
"""For compatibility and optional dependencies."""
|
"""For compatibility and optional dependencies."""
|
||||||
from typing import Any, Type, Dict, Optional, List
|
from typing import Any, Type, Dict, Optional, List, Sequence, cast
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ._typing import _T
|
||||||
|
|
||||||
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
|
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
|
||||||
|
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ def py_str(x: bytes) -> str:
|
|||||||
return x.decode('utf-8') # type: ignore
|
return x.decode('utf-8') # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def lazy_isinstance(instance: Type[object], module: str, name: str) -> bool:
|
def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
|
||||||
"""Use string representation to identify a type."""
|
"""Use string representation to identify a type."""
|
||||||
|
|
||||||
# Notice, we use .__class__ as opposed to type() in order
|
# Notice, we use .__class__ as opposed to type() in order
|
||||||
@ -104,11 +105,42 @@ class XGBoostLabelEncoder(LabelEncoder):
|
|||||||
try:
|
try:
|
||||||
import scipy.sparse as scipy_sparse
|
import scipy.sparse as scipy_sparse
|
||||||
from scipy.sparse import csr_matrix as scipy_csr
|
from scipy.sparse import csr_matrix as scipy_csr
|
||||||
SCIPY_INSTALLED = True
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
scipy_sparse = False
|
scipy_sparse = False
|
||||||
scipy_csr = object
|
scipy_csr = object
|
||||||
SCIPY_INSTALLED = False
|
|
||||||
|
|
||||||
|
def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements
|
||||||
|
"""Concatenate row-wise."""
|
||||||
|
if isinstance(value[0], np.ndarray):
|
||||||
|
value_arr = cast(Sequence[np.ndarray], value)
|
||||||
|
return np.concatenate(value_arr, axis=0)
|
||||||
|
if scipy_sparse and isinstance(value[0], scipy_sparse.csr_matrix):
|
||||||
|
return scipy_sparse.vstack(value, format="csr")
|
||||||
|
if scipy_sparse and isinstance(value[0], scipy_sparse.csc_matrix):
|
||||||
|
return scipy_sparse.vstack(value, format="csc")
|
||||||
|
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
|
||||||
|
# other sparse format will be converted to CSR.
|
||||||
|
return scipy_sparse.vstack(value, format="csr")
|
||||||
|
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
|
||||||
|
return pandas_concat(value, axis=0)
|
||||||
|
if lazy_isinstance(value[0], "cudf.core.dataframe", "DataFrame") or lazy_isinstance(
|
||||||
|
value[0], "cudf.core.series", "Series"
|
||||||
|
):
|
||||||
|
from cudf import concat as CUDF_concat # pylint: disable=import-error
|
||||||
|
|
||||||
|
return CUDF_concat(value, axis=0)
|
||||||
|
if lazy_isinstance(value[0], "cupy._core.core", "ndarray"):
|
||||||
|
import cupy # pylint: disable=import-error
|
||||||
|
|
||||||
|
# pylint: disable=c-extension-no-member,no-member
|
||||||
|
d = cupy.cuda.runtime.getDevice()
|
||||||
|
for v in value:
|
||||||
|
arr = cast(cupy.ndarray, v)
|
||||||
|
d_v = arr.device.id
|
||||||
|
assert d_v == d, "Concatenating arrays on different devices."
|
||||||
|
return cupy.concatenate(value, axis=0)
|
||||||
|
raise TypeError("Unknown type.")
|
||||||
|
|
||||||
|
|
||||||
# Modified from tensorflow with added caching. There's a `LazyLoader` in
|
# Modified from tensorflow with added caching. There's a `LazyLoader` in
|
||||||
|
|||||||
@ -58,17 +58,9 @@ from typing import (
|
|||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from . import config, rabit
|
from . import config, rabit
|
||||||
from ._typing import FeatureNames, FeatureTypes
|
from ._typing import _T, FeatureNames, FeatureTypes
|
||||||
from .callback import TrainingCallback
|
from .callback import TrainingCallback
|
||||||
from .compat import (
|
from .compat import DataFrame, LazyLoader, concat, lazy_isinstance
|
||||||
PANDAS_INSTALLED,
|
|
||||||
DataFrame,
|
|
||||||
LazyLoader,
|
|
||||||
Series,
|
|
||||||
lazy_isinstance,
|
|
||||||
pandas_concat,
|
|
||||||
scipy_sparse,
|
|
||||||
)
|
|
||||||
from .core import (
|
from .core import (
|
||||||
Booster,
|
Booster,
|
||||||
DataIter,
|
DataIter,
|
||||||
@ -234,34 +226,11 @@ class RabitContext(rabit.RabitContext):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
|
def dconcat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements
|
||||||
"""To be replaced with dask builtin."""
|
"""Concatenate sequence of partitions."""
|
||||||
if isinstance(value[0], numpy.ndarray):
|
try:
|
||||||
return numpy.concatenate(value, axis=0)
|
return concat(value)
|
||||||
if scipy_sparse and isinstance(value[0], scipy_sparse.csr_matrix):
|
except TypeError:
|
||||||
return scipy_sparse.vstack(value, format="csr")
|
|
||||||
if scipy_sparse and isinstance(value[0], scipy_sparse.csc_matrix):
|
|
||||||
return scipy_sparse.vstack(value, format="csc")
|
|
||||||
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
|
|
||||||
# other sparse format will be converted to CSR.
|
|
||||||
return scipy_sparse.vstack(value, format="csr")
|
|
||||||
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
|
|
||||||
return pandas_concat(value, axis=0)
|
|
||||||
if lazy_isinstance(value[0], "cudf.core.dataframe", "DataFrame") or lazy_isinstance(
|
|
||||||
value[0], "cudf.core.series", "Series"
|
|
||||||
):
|
|
||||||
from cudf import concat as CUDF_concat # pylint: disable=import-error
|
|
||||||
|
|
||||||
return CUDF_concat(value, axis=0)
|
|
||||||
if lazy_isinstance(value[0], "cupy._core.core", "ndarray"):
|
|
||||||
import cupy
|
|
||||||
|
|
||||||
# pylint: disable=c-extension-no-member,no-member
|
|
||||||
d = cupy.cuda.runtime.getDevice()
|
|
||||||
for v in value:
|
|
||||||
d_v = v.device.id
|
|
||||||
assert d_v == d, "Concatenating arrays on different devices."
|
|
||||||
return cupy.concatenate(value, axis=0)
|
|
||||||
return dd.multi.concat(list(value), axis=0)
|
return dd.multi.concat(list(value), axis=0)
|
||||||
|
|
||||||
|
|
||||||
@ -797,7 +766,7 @@ def _create_dmatrix(
|
|||||||
def concat_or_none(data: Sequence[Optional[T]]) -> Optional[T]:
|
def concat_or_none(data: Sequence[Optional[T]]) -> Optional[T]:
|
||||||
if any(part is None for part in data):
|
if any(part is None for part in data):
|
||||||
return None
|
return None
|
||||||
return concat(data)
|
return dconcat(data)
|
||||||
|
|
||||||
unzipped_dict = _get_worker_parts(list_of_parts)
|
unzipped_dict = _get_worker_parts(list_of_parts)
|
||||||
concated_dict: Dict[str, Any] = {}
|
concated_dict: Dict[str, Any] = {}
|
||||||
|
|||||||
@ -17,7 +17,9 @@ from typing import (
|
|||||||
Type,
|
Type,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy.special import softmax
|
||||||
|
|
||||||
from .core import Booster, DMatrix, XGBoostError
|
from .core import Booster, DMatrix, XGBoostError
|
||||||
from .core import _deprecate_positional_args, _convert_ntree_limit
|
from .core import _deprecate_positional_args, _convert_ntree_limit
|
||||||
@ -1540,17 +1542,20 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
"""
|
"""
|
||||||
# custom obj: Do nothing as we don't know what to do.
|
# custom obj: Do nothing as we don't know what to do.
|
||||||
# softprob: Do nothing, output is proba.
|
# softprob: Do nothing, output is proba.
|
||||||
# softmax: Unsupported by predict_proba()
|
# softmax: Use softmax from scipy
|
||||||
# binary:logistic: Expand the prob vector into 2-class matrix after predict.
|
# binary:logistic: Expand the prob vector into 2-class matrix after predict.
|
||||||
# binary:logitraw: Unsupported by predict_proba()
|
# binary:logitraw: Unsupported by predict_proba()
|
||||||
if self.objective == "multi:softmax":
|
if self.objective == "multi:softmax":
|
||||||
# We need to run a Python implementation of softmax for it. Just ask user to
|
raw_predt = super().predict(
|
||||||
# use softprob since XGBoost's implementation has mitigation for floating
|
X=X,
|
||||||
# point overflow. No need to reinvent the wheel.
|
ntree_limit=ntree_limit,
|
||||||
raise ValueError(
|
validate_features=validate_features,
|
||||||
"multi:softmax doesn't support `predict_proba`. "
|
base_margin=base_margin,
|
||||||
"Switch to `multi:softproba` instead"
|
iteration_range=iteration_range,
|
||||||
|
output_margin=True
|
||||||
)
|
)
|
||||||
|
class_prob = softmax(raw_predt, axis=1)
|
||||||
|
return class_prob
|
||||||
class_probs = super().predict(
|
class_probs = super().predict(
|
||||||
X=X,
|
X=X,
|
||||||
ntree_limit=ntree_limit,
|
ntree_limit=ntree_limit,
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
"""Xgboost pyspark integration submodule for core code."""
|
"""Xgboost pyspark integration submodule for core code."""
|
||||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
|
from typing import Iterator, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pyspark.ml import Estimator, Model
|
from pyspark.ml import Estimator, Model
|
||||||
@ -34,7 +36,7 @@ from xgboost.training import train as worker_train
|
|||||||
import xgboost
|
import xgboost
|
||||||
from xgboost import XGBClassifier, XGBRegressor
|
from xgboost import XGBClassifier, XGBRegressor
|
||||||
|
|
||||||
from .data import _convert_partition_data_to_dmatrix
|
from .data import alias, create_dmatrix_from_partitions, stack_series
|
||||||
from .model import (
|
from .model import (
|
||||||
SparkXGBModelReader,
|
SparkXGBModelReader,
|
||||||
SparkXGBModelWriter,
|
SparkXGBModelWriter,
|
||||||
@ -324,10 +326,10 @@ def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If feature column is array type, its elements must be number type."
|
"If feature column is array type, its elements must be number type."
|
||||||
)
|
)
|
||||||
features_array_col = features_col.cast(ArrayType(FloatType())).alias("values")
|
features_array_col = features_col.cast(ArrayType(FloatType())).alias(alias.data)
|
||||||
elif isinstance(features_col_datatype, VectorUDT):
|
elif isinstance(features_col_datatype, VectorUDT):
|
||||||
features_array_col = vector_to_array(features_col, dtype="float32").alias(
|
features_array_col = vector_to_array(features_col, dtype="float32").alias(
|
||||||
"values"
|
alias.data
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -462,7 +464,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
params.update(fit_params)
|
params.update(fit_params)
|
||||||
params["verbose_eval"] = verbose_eval
|
params["verbose_eval"] = verbose_eval
|
||||||
classification = self._xgb_cls() == XGBClassifier
|
classification = self._xgb_cls() == XGBClassifier
|
||||||
num_classes = int(dataset.select(countDistinct("label")).collect()[0][0])
|
num_classes = int(dataset.select(countDistinct(alias.label)).collect()[0][0])
|
||||||
if classification and num_classes == 2:
|
if classification and num_classes == 2:
|
||||||
params["objective"] = "binary:logistic"
|
params["objective"] = "binary:logistic"
|
||||||
elif classification and num_classes > 2:
|
elif classification and num_classes > 2:
|
||||||
@ -493,37 +495,30 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
def _fit(self, dataset):
|
def _fit(self, dataset):
|
||||||
# pylint: disable=too-many-statements, too-many-locals
|
# pylint: disable=too-many-statements, too-many-locals
|
||||||
self._validate_params()
|
self._validate_params()
|
||||||
label_col = col(self.getOrDefault(self.labelCol)).alias("label")
|
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)
|
||||||
|
|
||||||
features_array_col = _validate_and_convert_feature_col_as_array_col(
|
features_array_col = _validate_and_convert_feature_col_as_array_col(
|
||||||
dataset, self.getOrDefault(self.featuresCol)
|
dataset, self.getOrDefault(self.featuresCol)
|
||||||
)
|
)
|
||||||
select_cols = [features_array_col, label_col]
|
select_cols = [features_array_col, label_col]
|
||||||
|
|
||||||
has_weight = False
|
|
||||||
has_validation = False
|
|
||||||
has_base_margin = False
|
|
||||||
|
|
||||||
if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol):
|
if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol):
|
||||||
has_weight = True
|
select_cols.append(
|
||||||
select_cols.append(col(self.getOrDefault(self.weightCol)).alias("weight"))
|
col(self.getOrDefault(self.weightCol)).alias(alias.weight)
|
||||||
|
)
|
||||||
|
|
||||||
if self.isDefined(self.validationIndicatorCol) and self.getOrDefault(
|
if self.isDefined(self.validationIndicatorCol) and self.getOrDefault(
|
||||||
self.validationIndicatorCol
|
self.validationIndicatorCol
|
||||||
):
|
):
|
||||||
has_validation = True
|
|
||||||
select_cols.append(
|
select_cols.append(
|
||||||
col(self.getOrDefault(self.validationIndicatorCol)).alias(
|
col(self.getOrDefault(self.validationIndicatorCol)).alias(alias.valid)
|
||||||
"validationIndicator"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.isDefined(self.base_margin_col) and self.getOrDefault(
|
if self.isDefined(self.base_margin_col) and self.getOrDefault(
|
||||||
self.base_margin_col
|
self.base_margin_col
|
||||||
):
|
):
|
||||||
has_base_margin = True
|
|
||||||
select_cols.append(
|
select_cols.append(
|
||||||
col(self.getOrDefault(self.base_margin_col)).alias("baseMargin")
|
col(self.getOrDefault(self.base_margin_col)).alias(alias.margin)
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = dataset.select(*select_cols)
|
dataset = dataset.select(*select_cols)
|
||||||
@ -551,6 +546,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
cpu_per_task = int(
|
cpu_per_task = int(
|
||||||
_get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1")
|
_get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1")
|
||||||
)
|
)
|
||||||
|
|
||||||
dmatrix_kwargs = {
|
dmatrix_kwargs = {
|
||||||
"nthread": cpu_per_task,
|
"nthread": cpu_per_task,
|
||||||
"feature_types": self.getOrDefault(self.feature_types),
|
"feature_types": self.getOrDefault(self.feature_types),
|
||||||
@ -564,9 +560,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
is_local = _is_local(_get_spark_session().sparkContext)
|
is_local = _is_local(_get_spark_session().sparkContext)
|
||||||
|
|
||||||
def _train_booster(pandas_df_iter):
|
def _train_booster(pandas_df_iter):
|
||||||
"""
|
"""Takes in an RDD partition and outputs a booster for that partition after
|
||||||
Takes in an RDD partition and outputs a booster for that partition after going through
|
going through the Rabit Ring protocol
|
||||||
the Rabit Ring protocol
|
|
||||||
"""
|
"""
|
||||||
from pyspark import BarrierTaskContext
|
from pyspark import BarrierTaskContext
|
||||||
|
|
||||||
@ -586,25 +582,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
_rabit_args = _get_args_from_message_list(messages)
|
_rabit_args = _get_args_from_message_list(messages)
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
with RabitContext(_rabit_args, context):
|
with RabitContext(_rabit_args, context):
|
||||||
dtrain, dval = None, []
|
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||||
if has_validation:
|
|
||||||
dtrain, dval = _convert_partition_data_to_dmatrix(
|
|
||||||
pandas_df_iter,
|
pandas_df_iter,
|
||||||
has_weight,
|
None,
|
||||||
has_validation,
|
dmatrix_kwargs,
|
||||||
has_base_margin,
|
|
||||||
dmatrix_kwargs=dmatrix_kwargs,
|
|
||||||
)
|
)
|
||||||
# TODO: Question: do we need to add dtrain to dval list ?
|
if dvalid is not None:
|
||||||
dval = [(dtrain, "training"), (dval, "validation")]
|
dval = [(dtrain, "training"), (dvalid, "validation")]
|
||||||
else:
|
else:
|
||||||
dtrain = _convert_partition_data_to_dmatrix(
|
dval = None
|
||||||
pandas_df_iter,
|
|
||||||
has_weight,
|
|
||||||
has_validation,
|
|
||||||
has_base_margin,
|
|
||||||
dmatrix_kwargs=dmatrix_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
booster = worker_train(
|
booster = worker_train(
|
||||||
params=booster_params,
|
params=booster_params,
|
||||||
@ -619,13 +605,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
yield pd.DataFrame(
|
yield pd.DataFrame(
|
||||||
data={
|
data={
|
||||||
"config": [booster.save_config()],
|
"config": [booster.save_config()],
|
||||||
"booster": [booster.save_raw("json").decode("utf-8")]
|
"booster": [booster.save_raw("json").decode("utf-8")],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_job():
|
def _run_job():
|
||||||
ret = (
|
ret = (
|
||||||
dataset.mapInPandas(_train_booster, schema="config string, booster string")
|
dataset.mapInPandas(
|
||||||
|
_train_booster, schema="config string, booster string"
|
||||||
|
)
|
||||||
.rdd.barrier()
|
.rdd.barrier()
|
||||||
.mapPartitions(lambda x: x)
|
.mapPartitions(lambda x: x)
|
||||||
.collect()[0]
|
.collect()[0]
|
||||||
@ -635,8 +623,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
(config, booster) = _run_job()
|
(config, booster) = _run_job()
|
||||||
|
|
||||||
result_xgb_model = self._convert_to_sklearn_model(
|
result_xgb_model = self._convert_to_sklearn_model(
|
||||||
bytearray(booster, "utf-8"),
|
bytearray(booster, "utf-8"), config
|
||||||
config
|
|
||||||
)
|
)
|
||||||
return self._copyValues(self._create_pyspark_model(result_xgb_model))
|
return self._copyValues(self._create_pyspark_model(result_xgb_model))
|
||||||
|
|
||||||
@ -675,12 +662,6 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
* 'total_gain': the total gain across all splits the feature is used in.
|
* 'total_gain': the total gain across all splits the feature is used in.
|
||||||
* 'total_cover': the total coverage across all splits the feature is used in.
|
* 'total_cover': the total coverage across all splits the feature is used in.
|
||||||
|
|
||||||
.. note:: Feature importance is defined only for tree boosters
|
|
||||||
|
|
||||||
Feature importance is only defined when the decision tree model is chosen as base
|
|
||||||
learner (`booster=gbtree`). It is not defined for other base learner types, such
|
|
||||||
as linear learners (`booster=gblinear`).
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
importance_type: str, default 'weight'
|
importance_type: str, default 'weight'
|
||||||
@ -728,21 +709,26 @@ class SparkXGBRegressorModel(_SparkXGBModel):
|
|||||||
):
|
):
|
||||||
has_base_margin = True
|
has_base_margin = True
|
||||||
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
||||||
"baseMargin"
|
alias.margin
|
||||||
)
|
)
|
||||||
|
|
||||||
@pandas_udf("double")
|
@pandas_udf("double")
|
||||||
def predict_udf(input_data: pd.DataFrame) -> pd.Series:
|
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
||||||
X = np.array(input_data["values"].tolist())
|
model = xgb_sklearn_model
|
||||||
|
for data in iterator:
|
||||||
|
X = stack_series(data[alias.data])
|
||||||
if has_base_margin:
|
if has_base_margin:
|
||||||
base_margin = input_data["baseMargin"].to_numpy()
|
base_margin = data[alias.margin].to_numpy()
|
||||||
else:
|
else:
|
||||||
base_margin = None
|
base_margin = None
|
||||||
|
|
||||||
preds = xgb_sklearn_model.predict(
|
preds = model.predict(
|
||||||
X, base_margin=base_margin, validate_features=False, **predict_params
|
X,
|
||||||
|
base_margin=base_margin,
|
||||||
|
validate_features=False,
|
||||||
|
**predict_params,
|
||||||
)
|
)
|
||||||
return pd.Series(preds)
|
yield pd.Series(preds)
|
||||||
|
|
||||||
features_col = _validate_and_convert_feature_col_as_array_col(
|
features_col = _validate_and_convert_feature_col_as_array_col(
|
||||||
dataset, self.getOrDefault(self.featuresCol)
|
dataset, self.getOrDefault(self.featuresCol)
|
||||||
@ -781,26 +767,10 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
|||||||
):
|
):
|
||||||
has_base_margin = True
|
has_base_margin = True
|
||||||
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
||||||
"baseMargin"
|
alias.margin
|
||||||
)
|
)
|
||||||
|
|
||||||
@pandas_udf(
|
def transform_margin(margins: np.ndarray):
|
||||||
"rawPrediction array<double>, prediction double, probability array<double>"
|
|
||||||
)
|
|
||||||
def predict_udf(input_data: pd.DataFrame) -> pd.DataFrame:
|
|
||||||
X = np.array(input_data["values"].tolist())
|
|
||||||
if has_base_margin:
|
|
||||||
base_margin = input_data["baseMargin"].to_numpy()
|
|
||||||
else:
|
|
||||||
base_margin = None
|
|
||||||
|
|
||||||
margins = xgb_sklearn_model.predict(
|
|
||||||
X,
|
|
||||||
base_margin=base_margin,
|
|
||||||
output_margin=True,
|
|
||||||
validate_features=False,
|
|
||||||
**predict_params,
|
|
||||||
)
|
|
||||||
if margins.ndim == 1:
|
if margins.ndim == 1:
|
||||||
# binomial case
|
# binomial case
|
||||||
classone_probs = expit(margins)
|
classone_probs = expit(margins)
|
||||||
@ -811,15 +781,39 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
|||||||
# multinomial case
|
# multinomial case
|
||||||
raw_preds = margins
|
raw_preds = margins
|
||||||
class_probs = softmax(raw_preds, axis=1)
|
class_probs = softmax(raw_preds, axis=1)
|
||||||
|
return raw_preds, class_probs
|
||||||
|
|
||||||
|
@pandas_udf(
|
||||||
|
"rawPrediction array<double>, prediction double, probability array<double>"
|
||||||
|
)
|
||||||
|
def predict_udf(
|
||||||
|
iterator: Iterator[Tuple[pd.Series, ...]]
|
||||||
|
) -> Iterator[pd.DataFrame]:
|
||||||
|
model = xgb_sklearn_model
|
||||||
|
for data in iterator:
|
||||||
|
X = stack_series(data[alias.data])
|
||||||
|
if has_base_margin:
|
||||||
|
base_margin = stack_series(data[alias.margin])
|
||||||
|
else:
|
||||||
|
base_margin = None
|
||||||
|
|
||||||
|
margins = model.predict(
|
||||||
|
X,
|
||||||
|
base_margin=base_margin,
|
||||||
|
output_margin=True,
|
||||||
|
validate_features=False,
|
||||||
|
**predict_params,
|
||||||
|
)
|
||||||
|
raw_preds, class_probs = transform_margin(margins)
|
||||||
|
|
||||||
# It seems that they use argmax of class probs,
|
# It seems that they use argmax of class probs,
|
||||||
# not of margin to get the prediction (Note: scala implementation)
|
# not of margin to get the prediction (Note: scala implementation)
|
||||||
preds = np.argmax(class_probs, axis=1)
|
preds = np.argmax(class_probs, axis=1)
|
||||||
return pd.DataFrame(
|
yield pd.DataFrame(
|
||||||
data={
|
data={
|
||||||
"rawPrediction": pd.Series(raw_preds.tolist()),
|
"rawPrediction": pd.Series(list(raw_preds)),
|
||||||
"prediction": pd.Series(preds),
|
"prediction": pd.Series(preds),
|
||||||
"probability": pd.Series(class_probs.tolist()),
|
"probability": pd.Series(list(class_probs)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,194 +1,181 @@
|
|||||||
# type: ignore
|
"""Utilities for processing spark partitions."""
|
||||||
"""Xgboost pyspark integration submodule for data related functions."""
|
from collections import defaultdict, namedtuple
|
||||||
# pylint: disable=too-many-arguments
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple
|
||||||
from typing import Iterator
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from xgboost.compat import concat
|
||||||
|
|
||||||
from xgboost import DMatrix
|
from xgboost import DataIter, DeviceQuantileDMatrix, DMatrix
|
||||||
|
|
||||||
|
|
||||||
def _prepare_train_val_data(
|
def stack_series(series: pd.Series) -> np.ndarray:
|
||||||
data_iterator, has_weight, has_validation, has_fit_base_margin
|
"""Stack a series of arrays."""
|
||||||
):
|
array = series.to_numpy(copy=False)
|
||||||
def gen_data_pdf():
|
array = np.stack(array)
|
||||||
for pdf in data_iterator:
|
return array
|
||||||
yield pdf
|
|
||||||
|
|
||||||
return _process_data_iter(
|
|
||||||
gen_data_pdf(),
|
|
||||||
train=True,
|
|
||||||
has_weight=has_weight,
|
|
||||||
has_validation=has_validation,
|
|
||||||
has_fit_base_margin=has_fit_base_margin,
|
|
||||||
has_predict_base_margin=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_feature_dims(num_dims, expected_dims):
|
# Global constant for defining column alias shared between estimator and data
|
||||||
"""
|
# processing procedures.
|
||||||
Check all feature vectors has the same dimension
|
Alias = namedtuple("Alias", ("data", "label", "weight", "margin", "valid"))
|
||||||
"""
|
alias = Alias("values", "label", "weight", "baseMargin", "validationIndicator")
|
||||||
if expected_dims is None:
|
|
||||||
return num_dims
|
|
||||||
if num_dims != expected_dims:
|
|
||||||
raise ValueError(
|
|
||||||
f"Rows contain different feature dimensions: Expecting {expected_dims}, got {num_dims}."
|
|
||||||
)
|
|
||||||
return expected_dims
|
|
||||||
|
|
||||||
|
|
||||||
def _row_tuple_list_to_feature_matrix_y_w(
|
def concat_or_none(seq: Optional[Sequence[np.ndarray]]) -> Optional[np.ndarray]:
|
||||||
data_iterator,
|
"""Concatenate the data if it's not None."""
|
||||||
train,
|
if seq:
|
||||||
has_weight,
|
return concat(seq)
|
||||||
has_fit_base_margin,
|
return None
|
||||||
has_predict_base_margin,
|
|
||||||
has_validation: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Construct a feature matrix in ndarray format, label array y and weight array w
|
|
||||||
from the row_tuple_list.
|
|
||||||
If train == False, y and w will be None.
|
|
||||||
If has_weight == False, w will be None.
|
|
||||||
If has_base_margin == False, b_m will be None.
|
|
||||||
Note: the row_tuple_list will be cleared during
|
|
||||||
executing for reducing peak memory consumption
|
|
||||||
"""
|
|
||||||
# pylint: disable=too-many-locals
|
|
||||||
expected_feature_dims = None
|
|
||||||
label_list, weight_list, base_margin_list = [], [], []
|
|
||||||
label_val_list, weight_val_list, base_margin_val_list = [], [], []
|
|
||||||
values_list, values_val_list = [], []
|
|
||||||
|
|
||||||
# Process rows
|
|
||||||
for pdf in data_iterator:
|
|
||||||
if len(pdf) == 0:
|
|
||||||
continue
|
|
||||||
if train and has_validation:
|
|
||||||
pdf_val = pdf.loc[pdf["validationIndicator"], :]
|
|
||||||
pdf = pdf.loc[~pdf["validationIndicator"], :]
|
|
||||||
|
|
||||||
num_feature_dims = len(pdf["values"].values[0])
|
def cache_partitions(
|
||||||
|
iterator: Iterator[pd.DataFrame], append: Callable[[pd.DataFrame, str, bool], None]
|
||||||
|
) -> None:
|
||||||
|
"""Extract partitions from pyspark iterator. `append` is a user defined function for
|
||||||
|
accepting new partition."""
|
||||||
|
|
||||||
expected_feature_dims = _check_feature_dims(
|
def make_blob(part: pd.DataFrame, is_valid: bool) -> None:
|
||||||
num_feature_dims, expected_feature_dims
|
append(part, alias.data, is_valid)
|
||||||
)
|
append(part, alias.label, is_valid)
|
||||||
|
append(part, alias.weight, is_valid)
|
||||||
|
append(part, alias.margin, is_valid)
|
||||||
|
|
||||||
|
has_validation: Optional[bool] = None
|
||||||
|
|
||||||
|
for part in iterator:
|
||||||
|
if has_validation is None:
|
||||||
|
has_validation = alias.valid in part.columns
|
||||||
|
if has_validation is True:
|
||||||
|
assert alias.valid in part.columns
|
||||||
|
|
||||||
# Note: each element in `pdf["values"]` is an numpy array.
|
|
||||||
values_list.append(pdf["values"].to_list())
|
|
||||||
if train:
|
|
||||||
label_list.append(pdf["label"].to_numpy())
|
|
||||||
if has_weight:
|
|
||||||
weight_list.append(pdf["weight"].to_numpy())
|
|
||||||
if has_fit_base_margin or has_predict_base_margin:
|
|
||||||
base_margin_list.append(pdf["baseMargin"].to_numpy())
|
|
||||||
if has_validation:
|
if has_validation:
|
||||||
values_val_list.append(pdf_val["values"].to_list())
|
train = part.loc[~part[alias.valid], :]
|
||||||
if train:
|
valid = part.loc[part[alias.valid], :]
|
||||||
label_val_list.append(pdf_val["label"].to_numpy())
|
else:
|
||||||
if has_weight:
|
train, valid = part, None
|
||||||
weight_val_list.append(pdf_val["weight"].to_numpy())
|
|
||||||
if has_fit_base_margin or has_predict_base_margin:
|
|
||||||
base_margin_val_list.append(pdf_val["baseMargin"].to_numpy())
|
|
||||||
|
|
||||||
# Construct feature_matrix
|
make_blob(train, False)
|
||||||
if expected_feature_dims is None:
|
if valid is not None:
|
||||||
return [], [], [], []
|
make_blob(valid, True)
|
||||||
|
|
||||||
# Construct feature_matrix, y and w
|
|
||||||
feature_matrix = np.concatenate(values_list)
|
class PartIter(DataIter):
|
||||||
y = np.concatenate(label_list) if train else None
|
"""Iterator for creating Quantile DMatrix from partitions."""
|
||||||
w = np.concatenate(weight_list) if has_weight else None
|
|
||||||
b_m = (
|
def __init__(self, data: Dict[str, List], on_device: bool) -> None:
|
||||||
np.concatenate(base_margin_list)
|
self._iter = 0
|
||||||
if (has_fit_base_margin or has_predict_base_margin)
|
self._cuda = on_device
|
||||||
else None
|
self._data = data
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _fetch(self, data: Optional[Sequence[pd.DataFrame]]) -> Optional[pd.DataFrame]:
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._cuda:
|
||||||
|
import cudf # pylint: disable=import-error
|
||||||
|
|
||||||
|
return cudf.DataFrame(data[self._iter])
|
||||||
|
|
||||||
|
return data[self._iter]
|
||||||
|
|
||||||
|
def next(self, input_data: Callable) -> int:
|
||||||
|
if self._iter == len(self._data[alias.data]):
|
||||||
|
return 0
|
||||||
|
input_data(
|
||||||
|
data=self._fetch(self._data[alias.data]),
|
||||||
|
label=self._fetch(self._data.get(alias.label, None)),
|
||||||
|
weight=self._fetch(self._data.get(alias.weight, None)),
|
||||||
|
base_margin=self._fetch(self._data.get(alias.margin, None)),
|
||||||
)
|
)
|
||||||
if has_validation:
|
self._iter += 1
|
||||||
feature_matrix_val = np.concatenate(values_val_list)
|
return 1
|
||||||
y_val = np.concatenate(label_val_list) if train else None
|
|
||||||
w_val = np.concatenate(weight_val_list) if has_weight else None
|
def reset(self) -> None:
|
||||||
b_m_val = (
|
self._iter = 0
|
||||||
np.concatenate(base_margin_val_list)
|
|
||||||
if (has_fit_base_margin or has_predict_base_margin)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return feature_matrix, y, w, b_m, feature_matrix_val, y_val, w_val, b_m_val
|
|
||||||
return feature_matrix, y, w, b_m
|
|
||||||
|
|
||||||
|
|
||||||
def _process_data_iter(
|
def create_dmatrix_from_partitions(
|
||||||
data_iterator: Iterator[pd.DataFrame],
|
iterator: Iterator[pd.DataFrame],
|
||||||
train: bool,
|
feature_cols: Optional[Sequence[str]],
|
||||||
has_weight: bool,
|
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
|
||||||
has_validation: bool,
|
) -> Tuple[DMatrix, Optional[DMatrix]]:
|
||||||
has_fit_base_margin: bool = False,
|
"""Create DMatrix from spark data partitions. This is not particularly efficient as
|
||||||
has_predict_base_margin: bool = False,
|
we need to convert the pandas series format to numpy then concatenate all the data.
|
||||||
):
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
iterator :
|
||||||
|
Pyspark partition iterator.
|
||||||
|
kwargs :
|
||||||
|
Metainfo for DMatrix.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
If input is for train and has_validation=True, it will split the train data into train dataset
|
|
||||||
and validation dataset, and return (train_X, train_y, train_w, train_b_m <-
|
train_data: Dict[str, List[np.ndarray]] = defaultdict(list)
|
||||||
train base margin, val_X, val_y, val_w, val_b_m <- validation base margin)
|
valid_data: Dict[str, List[np.ndarray]] = defaultdict(list)
|
||||||
otherwise return (X, y, w, b_m <- base margin)
|
|
||||||
"""
|
n_features: int = 0
|
||||||
return _row_tuple_list_to_feature_matrix_y_w(
|
|
||||||
data_iterator,
|
def append_m(part: pd.DataFrame, name: str, is_valid: bool) -> None:
|
||||||
train,
|
nonlocal n_features
|
||||||
has_weight,
|
if name in part.columns:
|
||||||
has_fit_base_margin,
|
array = part[name]
|
||||||
has_predict_base_margin,
|
if name == alias.data:
|
||||||
has_validation,
|
array = stack_series(array)
|
||||||
|
if n_features == 0:
|
||||||
|
n_features = array.shape[1]
|
||||||
|
assert n_features == array.shape[1]
|
||||||
|
|
||||||
|
if is_valid:
|
||||||
|
valid_data[name].append(array)
|
||||||
|
else:
|
||||||
|
train_data[name].append(array)
|
||||||
|
|
||||||
|
def append_dqm(part: pd.DataFrame, name: str, is_valid: bool) -> None:
|
||||||
|
"""Preprocessing for DeviceQuantileDMatrix"""
|
||||||
|
nonlocal n_features
|
||||||
|
if name == alias.data or name in part.columns:
|
||||||
|
if name == alias.data:
|
||||||
|
cname = feature_cols
|
||||||
|
else:
|
||||||
|
cname = name
|
||||||
|
|
||||||
|
array = part[cname]
|
||||||
|
if name == alias.data:
|
||||||
|
if n_features == 0:
|
||||||
|
n_features = array.shape[1]
|
||||||
|
assert n_features == array.shape[1]
|
||||||
|
|
||||||
|
if is_valid:
|
||||||
|
valid_data[name].append(array)
|
||||||
|
else:
|
||||||
|
train_data[name].append(array)
|
||||||
|
|
||||||
|
def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix:
|
||||||
|
data = concat_or_none(values[alias.data])
|
||||||
|
label = concat_or_none(values.get(alias.label, None))
|
||||||
|
weight = concat_or_none(values.get(alias.weight, None))
|
||||||
|
margin = concat_or_none(values.get(alias.margin, None))
|
||||||
|
return DMatrix(
|
||||||
|
data=data, label=label, weight=weight, base_margin=margin, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_dmatrix = feature_cols is None
|
||||||
|
if is_dmatrix:
|
||||||
|
cache_partitions(iterator, append_m)
|
||||||
|
dtrain = make(train_data, kwargs)
|
||||||
|
else:
|
||||||
|
cache_partitions(iterator, append_dqm)
|
||||||
|
it = PartIter(train_data, True)
|
||||||
|
dtrain = DeviceQuantileDMatrix(it, **kwargs)
|
||||||
|
|
||||||
def _convert_partition_data_to_dmatrix(
|
dvalid = make(valid_data, kwargs) if len(valid_data) != 0 else None
|
||||||
partition_data_iter,
|
|
||||||
has_weight,
|
|
||||||
has_validation,
|
|
||||||
has_base_margin,
|
|
||||||
dmatrix_kwargs=None,
|
|
||||||
):
|
|
||||||
# pylint: disable=too-many-locals, unbalanced-tuple-unpacking
|
|
||||||
dmatrix_kwargs = dmatrix_kwargs or {}
|
|
||||||
# if we are not using external storage, we use the standard method of parsing data.
|
|
||||||
train_val_data = _prepare_train_val_data(
|
|
||||||
partition_data_iter, has_weight, has_validation, has_base_margin
|
|
||||||
)
|
|
||||||
if has_validation:
|
|
||||||
(
|
|
||||||
train_x,
|
|
||||||
train_y,
|
|
||||||
train_w,
|
|
||||||
train_b_m,
|
|
||||||
val_x,
|
|
||||||
val_y,
|
|
||||||
val_w,
|
|
||||||
val_b_m,
|
|
||||||
) = train_val_data
|
|
||||||
training_dmatrix = DMatrix(
|
|
||||||
data=train_x,
|
|
||||||
label=train_y,
|
|
||||||
weight=train_w,
|
|
||||||
base_margin=train_b_m,
|
|
||||||
**dmatrix_kwargs,
|
|
||||||
)
|
|
||||||
val_dmatrix = DMatrix(
|
|
||||||
data=val_x,
|
|
||||||
label=val_y,
|
|
||||||
weight=val_w,
|
|
||||||
base_margin=val_b_m,
|
|
||||||
**dmatrix_kwargs,
|
|
||||||
)
|
|
||||||
return training_dmatrix, val_dmatrix
|
|
||||||
|
|
||||||
train_x, train_y, train_w, train_b_m = train_val_data
|
assert dtrain.num_col() == n_features
|
||||||
training_dmatrix = DMatrix(
|
if dvalid is not None:
|
||||||
data=train_x,
|
assert dvalid.num_col() == dtrain.num_col()
|
||||||
label=train_y,
|
|
||||||
weight=train_w,
|
return dtrain, dvalid
|
||||||
base_margin=train_b_m,
|
|
||||||
**dmatrix_kwargs,
|
|
||||||
)
|
|
||||||
return training_dmatrix
|
|
||||||
|
|||||||
@ -15,13 +15,11 @@ PROJECT_ROOT = os.path.normpath(os.path.join(CURDIR, os.path.pardir, os.path.par
|
|||||||
def run_formatter(rel_path: str) -> bool:
|
def run_formatter(rel_path: str) -> bool:
|
||||||
path = os.path.join(PROJECT_ROOT, rel_path)
|
path = os.path.join(PROJECT_ROOT, rel_path)
|
||||||
isort_ret = subprocess.run(["isort", "--check", "--profile=black", path]).returncode
|
isort_ret = subprocess.run(["isort", "--check", "--profile=black", path]).returncode
|
||||||
black_ret = subprocess.run(
|
black_ret = subprocess.run(["black", "--check", rel_path]).returncode
|
||||||
["black", "--check", "./python-package/xgboost/dask.py"]
|
|
||||||
).returncode
|
|
||||||
if isort_ret != 0 or black_ret != 0:
|
if isort_ret != 0 or black_ret != 0:
|
||||||
msg = (
|
msg = (
|
||||||
"Please run the following command on your machine to address the format"
|
"Please run the following command on your machine to address the format"
|
||||||
f" errors:\n isort --check --profile=black {rel_path}\n black {rel_path}\n"
|
f" errors:\n isort --profile=black {rel_path}\n black {rel_path}\n"
|
||||||
)
|
)
|
||||||
print(msg, file=sys.stdout)
|
print(msg, file=sys.stdout)
|
||||||
return False
|
return False
|
||||||
@ -38,7 +36,8 @@ def run_mypy(rel_path: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class PyLint:
|
class PyLint:
|
||||||
"""A helper for running pylint, mostly copied from dmlc-core/scripts. """
|
"""A helper for running pylint, mostly copied from dmlc-core/scripts."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.pypackage_root = os.path.join(PROJECT_ROOT, "python-package/")
|
self.pypackage_root = os.path.join(PROJECT_ROOT, "python-package/")
|
||||||
self.pylint_cats = set(["error", "warning", "convention", "refactor"])
|
self.pylint_cats = set(["error", "warning", "convention", "refactor"])
|
||||||
@ -115,6 +114,8 @@ if __name__ == "__main__":
|
|||||||
for path in [
|
for path in [
|
||||||
"python-package/xgboost/dask.py",
|
"python-package/xgboost/dask.py",
|
||||||
"python-package/xgboost/spark",
|
"python-package/xgboost/spark",
|
||||||
|
"tests/python/test_spark/test_data.py",
|
||||||
|
"tests/python-gpu/test_spark_with_gpu/test_data.py",
|
||||||
"tests/ci_build/lint_python.py",
|
"tests/ci_build/lint_python.py",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
@ -128,8 +129,10 @@ if __name__ == "__main__":
|
|||||||
"demo/guide-python/external_memory.py",
|
"demo/guide-python/external_memory.py",
|
||||||
"demo/guide-python/cat_in_the_dat.py",
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
"tests/python/test_data_iterator.py",
|
"tests/python/test_data_iterator.py",
|
||||||
|
"tests/python/test_spark/test_data.py",
|
||||||
"tests/python-gpu/test_gpu_with_dask.py",
|
"tests/python-gpu/test_gpu_with_dask.py",
|
||||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||||
|
"tests/python-gpu/test_spark_with_gpu/test_data.py",
|
||||||
"tests/ci_build/lint_python.py",
|
"tests/ci_build/lint_python.py",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
|
|||||||
23
tests/python-gpu/test_spark_with_gpu/test_data.py
Normal file
23
tests/python-gpu/test_spark_with_gpu/test_data.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.append("tests/python")
|
||||||
|
|
||||||
|
import testing as tm
|
||||||
|
|
||||||
|
if tm.no_spark()["condition"]:
|
||||||
|
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||||
|
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||||
|
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||||
|
|
||||||
|
|
||||||
|
from test_spark.test_data import run_dmatrix_ctor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
|
def test_qdm_ctor() -> None:
|
||||||
|
run_dmatrix_ctor(True)
|
||||||
@ -1,11 +1,9 @@
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
from typing import List
|
||||||
import shutil
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
import testing as tm
|
import testing as tm
|
||||||
|
|
||||||
if tm.no_spark()["condition"]:
|
if tm.no_spark()["condition"]:
|
||||||
@ -13,156 +11,90 @@ if tm.no_spark()["condition"]:
|
|||||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||||
|
|
||||||
from xgboost.spark.data import (
|
from xgboost.spark.data import alias, create_dmatrix_from_partitions, stack_series
|
||||||
_row_tuple_list_to_feature_matrix_y_w,
|
|
||||||
_convert_partition_data_to_dmatrix,
|
|
||||||
)
|
|
||||||
|
|
||||||
from xgboost import DMatrix, XGBClassifier
|
|
||||||
from xgboost.training import train as worker_train
|
|
||||||
from .utils import SparkTestCase
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.getLogger("py4j").setLevel(logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
class DataTest(SparkTestCase):
|
def test_stack() -> None:
|
||||||
def test_sparse_dense_vector(self):
|
a = pd.DataFrame({"a": [[1, 2], [3, 4]]})
|
||||||
def row_tup_iter(data):
|
b = stack_series(a["a"])
|
||||||
pdf = pd.DataFrame(data)
|
assert b.shape == (2, 2)
|
||||||
yield pdf
|
|
||||||
|
|
||||||
expected_ndarray = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
a = pd.DataFrame({"a": [[1], [3]]})
|
||||||
data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]}
|
b = stack_series(a["a"])
|
||||||
feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w(
|
assert b.shape == (2, 1)
|
||||||
list(row_tup_iter(data)),
|
|
||||||
train=False,
|
a = pd.DataFrame({"a": [np.array([1, 2]), np.array([3, 4])]})
|
||||||
has_weight=False,
|
b = stack_series(a["a"])
|
||||||
has_fit_base_margin=False,
|
assert b.shape == (2, 2)
|
||||||
has_predict_base_margin=False,
|
|
||||||
|
a = pd.DataFrame({"a": [np.array([1]), np.array([3])]})
|
||||||
|
b = stack_series(a["a"])
|
||||||
|
assert b.shape == (2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dmatrix_ctor(is_dqm: bool) -> None:
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
dfs: List[pd.DataFrame] = []
|
||||||
|
n_features = 16
|
||||||
|
n_samples_per_batch = 16
|
||||||
|
n_batches = 10
|
||||||
|
feature_types = ["float"] * n_features
|
||||||
|
|
||||||
|
for i in range(n_batches):
|
||||||
|
X = rng.normal(loc=0, size=256).reshape(n_samples_per_batch, n_features)
|
||||||
|
y = rng.normal(loc=0, size=n_samples_per_batch)
|
||||||
|
m = rng.normal(loc=0, size=n_samples_per_batch)
|
||||||
|
w = rng.normal(loc=0.5, scale=0.5, size=n_samples_per_batch)
|
||||||
|
w -= w.min()
|
||||||
|
|
||||||
|
valid = rng.binomial(n=1, p=0.5, size=16).astype(np.bool_)
|
||||||
|
|
||||||
|
df = pd.DataFrame(
|
||||||
|
{alias.label: y, alias.margin: m, alias.weight: w, alias.valid: valid}
|
||||||
)
|
)
|
||||||
self.assertIsNone(y)
|
if is_dqm:
|
||||||
self.assertIsNone(w)
|
for j in range(X.shape[1]):
|
||||||
self.assertTrue(np.allclose(feature_matrix, expected_ndarray))
|
df[f"feat-{j}"] = pd.Series(X[:, j])
|
||||||
|
else:
|
||||||
|
df[alias.data] = pd.Series(list(X))
|
||||||
|
dfs.append(df)
|
||||||
|
|
||||||
data["label"] = [1, 0]
|
kwargs = {"feature_types": feature_types}
|
||||||
feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w(
|
if is_dqm:
|
||||||
row_tup_iter(data),
|
cols = [f"feat-{i}" for i in range(n_features)]
|
||||||
train=True,
|
train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), cols, kwargs)
|
||||||
has_weight=False,
|
else:
|
||||||
has_fit_base_margin=False,
|
train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), None, kwargs)
|
||||||
has_predict_base_margin=False,
|
|
||||||
|
assert valid_Xy is not None
|
||||||
|
assert valid_Xy.num_row() + train_Xy.num_row() == n_samples_per_batch * n_batches
|
||||||
|
assert train_Xy.num_col() == n_features
|
||||||
|
assert valid_Xy.num_col() == n_features
|
||||||
|
|
||||||
|
df = pd.concat(dfs, axis=0)
|
||||||
|
df_train = df.loc[~df[alias.valid], :]
|
||||||
|
df_valid = df.loc[df[alias.valid], :]
|
||||||
|
|
||||||
|
assert df_train.shape[0] == train_Xy.num_row()
|
||||||
|
assert df_valid.shape[0] == valid_Xy.num_row()
|
||||||
|
|
||||||
|
# margin
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
df_train[alias.margin].to_numpy(), train_Xy.get_base_margin()
|
||||||
)
|
)
|
||||||
self.assertIsNone(w)
|
np.testing.assert_allclose(
|
||||||
self.assertTrue(np.allclose(feature_matrix, expected_ndarray))
|
df_valid[alias.margin].to_numpy(), valid_Xy.get_base_margin()
|
||||||
self.assertTrue(np.array_equal(y, np.array(data["label"])))
|
|
||||||
|
|
||||||
data["weight"] = [0.2, 0.8]
|
|
||||||
feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w(
|
|
||||||
list(row_tup_iter(data)),
|
|
||||||
train=True,
|
|
||||||
has_weight=True,
|
|
||||||
has_fit_base_margin=False,
|
|
||||||
has_predict_base_margin=False,
|
|
||||||
)
|
)
|
||||||
self.assertTrue(np.allclose(feature_matrix, expected_ndarray))
|
# weight
|
||||||
self.assertTrue(np.array_equal(y, np.array(data["label"])))
|
np.testing.assert_allclose(df_train[alias.weight].to_numpy(), train_Xy.get_weight())
|
||||||
self.assertTrue(np.array_equal(w, np.array(data["weight"])))
|
np.testing.assert_allclose(df_valid[alias.weight].to_numpy(), valid_Xy.get_weight())
|
||||||
|
# label
|
||||||
|
np.testing.assert_allclose(df_train[alias.label].to_numpy(), train_Xy.get_label())
|
||||||
|
np.testing.assert_allclose(df_valid[alias.label].to_numpy(), valid_Xy.get_label())
|
||||||
|
|
||||||
def test_dmatrix_creator(self):
|
np.testing.assert_equal(train_Xy.feature_types, feature_types)
|
||||||
|
np.testing.assert_equal(valid_Xy.feature_types, feature_types)
|
||||||
|
|
||||||
# This function acts as a pseudo-itertools.chain()
|
|
||||||
def row_tup_iter(data):
|
|
||||||
pdf = pd.DataFrame(data)
|
|
||||||
yield pdf
|
|
||||||
|
|
||||||
# Standard testing DMatrix creation
|
def test_dmatrix_ctor() -> None:
|
||||||
expected_features = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100)
|
run_dmatrix_ctor(False)
|
||||||
expected_labels = np.array([1, 0] * 100)
|
|
||||||
expected_dmatrix = DMatrix(data=expected_features, label=expected_labels)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100,
|
|
||||||
"label": [1, 0] * 100,
|
|
||||||
}
|
|
||||||
output_dmatrix = _convert_partition_data_to_dmatrix(
|
|
||||||
[pd.DataFrame(data)],
|
|
||||||
has_weight=False,
|
|
||||||
has_validation=False,
|
|
||||||
has_base_margin=False,
|
|
||||||
)
|
|
||||||
# You can't compare DMatrix outputs, so the only way is to predict on the two seperate DMatrices using
|
|
||||||
# the same classifier and making sure the outputs are equal
|
|
||||||
model = XGBClassifier()
|
|
||||||
model.fit(expected_features, expected_labels)
|
|
||||||
expected_preds = model.get_booster().predict(expected_dmatrix)
|
|
||||||
output_preds = model.get_booster().predict(output_dmatrix)
|
|
||||||
self.assertTrue(np.allclose(expected_preds, output_preds, atol=1e-3))
|
|
||||||
|
|
||||||
# DMatrix creation with weights
|
|
||||||
expected_weight = np.array([0.2, 0.8] * 100)
|
|
||||||
expected_dmatrix = DMatrix(
|
|
||||||
data=expected_features, label=expected_labels, weight=expected_weight
|
|
||||||
)
|
|
||||||
|
|
||||||
data["weight"] = [0.2, 0.8] * 100
|
|
||||||
output_dmatrix = _convert_partition_data_to_dmatrix(
|
|
||||||
[pd.DataFrame(data)],
|
|
||||||
has_weight=True,
|
|
||||||
has_validation=False,
|
|
||||||
has_base_margin=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
model.fit(expected_features, expected_labels, sample_weight=expected_weight)
|
|
||||||
expected_preds = model.get_booster().predict(expected_dmatrix)
|
|
||||||
output_preds = model.get_booster().predict(output_dmatrix)
|
|
||||||
self.assertTrue(np.allclose(expected_preds, output_preds, atol=1e-3))
|
|
||||||
|
|
||||||
def test_external_storage(self):
|
|
||||||
# Instantiating base data (features, labels)
|
|
||||||
features = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100)
|
|
||||||
labels = np.array([1, 0] * 100)
|
|
||||||
normal_dmatrix = DMatrix(features, labels)
|
|
||||||
test_dmatrix = DMatrix(features)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100,
|
|
||||||
"label": [1, 0] * 100,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Creating the dmatrix based on storage
|
|
||||||
temporary_path = tempfile.mkdtemp()
|
|
||||||
storage_dmatrix = _convert_partition_data_to_dmatrix(
|
|
||||||
[pd.DataFrame(data)],
|
|
||||||
has_weight=False,
|
|
||||||
has_validation=False,
|
|
||||||
has_base_margin=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Testing without weights
|
|
||||||
normal_booster = worker_train({}, normal_dmatrix)
|
|
||||||
storage_booster = worker_train({}, storage_dmatrix)
|
|
||||||
normal_preds = normal_booster.predict(test_dmatrix)
|
|
||||||
storage_preds = storage_booster.predict(test_dmatrix)
|
|
||||||
self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3))
|
|
||||||
shutil.rmtree(temporary_path)
|
|
||||||
|
|
||||||
# Testing weights
|
|
||||||
weights = np.array([0.2, 0.8] * 100)
|
|
||||||
normal_dmatrix = DMatrix(data=features, label=labels, weight=weights)
|
|
||||||
data["weight"] = [0.2, 0.8] * 100
|
|
||||||
|
|
||||||
temporary_path = tempfile.mkdtemp()
|
|
||||||
storage_dmatrix = _convert_partition_data_to_dmatrix(
|
|
||||||
[pd.DataFrame(data)],
|
|
||||||
has_weight=True,
|
|
||||||
has_validation=False,
|
|
||||||
has_base_margin=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
normal_booster = worker_train({}, normal_dmatrix)
|
|
||||||
storage_booster = worker_train({}, storage_dmatrix)
|
|
||||||
normal_preds = normal_booster.predict(test_dmatrix)
|
|
||||||
storage_preds = storage_booster.predict(test_dmatrix)
|
|
||||||
self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3))
|
|
||||||
shutil.rmtree(temporary_path)
|
|
||||||
|
|||||||
@ -765,23 +765,22 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
self.reg_df_test_with_eval_weight
|
self.reg_df_test_with_eval_weight
|
||||||
).collect()
|
).collect()
|
||||||
for row in pred_result_with_weight:
|
for row in pred_result_with_weight:
|
||||||
self.assertTrue(
|
assert np.isclose(
|
||||||
np.isclose(
|
|
||||||
row.prediction, row.expected_prediction_with_weight, atol=1e-3
|
row.prediction, row.expected_prediction_with_weight, atol=1e-3
|
||||||
)
|
)
|
||||||
)
|
|
||||||
# with eval
|
# with eval
|
||||||
regressor_with_eval = SparkXGBRegressor(**self.reg_params_with_eval)
|
regressor_with_eval = SparkXGBRegressor(**self.reg_params_with_eval)
|
||||||
model_with_eval = regressor_with_eval.fit(self.reg_df_train_with_eval_weight)
|
model_with_eval = regressor_with_eval.fit(self.reg_df_train_with_eval_weight)
|
||||||
self.assertTrue(
|
assert np.isclose(
|
||||||
np.isclose(
|
|
||||||
model_with_eval._xgb_sklearn_model.best_score,
|
model_with_eval._xgb_sklearn_model.best_score,
|
||||||
self.reg_with_eval_best_score,
|
self.reg_with_eval_best_score,
|
||||||
atol=1e-3,
|
atol=1e-3,
|
||||||
),
|
), (
|
||||||
f"Expected best score: {self.reg_with_eval_best_score}, "
|
f"Expected best score: {self.reg_with_eval_best_score}, but ",
|
||||||
f"but get {model_with_eval._xgb_sklearn_model.best_score}",
|
f"get {model_with_eval._xgb_sklearn_model.best_score}",
|
||||||
)
|
)
|
||||||
|
|
||||||
pred_result_with_eval = model_with_eval.transform(
|
pred_result_with_eval = model_with_eval.transform(
|
||||||
self.reg_df_test_with_eval_weight
|
self.reg_df_test_with_eval_weight
|
||||||
).collect()
|
).collect()
|
||||||
@ -905,7 +904,7 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
# Check that regardless of what booster, _convert_to_model converts to the correct class type
|
# Check that regardless of what booster, _convert_to_model converts to the correct class type
|
||||||
sklearn_classifier = classifier._convert_to_sklearn_model(
|
sklearn_classifier = classifier._convert_to_sklearn_model(
|
||||||
clf_model.get_booster().save_raw("json"),
|
clf_model.get_booster().save_raw("json"),
|
||||||
clf_model.get_booster().save_config()
|
clf_model.get_booster().save_config(),
|
||||||
)
|
)
|
||||||
assert isinstance(sklearn_classifier, XGBClassifier)
|
assert isinstance(sklearn_classifier, XGBClassifier)
|
||||||
assert sklearn_classifier.n_estimators == 200
|
assert sklearn_classifier.n_estimators == 200
|
||||||
@ -915,7 +914,7 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
|
|
||||||
sklearn_regressor = regressor._convert_to_sklearn_model(
|
sklearn_regressor = regressor._convert_to_sklearn_model(
|
||||||
reg_model.get_booster().save_raw("json"),
|
reg_model.get_booster().save_raw("json"),
|
||||||
reg_model.get_booster().save_config()
|
reg_model.get_booster().save_config(),
|
||||||
)
|
)
|
||||||
assert isinstance(sklearn_regressor, XGBRegressor)
|
assert isinstance(sklearn_regressor, XGBRegressor)
|
||||||
assert sklearn_regressor.n_estimators == 200
|
assert sklearn_regressor.n_estimators == 200
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user