[jvm-packages] Support Ranker (#10823)
This commit is contained in:
parent
d7599e095b
commit
19b55b300b
@ -93,7 +93,8 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
|
|||||||
selectedCols.append(col)
|
selectedCols.append(col)
|
||||||
}
|
}
|
||||||
val input = dataset.select(selectedCols.toArray: _*)
|
val input = dataset.select(selectedCols.toArray: _*)
|
||||||
estimator.repartitionIfNeeded(input)
|
val repartitioned = estimator.repartitionIfNeeded(input)
|
||||||
|
estimator.sortPartitionIfNeeded(repartitioned)
|
||||||
}
|
}
|
||||||
|
|
||||||
// visible for testing
|
// visible for testing
|
||||||
|
|||||||
@ -16,14 +16,14 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ai.rapids.cudf.Table
|
import ai.rapids.cudf.{OrderByArg, Table}
|
||||||
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix, XGBoost => ScalaXGBoost}
|
||||||
import ml.dmlc.xgboost4j.scala.rapids.spark.GpuTestSuite
|
import ml.dmlc.xgboost4j.scala.rapids.spark.GpuTestSuite
|
||||||
import ml.dmlc.xgboost4j.scala.rapids.spark.SparkSessionHolder.withSparkSession
|
import ml.dmlc.xgboost4j.scala.rapids.spark.SparkSessionHolder.withSparkSession
|
||||||
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
|
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
|
||||||
import org.apache.spark.ml.linalg.DenseVector
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
import org.apache.spark.sql.{Dataset, SparkSession}
|
import org.apache.spark.sql.{Dataset, Row, SparkSession}
|
||||||
import org.apache.spark.SparkConf
|
import org.apache.spark.SparkConf
|
||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
@ -94,7 +94,9 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// spark.rapids.sql.enabled is not set explicitly, default to true
|
// spark.rapids.sql.enabled is not set explicitly, default to true
|
||||||
withSparkSession(new SparkConf(), spark => {checkIsEnabled(spark, true)})
|
withSparkSession(new SparkConf(), spark => {
|
||||||
|
checkIsEnabled(spark, true)
|
||||||
|
})
|
||||||
|
|
||||||
// set spark.rapids.sql.enabled to false
|
// set spark.rapids.sql.enabled to false
|
||||||
withCpuSparkSession() { spark =>
|
withCpuSparkSession() { spark =>
|
||||||
@ -503,6 +505,109 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("The group col should be sorted in each partition") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
val df = Ranking.train.toDF("label", "weight", "group", "c1", "c2", "c3")
|
||||||
|
|
||||||
|
val xgboostParams: Map[String, Any] = Map(
|
||||||
|
"device" -> "cuda",
|
||||||
|
"objective" -> "rank:ndcg"
|
||||||
|
)
|
||||||
|
val features = Array("c1", "c2", "c3")
|
||||||
|
val label = "label"
|
||||||
|
val group = "group"
|
||||||
|
|
||||||
|
val ranker = new XGBoostRanker(xgboostParams)
|
||||||
|
.setFeaturesCol(features)
|
||||||
|
.setLabelCol(label)
|
||||||
|
.setNumWorkers(1)
|
||||||
|
.setNumRound(1)
|
||||||
|
.setGroupCol(group)
|
||||||
|
.setDevice("cuda")
|
||||||
|
|
||||||
|
val processedDf = ranker.getPlugin.get.asInstanceOf[GpuXGBoostPlugin].preprocess(ranker, df)
|
||||||
|
processedDf.rdd.foreachPartition { iter => {
|
||||||
|
var prevGroup = Int.MinValue
|
||||||
|
while (iter.hasNext) {
|
||||||
|
val curr = iter.next()
|
||||||
|
val group = curr.asInstanceOf[Row].getAs[Int](1)
|
||||||
|
assert(prevGroup <= group)
|
||||||
|
prevGroup = group
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Ranker: XGBoost-Spark should match xgboost4j") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
|
||||||
|
val trainPath = writeFile(Ranking.train.toDF("label", "weight", "group", "c1", "c2", "c3"))
|
||||||
|
val testPath = writeFile(Ranking.test.toDF("label", "weight", "group", "c1", "c2", "c3"))
|
||||||
|
|
||||||
|
val df = spark.read.parquet(trainPath)
|
||||||
|
val testdf = spark.read.parquet(testPath)
|
||||||
|
|
||||||
|
val features = Array("c1", "c2", "c3")
|
||||||
|
val featuresIndices = features.map(df.schema.fieldIndex)
|
||||||
|
val label = "label"
|
||||||
|
val group = "group"
|
||||||
|
|
||||||
|
val numRound = 100
|
||||||
|
val xgboostParams: Map[String, Any] = Map(
|
||||||
|
"device" -> "cuda",
|
||||||
|
"objective" -> "rank:ndcg"
|
||||||
|
)
|
||||||
|
|
||||||
|
val ranker = new XGBoostRanker(xgboostParams)
|
||||||
|
.setFeaturesCol(features)
|
||||||
|
.setLabelCol(label)
|
||||||
|
.setNumRound(numRound)
|
||||||
|
.setLeafPredictionCol("leaf")
|
||||||
|
.setContribPredictionCol("contrib")
|
||||||
|
.setGroupCol(group)
|
||||||
|
.setDevice("cuda")
|
||||||
|
|
||||||
|
val xgb4jModel = withResource(new GpuColumnBatch(
|
||||||
|
Table.readParquet(new File(trainPath)
|
||||||
|
).orderBy(OrderByArg.asc(df.schema.fieldIndex(group))))) { batch =>
|
||||||
|
val cb = new CudfColumnBatch(batch.select(featuresIndices),
|
||||||
|
batch.select(df.schema.fieldIndex(label)), null, null,
|
||||||
|
batch.select(df.schema.fieldIndex(group)))
|
||||||
|
val qdm = new QuantileDMatrix(Seq(cb).iterator, ranker.getMissing,
|
||||||
|
ranker.getMaxBins, ranker.getNthread)
|
||||||
|
ScalaXGBoost.train(qdm, xgboostParams, numRound)
|
||||||
|
}
|
||||||
|
|
||||||
|
val (xgb4jLeaf, xgb4jContrib, xgb4jPred) = withResource(new GpuColumnBatch(
|
||||||
|
Table.readParquet(new File(testPath)))) { batch =>
|
||||||
|
val cb = new CudfColumnBatch(batch.select(featuresIndices), null, null, null, null
|
||||||
|
)
|
||||||
|
val qdm = new DMatrix(cb, ranker.getMissing, ranker.getNthread)
|
||||||
|
(xgb4jModel.predictLeaf(qdm), xgb4jModel.predictContrib(qdm),
|
||||||
|
xgb4jModel.predict(qdm))
|
||||||
|
}
|
||||||
|
|
||||||
|
val rows = ranker.fit(df).transform(testdf).collect()
|
||||||
|
|
||||||
|
// Check Leaf
|
||||||
|
val xgbSparkLeaf = rows.map(row => row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))
|
||||||
|
checkEqual(xgb4jLeaf, xgbSparkLeaf)
|
||||||
|
|
||||||
|
// Check contrib
|
||||||
|
val xgbSparkContrib = rows.map(row =>
|
||||||
|
row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))
|
||||||
|
checkEqual(xgb4jContrib, xgbSparkContrib)
|
||||||
|
|
||||||
|
// Check prediction
|
||||||
|
val xgbSparkPred = rows.map(row =>
|
||||||
|
Array(row.getAs[Double]("prediction").toFloat))
|
||||||
|
checkEqual(xgb4jPred, xgbSparkPred)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def writeFile(df: Dataset[_]): String = {
|
def writeFile(df: Dataset[_]): String = {
|
||||||
def listFiles(directory: String): Array[String] = {
|
def listFiles(directory: String): Array[String] = {
|
||||||
val dir = new File(directory)
|
val dir = new File(directory)
|
||||||
|
|||||||
@ -81,6 +81,6 @@ object Regression extends TrainTestData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
object Ranking extends TrainTestData {
|
object Ranking extends TrainTestData {
|
||||||
val train = generateRankDataset(300, 10, 555)
|
val train = generateRankDataset(300, 10, 12, 555)
|
||||||
val test = generateRankDataset(150, 10, 556)
|
val test = generateRankDataset(150, 10, 12, 556)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -134,6 +134,15 @@ private[spark] trait XGBoostEstimator[
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sort partition for Ranker issue.
|
||||||
|
* @param dataset
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private[spark] def sortPartitionIfNeeded(dataset: Dataset[_]): Dataset[_] = {
|
||||||
|
dataset
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build the columns indices.
|
* Build the columns indices.
|
||||||
*/
|
*/
|
||||||
@ -198,10 +207,10 @@ private[spark] trait XGBoostEstimator[
|
|||||||
case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
|
case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
|
||||||
case _ =>
|
case _ =>
|
||||||
}
|
}
|
||||||
val input = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*))
|
val repartitioned = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*))
|
||||||
|
val sorted = sortPartitionIfNeeded(repartitioned)
|
||||||
val columnIndices = buildColumnIndices(input.schema)
|
val columnIndices = buildColumnIndices(sorted.schema)
|
||||||
(input, columnIndices)
|
(sorted, columnIndices)
|
||||||
}
|
}
|
||||||
|
|
||||||
/** visible for testing */
|
/** visible for testing */
|
||||||
|
|||||||
@ -0,0 +1,124 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import org.apache.spark.ml.{PredictionModel, Predictor}
|
||||||
|
import org.apache.spark.ml.linalg.Vector
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
|
||||||
|
import org.apache.spark.ml.xgboost.SparkUtils
|
||||||
|
import org.apache.spark.sql.Dataset
|
||||||
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostRanker._uid
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.RANKER_OBJS
|
||||||
|
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
||||||
|
|
||||||
|
class XGBoostRanker(override val uid: String,
|
||||||
|
private val xgboostParams: Map[String, Any])
|
||||||
|
extends Predictor[Vector, XGBoostRanker, XGBoostRankerModel]
|
||||||
|
with XGBoostEstimator[XGBoostRanker, XGBoostRankerModel] with HasGroupCol {
|
||||||
|
|
||||||
|
def this() = this(_uid, Map[String, Any]())
|
||||||
|
|
||||||
|
def this(uid: String) = this(uid, Map[String, Any]())
|
||||||
|
|
||||||
|
def this(xgboostParams: Map[String, Any]) = this(_uid, xgboostParams)
|
||||||
|
|
||||||
|
def setGroupCol(value: String): XGBoostRanker = set(groupCol, value)
|
||||||
|
|
||||||
|
xgboost2SparkParams(xgboostParams)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate the parameters before training, throw exception if possible
|
||||||
|
*/
|
||||||
|
override protected[spark] def validate(dataset: Dataset[_]): Unit = {
|
||||||
|
super.validate(dataset)
|
||||||
|
|
||||||
|
require(isDefinedNonEmpty(groupCol), "groupCol needs to be set")
|
||||||
|
|
||||||
|
// If the objective is set explicitly, it must be in RANKER_OBJS
|
||||||
|
if (isSet(objective)) {
|
||||||
|
val tmpObj = getObjective
|
||||||
|
require(RANKER_OBJS.contains(tmpObj),
|
||||||
|
s"Wrong objective for XGBoostRanker, supported objs: ${RANKER_OBJS.mkString(",")}")
|
||||||
|
} else {
|
||||||
|
setObjective("rank:ndcg")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sort partition for Ranker issue.
|
||||||
|
*
|
||||||
|
* @param dataset
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
override private[spark] def sortPartitionIfNeeded(dataset: Dataset[_]) = {
|
||||||
|
dataset.sortWithinPartitions(getGroupCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
override protected def createModel(
|
||||||
|
booster: Booster,
|
||||||
|
summary: XGBoostTrainingSummary): XGBoostRankerModel = {
|
||||||
|
new XGBoostRankerModel(uid, booster, Option(summary))
|
||||||
|
}
|
||||||
|
|
||||||
|
override protected def validateAndTransformSchema(
|
||||||
|
schema: StructType,
|
||||||
|
fitting: Boolean,
|
||||||
|
featuresDataType: DataType): StructType =
|
||||||
|
SparkUtils.appendColumn(schema, $(predictionCol), DoubleType)
|
||||||
|
}
|
||||||
|
|
||||||
|
object XGBoostRanker extends DefaultParamsReadable[XGBoostRanker] {
|
||||||
|
private val _uid = Identifiable.randomUID("xgbranker")
|
||||||
|
}
|
||||||
|
|
||||||
|
class XGBoostRankerModel private[ml](val uid: String,
|
||||||
|
val nativeBooster: Booster,
|
||||||
|
val summary: Option[XGBoostTrainingSummary] = None)
|
||||||
|
extends PredictionModel[Vector, XGBoostRankerModel]
|
||||||
|
with RankerRegressorBaseModel[XGBoostRankerModel] with HasGroupCol {
|
||||||
|
|
||||||
|
def this(uid: String) = this(uid, null)
|
||||||
|
|
||||||
|
def setGroupCol(value: String): XGBoostRankerModel = set(groupCol, value)
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): XGBoostRankerModel = {
|
||||||
|
val newModel = copyValues(new XGBoostRankerModel(uid, nativeBooster, summary), extra)
|
||||||
|
newModel.setParent(parent)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def predict(features: Vector): Double = {
|
||||||
|
val values = predictSingleInstance(features)
|
||||||
|
values(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object XGBoostRankerModel extends MLReadable[XGBoostRankerModel] {
|
||||||
|
override def read: MLReader[XGBoostRankerModel] = new ModelReader
|
||||||
|
|
||||||
|
private class ModelReader extends XGBoostModelReader[XGBoostRankerModel] {
|
||||||
|
override def load(path: String): XGBoostRankerModel = {
|
||||||
|
val xgbModel = loadBooster(path)
|
||||||
|
val meta = SparkUtils.loadMetadata(path, sc)
|
||||||
|
val model = new XGBoostRankerModel(meta.uid, xgbModel, None)
|
||||||
|
meta.getAndSetParams(model)
|
||||||
|
model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,309 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
|
import org.apache.spark.ml.linalg.{DenseVector, Vectors}
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.Regression.Ranking
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.RANKER_OBJS
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams
|
||||||
|
|
||||||
|
class XGBoostRankerSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||||
|
|
||||||
|
test("XGBoostRanker copy") {
|
||||||
|
val ranker = new XGBoostRanker().setNthread(2).setNumWorkers(10)
|
||||||
|
val rankertCopied = ranker.copy(ParamMap.empty)
|
||||||
|
|
||||||
|
assert(ranker.uid === rankertCopied.uid)
|
||||||
|
assert(ranker.getNthread === rankertCopied.getNthread)
|
||||||
|
assert(ranker.getNumWorkers === ranker.getNumWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoostRankerModel copy") {
|
||||||
|
val model = new XGBoostRankerModel("hello").setNthread(2).setNumWorkers(10)
|
||||||
|
val modelCopied = model.copy(ParamMap.empty)
|
||||||
|
assert(model.uid === modelCopied.uid)
|
||||||
|
assert(model.getNthread === modelCopied.getNthread)
|
||||||
|
assert(model.getNumWorkers === modelCopied.getNumWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("read/write") {
|
||||||
|
val trainDf = smallGroupVector
|
||||||
|
val xgbParams: Map[String, Any] = Map(
|
||||||
|
"max_depth" -> 5,
|
||||||
|
"eta" -> 0.2,
|
||||||
|
"objective" -> "rank:ndcg"
|
||||||
|
)
|
||||||
|
|
||||||
|
def check(xgboostParams: XGBoostParams[_]): Unit = {
|
||||||
|
assert(xgboostParams.getMaxDepth === 5)
|
||||||
|
assert(xgboostParams.getEta === 0.2)
|
||||||
|
assert(xgboostParams.getObjective === "rank:ndcg")
|
||||||
|
}
|
||||||
|
|
||||||
|
val rankerPath = new File(tempDir.toFile, "ranker").getPath
|
||||||
|
val ranker = new XGBoostRanker(xgbParams).setNumRound(1).setGroupCol("group")
|
||||||
|
check(ranker)
|
||||||
|
assert(ranker.getGroupCol === "group")
|
||||||
|
|
||||||
|
ranker.write.overwrite().save(rankerPath)
|
||||||
|
val loadedRanker = XGBoostRanker.load(rankerPath)
|
||||||
|
check(loadedRanker)
|
||||||
|
assert(loadedRanker.getGroupCol === "group")
|
||||||
|
|
||||||
|
val model = loadedRanker.fit(trainDf)
|
||||||
|
check(model)
|
||||||
|
assert(model.getGroupCol === "group")
|
||||||
|
|
||||||
|
val modelPath = new File(tempDir.toFile, "model").getPath
|
||||||
|
model.write.overwrite().save(modelPath)
|
||||||
|
val modelLoaded = XGBoostRankerModel.load(modelPath)
|
||||||
|
check(modelLoaded)
|
||||||
|
assert(modelLoaded.getGroupCol === "group")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("validate") {
|
||||||
|
val trainDf = smallGroupVector
|
||||||
|
val ranker = new XGBoostRanker()
|
||||||
|
// must define group column
|
||||||
|
intercept[IllegalArgumentException](
|
||||||
|
ranker.validate(trainDf)
|
||||||
|
)
|
||||||
|
val ranker1 = new XGBoostRanker().setGroupCol("group")
|
||||||
|
ranker1.validate(trainDf)
|
||||||
|
assert(ranker1.getObjective === "rank:ndcg")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoostRankerModel transformed schema") {
|
||||||
|
val trainDf = smallGroupVector
|
||||||
|
val ranker = new XGBoostRanker().setGroupCol("group").setNumRound(1)
|
||||||
|
val model = ranker.fit(trainDf)
|
||||||
|
var out = model.transform(trainDf)
|
||||||
|
// Transform should not discard the other columns of the transforming dataframe
|
||||||
|
Seq("label", "group", "margin", "weight", "features").foreach { v =>
|
||||||
|
assert(out.schema.names.contains(v))
|
||||||
|
}
|
||||||
|
// Ranker does not have extra columns
|
||||||
|
Seq("rawPrediction", "probability").foreach { v =>
|
||||||
|
assert(!out.schema.names.contains(v))
|
||||||
|
}
|
||||||
|
assert(out.schema.names.contains("prediction"))
|
||||||
|
assert(out.schema.names.length === 6)
|
||||||
|
model.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
|
||||||
|
out = model.transform(trainDf)
|
||||||
|
assert(out.schema.names.contains("leaf"))
|
||||||
|
assert(out.schema.names.contains("contrib"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Supported objectives") {
|
||||||
|
val ranker = new XGBoostRanker().setGroupCol("group")
|
||||||
|
val df = smallGroupVector
|
||||||
|
RANKER_OBJS.foreach { obj =>
|
||||||
|
ranker.setObjective(obj)
|
||||||
|
ranker.validate(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
ranker.setObjective("binary:logistic")
|
||||||
|
intercept[IllegalArgumentException](
|
||||||
|
ranker.validate(df)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("The group col should be sorted in each partition") {
|
||||||
|
val trainingDF = buildDataFrameWithGroup(Ranking.train)
|
||||||
|
|
||||||
|
val ranker = new XGBoostRanker()
|
||||||
|
.setNumRound(1)
|
||||||
|
.setNumWorkers(numWorkers)
|
||||||
|
.setGroupCol("group")
|
||||||
|
|
||||||
|
val (df, _) = ranker.preprocess(trainingDF)
|
||||||
|
df.rdd.foreachPartition { iter => {
|
||||||
|
var prevGroup = Int.MinValue
|
||||||
|
while (iter.hasNext) {
|
||||||
|
val curr = iter.next()
|
||||||
|
val group = curr.asInstanceOf[Row].getAs[Int](2)
|
||||||
|
assert(prevGroup <= group)
|
||||||
|
prevGroup = group
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def runLengthEncode(input: Seq[Int]): Seq[Int] = {
|
||||||
|
if (input.isEmpty) return Seq(0)
|
||||||
|
|
||||||
|
input.indices
|
||||||
|
.filter(i => i == 0 || input(i) != input(i - 1)) :+ input.length
|
||||||
|
}
|
||||||
|
|
||||||
|
private def runRanker(ranker: XGBoostRanker, dataset: Dataset[_]): (Array[Float], Array[Int]) = {
|
||||||
|
val (df, indices) = ranker.preprocess(dataset)
|
||||||
|
val rdd = ranker.toRdd(df, indices)
|
||||||
|
val result = rdd.mapPartitions { iter =>
|
||||||
|
if (iter.hasNext) {
|
||||||
|
val watches = iter.next()
|
||||||
|
val dm = watches.toMap(Utils.TRAIN_NAME)
|
||||||
|
val weight = dm.getWeight
|
||||||
|
val group = dm.getGroup
|
||||||
|
watches.delete()
|
||||||
|
Iterator.single((weight, group))
|
||||||
|
} else {
|
||||||
|
Iterator.empty
|
||||||
|
}
|
||||||
|
}.collect()
|
||||||
|
|
||||||
|
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
|
||||||
|
val group: ArrayBuffer[Int] = ArrayBuffer.empty
|
||||||
|
|
||||||
|
for (row <- result) {
|
||||||
|
weight.append(row._1: _*)
|
||||||
|
group.append(row._2: _*)
|
||||||
|
}
|
||||||
|
(weight.toArray, group.toArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
Seq(None, Some("weight")).foreach { weightCol => {
|
||||||
|
val msg = weightCol.map(_ => "with weight").getOrElse("without weight")
|
||||||
|
test(s"to RDD watches with group $msg") {
|
||||||
|
// One instance without setting weight
|
||||||
|
var df = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(1.0, 0, 10, Vectors.dense(Array(1.0, 2.0, 3.0)))
|
||||||
|
))).toDF("label", "group", "weight", "features")
|
||||||
|
|
||||||
|
val ranker = new XGBoostRanker()
|
||||||
|
.setLabelCol("label")
|
||||||
|
.setFeaturesCol("features")
|
||||||
|
.setGroupCol("group")
|
||||||
|
.setNumWorkers(1)
|
||||||
|
|
||||||
|
weightCol.foreach(ranker.setWeightCol)
|
||||||
|
|
||||||
|
val (weights, groupSize) = runRanker(ranker, df)
|
||||||
|
val expectedWeight = weightCol.map(_ => Array(10.0f)).getOrElse(Array(1.0f))
|
||||||
|
assert(weights === expectedWeight)
|
||||||
|
assert(groupSize === runLengthEncode(Seq(0)))
|
||||||
|
|
||||||
|
df = ss.createDataFrame(sc.parallelize(Seq(
|
||||||
|
(1.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))),
|
||||||
|
(2.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))),
|
||||||
|
(1.0, 0, 5, Vectors.dense(Array(1.0, 2.0, 3.0))),
|
||||||
|
(0.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))),
|
||||||
|
(1.0, 0, 5, Vectors.dense(Array(1.0, 2.0, 3.0))),
|
||||||
|
(2.0, 2, 7, Vectors.dense(Array(1.0, 2.0, 3.0)))
|
||||||
|
))).toDF("label", "group", "weight", "features")
|
||||||
|
|
||||||
|
val groups = Array(1, 1, 0, 1, 0, 2).sorted
|
||||||
|
val (weights1, groupSize1) = runRanker(ranker, df)
|
||||||
|
val expectedWeight1 = weightCol.map(_ => Array(5.0f, 2.0f, 7.0f))
|
||||||
|
.getOrElse(groups.distinct.map(_ => 1.0f))
|
||||||
|
|
||||||
|
assert(groupSize1 === runLengthEncode(groups))
|
||||||
|
assert(weights1 === expectedWeight1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoost-Spark output should match XGBoost4j") {
|
||||||
|
val trainingDM = new DMatrix(Ranking.train.iterator)
|
||||||
|
val weights = Ranking.trainGroups.distinct.map(_ => 1.0f).toArray
|
||||||
|
trainingDM.setQueryId(Ranking.trainGroups.toArray)
|
||||||
|
trainingDM.setWeight(weights)
|
||||||
|
|
||||||
|
val testDM = new DMatrix(Ranking.test.iterator)
|
||||||
|
val trainingDF = buildDataFrameWithGroup(Ranking.train)
|
||||||
|
val testDF = buildDataFrameWithGroup(Ranking.test)
|
||||||
|
val paramMap = Map("objective" -> "rank:ndcg")
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoost-Spark output with weight should match XGBoost4j") {
|
||||||
|
val trainingDM = new DMatrix(Ranking.trainWithWeight.iterator)
|
||||||
|
trainingDM.setQueryId(Ranking.trainGroups.toArray)
|
||||||
|
trainingDM.setWeight(Ranking.trainGroups.distinct.map(_.toFloat).toArray)
|
||||||
|
|
||||||
|
val testDM = new DMatrix(Ranking.test.iterator)
|
||||||
|
val trainingDF = buildDataFrameWithGroup(Ranking.trainWithWeight)
|
||||||
|
val testDF = buildDataFrameWithGroup(Ranking.test)
|
||||||
|
val paramMap = Map("objective" -> "rank:ndcg")
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF,
|
||||||
|
5, paramMap, Some("weight"))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def checkResultsWithXGBoost4j(
|
||||||
|
trainingDM: DMatrix,
|
||||||
|
testDM: DMatrix,
|
||||||
|
trainingDF: DataFrame,
|
||||||
|
testDF: DataFrame,
|
||||||
|
round: Int = 5,
|
||||||
|
xgbParams: Map[String, Any] = Map.empty,
|
||||||
|
weightCol: Option[String] = None): Unit = {
|
||||||
|
val paramMap = Map(
|
||||||
|
"eta" -> "1",
|
||||||
|
"max_depth" -> "6",
|
||||||
|
"base_score" -> 0.5,
|
||||||
|
"max_bin" -> 16) ++ xgbParams
|
||||||
|
val xgb4jModel = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||||
|
|
||||||
|
val ranker = new XGBoostRanker(paramMap)
|
||||||
|
.setNumRound(round)
|
||||||
|
// If we use multi workers to train the ranking, the result probably will be different
|
||||||
|
.setNumWorkers(1)
|
||||||
|
.setLeafPredictionCol("leaf")
|
||||||
|
.setContribPredictionCol("contrib")
|
||||||
|
.setGroupCol("group")
|
||||||
|
weightCol.foreach(weight => ranker.setWeightCol(weight))
|
||||||
|
|
||||||
|
def checkEqual(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
|
||||||
|
assert(left.size === right.size)
|
||||||
|
left.zipWithIndex.foreach { case (leftValue, index) =>
|
||||||
|
assert(leftValue.sameElements(right(index)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val xgbSparkModel = ranker.fit(trainingDF)
|
||||||
|
val rows = xgbSparkModel.transform(testDF).collect()
|
||||||
|
|
||||||
|
// Check Leaf
|
||||||
|
val xgb4jLeaf = xgb4jModel.predictLeaf(testDM)
|
||||||
|
val xgbSparkLeaf = rows.map(row =>
|
||||||
|
(row.getAs[Int]("id"), row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))).toMap
|
||||||
|
checkEqual(xgb4jLeaf, xgbSparkLeaf)
|
||||||
|
|
||||||
|
// Check contrib
|
||||||
|
val xgb4jContrib = xgb4jModel.predictContrib(testDM)
|
||||||
|
val xgbSparkContrib = rows.map(row =>
|
||||||
|
(row.getAs[Int]("id"), row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))).toMap
|
||||||
|
checkEqual(xgb4jContrib, xgbSparkContrib)
|
||||||
|
|
||||||
|
// Check prediction
|
||||||
|
val xgb4jPred = xgb4jModel.predict(testDM)
|
||||||
|
val xgbSparkPred = rows.map(row => {
|
||||||
|
val pred = row.getAs[Double]("prediction").toFloat
|
||||||
|
(row.getAs[Int]("id"), Array(pred))
|
||||||
|
}).toMap
|
||||||
|
checkEqual(xgb4jPred, xgbSparkPred)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user