[pyspark] Make Xgboost estimator support using sparse matrix as optimization (#8145)
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
parent
1703dc330f
commit
53d2a733b0
@ -1,7 +1,7 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for core code."""
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=too-few-public-methods
|
||||
# pylint: disable=too-few-public-methods, too-many-lines
|
||||
from typing import Iterator, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -37,14 +37,24 @@ from xgboost.training import train as worker_train
|
||||
import xgboost
|
||||
from xgboost import XGBClassifier, XGBRegressor
|
||||
|
||||
from .data import alias, create_dmatrix_from_partitions, stack_series
|
||||
from .data import (
|
||||
_read_csr_matrix_from_unwrapped_spark_vec,
|
||||
alias,
|
||||
create_dmatrix_from_partitions,
|
||||
stack_series,
|
||||
)
|
||||
from .model import (
|
||||
SparkXGBModelReader,
|
||||
SparkXGBModelWriter,
|
||||
SparkXGBReader,
|
||||
SparkXGBWriter,
|
||||
)
|
||||
from .params import HasArbitraryParamsDict, HasBaseMarginCol, HasFeaturesCols
|
||||
from .params import (
|
||||
HasArbitraryParamsDict,
|
||||
HasBaseMarginCol,
|
||||
HasEnableSparseDataOptim,
|
||||
HasFeaturesCols,
|
||||
)
|
||||
from .utils import (
|
||||
RabitContext,
|
||||
_get_args_from_message_list,
|
||||
@ -75,6 +85,7 @@ _pyspark_specific_params = [
|
||||
"use_gpu",
|
||||
"feature_names",
|
||||
"features_cols",
|
||||
"enable_sparse_data_optim",
|
||||
]
|
||||
|
||||
_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"]
|
||||
@ -124,6 +135,7 @@ class _SparkXGBParams(
|
||||
HasArbitraryParamsDict,
|
||||
HasBaseMarginCol,
|
||||
HasFeaturesCols,
|
||||
HasEnableSparseDataOptim,
|
||||
):
|
||||
num_workers = Param(
|
||||
Params._dummy(),
|
||||
@ -237,6 +249,7 @@ class _SparkXGBParams(
|
||||
return predict_params
|
||||
|
||||
def _validate_params(self):
|
||||
# pylint: disable=too-many-branches
|
||||
init_model = self.getOrDefault(self.xgb_model)
|
||||
if init_model is not None and not isinstance(init_model, Booster):
|
||||
raise ValueError(
|
||||
@ -267,6 +280,26 @@ class _SparkXGBParams(
|
||||
"If features_cols param set, then features_col param is ignored."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.enable_sparse_data_optim):
|
||||
if self.getOrDefault(self.missing) != 0.0:
|
||||
# If DMatrix is constructed from csr / csc matrix, then inactive elements
|
||||
# in csr / csc matrix are regarded as missing value, but, in pyspark, we
|
||||
# are hard to control elements to be active or inactive in sparse vector column,
|
||||
# some spark transformers such as VectorAssembler might compress vectors
|
||||
# to be dense or sparse format automatically, and when a spark ML vector object
|
||||
# is compressed to sparse vector, then all zero value elements become inactive.
|
||||
# So we force setting missing param to be 0 when enable_sparse_data_optim config
|
||||
# is True.
|
||||
raise ValueError(
|
||||
"If enable_sparse_data_optim is True, missing param != 0 is not supported."
|
||||
)
|
||||
if self.getOrDefault(self.features_cols):
|
||||
raise ValueError(
|
||||
"If enable_sparse_data_optim is True, you cannot set multiple feature columns "
|
||||
"but you should set one feature column with values of "
|
||||
"`pyspark.ml.linalg.Vector` type."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.use_gpu):
|
||||
tree_method = self.getParam("tree_method")
|
||||
if (
|
||||
@ -363,6 +396,52 @@ def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
|
||||
return features_array_col
|
||||
|
||||
|
||||
def _get_unwrap_udt_fn():
|
||||
try:
|
||||
from pyspark.sql.functions import unwrap_udt
|
||||
|
||||
return unwrap_udt
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from pyspark.databricks.sql.functions import unwrap_udt
|
||||
|
||||
return unwrap_udt
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"Cannot import pyspark `unwrap_udt` function. Please install pyspark>=3.4 "
|
||||
"or run on Databricks Runtime."
|
||||
) from exc
|
||||
|
||||
|
||||
def _get_unwrapped_vec_cols(feature_col):
|
||||
unwrap_udt = _get_unwrap_udt_fn()
|
||||
features_unwrapped_vec_col = unwrap_udt(feature_col)
|
||||
|
||||
# After a `pyspark.ml.linalg.VectorUDT` type column being unwrapped, it becomes
|
||||
# a pyspark struct type column, the struct fields are:
|
||||
# - `type`: byte
|
||||
# - `size`: int
|
||||
# - `indices`: array<int>
|
||||
# - `values`: array<double>
|
||||
# For sparse vector, `type` field is 0, `size` field means vector length,
|
||||
# `indices` field is the array of active element indices, `values` field
|
||||
# is the array of active element values.
|
||||
# For dense vector, `type` field is 1, `size` and `indices` fields are None,
|
||||
# `values` field is the array of the vector element values.
|
||||
return [
|
||||
features_unwrapped_vec_col.type.alias("featureVectorType"),
|
||||
features_unwrapped_vec_col.size.alias("featureVectorSize"),
|
||||
features_unwrapped_vec_col.indices.alias("featureVectorIndices"),
|
||||
# Note: the value field is double array type, cast it to float32 array type
|
||||
# for speedup following repartitioning.
|
||||
features_unwrapped_vec_col.values.cast(ArrayType(FloatType())).alias(
|
||||
"featureVectorValues"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -527,17 +606,28 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
select_cols = [label_col]
|
||||
features_cols_names = None
|
||||
if self.getOrDefault(self.features_cols):
|
||||
features_cols_names = self.getOrDefault(self.features_cols)
|
||||
features_cols = _validate_and_convert_feature_col_as_float_col_list(
|
||||
dataset, features_cols_names
|
||||
)
|
||||
select_cols.extend(features_cols)
|
||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||
if enable_sparse_data_optim:
|
||||
features_col_name = self.getOrDefault(self.featuresCol)
|
||||
features_col_datatype = dataset.schema[features_col_name].dataType
|
||||
if not isinstance(features_col_datatype, VectorUDT):
|
||||
raise ValueError(
|
||||
"If enable_sparse_data_optim is True, the feature column values must be "
|
||||
"`pyspark.ml.linalg.Vector` type."
|
||||
)
|
||||
select_cols.extend(_get_unwrapped_vec_cols(col(features_col_name)))
|
||||
else:
|
||||
features_array_col = _validate_and_convert_feature_col_as_array_col(
|
||||
dataset, self.getOrDefault(self.featuresCol)
|
||||
)
|
||||
select_cols.append(features_array_col)
|
||||
if self.getOrDefault(self.features_cols):
|
||||
features_cols_names = self.getOrDefault(self.features_cols)
|
||||
features_cols = _validate_and_convert_feature_col_as_float_col_list(
|
||||
dataset, features_cols_names
|
||||
)
|
||||
select_cols.extend(features_cols)
|
||||
else:
|
||||
features_array_col = _validate_and_convert_feature_col_as_array_col(
|
||||
dataset, self.getOrDefault(self.featuresCol)
|
||||
)
|
||||
select_cols.append(features_array_col)
|
||||
|
||||
if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol):
|
||||
select_cols.append(
|
||||
@ -589,7 +679,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
"feature_types": self.getOrDefault(self.feature_types),
|
||||
"feature_names": self.getOrDefault(self.feature_names),
|
||||
"feature_weights": self.getOrDefault(self.feature_weights),
|
||||
"missing": self.getOrDefault(self.missing),
|
||||
"missing": float(self.getOrDefault(self.missing)),
|
||||
}
|
||||
booster_params["nthread"] = cpu_per_task
|
||||
use_gpu = self.getOrDefault(self.use_gpu)
|
||||
@ -627,7 +717,11 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
evals_result = {}
|
||||
with RabitContext(_rabit_args, context):
|
||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||
pandas_df_iter, features_cols_names, gpu_id, dmatrix_kwargs
|
||||
pandas_df_iter,
|
||||
features_cols_names,
|
||||
gpu_id,
|
||||
dmatrix_kwargs,
|
||||
enable_sparse_data_optim=enable_sparse_data_optim,
|
||||
)
|
||||
if dvalid is not None:
|
||||
dval = [(dtrain, "training"), (dvalid, "validation")]
|
||||
@ -732,6 +826,12 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
vector or array feature type. But first we need to check features_cols
|
||||
and then featuresCol
|
||||
"""
|
||||
if self.getOrDefault(self.enable_sparse_data_optim):
|
||||
feature_col_names = None
|
||||
features_col = _get_unwrapped_vec_cols(
|
||||
col(self.getOrDefault(self.featuresCol))
|
||||
)
|
||||
return features_col, feature_col_names
|
||||
|
||||
feature_col_names = self.getOrDefault(self.features_cols)
|
||||
features_col = []
|
||||
@ -783,15 +883,19 @@ class SparkXGBRegressorModel(_SparkXGBModel):
|
||||
)
|
||||
|
||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||
|
||||
@pandas_udf("double")
|
||||
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
||||
model = xgb_sklearn_model
|
||||
for data in iterator:
|
||||
if feature_col_names is not None:
|
||||
X = data[feature_col_names]
|
||||
if enable_sparse_data_optim:
|
||||
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
|
||||
else:
|
||||
X = stack_series(data[alias.data])
|
||||
if feature_col_names is not None:
|
||||
X = data[feature_col_names]
|
||||
else:
|
||||
X = stack_series(data[alias.data])
|
||||
|
||||
if has_base_margin:
|
||||
base_margin = data[alias.margin].to_numpy()
|
||||
@ -828,6 +932,7 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
||||
return XGBClassifier
|
||||
|
||||
def _transform(self, dataset):
|
||||
# pylint: disable=too-many-locals
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
# to avoid the `self` object to be pickled to remote.
|
||||
xgb_sklearn_model = self._xgb_sklearn_model
|
||||
@ -856,6 +961,7 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
||||
return raw_preds, class_probs
|
||||
|
||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||
|
||||
@pandas_udf(
|
||||
"rawPrediction array<double>, prediction double, probability array<double>"
|
||||
@ -865,10 +971,13 @@ class SparkXGBClassifierModel(_SparkXGBModel, HasProbabilityCol, HasRawPredictio
|
||||
) -> Iterator[pd.DataFrame]:
|
||||
model = xgb_sklearn_model
|
||||
for data in iterator:
|
||||
if feature_col_names is not None:
|
||||
X = data[feature_col_names]
|
||||
if enable_sparse_data_optim:
|
||||
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
|
||||
else:
|
||||
X = stack_series(data[alias.data])
|
||||
if feature_col_names is not None:
|
||||
X = data[feature_col_names]
|
||||
else:
|
||||
X = stack_series(data[alias.data])
|
||||
|
||||
if has_base_margin:
|
||||
base_margin = stack_series(data[alias.margin])
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tupl
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.sparse import csr_matrix
|
||||
from xgboost.compat import concat
|
||||
|
||||
from xgboost import DataIter, DeviceQuantileDMatrix, DMatrix
|
||||
@ -101,11 +102,55 @@ class PartIter(DataIter):
|
||||
self._iter = 0
|
||||
|
||||
|
||||
def _read_csr_matrix_from_unwrapped_spark_vec(part: pd.DataFrame) -> csr_matrix:
|
||||
# variables for constructing csr_matrix
|
||||
csr_indices_list, csr_indptr_list, csr_values_list = [], [0], []
|
||||
|
||||
n_features = 0
|
||||
|
||||
for vec_type, vec_size_, vec_indices, vec_values in zip(
|
||||
part.featureVectorType,
|
||||
part.featureVectorSize,
|
||||
part.featureVectorIndices,
|
||||
part.featureVectorValues,
|
||||
):
|
||||
if vec_type == 0:
|
||||
# sparse vector
|
||||
vec_size = int(vec_size_)
|
||||
csr_indices = vec_indices
|
||||
csr_values = vec_values
|
||||
else:
|
||||
# dense vector
|
||||
# Note: According to spark ML VectorUDT format,
|
||||
# when type field is 1, the size field is also empty.
|
||||
# we need to check the values field to get vector length.
|
||||
vec_size = len(vec_values)
|
||||
csr_indices = np.arange(vec_size, dtype=np.int32)
|
||||
csr_values = vec_values
|
||||
|
||||
if n_features == 0:
|
||||
n_features = vec_size
|
||||
assert n_features == vec_size
|
||||
|
||||
csr_indices_list.append(csr_indices)
|
||||
csr_indptr_list.append(csr_indptr_list[-1] + len(csr_indices))
|
||||
csr_values_list.append(csr_values)
|
||||
|
||||
csr_indptr_arr = np.array(csr_indptr_list)
|
||||
csr_indices_arr = np.concatenate(csr_indices_list)
|
||||
csr_values_arr = np.concatenate(csr_values_list)
|
||||
|
||||
return csr_matrix(
|
||||
(csr_values_arr, csr_indices_arr, csr_indptr_arr), shape=(len(part), n_features)
|
||||
)
|
||||
|
||||
|
||||
def create_dmatrix_from_partitions(
|
||||
iterator: Iterator[pd.DataFrame],
|
||||
feature_cols: Optional[Sequence[str]],
|
||||
gpu_id: Optional[int],
|
||||
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
|
||||
enable_sparse_data_optim: bool,
|
||||
) -> Tuple[DMatrix, Optional[DMatrix]]:
|
||||
"""Create DMatrix from spark data partitions. This is not particularly efficient as
|
||||
we need to convert the pandas series format to numpy then concatenate all the data.
|
||||
@ -118,7 +163,7 @@ def create_dmatrix_from_partitions(
|
||||
Metainfo for DMatrix.
|
||||
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-locals, too-many-statements
|
||||
train_data: Dict[str, List[np.ndarray]] = defaultdict(list)
|
||||
valid_data: Dict[str, List[np.ndarray]] = defaultdict(list)
|
||||
|
||||
@ -139,6 +184,23 @@ def create_dmatrix_from_partitions(
|
||||
else:
|
||||
train_data[name].append(array)
|
||||
|
||||
def append_m_sparse(part: pd.DataFrame, name: str, is_valid: bool) -> None:
|
||||
nonlocal n_features
|
||||
|
||||
if name == alias.data or name in part.columns:
|
||||
if name == alias.data:
|
||||
array = _read_csr_matrix_from_unwrapped_spark_vec(part)
|
||||
if n_features == 0:
|
||||
n_features = array.shape[1]
|
||||
assert n_features == array.shape[1]
|
||||
else:
|
||||
array = part[name]
|
||||
|
||||
if is_valid:
|
||||
valid_data[name].append(array)
|
||||
else:
|
||||
train_data[name].append(array)
|
||||
|
||||
def append_dqm(part: pd.DataFrame, name: str, is_valid: bool) -> None:
|
||||
"""Preprocessing for DeviceQuantileDMatrix"""
|
||||
nonlocal n_features
|
||||
@ -164,13 +226,19 @@ def create_dmatrix_from_partitions(
|
||||
label = concat_or_none(values.get(alias.label, None))
|
||||
weight = concat_or_none(values.get(alias.weight, None))
|
||||
margin = concat_or_none(values.get(alias.margin, None))
|
||||
|
||||
return DMatrix(
|
||||
data=data, label=label, weight=weight, base_margin=margin, **kwargs
|
||||
)
|
||||
|
||||
is_dmatrix = feature_cols is None
|
||||
if is_dmatrix:
|
||||
cache_partitions(iterator, append_m)
|
||||
if enable_sparse_data_optim:
|
||||
append_fn = append_m_sparse
|
||||
assert "missing" in kwargs and kwargs["missing"] == 0.0
|
||||
else:
|
||||
append_fn = append_m
|
||||
cache_partitions(iterator, append_fn)
|
||||
dtrain = make(train_data, kwargs)
|
||||
else:
|
||||
cache_partitions(iterator, append_dqm)
|
||||
|
||||
@ -50,3 +50,25 @@ class HasFeaturesCols(Params):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._setDefault(features_cols=[])
|
||||
|
||||
|
||||
class HasEnableSparseDataOptim(Params):
|
||||
|
||||
"""
|
||||
This is a Params based class that is extended by _SparkXGBParams
|
||||
and holds the variable to store the boolean config of enabling sparse data optimization.
|
||||
"""
|
||||
|
||||
enable_sparse_data_optim = Param(
|
||||
Params._dummy(),
|
||||
"enable_sparse_data_optim",
|
||||
"This stores the boolean config of enabling sparse data optimization, if enabled, "
|
||||
"Xgboost DMatrix object will be constructed from sparse matrix instead of "
|
||||
"dense matrix. This config is disabled by default. If most of examples in your "
|
||||
"training dataset contains sparse features, we suggest to enable this config.",
|
||||
typeConverter=TypeConverters.toBoolean,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._setDefault(enable_sparse_data_optim=False)
|
||||
|
||||
@ -32,9 +32,10 @@ dependencies:
|
||||
- cffi
|
||||
- pyarrow
|
||||
- protobuf
|
||||
- pyspark>=3.3.0
|
||||
- cloudpickle
|
||||
- shap
|
||||
- modin
|
||||
- pip:
|
||||
- datatable
|
||||
# TODO: Replace it with pyspark>=3.4 once 3.4 released.
|
||||
- https://ml-team-public-read.s3.us-west-2.amazonaws.com/pyspark-3.4.0.dev0.tar.gz
|
||||
|
||||
@ -11,7 +11,12 @@ if tm.no_spark()["condition"]:
|
||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
|
||||
|
||||
from xgboost.spark.data import alias, create_dmatrix_from_partitions, stack_series
|
||||
from xgboost.spark.data import (
|
||||
_read_csr_matrix_from_unwrapped_spark_vec,
|
||||
alias,
|
||||
create_dmatrix_from_partitions,
|
||||
stack_series,
|
||||
)
|
||||
|
||||
|
||||
def test_stack() -> None:
|
||||
@ -62,10 +67,12 @@ def run_dmatrix_ctor(is_dqm: bool) -> None:
|
||||
kwargs = {"feature_types": feature_types}
|
||||
if is_dqm:
|
||||
cols = [f"feat-{i}" for i in range(n_features)]
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), cols, 0, kwargs)
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||
iter(dfs), cols, 0, kwargs, False
|
||||
)
|
||||
else:
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||
iter(dfs), None, None, kwargs
|
||||
iter(dfs), None, None, kwargs, False
|
||||
)
|
||||
|
||||
assert valid_Xy is not None
|
||||
@ -100,3 +107,35 @@ def run_dmatrix_ctor(is_dqm: bool) -> None:
|
||||
|
||||
def test_dmatrix_ctor() -> None:
|
||||
run_dmatrix_ctor(False)
|
||||
|
||||
|
||||
def test_read_csr_matrix_from_unwrapped_spark_vec() -> None:
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
pd1 = pd.DataFrame(
|
||||
{
|
||||
"featureVectorType": [0, 1, 1, 0],
|
||||
"featureVectorSize": [3, None, None, 3],
|
||||
"featureVectorIndices": [
|
||||
np.array([0, 2], dtype=np.int32),
|
||||
None,
|
||||
None,
|
||||
np.array([1, 2], dtype=np.int32),
|
||||
],
|
||||
"featureVectorValues": [
|
||||
np.array([3.0, 0.0], dtype=np.float64),
|
||||
np.array([13.0, 14.0, 0.0], dtype=np.float64),
|
||||
np.array([0.0, 24.0, 25.0], dtype=np.float64),
|
||||
np.array([0.0, 35.0], dtype=np.float64),
|
||||
],
|
||||
}
|
||||
)
|
||||
sm = _read_csr_matrix_from_unwrapped_spark_vec(pd1)
|
||||
assert isinstance(sm, csr_matrix)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
sm.data, [3.0, 0.0, 13.0, 14.0, 0.0, 0.0, 24.0, 25.0, 0.0, 35.0]
|
||||
)
|
||||
np.testing.assert_array_equal(sm.indptr, [0, 2, 5, 8, 10])
|
||||
np.testing.assert_array_equal(sm.indices, [0, 2, 0, 1, 2, 0, 1, 2, 1, 2])
|
||||
assert sm.shape == (4, 3)
|
||||
|
||||
@ -381,6 +381,26 @@ class XgboostLocalTest(SparkTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
self.reg_df_sparse_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 0.0, 3.0, 0.0, 0.0), 0),
|
||||
(Vectors.sparse(5, {1: 1.0, 3: 5.5}), 1),
|
||||
(Vectors.sparse(5, {4: -3.0}), 2),
|
||||
]
|
||||
* 10,
|
||||
["features", "label"],
|
||||
)
|
||||
|
||||
self.cls_df_sparse_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 0.0, 3.0, 0.0, 0.0), 0),
|
||||
(Vectors.sparse(5, {1: 1.0, 3: 5.5}), 1),
|
||||
(Vectors.sparse(5, {4: -3.0}), 0),
|
||||
]
|
||||
* 10,
|
||||
["features", "label"],
|
||||
)
|
||||
|
||||
def get_local_tmp_dir(self):
|
||||
return self.tempdir + str(uuid.uuid4())
|
||||
|
||||
@ -972,3 +992,35 @@ class XgboostLocalTest(SparkTestCase):
|
||||
)
|
||||
model = classifier.fit(self.cls_df_train)
|
||||
model.transform(self.cls_df_test).collect()
|
||||
|
||||
def test_regressor_with_sparse_optim(self):
|
||||
regressor = SparkXGBRegressor(missing=0.0)
|
||||
model = regressor.fit(self.reg_df_sparse_train)
|
||||
assert model._xgb_sklearn_model.missing == 0.0
|
||||
pred_result = model.transform(self.reg_df_sparse_train).collect()
|
||||
|
||||
# enable sparse optimiaztion
|
||||
regressor2 = SparkXGBRegressor(missing=0.0, enable_sparse_data_optim=True)
|
||||
model2 = regressor2.fit(self.reg_df_sparse_train)
|
||||
assert model2.getOrDefault(model2.enable_sparse_data_optim)
|
||||
assert model2._xgb_sklearn_model.missing == 0.0
|
||||
pred_result2 = model2.transform(self.reg_df_sparse_train).collect()
|
||||
|
||||
for row1, row2 in zip(pred_result, pred_result2):
|
||||
self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3))
|
||||
|
||||
def test_classifier_with_sparse_optim(self):
|
||||
cls = SparkXGBClassifier(missing=0.0)
|
||||
model = cls.fit(self.cls_df_sparse_train)
|
||||
assert model._xgb_sklearn_model.missing == 0.0
|
||||
pred_result = model.transform(self.cls_df_sparse_train).collect()
|
||||
|
||||
# enable sparse optimiaztion
|
||||
cls2 = SparkXGBClassifier(missing=0.0, enable_sparse_data_optim=True)
|
||||
model2 = cls2.fit(self.cls_df_sparse_train)
|
||||
assert model2.getOrDefault(model2.enable_sparse_data_optim)
|
||||
assert model2._xgb_sklearn_model.missing == 0.0
|
||||
pred_result2 = model2.transform(self.cls_df_sparse_train).collect()
|
||||
|
||||
for row1, row2 in zip(pred_result, pred_result2):
|
||||
self.assertTrue(np.allclose(row1.probability, row2.probability, rtol=1e-3))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user