[pyspark] add logs for training (#9449)

This commit is contained in:
Bobby Wang 2023-08-09 18:32:23 +08:00 committed by GitHub
parent 7f854848d3
commit d495a180d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 9 deletions

View File

@ -924,21 +924,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
# Note: Checking `is_cudf_available` in spark worker side because
# spark worker might has different python environment with driver side.
use_qdm = use_qdm and is_cudf_available()
get_logger("XGBoost-PySpark").info(
"Leveraging %s to train with QDM: %s",
booster_params["device"],
"on" if use_qdm else "off",
)
if use_qdm and (booster_params.get("max_bin", None) is not None):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
_rabit_args = {}
if context.partitionId() == 0:
get_logger("XGBoostPySpark").debug(
"booster params: %s\n"
"train_call_kwargs_params: %s\n"
"dmatrix_kwargs: %s",
booster_params,
train_call_kwargs_params,
dmatrix_kwargs,
)
_rabit_args = _get_rabit_args(context, num_workers)
worker_message = {
@ -995,7 +991,19 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
)
return ret[0], ret[1]
get_logger("XGBoost-PySpark").info(
"Running xgboost-%s on %s workers with"
"\n\tbooster params: %s"
"\n\ttrain_call_kwargs_params: %s"
"\n\tdmatrix_kwargs: %s",
xgboost._py_version(),
num_workers,
booster_params,
train_call_kwargs_params,
dmatrix_kwargs,
)
(config, booster) = _run_job()
get_logger("XGBoost-PySpark").info("Finished xgboost training!")
result_xgb_model = self._convert_to_sklearn_model(
bytearray(booster, "utf-8"), config

View File

@ -104,6 +104,10 @@ def get_logger(name: str, level: str = "INFO") -> logging.Logger:
# If the logger is configured, skip the configure
if not logger.handlers and not logging.getLogger().handlers:
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter(
"%(asctime)s %(levelname)s %(name)s: %(funcName)s %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger