[jvm-packages] Support Ranker (#10823)
This commit is contained in:
parent
d7599e095b
commit
19b55b300b
@ -93,7 +93,8 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
|
||||
selectedCols.append(col)
|
||||
}
|
||||
val input = dataset.select(selectedCols.toArray: _*)
|
||||
estimator.repartitionIfNeeded(input)
|
||||
val repartitioned = estimator.repartitionIfNeeded(input)
|
||||
estimator.sortPartitionIfNeeded(repartitioned)
|
||||
}
|
||||
|
||||
// visible for testing
|
||||
|
||||
@ -16,14 +16,14 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ai.rapids.cudf.Table
|
||||
import ai.rapids.cudf.{OrderByArg, Table}
|
||||
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix, XGBoost => ScalaXGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.rapids.spark.GpuTestSuite
|
||||
import ml.dmlc.xgboost4j.scala.rapids.spark.SparkSessionHolder.withSparkSession
|
||||
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.sql.{Dataset, SparkSession}
|
||||
import org.apache.spark.sql.{Dataset, Row, SparkSession}
|
||||
import org.apache.spark.SparkConf
|
||||
|
||||
import java.io.File
|
||||
@ -94,7 +94,9 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
|
||||
}
|
||||
|
||||
// spark.rapids.sql.enabled is not set explicitly, default to true
|
||||
withSparkSession(new SparkConf(), spark => {checkIsEnabled(spark, true)})
|
||||
withSparkSession(new SparkConf(), spark => {
|
||||
checkIsEnabled(spark, true)
|
||||
})
|
||||
|
||||
// set spark.rapids.sql.enabled to false
|
||||
withCpuSparkSession() { spark =>
|
||||
@ -503,6 +505,109 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
|
||||
}
|
||||
}
|
||||
|
||||
test("The group col should be sorted in each partition") {
|
||||
withGpuSparkSession() { spark =>
|
||||
import spark.implicits._
|
||||
val df = Ranking.train.toDF("label", "weight", "group", "c1", "c2", "c3")
|
||||
|
||||
val xgboostParams: Map[String, Any] = Map(
|
||||
"device" -> "cuda",
|
||||
"objective" -> "rank:ndcg"
|
||||
)
|
||||
val features = Array("c1", "c2", "c3")
|
||||
val label = "label"
|
||||
val group = "group"
|
||||
|
||||
val ranker = new XGBoostRanker(xgboostParams)
|
||||
.setFeaturesCol(features)
|
||||
.setLabelCol(label)
|
||||
.setNumWorkers(1)
|
||||
.setNumRound(1)
|
||||
.setGroupCol(group)
|
||||
.setDevice("cuda")
|
||||
|
||||
val processedDf = ranker.getPlugin.get.asInstanceOf[GpuXGBoostPlugin].preprocess(ranker, df)
|
||||
processedDf.rdd.foreachPartition { iter => {
|
||||
var prevGroup = Int.MinValue
|
||||
while (iter.hasNext) {
|
||||
val curr = iter.next()
|
||||
val group = curr.asInstanceOf[Row].getAs[Int](1)
|
||||
assert(prevGroup <= group)
|
||||
prevGroup = group
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Ranker: XGBoost-Spark should match xgboost4j") {
|
||||
withGpuSparkSession() { spark =>
|
||||
import spark.implicits._
|
||||
|
||||
val trainPath = writeFile(Ranking.train.toDF("label", "weight", "group", "c1", "c2", "c3"))
|
||||
val testPath = writeFile(Ranking.test.toDF("label", "weight", "group", "c1", "c2", "c3"))
|
||||
|
||||
val df = spark.read.parquet(trainPath)
|
||||
val testdf = spark.read.parquet(testPath)
|
||||
|
||||
val features = Array("c1", "c2", "c3")
|
||||
val featuresIndices = features.map(df.schema.fieldIndex)
|
||||
val label = "label"
|
||||
val group = "group"
|
||||
|
||||
val numRound = 100
|
||||
val xgboostParams: Map[String, Any] = Map(
|
||||
"device" -> "cuda",
|
||||
"objective" -> "rank:ndcg"
|
||||
)
|
||||
|
||||
val ranker = new XGBoostRanker(xgboostParams)
|
||||
.setFeaturesCol(features)
|
||||
.setLabelCol(label)
|
||||
.setNumRound(numRound)
|
||||
.setLeafPredictionCol("leaf")
|
||||
.setContribPredictionCol("contrib")
|
||||
.setGroupCol(group)
|
||||
.setDevice("cuda")
|
||||
|
||||
val xgb4jModel = withResource(new GpuColumnBatch(
|
||||
Table.readParquet(new File(trainPath)
|
||||
).orderBy(OrderByArg.asc(df.schema.fieldIndex(group))))) { batch =>
|
||||
val cb = new CudfColumnBatch(batch.select(featuresIndices),
|
||||
batch.select(df.schema.fieldIndex(label)), null, null,
|
||||
batch.select(df.schema.fieldIndex(group)))
|
||||
val qdm = new QuantileDMatrix(Seq(cb).iterator, ranker.getMissing,
|
||||
ranker.getMaxBins, ranker.getNthread)
|
||||
ScalaXGBoost.train(qdm, xgboostParams, numRound)
|
||||
}
|
||||
|
||||
val (xgb4jLeaf, xgb4jContrib, xgb4jPred) = withResource(new GpuColumnBatch(
|
||||
Table.readParquet(new File(testPath)))) { batch =>
|
||||
val cb = new CudfColumnBatch(batch.select(featuresIndices), null, null, null, null
|
||||
)
|
||||
val qdm = new DMatrix(cb, ranker.getMissing, ranker.getNthread)
|
||||
(xgb4jModel.predictLeaf(qdm), xgb4jModel.predictContrib(qdm),
|
||||
xgb4jModel.predict(qdm))
|
||||
}
|
||||
|
||||
val rows = ranker.fit(df).transform(testdf).collect()
|
||||
|
||||
// Check Leaf
|
||||
val xgbSparkLeaf = rows.map(row => row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))
|
||||
checkEqual(xgb4jLeaf, xgbSparkLeaf)
|
||||
|
||||
// Check contrib
|
||||
val xgbSparkContrib = rows.map(row =>
|
||||
row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))
|
||||
checkEqual(xgb4jContrib, xgbSparkContrib)
|
||||
|
||||
// Check prediction
|
||||
val xgbSparkPred = rows.map(row =>
|
||||
Array(row.getAs[Double]("prediction").toFloat))
|
||||
checkEqual(xgb4jPred, xgbSparkPred)
|
||||
}
|
||||
}
|
||||
|
||||
def writeFile(df: Dataset[_]): String = {
|
||||
def listFiles(directory: String): Array[String] = {
|
||||
val dir = new File(directory)
|
||||
|
||||
@ -81,6 +81,6 @@ object Regression extends TrainTestData {
|
||||
}
|
||||
|
||||
object Ranking extends TrainTestData {
|
||||
val train = generateRankDataset(300, 10, 555)
|
||||
val test = generateRankDataset(150, 10, 556)
|
||||
val train = generateRankDataset(300, 10, 12, 555)
|
||||
val test = generateRankDataset(150, 10, 12, 556)
|
||||
}
|
||||
|
||||
@ -134,6 +134,15 @@ private[spark] trait XGBoostEstimator[
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort partition for Ranker issue.
|
||||
* @param dataset
|
||||
* @return
|
||||
*/
|
||||
private[spark] def sortPartitionIfNeeded(dataset: Dataset[_]): Dataset[_] = {
|
||||
dataset
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the columns indices.
|
||||
*/
|
||||
@ -198,10 +207,10 @@ private[spark] trait XGBoostEstimator[
|
||||
case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
|
||||
case _ =>
|
||||
}
|
||||
val input = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*))
|
||||
|
||||
val columnIndices = buildColumnIndices(input.schema)
|
||||
(input, columnIndices)
|
||||
val repartitioned = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*))
|
||||
val sorted = sortPartitionIfNeeded(repartitioned)
|
||||
val columnIndices = buildColumnIndices(sorted.schema)
|
||||
(sorted, columnIndices)
|
||||
}
|
||||
|
||||
/** visible for testing */
|
||||
|
||||
@ -0,0 +1,124 @@
|
||||
/*
|
||||
Copyright (c) 2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.spark.ml.{PredictionModel, Predictor}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
|
||||
import org.apache.spark.ml.xgboost.SparkUtils
|
||||
import org.apache.spark.sql.Dataset
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostRanker._uid
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.RANKER_OBJS
|
||||
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
||||
|
||||
class XGBoostRanker(override val uid: String,
|
||||
private val xgboostParams: Map[String, Any])
|
||||
extends Predictor[Vector, XGBoostRanker, XGBoostRankerModel]
|
||||
with XGBoostEstimator[XGBoostRanker, XGBoostRankerModel] with HasGroupCol {
|
||||
|
||||
def this() = this(_uid, Map[String, Any]())
|
||||
|
||||
def this(uid: String) = this(uid, Map[String, Any]())
|
||||
|
||||
def this(xgboostParams: Map[String, Any]) = this(_uid, xgboostParams)
|
||||
|
||||
def setGroupCol(value: String): XGBoostRanker = set(groupCol, value)
|
||||
|
||||
xgboost2SparkParams(xgboostParams)
|
||||
|
||||
/**
|
||||
* Validate the parameters before training, throw exception if possible
|
||||
*/
|
||||
override protected[spark] def validate(dataset: Dataset[_]): Unit = {
|
||||
super.validate(dataset)
|
||||
|
||||
require(isDefinedNonEmpty(groupCol), "groupCol needs to be set")
|
||||
|
||||
// If the objective is set explicitly, it must be in RANKER_OBJS
|
||||
if (isSet(objective)) {
|
||||
val tmpObj = getObjective
|
||||
require(RANKER_OBJS.contains(tmpObj),
|
||||
s"Wrong objective for XGBoostRanker, supported objs: ${RANKER_OBJS.mkString(",")}")
|
||||
} else {
|
||||
setObjective("rank:ndcg")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort partition for Ranker issue.
|
||||
*
|
||||
* @param dataset
|
||||
* @return
|
||||
*/
|
||||
override private[spark] def sortPartitionIfNeeded(dataset: Dataset[_]) = {
|
||||
dataset.sortWithinPartitions(getGroupCol)
|
||||
}
|
||||
|
||||
override protected def createModel(
|
||||
booster: Booster,
|
||||
summary: XGBoostTrainingSummary): XGBoostRankerModel = {
|
||||
new XGBoostRankerModel(uid, booster, Option(summary))
|
||||
}
|
||||
|
||||
override protected def validateAndTransformSchema(
|
||||
schema: StructType,
|
||||
fitting: Boolean,
|
||||
featuresDataType: DataType): StructType =
|
||||
SparkUtils.appendColumn(schema, $(predictionCol), DoubleType)
|
||||
}
|
||||
|
||||
object XGBoostRanker extends DefaultParamsReadable[XGBoostRanker] {
|
||||
private val _uid = Identifiable.randomUID("xgbranker")
|
||||
}
|
||||
|
||||
class XGBoostRankerModel private[ml](val uid: String,
|
||||
val nativeBooster: Booster,
|
||||
val summary: Option[XGBoostTrainingSummary] = None)
|
||||
extends PredictionModel[Vector, XGBoostRankerModel]
|
||||
with RankerRegressorBaseModel[XGBoostRankerModel] with HasGroupCol {
|
||||
|
||||
def this(uid: String) = this(uid, null)
|
||||
|
||||
def setGroupCol(value: String): XGBoostRankerModel = set(groupCol, value)
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostRankerModel = {
|
||||
val newModel = copyValues(new XGBoostRankerModel(uid, nativeBooster, summary), extra)
|
||||
newModel.setParent(parent)
|
||||
}
|
||||
|
||||
override def predict(features: Vector): Double = {
|
||||
val values = predictSingleInstance(features)
|
||||
values(0)
|
||||
}
|
||||
}
|
||||
|
||||
object XGBoostRankerModel extends MLReadable[XGBoostRankerModel] {
|
||||
override def read: MLReader[XGBoostRankerModel] = new ModelReader
|
||||
|
||||
private class ModelReader extends XGBoostModelReader[XGBoostRankerModel] {
|
||||
override def load(path: String): XGBoostRankerModel = {
|
||||
val xgbModel = loadBooster(path)
|
||||
val meta = SparkUtils.loadMetadata(path, sc)
|
||||
val model = new XGBoostRankerModel(meta.uid, xgbModel, None)
|
||||
meta.getAndSetParams(model)
|
||||
model
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,309 @@
|
||||
/*
|
||||
Copyright (c) 2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.ml.linalg.{DenseVector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.spark.Regression.Ranking
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.RANKER_OBJS
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams
|
||||
|
||||
class XGBoostRankerSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||
|
||||
test("XGBoostRanker copy") {
|
||||
val ranker = new XGBoostRanker().setNthread(2).setNumWorkers(10)
|
||||
val rankertCopied = ranker.copy(ParamMap.empty)
|
||||
|
||||
assert(ranker.uid === rankertCopied.uid)
|
||||
assert(ranker.getNthread === rankertCopied.getNthread)
|
||||
assert(ranker.getNumWorkers === ranker.getNumWorkers)
|
||||
}
|
||||
|
||||
test("XGBoostRankerModel copy") {
|
||||
val model = new XGBoostRankerModel("hello").setNthread(2).setNumWorkers(10)
|
||||
val modelCopied = model.copy(ParamMap.empty)
|
||||
assert(model.uid === modelCopied.uid)
|
||||
assert(model.getNthread === modelCopied.getNthread)
|
||||
assert(model.getNumWorkers === modelCopied.getNumWorkers)
|
||||
}
|
||||
|
||||
test("read/write") {
|
||||
val trainDf = smallGroupVector
|
||||
val xgbParams: Map[String, Any] = Map(
|
||||
"max_depth" -> 5,
|
||||
"eta" -> 0.2,
|
||||
"objective" -> "rank:ndcg"
|
||||
)
|
||||
|
||||
def check(xgboostParams: XGBoostParams[_]): Unit = {
|
||||
assert(xgboostParams.getMaxDepth === 5)
|
||||
assert(xgboostParams.getEta === 0.2)
|
||||
assert(xgboostParams.getObjective === "rank:ndcg")
|
||||
}
|
||||
|
||||
val rankerPath = new File(tempDir.toFile, "ranker").getPath
|
||||
val ranker = new XGBoostRanker(xgbParams).setNumRound(1).setGroupCol("group")
|
||||
check(ranker)
|
||||
assert(ranker.getGroupCol === "group")
|
||||
|
||||
ranker.write.overwrite().save(rankerPath)
|
||||
val loadedRanker = XGBoostRanker.load(rankerPath)
|
||||
check(loadedRanker)
|
||||
assert(loadedRanker.getGroupCol === "group")
|
||||
|
||||
val model = loadedRanker.fit(trainDf)
|
||||
check(model)
|
||||
assert(model.getGroupCol === "group")
|
||||
|
||||
val modelPath = new File(tempDir.toFile, "model").getPath
|
||||
model.write.overwrite().save(modelPath)
|
||||
val modelLoaded = XGBoostRankerModel.load(modelPath)
|
||||
check(modelLoaded)
|
||||
assert(modelLoaded.getGroupCol === "group")
|
||||
}
|
||||
|
||||
test("validate") {
|
||||
val trainDf = smallGroupVector
|
||||
val ranker = new XGBoostRanker()
|
||||
// must define group column
|
||||
intercept[IllegalArgumentException](
|
||||
ranker.validate(trainDf)
|
||||
)
|
||||
val ranker1 = new XGBoostRanker().setGroupCol("group")
|
||||
ranker1.validate(trainDf)
|
||||
assert(ranker1.getObjective === "rank:ndcg")
|
||||
}
|
||||
|
||||
test("XGBoostRankerModel transformed schema") {
|
||||
val trainDf = smallGroupVector
|
||||
val ranker = new XGBoostRanker().setGroupCol("group").setNumRound(1)
|
||||
val model = ranker.fit(trainDf)
|
||||
var out = model.transform(trainDf)
|
||||
// Transform should not discard the other columns of the transforming dataframe
|
||||
Seq("label", "group", "margin", "weight", "features").foreach { v =>
|
||||
assert(out.schema.names.contains(v))
|
||||
}
|
||||
// Ranker does not have extra columns
|
||||
Seq("rawPrediction", "probability").foreach { v =>
|
||||
assert(!out.schema.names.contains(v))
|
||||
}
|
||||
assert(out.schema.names.contains("prediction"))
|
||||
assert(out.schema.names.length === 6)
|
||||
model.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
|
||||
out = model.transform(trainDf)
|
||||
assert(out.schema.names.contains("leaf"))
|
||||
assert(out.schema.names.contains("contrib"))
|
||||
}
|
||||
|
||||
test("Supported objectives") {
|
||||
val ranker = new XGBoostRanker().setGroupCol("group")
|
||||
val df = smallGroupVector
|
||||
RANKER_OBJS.foreach { obj =>
|
||||
ranker.setObjective(obj)
|
||||
ranker.validate(df)
|
||||
}
|
||||
|
||||
ranker.setObjective("binary:logistic")
|
||||
intercept[IllegalArgumentException](
|
||||
ranker.validate(df)
|
||||
)
|
||||
}
|
||||
|
||||
test("The group col should be sorted in each partition") {
|
||||
val trainingDF = buildDataFrameWithGroup(Ranking.train)
|
||||
|
||||
val ranker = new XGBoostRanker()
|
||||
.setNumRound(1)
|
||||
.setNumWorkers(numWorkers)
|
||||
.setGroupCol("group")
|
||||
|
||||
val (df, _) = ranker.preprocess(trainingDF)
|
||||
df.rdd.foreachPartition { iter => {
|
||||
var prevGroup = Int.MinValue
|
||||
while (iter.hasNext) {
|
||||
val curr = iter.next()
|
||||
val group = curr.asInstanceOf[Row].getAs[Int](2)
|
||||
assert(prevGroup <= group)
|
||||
prevGroup = group
|
||||
}
|
||||
}}
|
||||
}
|
||||
|
||||
private def runLengthEncode(input: Seq[Int]): Seq[Int] = {
|
||||
if (input.isEmpty) return Seq(0)
|
||||
|
||||
input.indices
|
||||
.filter(i => i == 0 || input(i) != input(i - 1)) :+ input.length
|
||||
}
|
||||
|
||||
private def runRanker(ranker: XGBoostRanker, dataset: Dataset[_]): (Array[Float], Array[Int]) = {
|
||||
val (df, indices) = ranker.preprocess(dataset)
|
||||
val rdd = ranker.toRdd(df, indices)
|
||||
val result = rdd.mapPartitions { iter =>
|
||||
if (iter.hasNext) {
|
||||
val watches = iter.next()
|
||||
val dm = watches.toMap(Utils.TRAIN_NAME)
|
||||
val weight = dm.getWeight
|
||||
val group = dm.getGroup
|
||||
watches.delete()
|
||||
Iterator.single((weight, group))
|
||||
} else {
|
||||
Iterator.empty
|
||||
}
|
||||
}.collect()
|
||||
|
||||
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||
val group: ArrayBuffer[Int] = ArrayBuffer.empty
|
||||
|
||||
for (row <- result) {
|
||||
weight.append(row._1: _*)
|
||||
group.append(row._2: _*)
|
||||
}
|
||||
(weight.toArray, group.toArray)
|
||||
}
|
||||
|
||||
Seq(None, Some("weight")).foreach { weightCol => {
|
||||
val msg = weightCol.map(_ => "with weight").getOrElse("without weight")
|
||||
test(s"to RDD watches with group $msg") {
|
||||
// One instance without setting weight
|
||||
var df = ss.createDataFrame(sc.parallelize(Seq(
|
||||
(1.0, 0, 10, Vectors.dense(Array(1.0, 2.0, 3.0)))
|
||||
))).toDF("label", "group", "weight", "features")
|
||||
|
||||
val ranker = new XGBoostRanker()
|
||||
.setLabelCol("label")
|
||||
.setFeaturesCol("features")
|
||||
.setGroupCol("group")
|
||||
.setNumWorkers(1)
|
||||
|
||||
weightCol.foreach(ranker.setWeightCol)
|
||||
|
||||
val (weights, groupSize) = runRanker(ranker, df)
|
||||
val expectedWeight = weightCol.map(_ => Array(10.0f)).getOrElse(Array(1.0f))
|
||||
assert(weights === expectedWeight)
|
||||
assert(groupSize === runLengthEncode(Seq(0)))
|
||||
|
||||
df = ss.createDataFrame(sc.parallelize(Seq(
|
||||
(1.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))),
|
||||
(2.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))),
|
||||
(1.0, 0, 5, Vectors.dense(Array(1.0, 2.0, 3.0))),
|
||||
(0.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))),
|
||||
(1.0, 0, 5, Vectors.dense(Array(1.0, 2.0, 3.0))),
|
||||
(2.0, 2, 7, Vectors.dense(Array(1.0, 2.0, 3.0)))
|
||||
))).toDF("label", "group", "weight", "features")
|
||||
|
||||
val groups = Array(1, 1, 0, 1, 0, 2).sorted
|
||||
val (weights1, groupSize1) = runRanker(ranker, df)
|
||||
val expectedWeight1 = weightCol.map(_ => Array(5.0f, 2.0f, 7.0f))
|
||||
.getOrElse(groups.distinct.map(_ => 1.0f))
|
||||
|
||||
assert(groupSize1 === runLengthEncode(groups))
|
||||
assert(weights1 === expectedWeight1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("XGBoost-Spark output should match XGBoost4j") {
|
||||
val trainingDM = new DMatrix(Ranking.train.iterator)
|
||||
val weights = Ranking.trainGroups.distinct.map(_ => 1.0f).toArray
|
||||
trainingDM.setQueryId(Ranking.trainGroups.toArray)
|
||||
trainingDM.setWeight(weights)
|
||||
|
||||
val testDM = new DMatrix(Ranking.test.iterator)
|
||||
val trainingDF = buildDataFrameWithGroup(Ranking.train)
|
||||
val testDF = buildDataFrameWithGroup(Ranking.test)
|
||||
val paramMap = Map("objective" -> "rank:ndcg")
|
||||
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
|
||||
}
|
||||
|
||||
test("XGBoost-Spark output with weight should match XGBoost4j") {
|
||||
val trainingDM = new DMatrix(Ranking.trainWithWeight.iterator)
|
||||
trainingDM.setQueryId(Ranking.trainGroups.toArray)
|
||||
trainingDM.setWeight(Ranking.trainGroups.distinct.map(_.toFloat).toArray)
|
||||
|
||||
val testDM = new DMatrix(Ranking.test.iterator)
|
||||
val trainingDF = buildDataFrameWithGroup(Ranking.trainWithWeight)
|
||||
val testDF = buildDataFrameWithGroup(Ranking.test)
|
||||
val paramMap = Map("objective" -> "rank:ndcg")
|
||||
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF,
|
||||
5, paramMap, Some("weight"))
|
||||
}
|
||||
|
||||
private def checkResultsWithXGBoost4j(
|
||||
trainingDM: DMatrix,
|
||||
testDM: DMatrix,
|
||||
trainingDF: DataFrame,
|
||||
testDF: DataFrame,
|
||||
round: Int = 5,
|
||||
xgbParams: Map[String, Any] = Map.empty,
|
||||
weightCol: Option[String] = None): Unit = {
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"base_score" -> 0.5,
|
||||
"max_bin" -> 16) ++ xgbParams
|
||||
val xgb4jModel = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||
|
||||
val ranker = new XGBoostRanker(paramMap)
|
||||
.setNumRound(round)
|
||||
// If we use multi workers to train the ranking, the result probably will be different
|
||||
.setNumWorkers(1)
|
||||
.setLeafPredictionCol("leaf")
|
||||
.setContribPredictionCol("contrib")
|
||||
.setGroupCol("group")
|
||||
weightCol.foreach(weight => ranker.setWeightCol(weight))
|
||||
|
||||
def checkEqual(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
|
||||
assert(left.size === right.size)
|
||||
left.zipWithIndex.foreach { case (leftValue, index) =>
|
||||
assert(leftValue.sameElements(right(index)))
|
||||
}
|
||||
}
|
||||
|
||||
val xgbSparkModel = ranker.fit(trainingDF)
|
||||
val rows = xgbSparkModel.transform(testDF).collect()
|
||||
|
||||
// Check Leaf
|
||||
val xgb4jLeaf = xgb4jModel.predictLeaf(testDM)
|
||||
val xgbSparkLeaf = rows.map(row =>
|
||||
(row.getAs[Int]("id"), row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))).toMap
|
||||
checkEqual(xgb4jLeaf, xgbSparkLeaf)
|
||||
|
||||
// Check contrib
|
||||
val xgb4jContrib = xgb4jModel.predictContrib(testDM)
|
||||
val xgbSparkContrib = rows.map(row =>
|
||||
(row.getAs[Int]("id"), row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))).toMap
|
||||
checkEqual(xgb4jContrib, xgbSparkContrib)
|
||||
|
||||
// Check prediction
|
||||
val xgb4jPred = xgb4jModel.predict(testDM)
|
||||
val xgbSparkPred = rows.map(row => {
|
||||
val pred = row.getAs[Double]("prediction").toFloat
|
||||
(row.getAs[Int]("id"), Array(pred))
|
||||
}).toMap
|
||||
checkEqual(xgb4jPred, xgbSparkPred)
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user