[pyspark] add logs for training (#9449)
This commit is contained in:
parent
7f854848d3
commit
d495a180d8
@ -924,21 +924,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
# Note: Checking `is_cudf_available` in spark worker side because
|
# Note: Checking `is_cudf_available` in spark worker side because
|
||||||
# spark worker might has different python environment with driver side.
|
# spark worker might has different python environment with driver side.
|
||||||
use_qdm = use_qdm and is_cudf_available()
|
use_qdm = use_qdm and is_cudf_available()
|
||||||
|
get_logger("XGBoost-PySpark").info(
|
||||||
|
"Leveraging %s to train with QDM: %s",
|
||||||
|
booster_params["device"],
|
||||||
|
"on" if use_qdm else "off",
|
||||||
|
)
|
||||||
|
|
||||||
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
if use_qdm and (booster_params.get("max_bin", None) is not None):
|
||||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||||
|
|
||||||
_rabit_args = {}
|
_rabit_args = {}
|
||||||
if context.partitionId() == 0:
|
if context.partitionId() == 0:
|
||||||
get_logger("XGBoostPySpark").debug(
|
|
||||||
"booster params: %s\n"
|
|
||||||
"train_call_kwargs_params: %s\n"
|
|
||||||
"dmatrix_kwargs: %s",
|
|
||||||
booster_params,
|
|
||||||
train_call_kwargs_params,
|
|
||||||
dmatrix_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
_rabit_args = _get_rabit_args(context, num_workers)
|
_rabit_args = _get_rabit_args(context, num_workers)
|
||||||
|
|
||||||
worker_message = {
|
worker_message = {
|
||||||
@ -995,7 +991,19 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
)
|
)
|
||||||
return ret[0], ret[1]
|
return ret[0], ret[1]
|
||||||
|
|
||||||
|
get_logger("XGBoost-PySpark").info(
|
||||||
|
"Running xgboost-%s on %s workers with"
|
||||||
|
"\n\tbooster params: %s"
|
||||||
|
"\n\ttrain_call_kwargs_params: %s"
|
||||||
|
"\n\tdmatrix_kwargs: %s",
|
||||||
|
xgboost._py_version(),
|
||||||
|
num_workers,
|
||||||
|
booster_params,
|
||||||
|
train_call_kwargs_params,
|
||||||
|
dmatrix_kwargs,
|
||||||
|
)
|
||||||
(config, booster) = _run_job()
|
(config, booster) = _run_job()
|
||||||
|
get_logger("XGBoost-PySpark").info("Finished xgboost training!")
|
||||||
|
|
||||||
result_xgb_model = self._convert_to_sklearn_model(
|
result_xgb_model = self._convert_to_sklearn_model(
|
||||||
bytearray(booster, "utf-8"), config
|
bytearray(booster, "utf-8"), config
|
||||||
|
|||||||
@ -104,6 +104,10 @@ def get_logger(name: str, level: str = "INFO") -> logging.Logger:
|
|||||||
# If the logger is configured, skip the configure
|
# If the logger is configured, skip the configure
|
||||||
if not logger.handlers and not logging.getLogger().handlers:
|
if not logger.handlers and not logging.getLogger().handlers:
|
||||||
handler = logging.StreamHandler(sys.stderr)
|
handler = logging.StreamHandler(sys.stderr)
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
"%(asctime)s %(levelname)s %(name)s: %(funcName)s %(message)s"
|
||||||
|
)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user