[pyspark] Make QDM optional based on cuDF check (#8471)
This commit is contained in:
parent
8f97c92541
commit
67ea1c3435
@ -43,6 +43,7 @@ except ImportError:
|
||||
pandas_concat = None
|
||||
PANDAS_INSTALLED = False
|
||||
|
||||
|
||||
# sklearn
|
||||
try:
|
||||
from sklearn.base import BaseEstimator as XGBModelBase
|
||||
@ -72,6 +73,22 @@ except ImportError:
|
||||
XGBStratifiedKFold = None
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_cudf_available() -> bool:
|
||||
"""Check cuDF package available or not"""
|
||||
if importlib.util.find_spec("cudf") is None:
|
||||
return False
|
||||
try:
|
||||
import cudf
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
_logger.exception("Importing cuDF failed, use DMatrix instead of QDM")
|
||||
return False
|
||||
|
||||
|
||||
class XGBoostLabelEncoder(LabelEncoder):
|
||||
"""Label encoder with JSON serialization methods."""
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ from pyspark.sql.types import (
|
||||
ShortType,
|
||||
)
|
||||
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
|
||||
from xgboost.compat import is_cudf_available
|
||||
from xgboost.core import Booster
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
@ -755,7 +756,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
k: v for k, v in train_call_kwargs_params.items() if v is not None
|
||||
}
|
||||
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
|
||||
use_qdm = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
|
||||
|
||||
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
|
||||
|
||||
def _train_booster(pandas_df_iter):
|
||||
"""Takes in an RDD partition and outputs a booster for that partition after
|
||||
@ -769,6 +771,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
gpu_id = None
|
||||
|
||||
# If cuDF is not installed, then using DMatrix instead of QDM,
|
||||
# because without cuDF, DMatrix performs better than QDM.
|
||||
# Note: Checking `is_cudf_available` in spark worker side because
|
||||
# spark worker might has different python environment with driver side.
|
||||
if use_gpu:
|
||||
use_qdm = use_hist and is_cudf_available()
|
||||
else:
|
||||
use_qdm = use_hist
|
||||
|
||||
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user