[pyspark] Cleanup data processing. (#8344)
* Enable additional combinations of ctor parameters. * Unify procedures for QuantileDMatrix and DMatrix.
This commit is contained in:
parent
521086d56b
commit
3901f5d9db
@ -83,10 +83,11 @@ generate result dataset with 3 new columns:
|
||||
XGBoost PySpark GPU support
|
||||
***************************
|
||||
|
||||
XGBoost PySpark supports GPU training and prediction. To enable GPU support, you first need
|
||||
to install the xgboost and cudf packages. Then you can set `use_gpu` parameter to `True`.
|
||||
XGBoost PySpark supports GPU training and prediction. To enable GPU support, first you
|
||||
need to install the XGBoost and the `cuDF <https://docs.rapids.ai/api/cudf/stable/>`_
|
||||
package. Then you can set `use_gpu` parameter to `True`.
|
||||
|
||||
Below tutorial will show you how to train a model with XGBoost PySpark GPU on Spark
|
||||
Below tutorial demonstrates how to train a model with XGBoost PySpark GPU on Spark
|
||||
standalone cluster.
|
||||
|
||||
|
||||
@ -138,7 +139,7 @@ in PySpark. Please refer to
|
||||
conda create -y -n xgboost-env -c conda-forge conda-pack python=3.9
|
||||
conda activate xgboost-env
|
||||
pip install xgboost
|
||||
pip install cudf
|
||||
conda install cudf -c rapids -c nvidia -c conda-forge
|
||||
conda pack -f -o xgboost-env.tar.gz
|
||||
|
||||
|
||||
@ -220,3 +221,6 @@ Below is a simple example submit command for enabling GPU acceleration:
|
||||
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
|
||||
--archives xgboost-env.tar.gz#environment \
|
||||
xgboost_app.py
|
||||
|
||||
When rapids plugin is enabled, both of the JVM rapids plugin and the cuDF Python are
|
||||
required for the acceleration.
|
||||
|
||||
@ -747,7 +747,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
k: v for k, v in train_call_kwargs_params.items() if v is not None
|
||||
}
|
||||
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
|
||||
use_qdm = booster_params.get("tree_method") in ("hist", "gpu_hist")
|
||||
use_qdm = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
|
||||
|
||||
def _train_booster(pandas_df_iter):
|
||||
"""Takes in an RDD partition and outputs a booster for that partition after
|
||||
|
||||
@ -208,14 +208,27 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
|
||||
|
||||
def append_m(part: pd.DataFrame, name: str, is_valid: bool) -> None:
|
||||
nonlocal n_features
|
||||
if name in part.columns and part[name].shape[0] > 0:
|
||||
array = part[name]
|
||||
if name == alias.data:
|
||||
if name == alias.data or name in part.columns:
|
||||
if (
|
||||
name == alias.data
|
||||
and feature_cols is not None
|
||||
and part[feature_cols].shape[0] > 0 # guard against empty partition
|
||||
):
|
||||
array: Optional[np.ndarray] = part[feature_cols]
|
||||
elif part[name].shape[0] > 0:
|
||||
array = part[name]
|
||||
array = stack_series(array)
|
||||
else:
|
||||
array = None
|
||||
|
||||
if name == alias.data and array is not None:
|
||||
if n_features == 0:
|
||||
n_features = array.shape[1]
|
||||
assert n_features == array.shape[1]
|
||||
|
||||
if array is None:
|
||||
return
|
||||
|
||||
if is_valid:
|
||||
valid_data[name].append(array)
|
||||
else:
|
||||
@ -238,26 +251,6 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
|
||||
else:
|
||||
train_data[name].append(array)
|
||||
|
||||
def append_qdm(part: pd.DataFrame, name: str, is_valid: bool) -> None:
|
||||
"""Preprocessing for QuantileDMatrix."""
|
||||
nonlocal n_features
|
||||
if name == alias.data or name in part.columns:
|
||||
if name == alias.data and feature_cols is not None:
|
||||
array = part[feature_cols]
|
||||
else:
|
||||
array = part[name]
|
||||
array = stack_series(array)
|
||||
|
||||
if name == alias.data:
|
||||
if n_features == 0:
|
||||
n_features = array.shape[1]
|
||||
assert n_features == array.shape[1]
|
||||
|
||||
if is_valid:
|
||||
valid_data[name].append(array)
|
||||
else:
|
||||
train_data[name].append(array)
|
||||
|
||||
def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix:
|
||||
if len(values) == 0:
|
||||
get_logger("XGBoostPySpark").warning(
|
||||
@ -305,13 +298,14 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
|
||||
|
||||
meta, params = split_params()
|
||||
|
||||
if feature_cols is not None: # rapidsai plugin
|
||||
assert gpu_id is not None
|
||||
assert use_qdm is True
|
||||
cache_partitions(iterator, append_qdm)
|
||||
if feature_cols is not None and use_qdm:
|
||||
cache_partitions(iterator, append_fn)
|
||||
dtrain: DMatrix = make_qdm(train_data, gpu_id, meta, None, params)
|
||||
elif use_qdm:
|
||||
cache_partitions(iterator, append_qdm)
|
||||
elif feature_cols is not None and not use_qdm:
|
||||
cache_partitions(iterator, append_fn)
|
||||
dtrain = make(train_data, kwargs)
|
||||
elif feature_cols is None and use_qdm:
|
||||
cache_partitions(iterator, append_fn)
|
||||
dtrain = make_qdm(train_data, gpu_id, meta, None, params)
|
||||
else:
|
||||
cache_partitions(iterator, append_fn)
|
||||
|
||||
@ -19,7 +19,9 @@ from test_spark.test_data import run_dmatrix_ctor
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_qdm_ctor() -> None:
|
||||
run_dmatrix_ctor(is_dqm=True, on_gpu=True)
|
||||
with pytest.raises(AssertionError):
|
||||
run_dmatrix_ctor(is_dqm=False, on_gpu=True)
|
||||
@pytest.mark.parametrize(
|
||||
"is_feature_cols,is_qdm",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
def test_dmatrix_ctor(is_feature_cols: bool, is_qdm: bool) -> None:
|
||||
run_dmatrix_ctor(is_feature_cols, is_qdm, on_gpu=True)
|
||||
|
||||
@ -18,6 +18,8 @@ from xgboost.spark.data import (
|
||||
stack_series,
|
||||
)
|
||||
|
||||
from xgboost import DMatrix, QuantileDMatrix
|
||||
|
||||
|
||||
def test_stack() -> None:
|
||||
a = pd.DataFrame({"a": [[1, 2], [3, 4]]})
|
||||
@ -37,7 +39,7 @@ def test_stack() -> None:
|
||||
assert b.shape == (2, 1)
|
||||
|
||||
|
||||
def run_dmatrix_ctor(is_dqm: bool, on_gpu: bool) -> None:
|
||||
def run_dmatrix_ctor(is_feature_cols: bool, is_qdm: bool, on_gpu: bool) -> None:
|
||||
rng = np.random.default_rng(0)
|
||||
dfs: List[pd.DataFrame] = []
|
||||
n_features = 16
|
||||
@ -57,7 +59,7 @@ def run_dmatrix_ctor(is_dqm: bool, on_gpu: bool) -> None:
|
||||
df = pd.DataFrame(
|
||||
{alias.label: y, alias.margin: m, alias.weight: w, alias.valid: valid}
|
||||
)
|
||||
if on_gpu:
|
||||
if is_feature_cols:
|
||||
for j in range(X.shape[1]):
|
||||
df[f"feat-{j}"] = pd.Series(X[:, j])
|
||||
else:
|
||||
@ -65,19 +67,27 @@ def run_dmatrix_ctor(is_dqm: bool, on_gpu: bool) -> None:
|
||||
dfs.append(df)
|
||||
|
||||
kwargs = {"feature_types": feature_types}
|
||||
if on_gpu:
|
||||
cols = [f"feat-{i}" for i in range(n_features)]
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||
iter(dfs), cols, 0, is_dqm, kwargs, False, True
|
||||
)
|
||||
elif is_dqm:
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||
iter(dfs), None, None, True, kwargs, False, True
|
||||
)
|
||||
device_id = 0 if on_gpu else None
|
||||
cols = [f"feat-{i}" for i in range(n_features)]
|
||||
feature_cols = cols if is_feature_cols else None
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||
iter(dfs),
|
||||
feature_cols,
|
||||
gpu_id=device_id,
|
||||
use_qdm=is_qdm,
|
||||
kwargs=kwargs,
|
||||
enable_sparse_data_optim=False,
|
||||
has_validation_col=True,
|
||||
)
|
||||
|
||||
if is_qdm:
|
||||
assert isinstance(train_Xy, QuantileDMatrix)
|
||||
assert isinstance(valid_Xy, QuantileDMatrix)
|
||||
else:
|
||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||
iter(dfs), None, None, False, kwargs, False, True
|
||||
)
|
||||
assert not isinstance(train_Xy, QuantileDMatrix)
|
||||
assert isinstance(train_Xy, DMatrix)
|
||||
assert not isinstance(valid_Xy, QuantileDMatrix)
|
||||
assert isinstance(valid_Xy, DMatrix)
|
||||
|
||||
assert valid_Xy is not None
|
||||
assert valid_Xy.num_row() + train_Xy.num_row() == n_samples_per_batch * n_batches
|
||||
@ -109,9 +119,12 @@ def run_dmatrix_ctor(is_dqm: bool, on_gpu: bool) -> None:
|
||||
np.testing.assert_equal(valid_Xy.feature_types, feature_types)
|
||||
|
||||
|
||||
def test_dmatrix_ctor() -> None:
|
||||
run_dmatrix_ctor(is_dqm=False, on_gpu=False)
|
||||
run_dmatrix_ctor(is_dqm=True, on_gpu=False)
|
||||
@pytest.mark.parametrize(
|
||||
"is_feature_cols,is_qdm",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
def test_dmatrix_ctor(is_feature_cols: bool, is_qdm: bool) -> None:
|
||||
run_dmatrix_ctor(is_feature_cols, is_qdm, on_gpu=False)
|
||||
|
||||
|
||||
def test_read_csr_matrix_from_unwrapped_spark_vec() -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user