[pyspark] Cleanup the comments (#8217)

This commit is contained in:
Bobby Wang 2022-09-05 16:20:12 +08:00 committed by GitHub
parent ada4a86d1c
commit 7ee10e3dbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 5 deletions

View File

@ -322,7 +322,7 @@ if __name__ == '__main__':
# - python setup.py bdist_wheel && pip install <wheel-name> # - python setup.py bdist_wheel && pip install <wheel-name>
# When XGBoost is compiled directly with CMake: # When XGBoost is compiled directly with CMake:
# - pip install . -e # - pip install -e .
# - python setup.py develop # same as above # - python setup.py develop # same as above
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)

View File

@ -713,6 +713,13 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
is_local = _is_local(_get_spark_session().sparkContext) is_local = _is_local(_get_spark_session().sparkContext)
# Remove the parameters whose value is None
booster_params = {k: v for k, v in booster_params.items() if v is not None}
train_call_kwargs_params = {
k: v for k, v in train_call_kwargs_params.items() if v is not None
}
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
def _train_booster(pandas_df_iter): def _train_booster(pandas_df_iter):
"""Takes in an RDD partition and outputs a booster for that partition after """Takes in an RDD partition and outputs a booster for that partition after
going through the Rabit Ring protocol going through the Rabit Ring protocol
@ -737,6 +744,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
_rabit_args = "" _rabit_args = ""
if context.partitionId() == 0: if context.partitionId() == 0:
get_logger("XGBoostPySpark").info(
"booster params: %s\n"
"train_call_kwargs_params: %s\n"
"dmatrix_kwargs: %s",
booster_params,
train_call_kwargs_params,
dmatrix_kwargs,
)
_rabit_args = str(_get_rabit_args(context, num_workers)) _rabit_args = str(_get_rabit_args(context, num_workers))
messages = context.allGather(message=str(_rabit_args)) messages = context.allGather(message=str(_rabit_args))
@ -754,7 +770,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dval = [(dtrain, "training"), (dvalid, "validation")] dval = [(dtrain, "training"), (dvalid, "validation")]
else: else:
dval = None dval = None
booster = worker_train( booster = worker_train(
params=booster_params, params=booster_params,
dtrain=dtrain, dtrain=dtrain,

View File

@ -36,7 +36,7 @@ class HasBaseMarginCol(Params):
class HasFeaturesCols(Params): class HasFeaturesCols(Params):
""" """
Mixin for param featuresCols: a list of feature column names. Mixin for param features_cols: a list of feature column names.
This parameter is taken effect only when use_gpu is enabled. This parameter is taken effect only when use_gpu is enabled.
""" """
@ -76,8 +76,7 @@ class HasEnableSparseDataOptim(Params):
class HasQueryIdCol(Params): class HasQueryIdCol(Params):
""" """
Mixin for param featuresCols: a list of feature column names. Mixin for param qid_col: query id column name.
This parameter is taken effect only when use_gpu is enabled.
""" """
qid_col = Param( qid_col = Param(