[pyspark] rework the log (#10077)
This commit is contained in:
parent
5ac233280e
commit
d24df52bb9
@ -95,6 +95,7 @@ from .utils import (
|
|||||||
deserialize_xgb_model,
|
deserialize_xgb_model,
|
||||||
get_class_name,
|
get_class_name,
|
||||||
get_logger,
|
get_logger,
|
||||||
|
get_logger_level,
|
||||||
serialize_booster,
|
serialize_booster,
|
||||||
use_cuda,
|
use_cuda,
|
||||||
)
|
)
|
||||||
@ -181,6 +182,8 @@ pred = Pred("prediction", "rawPrediction", "probability", "predContrib")
|
|||||||
|
|
||||||
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
|
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
|
||||||
|
|
||||||
|
_LOG_TAG = "XGBoost-PySpark"
|
||||||
|
|
||||||
|
|
||||||
class _SparkXGBParams(
|
class _SparkXGBParams(
|
||||||
HasFeaturesCol,
|
HasFeaturesCol,
|
||||||
@ -1034,6 +1037,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
|
|
||||||
num_workers = self.getOrDefault(self.num_workers)
|
num_workers = self.getOrDefault(self.num_workers)
|
||||||
|
|
||||||
|
log_level = get_logger_level(_LOG_TAG)
|
||||||
|
|
||||||
def _train_booster(
|
def _train_booster(
|
||||||
pandas_df_iter: Iterator[pd.DataFrame],
|
pandas_df_iter: Iterator[pd.DataFrame],
|
||||||
) -> Iterator[pd.DataFrame]:
|
) -> Iterator[pd.DataFrame]:
|
||||||
@ -1047,7 +1052,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
|
|
||||||
dev_ordinal = None
|
dev_ordinal = None
|
||||||
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
|
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
|
||||||
|
msg = "Training on CPUs"
|
||||||
if run_on_gpu:
|
if run_on_gpu:
|
||||||
dev_ordinal = (
|
dev_ordinal = (
|
||||||
context.partitionId() if is_local else _get_gpu_id(context)
|
context.partitionId() if is_local else _get_gpu_id(context)
|
||||||
@ -1058,10 +1063,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
# Note: Checking `is_cudf_available` in spark worker side because
|
# Note: Checking `is_cudf_available` in spark worker side because
|
||||||
# spark worker might has different python environment with driver side.
|
# spark worker might has different python environment with driver side.
|
||||||
use_qdm = use_qdm and is_cudf_available()
|
use_qdm = use_qdm and is_cudf_available()
|
||||||
get_logger("XGBoost-PySpark").info(
|
msg = (
|
||||||
"Leveraging %s to train with QDM: %s",
|
f"Leveraging {booster_params['device']} to train with "
|
||||||
booster_params["device"],
|
f"QDM: {'on' if use_qdm else 'off'}"
|
||||||
"on" if use_qdm else "off",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
||||||
@ -1070,6 +1074,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
_rabit_args = {}
|
_rabit_args = {}
|
||||||
if context.partitionId() == 0:
|
if context.partitionId() == 0:
|
||||||
_rabit_args = _get_rabit_args(context, num_workers)
|
_rabit_args = _get_rabit_args(context, num_workers)
|
||||||
|
get_logger(_LOG_TAG, log_level).info(msg)
|
||||||
|
|
||||||
worker_message = {
|
worker_message = {
|
||||||
"rabit_msg": _rabit_args,
|
"rabit_msg": _rabit_args,
|
||||||
@ -1127,7 +1132,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
ret = rdd_with_resource.collect()[0]
|
ret = rdd_with_resource.collect()[0]
|
||||||
return ret[0], ret[1]
|
return ret[0], ret[1]
|
||||||
|
|
||||||
get_logger("XGBoost-PySpark").info(
|
get_logger(_LOG_TAG).info(
|
||||||
"Running xgboost-%s on %s workers with"
|
"Running xgboost-%s on %s workers with"
|
||||||
"\n\tbooster params: %s"
|
"\n\tbooster params: %s"
|
||||||
"\n\ttrain_call_kwargs_params: %s"
|
"\n\ttrain_call_kwargs_params: %s"
|
||||||
@ -1139,7 +1144,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
dmatrix_kwargs,
|
dmatrix_kwargs,
|
||||||
)
|
)
|
||||||
(config, booster) = _run_job()
|
(config, booster) = _run_job()
|
||||||
get_logger("XGBoost-PySpark").info("Finished xgboost training!")
|
get_logger(_LOG_TAG).info("Finished xgboost training!")
|
||||||
|
|
||||||
result_xgb_model = self._convert_to_sklearn_model(
|
result_xgb_model = self._convert_to_sklearn_model(
|
||||||
bytearray(booster, "utf-8"), config
|
bytearray(booster, "utf-8"), config
|
||||||
@ -1342,7 +1347,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
# User don't set gpu configurations, just use cpu
|
# User don't set gpu configurations, just use cpu
|
||||||
if gpu_per_task is None:
|
if gpu_per_task is None:
|
||||||
if use_gpu_by_params:
|
if use_gpu_by_params:
|
||||||
get_logger("XGBoost-PySpark").warning(
|
get_logger(_LOG_TAG).warning(
|
||||||
"Do the prediction on the CPUs since "
|
"Do the prediction on the CPUs since "
|
||||||
"no gpu configurations are set"
|
"no gpu configurations are set"
|
||||||
)
|
)
|
||||||
@ -1377,6 +1382,8 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
is_local = _is_local(_get_spark_session().sparkContext)
|
is_local = _is_local(_get_spark_session().sparkContext)
|
||||||
run_on_gpu = self._run_on_gpu()
|
run_on_gpu = self._run_on_gpu()
|
||||||
|
|
||||||
|
log_level = get_logger_level(_LOG_TAG)
|
||||||
|
|
||||||
@pandas_udf(schema) # type: ignore
|
@pandas_udf(schema) # type: ignore
|
||||||
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
||||||
assert xgb_sklearn_model is not None
|
assert xgb_sklearn_model is not None
|
||||||
@ -1413,7 +1420,8 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
else:
|
else:
|
||||||
msg = "CUDF or Cupy is unavailable, fallback the inference on the CPUs"
|
msg = "CUDF or Cupy is unavailable, fallback the inference on the CPUs"
|
||||||
|
|
||||||
get_logger("XGBoost-PySpark").info(msg)
|
if context.partitionId() == 0:
|
||||||
|
get_logger(_LOG_TAG, log_level).info(msg)
|
||||||
|
|
||||||
def to_gpu_if_possible(data: ArrayLike) -> ArrayLike:
|
def to_gpu_if_possible(data: ArrayLike) -> ArrayLike:
|
||||||
"""Move the data to gpu if possible"""
|
"""Move the data to gpu if possible"""
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Callable, Dict, Optional, Set, Type
|
from typing import Any, Callable, Dict, Optional, Set, Type, Union
|
||||||
|
|
||||||
import pyspark
|
import pyspark
|
||||||
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
|
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
|
||||||
@ -98,10 +98,15 @@ def _get_spark_session() -> SparkSession:
|
|||||||
return SparkSession.builder.getOrCreate()
|
return SparkSession.builder.getOrCreate()
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: str, level: str = "INFO") -> logging.Logger:
|
def get_logger(name: str, level: Optional[Union[str, int]] = None) -> logging.Logger:
|
||||||
"""Gets a logger by name, or creates and configures it for the first time."""
|
"""Gets a logger by name, or creates and configures it for the first time."""
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(level)
|
if level is not None:
|
||||||
|
logger.setLevel(level)
|
||||||
|
else:
|
||||||
|
# Default to info if not set.
|
||||||
|
if logger.level == logging.NOTSET:
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
# If the logger is configured, skip the configure
|
# If the logger is configured, skip the configure
|
||||||
if not logger.handlers and not logging.getLogger().handlers:
|
if not logger.handlers and not logging.getLogger().handlers:
|
||||||
handler = logging.StreamHandler(sys.stderr)
|
handler = logging.StreamHandler(sys.stderr)
|
||||||
@ -113,6 +118,12 @@ def get_logger(name: str, level: str = "INFO") -> logging.Logger:
|
|||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger_level(name: str) -> Optional[int]:
|
||||||
|
"""Get the logger level for the given log name"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
return None if logger.level == logging.NOTSET else logger.level
|
||||||
|
|
||||||
|
|
||||||
def _get_max_num_concurrent_tasks(spark_context: SparkContext) -> int:
|
def _get_max_num_concurrent_tasks(spark_context: SparkContext) -> int:
|
||||||
"""Gets the current max number of concurrent tasks."""
|
"""Gets the current max number of concurrent tasks."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user