[pyspark] Make QDM optional based on cuDF check (#8471)

This commit is contained in:
WeichenXu 2022-11-27 14:58:54 +08:00 committed by GitHub
parent 8f97c92541
commit 67ea1c3435
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 1 deletions

View File

@ -43,6 +43,7 @@ except ImportError:
pandas_concat = None
PANDAS_INSTALLED = False
# sklearn
try:
from sklearn.base import BaseEstimator as XGBModelBase
@ -72,6 +73,22 @@ except ImportError:
XGBStratifiedKFold = None
_logger = logging.getLogger(__name__)
def is_cudf_available() -> bool:
"""Check cuDF package available or not"""
if importlib.util.find_spec("cudf") is None:
return False
try:
import cudf
return True
except ImportError:
_logger.exception("Importing cuDF failed, use DMatrix instead of QDM")
return False
class XGBoostLabelEncoder(LabelEncoder):
"""Label encoder with JSON serialization methods."""

View File

@ -32,6 +32,7 @@ from pyspark.sql.types import (
ShortType,
)
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
from xgboost.compat import is_cudf_available
from xgboost.core import Booster
from xgboost.training import train as worker_train
@ -755,7 +756,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
k: v for k, v in train_call_kwargs_params.items() if v is not None
}
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
use_qdm = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
def _train_booster(pandas_df_iter):
"""Takes in an RDD partition and outputs a booster for that partition after
@ -769,6 +771,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
gpu_id = None
# If cuDF is not installed, then using DMatrix instead of QDM,
# because without cuDF, DMatrix performs better than QDM.
# Note: Checking `is_cudf_available` in spark worker side because
# spark worker might has different python environment with driver side.
if use_gpu:
use_qdm = use_hist and is_cudf_available()
else:
use_qdm = use_hist
if use_qdm and (booster_params.get("max_bin", None) is not None):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]