This commit is contained in:
parent
ca33bf6476
commit
9504f411c1
@ -80,3 +80,4 @@ List of Contributors
|
|||||||
- liuliang01 added support for the qid column for LibSVM input format. This makes ranking task easier in distributed setting.
|
- liuliang01 added support for the qid column for LibSVM input format. This makes ranking task easier in distributed setting.
|
||||||
* [Andrew Thia](https://github.com/BlueTea88)
|
* [Andrew Thia](https://github.com/BlueTea88)
|
||||||
- Andrew Thia implemented feature interaction constraints
|
- Andrew Thia implemented feature interaction constraints
|
||||||
|
* [Wei Tian](https://github.com/weitian)
|
||||||
|
|||||||
@ -497,7 +497,7 @@ private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
|
|||||||
extends AbstractIterator[XGBLabeledPointGroup] {
|
extends AbstractIterator[XGBLabeledPointGroup] {
|
||||||
|
|
||||||
private var firstPointOfNextGroup: XGBLabeledPoint = null
|
private var firstPointOfNextGroup: XGBLabeledPoint = null
|
||||||
private var isNewGroup = true
|
private var isNewGroup = false
|
||||||
|
|
||||||
override def hasNext: Boolean = {
|
override def hasNext: Boolean = {
|
||||||
return base.hasNext || isNewGroup
|
return base.hasNext || isNewGroup
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import ml.dmlc.xgboost4j.scala.DMatrix
|
|||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||||
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.ml.linalg.Vectors
|
import org.apache.spark.ml.linalg.Vectors
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
@ -256,8 +257,16 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("repartitionForTrainingGroup with group data which has empty partition") {
|
||||||
|
val trainingRDD = sc.parallelize(Ranking.train, 5).mapPartitions(it => {
|
||||||
|
// make one partition empty for testing
|
||||||
|
it.filter(_ => TaskContext.getPartitionId() != 3)
|
||||||
|
})
|
||||||
|
XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||||
|
}
|
||||||
|
|
||||||
test("distributed training with group data") {
|
test("distributed training with group data") {
|
||||||
val trainingRDD = sc.parallelize(Ranking.train, 2)
|
val trainingRDD = sc.parallelize(Ranking.train, 5)
|
||||||
val (booster, metrics) = XGBoost.trainDistributed(
|
val (booster, metrics) = XGBoost.trainDistributed(
|
||||||
trainingRDD,
|
trainingRDD,
|
||||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user