[pyspark] Use quantile dmatrix. (#8284)
This commit is contained in:
parent
ce0382dcb0
commit
97a5b088a5
@ -349,7 +349,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
- ``reg:squaredlogerror``: regression with squared log loss :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`. All input labels are required to be greater than -1. Also, see metric ``rmsle`` for possible issue with this objective.
|
||||
- ``reg:logistic``: logistic regression.
|
||||
- ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss.
|
||||
- ``reg:absoluteerror``: Regression with L1 error. When tree model is used, leaf value is refreshed after tree construction.
|
||||
- ``reg:absoluteerror``: Regression with L1 error. When tree model is used, leaf value is refreshed after tree construction. If used in distributed training, the leaf value is calculated as the mean value from all workers, which is not guaranteed to be optimal.
|
||||
- ``binary:logistic``: logistic regression for binary classification, output probability
|
||||
- ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation
|
||||
- ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities.
|
||||
|
||||
@ -105,6 +105,11 @@ def from_cstr_to_pystr(data: CStrPptr, length: c_bst_ulong) -> List[str]:
|
||||
return res
|
||||
|
||||
|
||||
def make_jcargs(**kwargs: Any) -> bytes:
|
||||
"Make JSON-based arguments for C functions."
|
||||
return from_pystr_to_cstr(json.dumps(kwargs))
|
||||
|
||||
|
||||
IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int])
|
||||
|
||||
|
||||
@ -1256,7 +1261,7 @@ class _ProxyDMatrix(DMatrix):
|
||||
def _set_data_from_cuda_interface(self, data: DataType) -> None:
|
||||
"""Set data from CUDA array interface."""
|
||||
interface = data.__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
|
||||
interface_str = bytes(json.dumps(interface), "utf-8")
|
||||
_check_call(
|
||||
_LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str)
|
||||
)
|
||||
@ -1357,6 +1362,26 @@ class QuantileDMatrix(DMatrix):
|
||||
"Only one of the eval_qid or eval_group for each evaluation "
|
||||
"dataset should be provided."
|
||||
)
|
||||
if isinstance(data, DataIter):
|
||||
if any(
|
||||
info is not None
|
||||
for info in (
|
||||
label,
|
||||
weight,
|
||||
base_margin,
|
||||
feature_names,
|
||||
feature_types,
|
||||
group,
|
||||
qid,
|
||||
label_lower_bound,
|
||||
label_upper_bound,
|
||||
feature_weights,
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"If data iterator is used as input, data like label should be "
|
||||
"specified as batch argument."
|
||||
)
|
||||
|
||||
self._init(
|
||||
data,
|
||||
@ -1405,12 +1430,9 @@ class QuantileDMatrix(DMatrix):
|
||||
"in iterator to fix this error."
|
||||
)
|
||||
|
||||
args = {
|
||||
"nthread": self.nthread,
|
||||
"missing": self.missing,
|
||||
"max_bin": self.max_bin,
|
||||
}
|
||||
config = from_pystr_to_cstr(json.dumps(args))
|
||||
config = make_jcargs(
|
||||
nthread=self.nthread, missing=self.missing, max_bin=self.max_bin
|
||||
)
|
||||
ret = _LIB.XGQuantileDMatrixCreateFromCallback(
|
||||
None,
|
||||
it.proxy.handle,
|
||||
@ -2375,7 +2397,7 @@ class Booster:
|
||||
"""
|
||||
length = c_bst_ulong()
|
||||
cptr = ctypes.POINTER(ctypes.c_char)()
|
||||
config = from_pystr_to_cstr(json.dumps({"format": raw_format}))
|
||||
config = make_jcargs(format=raw_format)
|
||||
_check_call(
|
||||
_LIB.XGBoosterSaveModelToBuffer(
|
||||
self.handle, config, ctypes.byref(length), ctypes.byref(cptr)
|
||||
@ -2570,9 +2592,6 @@ class Booster:
|
||||
`n_classes`, otherwise they're scalars.
|
||||
"""
|
||||
fmap = os.fspath(os.path.expanduser(fmap))
|
||||
args = from_pystr_to_cstr(
|
||||
json.dumps({"importance_type": importance_type, "feature_map": fmap})
|
||||
)
|
||||
features = ctypes.POINTER(ctypes.c_char_p)()
|
||||
scores = ctypes.POINTER(ctypes.c_float)()
|
||||
n_out_features = c_bst_ulong()
|
||||
@ -2582,7 +2601,7 @@ class Booster:
|
||||
_check_call(
|
||||
_LIB.XGBoosterFeatureScore(
|
||||
self.handle,
|
||||
args,
|
||||
make_jcargs(importance_type=importance_type, feature_map=fmap),
|
||||
ctypes.byref(n_out_features),
|
||||
ctypes.byref(features),
|
||||
ctypes.byref(out_dim),
|
||||
|
||||
@ -573,6 +573,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
label_upper_bound: Optional[List[Any]] = None,
|
||||
feature_names: Optional[FeatureNames] = None,
|
||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||
feature_weights: Optional[Any] = None,
|
||||
) -> None:
|
||||
self._data = data
|
||||
self._label = label
|
||||
@ -583,6 +584,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
self._label_upper_bound = label_upper_bound
|
||||
self._feature_names = feature_names
|
||||
self._feature_types = feature_types
|
||||
self._feature_weights = feature_weights
|
||||
|
||||
assert isinstance(self._data, collections.abc.Sequence)
|
||||
|
||||
@ -633,6 +635,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
label_upper_bound=self._get("_label_upper_bound"),
|
||||
feature_names=feature_names,
|
||||
feature_types=self._feature_types,
|
||||
feature_weights=self._feature_weights,
|
||||
)
|
||||
self._iter += 1
|
||||
return 1
|
||||
@ -731,19 +734,21 @@ def _create_quantile_dmatrix(
|
||||
return d
|
||||
|
||||
unzipped_dict = _get_worker_parts(parts)
|
||||
it = DaskPartitionIter(**unzipped_dict)
|
||||
it = DaskPartitionIter(
|
||||
**unzipped_dict,
|
||||
feature_types=feature_types,
|
||||
feature_names=feature_names,
|
||||
feature_weights=feature_weights,
|
||||
)
|
||||
|
||||
dmatrix = QuantileDMatrix(
|
||||
it,
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=nthread,
|
||||
max_bin=max_bin,
|
||||
ref=ref,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
dmatrix.set_info(feature_weights=feature_weights)
|
||||
return dmatrix
|
||||
|
||||
|
||||
|
||||
@ -747,6 +747,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
k: v for k, v in train_call_kwargs_params.items() if v is not None
|
||||
}
|
||||
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
|
||||
use_qdm = booster_params.get("tree_method") in ("hist", "gpu_hist")
|
||||
|
||||
def _train_booster(pandas_df_iter):
|
||||
"""Takes in an RDD partition and outputs a booster for that partition after
|
||||
@ -759,20 +760,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
context.barrier()
|
||||
|
||||
gpu_id = None
|
||||
|
||||
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||
|
||||
if use_gpu:
|
||||
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
|
||||
booster_params["gpu_id"] = gpu_id
|
||||
|
||||
# max_bin is needed for qdm
|
||||
if (
|
||||
features_cols_names is not None
|
||||
and booster_params.get("max_bin", None) is not None
|
||||
):
|
||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||
|
||||
_rabit_args = {}
|
||||
if context.partitionId() == 0:
|
||||
get_logger("XGBoostPySpark").info(
|
||||
get_logger("XGBoostPySpark").debug(
|
||||
"booster params: %s\n"
|
||||
"train_call_kwargs_params: %s\n"
|
||||
"dmatrix_kwargs: %s",
|
||||
@ -791,6 +789,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
pandas_df_iter,
|
||||
features_cols_names,
|
||||
gpu_id,
|
||||
use_qdm,
|
||||
dmatrix_kwargs,
|
||||
enable_sparse_data_optim=enable_sparse_data_optim,
|
||||
has_validation_col=has_validation_col,
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
"""Utilities for processing spark partitions."""
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.sparse import csr_matrix
|
||||
from xgboost.compat import concat
|
||||
|
||||
from xgboost import DataIter, DeviceQuantileDMatrix, DMatrix
|
||||
from xgboost import DataIter, DMatrix, QuantileDMatrix
|
||||
|
||||
from .utils import get_logger # type: ignore
|
||||
|
||||
@ -67,10 +67,13 @@ def cache_partitions(
|
||||
class PartIter(DataIter):
|
||||
"""Iterator for creating Quantile DMatrix from partitions."""
|
||||
|
||||
def __init__(self, data: Dict[str, List], device_id: Optional[int]) -> None:
|
||||
def __init__(
|
||||
self, data: Dict[str, List], device_id: Optional[int], **kwargs: Any
|
||||
) -> None:
|
||||
self._iter = 0
|
||||
self._device_id = device_id
|
||||
self._data = data
|
||||
self._kwargs = kwargs
|
||||
|
||||
super().__init__()
|
||||
|
||||
@ -98,6 +101,7 @@ class PartIter(DataIter):
|
||||
weight=self._fetch(self._data.get(alias.weight, None)),
|
||||
base_margin=self._fetch(self._data.get(alias.margin, None)),
|
||||
qid=self._fetch(self._data.get(alias.qid, None)),
|
||||
**self._kwargs,
|
||||
)
|
||||
self._iter += 1
|
||||
return 1
|
||||
@ -149,24 +153,52 @@ def _read_csr_matrix_from_unwrapped_spark_vec(part: pd.DataFrame) -> csr_matrix:
|
||||
)
|
||||
|
||||
|
||||
def make_qdm(
|
||||
data: Dict[str, List[np.ndarray]],
|
||||
gpu_id: Optional[int],
|
||||
meta: Dict[str, Any],
|
||||
ref: Optional[DMatrix],
|
||||
params: Dict[str, Any],
|
||||
) -> DMatrix:
|
||||
"""Handle empty partition for QuantileDMatrix."""
|
||||
if not data:
|
||||
return QuantileDMatrix(np.empty((0, 0)), ref=ref)
|
||||
it = PartIter(data, gpu_id, **meta)
|
||||
m = QuantileDMatrix(it, **params, ref=ref)
|
||||
return m
|
||||
|
||||
|
||||
def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
|
||||
iterator: Iterator[pd.DataFrame],
|
||||
feature_cols: Optional[Sequence[str]],
|
||||
gpu_id: Optional[int],
|
||||
use_qdm: bool,
|
||||
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
|
||||
enable_sparse_data_optim: bool,
|
||||
has_validation_col: bool,
|
||||
) -> Tuple[DMatrix, Optional[DMatrix]]:
|
||||
"""Create DMatrix from spark data partitions. This is not particularly efficient as
|
||||
we need to convert the pandas series format to numpy then concatenate all the data.
|
||||
"""Create DMatrix from spark data partitions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
iterator :
|
||||
Pyspark partition iterator.
|
||||
feature_cols:
|
||||
A sequence of feature names, used only when rapids plugin is enabled.
|
||||
gpu_id:
|
||||
Device ordinal, used when GPU is enabled.
|
||||
use_qdm :
|
||||
Whether QuantileDMatrix should be used instead of DMatrix.
|
||||
kwargs :
|
||||
Metainfo for DMatrix.
|
||||
enable_sparse_data_optim :
|
||||
Whether sparse data should be unwrapped
|
||||
has_validation:
|
||||
Whether there's validation data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Training DMatrix and an optional validation DMatrix.
|
||||
"""
|
||||
# pylint: disable=too-many-locals, too-many-statements
|
||||
train_data: Dict[str, List[np.ndarray]] = defaultdict(list)
|
||||
@ -206,16 +238,16 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
|
||||
else:
|
||||
train_data[name].append(array)
|
||||
|
||||
def append_dqm(part: pd.DataFrame, name: str, is_valid: bool) -> None:
|
||||
"""Preprocessing for DeviceQuantileDMatrix"""
|
||||
def append_qdm(part: pd.DataFrame, name: str, is_valid: bool) -> None:
|
||||
"""Preprocessing for QuantileDMatrix."""
|
||||
nonlocal n_features
|
||||
if name == alias.data or name in part.columns:
|
||||
if name == alias.data:
|
||||
cname = feature_cols
|
||||
if name == alias.data and feature_cols is not None:
|
||||
array = part[feature_cols]
|
||||
else:
|
||||
cname = name
|
||||
array = part[name]
|
||||
array = stack_series(array)
|
||||
|
||||
array = part[cname]
|
||||
if name == alias.data:
|
||||
if n_features == 0:
|
||||
n_features = array.shape[1]
|
||||
@ -228,6 +260,10 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
|
||||
|
||||
def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix:
|
||||
if len(values) == 0:
|
||||
get_logger("XGBoostPySpark").warning(
|
||||
"Detected an empty partition in the training data. Consider to enable"
|
||||
" repartition_random_shuffle"
|
||||
)
|
||||
# We must construct an empty DMatrix to bypass the AllReduce
|
||||
return DMatrix(data=np.empty((0, 0)), **kwargs)
|
||||
|
||||
@ -240,32 +276,62 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
|
||||
data=data, label=label, weight=weight, base_margin=margin, qid=qid, **kwargs
|
||||
)
|
||||
|
||||
is_dmatrix = feature_cols is None
|
||||
if is_dmatrix:
|
||||
if enable_sparse_data_optim:
|
||||
append_fn = append_m_sparse
|
||||
assert "missing" in kwargs and kwargs["missing"] == 0.0
|
||||
else:
|
||||
append_fn = append_m
|
||||
cache_partitions(iterator, append_fn)
|
||||
if len(train_data) == 0:
|
||||
get_logger("XGBoostPySpark").warning(
|
||||
"Detected an empty partition in the training data. "
|
||||
"Consider to enable repartition_random_shuffle"
|
||||
)
|
||||
dtrain = make(train_data, kwargs)
|
||||
if enable_sparse_data_optim:
|
||||
append_fn = append_m_sparse
|
||||
assert "missing" in kwargs and kwargs["missing"] == 0.0
|
||||
else:
|
||||
cache_partitions(iterator, append_dqm)
|
||||
it = PartIter(train_data, gpu_id)
|
||||
dtrain = DeviceQuantileDMatrix(it, **kwargs)
|
||||
append_fn = append_m
|
||||
|
||||
def split_params() -> Tuple[Dict[str, Any], Dict[str, Union[int, float, bool]]]:
|
||||
# FIXME(jiamingy): we really need a better way to bridge distributed frameworks
|
||||
# to XGBoost native interface and prevent scattering parameters like this.
|
||||
|
||||
# parameters that are not related to data.
|
||||
non_data_keys = (
|
||||
"max_bin",
|
||||
"missing",
|
||||
"silent",
|
||||
"nthread",
|
||||
"enable_categorical",
|
||||
)
|
||||
non_data_params = {}
|
||||
meta = {}
|
||||
for k, v in kwargs.items():
|
||||
if k in non_data_keys:
|
||||
non_data_params[k] = v
|
||||
else:
|
||||
meta[k] = v
|
||||
return meta, non_data_params
|
||||
|
||||
meta, params = split_params()
|
||||
|
||||
if feature_cols is not None: # rapidsai plugin
|
||||
assert gpu_id is not None
|
||||
assert use_qdm is True
|
||||
cache_partitions(iterator, append_qdm)
|
||||
dtrain: DMatrix = make_qdm(train_data, gpu_id, meta, None, params)
|
||||
elif use_qdm:
|
||||
cache_partitions(iterator, append_qdm)
|
||||
dtrain = make_qdm(train_data, gpu_id, meta, None, params)
|
||||
else:
|
||||
cache_partitions(iterator, append_fn)
|
||||
dtrain = make(train_data, kwargs)
|
||||
|
||||
# Using has_validation_col here to indicate if there is validation col
|
||||
# instead of getting it from iterator, since the iterator may be empty
|
||||
# in some special case. That is to say, we must ensure every worker
|
||||
# construct DMatrix even there is no any data since we need to ensure every
|
||||
# construct DMatrix even there is no data since we need to ensure every
|
||||
# worker do the AllReduce when constructing DMatrix, or else it may hang
|
||||
# forever.
|
||||
dvalid = make(valid_data, kwargs) if has_validation_col else None
|
||||
if has_validation_col:
|
||||
if use_qdm:
|
||||
dvalid: Optional[DMatrix] = make_qdm(
|
||||
valid_data, gpu_id, meta, dtrain, params
|
||||
)
|
||||
else:
|
||||
dvalid = make(valid_data, kwargs) if has_validation_col else None
|
||||
else:
|
||||
dvalid = None
|
||||
|
||||
if dvalid is not None:
|
||||
assert dvalid.num_col() == dtrain.num_col()
|
||||
|
||||
@ -20,4 +20,6 @@ from test_spark.test_data import run_dmatrix_ctor
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_qdm_ctor() -> None:
|
||||
run_dmatrix_ctor(True)
|
||||
run_dmatrix_ctor(is_dqm=True, on_gpu=True)
|
||||
with pytest.raises(AssertionError):
|
||||
run_dmatrix_ctor(is_dqm=False, on_gpu=True)
|
||||
|
||||
@ -188,12 +188,9 @@ def run_gpu_hist(
|
||||
|
||||
# See note on `ObjFunction::UpdateTreeLeaf`.
|
||||
update_leaf = dataset.name.endswith("-l1")
|
||||
if update_leaf and len(history) == 2:
|
||||
if update_leaf:
|
||||
assert history[0] + 1e-2 >= history[-1]
|
||||
return
|
||||
if update_leaf and len(history) > 2:
|
||||
assert history[0] >= history[-1]
|
||||
return
|
||||
else:
|
||||
assert tm.non_increasing(history)
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ def test_stack() -> None:
|
||||
assert b.shape == (2, 1)
|
||||
|
||||
|
||||
def run_dmatrix_ctor(is_dqm: bool) -> None:
|
||||
def run_dmatrix_ctor(is_dqm: bool, on_gpu: bool) -> None:
|
||||
rng = np.random.default_rng(0)
|
||||
dfs: List[pd.DataFrame] = []
|
||||
n_features = 16
|
||||
@ -57,7 +57,7 @@ def run_dmatrix_ctor(is_dqm: bool) -> None:
|
||||
df = pd.DataFrame(
|
||||
{alias.label: y, alias.margin: m, alias.weight: w, alias.valid: valid}
|
||||
)
|
||||
if is_dqm:
|
||||
if on_gpu:
|
||||
for j in range(X.shape[1]):
|
||||
df[f"feat-{j}"] = pd.Series(X[:, j])
|
||||
else:
|
||||
@ -65,14 +65,18 @@ def run_dmatrix_ctor(is_dqm: bool) -> None:
|
||||
dfs.append(df)
|
||||
|
||||
kwargs = {"feature_types": feature_types}
|
||||
if is_dqm:
|
||||
if on_gpu:
|
||||
cols = [f"feat-{i}" for i in range(n_features)]
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||
iter(dfs), cols, 0, kwargs, False, True
|
||||
iter(dfs), cols, 0, is_dqm, kwargs, False, True
|
||||
)
|
||||
elif is_dqm:
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||
iter(dfs), None, None, True, kwargs, False, True
|
||||
)
|
||||
else:
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||
iter(dfs), None, None, kwargs, False, True
|
||||
iter(dfs), None, None, False, kwargs, False, True
|
||||
)
|
||||
|
||||
assert valid_Xy is not None
|
||||
@ -106,7 +110,8 @@ def run_dmatrix_ctor(is_dqm: bool) -> None:
|
||||
|
||||
|
||||
def test_dmatrix_ctor() -> None:
|
||||
run_dmatrix_ctor(False)
|
||||
run_dmatrix_ctor(is_dqm=False, on_gpu=False)
|
||||
run_dmatrix_ctor(is_dqm=True, on_gpu=False)
|
||||
|
||||
|
||||
def test_read_csr_matrix_from_unwrapped_spark_vec() -> None:
|
||||
|
||||
@ -1047,67 +1047,79 @@ class XgboostLocalTest(SparkTestCase):
|
||||
for row in pred_result:
|
||||
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
|
||||
|
||||
def test_empty_validation_data(self):
|
||||
df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(10.1, 11.2, 11.3), 0, False),
|
||||
(Vectors.dense(1, 1.2, 1.3), 1, False),
|
||||
(Vectors.dense(14.0, 15.0, 16.0), 0, False),
|
||||
(Vectors.dense(1.1, 1.2, 1.3), 1, True),
|
||||
],
|
||||
["features", "label", "val_col"],
|
||||
)
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=2,
|
||||
min_child_weight=0.0,
|
||||
reg_alpha=0,
|
||||
reg_lambda=0,
|
||||
validation_indicator_col="val_col",
|
||||
)
|
||||
model = classifier.fit(df_train)
|
||||
pred_result = model.transform(df_train).collect()
|
||||
for row in pred_result:
|
||||
self.assertEqual(row.prediction, row.label)
|
||||
def test_empty_validation_data(self) -> None:
|
||||
for tree_method in [
|
||||
"hist",
|
||||
"approx",
|
||||
]: # pytest.mark conflict with python unittest
|
||||
df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(10.1, 11.2, 11.3), 0, False),
|
||||
(Vectors.dense(1, 1.2, 1.3), 1, False),
|
||||
(Vectors.dense(14.0, 15.0, 16.0), 0, False),
|
||||
(Vectors.dense(1.1, 1.2, 1.3), 1, True),
|
||||
],
|
||||
["features", "label", "val_col"],
|
||||
)
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=2,
|
||||
tree_method=tree_method,
|
||||
min_child_weight=0.0,
|
||||
reg_alpha=0,
|
||||
reg_lambda=0,
|
||||
validation_indicator_col="val_col",
|
||||
)
|
||||
model = classifier.fit(df_train)
|
||||
pred_result = model.transform(df_train).collect()
|
||||
for row in pred_result:
|
||||
self.assertEqual(row.prediction, row.label)
|
||||
|
||||
def test_empty_train_data(self):
|
||||
df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(10.1, 11.2, 11.3), 0, True),
|
||||
(Vectors.dense(1, 1.2, 1.3), 1, True),
|
||||
(Vectors.dense(14.0, 15.0, 16.0), 0, True),
|
||||
(Vectors.dense(1.1, 1.2, 1.3), 1, False),
|
||||
],
|
||||
["features", "label", "val_col"],
|
||||
)
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=2,
|
||||
min_child_weight=0.0,
|
||||
reg_alpha=0,
|
||||
reg_lambda=0,
|
||||
validation_indicator_col="val_col",
|
||||
)
|
||||
model = classifier.fit(df_train)
|
||||
pred_result = model.transform(df_train).collect()
|
||||
for row in pred_result:
|
||||
self.assertEqual(row.prediction, 1.0)
|
||||
def test_empty_train_data(self) -> None:
|
||||
for tree_method in [
|
||||
"hist",
|
||||
"approx",
|
||||
]: # pytest.mark conflict with python unittest
|
||||
df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(10.1, 11.2, 11.3), 0, True),
|
||||
(Vectors.dense(1, 1.2, 1.3), 1, True),
|
||||
(Vectors.dense(14.0, 15.0, 16.0), 0, True),
|
||||
(Vectors.dense(1.1, 1.2, 1.3), 1, False),
|
||||
],
|
||||
["features", "label", "val_col"],
|
||||
)
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=2,
|
||||
min_child_weight=0.0,
|
||||
reg_alpha=0,
|
||||
reg_lambda=0,
|
||||
tree_method=tree_method,
|
||||
validation_indicator_col="val_col",
|
||||
)
|
||||
model = classifier.fit(df_train)
|
||||
pred_result = model.transform(df_train).collect()
|
||||
for row in pred_result:
|
||||
assert row.prediction == 1.0
|
||||
|
||||
def test_empty_partition(self):
|
||||
# raw_df.repartition(4) will result int severe data skew, actually,
|
||||
# there is no any data in reducer partition 1, reducer partition 2
|
||||
# see https://github.com/dmlc/xgboost/issues/8221
|
||||
raw_df = self.session.range(0, 100, 1, 50).withColumn(
|
||||
"label", spark_sql_func.when(spark_sql_func.rand(1) > 0.5, 1).otherwise(0)
|
||||
)
|
||||
vector_assembler = (
|
||||
VectorAssembler().setInputCols(["id"]).setOutputCol("features")
|
||||
)
|
||||
data_trans = vector_assembler.setHandleInvalid("keep").transform(raw_df)
|
||||
data_trans.show(100)
|
||||
for tree_method in [
|
||||
"hist",
|
||||
"approx",
|
||||
]: # pytest.mark conflict with python unittest
|
||||
raw_df = self.session.range(0, 100, 1, 50).withColumn(
|
||||
"label",
|
||||
spark_sql_func.when(spark_sql_func.rand(1) > 0.5, 1).otherwise(0),
|
||||
)
|
||||
vector_assembler = (
|
||||
VectorAssembler().setInputCols(["id"]).setOutputCol("features")
|
||||
)
|
||||
data_trans = vector_assembler.setHandleInvalid("keep").transform(raw_df)
|
||||
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=4,
|
||||
)
|
||||
classifier.fit(data_trans)
|
||||
classifier = SparkXGBClassifier(num_workers=4, tree_method=tree_method)
|
||||
classifier.fit(data_trans)
|
||||
|
||||
def test_early_stop_param_validation(self):
|
||||
classifier = SparkXGBClassifier(early_stopping_rounds=1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user