[pyspark] sort qid for SparkRanker (#8497)
* [pyspark] sort qid for SparkRandker * resolve comments
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for core code."""
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=too-few-public-methods, too-many-lines
|
||||
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
|
||||
import json
|
||||
from typing import Iterator, Optional, Tuple
|
||||
|
||||
@@ -729,6 +729,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
else:
|
||||
dataset = dataset.repartition(num_workers)
|
||||
|
||||
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
|
||||
# XGBoost requires qid to be sorted for each partition
|
||||
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
|
||||
|
||||
train_params = self._get_distributed_train_params(dataset)
|
||||
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
|
||||
train_params
|
||||
|
||||
Reference in New Issue
Block a user