[pyspark] check use_qdm across all the workers (#8496)
This commit is contained in:
parent
5aeb8f7009
commit
40a1a2ffa8
@ -60,7 +60,6 @@ from .params import (
|
|||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
CommunicatorContext,
|
CommunicatorContext,
|
||||||
_get_args_from_message_list,
|
|
||||||
_get_default_params_from_func,
|
_get_default_params_from_func,
|
||||||
_get_gpu_id,
|
_get_gpu_id,
|
||||||
_get_max_num_concurrent_tasks,
|
_get_max_num_concurrent_tasks,
|
||||||
@ -771,15 +770,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
from pyspark import BarrierTaskContext
|
from pyspark import BarrierTaskContext
|
||||||
|
|
||||||
context = BarrierTaskContext.get()
|
context = BarrierTaskContext.get()
|
||||||
context.barrier()
|
|
||||||
|
|
||||||
gpu_id = None
|
gpu_id = None
|
||||||
|
|
||||||
# If cuDF is not installed, then using DMatrix instead of QDM,
|
|
||||||
# because without cuDF, DMatrix performs better than QDM.
|
|
||||||
# Note: Checking `is_cudf_available` in spark worker side because
|
|
||||||
# spark worker might has different python environment with driver side.
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
|
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
|
||||||
|
booster_params["gpu_id"] = gpu_id
|
||||||
|
# If cuDF is not installed, then using DMatrix instead of QDM,
|
||||||
|
# because without cuDF, DMatrix performs better than QDM.
|
||||||
|
# Note: Checking `is_cudf_available` in spark worker side because
|
||||||
|
# spark worker might has different python environment with driver side.
|
||||||
use_qdm = use_hist and is_cudf_available()
|
use_qdm = use_hist and is_cudf_available()
|
||||||
else:
|
else:
|
||||||
use_qdm = use_hist
|
use_qdm = use_hist
|
||||||
@ -787,10 +786,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
||||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||||
|
|
||||||
if use_gpu:
|
|
||||||
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
|
|
||||||
booster_params["gpu_id"] = gpu_id
|
|
||||||
|
|
||||||
_rabit_args = {}
|
_rabit_args = {}
|
||||||
if context.partitionId() == 0:
|
if context.partitionId() == 0:
|
||||||
get_logger("XGBoostPySpark").debug(
|
get_logger("XGBoostPySpark").debug(
|
||||||
@ -804,8 +799,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
|
|
||||||
_rabit_args = _get_rabit_args(context, num_workers)
|
_rabit_args = _get_rabit_args(context, num_workers)
|
||||||
|
|
||||||
messages = context.allGather(message=json.dumps(_rabit_args))
|
worker_message = {
|
||||||
_rabit_args = _get_args_from_message_list(messages)
|
"rabit_msg": _rabit_args,
|
||||||
|
"use_qdm": use_qdm,
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = context.allGather(message=json.dumps(worker_message))
|
||||||
|
if len(set(json.loads(x)["use_qdm"] for x in messages)) != 1:
|
||||||
|
raise RuntimeError("The workers' cudf environments are in-consistent ")
|
||||||
|
|
||||||
|
_rabit_args = json.loads(messages[0])["rabit_msg"]
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
with CommunicatorContext(context, **_rabit_args):
|
with CommunicatorContext(context, **_rabit_args):
|
||||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
"""Xgboost pyspark integration submodule for helper functions."""
|
"""Xgboost pyspark integration submodule for helper functions."""
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Callable, Dict, List, Set, Type
|
from typing import Any, Callable, Dict, Set, Type
|
||||||
|
|
||||||
import pyspark
|
import pyspark
|
||||||
from pyspark import BarrierTaskContext, SparkContext
|
from pyspark import BarrierTaskContext, SparkContext
|
||||||
@ -81,16 +80,6 @@ def _get_host_ip(context: BarrierTaskContext) -> str:
|
|||||||
return task_ip_list[0]
|
return task_ip_list[0]
|
||||||
|
|
||||||
|
|
||||||
def _get_args_from_message_list(messages: List[str]) -> Dict[str, Any]:
|
|
||||||
"""A function to send/recieve messages in barrier context mode"""
|
|
||||||
output = ""
|
|
||||||
for message in messages:
|
|
||||||
if message != "":
|
|
||||||
output = message
|
|
||||||
break
|
|
||||||
return json.loads(output)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_spark_session() -> SparkSession:
|
def _get_spark_session() -> SparkSession:
|
||||||
"""Get or create spark session. Note: This function can only be invoked from driver
|
"""Get or create spark session. Note: This function can only be invoked from driver
|
||||||
side.
|
side.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user