[pyspark] sort qid for SparkRanker (#8497)

* [pyspark] sort qid for SparkRandker

* resolve comments
This commit is contained in:
Bobby Wang
2022-12-02 08:40:35 +08:00
committed by GitHub
parent f747e05eac
commit 8e41ad24f5
2 changed files with 65 additions and 32 deletions

View File

@@ -1,7 +1,7 @@
# type: ignore
"""Xgboost pyspark integration submodule for core code."""
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
import json
from typing import Iterator, Optional, Tuple
@@ -729,6 +729,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
else:
dataset = dataset.repartition(num_workers)
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
# XGBoost requires qid to be sorted for each partition
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
train_params = self._get_distributed_train_params(dataset)
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
train_params