[jvm-packages] Rework the train pipeline (#7401)

1. Add PreXGBoost to build RDD[Watches] from Dataset
2. Feed RDD[Watches] built from PreXGBoost to XGBoost to train
This commit is contained in:
Bobby Wang
2021-11-10 17:51:38 +08:00
committed by GitHub
parent 8df0a252b7
commit cb685607b2
8 changed files with 631 additions and 470 deletions

View File

@@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.linalg.Vectors
import org.scalatest.FunSuite
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
import org.apache.spark.sql.functions._
@@ -55,13 +56,13 @@ class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite wit
resultDF
})
val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs(
col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
numWorkers,
deterministicPartition = true,
PackedParams(col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
numWorkers,
deterministicPartition = true),
df
).head)
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
@@ -90,14 +91,13 @@ class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite wit
val df = ss.createDataFrame(sc.parallelize(dataset)).toDF("id", "label", "features")
val dfRepartitioned = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
10,
deterministicPartition = true,
df
PackedParams(col("label"),
col("features"),
lit(1.0),
lit(Float.NaN),
None,
10,
deterministicPartition = true), df
).head
val partitionsSizes = dfRepartitioned

View File

@@ -17,12 +17,14 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError
import scala.util.Random
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.TaskContext
import org.scalatest.FunSuite
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.lit
@@ -30,13 +32,14 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
test("distributed training with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train)
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD)
val (booster, metrics) = XGBoost.trainDistributed(
trainingRDD,
sc,
buildTrainingRDD,
List("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap,
hasGroup = false)
"missing" -> Float.NaN).toMap)
assert(booster != null)
}
@@ -179,7 +182,7 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
// test different splits to cover the corner cases.
for (split <- 1 to 20) {
val trainingRDD = sc.parallelize(Ranking.train, split)
val traingGroupsRDD = XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
val traingGroupsRDD = PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
// check the the order of the groups with group id.
// Ranking.train has 20 groups
@@ -201,18 +204,19 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
// make one partition empty for testing
it.filter(_ => TaskContext.getPartitionId() != 3)
})
XGBoost.repartitionForTrainingGroup(trainingRDD, 4)
PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
}
test("distributed training with group data") {
val trainingRDD = sc.parallelize(Ranking.train, 5)
val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD, hasGroup = true)
val (booster, _) = XGBoost.trainDistributed(
trainingRDD,
sc,
buildTrainingRDD,
List("eta" -> "1", "max_depth" -> "6",
"objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap,
hasGroup = true)
"missing" -> Float.NaN).toMap)
assert(booster != null)
}