* [pyspark] sort qid for SparkRandker * resolve comments Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
parent
58bc225657
commit
60a8c8ebba
@ -1,7 +1,7 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for core code."""
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=too-few-public-methods, too-many-lines
|
||||
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
|
||||
import json
|
||||
from typing import Iterator, Optional, Tuple
|
||||
|
||||
@ -728,6 +728,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
else:
|
||||
dataset = dataset.repartition(num_workers)
|
||||
|
||||
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
|
||||
# XGBoost requires qid to be sorted for each partition
|
||||
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
|
||||
|
||||
train_params = self._get_distributed_train_params(dataset)
|
||||
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
|
||||
train_params
|
||||
|
||||
@ -390,28 +390,6 @@ class XgboostLocalTest(SparkTestCase):
|
||||
"expected_prediction_with_base_margin",
|
||||
],
|
||||
)
|
||||
self.ranker_df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
|
||||
],
|
||||
["features", "label", "qid"],
|
||||
)
|
||||
self.ranker_df_test = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
|
||||
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
|
||||
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
|
||||
],
|
||||
["features", "qid", "expected_prediction"],
|
||||
)
|
||||
|
||||
self.reg_df_sparse_train = self.session.createDataFrame(
|
||||
[
|
||||
@ -1039,15 +1017,6 @@ class XgboostLocalTest(SparkTestCase):
|
||||
for row1, row2 in zip(pred_result, pred_result2):
|
||||
self.assertTrue(np.allclose(row1.probability, row2.probability, rtol=1e-3))
|
||||
|
||||
def test_ranker(self):
|
||||
ranker = SparkXGBRanker(qid_col="qid")
|
||||
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
|
||||
model = ranker.fit(self.ranker_df_train)
|
||||
pred_result = model.transform(self.ranker_df_test).collect()
|
||||
|
||||
for row in pred_result:
|
||||
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
|
||||
|
||||
def test_empty_validation_data(self) -> None:
|
||||
for tree_method in [
|
||||
"hist",
|
||||
@ -1130,3 +1099,63 @@ class XgboostLocalTest(SparkTestCase):
|
||||
def test_unsupported_params(self):
|
||||
with pytest.raises(ValueError, match="evals_result"):
|
||||
SparkXGBClassifier(evals_result={})
|
||||
|
||||
|
||||
class XgboostRankerLocalTest(SparkTestCase):
|
||||
def setUp(self):
|
||||
self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8")
|
||||
self.ranker_df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
|
||||
],
|
||||
["features", "label", "qid"],
|
||||
)
|
||||
self.ranker_df_test = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
|
||||
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
|
||||
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
|
||||
],
|
||||
["features", "qid", "expected_prediction"],
|
||||
)
|
||||
self.ranker_df_train_1 = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9),
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 8),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 8),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 8),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7),
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 6),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 6),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 6),
|
||||
]
|
||||
* 4,
|
||||
["features", "label", "qid"],
|
||||
)
|
||||
|
||||
def test_ranker(self):
|
||||
ranker = SparkXGBRanker(qid_col="qid")
|
||||
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
|
||||
model = ranker.fit(self.ranker_df_train)
|
||||
pred_result = model.transform(self.ranker_df_test).collect()
|
||||
|
||||
for row in pred_result:
|
||||
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
|
||||
|
||||
def test_ranker_qid_sorted(self):
|
||||
ranker = SparkXGBRanker(qid_col="qid", num_workers=4)
|
||||
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
|
||||
model = ranker.fit(self.ranker_df_train_1)
|
||||
model.transform(self.ranker_df_test).collect()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user