Co-authored-by: WeichenXu <weichen.xu@databricks.com>
This commit is contained in:
parent
60a8c8ebba
commit
59c54e361b
@ -43,6 +43,7 @@ except ImportError:
|
|||||||
pandas_concat = None
|
pandas_concat = None
|
||||||
PANDAS_INSTALLED = False
|
PANDAS_INSTALLED = False
|
||||||
|
|
||||||
|
|
||||||
# sklearn
|
# sklearn
|
||||||
try:
|
try:
|
||||||
from sklearn.base import BaseEstimator as XGBModelBase
|
from sklearn.base import BaseEstimator as XGBModelBase
|
||||||
@ -72,6 +73,22 @@ except ImportError:
|
|||||||
XGBStratifiedKFold = None
|
XGBStratifiedKFold = None
|
||||||
|
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cudf_available() -> bool:
|
||||||
|
"""Check cuDF package available or not"""
|
||||||
|
if importlib.util.find_spec("cudf") is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
import cudf
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
_logger.exception("Importing cuDF failed, use DMatrix instead of QDM")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class XGBoostLabelEncoder(LabelEncoder):
|
class XGBoostLabelEncoder(LabelEncoder):
|
||||||
"""Label encoder with JSON serialization methods."""
|
"""Label encoder with JSON serialization methods."""
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from pyspark.sql.types import (
|
|||||||
ShortType,
|
ShortType,
|
||||||
)
|
)
|
||||||
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
|
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
|
||||||
|
from xgboost.compat import is_cudf_available
|
||||||
from xgboost.core import Booster
|
from xgboost.core import Booster
|
||||||
from xgboost.training import train as worker_train
|
from xgboost.training import train as worker_train
|
||||||
|
|
||||||
@ -759,7 +760,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
k: v for k, v in train_call_kwargs_params.items() if v is not None
|
k: v for k, v in train_call_kwargs_params.items() if v is not None
|
||||||
}
|
}
|
||||||
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
|
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
|
||||||
use_qdm = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
|
|
||||||
|
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
|
||||||
|
|
||||||
def _train_booster(pandas_df_iter):
|
def _train_booster(pandas_df_iter):
|
||||||
"""Takes in an RDD partition and outputs a booster for that partition after
|
"""Takes in an RDD partition and outputs a booster for that partition after
|
||||||
@ -773,6 +775,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
|
|
||||||
gpu_id = None
|
gpu_id = None
|
||||||
|
|
||||||
|
# If cuDF is not installed, then using DMatrix instead of QDM,
|
||||||
|
# because without cuDF, DMatrix performs better than QDM.
|
||||||
|
# Note: Checking `is_cudf_available` in spark worker side because
|
||||||
|
# spark worker might has different python environment with driver side.
|
||||||
|
if use_gpu:
|
||||||
|
use_qdm = use_hist and is_cudf_available()
|
||||||
|
else:
|
||||||
|
use_qdm = use_hist
|
||||||
|
|
||||||
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
||||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user