This commit is contained in:
@@ -24,6 +24,7 @@ import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
@@ -256,8 +257,16 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
}
|
||||
}
|
||||
|
||||
test("repartitionForTrainingGroup with group data which has empty partition") {
|
||||
val trainingRDD = sc.parallelize(Ranking.train, 5).mapPartitions(it => {
|
||||
// make one partition empty for testing
|
||||
it.filter(_ => TaskContext.getPartitionId() != 3)
|
||||
})
|
||||
XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||
}
|
||||
|
||||
test("distributed training with group data") {
|
||||
val trainingRDD = sc.parallelize(Ranking.train, 2)
|
||||
val trainingRDD = sc.parallelize(Ranking.train, 5)
|
||||
val (booster, metrics) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
|
||||
Reference in New Issue
Block a user