[pyspark] check use_qdm across all the workers (#8496)
This commit is contained in:
parent
5aeb8f7009
commit
40a1a2ffa8
@ -60,7 +60,6 @@ from .params import (
|
||||
)
|
||||
from .utils import (
|
||||
CommunicatorContext,
|
||||
_get_args_from_message_list,
|
||||
_get_default_params_from_func,
|
||||
_get_gpu_id,
|
||||
_get_max_num_concurrent_tasks,
|
||||
@ -771,15 +770,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
from pyspark import BarrierTaskContext
|
||||
|
||||
context = BarrierTaskContext.get()
|
||||
context.barrier()
|
||||
|
||||
gpu_id = None
|
||||
|
||||
# If cuDF is not installed, then using DMatrix instead of QDM,
|
||||
# because without cuDF, DMatrix performs better than QDM.
|
||||
# Note: Checking `is_cudf_available` in spark worker side because
|
||||
# spark worker might has different python environment with driver side.
|
||||
if use_gpu:
|
||||
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
|
||||
booster_params["gpu_id"] = gpu_id
|
||||
# If cuDF is not installed, then using DMatrix instead of QDM,
|
||||
# because without cuDF, DMatrix performs better than QDM.
|
||||
# Note: Checking `is_cudf_available` in spark worker side because
|
||||
# spark worker might has different python environment with driver side.
|
||||
use_qdm = use_hist and is_cudf_available()
|
||||
else:
|
||||
use_qdm = use_hist
|
||||
@ -787,10 +786,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||
|
||||
if use_gpu:
|
||||
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
|
||||
booster_params["gpu_id"] = gpu_id
|
||||
|
||||
_rabit_args = {}
|
||||
if context.partitionId() == 0:
|
||||
get_logger("XGBoostPySpark").debug(
|
||||
@ -804,8 +799,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
_rabit_args = _get_rabit_args(context, num_workers)
|
||||
|
||||
messages = context.allGather(message=json.dumps(_rabit_args))
|
||||
_rabit_args = _get_args_from_message_list(messages)
|
||||
worker_message = {
|
||||
"rabit_msg": _rabit_args,
|
||||
"use_qdm": use_qdm,
|
||||
}
|
||||
|
||||
messages = context.allGather(message=json.dumps(worker_message))
|
||||
if len(set(json.loads(x)["use_qdm"] for x in messages)) != 1:
|
||||
raise RuntimeError("The workers' cudf environments are in-consistent ")
|
||||
|
||||
_rabit_args = json.loads(messages[0])["rabit_msg"]
|
||||
|
||||
evals_result = {}
|
||||
with CommunicatorContext(context, **_rabit_args):
|
||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
"""Xgboost pyspark integration submodule for helper functions."""
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Dict, List, Set, Type
|
||||
from typing import Any, Callable, Dict, Set, Type
|
||||
|
||||
import pyspark
|
||||
from pyspark import BarrierTaskContext, SparkContext
|
||||
@ -81,16 +80,6 @@ def _get_host_ip(context: BarrierTaskContext) -> str:
|
||||
return task_ip_list[0]
|
||||
|
||||
|
||||
def _get_args_from_message_list(messages: List[str]) -> Dict[str, Any]:
|
||||
"""A function to send/recieve messages in barrier context mode"""
|
||||
output = ""
|
||||
for message in messages:
|
||||
if message != "":
|
||||
output = message
|
||||
break
|
||||
return json.loads(output)
|
||||
|
||||
|
||||
def _get_spark_session() -> SparkSession:
|
||||
"""Get or create spark session. Note: This function can only be invoked from driver
|
||||
side.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user