[pyspark] Implement SparkXGBRanker estimator (#8172)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
WeichenXu
2022-08-23 02:35:19 +08:00
committed by GitHub
parent 35ef8abc27
commit f4628c22a4
6 changed files with 235 additions and 27 deletions

View File

@@ -24,6 +24,7 @@ from pyspark.sql import functions as spark_sql_func
from xgboost.spark import (
SparkXGBClassifier,
SparkXGBClassifierModel,
SparkXGBRanker,
SparkXGBRegressor,
SparkXGBRegressorModel,
)
@@ -380,6 +381,28 @@ class XgboostLocalTest(SparkTestCase):
"expected_prediction_with_base_margin",
],
)
self.ranker_df_train = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
],
["features", "label", "qid"],
)
self.ranker_df_test = self.session.createDataFrame(
[
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
],
["features", "qid", "expected_prediction"],
)
self.reg_df_sparse_train = self.session.createDataFrame(
[
@@ -1024,3 +1047,12 @@ class XgboostLocalTest(SparkTestCase):
for row1, row2 in zip(pred_result, pred_result2):
self.assertTrue(np.allclose(row1.probability, row2.probability, rtol=1e-3))
def test_ranker(self):
ranker = SparkXGBRanker(qid_col="qid")
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
model = ranker.fit(self.ranker_df_train)
pred_result = model.transform(self.ranker_df_test).collect()
for row in pred_result:
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)