[pyspark] Implement SparkXGBRanker estimator (#8172)
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from pyspark.sql import functions as spark_sql_func
|
||||
from xgboost.spark import (
|
||||
SparkXGBClassifier,
|
||||
SparkXGBClassifierModel,
|
||||
SparkXGBRanker,
|
||||
SparkXGBRegressor,
|
||||
SparkXGBRegressorModel,
|
||||
)
|
||||
@@ -380,6 +381,28 @@ class XgboostLocalTest(SparkTestCase):
|
||||
"expected_prediction_with_base_margin",
|
||||
],
|
||||
)
|
||||
self.ranker_df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
|
||||
],
|
||||
["features", "label", "qid"],
|
||||
)
|
||||
self.ranker_df_test = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
|
||||
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
|
||||
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
|
||||
],
|
||||
["features", "qid", "expected_prediction"],
|
||||
)
|
||||
|
||||
self.reg_df_sparse_train = self.session.createDataFrame(
|
||||
[
|
||||
@@ -1024,3 +1047,12 @@ class XgboostLocalTest(SparkTestCase):
|
||||
|
||||
for row1, row2 in zip(pred_result, pred_result2):
|
||||
self.assertTrue(np.allclose(row1.probability, row2.probability, rtol=1e-3))
|
||||
|
||||
def test_ranker(self):
|
||||
ranker = SparkXGBRanker(qid_col="qid")
|
||||
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
|
||||
model = ranker.fit(self.ranker_df_train)
|
||||
pred_result = model.transform(self.ranker_df_test).collect()
|
||||
|
||||
for row in pred_result:
|
||||
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user