[pyspark] add logs for training (#9449)
This commit is contained in:
parent
7f854848d3
commit
d495a180d8
@ -924,21 +924,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
# Note: Checking `is_cudf_available` in spark worker side because
|
||||
# spark worker might has different python environment with driver side.
|
||||
use_qdm = use_qdm and is_cudf_available()
|
||||
get_logger("XGBoost-PySpark").info(
|
||||
"Leveraging %s to train with QDM: %s",
|
||||
booster_params["device"],
|
||||
"on" if use_qdm else "off",
|
||||
)
|
||||
|
||||
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||
|
||||
_rabit_args = {}
|
||||
if context.partitionId() == 0:
|
||||
get_logger("XGBoostPySpark").debug(
|
||||
"booster params: %s\n"
|
||||
"train_call_kwargs_params: %s\n"
|
||||
"dmatrix_kwargs: %s",
|
||||
booster_params,
|
||||
train_call_kwargs_params,
|
||||
dmatrix_kwargs,
|
||||
)
|
||||
|
||||
_rabit_args = _get_rabit_args(context, num_workers)
|
||||
|
||||
worker_message = {
|
||||
@ -995,7 +991,19 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
)
|
||||
return ret[0], ret[1]
|
||||
|
||||
get_logger("XGBoost-PySpark").info(
|
||||
"Running xgboost-%s on %s workers with"
|
||||
"\n\tbooster params: %s"
|
||||
"\n\ttrain_call_kwargs_params: %s"
|
||||
"\n\tdmatrix_kwargs: %s",
|
||||
xgboost._py_version(),
|
||||
num_workers,
|
||||
booster_params,
|
||||
train_call_kwargs_params,
|
||||
dmatrix_kwargs,
|
||||
)
|
||||
(config, booster) = _run_job()
|
||||
get_logger("XGBoost-PySpark").info("Finished xgboost training!")
|
||||
|
||||
result_xgb_model = self._convert_to_sklearn_model(
|
||||
bytearray(booster, "utf-8"), config
|
||||
|
||||
@ -104,6 +104,10 @@ def get_logger(name: str, level: str = "INFO") -> logging.Logger:
|
||||
# If the logger is configured, skip the configure
|
||||
if not logger.handlers and not logging.getLogger().handlers:
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s %(levelname)s %(name)s: %(funcName)s %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user