[jvm-packages] Rework the train pipeline (#7401)
1. Add PreXGBoost to build RDD[Watches] from Dataset 2. Feed RDD[Watches] built from PreXGBoost to XGBoost to train
This commit is contained in:
@@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.scalatest.FunSuite
|
||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
@@ -55,13 +56,13 @@ class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite wit
|
||||
resultDF
|
||||
})
|
||||
val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col("label"),
|
||||
col("features"),
|
||||
lit(1.0),
|
||||
lit(Float.NaN),
|
||||
None,
|
||||
numWorkers,
|
||||
deterministicPartition = true,
|
||||
PackedParams(col("label"),
|
||||
col("features"),
|
||||
lit(1.0),
|
||||
lit(Float.NaN),
|
||||
None,
|
||||
numWorkers,
|
||||
deterministicPartition = true),
|
||||
df
|
||||
).head)
|
||||
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
|
||||
@@ -90,14 +91,13 @@ class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite wit
|
||||
val df = ss.createDataFrame(sc.parallelize(dataset)).toDF("id", "label", "features")
|
||||
|
||||
val dfRepartitioned = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
|
||||
col("label"),
|
||||
col("features"),
|
||||
lit(1.0),
|
||||
lit(Float.NaN),
|
||||
None,
|
||||
10,
|
||||
deterministicPartition = true,
|
||||
df
|
||||
PackedParams(col("label"),
|
||||
col("features"),
|
||||
lit(1.0),
|
||||
lit(Float.NaN),
|
||||
None,
|
||||
10,
|
||||
deterministicPartition = true), df
|
||||
).head
|
||||
|
||||
val partitionsSizes = dfRepartitioned
|
||||
|
||||
@@ -17,12 +17,14 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.sql.functions.lit
|
||||
|
||||
@@ -30,13 +32,14 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
test("distributed training with the specified worker number") {
|
||||
val trainingRDD = sc.parallelize(Classification.train)
|
||||
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD)
|
||||
val (booster, metrics) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
sc,
|
||||
buildTrainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
||||
"missing" -> Float.NaN).toMap,
|
||||
hasGroup = false)
|
||||
"missing" -> Float.NaN).toMap)
|
||||
assert(booster != null)
|
||||
}
|
||||
|
||||
@@ -179,7 +182,7 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
// test different splits to cover the corner cases.
|
||||
for (split <- 1 to 20) {
|
||||
val trainingRDD = sc.parallelize(Ranking.train, split)
|
||||
val traingGroupsRDD = XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||
val traingGroupsRDD = PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||
val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
|
||||
// check the the order of the groups with group id.
|
||||
// Ranking.train has 20 groups
|
||||
@@ -201,18 +204,19 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
// make one partition empty for testing
|
||||
it.filter(_ => TaskContext.getPartitionId() != 3)
|
||||
})
|
||||
XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||
PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
|
||||
}
|
||||
|
||||
test("distributed training with group data") {
|
||||
val trainingRDD = sc.parallelize(Ranking.train, 5)
|
||||
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD, hasGroup = true)
|
||||
val (booster, _) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
sc,
|
||||
buildTrainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6",
|
||||
"objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
||||
"missing" -> Float.NaN).toMap,
|
||||
hasGroup = true)
|
||||
"missing" -> Float.NaN).toMap)
|
||||
|
||||
assert(booster != null)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user