[jvm-packages] Add some documentation to xgboost4j-spark plus minor style edits (#2823)
* add scala docs to several methods * indentation * license formatting * clarify distributed boosters * address some review comments * reduce doc lengths * change method name, clarify doc * reset make config * delete most comments * more review feedback
This commit is contained in:
parent
46f2b820f1
commit
a8f670d247
@ -30,28 +30,30 @@ import org.apache.spark.sql.Dataset
|
|||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rabit tracker configurations.
|
||||||
|
*
|
||||||
|
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||||
|
* Set timeout length to zero to disable timeout.
|
||||||
|
* Use a finite, non-zero timeout value to prevent tracker from
|
||||||
|
* hanging indefinitely (in milliseconds)
|
||||||
|
* (supported by "scala" implementation only.)
|
||||||
|
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
||||||
|
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||||
|
* in Scala without Python components, and with full support of timeouts.
|
||||||
|
* The Scala implementation is currently experimental, use at your own risk.
|
||||||
|
*/
|
||||||
|
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
|
||||||
|
|
||||||
object TrackerConf {
|
object TrackerConf {
|
||||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Rabit tracker configurations.
|
|
||||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
|
||||||
* Set timeout length to zero to disable timeout.
|
|
||||||
* Use a finite, non-zero timeout value to prevent tracker from
|
|
||||||
* hanging indefinitely (in milliseconds)
|
|
||||||
* (supported by "scala" implementation only.)
|
|
||||||
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
|
||||||
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
|
||||||
* in Scala without Python components, and with full support of timeouts.
|
|
||||||
* The Scala implementation is currently experimental, use at your own risk.
|
|
||||||
*/
|
|
||||||
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
|
|
||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
private def fromDenseToSparseLabeledPoints(
|
private def removeMissingValues(
|
||||||
denseLabeledPoints: Iterator[XGBLabeledPoint],
|
denseLabeledPoints: Iterator[XGBLabeledPoint],
|
||||||
missing: Float): Iterator[XGBLabeledPoint] = {
|
missing: Float): Iterator[XGBLabeledPoint] = {
|
||||||
if (!missing.isNaN) {
|
if (!missing.isNaN) {
|
||||||
@ -89,7 +91,7 @@ object XGBoost extends Serializable {
|
|||||||
} else {
|
} else {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
||||||
"If you want to specify base margin, ensure all values are non-NaN.")
|
s"If you want to specify base margin, ensure all values are non-NaN.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,23 +120,23 @@ object XGBoost extends Serializable {
|
|||||||
if (labeledPoints.isEmpty) {
|
if (labeledPoints.isEmpty) {
|
||||||
throw new XGBoostError(
|
throw new XGBoostError(
|
||||||
s"detected an empty partition in the training data, partition ID:" +
|
s"detected an empty partition in the training data, partition ID:" +
|
||||||
s" ${TaskContext.getPartitionId()}")
|
s" ${TaskContext.getPartitionId()}")
|
||||||
}
|
}
|
||||||
val cacheFileName = if (useExternalMemory) {
|
val cacheFileName = if (useExternalMemory) {
|
||||||
s"$appName-${TaskContext.get().stageId()}-" +
|
s"$appName-${TaskContext.get().stageId()}-" +
|
||||||
s"dtrain_cache-${TaskContext.getPartitionId()}"
|
s"dtrain_cache-${TaskContext.getPartitionId()}"
|
||||||
} else {
|
} else {
|
||||||
null
|
null
|
||||||
}
|
}
|
||||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv)
|
Rabit.init(rabitEnv)
|
||||||
val watches = Watches(params,
|
val watches = Watches(params,
|
||||||
fromDenseToSparseLabeledPoints(labeledPoints, missing),
|
removeMissingValues(labeledPoints, missing),
|
||||||
fromBaseMarginsToArray(baseMargins), cacheFileName)
|
fromBaseMarginsToArray(baseMargins), cacheFileName)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
||||||
.map(_.toString.toInt).getOrElse(0)
|
.map(_.toString.toInt).getOrElse(0)
|
||||||
val booster = SXGBoost.train(watches.train, params, round,
|
val booster = SXGBoost.train(watches.train, params, round,
|
||||||
watches = watches.toMap, obj = obj, eval = eval,
|
watches = watches.toMap, obj = obj, eval = eval,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds)
|
earlyStoppingRound = numEarlyStoppingRounds)
|
||||||
@ -147,17 +149,18 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* train XGBoost model with the DataFrame-represented data
|
* Train XGBoost model with the DataFrame-represented data
|
||||||
* @param trainingData the trainingset represented as DataFrame
|
*
|
||||||
|
* @param trainingData the training set represented as DataFrame
|
||||||
* @param params Map containing the parameters to configure XGBoost
|
* @param params Map containing the parameters to configure XGBoost
|
||||||
* @param round the number of iterations
|
* @param round the number of iterations
|
||||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||||
* workers equals to the partition number of trainingData RDD
|
* workers equals to the partition number of trainingData RDD
|
||||||
* @param obj the user-defined objective function, null by default
|
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||||
* @param eval the user-defined evaluation function, null by default
|
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
* @param missing the value represented the missing value in the dataset
|
* @param missing The value which represents a missing value in the dataset
|
||||||
* @param featureCol the name of input column, "features" as default value
|
* @param featureCol the name of input column, "features" as default value
|
||||||
* @param labelCol the name of output column, "label" as default value
|
* @param labelCol the name of output column, "label" as default value
|
||||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||||
@ -200,14 +203,15 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* train XGBoost model with the RDD-represented data
|
* Train XGBoost model with the RDD-represented data
|
||||||
* @param trainingData the trainingset represented as RDD
|
*
|
||||||
|
* @param trainingData the training set represented as RDD
|
||||||
* @param params Map containing the configuration entries
|
* @param params Map containing the configuration entries
|
||||||
* @param round the number of iterations
|
* @param round the number of iterations
|
||||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||||
* workers equals to the partition number of trainingData RDD
|
* workers equals to the partition number of trainingData RDD
|
||||||
* @param obj the user-defined objective function, null by default
|
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||||
* @param eval the user-defined evaluation function, null by default
|
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
* @param missing the value represented the missing value in the dataset
|
* @param missing the value represented the missing value in the dataset
|
||||||
@ -224,8 +228,7 @@ object XGBoost extends Serializable {
|
|||||||
eval: EvalTrait = null,
|
eval: EvalTrait = null,
|
||||||
useExternalMemory: Boolean = false,
|
useExternalMemory: Boolean = false,
|
||||||
missing: Float = Float.NaN): XGBoostModel = {
|
missing: Float = Float.NaN): XGBoostModel = {
|
||||||
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory,
|
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing)
|
||||||
missing)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private def overrideParamsAccordingToTaskCPUs(
|
private def overrideParamsAccordingToTaskCPUs(
|
||||||
@ -256,18 +259,19 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* various of train()
|
* Train XGBoost model with the RDD-represented data
|
||||||
* @param trainingData the trainingset represented as RDD
|
*
|
||||||
|
* @param trainingData the training set represented as RDD
|
||||||
* @param params Map containing the configuration entries
|
* @param params Map containing the configuration entries
|
||||||
* @param round the number of iterations
|
* @param round the number of iterations
|
||||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||||
* workers equals to the partition number of trainingData RDD
|
* workers equals to the partition number of trainingData RDD
|
||||||
* @param obj the user-defined objective function, null by default
|
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||||
* @param eval the user-defined evaluation function, null by default
|
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
* @param missing the value represented the missing value in the dataset
|
* @param missing The value which represents a missing value in the dataset
|
||||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training has failed
|
||||||
* @return XGBoostModel when successful training
|
* @return XGBoostModel when successful training
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
@ -300,19 +304,19 @@ object XGBoost extends Serializable {
|
|||||||
missing: Float = Float.NaN): XGBoostModel = {
|
missing: Float = Float.NaN): XGBoostModel = {
|
||||||
if (params.contains("tree_method")) {
|
if (params.contains("tree_method")) {
|
||||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||||
" for now")
|
" for now")
|
||||||
}
|
}
|
||||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||||
if (obj != null) {
|
if (obj != null) {
|
||||||
require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
|
require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
|
||||||
" you have to specify the objective type as classification or regression with a" +
|
" you have to specify the objective type as classification or regression with a" +
|
||||||
" customized objective function")
|
" customized objective function")
|
||||||
}
|
}
|
||||||
val trackerConf = params.get("tracker_conf") match {
|
val trackerConf = params.get("tracker_conf") match {
|
||||||
case None => TrackerConf()
|
case None => TrackerConf()
|
||||||
case Some(conf: TrackerConf) => conf
|
case Some(conf: TrackerConf) => conf
|
||||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||||
"instance of TrackerConf.")
|
"instance of TrackerConf.")
|
||||||
}
|
}
|
||||||
val timeoutRequestWorkers: Long = params.get("timeout_request_workers") match {
|
val timeoutRequestWorkers: Long = params.get("timeout_request_workers") match {
|
||||||
case None => 0L
|
case None => 0L
|
||||||
@ -339,8 +343,7 @@ object XGBoost extends Serializable {
|
|||||||
val isClsTask = isClassificationTask(params)
|
val isClsTask = isClassificationTask(params)
|
||||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams,
|
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, sparkJobThread, isClsTask)
|
||||||
sparkJobThread, isClsTask)
|
|
||||||
if (isClsTask){
|
if (isClsTask){
|
||||||
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
||||||
params.getOrElse("num_class", "2").toString.toInt
|
params.getOrElse("num_class", "2").toString.toInt
|
||||||
@ -352,10 +355,13 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private def postTrackerReturnProcessing(
|
private def postTrackerReturnProcessing(
|
||||||
trackerReturnVal: Int, distributedBoosters: RDD[Booster],
|
trackerReturnVal: Int,
|
||||||
params: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean):
|
distributedBoosters: RDD[Booster],
|
||||||
XGBoostModel = {
|
sparkJobThread: Thread,
|
||||||
|
isClassificationTask: Boolean): XGBoostModel = {
|
||||||
if (trackerReturnVal == 0) {
|
if (trackerReturnVal == 0) {
|
||||||
|
// Copies of the finished model reside in each partition of the `distributedBoosters`.
|
||||||
|
// Any of them can be used to create the model. Here, just choose the first partition.
|
||||||
val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask)
|
val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask)
|
||||||
distributedBoosters.unpersist(false)
|
distributedBoosters.unpersist(false)
|
||||||
xgboostModel
|
xgboostModel
|
||||||
@ -365,7 +371,7 @@ object XGBoost extends Serializable {
|
|||||||
sparkJobThread.interrupt()
|
sparkJobThread.interrupt()
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
case ie: InterruptedException =>
|
case _: InterruptedException =>
|
||||||
logger.info("spark job thread is interrupted")
|
logger.info("spark job thread is interrupted")
|
||||||
}
|
}
|
||||||
throw new XGBoostError("XGBoostModel training failed")
|
throw new XGBoostError("XGBoostModel training failed")
|
||||||
@ -380,8 +386,10 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private def setGeneralModelParams(
|
private def setGeneralModelParams(
|
||||||
featureCol: String, labelCol: String, predCol: String, xgBoostModel: XGBoostModel):
|
featureCol: String,
|
||||||
XGBoostModel = {
|
labelCol: String,
|
||||||
|
predCol: String,
|
||||||
|
xgBoostModel: XGBoostModel): XGBoostModel = {
|
||||||
xgBoostModel.setFeaturesCol(featureCol)
|
xgBoostModel.setFeaturesCol(featureCol)
|
||||||
xgBoostModel.setLabelCol(labelCol)
|
xgBoostModel.setLabelCol(labelCol)
|
||||||
xgBoostModel.setPredictionCol(predCol)
|
xgBoostModel.setPredictionCol(predCol)
|
||||||
@ -422,13 +430,17 @@ object XGBoost extends Serializable {
|
|||||||
case "_reg_" =>
|
case "_reg_" =>
|
||||||
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
||||||
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
|
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
|
||||||
|
case other =>
|
||||||
|
throw new XGBoostError(s"Unknown model type $other. Supported types " +
|
||||||
|
s"are: ['_reg_', '_cls_'].")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class Watches private(val train: DMatrix, val test: DMatrix) {
|
private class Watches private(val train: DMatrix, val test: DMatrix) {
|
||||||
|
|
||||||
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
||||||
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||||
|
|
||||||
def size: Int = toMap.size
|
def size: Int = toMap.size
|
||||||
|
|
||||||
@ -440,6 +452,7 @@ private class Watches private(val train: DMatrix, val test: DMatrix) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private object Watches {
|
private object Watches {
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
params: Map[String, Any],
|
params: Map[String, Any],
|
||||||
labeledPoints: Iterator[XGBLabeledPoint],
|
labeledPoints: Iterator[XGBLabeledPoint],
|
||||||
|
|||||||
@ -16,21 +16,23 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j
|
package ml.dmlc.xgboost4j
|
||||||
|
|
||||||
/** Labeled training data point. */
|
/**
|
||||||
|
* Labeled training data point.
|
||||||
|
*
|
||||||
|
* @param label Label of this point.
|
||||||
|
* @param indices Feature indices of this point or `null` if the data is dense.
|
||||||
|
* @param values Feature values of this point.
|
||||||
|
* @param weight Weight of this point.
|
||||||
|
* @param group Group of this point (used for ranking) or -1.
|
||||||
|
* @param baseMargin Initial prediction on this point or `Float.NaN`
|
||||||
|
*/
|
||||||
case class LabeledPoint(
|
case class LabeledPoint(
|
||||||
/** Label of this point. */
|
|
||||||
label: Float,
|
label: Float,
|
||||||
/** Feature indices of this point or `null` if the data is dense. */
|
|
||||||
indices: Array[Int],
|
indices: Array[Int],
|
||||||
/** Feature values of this point. */
|
|
||||||
values: Array[Float],
|
values: Array[Float],
|
||||||
/** Weight of this point. */
|
weight: Float = 1f,
|
||||||
weight: Float = 1.0f,
|
|
||||||
/** Group of this point (used for ranking) or -1. */
|
|
||||||
group: Int = -1,
|
group: Int = -1,
|
||||||
/** Initial prediction on this point or `Float.NaN`. */
|
baseMargin: Float = Float.NaN) extends Serializable {
|
||||||
baseMargin: Float = Float.NaN
|
|
||||||
) extends Serializable {
|
|
||||||
require(indices == null || indices.length == values.length,
|
require(indices == null || indices.length == values.length,
|
||||||
"indices and values must have the same number of elements")
|
"indices and values must have the same number of elements")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user