diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index caa6e3cd0..1b42411f6 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -60,7 +60,6 @@ from .params import ( ) from .utils import ( CommunicatorContext, - _get_args_from_message_list, _get_default_params_from_func, _get_gpu_id, _get_max_num_concurrent_tasks, @@ -771,15 +770,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): from pyspark import BarrierTaskContext context = BarrierTaskContext.get() - context.barrier() gpu_id = None - - # If cuDF is not installed, then using DMatrix instead of QDM, - # because without cuDF, DMatrix performs better than QDM. - # Note: Checking `is_cudf_available` in spark worker side because - # spark worker might has different python environment with driver side. if use_gpu: + gpu_id = context.partitionId() if is_local else _get_gpu_id(context) + booster_params["gpu_id"] = gpu_id + # If cuDF is not installed, then using DMatrix instead of QDM, + # because without cuDF, DMatrix performs better than QDM. + # Note: Checking `is_cudf_available` in spark worker side because + # spark worker might has different python environment with driver side. use_qdm = use_hist and is_cudf_available() else: use_qdm = use_hist @@ -787,10 +786,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): if use_qdm and (booster_params.get("max_bin", None) is not None): dmatrix_kwargs["max_bin"] = booster_params["max_bin"] - if use_gpu: - gpu_id = context.partitionId() if is_local else _get_gpu_id(context) - booster_params["gpu_id"] = gpu_id - _rabit_args = {} if context.partitionId() == 0: get_logger("XGBoostPySpark").debug( @@ -804,8 +799,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): _rabit_args = _get_rabit_args(context, num_workers) - messages = context.allGather(message=json.dumps(_rabit_args)) - _rabit_args = _get_args_from_message_list(messages) + worker_message = { + "rabit_msg": _rabit_args, + "use_qdm": use_qdm, + } + + messages = context.allGather(message=json.dumps(worker_message)) + if len(set(json.loads(x)["use_qdm"] for x in messages)) != 1: + raise RuntimeError("The workers' cudf environments are in-consistent ") + + _rabit_args = json.loads(messages[0])["rabit_msg"] + evals_result = {} with CommunicatorContext(context, **_rabit_args): dtrain, dvalid = create_dmatrix_from_partitions( diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 36705459a..189396089 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -1,10 +1,9 @@ """Xgboost pyspark integration submodule for helper functions.""" import inspect -import json import logging import sys from threading import Thread -from typing import Any, Callable, Dict, List, Set, Type +from typing import Any, Callable, Dict, Set, Type import pyspark from pyspark import BarrierTaskContext, SparkContext @@ -81,16 +80,6 @@ def _get_host_ip(context: BarrierTaskContext) -> str: return task_ip_list[0] -def _get_args_from_message_list(messages: List[str]) -> Dict[str, Any]: - """A function to send/recieve messages in barrier context mode""" - output = "" - for message in messages: - if message != "": - output = message - break - return json.loads(output) - - def _get_spark_session() -> SparkSession: """Get or create spark session. Note: This function can only be invoked from driver side.