From 9504f411c1f13eff604285f9e93d25c48ef11369 Mon Sep 17 00:00:00 2001 From: weitian <5275150+weitian@users.noreply.github.com> Date: Tue, 9 Oct 2018 09:03:22 -0700 Subject: [PATCH] [jvm-packages] For training data with group, empty RDD partition threw exception (#3749) (#3750) --- CONTRIBUTORS.md | 1 + .../scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 2 +- .../xgboost4j/scala/spark/XGBoostGeneralSuite.scala | 11 ++++++++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index f23e80426..edb30746e 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -80,3 +80,4 @@ List of Contributors - liuliang01 added support for the qid column for LibSVM input format. This makes ranking task easier in distributed setting. * [Andrew Thia](https://github.com/BlueTea88) - Andrew Thia implemented feature interaction constraints +* [Wei Tian](https://github.com/weitian) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 038a65889..4177af88a 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -497,7 +497,7 @@ private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint]) extends AbstractIterator[XGBLabeledPointGroup] { private var firstPointOfNextGroup: XGBLabeledPoint = null - private var isNewGroup = true + private var isNewGroup = false override def hasNext: Boolean = { return base.hasNext || isNewGroup diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 243ab4fc2..b20ccd451 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -24,6 +24,7 @@ import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.TaskContext import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql._ import org.scalatest.FunSuite @@ -256,8 +257,16 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { } } + test("repartitionForTrainingGroup with group data which has empty partition") { + val trainingRDD = sc.parallelize(Ranking.train, 5).mapPartitions(it => { + // make one partition empty for testing + it.filter(_ => TaskContext.getPartitionId() != 3) + }) + XGBoost.repartitionForTrainingGroup(trainingRDD, 4) + } + test("distributed training with group data") { - val trainingRDD = sc.parallelize(Ranking.train, 2) + val trainingRDD = sc.parallelize(Ranking.train, 5) val (booster, metrics) = XGBoost.trainDistributed( trainingRDD, List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",