[jvm-packages][refactor] refactor XGBoost.scala (spark) (#3904)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * wrap iterators * remove unused code * refactor * fix typo
This commit is contained in:
parent
0cd326c1bc
commit
aa48b7e903
@ -26,9 +26,9 @@ import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker
|
|||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils
|
import org.apache.commons.io.FileUtils
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||||
import org.apache.spark.sql.SparkSession
|
import org.apache.spark.sql.SparkSession
|
||||||
@ -207,21 +207,14 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private def parameterFetchAndValidation(params: Map[String, Any], sparkContext: SparkContext) = {
|
||||||
* @return A tuple of the booster and the metrics used to build training summary
|
val nWorkers = params("num_workers").asInstanceOf[Int]
|
||||||
*/
|
val round = params("num_round").asInstanceOf[Int]
|
||||||
@throws(classOf[XGBoostError])
|
val useExternalMemory = params("use_external_memory").asInstanceOf[Boolean]
|
||||||
private[spark] def trainDistributed(
|
val obj = params.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||||
trainingData: RDD[XGBLabeledPoint],
|
val eval = params.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||||
params: Map[String, Any],
|
val missing = params.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||||
round: Int,
|
validateSparkSslConf(sparkContext)
|
||||||
nWorkers: Int,
|
|
||||||
obj: ObjectiveTrait = null,
|
|
||||||
eval: EvalTrait = null,
|
|
||||||
useExternalMemory: Boolean = false,
|
|
||||||
missing: Float = Float.NaN,
|
|
||||||
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
|
|
||||||
validateSparkSslConf(trainingData.context)
|
|
||||||
if (params.contains("tree_method")) {
|
if (params.contains("tree_method")) {
|
||||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||||
" for now")
|
" for now")
|
||||||
@ -245,11 +238,60 @@ object XGBoost extends Serializable {
|
|||||||
" an instance of Long.")
|
" an instance of Long.")
|
||||||
}
|
}
|
||||||
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
||||||
|
(nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers,
|
||||||
|
checkpointPath, checkpointInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def trainForNonRanking(
|
||||||
|
trainingData: RDD[XGBLabeledPoint],
|
||||||
|
params: Map[String, Any],
|
||||||
|
rabitEnv: java.util.Map[String, String],
|
||||||
|
checkpointRound: Int,
|
||||||
|
prevBooster: Booster) = {
|
||||||
|
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||||
|
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||||
|
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||||
|
partitionedData.mapPartitions(labeledPoints => {
|
||||||
|
val watches = Watches.buildWatches(params,
|
||||||
|
removeMissingValues(labeledPoints, missing),
|
||||||
|
getCacheDirName(useExternalMemory))
|
||||||
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||||
|
obj, eval, prevBooster)
|
||||||
|
}).cache()
|
||||||
|
}
|
||||||
|
|
||||||
|
private def trainForRanking(
|
||||||
|
trainingData: RDD[XGBLabeledPoint],
|
||||||
|
params: Map[String, Any],
|
||||||
|
rabitEnv: java.util.Map[String, String],
|
||||||
|
checkpointRound: Int,
|
||||||
|
prevBooster: Booster) = {
|
||||||
|
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||||
|
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||||
|
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||||
|
partitionedData.mapPartitions(labeledPointGroups => {
|
||||||
|
val watches = Watches.buildWatchesWithGroup(params,
|
||||||
|
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||||
|
getCacheDirName(useExternalMemory))
|
||||||
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||||
|
obj, eval, prevBooster)
|
||||||
|
}).cache()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return A tuple of the booster and the metrics used to build training summary
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
private[spark] def trainDistributed(
|
||||||
|
trainingData: RDD[XGBLabeledPoint],
|
||||||
|
params: Map[String, Any],
|
||||||
|
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
|
||||||
|
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
|
||||||
|
checkpointPath, checkpointInterval) = parameterFetchAndValidation(params,
|
||||||
|
trainingData.sparkContext)
|
||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||||
checkpointManager.cleanUpHigherVersions(round)
|
checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int])
|
||||||
|
|
||||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||||
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||||
@ -259,27 +301,12 @@ object XGBoost extends Serializable {
|
|||||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||||
val rabitEnv = tracker.getWorkerEnvs
|
val rabitEnv = tracker.getWorkerEnvs
|
||||||
val boostersAndMetrics = hasGroup match {
|
val boostersAndMetrics = if (hasGroup) {
|
||||||
case true => {
|
trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
|
||||||
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
prevBooster)
|
||||||
partitionedData.mapPartitions(labeledPointGroups => {
|
} else {
|
||||||
val watches = Watches.buildWatchesWithGroup(overriddenParams,
|
trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
|
||||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
prevBooster)
|
||||||
getCacheDirName(useExternalMemory))
|
|
||||||
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
|
|
||||||
obj, eval, prevBooster)
|
|
||||||
}).cache()
|
|
||||||
}
|
|
||||||
case false => {
|
|
||||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
|
||||||
partitionedData.mapPartitions(labeledPoints => {
|
|
||||||
val watches = Watches.buildWatches(overriddenParams,
|
|
||||||
removeMissingValues(labeledPoints, missing),
|
|
||||||
getCacheDirName(useExternalMemory))
|
|
||||||
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
|
|
||||||
obj, eval, prevBooster)
|
|
||||||
}).cache()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
override def run() {
|
override def run() {
|
||||||
|
|||||||
@ -198,8 +198,7 @@ class XGBoostClassifier (
|
|||||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||||
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
||||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
hasGroup = false)
|
||||||
$(missing), hasGroup = false)
|
|
||||||
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
||||||
val summary = XGBoostTrainingSummary(_metrics)
|
val summary = XGBoostTrainingSummary(_metrics)
|
||||||
model.setSummary(summary)
|
model.setSummary(summary)
|
||||||
|
|||||||
@ -193,8 +193,7 @@ class XGBoostRegressor (
|
|||||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||||
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
||||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
hasGroup = group != lit(-1))
|
||||||
$(missing), hasGroup = group != lit(-1))
|
|
||||||
val model = new XGBoostRegressionModel(uid, _booster)
|
val model = new XGBoostRegressionModel(uid, _booster)
|
||||||
val summary = XGBoostTrainingSummary(_metrics)
|
val summary = XGBoostTrainingSummary(_metrics)
|
||||||
model.setSummary(summary)
|
model.setSummary(summary)
|
||||||
|
|||||||
@ -78,10 +78,10 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
val (booster, metrics) = XGBoost.trainDistributed(
|
val (booster, metrics) = XGBoost.trainDistributed(
|
||||||
trainingRDD,
|
trainingRDD,
|
||||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
|
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
||||||
hasGroup = false, missing = Float.NaN)
|
"missing" -> Float.NaN).toMap,
|
||||||
|
hasGroup = false)
|
||||||
assert(booster != null)
|
assert(booster != null)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -270,9 +270,10 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
val (booster, metrics) = XGBoost.trainDistributed(
|
val (booster, metrics) = XGBoost.trainDistributed(
|
||||||
trainingRDD,
|
trainingRDD,
|
||||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
|
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
||||||
hasGroup = true, missing = Float.NaN)
|
"missing" -> Float.NaN).toMap,
|
||||||
|
hasGroup = true)
|
||||||
|
|
||||||
assert(booster != null)
|
assert(booster != null)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user