[jvm-packages] Add some documentation to xgboost4j-spark plus minor style edits (#2823)
* add scala docs to several methods * indentation * license formatting * clarify distributed boosters * address some review comments * reduce doc lengths * change method name, clarify doc * reset make config * delete most comments * more review feedback
This commit is contained in:
parent
46f2b820f1
commit
a8f670d247
@ -30,12 +30,10 @@ import org.apache.spark.sql.Dataset
|
|||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||||
|
|
||||||
object TrackerConf {
|
|
||||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Rabit tracker configurations.
|
* Rabit tracker configurations.
|
||||||
|
*
|
||||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||||
* Set timeout length to zero to disable timeout.
|
* Set timeout length to zero to disable timeout.
|
||||||
* Use a finite, non-zero timeout value to prevent tracker from
|
* Use a finite, non-zero timeout value to prevent tracker from
|
||||||
@ -48,10 +46,14 @@ object TrackerConf {
|
|||||||
*/
|
*/
|
||||||
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
|
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
|
||||||
|
|
||||||
|
object TrackerConf {
|
||||||
|
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||||
|
}
|
||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
private def fromDenseToSparseLabeledPoints(
|
private def removeMissingValues(
|
||||||
denseLabeledPoints: Iterator[XGBLabeledPoint],
|
denseLabeledPoints: Iterator[XGBLabeledPoint],
|
||||||
missing: Float): Iterator[XGBLabeledPoint] = {
|
missing: Float): Iterator[XGBLabeledPoint] = {
|
||||||
if (!missing.isNaN) {
|
if (!missing.isNaN) {
|
||||||
@ -89,7 +91,7 @@ object XGBoost extends Serializable {
|
|||||||
} else {
|
} else {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
s"Encountered a partition with $nUndefined NaN base margin values. " +
|
||||||
"If you want to specify base margin, ensure all values are non-NaN.")
|
s"If you want to specify base margin, ensure all values are non-NaN.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,7 +131,7 @@ object XGBoost extends Serializable {
|
|||||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv)
|
Rabit.init(rabitEnv)
|
||||||
val watches = Watches(params,
|
val watches = Watches(params,
|
||||||
fromDenseToSparseLabeledPoints(labeledPoints, missing),
|
removeMissingValues(labeledPoints, missing),
|
||||||
fromBaseMarginsToArray(baseMargins), cacheFileName)
|
fromBaseMarginsToArray(baseMargins), cacheFileName)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -147,17 +149,18 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* train XGBoost model with the DataFrame-represented data
|
* Train XGBoost model with the DataFrame-represented data
|
||||||
|
*
|
||||||
* @param trainingData the training set represented as DataFrame
|
* @param trainingData the training set represented as DataFrame
|
||||||
* @param params Map containing the parameters to configure XGBoost
|
* @param params Map containing the parameters to configure XGBoost
|
||||||
* @param round the number of iterations
|
* @param round the number of iterations
|
||||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||||
* workers equals to the partition number of trainingData RDD
|
* workers equals to the partition number of trainingData RDD
|
||||||
* @param obj the user-defined objective function, null by default
|
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||||
* @param eval the user-defined evaluation function, null by default
|
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
* @param missing the value represented the missing value in the dataset
|
* @param missing The value which represents a missing value in the dataset
|
||||||
* @param featureCol the name of input column, "features" as default value
|
* @param featureCol the name of input column, "features" as default value
|
||||||
* @param labelCol the name of output column, "label" as default value
|
* @param labelCol the name of output column, "label" as default value
|
||||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||||
@ -200,14 +203,15 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* train XGBoost model with the RDD-represented data
|
* Train XGBoost model with the RDD-represented data
|
||||||
|
*
|
||||||
* @param trainingData the training set represented as RDD
|
* @param trainingData the training set represented as RDD
|
||||||
* @param params Map containing the configuration entries
|
* @param params Map containing the configuration entries
|
||||||
* @param round the number of iterations
|
* @param round the number of iterations
|
||||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||||
* workers equals to the partition number of trainingData RDD
|
* workers equals to the partition number of trainingData RDD
|
||||||
* @param obj the user-defined objective function, null by default
|
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||||
* @param eval the user-defined evaluation function, null by default
|
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
* @param missing the value represented the missing value in the dataset
|
* @param missing the value represented the missing value in the dataset
|
||||||
@ -224,8 +228,7 @@ object XGBoost extends Serializable {
|
|||||||
eval: EvalTrait = null,
|
eval: EvalTrait = null,
|
||||||
useExternalMemory: Boolean = false,
|
useExternalMemory: Boolean = false,
|
||||||
missing: Float = Float.NaN): XGBoostModel = {
|
missing: Float = Float.NaN): XGBoostModel = {
|
||||||
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory,
|
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing)
|
||||||
missing)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private def overrideParamsAccordingToTaskCPUs(
|
private def overrideParamsAccordingToTaskCPUs(
|
||||||
@ -256,18 +259,19 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* various of train()
|
* Train XGBoost model with the RDD-represented data
|
||||||
|
*
|
||||||
* @param trainingData the training set represented as RDD
|
* @param trainingData the training set represented as RDD
|
||||||
* @param params Map containing the configuration entries
|
* @param params Map containing the configuration entries
|
||||||
* @param round the number of iterations
|
* @param round the number of iterations
|
||||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||||
* workers equals to the partition number of trainingData RDD
|
* workers equals to the partition number of trainingData RDD
|
||||||
* @param obj the user-defined objective function, null by default
|
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||||
* @param eval the user-defined evaluation function, null by default
|
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
* @param missing the value represented the missing value in the dataset
|
* @param missing The value which represents a missing value in the dataset
|
||||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training has failed
|
||||||
* @return XGBoostModel when successful training
|
* @return XGBoostModel when successful training
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
@ -339,8 +343,7 @@ object XGBoost extends Serializable {
|
|||||||
val isClsTask = isClassificationTask(params)
|
val isClsTask = isClassificationTask(params)
|
||||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams,
|
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, sparkJobThread, isClsTask)
|
||||||
sparkJobThread, isClsTask)
|
|
||||||
if (isClsTask){
|
if (isClsTask){
|
||||||
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
||||||
params.getOrElse("num_class", "2").toString.toInt
|
params.getOrElse("num_class", "2").toString.toInt
|
||||||
@ -352,10 +355,13 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private def postTrackerReturnProcessing(
|
private def postTrackerReturnProcessing(
|
||||||
trackerReturnVal: Int, distributedBoosters: RDD[Booster],
|
trackerReturnVal: Int,
|
||||||
params: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean):
|
distributedBoosters: RDD[Booster],
|
||||||
XGBoostModel = {
|
sparkJobThread: Thread,
|
||||||
|
isClassificationTask: Boolean): XGBoostModel = {
|
||||||
if (trackerReturnVal == 0) {
|
if (trackerReturnVal == 0) {
|
||||||
|
// Copies of the finished model reside in each partition of the `distributedBoosters`.
|
||||||
|
// Any of them can be used to create the model. Here, just choose the first partition.
|
||||||
val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask)
|
val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask)
|
||||||
distributedBoosters.unpersist(false)
|
distributedBoosters.unpersist(false)
|
||||||
xgboostModel
|
xgboostModel
|
||||||
@ -365,7 +371,7 @@ object XGBoost extends Serializable {
|
|||||||
sparkJobThread.interrupt()
|
sparkJobThread.interrupt()
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
case ie: InterruptedException =>
|
case _: InterruptedException =>
|
||||||
logger.info("spark job thread is interrupted")
|
logger.info("spark job thread is interrupted")
|
||||||
}
|
}
|
||||||
throw new XGBoostError("XGBoostModel training failed")
|
throw new XGBoostError("XGBoostModel training failed")
|
||||||
@ -380,8 +386,10 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private def setGeneralModelParams(
|
private def setGeneralModelParams(
|
||||||
featureCol: String, labelCol: String, predCol: String, xgBoostModel: XGBoostModel):
|
featureCol: String,
|
||||||
XGBoostModel = {
|
labelCol: String,
|
||||||
|
predCol: String,
|
||||||
|
xgBoostModel: XGBoostModel): XGBoostModel = {
|
||||||
xgBoostModel.setFeaturesCol(featureCol)
|
xgBoostModel.setFeaturesCol(featureCol)
|
||||||
xgBoostModel.setLabelCol(labelCol)
|
xgBoostModel.setLabelCol(labelCol)
|
||||||
xgBoostModel.setPredictionCol(predCol)
|
xgBoostModel.setPredictionCol(predCol)
|
||||||
@ -422,11 +430,15 @@ object XGBoost extends Serializable {
|
|||||||
case "_reg_" =>
|
case "_reg_" =>
|
||||||
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
||||||
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
|
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
|
||||||
|
case other =>
|
||||||
|
throw new XGBoostError(s"Unknown model type $other. Supported types " +
|
||||||
|
s"are: ['_reg_', '_cls_'].")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class Watches private(val train: DMatrix, val test: DMatrix) {
|
private class Watches private(val train: DMatrix, val test: DMatrix) {
|
||||||
|
|
||||||
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
||||||
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||||
|
|
||||||
@ -440,6 +452,7 @@ private class Watches private(val train: DMatrix, val test: DMatrix) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private object Watches {
|
private object Watches {
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
params: Map[String, Any],
|
params: Map[String, Any],
|
||||||
labeledPoints: Iterator[XGBLabeledPoint],
|
labeledPoints: Iterator[XGBLabeledPoint],
|
||||||
|
|||||||
@ -16,21 +16,23 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j
|
package ml.dmlc.xgboost4j
|
||||||
|
|
||||||
/** Labeled training data point. */
|
/**
|
||||||
|
* Labeled training data point.
|
||||||
|
*
|
||||||
|
* @param label Label of this point.
|
||||||
|
* @param indices Feature indices of this point or `null` if the data is dense.
|
||||||
|
* @param values Feature values of this point.
|
||||||
|
* @param weight Weight of this point.
|
||||||
|
* @param group Group of this point (used for ranking) or -1.
|
||||||
|
* @param baseMargin Initial prediction on this point or `Float.NaN`
|
||||||
|
*/
|
||||||
case class LabeledPoint(
|
case class LabeledPoint(
|
||||||
/** Label of this point. */
|
|
||||||
label: Float,
|
label: Float,
|
||||||
/** Feature indices of this point or `null` if the data is dense. */
|
|
||||||
indices: Array[Int],
|
indices: Array[Int],
|
||||||
/** Feature values of this point. */
|
|
||||||
values: Array[Float],
|
values: Array[Float],
|
||||||
/** Weight of this point. */
|
weight: Float = 1f,
|
||||||
weight: Float = 1.0f,
|
|
||||||
/** Group of this point (used for ranking) or -1. */
|
|
||||||
group: Int = -1,
|
group: Int = -1,
|
||||||
/** Initial prediction on this point or `Float.NaN`. */
|
baseMargin: Float = Float.NaN) extends Serializable {
|
||||||
baseMargin: Float = Float.NaN
|
|
||||||
) extends Serializable {
|
|
||||||
require(indices == null || indices.length == values.length,
|
require(indices == null || indices.length == values.length,
|
||||||
"indices and values must have the same number of elements")
|
"indices and values must have the same number of elements")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user