[jvm-packages] For training data with group, empty RDD partition threw exception (#3749) (#3750)

This commit is contained in:
weitian 2018-10-09 09:03:22 -07:00 committed by Nan Zhu
parent ca33bf6476
commit 9504f411c1
3 changed files with 12 additions and 2 deletions

View File

@ -80,3 +80,4 @@ List of Contributors
- liuliang01 added support for the qid column for LibSVM input format. This makes ranking task easier in distributed setting.
* [Andrew Thia](https://github.com/BlueTea88)
- Andrew Thia implemented feature interaction constraints
* [Wei Tian](https://github.com/weitian)

View File

@ -497,7 +497,7 @@ private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
extends AbstractIterator[XGBLabeledPointGroup] {
private var firstPointOfNextGroup: XGBLabeledPoint = null
private var isNewGroup = true
private var isNewGroup = false
override def hasNext: Boolean = {
return base.hasNext || isNewGroup

View File

@ -24,6 +24,7 @@ import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.TaskContext
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql._
import org.scalatest.FunSuite
@ -256,8 +257,16 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
}
}
test("repartitionForTrainingGroup with group data which has empty partition") {
val trainingRDD = sc.parallelize(Ranking.train, 5).mapPartitions(it => {
// make one partition empty for testing
it.filter(_ => TaskContext.getPartitionId() != 3)
})
XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
}
test("distributed training with group data") {
val trainingRDD = sc.parallelize(Ranking.train, 2)
val trainingRDD = sc.parallelize(Ranking.train, 5)
val (booster, metrics) = XGBoost.trainDistributed(
trainingRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",