[pyspark] check use_qdm across all the workers (#8496)

This commit is contained in:
Bobby Wang 2022-12-08 18:09:17 +08:00 committed by GitHub
parent 5aeb8f7009
commit 40a1a2ffa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 25 deletions

View File

@ -60,7 +60,6 @@ from .params import (
)
from .utils import (
CommunicatorContext,
_get_args_from_message_list,
_get_default_params_from_func,
_get_gpu_id,
_get_max_num_concurrent_tasks,
@ -771,15 +770,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
from pyspark import BarrierTaskContext
context = BarrierTaskContext.get()
context.barrier()
gpu_id = None
# If cuDF is not installed, then using DMatrix instead of QDM,
# because without cuDF, DMatrix performs better than QDM.
# Note: Checking `is_cudf_available` in spark worker side because
# spark worker might has different python environment with driver side.
if use_gpu:
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
booster_params["gpu_id"] = gpu_id
# If cuDF is not installed, then using DMatrix instead of QDM,
# because without cuDF, DMatrix performs better than QDM.
# Note: Checking `is_cudf_available` in spark worker side because
# spark worker might has different python environment with driver side.
use_qdm = use_hist and is_cudf_available()
else:
use_qdm = use_hist
@ -787,10 +786,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
if use_qdm and (booster_params.get("max_bin", None) is not None):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
if use_gpu:
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
booster_params["gpu_id"] = gpu_id
_rabit_args = {}
if context.partitionId() == 0:
get_logger("XGBoostPySpark").debug(
@ -804,8 +799,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
_rabit_args = _get_rabit_args(context, num_workers)
messages = context.allGather(message=json.dumps(_rabit_args))
_rabit_args = _get_args_from_message_list(messages)
worker_message = {
"rabit_msg": _rabit_args,
"use_qdm": use_qdm,
}
messages = context.allGather(message=json.dumps(worker_message))
if len(set(json.loads(x)["use_qdm"] for x in messages)) != 1:
raise RuntimeError("The workers' cudf environments are in-consistent ")
_rabit_args = json.loads(messages[0])["rabit_msg"]
evals_result = {}
with CommunicatorContext(context, **_rabit_args):
dtrain, dvalid = create_dmatrix_from_partitions(

View File

@ -1,10 +1,9 @@
"""Xgboost pyspark integration submodule for helper functions."""
import inspect
import json
import logging
import sys
from threading import Thread
from typing import Any, Callable, Dict, List, Set, Type
from typing import Any, Callable, Dict, Set, Type
import pyspark
from pyspark import BarrierTaskContext, SparkContext
@ -81,16 +80,6 @@ def _get_host_ip(context: BarrierTaskContext) -> str:
return task_ip_list[0]
def _get_args_from_message_list(messages: List[str]) -> Dict[str, Any]:
"""A function to send/recieve messages in barrier context mode"""
output = ""
for message in messages:
if message != "":
output = message
break
return json.loads(output)
def _get_spark_session() -> SparkSession:
"""Get or create spark session. Note: This function can only be invoked from driver
side.