[jvm-packages] Add some documentation to xgboost4j-spark plus minor style edits (#2823)
* add scala docs to several methods * indentation * license formatting * clarify distributed boosters * address some review comments * reduce doc lengths * change method name, clarify doc * reset make config * delete most comments * more review feedback
This commit is contained in:
parent
46f2b820f1
commit
a8f670d247
@ -30,28 +30,30 @@ import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations.
|
||||
*
|
||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||
* Set timeout length to zero to disable timeout.
|
||||
* Use a finite, non-zero timeout value to prevent tracker from
|
||||
* hanging indefinitely (in milliseconds)
|
||||
* (supported by "scala" implementation only.)
|
||||
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
||||
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||
* in Scala without Python components, and with full support of timeouts.
|
||||
* The Scala implementation is currently experimental, use at your own risk.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||
}
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations.
|
||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||
* Set timeout length to zero to disable timeout.
|
||||
* Use a finite, non-zero timeout value to prevent tracker from
|
||||
* hanging indefinitely (in milliseconds)
|
||||
* (supported by "scala" implementation only.)
|
||||
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
||||
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||
* in Scala without Python components, and with full support of timeouts.
|
||||
* The Scala implementation is currently experimental, use at your own risk.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private def fromDenseToSparseLabeledPoints(
|
||||
private def removeMissingValues(
|
||||
denseLabeledPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float): Iterator[XGBLabeledPoint] = {
|
||||
if (!missing.isNaN) {
|
||||
@ -89,7 +91,7 @@ object XGBoost extends Serializable {
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
||||
"If you want to specify base margin, ensure all values are non-NaN.")
|
||||
s"If you want to specify base margin, ensure all values are non-NaN.")
|
||||
}
|
||||
}
|
||||
|
||||
@ -118,23 +120,23 @@ object XGBoost extends Serializable {
|
||||
if (labeledPoints.isEmpty) {
|
||||
throw new XGBoostError(
|
||||
s"detected an empty partition in the training data, partition ID:" +
|
||||
s" ${TaskContext.getPartitionId()}")
|
||||
s" ${TaskContext.getPartitionId()}")
|
||||
}
|
||||
val cacheFileName = if (useExternalMemory) {
|
||||
s"$appName-${TaskContext.get().stageId()}-" +
|
||||
s"dtrain_cache-${TaskContext.getPartitionId()}"
|
||||
s"dtrain_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv)
|
||||
val watches = Watches(params,
|
||||
fromDenseToSparseLabeledPoints(labeledPoints, missing),
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
fromBaseMarginsToArray(baseMargins), cacheFileName)
|
||||
|
||||
try {
|
||||
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val booster = SXGBoost.train(watches.train, params, round,
|
||||
watches = watches.toMap, obj = obj, eval = eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds)
|
||||
@ -147,17 +149,18 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* train XGBoost model with the DataFrame-represented data
|
||||
* @param trainingData the trainingset represented as DataFrame
|
||||
* Train XGBoost model with the DataFrame-represented data
|
||||
*
|
||||
* @param trainingData the training set represented as DataFrame
|
||||
* @param params Map containing the parameters to configure XGBoost
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
* @param obj the user-defined objective function, null by default
|
||||
* @param eval the user-defined evaluation function, null by default
|
||||
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing the value represented the missing value in the dataset
|
||||
* @param missing The value which represents a missing value in the dataset
|
||||
* @param featureCol the name of input column, "features" as default value
|
||||
* @param labelCol the name of output column, "label" as default value
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
@ -200,14 +203,15 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* train XGBoost model with the RDD-represented data
|
||||
* @param trainingData the trainingset represented as RDD
|
||||
* Train XGBoost model with the RDD-represented data
|
||||
*
|
||||
* @param trainingData the training set represented as RDD
|
||||
* @param params Map containing the configuration entries
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
* @param obj the user-defined objective function, null by default
|
||||
* @param eval the user-defined evaluation function, null by default
|
||||
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing the value represented the missing value in the dataset
|
||||
@ -224,8 +228,7 @@ object XGBoost extends Serializable {
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN): XGBoostModel = {
|
||||
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory,
|
||||
missing)
|
||||
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing)
|
||||
}
|
||||
|
||||
private def overrideParamsAccordingToTaskCPUs(
|
||||
@ -256,18 +259,19 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* various of train()
|
||||
* @param trainingData the trainingset represented as RDD
|
||||
* Train XGBoost model with the RDD-represented data
|
||||
*
|
||||
* @param trainingData the training set represented as RDD
|
||||
* @param params Map containing the configuration entries
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
* @param obj the user-defined objective function, null by default
|
||||
* @param eval the user-defined evaluation function, null by default
|
||||
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing the value represented the missing value in the dataset
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @param missing The value which represents a missing value in the dataset
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training has failed
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
@ -300,19 +304,19 @@ object XGBoost extends Serializable {
|
||||
missing: Float = Float.NaN): XGBoostModel = {
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||
" for now")
|
||||
" for now")
|
||||
}
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
if (obj != null) {
|
||||
require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
|
||||
" you have to specify the objective type as classification or regression with a" +
|
||||
" customized objective function")
|
||||
" you have to specify the objective type as classification or regression with a" +
|
||||
" customized objective function")
|
||||
}
|
||||
val trackerConf = params.get("tracker_conf") match {
|
||||
case None => TrackerConf()
|
||||
case Some(conf: TrackerConf) => conf
|
||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||
"instance of TrackerConf.")
|
||||
"instance of TrackerConf.")
|
||||
}
|
||||
val timeoutRequestWorkers: Long = params.get("timeout_request_workers") match {
|
||||
case None => 0L
|
||||
@ -339,8 +343,7 @@ object XGBoost extends Serializable {
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams,
|
||||
sparkJobThread, isClsTask)
|
||||
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, sparkJobThread, isClsTask)
|
||||
if (isClsTask){
|
||||
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
||||
params.getOrElse("num_class", "2").toString.toInt
|
||||
@ -352,10 +355,13 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
private def postTrackerReturnProcessing(
|
||||
trackerReturnVal: Int, distributedBoosters: RDD[Booster],
|
||||
params: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean):
|
||||
XGBoostModel = {
|
||||
trackerReturnVal: Int,
|
||||
distributedBoosters: RDD[Booster],
|
||||
sparkJobThread: Thread,
|
||||
isClassificationTask: Boolean): XGBoostModel = {
|
||||
if (trackerReturnVal == 0) {
|
||||
// Copies of the finished model reside in each partition of the `distributedBoosters`.
|
||||
// Any of them can be used to create the model. Here, just choose the first partition.
|
||||
val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask)
|
||||
distributedBoosters.unpersist(false)
|
||||
xgboostModel
|
||||
@ -365,7 +371,7 @@ object XGBoost extends Serializable {
|
||||
sparkJobThread.interrupt()
|
||||
}
|
||||
} catch {
|
||||
case ie: InterruptedException =>
|
||||
case _: InterruptedException =>
|
||||
logger.info("spark job thread is interrupted")
|
||||
}
|
||||
throw new XGBoostError("XGBoostModel training failed")
|
||||
@ -380,8 +386,10 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
private def setGeneralModelParams(
|
||||
featureCol: String, labelCol: String, predCol: String, xgBoostModel: XGBoostModel):
|
||||
XGBoostModel = {
|
||||
featureCol: String,
|
||||
labelCol: String,
|
||||
predCol: String,
|
||||
xgBoostModel: XGBoostModel): XGBoostModel = {
|
||||
xgBoostModel.setFeaturesCol(featureCol)
|
||||
xgBoostModel.setLabelCol(labelCol)
|
||||
xgBoostModel.setPredictionCol(predCol)
|
||||
@ -422,13 +430,17 @@ object XGBoost extends Serializable {
|
||||
case "_reg_" =>
|
||||
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
||||
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
|
||||
case other =>
|
||||
throw new XGBoostError(s"Unknown model type $other. Supported types " +
|
||||
s"are: ['_reg_', '_cls_'].")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class Watches private(val train: DMatrix, val test: DMatrix) {
|
||||
|
||||
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
||||
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||
|
||||
def size: Int = toMap.size
|
||||
|
||||
@ -440,6 +452,7 @@ private class Watches private(val train: DMatrix, val test: DMatrix) {
|
||||
}
|
||||
|
||||
private object Watches {
|
||||
|
||||
def apply(
|
||||
params: Map[String, Any],
|
||||
labeledPoints: Iterator[XGBLabeledPoint],
|
||||
|
||||
@ -16,21 +16,23 @@
|
||||
|
||||
package ml.dmlc.xgboost4j
|
||||
|
||||
/** Labeled training data point. */
|
||||
/**
|
||||
* Labeled training data point.
|
||||
*
|
||||
* @param label Label of this point.
|
||||
* @param indices Feature indices of this point or `null` if the data is dense.
|
||||
* @param values Feature values of this point.
|
||||
* @param weight Weight of this point.
|
||||
* @param group Group of this point (used for ranking) or -1.
|
||||
* @param baseMargin Initial prediction on this point or `Float.NaN`
|
||||
*/
|
||||
case class LabeledPoint(
|
||||
/** Label of this point. */
|
||||
label: Float,
|
||||
/** Feature indices of this point or `null` if the data is dense. */
|
||||
indices: Array[Int],
|
||||
/** Feature values of this point. */
|
||||
values: Array[Float],
|
||||
/** Weight of this point. */
|
||||
weight: Float = 1.0f,
|
||||
/** Group of this point (used for ranking) or -1. */
|
||||
weight: Float = 1f,
|
||||
group: Int = -1,
|
||||
/** Initial prediction on this point or `Float.NaN`. */
|
||||
baseMargin: Float = Float.NaN
|
||||
) extends Serializable {
|
||||
baseMargin: Float = Float.NaN) extends Serializable {
|
||||
require(indices == null || indices.length == values.length,
|
||||
"indices and values must have the same number of elements")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user