From a8f670d24742002ed35f8e4927d9e7b7d3ec1d14 Mon Sep 17 00:00:00 2001 From: Seth Hendrickson Date: Thu, 2 Nov 2017 13:16:02 -0700 Subject: [PATCH] [jvm-packages] Add some documentation to xgboost4j-spark plus minor style edits (#2823) * add scala docs to several methods * indentation * license formatting * clarify distributed boosters * address some review comments * reduce doc lengths * change method name, clarify doc * reset make config * delete most comments * more review feedback --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 113 ++++++++++-------- .../ml/dmlc/xgboost4j/LabeledPoint.scala | 22 ++-- 2 files changed, 75 insertions(+), 60 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 3e8736370..2f218aa15 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -30,28 +30,30 @@ import org.apache.spark.sql.Dataset import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext} + +/** + * Rabit tracker configurations. + * + * @param workerConnectionTimeout The timeout for all workers to connect to the tracker. + * Set timeout length to zero to disable timeout. + * Use a finite, non-zero timeout value to prevent tracker from + * hanging indefinitely (in milliseconds) + * (supported by "scala" implementation only.) + * @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of + * the Python Rabit tracker (in dmlc_core), whereas the latter is implemented + * in Scala without Python components, and with full support of timeouts. + * The Scala implementation is currently experimental, use at your own risk. + */ +case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String) + object TrackerConf { def apply(): TrackerConf = TrackerConf(0L, "python") } -/** - * Rabit tracker configurations. - * @param workerConnectionTimeout The timeout for all workers to connect to the tracker. - * Set timeout length to zero to disable timeout. - * Use a finite, non-zero timeout value to prevent tracker from - * hanging indefinitely (in milliseconds) - * (supported by "scala" implementation only.) - * @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of - * the Python Rabit tracker (in dmlc_core), whereas the latter is implemented - * in Scala without Python components, and with full support of timeouts. - * The Scala implementation is currently experimental, use at your own risk. - */ -case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String) - object XGBoost extends Serializable { private val logger = LogFactory.getLog("XGBoostSpark") - private def fromDenseToSparseLabeledPoints( + private def removeMissingValues( denseLabeledPoints: Iterator[XGBLabeledPoint], missing: Float): Iterator[XGBLabeledPoint] = { if (!missing.isNaN) { @@ -89,7 +91,7 @@ object XGBoost extends Serializable { } else { throw new IllegalArgumentException( s"Encountered a partition with $nUndefined NaN base margin values. " + - "If you want to specify base margin, ensure all values are non-NaN.") + s"If you want to specify base margin, ensure all values are non-NaN.") } } @@ -118,23 +120,23 @@ object XGBoost extends Serializable { if (labeledPoints.isEmpty) { throw new XGBoostError( s"detected an empty partition in the training data, partition ID:" + - s" ${TaskContext.getPartitionId()}") + s" ${TaskContext.getPartitionId()}") } val cacheFileName = if (useExternalMemory) { s"$appName-${TaskContext.get().stageId()}-" + - s"dtrain_cache-${TaskContext.getPartitionId()}" + s"dtrain_cache-${TaskContext.getPartitionId()}" } else { null } rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) Rabit.init(rabitEnv) val watches = Watches(params, - fromDenseToSparseLabeledPoints(labeledPoints, missing), + removeMissingValues(labeledPoints, missing), fromBaseMarginsToArray(baseMargins), cacheFileName) try { val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds") - .map(_.toString.toInt).getOrElse(0) + .map(_.toString.toInt).getOrElse(0) val booster = SXGBoost.train(watches.train, params, round, watches = watches.toMap, obj = obj, eval = eval, earlyStoppingRound = numEarlyStoppingRounds) @@ -147,17 +149,18 @@ object XGBoost extends Serializable { } /** - * train XGBoost model with the DataFrame-represented data - * @param trainingData the trainingset represented as DataFrame + * Train XGBoost model with the DataFrame-represented data + * + * @param trainingData the training set represented as DataFrame * @param params Map containing the parameters to configure XGBoost * @param round the number of iterations * @param nWorkers the number of xgboost workers, 0 by default which means that the number of * workers equals to the partition number of trainingData RDD - * @param obj the user-defined objective function, null by default - * @param eval the user-defined evaluation function, null by default + * @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default + * @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as * true, the user may save the RAM cost for running XGBoost within Spark - * @param missing the value represented the missing value in the dataset + * @param missing The value which represents a missing value in the dataset * @param featureCol the name of input column, "features" as default value * @param labelCol the name of output column, "label" as default value * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed @@ -200,14 +203,15 @@ object XGBoost extends Serializable { } /** - * train XGBoost model with the RDD-represented data - * @param trainingData the trainingset represented as RDD + * Train XGBoost model with the RDD-represented data + * + * @param trainingData the training set represented as RDD * @param params Map containing the configuration entries * @param round the number of iterations * @param nWorkers the number of xgboost workers, 0 by default which means that the number of * workers equals to the partition number of trainingData RDD - * @param obj the user-defined objective function, null by default - * @param eval the user-defined evaluation function, null by default + * @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default + * @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as * true, the user may save the RAM cost for running XGBoost within Spark * @param missing the value represented the missing value in the dataset @@ -224,8 +228,7 @@ object XGBoost extends Serializable { eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = { - trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, - missing) + trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing) } private def overrideParamsAccordingToTaskCPUs( @@ -256,18 +259,19 @@ object XGBoost extends Serializable { } /** - * various of train() - * @param trainingData the trainingset represented as RDD + * Train XGBoost model with the RDD-represented data + * + * @param trainingData the training set represented as RDD * @param params Map containing the configuration entries * @param round the number of iterations * @param nWorkers the number of xgboost workers, 0 by default which means that the number of * workers equals to the partition number of trainingData RDD - * @param obj the user-defined objective function, null by default - * @param eval the user-defined evaluation function, null by default + * @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default + * @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as * true, the user may save the RAM cost for running XGBoost within Spark - * @param missing the value represented the missing value in the dataset - * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed + * @param missing The value which represents a missing value in the dataset + * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training has failed * @return XGBoostModel when successful training */ @throws(classOf[XGBoostError]) @@ -300,19 +304,19 @@ object XGBoost extends Serializable { missing: Float = Float.NaN): XGBoostModel = { if (params.contains("tree_method")) { require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" + - " for now") + " for now") } require(nWorkers > 0, "you must specify more than 0 workers") if (obj != null) { require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," + - " you have to specify the objective type as classification or regression with a" + - " customized objective function") + " you have to specify the objective type as classification or regression with a" + + " customized objective function") } val trackerConf = params.get("tracker_conf") match { case None => TrackerConf() case Some(conf: TrackerConf) => conf case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " + - "instance of TrackerConf.") + "instance of TrackerConf.") } val timeoutRequestWorkers: Long = params.get("timeout_request_workers") match { case None => 0L @@ -339,8 +343,7 @@ object XGBoost extends Serializable { val isClsTask = isClassificationTask(params) val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L)) logger.info(s"Rabit returns with exit code $trackerReturnVal") - val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams, - sparkJobThread, isClsTask) + val model = postTrackerReturnProcessing(trackerReturnVal, boosters, sparkJobThread, isClsTask) if (isClsTask){ model.asInstanceOf[XGBoostClassificationModel].numOfClasses = params.getOrElse("num_class", "2").toString.toInt @@ -352,10 +355,13 @@ object XGBoost extends Serializable { } private def postTrackerReturnProcessing( - trackerReturnVal: Int, distributedBoosters: RDD[Booster], - params: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean): - XGBoostModel = { + trackerReturnVal: Int, + distributedBoosters: RDD[Booster], + sparkJobThread: Thread, + isClassificationTask: Boolean): XGBoostModel = { if (trackerReturnVal == 0) { + // Copies of the finished model reside in each partition of the `distributedBoosters`. + // Any of them can be used to create the model. Here, just choose the first partition. val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask) distributedBoosters.unpersist(false) xgboostModel @@ -365,7 +371,7 @@ object XGBoost extends Serializable { sparkJobThread.interrupt() } } catch { - case ie: InterruptedException => + case _: InterruptedException => logger.info("spark job thread is interrupted") } throw new XGBoostError("XGBoostModel training failed") @@ -380,8 +386,10 @@ object XGBoost extends Serializable { } private def setGeneralModelParams( - featureCol: String, labelCol: String, predCol: String, xgBoostModel: XGBoostModel): - XGBoostModel = { + featureCol: String, + labelCol: String, + predCol: String, + xgBoostModel: XGBoostModel): XGBoostModel = { xgBoostModel.setFeaturesCol(featureCol) xgBoostModel.setLabelCol(labelCol) xgBoostModel.setPredictionCol(predCol) @@ -422,13 +430,17 @@ object XGBoost extends Serializable { case "_reg_" => val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream)) setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel) + case other => + throw new XGBoostError(s"Unknown model type $other. Supported types " + + s"are: ['_reg_', '_cls_'].") } } } private class Watches private(val train: DMatrix, val test: DMatrix) { + def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test) - .filter { case (_, matrix) => matrix.rowNum > 0 } + .filter { case (_, matrix) => matrix.rowNum > 0 } def size: Int = toMap.size @@ -440,6 +452,7 @@ private class Watches private(val train: DMatrix, val test: DMatrix) { } private object Watches { + def apply( params: Map[String, Any], labeledPoints: Iterator[XGBLabeledPoint], diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/LabeledPoint.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/LabeledPoint.scala index 48c9c8367..9a92d1b91 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/LabeledPoint.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/LabeledPoint.scala @@ -16,21 +16,23 @@ package ml.dmlc.xgboost4j -/** Labeled training data point. */ +/** + * Labeled training data point. + * + * @param label Label of this point. + * @param indices Feature indices of this point or `null` if the data is dense. + * @param values Feature values of this point. + * @param weight Weight of this point. + * @param group Group of this point (used for ranking) or -1. + * @param baseMargin Initial prediction on this point or `Float.NaN` + */ case class LabeledPoint( - /** Label of this point. */ label: Float, - /** Feature indices of this point or `null` if the data is dense. */ indices: Array[Int], - /** Feature values of this point. */ values: Array[Float], - /** Weight of this point. */ - weight: Float = 1.0f, - /** Group of this point (used for ranking) or -1. */ + weight: Float = 1f, group: Int = -1, - /** Initial prediction on this point or `Float.NaN`. */ - baseMargin: Float = Float.NaN -) extends Serializable { + baseMargin: Float = Float.NaN) extends Serializable { require(indices == null || indices.length == values.length, "indices and values must have the same number of elements")