[pyspark] rework the log (#10077)

This commit is contained in:
Bobby Wang 2024-02-29 16:47:31 +08:00 committed by GitHub
parent 5ac233280e
commit d24df52bb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 12 deletions

View File

@ -95,6 +95,7 @@ from .utils import (
deserialize_xgb_model,
get_class_name,
get_logger,
get_logger_level,
serialize_booster,
use_cuda,
)
@ -181,6 +182,8 @@ pred = Pred("prediction", "rawPrediction", "probability", "predContrib")
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
_LOG_TAG = "XGBoost-PySpark"
class _SparkXGBParams(
HasFeaturesCol,
@ -1034,6 +1037,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
num_workers = self.getOrDefault(self.num_workers)
log_level = get_logger_level(_LOG_TAG)
def _train_booster(
pandas_df_iter: Iterator[pd.DataFrame],
) -> Iterator[pd.DataFrame]:
@ -1047,7 +1052,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dev_ordinal = None
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
msg = "Training on CPUs"
if run_on_gpu:
dev_ordinal = (
context.partitionId() if is_local else _get_gpu_id(context)
@ -1058,10 +1063,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
# Note: Checking `is_cudf_available` in spark worker side because
# spark worker might has different python environment with driver side.
use_qdm = use_qdm and is_cudf_available()
get_logger("XGBoost-PySpark").info(
"Leveraging %s to train with QDM: %s",
booster_params["device"],
"on" if use_qdm else "off",
msg = (
f"Leveraging {booster_params['device']} to train with "
f"QDM: {'on' if use_qdm else 'off'}"
)
if use_qdm and (booster_params.get("max_bin", None) is not None):
@ -1070,6 +1074,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
_rabit_args = {}
if context.partitionId() == 0:
_rabit_args = _get_rabit_args(context, num_workers)
get_logger(_LOG_TAG, log_level).info(msg)
worker_message = {
"rabit_msg": _rabit_args,
@ -1127,7 +1132,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
ret = rdd_with_resource.collect()[0]
return ret[0], ret[1]
get_logger("XGBoost-PySpark").info(
get_logger(_LOG_TAG).info(
"Running xgboost-%s on %s workers with"
"\n\tbooster params: %s"
"\n\ttrain_call_kwargs_params: %s"
@ -1139,7 +1144,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dmatrix_kwargs,
)
(config, booster) = _run_job()
get_logger("XGBoost-PySpark").info("Finished xgboost training!")
get_logger(_LOG_TAG).info("Finished xgboost training!")
result_xgb_model = self._convert_to_sklearn_model(
bytearray(booster, "utf-8"), config
@ -1342,7 +1347,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
# User don't set gpu configurations, just use cpu
if gpu_per_task is None:
if use_gpu_by_params:
get_logger("XGBoost-PySpark").warning(
get_logger(_LOG_TAG).warning(
"Do the prediction on the CPUs since "
"no gpu configurations are set"
)
@ -1377,6 +1382,8 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
is_local = _is_local(_get_spark_session().sparkContext)
run_on_gpu = self._run_on_gpu()
log_level = get_logger_level(_LOG_TAG)
@pandas_udf(schema) # type: ignore
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
assert xgb_sklearn_model is not None
@ -1413,7 +1420,8 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
else:
msg = "CUDF or Cupy is unavailable, fallback the inference on the CPUs"
get_logger("XGBoost-PySpark").info(msg)
if context.partitionId() == 0:
get_logger(_LOG_TAG, log_level).info(msg)
def to_gpu_if_possible(data: ArrayLike) -> ArrayLike:
"""Move the data to gpu if possible"""

View File

@ -8,7 +8,7 @@ import os
import sys
import uuid
from threading import Thread
from typing import Any, Callable, Dict, Optional, Set, Type
from typing import Any, Callable, Dict, Optional, Set, Type, Union
import pyspark
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
@ -98,10 +98,15 @@ def _get_spark_session() -> SparkSession:
return SparkSession.builder.getOrCreate()
def get_logger(name: str, level: str = "INFO") -> logging.Logger:
def get_logger(name: str, level: Optional[Union[str, int]] = None) -> logging.Logger:
"""Gets a logger by name, or creates and configures it for the first time."""
logger = logging.getLogger(name)
logger.setLevel(level)
if level is not None:
logger.setLevel(level)
else:
# Default to info if not set.
if logger.level == logging.NOTSET:
logger.setLevel(logging.INFO)
# If the logger is configured, skip the configure
if not logger.handlers and not logging.getLogger().handlers:
handler = logging.StreamHandler(sys.stderr)
@ -113,6 +118,12 @@ def get_logger(name: str, level: str = "INFO") -> logging.Logger:
return logger
def get_logger_level(name: str) -> Optional[int]:
"""Get the logger level for the given log name"""
logger = logging.getLogger(name)
return None if logger.level == logging.NOTSET else logger.level
def _get_max_num_concurrent_tasks(spark_context: SparkContext) -> int:
"""Gets the current max number of concurrent tasks."""
# pylint: disable=protected-access