diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 998afbf77..2150e5055 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -924,21 +924,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): # Note: Checking `is_cudf_available` in spark worker side because # spark worker might has different python environment with driver side. use_qdm = use_qdm and is_cudf_available() + get_logger("XGBoost-PySpark").info( + "Leveraging %s to train with QDM: %s", + booster_params["device"], + "on" if use_qdm else "off", + ) if use_qdm and (booster_params.get("max_bin", None) is not None): dmatrix_kwargs["max_bin"] = booster_params["max_bin"] _rabit_args = {} if context.partitionId() == 0: - get_logger("XGBoostPySpark").debug( - "booster params: %s\n" - "train_call_kwargs_params: %s\n" - "dmatrix_kwargs: %s", - booster_params, - train_call_kwargs_params, - dmatrix_kwargs, - ) - _rabit_args = _get_rabit_args(context, num_workers) worker_message = { @@ -995,7 +991,19 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): ) return ret[0], ret[1] + get_logger("XGBoost-PySpark").info( + "Running xgboost-%s on %s workers with" + "\n\tbooster params: %s" + "\n\ttrain_call_kwargs_params: %s" + "\n\tdmatrix_kwargs: %s", + xgboost._py_version(), + num_workers, + booster_params, + train_call_kwargs_params, + dmatrix_kwargs, + ) (config, booster) = _run_job() + get_logger("XGBoost-PySpark").info("Finished xgboost training!") result_xgb_model = self._convert_to_sklearn_model( bytearray(booster, "utf-8"), config diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 5f3bb19ba..33a45a90e 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -104,6 +104,10 @@ def get_logger(name: str, level: str = "INFO") -> logging.Logger: # If the logger is configured, skip the configure if not logger.handlers and not logging.getLogger().handlers: handler = logging.StreamHandler(sys.stderr) + formatter = logging.Formatter( + "%(asctime)s %(levelname)s %(name)s: %(funcName)s %(message)s" + ) + handler.setFormatter(formatter) logger.addHandler(handler) return logger