[pyspark] check use_qdm across all the workers (#8496)

This commit is contained in:
Bobby Wang 2022-12-08 18:09:17 +08:00 committed by GitHub
parent 5aeb8f7009
commit 40a1a2ffa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 25 deletions

View File

@ -60,7 +60,6 @@ from .params import (
) )
from .utils import ( from .utils import (
CommunicatorContext, CommunicatorContext,
_get_args_from_message_list,
_get_default_params_from_func, _get_default_params_from_func,
_get_gpu_id, _get_gpu_id,
_get_max_num_concurrent_tasks, _get_max_num_concurrent_tasks,
@ -771,15 +770,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
from pyspark import BarrierTaskContext from pyspark import BarrierTaskContext
context = BarrierTaskContext.get() context = BarrierTaskContext.get()
context.barrier()
gpu_id = None gpu_id = None
# If cuDF is not installed, then using DMatrix instead of QDM,
# because without cuDF, DMatrix performs better than QDM.
# Note: Checking `is_cudf_available` in spark worker side because
# spark worker might has different python environment with driver side.
if use_gpu: if use_gpu:
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
booster_params["gpu_id"] = gpu_id
# If cuDF is not installed, then using DMatrix instead of QDM,
# because without cuDF, DMatrix performs better than QDM.
# Note: Checking `is_cudf_available` in spark worker side because
# spark worker might has different python environment with driver side.
use_qdm = use_hist and is_cudf_available() use_qdm = use_hist and is_cudf_available()
else: else:
use_qdm = use_hist use_qdm = use_hist
@ -787,10 +786,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
if use_qdm and (booster_params.get("max_bin", None) is not None): if use_qdm and (booster_params.get("max_bin", None) is not None):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"] dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
if use_gpu:
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
booster_params["gpu_id"] = gpu_id
_rabit_args = {} _rabit_args = {}
if context.partitionId() == 0: if context.partitionId() == 0:
get_logger("XGBoostPySpark").debug( get_logger("XGBoostPySpark").debug(
@ -804,8 +799,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
_rabit_args = _get_rabit_args(context, num_workers) _rabit_args = _get_rabit_args(context, num_workers)
messages = context.allGather(message=json.dumps(_rabit_args)) worker_message = {
_rabit_args = _get_args_from_message_list(messages) "rabit_msg": _rabit_args,
"use_qdm": use_qdm,
}
messages = context.allGather(message=json.dumps(worker_message))
if len(set(json.loads(x)["use_qdm"] for x in messages)) != 1:
raise RuntimeError("The workers' cudf environments are in-consistent ")
_rabit_args = json.loads(messages[0])["rabit_msg"]
evals_result = {} evals_result = {}
with CommunicatorContext(context, **_rabit_args): with CommunicatorContext(context, **_rabit_args):
dtrain, dvalid = create_dmatrix_from_partitions( dtrain, dvalid = create_dmatrix_from_partitions(

View File

@ -1,10 +1,9 @@
"""Xgboost pyspark integration submodule for helper functions.""" """Xgboost pyspark integration submodule for helper functions."""
import inspect import inspect
import json
import logging import logging
import sys import sys
from threading import Thread from threading import Thread
from typing import Any, Callable, Dict, List, Set, Type from typing import Any, Callable, Dict, Set, Type
import pyspark import pyspark
from pyspark import BarrierTaskContext, SparkContext from pyspark import BarrierTaskContext, SparkContext
@ -81,16 +80,6 @@ def _get_host_ip(context: BarrierTaskContext) -> str:
return task_ip_list[0] return task_ip_list[0]
def _get_args_from_message_list(messages: List[str]) -> Dict[str, Any]:
"""A function to send/recieve messages in barrier context mode"""
output = ""
for message in messages:
if message != "":
output = message
break
return json.loads(output)
def _get_spark_session() -> SparkSession: def _get_spark_session() -> SparkSession:
"""Get or create spark session. Note: This function can only be invoked from driver """Get or create spark session. Note: This function can only be invoked from driver
side. side.