[jvm-packages] call setGroup for ranking task (#2066)
* [jvm-packages] call setGroup for ranking task * passing groupData through xgBoostConfMap * fix original comment position * make groupData param * remove groupData variable, use xgBoostConfMap directly * set default groupData value * add use groupData tests * reduce rank-demo size * use TaskContext.getPartitionId() instead of mapPartitionsWithIndex * add DF use groupData test * remove unused varable
This commit is contained in:
@@ -239,4 +239,36 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers)
|
||||
}
|
||||
|
||||
test("test DF use nested groupData") {
|
||||
val testItr = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile).iterator.
|
||||
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
|
||||
(id, instance.features, instance.label)
|
||||
}
|
||||
val trainingDF = {
|
||||
val rowList0 = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile)
|
||||
val labeledPointsRDD0 = sc.parallelize(rowList0, numSlices = 1)
|
||||
val rowList1 = loadLabelPoints(getClass.getResource("/rank-demo-1.txt.train").getFile)
|
||||
val labeledPointsRDD1 = sc.parallelize(rowList1, numSlices = 1)
|
||||
val labeledPointsRDD = labeledPointsRDD0.union(labeledPointsRDD1)
|
||||
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
||||
import sparkSession.implicits._
|
||||
sparkSession.createDataset(labeledPointsRDD).toDF
|
||||
}
|
||||
val trainGroupData0: Seq[Int] = Source.fromFile(
|
||||
getClass.getResource("/rank-demo-0.txt.train.group").getFile).getLines().map(_.toInt).toList
|
||||
val trainGroupData1: Seq[Int] = Source.fromFile(
|
||||
getClass.getResource("/rank-demo-1.txt.train.group").getFile).getLines().map(_.toInt).toList
|
||||
val trainGroupData: Seq[Seq[Int]] = Seq(trainGroupData0, trainGroupData1)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
|
||||
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = 2)
|
||||
val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF(
|
||||
"id", "features", "label")
|
||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
|
||||
assert(testDF.count() === predResultsFromDF.size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import java.nio.file.Files
|
||||
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
import scala.util.Random
|
||||
import scala.concurrent.duration._
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
|
||||
@@ -341,4 +342,46 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||
}
|
||||
|
||||
test("test use groupData") {
|
||||
val trainSet = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile)
|
||||
val trainingRDD = sc.parallelize(trainSet, numSlices = 1)
|
||||
val trainGroupData: Seq[Seq[Int]] = Seq(Source.fromFile(
|
||||
getClass.getResource("/rank-demo-0.txt.train.group").getFile).getLines().map(_.toInt).toList)
|
||||
val testSet = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile)
|
||||
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
||||
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
|
||||
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1)
|
||||
val predRDD = xgBoostModel.predict(testRDD)
|
||||
val predResult1: Array[Array[Float]] = predRDD.collect()(0)
|
||||
assert(testRDD.count() === predResult1.length)
|
||||
}
|
||||
|
||||
test("test use nested groupData") {
|
||||
val trainSet0 = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile)
|
||||
val trainingRDD0 = sc.parallelize(trainSet0, numSlices = 1)
|
||||
val trainSet1 = loadLabelPoints(getClass.getResource("/rank-demo-1.txt.train").getFile)
|
||||
val trainingRDD1 = sc.parallelize(trainSet1, numSlices = 1)
|
||||
val trainingRDD = trainingRDD0.union(trainingRDD1)
|
||||
|
||||
val trainGroupData0: Seq[Int] = Source.fromFile(
|
||||
getClass.getResource("/rank-demo-0.txt.train.group").getFile).getLines().map(_.toInt).toList
|
||||
val trainGroupData1: Seq[Int] = Source.fromFile(
|
||||
getClass.getResource("/rank-demo-1.txt.train.group").getFile).getLines().map(_.toInt).toList
|
||||
val trainGroupData: Seq[Seq[Int]] = Seq(trainGroupData0, trainGroupData1)
|
||||
|
||||
val testSet = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile)
|
||||
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
||||
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
|
||||
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2)
|
||||
val predRDD = xgBoostModel.predict(testRDD)
|
||||
val predResult1: Array[Array[Float]] = predRDD.collect()(0)
|
||||
assert(testRDD.count() === predResult1.length)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user